diff --git a/.bazelrc b/.bazelrc index 224238d7c0b..f2aa3ac447b 100644 --- a/.bazelrc +++ b/.bazelrc @@ -143,6 +143,11 @@ build:mkl --define=tensorflow_mkldnn_contraction_kernel=0 build:mkl --define=build_with_mkl_dnn_v1_only=true build:mkl -c opt +# config to build OneDNN backend with a user specified threadpool. +build:mkl_threadpool --define=build_with_mkl=true --define=enable_mkl=true +build:mkl_threadpool --define=tensorflow_mkldnn_contraction_kernel=0 +build:mkl_threadpool --define=build_with_mkldnn_threadpool=true +build:mkl_threadpool -c opt # This config refers to building with CUDA available. It does not necessarily # mean that we build CUDA op kernels. build:using_cuda --define=using_cuda=true @@ -235,10 +240,15 @@ build:c++17 --cxxopt=-std=c++1z build:c++17 --cxxopt=-stdlib=libc++ build:c++1z --config=c++17 -# Enable using platform specific build settings +# Enable using platform specific build settings, except when cross-compiling for +# mobile platforms. build --enable_platform_specific_config +build:android --noenable_platform_specific_config +build:ios --noenable_platform_specific_config # Suppress C++ compiler warnings, otherwise build logs become 10s of MBs. +build:android --copt=-w +build:ios --copt=-w build:linux --copt=-w build:macos --copt=-w build:windows --copt=/w @@ -258,6 +268,10 @@ build:macos --define=INCLUDEDIR=$(PREFIX)/include # TF_SYSTEM_LIBS do not work on windows. # By default, build TF in C++ 14 mode. +build:android --cxxopt=-std=c++14 +build:android --host_cxxopt=-std=c++14 +build:ios --cxxopt=-std=c++14 +build:ios --host_cxxopt=-std=c++14 build:linux --cxxopt=-std=c++14 build:linux --host_cxxopt=-std=c++14 build:macos --cxxopt=-std=c++14 diff --git a/.github/bot_config.yml b/.github/bot_config.yml new file mode 100644 index 00000000000..88c737f41e2 --- /dev/null +++ b/.github/bot_config.yml @@ -0,0 +1,87 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +# +# THIS IS A GENERATED DOCKERFILE. +# +# This file was assembled from multiple pieces, whose use is documented +# throughout. Please refer to the TensorFlow dockerfiles documentation +# for more information. + +# A list of assignees +assignees: + - amahendrakar + - ravikyram + - Saduf2019 +# A list of assignees for compiler folder +compiler_assignees: + - joker-eph +# Cuda Comment +cuda_comment: > + From the template it looks like you are installing **TensorFlow** (TF) prebuilt binaries: + * For TF-GPU - See point 1 + * For TF-CPU - See point 2 + ----------------------------------------------------------------------------------------------- + + **1. Installing **TensorFlow-GPU** (TF) prebuilt binaries** + + + Make sure you are using compatible TF and CUDA versions. + Please refer following TF version and CUDA version compatibility table. + + | TF | CUDA | + + | :-------------: | :-------------: | + + | 2.1.0 - 2.2.0 | 10.1 | + + | 1.13.1 - 2.0 | 10.0 | + + | 1.5.0 - 1.12.0 | 9.0 | + + * If you have above configuration and using _**Windows**_ platform - + * Try adding the CUDA, CUPTI, and cuDNN installation directories to the %PATH% environment variable. + * Refer [windows setup guide](https://www.tensorflow.org/install/gpu#windows_setup). + * If you have above configuration and using _**Ubuntu/Linux**_ platform - + * Try adding the CUDA, CUPTI, and cuDNN installation directories to the $LD_LIBRARY_PATH environment variable. + * Refer [linux setup guide](https://www.tensorflow.org/install/gpu#linux_setup). + * If error still persists then, apparently your CPU model does not support AVX instruction sets. + * Refer [hardware requirements](https://www.tensorflow.org/install/pip#hardware-requirements). + + ----------------------------------------------------------------------------------------------- + + **2. Installing **TensorFlow** (TF) CPU prebuilt binaries** + + + *TensorFlow release binaries version 1.6 and higher are prebuilt with AVX instruction sets.* + + + Therefore on any CPU that does not have these instruction sets, either CPU or GPU version of TF will fail to load. + + Apparently, your CPU model does not support AVX instruction sets. You can still use TensorFlow with the alternatives given below: + + * Try Google Colab to use TensorFlow. + * The easiest way to use TF will be to switch to [google colab](https://colab.sandbox.google.com/notebooks/welcome.ipynb#recent=true). You get pre-installed latest stable TF version. Also you can use ```pip install``` to install any other preferred TF version. + * It has an added advantage since you can you easily switch to different hardware accelerators (cpu, gpu, tpu) as per the task. + * All you need is a good internet connection and you are all set. + * Try to build TF from sources by changing CPU optimization flags. + + *Please let us know if this helps.* + +windows_comment: > + From the stack trace it looks like you are hitting windows path length limit. + * Try to disable path length limit on Windows 10. + * Refer [disable path length limit instructions guide.](https://mspoweruser.com/ntfs-260-character-windows-10/) + + Please let us know if this helps. diff --git a/README.md b/README.md index ba4597af14c..a76b1bfd0b7 100644 --- a/README.md +++ b/README.md @@ -142,6 +142,7 @@ Build Type | Status * [Getting Started with TensorFlow 2 from Coursera](https://www.coursera.org/learn/getting-started-with-tensor-flow2) * [Intro to TensorFlow for Deep Learning from Udacity](https://www.udacity.com/course/intro-to-tensorflow-for-deep-learning--ud187) * [Introduction to TensorFlow Lite from Udacity](https://www.udacity.com/course/intro-to-tensorflow-lite--ud190) +* [Machine Learning with TensorFlow on GCP](https://www.coursera.org/specializations/machine-learning-tensorflow-gcp) * [TensorFlow Blog](https://blog.tensorflow.org) * [Learn ML with TensorFlow](https://www.tensorflow.org/resources/learn-ml) * [TensorFlow Twitter](https://twitter.com/tensorflow) diff --git a/RELEASE.md b/RELEASE.md index 6c8921cf492..6f3aa94c203 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,3 +1,41 @@ +# Release 2.3.0 + +## Breaking Changes + +* `tf.image.extract_glimpse` has been updated to correctly process the case + where `centered=False` and `normalized=False`. This is a breaking change as + the output is different from (incorrect) previous versions. Note this + breaking change only impacts `tf.image.extract_glimpse` and + `tf.compat.v2.image.extract_glimpse` API endpoints. The behavior of + `tf.compat.v1.image.extract_glimpse` does not change. The behavior of + exsiting C++ kernel `ExtractGlimpse` does not change as well, so saved + models will not be impacted. + +# Release 2.1.1 + +## Bug Fixes and Other Changes +* Updates `sqlite3` to `3.31.01` to handle [CVE-2019-19880](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19880), [CVE-2019-19244](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19244) and [CVE-2019-19645](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19645) +* Updates `curl` to `7.69.1` to handle [CVE-2019-15601](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-15601) +* Updates `libjpeg-turbo` to `2.0.4` to handle [CVE-2018-19664](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-19664), [CVE-2018-20330](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-20330) and [CVE-2019-13960](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-13960) +* Updates Apache Spark to `2.4.5` to handle [CVE-2019-10099](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-10099), [CVE-2018-17190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-17190) and [CVE-2018-11770](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-11770) +* Fixes a versioning bug which causes Keras layers from TF 1.x to be used instead of those from TF 2.x + +# Release 2.0.2 + +## Bug Fixes and Other Changes +* Updates `sqlite3` to `3.31.01` to handle [CVE-2019-19880](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19880), [CVE-2019-19244](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19244) and [CVE-2019-19645](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19645) +* Updates `curl` to `7.69.1` to handle [CVE-2019-15601](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-15601) +* Updates `libjpeg-turbo` to `2.0.4` to handle [CVE-2018-19664](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-19664), [CVE-2018-20330](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-20330) and [CVE-2019-13960](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-13960) +* Updates Apache Spark to `2.4.5` to handle [CVE-2019-10099](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-10099), [CVE-2018-17190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-17190) and [CVE-2018-11770](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-11770) + +# Release 1.15.3 + +## Bug Fixes and Other Changes +* Updates `sqlite3` to `3.31.01` to handle [CVE-2019-19880](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19880), [CVE-2019-19244](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19244) and [CVE-2019-19645](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-19645) +* Updates `curl` to `7.69.1` to handle [CVE-2019-15601](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-15601) +* Updates `libjpeg-turbo` to `2.0.4` to handle [CVE-2018-19664](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-19664), [CVE-2018-20330](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-20330) and [CVE-2019-13960](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-13960) +* Updates Apache Spark to `2.4.5` to handle [CVE-2019-10099](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-10099), [CVE-2018-17190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-17190) and [CVE-2018-11770](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2018-11770) + # Release 2.2.0 TensorFlow 2.2 discontinues support for Python 2, [previously announced](https://groups.google.com/a/tensorflow.org/d/msg/announce/gVwS5RC8mds/dCt1ka2XAAAJ) as following [Python 2's EOL on January 1, 2020](https://www.python.org/dev/peps/pep-0373/#update). diff --git a/SECURITY.md b/SECURITY.md index 6fc2c3aa9cc..f3a6c148b2e 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -64,7 +64,7 @@ your model, and we recommend you run the TensorFlow process in a sandbox. It is possible to write models that are secure in a sense that they can safely process untrusted inputs assuming there are no bugs. There are two main reasons -to not rely on this: first, it is easy to write models which must not be exposed +to not rely on this: First, it is easy to write models which must not be exposed to untrusted inputs, and second, there are bugs in any software system of sufficient complexity. Letting users control inputs could allow them to trigger bugs either in TensorFlow or in dependent libraries. @@ -149,7 +149,7 @@ attack (or worse). Because TensorFlow behaves correctly, this is not a vulnerability in TensorFlow (although it would be a vulnerability of this hypothetical system). -As a general rule, it is incorrect behavior for Tensorflow to access memory it +As a general rule, it is incorrect behavior for TensorFlow to access memory it does not own, or to terminate in an unclean way. Bugs in TensorFlow that lead to such behaviors constitute a vulnerability. diff --git a/WORKSPACE b/WORKSPACE index 021ed6d2542..ea741c31c7f 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -114,6 +114,14 @@ http_archive( ], ) +http_archive( + name = "person_detect_data", + sha256 = "170542270da256994ce24d1e357f6e84a54fdaf7d28ff2b74725a40b70b082cf", + urls = [ + "https://storage.googleapis.com/download.tensorflow.org/data/tf_lite_micro_person_data_grayscale_2020_05_24.zip", + ], +) + # Required for dependency @com_github_grpc_grpc load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps") diff --git a/configure.py b/configure.py index 945c3036a8d..0a5b87172c0 100644 --- a/configure.py +++ b/configure.py @@ -1368,8 +1368,13 @@ def main(): # environment variables. environ_cp = dict(os.environ) - current_bazel_version = check_bazel_version(_TF_MIN_BAZEL_VERSION, - _TF_MAX_BAZEL_VERSION) + try: + current_bazel_version = check_bazel_version(_TF_MIN_BAZEL_VERSION, + _TF_MAX_BAZEL_VERSION) + except subprocess.CalledProcessError as e: + print("Error checking bazel version: ", e.output.decode('UTF-8').strip()) + raise e + _TF_CURRENT_BAZEL_VERSION = convert_version_to_int(current_bazel_version) reset_tf_configure_bazelrc() @@ -1387,7 +1392,6 @@ def main(): # Windows. environ_cp['TF_DOWNLOAD_CLANG'] = '0' environ_cp['TF_NEED_MPI'] = '0' - environ_cp['TF_SET_ANDROID_WORKSPACE'] = '0' if is_macos(): environ_cp['TF_NEED_TENSORRT'] = '0' diff --git a/tensorflow/BUILD b/tensorflow/BUILD index ab4316d5ed0..efbdf89ecea 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -524,7 +524,10 @@ package_group( ], ) -package_group(name = "ndarray_tensor_allow_list") +package_group( + name = "ndarray_tensor_allow_list", + packages = ["//learning/pathways/..."], +) # Packages that use composite tensors or dispatch. # TODO(b/154762408) Remove this package group once it's no longer needed. diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 05d5f9a3ed2..12021a294e8 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -216,6 +216,7 @@ tf_cuda_library( ], visibility = [ "//tensorflow/c:__subpackages__", + "//tensorflow/compiler/mlir/tensorflow/c:__subpackages__", ], deps = select({ "//tensorflow:android": [ @@ -394,8 +395,14 @@ tf_cuda_library( deps = [ ":tf_status", ":tf_status_internal", - "//tensorflow/core:lib", - ], + ] + select({ + "//tensorflow:android": [ + "//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs + ], + "//conditions:default": [ + "//tensorflow/core:lib", + ], + }), ) tf_cc_test( diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index e623f30b98c..e9e6d470c68 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -325,205 +325,6 @@ TF_Buffer* TFE_GetServerDef(const char* text_proto, TF_Status* status) { return ret; } -TFE_Context* TFE_CreateContextFromSession(TF_Session* session, - TF_Status* status) { - auto* opts = TFE_NewContextOptions(); - - // Reduce GPU memory allocation, and set appropriate config options for TFE - // context. - auto* config = TF_CreateConfig( - /*xla*/ false, /* gpu_memory_allow_growth */ true, /* num_cpu_devices */ - 10); - TFE_ContextOptionsSetConfig(opts, config->data, config->length, status); - if (!status->status.ok()) { - CHECK(!config); - TFE_DeleteContextOptions(opts); - return nullptr; - } - - auto* ctx = TFE_NewContextFromSession(opts, session, status); - TF_DeleteBuffer(config); - TFE_DeleteContextOptions(opts); - return ctx; -} - -// TODO: retrieve the device string via TFE_ContextListDevices() -static const char DEFAULT_CPU_DEVICE[] = - "/job:localhost/replica:0/task:0/device:CPU:0"; - -static TFE_TensorHandle* createTFEQueue(TFE_Context* ctx, TF_DataType inputType, - int tensor_id, TF_Status* status) { - std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> queueOp( - TFE_NewOp(ctx, "FIFOQueueV2", status), TFE_DeleteOp); - TFE_OpSetDevice(queueOp.get(), DEFAULT_CPU_DEVICE, status); - if (!status->status.ok()) return nullptr; - // TODO: use NAMED_TENSOR_QUEUE_CAPACITY in S4TF compiler. - TFE_OpSetAttrInt(queueOp.get(), "capacity", 1); - TFE_OpSetAttrTypeList(queueOp.get(), "component_types", &inputType, 1); - auto shared_name = tensorflow::strings::StrCat("fifo_queue_", tensor_id); - TFE_OpSetAttrString(queueOp.get(), "shared_name", shared_name.data(), - shared_name.size()); - TFE_OpSetAttrString(queueOp.get(), "container", "", 0); - - // TODO: consider making this an unknown shape. - const int64_t* dims_ptr = nullptr; - int num_dims = 0; - TFE_OpSetAttrShapeList(queueOp.get(), "shapes", &dims_ptr, &num_dims, - /*num_values*/ 0, status); - if (!status->status.ok()) return nullptr; - - int num_retvals = 1; - TFE_TensorHandle* queue = nullptr; - TFE_Execute(queueOp.get(), &queue, &num_retvals, status); - if (!status->status.ok()) return nullptr; - CHECK_EQ(num_retvals, 1); - - return queue; -} - -static void createTFEEnqueue(TFE_Context* ctx, TF_DataType inputType, - TFE_TensorHandle* queue, TFE_TensorHandle* tensor, - TF_Status* status) { - TFE_Op* op = TFE_NewOp(ctx, "QueueEnqueueV2", status); - if (!status->status.ok()) return; - std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op_deleter(op, TFE_DeleteOp); - TFE_OpSetDevice(op, DEFAULT_CPU_DEVICE, status); - if (!status->status.ok()) return; - TFE_OpAddInput(op, queue, status); - if (!status->status.ok()) return; - TFE_OpAddInput(op, tensor, status); - if (!status->status.ok()) return; - TFE_OpSetAttrTypeList(op, "Tcomponents", &inputType, 1); - TFE_OpSetAttrInt(op, "timeout_ms", -1); - - int num_retvals = 0; - TFE_Execute(op, nullptr /*retvals*/, &num_retvals, status); - if (!status->status.ok()) return; - CHECK_EQ(num_retvals, 0); -} - -static TFE_TensorHandle* createTFEDequeue(TFE_Context* ctx, - TF_DataType inputType, - TFE_TensorHandle* queue, - TF_Status* status) { - TFE_Op* op = TFE_NewOp(ctx, "QueueDequeueV2", status); - if (!status->status.ok()) return nullptr; - std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op_deleter(op, TFE_DeleteOp); - TFE_OpSetDevice(op, DEFAULT_CPU_DEVICE, status); - if (!status->status.ok()) return nullptr; - - TFE_OpAddInput(op, queue, status); - if (!status->status.ok()) return nullptr; - TFE_OpSetAttrTypeList(op, "component_types", &inputType, 1); - TFE_OpSetAttrInt(op, "timeout_ms", -1); - TFE_TensorHandle* ret; - int num_retvals = 1; - TFE_Execute(op, &ret, &num_retvals, status); - if (!status->status.ok()) return nullptr; - CHECK_EQ(num_retvals, 1); - return ret; -} - -TFE_TensorHandle* TFE_DequeueNamedTensor(TF_Session* session, int tensor_id, - TF_DataType inputType, - TF_Status* status) { - assert(session); - VLOG(1) << "Dequeuing data tensor with id " << tensor_id; - - auto ctx = TFE_CreateContextFromSession(session, status); - if (!status->status.ok()) return nullptr; - std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter( - ctx, TFE_DeleteContext); - - TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status); - if (!status->status.ok()) return nullptr; - std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> - queue_deleter(queue, TFE_DeleteTensorHandle); - - auto* ret = createTFEDequeue(ctx, inputType, queue, status); - return ret; -} - -TFE_TensorHandle* TFE_DequeueNamedTensorFromCtx(TFE_Context* ctx, int tensor_id, - TF_DataType inputType, - TF_Status* status) { - TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status); - if (!status->status.ok()) return nullptr; - std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> - queue_deleter(queue, TFE_DeleteTensorHandle); - - auto* ret = createTFEDequeue(ctx, inputType, queue, status); - - return ret; -} - -void TFE_EnqueueNamedTensor(TF_Session* session, int tensor_id, - TFE_TensorHandle* tensor, TF_Status* status) { - assert(session); - VLOG(1) << "Enqueuing data tensor with id " << tensor_id; - - auto ctx = TFE_CreateContextFromSession(session, status); - if (!status->status.ok()) return; - std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter( - ctx, TFE_DeleteContext); - - TF_DataType inputType = TFE_TensorHandleDataType(tensor); - TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status); - if (!status->status.ok()) return; - std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> - queue_deleter(queue, TFE_DeleteTensorHandle); - - createTFEEnqueue(ctx, inputType, queue, tensor, status); -} - -void TFE_EnqueueNamedTensorFromCtx(TFE_Context* ctx, int tensor_id, - TFE_TensorHandle* tensor, - TF_Status* status) { - VLOG(1) << "Enqueuing data tensor with id " << tensor_id; - - TF_DataType inputType = TFE_TensorHandleDataType(tensor); - TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status); - if (!status->status.ok()) return; - std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> - queue_deleter(queue, TFE_DeleteTensorHandle); - - createTFEEnqueue(ctx, inputType, queue, tensor, status); -} - -void TFE_EnqueueVariantTensor(TF_Session* session, int tensor_id, - TFE_TensorHandle* tensor, TF_Status* status) { - VLOG(1) << "Enqueuing variant tensor with id " << tensor_id; - - auto ctx = TFE_CreateContextFromSession(session, status); - if (!status->status.ok()) return; - std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter( - ctx, TFE_DeleteContext); - - TFE_TensorHandle* queue = createTFEQueue(ctx, TF_VARIANT, tensor_id, status); - if (!status->status.ok()) return; - std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> - queue_deleter(queue, TFE_DeleteTensorHandle); - - createTFEEnqueue(ctx, TF_VARIANT, queue, tensor, status); -} - -TFE_TensorHandle* TFE_DequeueVariantTensor(TF_Session* session, int tensor_id, - TF_Status* status) { - VLOG(1) << "Dequeuing variant tensor with id " << tensor_id; - - auto ctx = TFE_CreateContextFromSession(session, status); - if (!status->status.ok()) return nullptr; - std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter( - ctx, TFE_DeleteContext); - - TFE_TensorHandle* queue = createTFEQueue(ctx, TF_VARIANT, tensor_id, status); - if (!status->status.ok()) return nullptr; - std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)> - queue_deleter(queue, TFE_DeleteTensorHandle); - - return createTFEDequeue(ctx, TF_VARIANT, queue, status); -} - void TF_MakeInternalErrorStatus(TF_Status* status, const char* errMsg) { status->status = tensorflow::errors::Internal(errMsg); } @@ -622,10 +423,9 @@ void TF_AttrBuilderSetType(TF_AttrBuilder* builder, const char* attr_name, void TF_AttrBuilderSetTypeList(TF_AttrBuilder* builder, const char* attr_name, const TF_DataType* values, int num_values) { auto iter = builder->attr_names.insert(attr_name).first; - builder->Set( - (*iter).c_str(), - tensorflow::gtl::ArraySlice<const tensorflow::DataType>( - reinterpret_cast<const tensorflow::DataType*>(values), num_values)); + builder->Set(*iter, tensorflow::gtl::ArraySlice<const tensorflow::DataType>( + reinterpret_cast<const tensorflow::DataType*>(values), + num_values)); } void TF_AttrBuilderCheckCanRunOnDevice(TF_AttrBuilder* builder, diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h index 551a45d92c4..d0ffbf125fb 100644 --- a/tensorflow/c/c_api_experimental.h +++ b/tensorflow/c/c_api_experimental.h @@ -146,48 +146,6 @@ TF_CAPI_EXPORT extern void TF_EnqueueNamedTensor(TF_Session* session, // Create a serialized tensorflow.ServerDef proto. TF_Buffer* TFE_GetServerDef(const char* text_proto, TF_Status* status); -// TODO: remove this API in favor of the next one. -TF_CAPI_EXPORT extern TFE_Context* TFE_NewContextFromSession( - const TFE_ContextOptions* opts, TF_Session* sess, TF_Status* status); - -// Creates from `session` a new eager context to run a graph function or -// sends/recvs, so that these concurrent TFE executions can share (via -// `session` and its associated device mgr) the same set of fifo queue resource -// ops, used for host<->TF tensor transfers. This way the sends/recvs calls and -// graph function execution can access the same fifo queue resource handles -// (associated with devices managed by the device manager, which can be obtained -// from `session`). -// -// TODO: Remove this function once we migrate away from using session. -TF_CAPI_EXPORT extern TFE_Context* TFE_CreateContextFromSession( - TF_Session* session, TF_Status* status); - -// TODO: Retire this API in favor of the next one. -TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueNamedTensor( - TF_Session* session, int tensor_id, TF_DataType inputType, - TF_Status* status); - -TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueNamedTensorFromCtx( - TFE_Context* ctx, int tensor_id, TF_DataType inputType, TF_Status* status); - -TF_CAPI_EXPORT extern void TFE_EnqueueNamedTensor(TF_Session* session, - int tensor_id, - TFE_TensorHandle* tensor, - TF_Status* status); - -TF_CAPI_EXPORT extern void TFE_EnqueueNamedTensorFromCtx( - TFE_Context* ctx, int tensor_id, TFE_TensorHandle* tensor, - TF_Status* status); - -// TODO: consider folding the 2 APIs below into the ones above. -TF_CAPI_EXPORT extern void TFE_EnqueueVariantTensor(TF_Session* session, - int tensor_id, - TFE_TensorHandle* tensor, - TF_Status* status); - -TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueVariantTensor( - TF_Session* session, int tensor_id, TF_Status* status); - TF_CAPI_EXPORT extern void TF_MakeInternalErrorStatus(TF_Status* status, const char* errMsg); diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index fe4d5ac6ffe..b8429646960 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -144,6 +144,24 @@ cc_library( ], ) +cc_library( + name = "c_api_unified_internal", + hdrs = [ + "c_api_unified_experimental_internal.h", + ], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + ":c_api", + ":c_api_experimental", + "//tensorflow/c:c_api_internal", + "//tensorflow/c:tf_status", + "//tensorflow/core/platform:casts", + "//tensorflow/core/platform:types", + ], +) + cc_library( name = "tensor_handle_interface", hdrs = ["tensor_handle_interface.h"], @@ -319,6 +337,7 @@ tf_cuda_cc_test( tags = [ "noguitar", # TODO(b/155445984): flaky #"guitar", + "notap", # TODO(b/156981931): flaky "multi_gpu", ], deps = [ @@ -349,7 +368,10 @@ tf_cuda_cc_test( # TODO(b/136478427): Figure out how to correctly shut the server down args = ["--heap_check=local"], extra_copts = tfe_xla_copts(), - tags = ["noasan"], # leaks gRPC server instances + tags = [ + "noasan", # leaks gRPC server instances + "notsan", # b/157098283 + ], deps = [ ":c_api", ":c_api_experimental", @@ -357,10 +379,13 @@ tf_cuda_cc_test( ":c_api_test_util", ":tfe_tensorhandle_internal", "//tensorflow/c:c_test_util", + "//tensorflow/core:framework", + "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core/common_runtime:function_optimization_registry", "//tensorflow/core/common_runtime/eager:eager_operation", "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", "@com_google_absl//absl/strings", @@ -507,6 +532,7 @@ tf_cuda_cc_test( "//tensorflow/c:c_api", "//tensorflow/c:c_test_util", "//tensorflow/cc/profiler", + "//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 5c01ccb82bb..5a39c17e1d9 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -102,6 +102,15 @@ string DeviceName(const tensorflow::Device* d) { } #if !defined(IS_MOBILE_PLATFORM) +bool AreLocalDevicesCompatible(const tensorflow::EagerContext* context, + const tensorflow::ServerDef& server_def) { + if (server_def.job_name() != context->HostCPU()->parsed_name().job) { + return false; + } + return server_def.default_session_config().SerializeAsString() == + context->session_options().config.SerializeAsString(); +} + tensorflow::Status AddRemoteDevicesToMgr( const std::vector<string>& added_remote_workers, tensorflow::WorkerCacheInterface* worker_cache, @@ -469,10 +478,15 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); tensorflow::GrpcServer* grpc_server; if (reset_context) { - LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server)); + const tensorflow::DeviceMgr* device_mgr = + AreLocalDevicesCompatible(context, server_def) + ? context->local_device_mgr() + : nullptr; + LOG_AND_RETURN_IF_ERROR(tensorflow::NewServerWithOptions( + server_def, {device_mgr}, &new_server)); grpc_server = dynamic_cast<tensorflow::GrpcServer*>(new_server.get()); LOG_AND_RETURN_IF_ERROR( - ListRemoteWorkers(grpc_server, worker_name, &remote_workers)); + ListRemoteWorkers(new_server.get(), worker_name, &remote_workers)); } else { LOG_AND_RETURN_IF_ERROR(ListRemoteWorkers(context->GetServer(), worker_name, &curr_remote_workers)); @@ -727,24 +741,6 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { tensorflow::GetDefaultCustomKernelCreator())); } -TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts, - TF_Session* sess, TF_Status* status) { - const tensorflow::DeviceMgr* device_mgr = nullptr; - status->status = sess->session->LocalDeviceManager(&device_mgr); - if (!status->status.ok()) return nullptr; - tensorflow::Rendezvous* r = - new tensorflow::IntraProcessRendezvous(device_mgr); - - return tensorflow::wrap(new tensorflow::EagerContext( - opts->session_options.options, - static_cast<tensorflow::ContextDevicePlacementPolicy>( - opts->device_placement_policy), - static_cast<tensorflow::ContextMirroringPolicy>(opts->mirroring_policy), - opts->async, opts->lazy_remote_inputs_copy, device_mgr, - /*device_mgr_owned*/ false, r, - tensorflow::GetDefaultCustomKernelCreator())); -} - void TFE_DeleteContext(TFE_Context* ctx) { if (ctx == nullptr) { return; @@ -899,9 +895,7 @@ TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx, #if defined(IS_MOBILE_PLATFORM) status->status = tensorflow::Status::OK(); #else // !defined(IS_MOBILE_PLATFORM) - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - status->status = context->SyncExecutors(); + status->status = tensorflow::unwrap(ctx)->AsyncWait(); #endif // !IS_MOBILE_PLATFORM } diff --git a/tensorflow/c/eager/c_api_cluster_test.cc b/tensorflow/c/eager/c_api_cluster_test.cc index 252a0408758..f8c702d592a 100644 --- a/tensorflow/c/eager/c_api_cluster_test.cc +++ b/tensorflow/c/eager/c_api_cluster_test.cc @@ -41,7 +41,7 @@ tensorflow::ServerDef GetServerDef(const string& job_name, int num_tasks) { for (int i = 0; i < num_tasks; i++) { int port = tensorflow::testing::PickUnusedPortOrDie(); job_def->mutable_tasks()->insert( - {i, tensorflow::strings::StrCat("localhost:", port)}); + {i, tensorflow::strings::StrCat("localhost", ":", port)}); } return server_def; } @@ -430,4 +430,70 @@ TEST(CAPI, RemoteExecuteUpdateServerDefWithFailuresAsync) { TestRemoteExecuteUpdateServerDefWithFailures(true); } +void TestConnectToCluster(bool keep_localhost_for_first_connect) { + // Fail fast on GetStatus requests so we can get errors instead of timeout + // when updating cluster with non-exsitent worker + tensorflow::setenv("GRPC_FAIL_FAST", "TRUE", /*overwrite=*/1); + + const string first_name = + keep_localhost_for_first_connect ? "localhost" : "abc"; + tensorflow::ServerDef server_def = GetServerDef(first_name, 1); + + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); + TFE_Context* ctx = TFE_NewContext(opts, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + const string dev0_name = "/job:localhost/replica:0/task:0/device:CPU:0"; + TFE_TensorHandle* var_handle0 = TestVariable(ctx, 1.0, dev0_name); + EXPECT_NE(var_handle0, nullptr); + + tensorflow::Status status2; + EXPECT_EQ(tensorflow::unwrap(var_handle0)->DeviceName(&status2), dev0_name); + + // Rename local device + // This server def has the task index set to 0. + string serialized = server_def.SerializeAsString(); + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + const string dev1_name = + absl::StrCat("/job:", first_name, "/replica:0/task:0/device:CPU:0"); + TFE_TensorHandle* var_handle1 = TestVariable(ctx, 2.0, dev1_name); + EXPECT_NE(var_handle1, nullptr); + EXPECT_EQ(tensorflow::unwrap(var_handle1)->DeviceName(&status2), dev1_name); + + // Another renaming of local device + const string second_name = "def"; + server_def.set_job_name(second_name); + server_def.mutable_cluster()->mutable_job(0)->set_name(second_name); + (*server_def.mutable_cluster()->mutable_job(0)->mutable_tasks())[0] = + absl::StrCat(second_name, ":", + tensorflow::testing::PickUnusedPortOrDie()); + + serialized = server_def.SerializeAsString(); + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + const string dev2_name = "/job:def/replica:0/task:0/device:CPU:0"; + TFE_TensorHandle* var_handle2 = TestVariable(ctx, 2.0, dev2_name); + EXPECT_NE(var_handle2, nullptr); + EXPECT_EQ(tensorflow::unwrap(var_handle2)->DeviceName(&status2), dev2_name); + + TFE_DeleteTensorHandle(var_handle0); + TFE_DeleteTensorHandle(var_handle1); + TFE_DeleteTensorHandle(var_handle2); + + TFE_DeleteContext(ctx); + TF_DeleteStatus(status); + + tensorflow::unsetenv("GRPC_FAIL_FAST"); +} + +TEST(CAPI, ConnectToClusterLocalhostFirst) { TestConnectToCluster(false); } + +TEST(CAPI, ConnectToClusterRenameFirst) { TestConnectToCluster(true); } + } // namespace diff --git a/tensorflow/c/eager/c_api_remote_test.cc b/tensorflow/c/eager/c_api_remote_test.cc index d04e4ef4212..93d830d2c90 100644 --- a/tensorflow/c/eager/c_api_remote_test.cc +++ b/tensorflow/c/eager/c_api_remote_test.cc @@ -19,11 +19,16 @@ limitations under the License. #include "tensorflow/c/eager/c_api_test_util.h" #include "tensorflow/c/eager/tfe_tensorhandle_internal.h" #include "tensorflow/core/common_runtime/eager/eager_operation.h" +#include "tensorflow/core/common_runtime/function_optimization_registry.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/platform/casts.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/protobuf/cluster.pb.h" +#include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/protobuf/tensorflow_server.pb.h" namespace { @@ -574,6 +579,181 @@ TEST(CAPI, TestRemoteFunctionWithPackedInput) { TestFunctionWithPackedInput(/*remote=*/true); } +string VariableAddFunction() { + tensorflow::FunctionDef def; + CHECK(tensorflow::protobuf::TextFormat::ParseFromString( + " signature {" + " name: 'VariableAddFunction'" + " input_arg {" + " name: 'var0'" + " type: DT_RESOURCE" + " }" + " output_arg {" + " name: 'var0_value'" + " type: DT_FLOAT" + " }" + " }" + " node_def {" + " name: 'read0'" + " op: 'ReadVariableOp'" + " input: 'var0'" + " attr {" + " key: 'dtype'" + " value {" + " type: DT_FLOAT" + " }" + " }" + " }" + " node_def {" + " name: 'add'" + " op: 'Add'" + " input: 'read0:value:0'" + " input: 'read0:value:0'" + " device: '/job:localhost/task:1/device:CPU:0'" + " attr {" + " key: 'T'" + " value {" + " type: DT_FLOAT" + " }" + " }" + " }" + " node_def {" + " name: 'identity'" + " op: 'Identity'" + " input: 'add:z:0'" + " device: '/job:localhost/task:0/device:CPU:0'" + " attr {" + " key: 'T'" + " value {" + " type: DT_FLOAT" + " }" + " }" + " }" + " ret {" + " key: 'var0_value'" + " value: 'identity:output:0'" + " }", + &def)); + return def.SerializeAsString(); +} + +class FunctionErrorInjectionPass : public tensorflow::FunctionOptimizationPass { + public: + FunctionErrorInjectionPass(string error_node, string error_device) + : error_node_(error_node), error_device_(error_device) {} + tensorflow::Status Run(const tensorflow::DeviceSet& device_set, + const tensorflow::ConfigProto& config_proto, + std::unique_ptr<tensorflow::Graph>* graph, + tensorflow::FunctionLibraryDefinition* flib_def, + std::vector<std::string>* control_ret_node_names, + bool* control_rets_updated) override { + // Inject failure to function instantiation if finding a node that contains + // the given node name (error_node_) and requested device (error_device_). + for (const auto node : graph->get()->nodes()) { + if (node->name().find(error_node_) != string::npos && + node->requested_device() == error_device_) { + return tensorflow::errors::Internal("Injected graph pass error."); + } + } + return tensorflow::Status::OK(); + } + + private: + const string error_node_; + const string error_device_; +}; + +void TestDistributedFunctionCancellation(bool inject_error) { + tensorflow::ServerDef server_def = GetServerDef(3); + // This server def has the task index set to 0. + string serialized = server_def.SerializeAsString(); + + server_def.set_task_index(1); + std::unique_ptr<tensorflow::GrpcServer> worker_server1; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server1) + .ok()); + ASSERT_TRUE(worker_server1->Start().ok()); + server_def.set_task_index(2); + std::unique_ptr<tensorflow::GrpcServer> worker_server2; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server2) + .ok()); + ASSERT_TRUE(worker_server2->Start().ok()); + const char dev2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0"; + + if (inject_error) { + // Inject a function optimization pass failure when it sees the 'read0' op + // having a requested device `dev2_name`. During execution: + // * task:0 processes the main function `VariableAddFunction` and places + // the read0 op on task:2 + // * task:0 partitions the main function with a subgraph containing read0 + // sent to task:2 + // * task:2 graph pass reports an error when it sees read0 with dev2_name + tensorflow::function_optimization_registration:: + FunctionOptimizationPassRegistration register_test_pass( + std::make_unique<FunctionErrorInjectionPass>("read0", dev2_name)); + } + + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); + TFE_Context* ctx = TFE_NewContext(opts, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_TensorHandle* var_handle = TestVariable(ctx, 2.0, dev2_name); + EXPECT_NE(var_handle, nullptr); + + const string function_def = VariableAddFunction(); + TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(), + status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + TFE_Op* func = TFE_NewOp(ctx, "VariableAddFunction", status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + TFE_OpAddInput(func, var_handle, status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + TFE_TensorHandle* retvals[1] = {nullptr}; + int num_retvals = 1; + TFE_Execute(func, &retvals[0], &num_retvals, status); + + if (inject_error) { + ASSERT_EQ(TF_INTERNAL, TF_GetCode(status)) << TF_Message(status); + } else { + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + ASSERT_EQ(1, num_retvals); + TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteTensorHandle(retvals[0]); + float sum = 0; + ASSERT_EQ(sizeof(sum), TF_TensorByteSize(t)); + memcpy(&sum, TF_TensorData(t), TF_TensorByteSize(t)); + TF_DeleteTensor(t); + ASSERT_EQ(sum, 4.0); + } + + TFE_DeleteOp(func); + TFE_DeleteTensorHandle(var_handle); + TFE_DeleteContext(ctx); + TF_DeleteStatus(status); + + // TODO(b/136478427): Figure out how to correctly shut the server down. + worker_server1.release(); + worker_server2.release(); +} + +TEST(CAPI, DistributedFunctionNoError) { + TestDistributedFunctionCancellation(false); +} + +TEST(CAPI, DistributedFunctionCancelledOnError) { + TestDistributedFunctionCancellation(true); +} + void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) { tensorflow::ServerDef server_def = GetServerDef(2); diff --git a/tensorflow/c/eager/c_api_unified_experimental_internal.h b/tensorflow/c/eager/c_api_unified_experimental_internal.h index 49212a230ee..8fc696f0f2f 100644 --- a/tensorflow/c/eager/c_api_unified_experimental_internal.h +++ b/tensorflow/c/eager/c_api_unified_experimental_internal.h @@ -58,7 +58,7 @@ T* dyncast(S source) { // GraphContext and vice-versa). class AbstractTensor { protected: - enum AbstractTensorKind { kGraphTensor, kEagerTensor, kMLIRTensor }; + enum AbstractTensorKind { kMlirTensor, kGraphTensor, kEagerTensor }; explicit AbstractTensor(AbstractTensorKind kind) : kind_(kind) {} public: @@ -101,7 +101,7 @@ class AbstractFunction { // on a given context, with the same or different input tensors. class AbstractOp { protected: - enum AbstractOpKind { kGraphOp, kEagerOp }; + enum AbstractOpKind { kMlirOp, kGraphOp, kEagerOp }; explicit AbstractOp(AbstractOpKind kind) : kind_(kind) {} public: @@ -129,7 +129,7 @@ class AbstractOp { // eager implementation or to a graph implementation. struct ExecutionContext { protected: - enum ExecutionContextKind { kGraphContext, kEagerContext }; + enum ExecutionContextKind { kMlirContext, kGraphContext, kEagerContext }; explicit ExecutionContext(ExecutionContextKind kind) : k(kind) {} public: diff --git a/tensorflow/c/eager/c_api_unified_experimental_test.cc b/tensorflow/c/eager/c_api_unified_experimental_test.cc index 9776b4d13ed..24d170f2f99 100644 --- a/tensorflow/c/eager/c_api_unified_experimental_test.cc +++ b/tensorflow/c/eager/c_api_unified_experimental_test.cc @@ -477,7 +477,8 @@ TEST_P(UnifiedCAPI, TestExecutingGraphOpInEagerModeRaises) { TF_DeleteExecutionContext(eager_execution_ctx); } -INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI, ::testing::Values("graphdef")); +INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI, + ::testing::Values("graphdef", "mlir")); } // namespace } // namespace tensorflow diff --git a/tensorflow/c/eager/context_interface.h b/tensorflow/c/eager/context_interface.h index d21ab45e579..2861fa43b66 100644 --- a/tensorflow/c/eager/context_interface.h +++ b/tensorflow/c/eager/context_interface.h @@ -101,6 +101,9 @@ class AbstractContextInterface { // Destroy the step resource container for a training step. virtual void EndStep() = 0; + // Block until all pending nodes are finished. + virtual Status AsyncWait() = 0; + protected: virtual ~AbstractContextInterface() {} }; diff --git a/tensorflow/c/experimental/network.cc b/tensorflow/c/experimental/network.cc index 94375cf9983..97e63ec6259 100644 --- a/tensorflow/c/experimental/network.cc +++ b/tensorflow/c/experimental/network.cc @@ -108,7 +108,7 @@ class CServerFactory : public ServerFactory { delete_function_(delete_function), rendezvous_builder_(rendezvous_builder) {} - Status NewServer(const ServerDef& server_def, + Status NewServer(const ServerDef& server_def, const Options& options, std::unique_ptr<ServerInterface>* out_server) override { TF_RETURN_IF_ERROR(CGrpcServer::Create( server_def, init_function_, start_function_, stop_function_, diff --git a/tensorflow/c/experimental/saved_model/internal/BUILD b/tensorflow/c/experimental/saved_model/internal/BUILD index 5c51e26f925..2ded784882b 100644 --- a/tensorflow/c/experimental/saved_model/internal/BUILD +++ b/tensorflow/c/experimental/saved_model/internal/BUILD @@ -155,6 +155,7 @@ cc_library( "saved_model_api_type.h", ], deps = [ + "//tensorflow/c:conversion_macros", "//tensorflow/c/experimental/saved_model/core:saved_model_api", ], ) diff --git a/tensorflow/c/experimental/saved_model/internal/saved_model_api.cc b/tensorflow/c/experimental/saved_model/internal/saved_model_api.cc index 629610dbe29..9614e507646 100644 --- a/tensorflow/c/experimental/saved_model/internal/saved_model_api.cc +++ b/tensorflow/c/experimental/saved_model/internal/saved_model_api.cc @@ -41,7 +41,7 @@ TF_SavedModel* TF_LoadSavedModel(const char* dirname, TFE_Context* ctx, if (!status->status.ok()) { return nullptr; } - return new TF_SavedModel{std::move(result)}; + return tensorflow::wrap(result.release()); } TF_SavedModel* TF_LoadSavedModelWithTags(const char* dirname, TFE_Context* ctx, @@ -60,17 +60,19 @@ TF_SavedModel* TF_LoadSavedModelWithTags(const char* dirname, TFE_Context* ctx, if (!status->status.ok()) { return nullptr; } - return new TF_SavedModel{std::move(result)}; + return tensorflow::wrap(result.release()); } -void TF_DeleteSavedModel(TF_SavedModel* model) { delete model; } +void TF_DeleteSavedModel(TF_SavedModel* model) { + delete tensorflow::unwrap(model); +} TF_ConcreteFunction* TF_GetSavedModelConcreteFunction(TF_SavedModel* model, const char* function_path, TF_Status* status) { tensorflow::ConcreteFunction* result = nullptr; tensorflow::Status get_function_status = - model->saved_model->GetFunction(function_path, &result); + tensorflow::unwrap(model)->GetFunction(function_path, &result); status->status.Update(get_function_status); if (!get_function_status.ok()) { return nullptr; @@ -82,7 +84,8 @@ TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction( TF_SavedModel* model, const char* signature_def_key, TF_Status* status) { tensorflow::ConcreteFunction* result = nullptr; tensorflow::Status get_function_status = - model->saved_model->GetSignatureDefFunction(signature_def_key, &result); + tensorflow::unwrap(model)->GetSignatureDefFunction(signature_def_key, + &result); status->status.Update(get_function_status); if (!get_function_status.ok()) { return nullptr; @@ -91,7 +94,8 @@ TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction( } TF_ConcreteFunctionList* TF_ListSavedModelFunctions(TF_SavedModel* model) { - return new TF_ConcreteFunctionList{model->saved_model->ListFunctions()}; + return new TF_ConcreteFunctionList{ + tensorflow::unwrap(model)->ListFunctions()}; } } // end extern "C" diff --git a/tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h b/tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h index 9e2d1117463..380c3703426 100644 --- a/tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h +++ b/tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h @@ -18,13 +18,18 @@ limitations under the License. #include <memory> +#include "tensorflow/c/conversion_macros.h" #include "tensorflow/c/experimental/saved_model/core/saved_model_api.h" // Internal structures used by the SavedModel C API. These are likely to change // and should not be depended on. -struct TF_SavedModel { - std::unique_ptr<tensorflow::SavedModelAPI> saved_model; -}; +typedef struct TF_SavedModel TF_SavedModel; + +namespace tensorflow { + +DEFINE_CONVERSION_FUNCTIONS(tensorflow::SavedModelAPI, TF_SavedModel) + +} // namespace tensorflow #endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SAVED_MODEL_API_TYPE_H_ diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index abccefbcdbb..f2b28e70ff1 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -20,7 +20,7 @@ load( "tf_cc_test", "tf_copts", ) -load("//tensorflow:tensorflow.bzl", "tfcompile_extra_flags") +load("//tensorflow:tensorflow.bzl", "tfcompile_target_cpu") def tf_library( name, @@ -42,7 +42,8 @@ def tf_library( mlir_components = "None", deps = None, tags = []): - """Runs tfcompile to compile a TensorFlow graph into executable code. + """Runs tfcompile to compile a TensorFlow graph into executable code with fast + math enabled on cpu. Given an invocation of tf_library(name="foo", ...), generates the following build targets: @@ -187,7 +188,9 @@ def tf_library( # `find` on such an object. need_xla_data_proto = flags and flags.find("--gen_program_shape") != -1 - flags = tfcompile_extra_flags() + flags + target_cpu = tfcompile_target_cpu() + extra_flags = "--target_cpu=" + target_cpu + " " if target_cpu else " " + flags = extra_flags + flags if enable_xla_hlo_profiling: profiling_flag = "--xla_hlo_profile" @@ -207,6 +210,15 @@ def tf_library( srcs.append(debug_info) debug_info_flag = " --debug_info=$(location " + debug_info + ")" + default_fast_math_xla_flags = ("XLA_FLAGS='" + + "--xla_cpu_enable_fast_math=true " + + "--xla_cpu_fast_math_honor_nans=false " + + "--xla_cpu_fast_math_honor_infs=false " + + "--xla_cpu_fast_math_honor_functions=false " + + "--xla_cpu_fast_math_honor_division=false " + + "--xla_cpu_enable_fast_min_max=true " + + "$${XLA_FLAGS:-}' ") + native.genrule( name = ("gen_" + name), srcs = srcs, @@ -216,6 +228,7 @@ def tf_library( function_object_file, ], cmd = ( + default_fast_math_xla_flags + "CUDA_VISIBLE_DEVICES='' " + "$(location " + tfcompile_tool + ")" + " --graph=$(location " + tfcompile_graph + ")" + @@ -256,6 +269,7 @@ def tf_library( session_module_pb, ], cmd = ( + default_fast_math_xla_flags + "CUDA_VISIBLE_DEVICES='' " + "$(location " + tfcompile_tool + ")" + " --graph=$(location " + tfcompile_graph + ")" + diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc index f0cf8f2ded9..846947454bb 100644 --- a/tensorflow/compiler/aot/tfcompile_main.cc +++ b/tensorflow/compiler/aot/tfcompile_main.cc @@ -67,6 +67,8 @@ int main(int argc, char** argv) { flags.entry_point = "entry"; flags.debug_info_path_begin_marker = ""; + // Note that tfcompile.bzl's tf_library macro sets fast math flags as that is + // generally the preferred case. std::vector<tensorflow::Flag> flag_list; AppendMainFlags(&flag_list, &flags); xla::AppendDebugOptionsFlags(&flag_list); diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index bc8fac0e88f..5ec0575ed77 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -505,6 +505,7 @@ cc_library( name = "shape_inference", srcs = ["shape_inference.cc"], hdrs = ["shape_inference.h"], + visibility = [":friends"], deps = [ ":shape_inference_helpers", "//tensorflow/compiler/xla:statusor", diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 174250f18bd..9f5723f4fa4 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -2034,6 +2034,7 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() { "TensorArraySplitV3", "TensorArrayV3", "TensorArrayWriteV3", + "TensorListConcatV2", "TensorListElementShape", "TensorListFromTensor", "TensorListGather", @@ -2043,6 +2044,7 @@ absl::flat_hash_set<string> GetKnownXLAWhitelistOp() { "TensorListPushBack", "TensorListReserve", "TensorListSetItem", + "TensorListSplit", "TensorListStack", "TensorScatterAdd", "TensorScatterSub", diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD index c0066ecda03..c4472e1185c 100644 --- a/tensorflow/compiler/mlir/BUILD +++ b/tensorflow/compiler/mlir/BUILD @@ -104,6 +104,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@llvm-project//llvm:support", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Shape", "@llvm-project//mlir:StandardOps", ], alwayslink = 1, diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 9b5b0c209e5..6eff7dbd084 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -260,6 +260,41 @@ cc_library( ], ) +cc_library( + name = "tftext_utils", + srcs = [ + "utils/tftext_utils.cc", + ], + hdrs = [ + "utils/tftext_utils.h", + ], + copts = ["-std=c++14"], + deps = [ + ":tensorflow_lite", + "//tensorflow/compiler/mlir/tensorflow", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + ], +) + +tf_cc_test( + name = "tftext_utils_test", + size = "small", + srcs = ["utils/lstm_utils_test.cc"], + deps = [ + ":lstm_utils", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + ], +) + cc_library( name = "stateful_ops_utils", srcs = [ @@ -320,6 +355,7 @@ cc_library( ":lstm_utils", ":stateful_ops_utils", ":tensorflow_lite", + ":tftext_utils", ":validators", "//tensorflow/compiler/mlir:op_or_arg_name_mapper", "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc index 6a631b1433d..df84b028f63 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc @@ -799,11 +799,6 @@ Optional<CustomOptionsOffset> Translator::CreateFlexOpCustomOptions( Optional<CustomOptionsOffset> Translator::CreateCustomOpCustomOptions( const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) { - std::string node_def_str; - if (!node_def.SerializeToString(&node_def_str)) { - return emitError(loc, "failed to serialize tensorflow node_def"), - llvm::None; - } auto flex_builder = CreateFlexBuilderWithNodeAttrs(node_def, loc); return builder_.CreateVector(flex_builder->GetBuffer()); } @@ -813,9 +808,13 @@ Translator::CreateFlexBuilderWithNodeAttrs( const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) { auto flex_builder = absl::make_unique<flexbuffers::Builder>(); size_t map_start = flex_builder->StartMap(); - for (const auto& pair : node_def.attr()) { + using Item = std::pair<std::string, ::tensorflow::AttrValue>; + std::vector<Item> attrs(node_def.attr().begin(), node_def.attr().end()); + std::sort(attrs.begin(), attrs.end(), + [](Item& p1, Item& p2) -> bool { return p1.first < p2.first; }); + for (const Item& pair : attrs) { const char* key = pair.first.c_str(); - const auto& attr = pair.second; + const ::tensorflow::AttrValue& attr = pair.second; switch (attr.value_case()) { case ::tensorflow::AttrValue::kS: flex_builder->String(key, attr.s()); diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.h b/tensorflow/compiler/mlir/lite/ir/tfl_ops.h index 0e6a3db1f1b..c7a1504c3b7 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.h +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.h @@ -27,7 +27,7 @@ limitations under the License. #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project #include "mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project -#include "mlir/Interfaces/SideEffects.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h" #include "tensorflow/lite/schema/schema_generated.h" diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index a585b8e1520..923efdbaf9d 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -254,6 +254,14 @@ class TFL_TFOperandTypesWithSameBits<int i, int j, int num> : Or<[CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isa<mlir::TF::Quint" # num # "Type>()">, CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isUnsignedInteger(" # num # ")">]>]>; +class TFL_OperandIsNoneOrHasRank<int n, int m> : + PredOpTrait<"operand " # n # " is " # m # "-D", + Or<[ + CPred<"$_op.getOperand(" # n # ").getType().isa<NoneType>()">, + TFL_OperandIsUnrankedPred<n>, + CPred<"$_op.getOperand(" # n # + ").getType().cast<ShapedType>().getRank() == " # m>]>>; + class TFL_OperandIsNoneOrHasRankAtMost<int n, int m> : PredOpTrait<"operand " # n # " is at most " # m # "-D", Or<[ @@ -285,14 +293,18 @@ def TFL_FloatNonNegative : AttrConstraint< CPred<"!$_self.cast<FloatAttr>().getValue().isNegative()">, "whose value is non-negative">; -def TFL_BoolTrue: AttrConstraint< +def TFL_BoolTrue : AttrConstraint< CPred<"$_self.cast<BoolAttr>().getValue()">, "whose value is true">; -def TFL_BoolFalse: AttrConstraint< +def TFL_BoolFalse : AttrConstraint< CPred<"!$_self.cast<BoolAttr>().getValue()">, "whose value is false">; +class TFL_StringEqualsTo<string value> : AttrConstraint< + CPred<"$_self.cast<StringAttr>().getValue() == \"" # value # "\"">, + "whose value equals to '" # value # "'">; + // This is a quantization-aware version of TCresVTEtIsSameAsOp class TFL_TCresVTEtIsSameAsOp<int i, int j> : And<[ TCOpResIsShapedTypePred<i, j>, @@ -1892,7 +1904,10 @@ def TFL_OneHotOp : TFL_Op<"one_hot", [NoSideEffect]> { let hasOptions = 1; } -def TFL_RoundOp: TFL_Op<"round", [NoSideEffect, SameOperandsAndResultType]> { +def TFL_RoundOp: TFL_Op<"round", [ + NoSideEffect, + SameOperandsAndResultShape, + SameOperandsAndResultType]> { let summary = "Round operator"; let description = [{ @@ -1909,7 +1924,14 @@ Rounds the values of a tensor to the nearest integer, element-wise. } def TFL_SliceOp : TFL_Op<"slice", [ - NoSideEffect, SameOperandsAndResultsScale, TFL_GpuTargetOp]> { + PredOpTrait<"input and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, + NoSideEffect, + SameOperandsAndResultsScale, + TFL_OperandHasRankAtMost<0, 4>, + TFL_OperandHasRankAtMost<1, 1>, + TFL_OperandHasRankAtMost<2, 1>, + TFL_GpuTargetOp]> { let summary = "Return a slice from 'input'."; let description = [{ @@ -1927,13 +1949,13 @@ equivalent to setting: }]; let arguments = (ins - AnyTensor:$input, + TFL_TensorOf<[F32, I32, I64, I8, UI8, I1, TFL_Str, QI8, QUI8, TFL_Quint8]>:$input, TFL_I32OrI64Tensor:$begin, TFL_I32OrI64Tensor:$size ); let results = (outs - AnyTensor:$output + TFL_TensorOf<[F32, I32, I64, I8, UI8, I1, TFL_Str, QI8, QUI8, TFL_Quint8]>:$output ); let verifier = [{ return Verify(*this); }]; @@ -1961,7 +1983,10 @@ def TFL_SumOp: TFL_Op<"sum", [NoSideEffect]> { } def TFL_ReduceMinOp: TFL_Op<"reduce_min", [ - NoSideEffect, SameOperandsAndResultsScale]> { + PredOpTrait<"input and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, + NoSideEffect, + SameOperandsAndResultsScale]> { let summary = "Min-reduction operator"; let description = [{ @@ -1969,19 +1994,23 @@ def TFL_ReduceMinOp: TFL_Op<"reduce_min", [ }]; let arguments = (ins - AnyTensor:$input, + TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input, TFL_I32Tensor:$axes, BoolAttr:$keep_dims ); - let results = (outs AnyTensor); + let results = (outs + TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output); let hasOptions = 1; let customOption = "ReducerOptions"; } def TFL_ReduceMaxOp: TFL_Op<"reduce_max", [ - NoSideEffect, SameOperandsAndResultsScale]> { + PredOpTrait<"input and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, + NoSideEffect, + SameOperandsAndResultsScale]> { let summary = "Max-reduction operator"; let description = [{ @@ -1989,18 +2018,22 @@ def TFL_ReduceMaxOp: TFL_Op<"reduce_max", [ }]; let arguments = (ins - AnyTensor:$input, + TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input, TFL_I32Tensor:$axes, BoolAttr:$keep_dims ); - let results = (outs AnyTensor); + let results = (outs + TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output); let hasOptions = 1; let customOption = "ReducerOptions"; } -def TFL_ReduceProdOp: TFL_Op<"reduce_prod", [NoSideEffect]> { +def TFL_ReduceProdOp: TFL_Op<"reduce_prod", [ + PredOpTrait<"input and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, + NoSideEffect]> { let summary = "Prod-reduction operator"; let description = [{ @@ -2008,12 +2041,13 @@ def TFL_ReduceProdOp: TFL_Op<"reduce_prod", [NoSideEffect]> { }]; let arguments = (ins - TFL_TensorOf<[F32, I8, I32, I64]>:$input, + TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input, TFL_I32Tensor:$axes, BoolAttr:$keep_dims ); - let results = (outs AnyTensor); + let results = (outs + TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output); let hasOptions = 1; let customOption = "ReducerOptions"; @@ -2308,10 +2342,13 @@ def TFL_RankOp: TFL_Op<"rank", [NoSideEffect]> { let hasFolder = 1; } -def TFL_ReluOp: TFL_Op<"relu", [NoSideEffect, - SameOperandsAndResultShape, - SameOperandsAndResultsScale, - TFL_GpuTargetOp]> { +def TFL_ReluOp: TFL_Op<"relu", [ + PredOpTrait<"x and y must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, + NoSideEffect, + SameOperandsAndResultShape, + SameOperandsAndResultsScale, + TFL_GpuTargetOp]> { let summary = "Relu operator"; let description = [{ @@ -2319,9 +2356,9 @@ def TFL_ReluOp: TFL_Op<"relu", [NoSideEffect, x -> max(0, x) }]; - let arguments = (ins TFL_TensorOf<[F32, QUI8, I8]>:$x); + let arguments = (ins TFL_TensorOf<[F32, QUI8, QI8]>:$x); - let results = (outs TFL_TensorOf<[F32, QUI8, I8]>:$y); + let results = (outs TFL_TensorOf<[F32, QUI8, QI8]>:$y); // This builder doesn't work with quantized type, so it can only be used by // non-quantization tablegen patterns. Currently, it is used by the @@ -2335,10 +2372,13 @@ def TFL_ReluOp: TFL_Op<"relu", [NoSideEffect, ]; } -def TFL_Relu6Op: TFL_Op<"relu6", [NoSideEffect, - SameOperandsAndResultShape, - SameOperandsAndResultsScale, - TFL_GpuTargetOp]> { +def TFL_Relu6Op: TFL_Op<"relu6", [ + PredOpTrait<"x and y must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, + NoSideEffect, + SameOperandsAndResultShape, + SameOperandsAndResultsScale, + TFL_GpuTargetOp]> { let summary = "Relu6 operator"; let description = [{ @@ -2346,9 +2386,9 @@ def TFL_Relu6Op: TFL_Op<"relu6", [NoSideEffect, x -> max(0, min(6, x)) }]; - let arguments = (ins TFL_TensorOf<[F32, QUI8, I8]>:$x); + let arguments = (ins TFL_TensorOf<[F32, QUI8, QI8]>:$x); - let results = (outs TFL_TensorOf<[F32, QUI8, I8]>:$y); + let results = (outs TFL_TensorOf<[F32, QUI8, QI8]>:$y); // This builder doesn't work with quantized type, so it can only be used by // non-quantization tablegen patterns. Currently, it is used by the @@ -2362,9 +2402,12 @@ def TFL_Relu6Op: TFL_Op<"relu6", [NoSideEffect, ]; } -def TFL_Relu1Op: TFL_Op<"relu_n1_to_1", [NoSideEffect, - SameOperandsAndResultShape, - SameOperandsAndResultsScale]> { +def TFL_Relu1Op: TFL_Op<"relu_n1_to_1", [ + PredOpTrait<"x and y must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, + NoSideEffect, + SameOperandsAndResultShape, + SameOperandsAndResultsScale]> { let summary = "Relu1 operator"; let description = [{ @@ -2372,9 +2415,9 @@ def TFL_Relu1Op: TFL_Op<"relu_n1_to_1", [NoSideEffect, x -> max(-1, min(1, x)) }]; - let arguments = (ins TFL_TensorOf<[F32, QUI8, I8]>:$x); + let arguments = (ins TFL_TensorOf<[F32, QUI8, QI8]>:$x); - let results = (outs TFL_TensorOf<[F32, QUI8, I8]>:$y); + let results = (outs TFL_TensorOf<[F32, QUI8, QI8]>:$y); // This builder doesn't work with quantized type, so it can only be used by // non-quantization tablegen patterns. Currently, it is used by the @@ -2406,7 +2449,11 @@ def TFL_ReshapeOp: TFL_Op<"reshape", [ let hasFolder = 1; } -def TFL_ReverseSequenceOp : TFL_Op<"reverse_sequence", [NoSideEffect]> { +def TFL_ReverseSequenceOp : TFL_Op<"reverse_sequence", [ + PredOpTrait<"input and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, + NoSideEffect, + TFL_OperandHasRank<1, 1>]> { let summary = "Reverses variable length slices."; let description = [{ @@ -2423,15 +2470,15 @@ slice `i`, with the first `seq_lengths[i]` slices along dimension }]; let arguments = (ins - TFL_TensorOf<[F32, I16, I32, I64, TFL_Uint8]>:$input, + TFL_TensorOf<[F32, I32, I64, QI16, QUI8, TFL_Quint8]>:$input, TFL_I32OrI64Tensor:$seq_lengths, - I32Attr:$seq_dim, - I32Attr:$batch_dim + Confined<I32Attr, [IntNonNegative]>:$seq_dim, + Confined<I32Attr, [IntNonNegative]>:$batch_dim ); let results = (outs - TFL_TensorOf<[F32, I16, I32, I64, TFL_Uint8]>:$output + TFL_TensorOf<[F32, I32, I64, QI16, QUI8, TFL_Quint8]>:$output ); let hasOptions = 1; @@ -2439,6 +2486,7 @@ slice `i`, with the first `seq_lengths[i]` slices along dimension def TFL_RsqrtOp: TFL_Op<"rsqrt", [NoSideEffect, SameOperandsAndResultType, + SameOperandsAndResultShape, NoQuantizableResult, TFL_GpuTargetOp]> { let summary = "Reciprocal of square root operator"; @@ -2463,7 +2511,7 @@ def TFL_ShapeOp: TFL_Op<"shape", [NoSideEffect]> { let arguments = (ins AnyTensor:$input); - let results = (outs AnyTensor:$output); + let results = (outs TFL_TensorOf<[I32, I64]>:$output); DerivedTypeAttr out_type = DerivedTypeAttr<[{ return getResult().getType().cast<TensorType>().getElementType(); @@ -2472,9 +2520,11 @@ def TFL_ShapeOp: TFL_Op<"shape", [NoSideEffect]> { let hasOptions = 1; } -// TODO(jpienaar): Flesh this out. -def TFL_RangeOp: TFL_Op<"range", [NoSideEffect, TFL_OperandHasRank<0, 0>, - TFL_OperandHasRank<1, 0>, TFL_OperandHasRank<2, 0>, +def TFL_RangeOp: TFL_Op<"range", [ + NoSideEffect, + TFL_OperandHasRank<0, 0>, + TFL_OperandHasRank<1, 0>, + TFL_OperandHasRank<2, 0>, PredOpTrait<"operands and output must have same element type", And<[TCresVTEtIsSameAsOp<0, 0>, TCresVTEtIsSameAsOp<0, 1>, TCresVTEtIsSameAsOp<0, 2>]>>]> { @@ -2486,17 +2536,20 @@ def TFL_RangeOp: TFL_Op<"range", [NoSideEffect, TFL_OperandHasRank<0, 0>, }]; let arguments = (ins - AnyTensor:$start, - AnyTensor:$limit, - AnyTensor:$delta); + TFL_TensorOf<[I32, F32]>:$start, + TFL_TensorOf<[I32, F32]>:$limit, + TFL_TensorOf<[I32, F32]>:$delta); - let results = (outs AnyTensor:$result); + let results = (outs TFL_TensorOf<[I32, F32]>:$result); let hasFolder = 1; } -def TFL_ReverseV2Op: TFL_Op<"reverse_v2", - [NoSideEffect, TFL_OperandHasRank<1,1>]> { +def TFL_ReverseV2Op: TFL_Op<"reverse_v2", [ + PredOpTrait<"input and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, + NoSideEffect, + TFL_OperandHasRank<1, 1>]> { let summary = "ReverseV2 Operator"; let description = [{ @@ -2518,18 +2571,18 @@ def TFL_ReverseV2Op: TFL_Op<"reverse_v2", let arguments = ( ins - TFL_TensorOf<[F32, I16, I32, I64, TFL_Uint8, I1]>:$input, - TFL_TensorOf<[I32, I64]>:$axis + TFL_TensorOf<[F32, UI8, I16, I32, I64, QI16, QUI8, TFL_Quint8, I1]>:$input, + TFL_I32Tensor:$axis ); let results = (outs - TFL_TensorOf<[F32, I16, I32, I64, TFL_Uint8, I1]>:$output - ); + TFL_TensorOf<[F32, UI8, I16, I32, I64, QI16, QUI8, TFL_Quint8, I1]>:$output); } // Select has many instances in TF models where one or more of its operands // are unranked. Therefore, we skip adding shape constraints here. -def TFL_SelectOp : TFL_Op<"select", [NoSideEffect, +def TFL_SelectOp : TFL_Op<"select", [ + NoSideEffect, PredOpTrait<"operands have same element type", TCopVTEtIsSameAs<1, 2>>, PredOpTrait<"operands and result have same element type", TCresVTEtIsSameAsOp<0, 1>>]> { @@ -2545,9 +2598,11 @@ def TFL_SelectOp : TFL_Op<"select", [NoSideEffect, let arguments = (ins TFL_BoolTensor:$condition, - TFL_TensorOf<[F32, I1, I8, I16, I32, I64, TFL_Uint8]>:$x, - TFL_TensorOf<[F32, I1, I8, I16, I32, I64, TFL_Uint8]>:$y); - let results = (outs AnyTensor:$output); + TFL_TensorOf<[F32, I1, I8, I16, I32, I64, QI8, QUI8, QI16, TFL_Quint8]>:$x, + TFL_TensorOf<[F32, I1, I8, I16, I32, I64, QI8, QUI8, QI16, TFL_Quint8]>:$y); + + let results = (outs + TFL_TensorOf<[F32, I1, I8, I16, I32, I64, QI8, QUI8, QI16, TFL_Quint8]>:$output); // TODO(jpienaar): autogenerate this. let builders = [OpBuilder<"OpBuilder &builder, OperationState &result, " @@ -2561,7 +2616,12 @@ def TFL_SelectOp : TFL_Op<"select", [NoSideEffect, let hasOptions = 1; } -def TFL_SelectV2Op : TFL_Op<"select_v2", [NoSideEffect]> { +def TFL_SelectV2Op : TFL_Op<"select_v2", [ + NoSideEffect, + TFL_BinaryOperandsHaveSameShapesOrBroadcastableShape<1, 2, 4>, + PredOpTrait<"operands have same element type", TCopVTEtIsSameAs<1, 2>>, + PredOpTrait<"operands and result have same element type", + TCresVTEtIsSameAsOp<0, 1>>]> { let summary = "SelectV2 operator"; let description = [{ @@ -2574,9 +2634,11 @@ def TFL_SelectV2Op : TFL_Op<"select_v2", [NoSideEffect]> { let arguments = (ins TFL_BoolTensor:$condition, - TFL_TensorOf<[F32, I1, I8, I16, I32, I64, TFL_Uint8]>:$x, - TFL_TensorOf<[F32, I1, I8, I16, I32, I64, TFL_Uint8]>:$y); - let results = (outs AnyTensor:$output); + TFL_TensorOf<[F32, I1, I8, I16, I32, I64, QI8, QUI8, QI16, TFL_Quint8]>:$x, + TFL_TensorOf<[F32, I1, I8, I16, I32, I64, QI8, QUI8, QI16, TFL_Quint8]>:$y); + + let results = (outs + TFL_TensorOf<[F32, I1, I8, I16, I32, I64, QI8, QUI8, QI16, TFL_Quint8]>:$output); let builders = [OpBuilder<"OpBuilder &builder, OperationState &result, " "Value cond, Value x, Value y", @@ -2605,9 +2667,11 @@ def TFL_SinOp: TFL_Op<"sin", [ let hasFolder = 1; } -// TODO(b/130643170): Adds some constraint for the input/output element types. def TFL_SoftmaxOp : TFL_Op<"softmax", [ NoSideEffect, + PredOpTrait<"input and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, + TFL_OperandHasRankRange<0, 1, 4>, SameOperandsAndResultShape, // zero_point = 0 // scale = 1. / (max_value + 1) @@ -2623,11 +2687,11 @@ def TFL_SoftmaxOp : TFL_Op<"softmax", [ }]; let arguments = ( - ins AnyTensor:$input, + ins TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8]>:$input, F32Attr:$beta ); - let results = (outs AnyTensor:$output); + let results = (outs TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8]>:$output); let hasOptions = 1; } @@ -2914,6 +2978,7 @@ def TFL_BatchToSpaceNdOp: TFL_Op<"batch_to_space_nd", [ def TFL_SpaceToBatchNdOp: TFL_Op<"space_to_batch_nd", [ NoSideEffect, SameOperandsAndResultsScale, + TFL_OperandHasRankRange<0, 3, 4>, PredOpTrait<"input and output must have same element type", TCresVTEtIsSameAsOp<0, 0>> ]> { @@ -2924,13 +2989,13 @@ def TFL_SpaceToBatchNdOp: TFL_Op<"space_to_batch_nd", [ }]; let arguments = (ins - TFL_TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$input, - TFL_TensorOf<[I32]>:$block_shape, - TFL_TensorOf<[I32]>:$paddings + TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$input, + TFL_I32Tensor:$block_shape, + TFL_I32Tensor:$paddings ); let results = (outs - TFL_TensorOf<[F32, I16, I32, I64, QI8, QUI8]>:$output + TFL_TensorOf<[F32, I32, I64, QI8, QUI8, TFL_Quint8]>:$output ); } @@ -3045,7 +3110,12 @@ def TFL_SplitVOp : TFL_Op<"split_v", [NoSideEffect, SameOperandsAndResultsScale] } def TFL_ResizeBilinearOp: TFL_Op<"resize_bilinear", [ - NoSideEffect, SameOperandsAndResultsScale]> { + NoSideEffect, + PredOpTrait<"input and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, + TFL_OperandHasRank<0, 4>, + TFL_OperandHasRank<1, 1>, + SameOperandsAndResultsScale]> { let summary = "ResizeBilinear Op"; let description = [{ @@ -3053,23 +3123,26 @@ def TFL_ResizeBilinearOp: TFL_Op<"resize_bilinear", [ }]; let arguments = (ins - // TODO(ycling): Support quantized types. - TFL_TensorOf<[F32, I32, QI8, QUI8]>:$input, - TFL_TensorOf<[I32]>:$size, + TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8]>:$input, + TFL_I32Tensor:$size, BoolAttr:$align_corners, DefaultValuedAttr<BoolAttr, "false">:$half_pixel_centers ); let results = (outs - TFL_TensorOf<[F32, QI8, QUI8]>:$output + TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8]>:$output ); let hasOptions = 1; } -def TFL_ResizeNearestNeighborOp : TFL_Op<"resize_nearest_neighbor", - [NoSideEffect, - SameOperandsAndResultsScale]> { +def TFL_ResizeNearestNeighborOp : TFL_Op<"resize_nearest_neighbor", [ + NoSideEffect, + PredOpTrait<"input and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, + TFL_OperandHasRank<0, 4>, + TFL_OperandHasRank<1, 1>, + SameOperandsAndResultsScale]> { let summary = "ResizeNearestNeighbor Op"; let description = [{ @@ -3077,14 +3150,14 @@ def TFL_ResizeNearestNeighborOp : TFL_Op<"resize_nearest_neighbor", }]; let arguments = (ins - TFL_TensorOf<[F32, I8, TFL_Uint8, QUI8, QI8]>:$input, - TFL_TensorOf<[I32]>:$size, + TFL_TensorOf<[F32, TFL_Quint8, QUI8, QI8]>:$input, + TFL_I32Tensor:$size, BoolAttr:$align_corners, DefaultValuedAttr<BoolAttr, "false">:$half_pixel_centers ); let results = (outs - TFL_TensorOf<[F32, I8, TFL_Uint8, QUI8, QI8]>:$output + TFL_TensorOf<[F32, TFL_Quint8, QUI8, QI8]>:$output ); let hasOptions = 1; @@ -3349,7 +3422,9 @@ def TFL_SparseQConstOp : Op<TFL_Dialect, "pseudo_sparse_qconst", [ } def TFL_QuantizeOp: TFL_Op<"quantize", [ - FirstAttrDerivedResultType, NoQuantizableResult]> { + FirstAttrDerivedResultType, + SameOperandsAndResultShape, + NoQuantizableResult]> { let summary = "Quantize operator"; let description = [{ @@ -3358,11 +3433,11 @@ def TFL_QuantizeOp: TFL_Op<"quantize", [ }]; let arguments = ( - ins AnyTensor:$input, + ins TFL_TensorOf<[F32, QI8, QUI8, QI16, TFL_Quint8]>:$input, TensorTypeAttr:$qtype ); - let results = (outs AnyTensor:$output); + let results = (outs TFL_TensorOf<[QI8, QUI8, QI16, TFL_Quint8]>:$output); } def TFL_DensifyOp: TFL_Op<"densify", [ @@ -3472,6 +3547,19 @@ def TFL_LSTMOp : LstmOptionalPeepholeWeightConstraint, LstmProjectionWeightBiasConstraint, LstmResultConstraint, + TFL_OperandHasRank<2, 2>, // input_to_forget_weights + TFL_OperandHasRank<3, 2>, // input_to_cell_weights + TFL_OperandIsNoneOrHasRank<5, 2>, // recurrent_to_input_weights + TFL_OperandHasRank<6, 2>, // recurrent_to_forget_weights + TFL_OperandHasRank<7, 2>, // recurrent_to_cell_weights + TFL_OperandIsNoneOrHasRank<9, 1>, // cell_to_input_weights + TFL_OperandIsNoneOrHasRank<10, 1>, // cell_to_forget_weights + TFL_OperandIsNoneOrHasRank<11, 1>, // cell_to_output_weights + TFL_OperandHasRank<13, 1>, // forget_gate_bias + TFL_OperandHasRank<14, 1>, // cell_gate_bias + TFL_OperandHasRank<15, 1>, // output_gate_bias + TFL_OperandIsNoneOrHasRank<16, 2>, // projection_weights + TFL_OperandIsNoneOrHasRank<17, 1>, // projection_bias TFL_StatefulOp]> { let summary = "The full lstm operator"; @@ -3498,23 +3586,23 @@ Ba et al. 'Layer Normalization' ins TFL_TensorOf<[F32, QI8]>:$input, // Weights - TFL_TensorOfOrNone<[F32, I8, QI8]>:$input_to_input_weights, - TFL_TensorOf<[F32, I8, QI8]>:$input_to_forget_weights, - TFL_TensorOf<[F32, I8, QI8]>:$input_to_cell_weights, - TFL_TensorOf<[F32, I8, QI8]>:$input_to_output_weights, + TFL_TensorOfOrNone<[F32, QI8]>:$input_to_input_weights, + TFL_TensorOf<[F32, QI8]>:$input_to_forget_weights, + TFL_TensorOf<[F32, QI8]>:$input_to_cell_weights, + TFL_TensorOf<[F32, QI8]>:$input_to_output_weights, // Recurrent weights - TFL_TensorOfOrNone<[F32, I8, QI8]>:$recurrent_to_input_weights, - TFL_TensorOf<[F32, I8, QI8]>:$recurrent_to_forget_weights, - TFL_TensorOf<[F32, I8, QI8]>:$recurrent_to_cell_weights, - TFL_TensorOf<[F32, I8, QI8]>:$recurrent_to_output_weights, + TFL_TensorOfOrNone<[F32, QI8]>:$recurrent_to_input_weights, + TFL_TensorOf<[F32, QI8]>:$recurrent_to_forget_weights, + TFL_TensorOf<[F32, QI8]>:$recurrent_to_cell_weights, + TFL_TensorOf<[F32, QI8]>:$recurrent_to_output_weights, // Cell weights - TFL_TensorOfOrNone<[F32, I8, QI16]>:$cell_to_input_weights, + TFL_TensorOfOrNone<[F32, QI8, QI16]>:$cell_to_input_weights, // Optional input - TFL_TensorOfOrNone<[F32, I8, QI16]>:$cell_to_forget_weights, + TFL_TensorOfOrNone<[F32, QI8, QI16]>:$cell_to_forget_weights, // Optional input - TFL_TensorOfOrNone<[F32, I8, QI16]>:$cell_to_output_weights, + TFL_TensorOfOrNone<[F32, QI8, QI16]>:$cell_to_output_weights, // Bias TFL_TensorOfOrNone<[F32, QI32]>:$input_gate_bias, @@ -3523,7 +3611,7 @@ Ba et al. 'Layer Normalization' TFL_TensorOf<[F32, QI32]>:$output_gate_bias, // Projection weight and bias - TFL_TensorOfOrNone<[F32, I8, QI8]>:$projection_weights, + TFL_TensorOfOrNone<[F32, QI8]>:$projection_weights, // Optional input TFL_TensorOfOrNone<[F32, QI32]>:$projection_bias, @@ -3539,8 +3627,8 @@ Ba et al. 'Layer Normalization' // Attributes TFL_AFAttr:$fused_activation_function, - DefaultValuedAttr<F32Attr, "0.0f">:$cell_clip, - DefaultValuedAttr<F32Attr, "0.0f">:$proj_clip, + Confined<DefaultValuedAttr<F32Attr, "0.0f">, [TFL_FloatNonNegative]>:$cell_clip, + Confined<DefaultValuedAttr<F32Attr, "0.0f">, [TFL_FloatNonNegative]>:$proj_clip, // Since this op is the FULL kernel only, constrain it. Confined< DefaultValuedAttr<TFL_LSTMKernelTypeAttr, "FULL">, @@ -3580,6 +3668,24 @@ def TFL_UnidirectionalSequenceLSTMOp : LstmOptionalPeepholeWeightConstraint, LstmProjectionWeightBiasConstraint, LstmResultConstraint, + TFL_OperandHasRankAtLeast<0, 2>, // input + TFL_OperandIsNoneOrHasRank<1, 2>, // input_to_input_weights + TFL_OperandHasRank<2, 2>, // input_to_forget_weights + TFL_OperandHasRank<3, 2>, // input_to_cell_weights + TFL_OperandHasRank<4, 2>, // input_to_output_weights + TFL_OperandIsNoneOrHasRank<5, 2>, // recurrent_to_input_weights + TFL_OperandHasRank<6, 2>, // recurrent_to_forget_weights + TFL_OperandHasRank<7, 2>, // recurrent_to_cell_weights + TFL_OperandHasRank<8, 2>, // recurrent_to_output_weights + TFL_OperandIsNoneOrHasRank<9, 1>, // cell_to_input_weights + TFL_OperandIsNoneOrHasRank<10, 1>, // cell_to_forget_weights + TFL_OperandIsNoneOrHasRank<11, 1>, // cell_to_output_weights + TFL_OperandIsNoneOrHasRank<12, 1>, // input_gate_bias + TFL_OperandHasRank<13, 1>, // forget_gate_bias + TFL_OperandHasRank<14, 1>, // cell_gate_bias + TFL_OperandHasRank<15, 1>, // output_gate_bias + TFL_OperandIsNoneOrHasRank<16, 2>, // projection_weights + TFL_OperandIsNoneOrHasRank<17, 2>, // projection_bias TFL_StatefulOp]> { let summary = "Unidirectional sequence lstm operator"; @@ -3595,35 +3701,35 @@ def TFL_UnidirectionalSequenceLSTMOp : }]; let arguments = ( - ins TFL_TensorOf<[F32, I8]>:$input, + ins TFL_FpTensor:$input, // Weights - TFL_TensorOfOrNone<[F32, I8]>:$input_to_input_weights, - TFL_TensorOf<[F32, I8]>:$input_to_forget_weights, - TFL_TensorOf<[F32, I8]>:$input_to_cell_weights, - TFL_TensorOf<[F32, I8]>:$input_to_output_weights, + TFL_TensorOfOrNone<[F32, QI8]>:$input_to_input_weights, + TFL_TensorOf<[F32, QI8]>:$input_to_forget_weights, + TFL_TensorOf<[F32, QI8]>:$input_to_cell_weights, + TFL_TensorOf<[F32, QI8]>:$input_to_output_weights, // Recurrent weights - TFL_TensorOfOrNone<[F32, I8]>:$recurrent_to_input_weights, - TFL_TensorOf<[F32, I8]>:$recurrent_to_forget_weights, - TFL_TensorOf<[F32, I8]>:$recurrent_to_cell_weights, - TFL_TensorOf<[F32, I8]>:$recurrent_to_output_weights, + TFL_TensorOfOrNone<[F32, QI8]>:$recurrent_to_input_weights, + TFL_TensorOf<[F32, QI8]>:$recurrent_to_forget_weights, + TFL_TensorOf<[F32, QI8]>:$recurrent_to_cell_weights, + TFL_TensorOf<[F32, QI8]>:$recurrent_to_output_weights, // Cell weights - TFL_TensorOfOrNone<[F32, I8]>:$cell_to_input_weights, + TFL_TensorOfOrNone<[F32, QI8]>:$cell_to_input_weights, // Optional input - TFL_TensorOfOrNone<[F32, I8]>:$cell_to_forget_weights, + TFL_TensorOfOrNone<[F32, QI8]>:$cell_to_forget_weights, // Optional input - TFL_TensorOfOrNone<[F32, I8]>:$cell_to_output_weights, + TFL_TensorOfOrNone<[F32, QI8]>:$cell_to_output_weights, // Bias TFL_TensorOfOrNone<[F32]>:$input_gate_bias, - TFL_TensorOf<[F32]>:$forget_gate_bias, - TFL_TensorOf<[F32]>:$cell_bias, - TFL_TensorOf<[F32]>:$output_gate_bias, + TFL_FpTensor:$forget_gate_bias, + TFL_FpTensor:$cell_bias, + TFL_FpTensor:$output_gate_bias, // Projection weight and bias - TFL_TensorOfOrNone<[F32, I8]>:$projection_weights, + TFL_TensorOfOrNone<[F32, QI8]>:$projection_weights, // Optional input TFL_TensorOfOrNone<[F32]>:$projection_bias, @@ -3632,19 +3738,19 @@ def TFL_UnidirectionalSequenceLSTMOp : TFL_StatefulTensor:$input_cell_state, // Layer norm coefficients - TFL_TensorOfOrNone<[F32, I8]>:$input_layer_norm_coefficients, - TFL_TensorOfOrNone<[F32, I8]>:$forget_layer_norm_coefficients, - TFL_TensorOfOrNone<[F32, I8]>:$cell_layer_norm_coefficients, - TFL_TensorOfOrNone<[F32, I8]>:$output_layer_norm_coefficients, + TFL_TensorOfOrNone<[F32, QI8]>:$input_layer_norm_coefficients, + TFL_TensorOfOrNone<[F32, QI8]>:$forget_layer_norm_coefficients, + TFL_TensorOfOrNone<[F32, QI8]>:$cell_layer_norm_coefficients, + TFL_TensorOfOrNone<[F32, QI8]>:$output_layer_norm_coefficients, // Attributes TFL_AFAttr:$fused_activation_function, - DefaultValuedAttr<F32Attr, "0.0f">:$cell_clip, - DefaultValuedAttr<F32Attr, "0.0f">:$proj_clip, + Confined<DefaultValuedAttr<F32Attr, "0.0f">, [TFL_FloatNonNegative]>:$cell_clip, + Confined<DefaultValuedAttr<F32Attr, "0.0f">, [TFL_FloatNonNegative]>:$proj_clip, BoolAttr:$time_major ); - let results = (outs AnyTensor:$output); + let results = (outs TFL_TensorOf<[F32, QI8]>:$output); let hasOptions = 1; @@ -3841,15 +3947,14 @@ def TFL_BidirectionalSequenceLSTMOp : }]; } -def RnnResultConstraint : PredOpTrait< - "the input and result tensor elemental types must be same", - TCresVTEtIsSameAsOp<0, 0>>; - // UnidirectionalSequenceRNN op. -def TFL_UnidirectionalSequenceRNNOp : - TFL_Op<"unidirectional_sequence_rnn", - [RnnResultConstraint, TFL_StatefulOp]> { - +def TFL_UnidirectionalSequenceRNNOp : TFL_Op<"unidirectional_sequence_rnn", [ + TFL_OperandHasRank<4, 2>, + PredOpTrait<"input and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, + PredOpTrait<"input and constant value operands must have same element type", + TFL_TCopVTEtAreSameAt<1, 2>>, + TFL_StatefulOp]> { let summary = "Unidirectional sequence rnn operator"; let description = [{ @@ -3866,16 +3971,16 @@ def TFL_UnidirectionalSequenceRNNOp : }]; let arguments = ( - ins TFL_TensorOf<[F32, I8]>:$input, + ins TFL_FpTensor:$input, // Weights - TFL_TensorOf<[F32, I8]>:$input_to_input_weights, + TFL_TensorOf<[F32, QI8]>:$input_to_input_weights, // Recurrent weights - TFL_TensorOf<[F32, I8]>:$recurrent_to_input_weights, + TFL_TensorOf<[F32, QI8]>:$recurrent_to_input_weights, // Bias - TFL_TensorOf<[F32]>:$input_gate_bias, + TFL_FpTensor:$input_gate_bias, // Hidden state. TFL_StatefulTensor:$hidden_state, @@ -3885,7 +3990,7 @@ def TFL_UnidirectionalSequenceRNNOp : TFL_AFAttr:$fused_activation_function ); - let results = (outs TFL_TensorOf<[F32, I8]>:$output); + let results = (outs TFL_FpTensor:$output); let hasOptions = 1; @@ -3941,14 +4046,12 @@ def TFL_NumericVerifyOp : Op<TFL_Dialect, "NumericVerify", [ let results = (outs); } -def SVDFResultConstraint: PredOpTrait< - "the input and result tensor elemental types must be same", - TCresVTEtIsSameAsOp<0, 0>>; - // SVDF op. def TFL_SVDFOp : - TFL_Op<"svdf", - [SVDFResultConstraint, TFL_StatefulOp]> { + TFL_Op<"svdf", [ + PredOpTrait<"the input and result tensor elemental types must be same", + TCresVTEtIsSameAsOp<0, 0>>, + TFL_StatefulOp]> { let summary = "Single value decomposition filter operator"; @@ -3960,13 +4063,13 @@ def TFL_SVDFOp : }]; let arguments = ( - ins TFL_TensorOf<[F32, I8]>:$input, + ins TFL_TensorOf<[F32, QI8]>:$input, // Feature Weights. - TFL_TensorOf<[F32, I8]>:$feature_weights, + TFL_TensorOf<[F32, QI8, QUI8]>:$feature_weights, // Time weights - TFL_TensorOf<[F32, I8]>:$time_weights, + TFL_TensorOf<[F32, QI8]>:$time_weights, // Bias TFL_TensorOfOrNone<[F32]>:$input_gate_bias, @@ -3975,11 +4078,11 @@ def TFL_SVDFOp : TFL_StatefulTensor:$activation_state, // Attributes - I32Attr:$rank, + Confined<I32Attr, [IntPositive]>:$rank, TFL_AFAttr:$fused_activation_function ); - let results = (outs TFL_TensorOf<[F32, I8]>:$output); + let results = (outs TFL_TensorOf<[F32, QI8]>:$output); let hasOptions = 1; @@ -3991,7 +4094,10 @@ def TFL_SVDFOp : }]; } -def TFL_SegmentSumOp: TFL_Op<"segment_sum", [NoSideEffect]> { +def TFL_SegmentSumOp: TFL_Op<"segment_sum", [ + NoSideEffect, + PredOpTrait<"input and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>]> { let summary = "SegmentSum operator"; let description = [{ @@ -3999,7 +4105,7 @@ def TFL_SegmentSumOp: TFL_Op<"segment_sum", [NoSideEffect]> { }]; let arguments = (ins - TFL_TensorOf<[F32, I32]>:$data, + TFL_TensorOf<[F32, I32]>:$input, TFL_I32Tensor:$segment_ids ); let results = (outs TFL_TensorOf<[F32, I32]>:$output); @@ -4030,7 +4136,7 @@ def TFL_WhileOp : Op<TFL_Dialect, "while", [ input: A list of input tensors whose types are T. output: A list of output tensors whose types are T. - cond: A region takes 'input' and returns a boolean scalar tensor. + cond: A region that takes 'input' and returns a boolean scalar tensor. body: A region that takes a list of tensors and returns another list of tensors. Both lists have the same types. }]; diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc index 6dd44e666fb..a1401323e89 100644 --- a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc +++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc @@ -121,6 +121,8 @@ DataType ConvertIODataTypeToDataType(toco::IODataType dtype) { return DT_STRING; case toco::IODataType::BOOL: return DT_BOOL; + case toco::IODataType::COMPLEX64: + return DT_COMPLEX64; default: return DT_INVALID; } @@ -252,7 +254,7 @@ Status DumpOpGraphToFile(mlir::ModuleOp module, const std::string& filename) { std::string error_message; auto output = mlir::openOutputFile(filename, &error_message); if (!error_message.empty()) { - return errors::InvalidArgument("Failed to open file in %s.", filename); + return errors::InvalidArgument("Failed to open file in ", filename); } mlir::PassManager pm(module.getContext()); pm.addPass(mlir::createPrintOpGraphPass(output->os())); diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h index 27ccc7d2b22..d4512509f6b 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h @@ -22,6 +22,7 @@ limitations under the License. #include <unordered_map> #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Twine.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project @@ -35,6 +36,7 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h" namespace mlir { @@ -363,6 +365,54 @@ struct ConvertUnsignedToSigned : public OpRewritePattern<Q> { } }; +// Fold Extra Requantize ops if the preceding ops has free scale requirement. +template <typename RQ> +struct FoldTrivalRequantizeOp : public OpRewritePattern<RQ> { + explicit FoldTrivalRequantizeOp(MLIRContext* context) + : OpRewritePattern<RQ>(context, 1) {} + + LogicalResult matchAndRewrite(RQ op, + PatternRewriter& rewriter) const override { + Value pre_quantized = op.input(); + auto pre_quantized_type = + quant::QuantizedType::getQuantizedElementType(pre_quantized.getType()); + if (!pre_quantized_type) return failure(); + + Operation* def = pre_quantized.getDefiningOp(); + if (!def) return failure(); + if (def->hasTrait<OpTrait::quant::SameOperandsAndResultsScale>() || + def->hasTrait<OpTrait::quant::NoQuantizableResult>()) { + return failure(); + } + + op.emitWarning("Remove trivial `rescale` op. Please fix the source graph."); + + llvm::SmallVector<Type, 4> new_output_types; + for (auto result : def->getResults()) { + result.getUsers().begin()->dump(); + op.dump(); + if (result.hasOneUse() && *result.getUsers().begin() == op) { + new_output_types.push_back(op.qtype()); + } else { + new_output_types.push_back(result.getType()); + } + } + + // Remove this rescale op. + rewriter.replaceOp(op, {pre_quantized}); + + // Replace the output scale of the preceding op. + rewriter.setInsertionPointAfter(def); + OperationState new_state(def->getLoc(), def->getName().getStringRef(), + def->getOperands(), new_output_types, + def->getAttrs()); + Operation* new_op = rewriter.createOperation(new_state); + + rewriter.replaceOp(def, new_op->getResults()); + return success(); + } +}; + // Given a quantized type `input`, magnifying its scales by the factor stored in // `factor`. If `input` isn't a quantized type or the `factor` doesn't match the // dimension size of `input` or isn't floating-point, nullptr will be returned. diff --git a/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir b/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir index 1f067aae685..5c69130c939 100644 --- a/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir @@ -11,9 +11,9 @@ func @reshape_removeAdjacent(tensor<4x4x4xf32>) -> tensor<64xf32> { return %1 : tensor<64xf32> // CHECK-LABEL: func @reshape_removeAdjacent -// CHECK: %cst = constant dense<64> : tensor<1xi32> -// CHECK: %0 = "tfl.reshape"(%arg0, %cst) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32> -// CHECK: return +// CHECK: %[[CST:.*]] = constant dense<64> : tensor<1xi32> +// CHECK: %[[RESHAPE:.*]] = "tfl.reshape"(%arg0, %[[CST]]) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32> +// CHECK: return %[[RESHAPE]] } // Checks that tfl.reshape should be removed if its output has more than one @@ -29,11 +29,11 @@ func @reshape_removeAdjacentWithMultipleUse(tensor<4x4x4xf32>) -> tensor<64xf32> return %3 : tensor<64xf32> // CHECK-LABEL: func @reshape_removeAdjacentWithMultipleUse -// CHECK: %cst = constant dense<64> : tensor<1xi32> -// CHECK: %0 = "tfl.reshape"(%arg0, %cst) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32> -// CHECK: %1 = "tfl.reshape"(%arg0, %cst) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32> -// CHECK: %2 = addf %0, %1 -// CHECK: return %2 +// CHECK: %[[CST:.*]] = constant dense<64> : tensor<1xi32> +// CHECK: %[[RESHAPE_1:.*]] = "tfl.reshape"(%arg0, %[[CST]]) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32> +// CHECK: %[[RESHAPE_2:.*]] = "tfl.reshape"(%arg0, %[[CST]]) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32> +// CHECK: %[[RESULT:.*]] = addf %[[RESHAPE_1]], %[[RESHAPE_2]] +// CHECK: return %[[RESULT]] } // Checks that tfl.reshape should be kept if its output has more than one @@ -47,11 +47,11 @@ func @reshape_keepAdjacentWithMultipleUse(tensor<4x4x4xf32>) -> (tensor<16x4xf32 return %0, %1 : tensor<16x4xf32>, tensor<64xf32> // CHECK-LABEL: func @reshape_keepAdjacentWithMultipleUse -// CHECK: %cst = constant dense<[16, 4]> : tensor<2xi32> -// CHECK: %cst_0 = constant dense<64> : tensor<1xi32> -// CHECK: %0 = "tfl.reshape"(%arg0, %cst) : (tensor<4x4x4xf32>, tensor<2xi32>) -> tensor<16x4xf32> -// CHECK: %1 = "tfl.reshape"(%arg0, %cst_0) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32> -// CHECK: return %0, %1 +// CHECK: %[[CST:.*]] = constant dense<[16, 4]> : tensor<2xi32> +// CHECK: %[[CST_0:.*]] = constant dense<64> : tensor<1xi32> +// CHECK: %[[RESHAPE_1:.*]] = "tfl.reshape"(%arg0, %[[CST]]) : (tensor<4x4x4xf32>, tensor<2xi32>) -> tensor<16x4xf32> +// CHECK: %[[RESHAPE_2:.*]] = "tfl.reshape"(%arg0, %[[CST_0]]) : (tensor<4x4x4xf32>, tensor<1xi32>) -> tensor<64xf32> +// CHECK: return %[[RESHAPE_1]], %[[RESHAPE_2]] } // Checks that tfl.reshape should be removed if its output type is the same diff --git a/tensorflow/compiler/mlir/lite/tests/const-fold.mlir b/tensorflow/compiler/mlir/lite/tests/const-fold.mlir index 4b8993e2b26..a8463d51c7e 100644 --- a/tensorflow/compiler/mlir/lite/tests/const-fold.mlir +++ b/tensorflow/compiler/mlir/lite/tests/const-fold.mlir @@ -8,13 +8,13 @@ func @add_float() -> (tensor<f32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, %2 = constant dense< 3.5> : tensor<4xf32> %3 = constant dense<-0.5> : tensor<4xf32> - // CHECK: %cst = constant dense<3.500000e+00> : tensor<4xf32> - // CHECK: %cst_0 = constant dense<-5.000000e-01> : tensor<4xf32> - // CHECK: %cst_1 = constant dense<6.000000e+00> : tensor<f32> - // CHECK: %cst_2 = constant dense<4.000000e+00> : tensor<4xf32> - // CHECK: %cst_3 = constant dense<5.000000e+00> : tensor<4xf32> - // CHECK: %cst_4 = constant dense<3.000000e+00> : tensor<4xf32> - // CHECK: %0 = tfl.add %cst, %cst_0 {fused_activation_function = "SIGN_BIT"} : tensor<4xf32> + // CHECK: %[[CST:.*]] = constant dense<3.500000e+00> : tensor<4xf32> + // CHECK: %[[CST_0:.*]] = constant dense<-5.000000e-01> : tensor<4xf32> + // CHECK: %[[CST_1:.*]] = constant dense<6.000000e+00> : tensor<f32> + // CHECK: %[[CST_2:.*]] = constant dense<4.000000e+00> : tensor<4xf32> + // CHECK: %[[CST_3:.*]] = constant dense<5.000000e+00> : tensor<4xf32> + // CHECK: %[[CST_4:.*]] = constant dense<3.000000e+00> : tensor<4xf32> + // CHECK: %0 = tfl.add %[[CST]], %[[CST_0]] {fused_activation_function = "SIGN_BIT"} : tensor<4xf32> %5 = "tfl.add"(%0, %1) {fused_activation_function = "NONE"} : (tensor< f32>, tensor< f32>) -> tensor< f32> %6 = "tfl.add"(%0, %3) {fused_activation_function = "NONE"} : (tensor< f32>, tensor<4xf32>) -> tensor<4xf32> @@ -33,10 +33,10 @@ func @add_int() -> (tensor<i32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) { %2 = constant dense< 4> : tensor<4xi32> %3 = constant dense<-2> : tensor<4xi32> - // CHECK: %cst = constant dense<9> : tensor<i32> - // CHECK: %cst_0 = constant dense<6> : tensor<4xi32> - // CHECK: %cst_1 = constant dense<5> : tensor<4xi32> - // CHECK: %cst_2 = constant dense<2> : tensor<4xi32> + // CHECK: %[[CST:.*]] = constant dense<9> : tensor<i32> + // CHECK: %[[CST_0:.*]] = constant dense<6> : tensor<4xi32> + // CHECK: %[[CST_1:.*]] = constant dense<5> : tensor<4xi32> + // CHECK: %[[CST_2:.*]] = constant dense<2> : tensor<4xi32> %5 = "tfl.add"(%0, %1) {fused_activation_function = "NONE"} : (tensor< i32>, tensor< i32>) -> tensor< i32> %6 = "tfl.add"(%0, %3) {fused_activation_function = "NONE"} : (tensor< i32>, tensor<4xi32>) -> tensor<4xi32> @@ -54,10 +54,10 @@ func @sub_float() -> (tensor<f32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) %2 = constant dense< 3.5> : tensor<4xf32> %3 = constant dense<-0.5> : tensor<4xf32> - // CHECK: %cst = constant dense<3.000000e+00> : tensor<f32> - // CHECK: %cst_0 = constant dense<5.000000e+00> : tensor<4xf32> - // CHECK: %cst_1 = constant dense<2.000000e+00> : tensor<4xf32> - // CHECK: %cst_2 = constant dense<4.000000e+00> : tensor<4xf32> + // CHECK: %[[CST:.*]] = constant dense<3.000000e+00> : tensor<f32> + // CHECK: %[[CST_0:.*]] = constant dense<5.000000e+00> : tensor<4xf32> + // CHECK: %[[CST_1:.*]] = constant dense<2.000000e+00> : tensor<4xf32> + // CHECK: %[[CST_2:.*]] = constant dense<4.000000e+00> : tensor<4xf32> %5 = "tfl.sub"(%0, %1) {fused_activation_function = "NONE"} : (tensor< f32>, tensor< f32>) -> tensor< f32> %6 = "tfl.sub"(%0, %3) {fused_activation_function = "NONE"} : (tensor< f32>, tensor<4xf32>) -> tensor<4xf32> @@ -75,10 +75,10 @@ func @sub_int() -> (tensor<i32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) { %2 = constant dense< 4> : tensor<4xi32> %3 = constant dense<-2> : tensor<4xi32> - // CHECK: %cst = constant dense<7> : tensor<i32> - // CHECK: %cst_0 = constant dense<10> : tensor<4xi32> - // CHECK: %cst_1 = constant dense<3> : tensor<4xi32> - // CHECK: %cst_2 = constant dense<6> : tensor<4xi32> + // CHECK: %[[CST:.*]] = constant dense<7> : tensor<i32> + // CHECK: %[[CST_0:.*]] = constant dense<10> : tensor<4xi32> + // CHECK: %[[CST_1:.*]] = constant dense<3> : tensor<4xi32> + // CHECK: %[[CST_2:.*]] = constant dense<6> : tensor<4xi32> %5 = "tfl.sub"(%0, %1) {fused_activation_function = "NONE"} : (tensor< i32>, tensor< i32>) -> tensor< i32> %6 = "tfl.sub"(%0, %3) {fused_activation_function = "NONE"} : (tensor< i32>, tensor<4xi32>) -> tensor<4xi32> @@ -96,10 +96,10 @@ func @mul_float() -> (tensor<f32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) %2 = constant dense< 3.5> : tensor<4xf32> %3 = constant dense<-0.5> : tensor<4xf32> - // CHECK: %cst = constant dense<6.750000e+00> : tensor<f32> - // CHECK: %cst_0 = constant dense<-2.250000e+00> : tensor<4xf32> - // CHECK: %cst_1 = constant dense<5.250000e+00> : tensor<4xf32> - // CHECK: %cst_2 = constant dense<-1.750000e+00> : tensor<4xf32> + // CHECK: %[[CST:.*]] = constant dense<6.750000e+00> : tensor<f32> + // CHECK: %[[CST_0:.*]] = constant dense<-2.250000e+00> : tensor<4xf32> + // CHECK: %[[CST_1:.*]] = constant dense<5.250000e+00> : tensor<4xf32> + // CHECK: %[[CST_2:.*]] = constant dense<-1.750000e+00> : tensor<4xf32> %5 = "tfl.mul"(%0, %1) {fused_activation_function = "NONE"} : (tensor< f32>, tensor< f32>) -> tensor< f32> %6 = "tfl.mul"(%0, %3) {fused_activation_function = "NONE"} : (tensor< f32>, tensor<4xf32>) -> tensor<4xf32> @@ -170,8 +170,8 @@ func @add_dense_splat_int() -> tensor<4xi32> { return %2 : tensor<4xi32> -// CHECK: %cst = constant dense<[-5, 4, 47, 105]> : tensor<4xi32> -// CHECK: return %cst +// CHECK: %[[CST:.*]] = constant dense<[-5, 4, 47, 105]> : tensor<4xi32> +// CHECK: return %[[CST]] } // CHECK-LABEL: @add_splat_dense_int @@ -183,8 +183,8 @@ func @add_splat_dense_int() -> tensor<4xi32> { return %2 : tensor<4xi32> -// CHECK: %cst = constant dense<[-5, 4, 47, 105]> : tensor<4xi32> -// CHECK: return %cst +// CHECK: %[[CST:.*]] = constant dense<[-5, 4, 47, 105]> : tensor<4xi32> +// CHECK: return %[[CST]] } // CHECK-LABEL: @add_dense_dense_int_same_shape @@ -196,8 +196,8 @@ func @add_dense_dense_int_same_shape() -> tensor<4xi32> { return %2 : tensor<4xi32> -// CHECK: %cst = constant dense<[5, 22, -2, 98]> : tensor<4xi32> -// CHECK: return %cst +// CHECK: %[[CST:.*]] = constant dense<[5, 22, -2, 98]> : tensor<4xi32> +// CHECK: return %[[CST]] } // CHECK-LABEL: @add_dense_dense_int_trailing_dim @@ -212,10 +212,10 @@ func @add_dense_dense_int_trailing_dim() -> (tensor<2x2xi32>, tensor<2x2x2xi32>, return %0, %1, %2 : tensor<2x2xi32>, tensor<2x2x2xi32>, tensor<2x2x2xi32> -// CHECK: %cst = constant dense<{{\[\[}}11, 22], [13, 24]]> : tensor<2x2xi32> -// CHECK: %cst_0 = constant dense<{{\[\[\[}}2, 3], [5, 6]], {{\[\[}}4, 5], [7, 8]]]> : tensor<2x2x2xi32> -// CHECK: %cst_1 = constant dense<{{\[\[\[}}11, 21], [12, 22]], {{\[\[}}13, 23], [14, 24]]]> : tensor<2x2x2xi32> -// CHECK: return %cst, %cst_0, %cst_1 +// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}11, 22], [13, 24]]> : tensor<2x2xi32> +// CHECK: %[[CST_0:.*]] = constant dense<{{\[\[\[}}2, 3], [5, 6]], {{\[\[}}4, 5], [7, 8]]]> : tensor<2x2x2xi32> +// CHECK: %[[CST_1:.*]] = constant dense<{{\[\[\[}}11, 21], [12, 22]], {{\[\[}}13, 23], [14, 24]]]> : tensor<2x2x2xi32> +// CHECK: return %[[CST]], %[[CST_0]], %[[CST_1]] } // CHECK-LABEL: @add_dense_dense_int_mixing_1_n @@ -226,8 +226,8 @@ func @add_dense_dense_int_mixing_1_n() -> tensor<2x2xi32> { %0 = "tfl.add"(%cst_0, %cst_1) {fused_activation_function = "NONE"} : (tensor<1x2xi32>, tensor<2x1xi32>) -> tensor<2x2xi32> return %0 : tensor<2x2xi32> -// CHECK: %cst = constant dense<{{\[\[}}4, 5], [5, 6]]> : tensor<2x2xi32> -// CHECK: return %cst +// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}4, 5], [5, 6]]> : tensor<2x2xi32> +// CHECK: return %[[CST]] } // CHECK-LABEL: @add_dense_splat_float @@ -239,8 +239,8 @@ func @add_dense_splat_float() -> tensor<4xf32> { return %2 : tensor<4xf32> -// CHECK: %cst = constant dense<[-6.500000e+00, 2.000000e+00, 4.550000e+01, 1.075000e+01]> : tensor<4xf32> -// CHECK: return %cst +// CHECK: %[[CST:.*]] = constant dense<[-6.500000e+00, 2.000000e+00, 4.550000e+01, 1.075000e+01]> : tensor<4xf32> +// CHECK: return %[[CST]] } // CHECK-LABEL: @add_splat_dense_float @@ -252,8 +252,8 @@ func @add_splat_dense_float() -> tensor<4xf32> { return %2 : tensor<4xf32> -// CHECK: %cst = constant dense<[-6.500000e+00, 2.000000e+00, 4.550000e+01, 1.075000e+01]> : tensor<4xf32> -// CHECK: return %cst +// CHECK: %[[CST:.*]] = constant dense<[-6.500000e+00, 2.000000e+00, 4.550000e+01, 1.075000e+01]> : tensor<4xf32> +// CHECK: return %[[CST]] } // CHECK-LABEL: @add_dense_dense_float_same_shape @@ -265,8 +265,8 @@ func @add_dense_dense_float_same_shape() -> (tensor<4xf32>) { return %2 : tensor<4xf32> -// CHECK: %cst = constant dense<[-8.89999961, 1.000000e+00, 3.800000e+01, 9.800000e+01]> : tensor<4xf32> -// CHECK: return %cst +// CHECK: %[[CST:.*]] = constant dense<[-8.89999961, 1.000000e+00, 3.800000e+01, 9.800000e+01]> : tensor<4xf32> +// CHECK: return %[[CST]] } // CHECK-LABEL: @add_dense_dense_float_trailing_dim @@ -281,10 +281,10 @@ func @add_dense_dense_float_trailing_dim() -> (tensor<2x2xf32>, tensor<2x2x2xf32 return %0, %1, %2 : tensor<2x2xf32>, tensor<2x2x2xf32>, tensor<2x2x2xf32> -// CHECK: %cst = constant dense<{{\[\[}}-4.500000e+00, -2.500000e+00], [8.500000e+00, -8.500000e+00]]> : tensor<2x2xf32> -// CHECK: %cst_0 = constant dense<{{\[\[\[}}-4.500000e+00, 2.500000e+00], [9.500000e+00, -2.500000e+00]], {{\[\[}}-2.500000e+00, 4.500000e+00], [1.150000e+01, -5.000000e-01]]]> : tensor<2x2x2xf32> -// CHECK: %cst_1 = constant dense<{{\[\[\[}}2.000000e+00, -3.000000e+00], [3.000000e+00, -2.000000e+00]], {{\[\[}}4.000000e+00, -1.000000e+00], [5.000000e+00, 0.000000e+00]]]> : tensor<2x2x2xf32> -// CHECK: return %cst, %cst_0, %cst_1 +// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}-4.500000e+00, -2.500000e+00], [8.500000e+00, -8.500000e+00]]> : tensor<2x2xf32> +// CHECK: %[[CST_0:.*]] = constant dense<{{\[\[\[}}-4.500000e+00, 2.500000e+00], [9.500000e+00, -2.500000e+00]], {{\[\[}}-2.500000e+00, 4.500000e+00], [1.150000e+01, -5.000000e-01]]]> : tensor<2x2x2xf32> +// CHECK: %[[CST_1:.*]] = constant dense<{{\[\[\[}}2.000000e+00, -3.000000e+00], [3.000000e+00, -2.000000e+00]], {{\[\[}}4.000000e+00, -1.000000e+00], [5.000000e+00, 0.000000e+00]]]> : tensor<2x2x2xf32> +// CHECK: return %[[CST]], %[[CST_0]], %[[CST_1]] } // CHECK-LABEL: @add_dense_dense_float_mixfng_1_n @@ -296,24 +296,24 @@ func @add_dense_dense_float_mixfng_1_n() -> tensor<2x2xf32> { return %0 : tensor<2x2xf32> -// CHECK: %cst = constant dense<{{\[\[}}-1.500000e+00, -5.500000e+00], [5.500000e+00, 1.500000e+00]]> : tensor<2x2xf32> -// CHECK: return %cst +// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}-1.500000e+00, -5.500000e+00], [5.500000e+00, 1.500000e+00]]> : tensor<2x2xf32> +// CHECK: return %[[CST]] } // CHECK-LABEL: @rank func @rank() -> tensor<1xi32> { %cst = constant dense<[[1], [2]]> : tensor<2x1xi32> - // CHECK: [[cst:%.*]] = constant dense<2> : tensor<1xi32> - // CHECK: return [[cst]] + // CHECK: %[[CST:.*]] = constant dense<2> : tensor<1xi32> + // CHECK: return %[[CST]] %0 = "tfl.rank"(%cst) : (tensor<2x1xi32>) -> tensor<1xi32> return %0 : tensor<1xi32> } // CHECK-LABEL: @rank_input_known_rank func @rank_input_known_rank(%arg0 : tensor<2x1xi32>) -> tensor<1xi32> { - // CHECK: [[cst:%.*]] = constant dense<2> : tensor<1xi32> - // CHECK: return [[cst]] + // CHECK: %[[CST:.*]] = constant dense<2> : tensor<1xi32> + // CHECK: return %[[CST]] %0 = "tfl.rank"(%arg0) : (tensor<2x1xi32>) -> tensor<1xi32> return %0 : tensor<1xi32> } @@ -323,8 +323,8 @@ func @reshape() -> tensor<4xi32> { %input = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32> %shape = constant dense<[4]> : tensor<1xi32> - // CHECK: [[cst:%.*]] = constant dense<[1, 2, 3, 4]> : tensor<4xi32> - // CHECK: return [[cst]] + // CHECK: %[[CST:.*]] = constant dense<[1, 2, 3, 4]> : tensor<4xi32> + // CHECK: return %[[CST]] %0 = "tfl.reshape"(%input, %shape) : (tensor<2x2xi32>, tensor<1xi32>) -> tensor<4xi32> return %0 : tensor<4xi32> } @@ -334,8 +334,8 @@ func @reshape_dynamic_output() -> tensor<?xi32> { %input = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32> %shape = constant dense<[4]> : tensor<1xi32> - // CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[1, 2, 3, 4]> : tensor<4xi32>} : () -> tensor<?xi32> - // CHECK: return [[cst]] + // CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[1, 2, 3, 4]> : tensor<4xi32>} : () -> tensor<?xi32> + // CHECK: return %[[CST]] %0 = "tfl.reshape"(%input, %shape) : (tensor<2x2xi32>, tensor<1xi32>) -> tensor<?xi32> return %0 : tensor<?xi32> } @@ -343,8 +343,8 @@ func @reshape_dynamic_output() -> tensor<?xi32> { // CHECK-LABEL: @pseudo_const func @pseudo_const() -> tensor<i32> { - // CHECK: [[cst:%.*]] = constant dense<1> : tensor<i32> - // CHECK: return [[cst]] + // CHECK: %[[CST:.*]] = constant dense<1> : tensor<i32> + // CHECK: return %[[CST]] %0 = "tfl.pseudo_const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32> return %0 : tensor<i32> } @@ -356,8 +356,8 @@ func @range_int() -> tensor<?xi32> { %cst_1 = constant dense<4> : tensor<i32> %cst_2 = constant dense<1> : tensor<i32> - // CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[0, 1, 2, 3]> : tensor<4xi32>} : () -> tensor<?xi32> - // CHECK: return [[cst]] + // CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[0, 1, 2, 3]> : tensor<4xi32>} : () -> tensor<?xi32> + // CHECK: return %[[CST]] %0 = "tfl.range"(%cst, %cst_1, %cst_2) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32> return %0 : tensor<?xi32> } @@ -368,8 +368,8 @@ func @range_float() -> tensor<?xf32> { %cst_1 = constant dense<4.0> : tensor<f32> %cst_2 = constant dense<1.0> : tensor<f32> - // CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32>} : () -> tensor<?xf32> - // CHECK: return [[cst]] + // CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32>} : () -> tensor<?xf32> + // CHECK: return %[[CST]] %0 = "tfl.range"(%cst, %cst_1, %cst_2) : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<?xf32> return %0 : tensor<?xf32> } @@ -381,8 +381,8 @@ func @range_float_neg_delta() -> tensor<?xf32> { %cst_1 = constant dense<-4.0> : tensor<f32> %cst_2 = constant dense<-1.0> : tensor<f32> - // CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[0.000000e+00, -1.000000e+00, -2.000000e+00, -3.000000e+00]> : tensor<4xf32>} : () -> tensor<?xf32> - // CHECK: return [[cst]] + // CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[0.000000e+00, -1.000000e+00, -2.000000e+00, -3.000000e+00]> : tensor<4xf32>} : () -> tensor<?xf32> + // CHECK: return %[[CST]] %0 = "tfl.range"(%cst, %cst_1, %cst_2) : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<?xf32> return %0 : tensor<?xf32> } @@ -393,8 +393,8 @@ func @range_float_nonzero_base() -> tensor<?xf32> { %cst_1 = constant dense<7.0> : tensor<f32> %cst_2 = constant dense<1.5> : tensor<f32> - // CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[2.000000e+00, 3.500000e+00, 5.000000e+00, 6.500000e+00]> : tensor<4xf32>} : () -> tensor<?xf32> - // CHECK: return [[cst]] + // CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[2.000000e+00, 3.500000e+00, 5.000000e+00, 6.500000e+00]> : tensor<4xf32>} : () -> tensor<?xf32> + // CHECK: return %[[CST]] %0 = "tfl.range"(%cst, %cst_1, %cst_2) : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<?xf32> return %0 : tensor<?xf32> } @@ -414,8 +414,8 @@ func @transpose_1d() -> tensor<3xi32> { %cst = constant dense<[1, 2, 3]> : tensor<3xi32> %cst_perm = constant dense<0> : tensor<1xi32> - // CHECK: [[cst:%.*]] = constant dense<{{\[}}1, 2, 3]> : tensor<3xi32> - // CHECK: return [[cst]] + // CHECK: %[[CST:.*]] = constant dense<{{\[}}1, 2, 3]> : tensor<3xi32> + // CHECK: return %[[CST]] %0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<3xi32>, tensor<1xi32>) -> tensor<3xi32> return %0 : tensor<3xi32> } @@ -425,8 +425,8 @@ func @transpose_dynamic() -> tensor<?xi32> { %cst = constant dense<[1, 2, 3]> : tensor<3xi32> %cst_perm = constant dense<0> : tensor<1xi32> - // CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<{{\[}}1, 2, 3]> : tensor<3xi32>} : () -> tensor<?xi32> - // CHECK: return [[cst]] + // CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<{{\[}}1, 2, 3]> : tensor<3xi32>} : () -> tensor<?xi32> + // CHECK: return %[[CST]] %0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<3xi32>, tensor<1xi32>) -> tensor<?xi32> return %0 : tensor<?xi32> } @@ -436,8 +436,8 @@ func @transpose_2d() -> tensor<2x2xi32> { %cst = constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi32> %cst_perm = constant dense<[1, 0]> : tensor<2xi32> - // CHECK: [[cst:%.*]] = constant dense<{{\[\[}}0, 2], {{\[}}1, 3]]> : tensor<2x2xi32> - // CHECK: return [[cst]] + // CHECK: %[[CST:.*]] = constant dense<{{\[\[}}0, 2], {{\[}}1, 3]]> : tensor<2x2xi32> + // CHECK: return %[[CST]] %0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<2x2xi32>, tensor<2xi32>) -> tensor<2x2xi32> return %0 : tensor<2x2xi32> } @@ -447,8 +447,8 @@ func @transpose_2d_identity() -> tensor<2x2xi32> { %cst = constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi32> %cst_perm = constant dense<[0, 1]> : tensor<2xi32> - // CHECK: [[cst:%.*]] = constant dense<{{\[\[}}0, 1], {{\[}}2, 3]]> : tensor<2x2xi32> - // CHECK: return [[cst]] + // CHECK: %[[CST:.*]] = constant dense<{{\[\[}}0, 1], {{\[}}2, 3]]> : tensor<2x2xi32> + // CHECK: return %[[CST]] %0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<2x2xi32>, tensor<2xi32>) -> tensor<2x2xi32> return %0 : tensor<2x2xi32> } @@ -460,8 +460,8 @@ func @transpose_3d() -> tensor<4x2x3xi32> { %cst = constant dense<[[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]]> : tensor<2x3x4xi32> %cst_perm = constant dense<[2, 0, 1]> : tensor<3xi32> - // CHECK: [[cst:%.*]] = constant dense<{{\[\[\[}}0, 4, 8], {{\[}}12, 16, 20]], {{\[\[}}1, 5, 9], {{\[}}13, 17, 21]], {{\[\[}}2, 6, 10], {{\[}}14, 18, 22]], {{\[\[}}3, 7, 11], {{\[}}15, 19, 23]]]> : tensor<4x2x3xi32> - // CHECK: return [[cst]] + // CHECK: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 4, 8], {{\[}}12, 16, 20]], {{\[\[}}1, 5, 9], {{\[}}13, 17, 21]], {{\[\[}}2, 6, 10], {{\[}}14, 18, 22]], {{\[\[}}3, 7, 11], {{\[}}15, 19, 23]]]> : tensor<4x2x3xi32> + // CHECK: return %[[CST]] %0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<2x3x4xi32>, tensor<3xi32>) -> tensor<4x2x3xi32> return %0 : tensor<4x2x3xi32> } @@ -473,8 +473,8 @@ func @ConstantFoldBinaryOpDynamicOutput() -> tensor<?xi32> { %87 = "tfl.sub"(%cst_0, %cst) {fused_activation_function = "NONE"} : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32> return %87 : tensor<?xi32> - // CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[-5, 0]> : tensor<2xi32>} : () -> tensor<?xi32> - // CHECK: return [[cst]] + // CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[-5, 0]> : tensor<2xi32>} : () -> tensor<?xi32> + // CHECK: return %[[CST]] } // CHECK-LABEL: @add_dense_dense_int_same_shape_dynamic @@ -486,8 +486,8 @@ func @add_dense_dense_int_same_shape_dynamic() -> tensor<?xi32> { return %2 : tensor<?xi32> - // CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[5, 22, -2, 98]> : tensor<4xi32>} : () -> tensor<?xi32> - // CHECK: return [[cst]] + // CHECK: %[[CST:.*]] = "tfl.pseudo_const"() {value = dense<[5, 22, -2, 98]> : tensor<4xi32>} : () -> tensor<?xi32> + // CHECK: return %[[CST]] } // CHECK-LABEL: @concat_2_tensors_1_empty @@ -497,8 +497,8 @@ func @concat_2_tensors_1_empty() -> tensor<2xi32> { %3 = "tfl.concatenation"(%1, %2) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<2xi32>, tensor<0xi32>) -> tensor<2xi32> return %3 : tensor<2xi32> - // CHECK: [[cst:%.*]] = constant dense<1> : tensor<2xi32> - // CHECK: return [[cst]] : tensor<2xi32> + // CHECK: %[[CST:.*]] = constant dense<1> : tensor<2xi32> + // CHECK: return %[[CST]] : tensor<2xi32> } // CHECK-LABEL: @concat_3_tensors_1_empty @@ -509,7 +509,7 @@ func @concat_3_tensors_1_empty() -> tensor<?xi32> { %3 = "tfl.concatenation"(%0, %1, %2) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<2xi32>, tensor<2xi32>, tensor<0xi32>) -> tensor<?xi32> return %3 : tensor<?xi32> - // CHECK: %0 = "tfl.concatenation"(%cst, %cst) {axis = 0 : i32, fused_activation_function = "NONE"} + // CHECK: %0 = "tfl.concatenation"(%[[CST]], %[[CST]]) {axis = 0 : i32, fused_activation_function = "NONE"} // CHECK: return %0 : tensor<?xi32> } @@ -520,10 +520,10 @@ func @concatConstantTensorsFirstDim() -> tensor<2x2x3xi32> { %0 = "tfl.concatenation"(%cst_0, %cst_1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2x3xi32>, tensor<1x2x3xi32>) -> tensor<2x2x3xi32> return %0 : tensor<2x2x3xi32> - // CHECK: [[cst:%.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0], {{\[}}0, 0, 0]], {{\[}}{{\[}}1, 1, 1], {{\[}}1, 1, 1]]]> : tensor<2x2x3xi32> + // CHECK: %[[CST:.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0], {{\[}}0, 0, 0]], {{\[}}{{\[}}1, 1, 1], {{\[}}1, 1, 1]]]> : tensor<2x2x3xi32> // CHECK-NOT: constant-dense // CHECK-NOT: "tfl.concatenation" - // CHECK: return [[cst]] + // CHECK: return %[[CST]] } // CHECK-LABEL: @concatConstantTensorsMiddleDim @@ -533,10 +533,10 @@ func @concatConstantTensorsMiddleDim() -> tensor<1x4x3xi32> { %0 = "tfl.concatenation"(%cst_0, %cst_1) {axis = 1 : i32, fused_activation_function = "NONE"} : (tensor<1x2x3xi32>, tensor<1x2x3xi32>) -> tensor<1x4x3xi32> return %0 : tensor<1x4x3xi32> - // CHECK: [[cst:%.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0], {{\[}}0, 0, 0], {{\[}}1, 1, 1], {{\[}}1, 1, 1]]]> : tensor<1x4x3xi32> + // CHECK: %[[CST:.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0], {{\[}}0, 0, 0], {{\[}}1, 1, 1], {{\[}}1, 1, 1]]]> : tensor<1x4x3xi32> // CHECK-NOT: constant-dense // CHECK-NOT: "tfl.concatenation" - // CHECK: return [[cst]] + // CHECK: return %[[CST]] } // CHECK-LABEL: @concatConstantTensorsLastDim @@ -546,10 +546,10 @@ func @concatConstantTensorsLastDim() -> tensor<1x2x6xi32> { %0 = "tfl.concatenation"(%cst_0, %cst_1) {axis = 2 : i32, fused_activation_function = "NONE"} : (tensor<1x2x3xi32>, tensor<1x2x3xi32>) -> tensor<1x2x6xi32> return %0 : tensor<1x2x6xi32> - // CHECK: [[cst:%.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0, 1, 1, 1], {{\[}}0, 0, 0, 1, 1, 1]]]> : tensor<1x2x6xi32> + // CHECK: %[[CST:.*]] = constant dense<[{{\[}}{{\[}}0, 0, 0, 1, 1, 1], {{\[}}0, 0, 0, 1, 1, 1]]]> : tensor<1x2x6xi32> // CHECK-NOT: constant-dense // CHECK-NOT: "tfl.concatenation" - // CHECK: return [[cst]] + // CHECK: return %[[CST]] } // CHECK-LABEL: @div_dense_dense_float_mixfng_1_n @@ -561,8 +561,8 @@ func @div_dense_dense_float_mixfng_1_n() -> tensor<2x2xf32> { return %0 : tensor<2x2xf32> -// CHECK: %cst = constant dense<{{\[\[}}-5.000000e-01, 0.833333313], [3.750000e-01, -6.250000e-01]]> : tensor<2x2xf32> -// CHECK: return %cst +// CHECK: %[[CST:.*]] = constant dense<{{\[\[}}-5.000000e-01, 0.833333313], [3.750000e-01, -6.250000e-01]]> : tensor<2x2xf32> +// CHECK: return %[[CST]] } // CHECK-LABEL: @div_dense_different_rank @@ -574,6 +574,6 @@ func @div_dense_different_rank() -> tensor<1x2x2xf32> { return %0 : tensor<1x2x2xf32> -// CHECK: %cst = constant dense<[{{\[}}{{\[}}5.000000e-01, 0.333333343], [1.000000e+00, 0.666666686]]]> : tensor<1x2x2xf32> -// CHECK: return %cst +// CHECK: %[[CST:.*]] = constant dense<[{{\[}}{{\[}}5.000000e-01, 0.333333343], [1.000000e+00, 0.666666686]]]> : tensor<1x2x2xf32> +// CHECK: return %[[CST]] } diff --git a/tensorflow/compiler/mlir/lite/tests/end2end/unroll_batch_matmul.pbtxt b/tensorflow/compiler/mlir/lite/tests/end2end/unroll_batch_matmul.pbtxt new file mode 100644 index 00000000000..096033e37cb --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/end2end/unroll_batch_matmul.pbtxt @@ -0,0 +1,101 @@ +# RUN: tf_tfl_translate -tf-input-arrays=Placeholder,Placeholder_1 -tf-input-shapes=2,5,3:3,7 -tf-input-data-types=DT_FLOAT,DT_FLOAT -tf-output-arrays=MatMul -output-mlir %s -o - 2>&1 | FileCheck %s + +node { + name: "Placeholder" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 2 + } + dim { + size: 5 + } + dim { + size: 3 + } + } + } + } +} +node { + name: "Placeholder_1" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 3 + } + dim { + size: 7 + } + } + } + } +} +node { + name: "MatMul" + op: "BatchMatMulV2" + input: "Placeholder" + input: "Placeholder_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "adj_x" + value { + b: false + } + } + attr { + key: "adj_y" + value { + b: false + } + } +} +versions { + producer: 175 +} + +# CHECK: func @main(%[[VAL_0:.*]]: tensor<2x5x3xf32>, %[[VAL_1:.*]]: tensor<3x7xf32>) -> tensor<2x5x7xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "Placeholder,Placeholder_1", outputs = "MatMul"}} { +# CHECK: %[[VAL_2:.*]] = constant dense<[1, 0]> : tensor<2xi32> +# CHECK: %[[VAL_3:.*]] = constant dense<[5, 3]> : tensor<2xi32> +# CHECK: %[[VAL_4:.*]] = constant dense<[3, 7]> : tensor<2xi32> +# CHECK: %[[VAL_5:.*]] = constant unit +# CHECK: %[[VAL_6:.*]] = constant dense<[1, 0, 0]> : tensor<3xi32> +# CHECK: %[[VAL_7:.*]] = constant dense<[1, 5, 3]> : tensor<3xi32> +# CHECK: %[[VAL_8:.*]] = constant dense<0> : tensor<3xi32> +# CHECK: %[[VAL_9:.*]] = constant dense<[1, 3, 7]> : tensor<3xi32> +# CHECK: %[[VAL_10:.*]] = "tfl.slice"(%[[VAL_0]], %[[VAL_8]], %[[VAL_7]]) : (tensor<2x5x3xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x5x3xf32> +# CHECK: %[[VAL_11:.*]] = "tfl.reshape"(%[[VAL_10]], %[[VAL_3]]) : (tensor<1x5x3xf32>, tensor<2xi32>) -> tensor<5x3xf32> +# CHECK: %[[VAL_12:.*]] = "tfl.slice"(%[[VAL_0]], %[[VAL_6]], %[[VAL_7]]) : (tensor<2x5x3xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x5x3xf32> +# CHECK: %[[VAL_13:.*]] = "tfl.reshape"(%[[VAL_12]], %[[VAL_3]]) : (tensor<1x5x3xf32>, tensor<2xi32>) -> tensor<5x3xf32> +# CHECK: %[[VAL_14:.*]] = "tfl.reshape"(%[[VAL_1]], %[[VAL_9]]) : (tensor<3x7xf32>, tensor<3xi32>) -> tensor<1x3x7xf32> +# CHECK: %[[VAL_15:.*]] = "tfl.slice"(%[[VAL_14]], %[[VAL_8]], %[[VAL_9]]) : (tensor<1x3x7xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x3x7xf32> +# CHECK: %[[VAL_16:.*]] = "tfl.reshape"(%[[VAL_15]], %[[VAL_4]]) : (tensor<1x3x7xf32>, tensor<2xi32>) -> tensor<3x7xf32> +# CHECK: %[[VAL_17:.*]] = "tfl.transpose"(%[[VAL_16]], %[[VAL_2]]) : (tensor<3x7xf32>, tensor<2xi32>) -> tensor<7x3xf32> +# CHECK: %[[VAL_18:.*]] = "tfl.fully_connected"(%[[VAL_11]], %[[VAL_17]], %[[VAL_5]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<5x3xf32>, tensor<7x3xf32>, none) -> tensor<5x7xf32> +# CHECK: %[[VAL_19:.*]] = "tfl.fully_connected"(%[[VAL_13]], %[[VAL_17]], %[[VAL_5]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<5x3xf32>, tensor<7x3xf32>, none) -> tensor<5x7xf32> +# CHECK: %[[VAL_20:.*]] = "tfl.pack"(%[[VAL_18]], %[[VAL_19]]) {axis = 0 : i32, values_count = 2 : i32} : (tensor<5x7xf32>, tensor<5x7xf32>) -> tensor<2x5x7xf32> +# CHECK: return %[[VAL_20]] : tensor<2x5x7xf32> +# CHECK: } diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/lstm.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/lstm.mlir index 0dd8ddc4c91..d793ea2d62f 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/lstm.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/lstm.mlir @@ -1,15 +1,15 @@ // RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s // Ensure lstm roundtrip exactly -func @main(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>, %arg3: tensor<4x4xf32>, %arg4: tensor<4x4xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4x4xf32>, %arg10: tensor<4x4xf32>, %arg11: tensor<4x4xf32>, %arg12: tensor<1x4xf32>, %arg13: tensor<1x4xf32>, %arg14: tensor<1x4xf32>, %arg15: tensor<1x4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<1x4xf32>, %arg18: tensor<4xf32>, %arg19: tensor<4xf32>, %arg20: tensor<4xf32>, %arg21: tensor<4xf32>) -> tensor<1x4xf32> { +func @main(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>, %arg3: tensor<4x4xf32>, %arg4: tensor<4x4xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4xf32>, %arg10: tensor<4xf32>, %arg11: tensor<4xf32>, %arg12: tensor<1x4xf32>, %arg13: tensor<4xf32>, %arg14: tensor<4xf32>, %arg15: tensor<4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<4xf32>, %arg18: tensor<4xf32>, %arg19: tensor<4xf32>, %arg20: tensor<4xf32>, %arg21: tensor<4xf32>) -> tensor<1x4xf32> { %cst0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const") %cst1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const") - %24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32> + %24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32> return %24 : tensor<1x4xf32> // CHECK-LABEL: main // seperate lines since there is no region for this op. third_party/tensorflow/compiler/mlir/lite/ir/tfl_ops.td: 3252 // CHECK: %[[RES0:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg22, %arg23, %arg18, %arg19, %arg20, %arg21) ( { -// CHECK: }) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32> +// CHECK: }) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32> // CHECK: return %[[RES0]] } diff --git a/tensorflow/compiler/mlir/lite/tests/fuse-tftext.mlir b/tensorflow/compiler/mlir/lite/tests/fuse-tftext.mlir new file mode 100644 index 00000000000..f08ac0e1027 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/fuse-tftext.mlir @@ -0,0 +1,14 @@ +// RUN: tf-opt -tfl-prepare-composite-funcs-tf -tfl-fuse-tftext=true %s -split-input-file | FileCheck %s --dump-input-on-failure +module { + + func @_whitespace_func(%arg0: tensor<1x!tf.string>) -> (tensor<?x!tf.string>, tensor<?xi64>) attributes {tf._GrapplerSpecializedFunc = true, tf._input_shapes = [#tf.shape<1>], tf.api_implements = "tftext:WhitespaceTokenizer", tf.signature.is_stateful} { + %0 = "tf.op1"(%arg0) : (tensor<1x!tf.string>) -> (tensor<?x!tf.string>) + %1 = "tf.Const"() {value = dense<-1> : tensor<i64>} : () -> tensor<?xi64> + %2:2 = "tf.op2"(%arg0, %1) : (tensor<1x!tf.string>, tensor<?xi64>) -> (tensor<?x!tf.string>, tensor<?xi64>) + return %2#0, %2#1 : tensor<?x!tf.string>, tensor<?xi64> + } + + // CHECK: func @_whitespace_func(%arg0: tensor<1x!tf.string>) -> (tensor<?x!tf.string>, tensor<?xi64>) attributes {tf._GrapplerSpecializedFunc = true, tf._input_shapes = [#tf.shape<1>], tf.api_implements = "tftext:WhitespaceTokenizer", tf.signature.is_stateful} { + // CHECK: "tfl.custom"(%arg0) {custom_code = "tftext:WhitespaceTokenizer", custom_option = opaque<"tfl", "0x"> : tensor<0xi8>} : (tensor<1x!tf.string>) -> (tensor<?x!tf.string>, tensor<?xi64>) + // CHECK: return %0#0, %0#1 : tensor<?x!tf.string>, tensor<?xi64> +} diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir index 15b6bf56b7a..15c73d2db2c 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir @@ -1048,6 +1048,15 @@ func @concatv2With3Tensors(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2 // CHECK: "tfl.concatenation"(%arg0, %arg1, %arg2) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x3xi32> } +func @concatv2I64Axis(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2: tensor<2x1xi32>) -> tensor<2x3xi32> { + %0 = "tf.Const"() { value = dense<-1> : tensor<i64> } : () -> tensor<i64> + %1 = "tf.ConcatV2"(%arg0, %arg1, %arg2, %0) : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>, tensor<i64>) -> tensor<2x3xi32> + return %1 : tensor<2x3xi32> + +// CHECK-LABEL: concatv2I64Axis +// CHECK: "tfl.concatenation"(%arg0, %arg1, %arg2) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x3xi32> +} + func @resize_with_bilinear(%arg0: tensor<1x100x100x3xf32>, %arg1: tensor<4xi32>) -> tensor<?xf32> { %0 = "tf.ResizeBilinear"(%arg0, %arg1) {align_corners = true} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor<?xf32> return %0 : tensor<?xf32> diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/custom_op_with_tflite_op.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/custom_op_with_tflite_op.mlir index 1b46fa3d0e5..320f869ac4c 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/custom_op_with_tflite_op.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/custom_op_with_tflite_op.mlir @@ -65,7 +65,7 @@ func @main(tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: opcode_index: 1, // CHECK-NEXT: inputs: [ 2, 1 ], // CHECK-NEXT: outputs: [ 3 ], -// CHECK-NEXT: custom_options: [ 105, 110, 116, 95, 97, 116, 116, 114, 0, 102, 117, 115, 101, 100, 95, 97, 99, 116, 105, 118, 97, 116, 105, 111, 110, 95, 102, 117, 110, 99, 116, 105, 111, 110, 0, 4, 82, 69, 76, 85, 0, 2, 33, 43, 2, 1, 2, 11, 2, 20, 4, 4, 36, 1 ] +// CHECK-NEXT: custom_options: [ 102, 117, 115, 101, 100, 95, 97, 99, 116, 105, 118, 97, 116, 105, 111, 110, 95, 102, 117, 110, 99, 116, 105, 111, 110, 0, 4, 82, 69, 76, 85, 0, 105, 110, 116, 95, 97, 116, 116, 114, 0, 2, 42, 11, 2, 1, 2, 20, 2, 20, 4, 4, 36, 1 ] // CHECK-NEXT: }, { // CHECK-NEXT: opcode_index: 2, // CHECK-NEXT: inputs: [ 3 ], diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm.mlir index e278572cd1e..ef78f993cc4 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm.mlir @@ -1,6 +1,6 @@ // RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s -func @main(tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32> { +func @main(tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32> { // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { @@ -72,21 +72,21 @@ func @main(tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, t // CHECK-EMPTY: // CHECK-NEXT: } // CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 4, 4 ], +// CHECK-NEXT: shape: [ 4 ], // CHECK-NEXT: buffer: 10, // CHECK-NEXT: name: "arg9", // CHECK-NEXT: quantization: { // CHECK-EMPTY: // CHECK-NEXT: } // CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 4, 4 ], +// CHECK-NEXT: shape: [ 4 ], // CHECK-NEXT: buffer: 11, // CHECK-NEXT: name: "arg10", // CHECK-NEXT: quantization: { // CHECK-EMPTY: // CHECK-NEXT: } // CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 4, 4 ], +// CHECK-NEXT: shape: [ 4 ], // CHECK-NEXT: buffer: 12, // CHECK-NEXT: name: "arg11", // CHECK-NEXT: quantization: { @@ -100,21 +100,21 @@ func @main(tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, t // CHECK-EMPTY: // CHECK-NEXT: } // CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 1, 4 ], +// CHECK-NEXT: shape: [ 4 ], // CHECK-NEXT: buffer: 14, // CHECK-NEXT: name: "arg13", // CHECK-NEXT: quantization: { // CHECK-EMPTY: // CHECK-NEXT: } // CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 1, 4 ], +// CHECK-NEXT: shape: [ 4 ], // CHECK-NEXT: buffer: 15, // CHECK-NEXT: name: "arg14", // CHECK-NEXT: quantization: { // CHECK-EMPTY: // CHECK-NEXT: } // CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 1, 4 ], +// CHECK-NEXT: shape: [ 4 ], // CHECK-NEXT: buffer: 16, // CHECK-NEXT: name: "arg15", // CHECK-NEXT: quantization: { @@ -128,7 +128,7 @@ func @main(tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, t // CHECK-EMPTY: // CHECK-NEXT: } // CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 1, 4 ], +// CHECK-NEXT: shape: [ 4 ], // CHECK-NEXT: buffer: 18, // CHECK-NEXT: name: "arg17", // CHECK-NEXT: quantization: { @@ -261,9 +261,9 @@ func @main(tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, t // CHECK-EMPTY: -^bb0(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>, %arg3: tensor<4x4xf32>, %arg4: tensor<4x4xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4x4xf32>, %arg10: tensor<4x4xf32>, %arg11: tensor<4x4xf32>, %arg12: tensor<1x4xf32>, %arg13: tensor<1x4xf32>, %arg14: tensor<1x4xf32>, %arg15: tensor<1x4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<1x4xf32>, %arg18: tensor<4xf32>, %arg19: tensor<4xf32>, %arg20: tensor<4xf32>, %arg21: tensor<4xf32>): +^bb0(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>, %arg3: tensor<4x4xf32>, %arg4: tensor<4x4xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4xf32>, %arg10: tensor<4xf32>, %arg11: tensor<4xf32>, %arg12: tensor<1x4xf32>, %arg13: tensor<4xf32>, %arg14: tensor<4xf32>, %arg15: tensor<4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<4xf32>, %arg18: tensor<4xf32>, %arg19: tensor<4xf32>, %arg20: tensor<4xf32>, %arg21: tensor<4xf32>): %cst0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const") %cst1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const") - %24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32> + %24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32> return %24 : tensor<1x4xf32> } diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_lstm.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_lstm.mlir index 8e579421b0b..d9bba58b7d7 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_lstm.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_lstm.mlir @@ -1,6 +1,6 @@ // RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s -func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>) -> tensor<4 x f32> { +func @main(tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> { // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { @@ -9,63 +9,63 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, t // CHECK-NEXT: } ], // CHECK-NEXT: subgraphs: [ { // CHECK-NEXT: tensors: [ { -// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: shape: [ 4, 4 ], // CHECK-NEXT: buffer: 1, // CHECK-NEXT: name: "arg0", // CHECK-NEXT: quantization: { // CHECK-EMPTY: // CHECK-NEXT: } // CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: shape: [ 4, 4 ], // CHECK-NEXT: buffer: 2, // CHECK-NEXT: name: "arg1", // CHECK-NEXT: quantization: { // CHECK-EMPTY: // CHECK-NEXT: } // CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: shape: [ 4, 4 ], // CHECK-NEXT: buffer: 3, // CHECK-NEXT: name: "arg2", // CHECK-NEXT: quantization: { // CHECK-EMPTY: // CHECK-NEXT: } // CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: shape: [ 4, 4 ], // CHECK-NEXT: buffer: 4, // CHECK-NEXT: name: "arg3", // CHECK-NEXT: quantization: { // CHECK-EMPTY: // CHECK-NEXT: } // CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: shape: [ 4, 4 ], // CHECK-NEXT: buffer: 5, // CHECK-NEXT: name: "arg4", // CHECK-NEXT: quantization: { // CHECK-EMPTY: // CHECK-NEXT: } // CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: shape: [ 4, 4 ], // CHECK-NEXT: buffer: 6, // CHECK-NEXT: name: "arg5", // CHECK-NEXT: quantization: { // CHECK-EMPTY: // CHECK-NEXT: } // CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: shape: [ 4, 4 ], // CHECK-NEXT: buffer: 7, // CHECK-NEXT: name: "arg6", // CHECK-NEXT: quantization: { // CHECK-EMPTY: // CHECK-NEXT: } // CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: shape: [ 4, 4 ], // CHECK-NEXT: buffer: 8, // CHECK-NEXT: name: "arg7", // CHECK-NEXT: quantization: { // CHECK-EMPTY: // CHECK-NEXT: } // CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: shape: [ 4, 4 ], // CHECK-NEXT: buffer: 9, // CHECK-NEXT: name: "arg8", // CHECK-NEXT: quantization: { @@ -121,63 +121,63 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, t // CHECK-EMPTY: // CHECK-NEXT: } // CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: shape: [ 4, 4 ], // CHECK-NEXT: buffer: 17, // CHECK-NEXT: name: "arg16", // CHECK-NEXT: quantization: { // CHECK-EMPTY: // CHECK-NEXT: } // CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: shape: [ 4, 4 ], // CHECK-NEXT: buffer: 18, // CHECK-NEXT: name: "arg17", // CHECK-NEXT: quantization: { // CHECK-EMPTY: // CHECK-NEXT: } // CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: shape: [ 4, 4 ], // CHECK-NEXT: buffer: 19, // CHECK-NEXT: name: "arg18", // CHECK-NEXT: quantization: { // CHECK-EMPTY: // CHECK-NEXT: } // CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: shape: [ 4, 4 ], // CHECK-NEXT: buffer: 20, // CHECK-NEXT: name: "arg19", // CHECK-NEXT: quantization: { // CHECK-EMPTY: // CHECK-NEXT: } // CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: shape: [ 4, 4 ], // CHECK-NEXT: buffer: 21, // CHECK-NEXT: name: "arg20", // CHECK-NEXT: quantization: { // CHECK-EMPTY: // CHECK-NEXT: } // CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: shape: [ 4, 4 ], // CHECK-NEXT: buffer: 22, // CHECK-NEXT: name: "arg21", // CHECK-NEXT: quantization: { // CHECK-EMPTY: // CHECK-NEXT: } // CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: shape: [ 4, 4 ], // CHECK-NEXT: name: "Const", // CHECK-NEXT: quantization: { // CHECK-EMPTY: // CHECK-NEXT: }, // CHECK-NEXT: is_variable: true // CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: shape: [ 4, 4 ], // CHECK-NEXT: name: "Const1", // CHECK-NEXT: quantization: { // CHECK-EMPTY: // CHECK-NEXT: }, // CHECK-NEXT: is_variable: true // CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: shape: [ 4, 4 ], // CHECK-NEXT: buffer: 25, // CHECK-NEXT: name: "tfl.unidirectional_sequence_lstm", // CHECK-NEXT: quantization: { @@ -244,9 +244,9 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, t // CHECK-NEXT: }, { // CHECK-EMPTY: // CHECK-NEXT: }, { -// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] +// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] // CHECK-NEXT: }, { -// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] +// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] // CHECK-NEXT: }, { // CHECK-EMPTY: // CHECK-NEXT: }, { @@ -259,9 +259,9 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, t // CHECK-NEXT: } // CHECK-EMPTY: -^bb0(%arg0: tensor<4 x f32>, %arg1: tensor<4 x f32>, %arg2: tensor<4 x f32>, %arg3: tensor<4 x f32>, %arg4: tensor<4 x f32>, %arg5: tensor<4 x f32>, %arg6: tensor<4 x f32>, %arg7: tensor<4 x f32>, %arg8: tensor<4 x f32>, %arg9: tensor<4 x f32>, %arg10: tensor<4 x f32>, %arg11: tensor<4 x f32>, %arg12: tensor<4 x f32>, %arg13: tensor<4 x f32>, %arg14: tensor<4 x f32>, %arg15: tensor<4 x f32>, %arg16: tensor<4 x f32>, %arg17: tensor<4 x f32>, %arg18: tensor<4 x f32>, %arg19: tensor<4 x f32>, %arg20: tensor<4 x f32>, %arg21: tensor<4 x f32>): - %0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const") - %1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const") - %2 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %0, %1, %arg18, %arg19, %arg20, %arg21) {fused_activation_function = "NONE", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - return %2 : tensor<4xf32> +^bb0(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>, %arg3: tensor<4x4xf32>, %arg4: tensor<4x4xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4xf32>, %arg10: tensor<4xf32>, %arg11: tensor<4xf32>, %arg12: tensor<4xf32>, %arg13: tensor<4xf32>, %arg14: tensor<4xf32>, %arg15: tensor<4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<4x4xf32>, %arg18: tensor<4x4xf32>, %arg19: tensor<4x4xf32>, %arg20: tensor<4x4xf32>, %arg21: tensor<4x4xf32>): + %0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4x4xf32>} : () -> tensor<4x4xf32> loc("Const") + %1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4x4xf32>} : () -> tensor<4x4xf32> loc("Const") + %2 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %0, %1, %arg18, %arg19, %arg20, %arg21) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + return %2 : tensor<4x4xf32> } diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_rnn.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_rnn.mlir index 7ba24bd5c51..f2b99bcd0df 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_rnn.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_rnn.mlir @@ -37,7 +37,7 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>) - // CHECK-EMPTY: // CHECK-NEXT: } // CHECK-NEXT: }, { -// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: shape: [ 4, 4 ], // CHECK-NEXT: name: "Const", // CHECK-NEXT: quantization: { // CHECK-EMPTY: @@ -76,7 +76,7 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>) - // CHECK-NEXT: }, { // CHECK-EMPTY: // CHECK-NEXT: }, { -// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] +// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] // CHECK-NEXT: }, { // CHECK-EMPTY: // CHECK-NEXT: }, { @@ -90,7 +90,7 @@ func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>) - // CHECK-EMPTY: ^bb0(%arg0: tensor<4 x f32>, %arg1: tensor<4 x f32>, %arg2: tensor<4 x f32>, %arg3: tensor<4 x f32>): - %0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const") - %1 = "tfl.unidirectional_sequence_rnn"(%arg0, %arg1, %arg2, %arg3, %0) {fused_activation_function = "TANH", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4x4xf32>} : () -> tensor<4x4xf32> loc("Const") + %1 = "tfl.unidirectional_sequence_rnn"(%arg0, %arg1, %arg2, %arg3, %0) {fused_activation_function = "TANH", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>) -> tensor<4xf32> return %1 : tensor<4xf32> } diff --git a/tensorflow/compiler/mlir/lite/tests/ops.mlir b/tensorflow/compiler/mlir/lite/tests/ops.mlir index f42e06350e5..3451f28380b 100644 --- a/tensorflow/compiler/mlir/lite/tests/ops.mlir +++ b/tensorflow/compiler/mlir/lite/tests/ops.mlir @@ -190,9 +190,9 @@ func @testSquare(tensor<? x f32>) -> tensor<? x f32> { return %0 : tensor<? x f32> } -func @testQuantizedResizeNearestNeighbor(tensor<? x !quant.uniform<u8:f32, 0.1>>, tensor<? x i32>) -> tensor<? x !quant.uniform<u8:f32, 0.1>> { -^bb0(%arg0: tensor<? x !quant.uniform<u8:f32, 0.1>>, %arg1: tensor<? x i32>): - %0 = "tfl.resize_nearest_neighbor"(%arg0, %arg1) { align_corners = false, half_pixel_centers = false } : (tensor<? x !quant.uniform<u8:f32, 0.1>>, tensor<? x i32>) -> tensor<? x !quant.uniform<u8:f32, 0.1>> +func @testQuantizedResizeNearestNeighbor(tensor<? x ? x ? x ? x !quant.uniform<u8:f32, 0.1>>, tensor<? x i32>) -> tensor<? x !quant.uniform<u8:f32, 0.1>> { +^bb0(%arg0: tensor<? x ? x ? x ? x !quant.uniform<u8:f32, 0.1>>, %arg1: tensor<? x i32>): + %0 = "tfl.resize_nearest_neighbor"(%arg0, %arg1) { align_corners = false, half_pixel_centers = false } : (tensor<? x ? x ? x ? x !quant.uniform<u8:f32, 0.1>>, tensor<? x i32>) -> tensor<? x !quant.uniform<u8:f32, 0.1>> return %0 : tensor<? x !quant.uniform<u8:f32, 0.1>> } @@ -581,36 +581,36 @@ func @testLogisticWithWrongInputType(tensor<?xi32>) -> tensor<?xi32> { // ----- // CHECK-LABEL: testUnidirectionalSequenceRnn -func @testUnidirectionalSequenceRnn(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>) -> tensor<? x f32> { - // CHECK: "tfl.unidirectional_sequence_rnn"(%arg0, %arg1, %arg2, %arg3, %arg4) {fused_activation_function = "NONE", time_major = false} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> - %0 = "tfl.unidirectional_sequence_rnn"(%arg0, %arg1, %arg2, %arg3, %arg4) {fused_activation_function = "NONE", time_major = false} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> +func @testUnidirectionalSequenceRnn(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x ? x f32>) -> tensor<? x f32> { + // CHECK: "tfl.unidirectional_sequence_rnn"(%arg0, %arg1, %arg2, %arg3, %arg4) {fused_activation_function = "NONE", time_major = false} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>) -> tensor<?xf32> + %0 = "tfl.unidirectional_sequence_rnn"(%arg0, %arg1, %arg2, %arg3, %arg4) {fused_activation_function = "NONE", time_major = false} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>) -> tensor<?xf32> return %0 : tensor<?xf32> } // ----- // CHECK-LABEL: testUnidirectionalSequenceLstmWithoutProjection -func @testUnidirectionalSequenceLstmWithoutProjection(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: none, %arg17: none, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> { - // CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, none, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> - %0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, none, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> +func @testUnidirectionalSequenceLstmWithoutProjection(%arg0: tensor<? x ? x f32>, %arg1: tensor<? x ? x f32>, %arg2: tensor<? x ? x f32>, %arg3: tensor<? x ? x f32>, %arg4: tensor<? x ? x f32>, %arg5: tensor<? x ? x f32>, %arg6: tensor<? x ? x f32>, %arg7: tensor<? x ? x f32>, %arg8: tensor<? x ? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: none, %arg17: none, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> { + // CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, none, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> + %0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, none, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> return %0 : tensor<?xf32> } // ----- // CHECK-LABEL: testUnidirectionalSequenceLstm -func @testUnidirectionalSequenceLstm(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> { - // CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> - %0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> +func @testUnidirectionalSequenceLstm(%arg0: tensor<? x ? x f32>, %arg1: tensor<? x ? x f32>, %arg2: tensor<? x ? x f32>, %arg3: tensor<? x ? x f32>, %arg4: tensor<? x ? x f32>, %arg5: tensor<? x ? x f32>, %arg6: tensor<? x ? x f32>, %arg7: tensor<? x ? x f32>, %arg8: tensor<? x ? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x ? x f32>, %arg17: tensor<? x ? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> { + // CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> + %0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> return %0 : tensor<?xf32> } // ----- // CHECK-LABEL: testUnidirectionalSequenceLstmWithNoneTypeAndOverrideAttr -func @testUnidirectionalSequenceLstmWithNoneTypeAndOverrideAttr(%arg0: tensor<? x f32>, %arg1: none, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> { - // CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", time_major = false} : (tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> - %0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", time_major = false} : (tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> +func @testUnidirectionalSequenceLstmWithNoneTypeAndOverrideAttr(%arg0: tensor<? x ? x f32>, %arg1: none, %arg2: tensor<? x ? x f32>, %arg3: tensor<? x ? x f32>, %arg4: tensor<? x ? x f32>, %arg5: tensor<? x ? x f32>, %arg6: tensor<? x ? x f32>, %arg7: tensor<? x ? x f32>, %arg8: tensor<? x ? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x ? x f32>, %arg17: tensor<? x ? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> { + // CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", time_major = false} : (tensor<?x?xf32>, none, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> + %0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", time_major = false} : (tensor<?x?xf32>, none, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> return %0 : tensor<?xf32> } @@ -663,10 +663,10 @@ func @testLstmQuantizedType(%arg0: tensor<1x528x!quant.uniform<i8:f32, 0.0372480 // ----- // CHECK-LABEL: testLstm -func @testLstm(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> { +func @testLstm(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: tensor<? x ? x f32>, %arg3: tensor<? x ? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x ? x f32>, %arg6: tensor<? x ? x f32>, %arg7: tensor<? x ? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x ? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> { // CHECK: "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) - // CHECK-NEXT: {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> - %0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> + // CHECK-NEXT: {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> + %0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> return %0 : tensor<?xf32> } @@ -689,10 +689,10 @@ func @testQuantizedBasicLstm(%arg0: tensor<1x384x!quant.uniform<u8:f32, 7.812500 // ----- // CHECK-LABEL: testLstmWithNoneTypeAndOverrideAttr -func @testLstmWithNoneTypeAndOverrideAttr(%arg0: tensor<? x f32>, %arg1: none, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> { +func @testLstmWithNoneTypeAndOverrideAttr(%arg0: tensor<? x f32>, %arg1: none, %arg2: tensor<? x ? x f32>, %arg3: tensor<? x ? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x ? x f32>, %arg6: tensor<? x ? x f32>, %arg7: tensor<? x ? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x ? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> { // CHECK: "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) - // CHECK-NEXT: {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> - %0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> + // CHECK-NEXT: {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, none, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> + %0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, none, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> return %0 : tensor<?xf32> } @@ -707,11 +707,11 @@ func @testLstmWithInvalidNoneType(%arg0: tensor<? x f32>, %arg1: tensor<? x f32> // ----- -// test invalid input dimension, the first input operand for lstm op should be at least 2D tensor. +// test invalid input dimension, the third input operand for lstm op should be 2-D tensor. func @testLstmWithInvalidInputDimension(%arg0: tensor<4 x f32>, %arg1: tensor<4 x f32>, %arg2: tensor<4 x f32>, %arg3: tensor<4 x f32>, %arg4: tensor<4 x f32>, %arg5: tensor<4 x f32>, %arg6: tensor<4 x f32>, %arg7: tensor<4 x f32>, %arg8: tensor<4 x f32>, %arg9: tensor<4 x f32>, %arg10: tensor<4 x f32>, %arg11: tensor<4 x f32>, %arg12: tensor<4 x f32>, %arg13: tensor<4 x f32>, %arg14: tensor<4 x f32>, %arg15: tensor<4 x f32>, %arg16: tensor<4 x f32>, %arg17: tensor<4 x f32>, %arg18: tensor<4 x f32>, %arg19: tensor<4 x f32>, %arg20: tensor<4 x f32>, %arg21: tensor<4 x f32>) -> tensor<4 x f32> { %cst0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const") %cst1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const") - // expected-error @+1 {{'tfl.lstm' op the first input operand should have more than 2 dimensions.}} + // expected-error @+1 {{'tfl.lstm' op failed to verify that operand 2 is 2-D}} %24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %24 : tensor<4xf32> @@ -720,22 +720,22 @@ func @testLstmWithInvalidInputDimension(%arg0: tensor<4 x f32>, %arg1: tensor<4 // ----- // 'input_to_output_weights' input for lstm op has unmatched rank with `input`. -func @testLstmWithInvalidInputsRankMatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4x2xf32>, %arg2: tensor<4x2xf32>, %arg3: tensor<4x2xf32>, %arg4: tensor<4x2xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4x4xf32>, %arg10: tensor<4x4xf32>, %arg11: tensor<4x4xf32>, %arg12: tensor<1x4xf32>, %arg13: tensor<1x4xf32>, %arg14: tensor<1x4xf32>, %arg15: tensor<1x4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<1x4xf32>, %arg18: tensor<4xf32>, %arg19: tensor<4xf32>, %arg20: tensor<4xf32>, %arg21: tensor<4xf32>) -> tensor<1x4xf32> { +func @testLstmWithInvalidInputsRankMatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4x2xf32>, %arg2: tensor<4x2xf32>, %arg3: tensor<4x2xf32>, %arg4: tensor<4x2xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4xf32>, %arg10: tensor<4xf32>, %arg11: tensor<4xf32>, %arg12: tensor<1x4xf32>, %arg13: tensor<4xf32>, %arg14: tensor<4xf32>, %arg15: tensor<4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<4xf32>, %arg18: tensor<4xf32>, %arg19: tensor<4xf32>, %arg20: tensor<4xf32>, %arg21: tensor<4xf32>) -> tensor<1x4xf32> { %cst0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const") %cst1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const") // expected-error @+1 {{'tfl.lstm' op inputs don't match with the dimensions.}} - %24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32> + %24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32> return %24 : tensor<1x4xf32> } // ----- // Coefficient inputs of LSTM op don't match the dimension with input operand `input_to_output_weights`. -func @testLstmWithInvalidInputsRankMatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>, %arg3: tensor<4x4xf32>, %arg4: tensor<4x4xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4x4xf32>, %arg10: tensor<4x4xf32>, %arg11: tensor<4x4xf32>, %arg12: tensor<1x4xf32>, %arg13: tensor<1x4xf32>, %arg14: tensor<1x4xf32>, %arg15: tensor<1x4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<1x4xf32>, %arg18: tensor<3xf32>, %arg19: tensor<3xf32>, %arg20: tensor<3xf32>, %arg21: tensor<3xf32>) -> tensor<1x4xf32> { +func @testLstmWithInvalidInputsRankMatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>, %arg3: tensor<4x4xf32>, %arg4: tensor<4x4xf32>, %arg5: tensor<4x4xf32>, %arg6: tensor<4x4xf32>, %arg7: tensor<4x4xf32>, %arg8: tensor<4x4xf32>, %arg9: tensor<4xf32>, %arg10: tensor<4xf32>, %arg11: tensor<4xf32>, %arg12: tensor<1x4xf32>, %arg13: tensor<4xf32>, %arg14: tensor<4xf32>, %arg15: tensor<4xf32>, %arg16: tensor<4x4xf32>, %arg17: tensor<4xf32>, %arg18: tensor<3xf32>, %arg19: tensor<3xf32>, %arg20: tensor<3xf32>, %arg21: tensor<3xf32>) -> tensor<1x4xf32> { %cst0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const") %cst1 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<1x4xf32>} : () -> tensor<1x4xf32> loc("Const") // expected-error @+1 {{'tfl.lstm' op coefficient inputs have more than 2 dimensions or don't match the dimension with input operand `input_to_output_weights`.}} - %24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<1x4xf32> + %24 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %cst0, %cst1, %arg18, %arg19, %arg20, %arg21) ({}) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<1x4xf32> return %24 : tensor<1x4xf32> } @@ -1201,7 +1201,7 @@ func @testResizeBilinear(%arg0 : tensor<1x100x100x3xf32>, %arg1 : tensor<4xi32>) // ----- func @testResizeBilinearInvalidOutputType(%arg0 : tensor<1x100x100x3xf32>, %arg1 : tensor<4xi32>) -> tensor<?xi32> { - // expected-error @+1 {{'tfl.resize_bilinear' op result #0 must be tensor of 32-bit float or QI8 type or QUI8 type values}} + // expected-error @+1 {{'tfl.resize_bilinear' op failed to verify that input and output must have same element type}} %0 = "tfl.resize_bilinear"(%arg0, %arg1) {align_corners = false} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor<?xi32> return %0 : tensor<?xi32> } @@ -1499,8 +1499,8 @@ func @testWrongQuantizedLocalResponseNormalization(%arg0 : tensor<1x56x56x192x!q // CHECK-LABEL: testSvdf func @testSvdf(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>) -> tensor<? x f32> { - // CHECK: "tfl.svdf"(%arg0, %arg1, %arg2, %arg3, %arg4) {fused_activation_function = "NONE", rank = 2 : i32} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> - %0 = "tfl.svdf"(%arg0, %arg1, %arg2, %arg3, %arg4) {fused_activation_function = "NONE", rank = 2 : i32} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> + // CHECK: "tfl.svdf"(%arg0, %arg1, %arg2, %arg3, %arg4) {fused_activation_function = "RELU", rank = 2 : i32} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> + %0 = "tfl.svdf"(%arg0, %arg1, %arg2, %arg3, %arg4) {fused_activation_function = "RELU", rank = 2 : i32} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> return %0 : tensor<?xf32> } diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index 2815afd14b9..3f8257b54f0 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -439,6 +439,19 @@ func @NotReorderReshapeAddIfNotTailingDim(%arg0: tensor<40x40x1xf32>) -> tensor< // CHECK: return %[[rs2]] } +// CHECK-LABEL: @NotReorderReshapeAddIfHighDim +func @NotReorderReshapeAddIfHighDim(%arg0: tensor<1x1x1x1x30x96xf32>) -> tensor<1x30x96xf32> { + %cst = constant dense<2.0> : tensor<f32> + %shape = constant dense<[1, 30, 96]> : tensor<3xi32> + %1 = "tfl.reshape"(%arg0, %shape) : (tensor<1x1x1x1x30x96xf32>, tensor<3xi32>) -> tensor<1x30x96xf32> + %2 = "tfl.add"(%1, %cst) {fused_activation_function = "NONE"} : (tensor<1x30x96xf32>, tensor<f32>) -> tensor<1x30x96xf32> + return %2 : tensor<1x30x96xf32> + + // CHECK: %[[rs1:.*]] = "tfl.reshape"(%arg0 + // CHECK: %[[rs2:.*]] = "tfl.add"(%[[rs1]] + // CHECK: return %[[rs2]] +} + // CHECK-LABEL: @ReorderElementwiseValueOpAndMoveOp func @ReorderElementwiseValueOpAndMoveOp(%arg0: tensor<40x40x1xf32>) -> tensor<40x40xf32> { %shape = constant dense<[40, 40]> : tensor<2xi32> diff --git a/tensorflow/compiler/mlir/lite/tests/post-quantize.mlir b/tensorflow/compiler/mlir/lite/tests/post-quantize.mlir index 5377c4fdb98..6573a2f1c36 100644 --- a/tensorflow/compiler/mlir/lite/tests/post-quantize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/post-quantize.mlir @@ -19,6 +19,16 @@ func @RemoveUnused(%arg0: tensor<4xf32>, %arg1: tensor<i32>) -> (tensor<2xf32>,t // CHECK-NEXT: return %[[split]]#0, %[[split]]#1 } +// CHECK-LABEL: RemoveTrival +func @RemoveTrival(%arg0: tensor<384x512x!quant.uniform<i8:f32, 1.0:-128>>, %arg1: tensor<128x512x!quant.uniform<i8<-127:127>:f32, 1.0>>, %arg2: none) -> tensor<384x128x!quant.uniform<i8:f32, 2.0>> { + %1 = "tfl.fully_connected"(%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<384x512x!quant.uniform<i8:f32, 1.0:-128>>, tensor<128x512x!quant.uniform<i8<-127:127>:f32, 1.0>>, none) -> tensor<384x128x!quant.uniform<i8:f32, 1.0>> + %2 = "tfl.quantize"(%1) {qtype = tensor<384x128x!quant.uniform<i8:f32, 2.0>>} : (tensor<384x128x!quant.uniform<i8:f32, 1.0>>) -> tensor<384x128x!quant.uniform<i8:f32, 2.0>> + return %2 : tensor<384x128x!quant.uniform<i8:f32, 2.0>> + +// CHECK-NEXT: %[[fc:.*]] = "tfl.fully_connected"{{.*}} -> tensor<384x128x!quant.uniform<i8:f32, 2.000000e+00>> +// CHECK-NEXT: return %[[fc]] +} + func @main(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x1001xf32> { %cst = constant dense<[1, 1001]> : tensor<2xi32> %0 = "tfl.quantize"(%arg0) {qtype = tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>>} : (tensor<1x224x224x3xf32>) -> tensor<1x224x224x3x!quant.uniform<u8:f32, 7.812500e-03:128>> diff --git a/tensorflow/compiler/mlir/lite/tests/split-merged-operands.mlir b/tensorflow/compiler/mlir/lite/tests/split-merged-operands.mlir index d2d0e43e0e9..c5c9ee645f4 100644 --- a/tensorflow/compiler/mlir/lite/tests/split-merged-operands.mlir +++ b/tensorflow/compiler/mlir/lite/tests/split-merged-operands.mlir @@ -1,27 +1,27 @@ // RUN: tf-opt -tfl-split-merged-operands %s | FileCheck %s -func @testSingleLstm(%arg0: tensor<4 x f32>) -> tensor<4xf32> { +func @testSingleLstm(%arg0: tensor<4x4xf32>, %arg1: tensor<4xf32>) -> tensor<4x4xf32> { // CHECK-LABEL: testSingleLstm - // CHECK: %[[CST_0:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4xf32>} : () -> tensor<4xf32> - // CHECK: %[[CST_1:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4xf32>} : () -> tensor<4xf32> - // CHECK: %[[LSTM:[a-z0-9]*]] = "tfl.unidirectional_sequence_lstm"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %[[CST_0]], %[[CST_1]], %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + // CHECK: %[[CST_0:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4x4xf32>} : () -> tensor<4x4xf32> + // CHECK: %[[CST_1:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4x4xf32>} : () -> tensor<4x4xf32> + // CHECK: %[[LSTM:[a-z0-9]*]] = "tfl.unidirectional_sequence_lstm"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg0, %[[CST_0]], %[[CST_1]], %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> - %0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const") - %1 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %0, %0, %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - return %1 : tensor<4xf32> + %0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4x4xf32>} : () -> tensor<4x4xf32> loc("Const") + %1 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg0, %0, %0, %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + return %1 : tensor<4x4xf32> } -func @testMultipleLstms(%arg0: tensor<4 x f32>) -> tensor<4xf32> { +func @testMultipleLstms(%arg0: tensor<4x4xf32>, %arg1: tensor<4xf32>) -> tensor<4x4xf32> { // CHECK-LABEL: testMultipleLstms - // CHECK: %[[CST_0:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4xf32>} : () -> tensor<4xf32> - // CHECK: %[[CST_1:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4xf32>} : () -> tensor<4xf32> - // CHECK: %[[LSTM_1:[a-z0-9]*]] = "tfl.unidirectional_sequence_lstm"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %[[CST_0]], %[[CST_1]], %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - // CHECK: %[[CST_2:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4xf32>} : () -> tensor<4xf32> - // CHECK: %[[CST_3:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4xf32>} : () -> tensor<4xf32> - // CHECK: %[[LSTM_2:[a-z0-9]*]] = "tfl.unidirectional_sequence_lstm"(%[[LSTM_1]], %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %[[CST_2]], %[[CST_3]], %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + // CHECK: %[[CST_0:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4x4xf32>} : () -> tensor<4x4xf32> + // CHECK: %[[CST_1:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4x4xf32>} : () -> tensor<4x4xf32> + // CHECK: %[[LSTM_1:[a-z0-9]*]] = "tfl.unidirectional_sequence_lstm"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg0, %[[CST_0]], %[[CST_1]], %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + // CHECK: %[[CST_2:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4x4xf32>} : () -> tensor<4x4xf32> + // CHECK: %[[CST_3:.*]] = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<4x4xf32>} : () -> tensor<4x4xf32> + // CHECK: %[[LSTM_2:[a-z0-9]*]] = "tfl.unidirectional_sequence_lstm"(%[[LSTM_1]], %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg0, %[[CST_2]], %[[CST_3]], %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> - %0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const") - %1 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %0, %0, %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - %2 = "tfl.unidirectional_sequence_lstm"(%1, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %0, %0, %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - return %2 : tensor<4xf32> + %0 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4x4xf32>} : () -> tensor<4x4xf32> loc("Const") + %1 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg0, %0, %0, %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + %2 = "tfl.unidirectional_sequence_lstm"(%1, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg1, %arg0, %arg0, %0, %0, %arg0, %arg0, %arg0, %arg0) {fused_activation_function = "NONE", time_major = true} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> + return %2 : tensor<4x4xf32> } diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc index d3f1a430642..40420eee697 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc @@ -162,6 +162,10 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config, pass_manager->addPass( mlir::TFL::CreatePrepareTFPass(pass_config.unfold_batch_matmul)); pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass()); + if (pass_config.shape_inference) { + // Add a shape inference pass to optimize away the unnecessary casts. + pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass()); + } pass_manager->addPass( mlir::TFL::CreateLegalizeTFPass(pass_config.runtime_verification)); pass_manager->addPass(mlir::TFL::CreateOptimizePass()); diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc index ab4c4f5c4cf..bfcbc190638 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc @@ -37,6 +37,7 @@ limitations under the License. #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" @@ -202,6 +203,26 @@ LogicalResult ConvertTFConcatOp::matchAndRewrite( return success(); } +// Converts any IntegerAttr to an IntegerAttr of an i32 type. +// The value won't change in the new attribute, but if the value is out of +// the bound of i32, the function returns a failure. +LogicalResult ConvertToI32Attr(IntegerAttr attr, IntegerAttr* attr_i32) { + if (attr.getType().isInteger(/*width=*/32)) { + *attr_i32 = attr; + return success(); + } + + int64_t value = attr.getInt(); + if (value > std::numeric_limits<int>::max() || + value < std::numeric_limits<int>::min()) { + return failure(); + } + + *attr_i32 = IntegerAttr::get( + IntegerType::get(/*width=*/32, attr.getContext()), value); + return success(); +} + LogicalResult ConvertTFConcatV2Op::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tf_concat_op = cast<TF::ConcatV2Op>(op); @@ -211,12 +232,16 @@ LogicalResult ConvertTFConcatV2Op::matchAndRewrite( // Extract axis attribute from constant axis tensor ElementsAttr axis; if (!matchPattern(tf_concat_op.axis(), m_Constant(&axis))) return failure(); + IntegerAttr axis_int = ExtractSingleElementAsInteger(axis); + + // "axis" operand could be a i64 tensor. Resolve it here. + IntegerAttr axis_i32; + if (failed(ConvertToI32Attr(axis_int, &axis_i32))) return failure(); StringAttr fused_activation_function = StringAttr::get("NONE", rewriter.getContext()); rewriter.replaceOpWithNewOp<ConcatenationOp>( - op, output_type, values, ExtractSingleElementAsInteger(axis), - fused_activation_function); + op, output_type, values, axis_i32, fused_activation_function); return success(); } diff --git a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc index 49be29065fe..45b8c9e5fb2 100644 --- a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc +++ b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc @@ -838,7 +838,8 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction( // TensorFlow operations that doesn't have operands and results of type // variant are legal. Here, we don't distinguish between variants encoding // TensorList or some other type as that information is not available here. - // This constraint should be relaxed to support other variant types in TFLite. + // Partial legalization is used below to still allow ops with variant types + // still. auto is_legal = [](Operation *op) { auto is_not_variant = [](Type ty) { return !ty.cast<ShapedType>().getElementType().isa<TF::VariantType>(); @@ -873,7 +874,7 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction( ConvertTensorListPushBack, ConvertTensorListReserve, ConvertTensorListSetItem, ConvertTensorListStack, ConvertTensorListResize, ConvertWhile>(context); - return applyFullConversion(func, target, patterns); + return applyPartialConversion(func, target, patterns); } void LowerStaticTensorListPass::runOnOperation() { diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td index a3244f31053..6ade6122fe4 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td @@ -29,6 +29,10 @@ def ExtractSingleElementAsFloat : NativeCodeCall< // Checks if the value has only one user. def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>; +// Checks if the value has rank at most 'n'. +class HasRankAtMost<int n> : Constraint< + CPred<"$0.getType().cast<ShapedType>().getRank() <= " # n>>; + //===----------------------------------------------------------------------===// // Ternary ops patterns. //===----------------------------------------------------------------------===// @@ -347,7 +351,9 @@ foreach BinaryOp = [TFL_AddOp, TFL_SubOp, TFL_DivOp, TFL_MulOp] in { // The result of the new "BinaryOp" will have the same shape as // `input`. In other words, the shape of the `Reshape` op are not // changed after the transformation. - (IsTailOfShape $rhs, $input)]>; + (IsTailOfShape $rhs, $input), + (HasRankAtMost<5> $input), + (HasRankAtMost<5> $rhs)]>; } foreach BinaryOp = [TFL_FloorDivOp, TFL_FloorModOp, TFL_MinimumOp, diff --git a/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc index 97b7d57dbf4..7954f72046a 100644 --- a/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc @@ -125,6 +125,7 @@ void PostQuantizePass::runOnFunction() { auto func = getFunction(); auto* ctx = func.getContext(); TFL::populateWithGenerated(ctx, &patterns); + patterns.insert<quant::FoldTrivalRequantizeOp<QuantizeOp>>(ctx); applyPatternsAndFoldGreedily(func, patterns); if (!emit_quant_adaptor_ops_) { diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc index 6179eb2ce64..56af68f6bbe 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc @@ -41,15 +41,22 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/utils/lstm_utils.h" +#include "tensorflow/compiler/mlir/lite/utils/tftext_utils.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +// The cmd line flag to turn on/off Tf.Text API fusion. // NOLINTNEXTLINE +static llvm::cl::opt<bool> fuse_tftext( + "tfl-fuse-tftext", llvm::cl::value_desc("bool"), + llvm::cl::desc("Fuse TF.Text API ops when it's true"), + llvm::cl::init(false)); namespace mlir { namespace TFL { namespace { constexpr char kTFAPIImplements[] = "tf.api_implements"; +constexpr char kTfTextAPIPRefix[] = "tftext:"; // Abstracts the conversion of the embedded lookup composite function. class ConvertEmbeddedLookupFunc { @@ -187,6 +194,10 @@ void PrepareCompositeFunctionsPass::ConvertTFAPIImplements(FuncOp func, OpBuilder builder(func.getBody()); if (failed(ConvertKerasLSTMLayer(func, &builder))) return signalPassFailure(); + } else if (fuse_tftext && attr.getValue().startswith(kTfTextAPIPRefix)) { + if (failed(ConvertTFTextAPI(func, attr.getValue()))) { + return signalPassFailure(); + } } } diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc index a9e10a485bf..87cae3dd957 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc @@ -70,6 +70,7 @@ class PrepareQuantizePass : public PassWrapper<PrepareQuantizePass, FunctionPass> { public: // Constructor used by the PassRegistration and enforce uint8 quantization. + // This is only used by test. explicit PrepareQuantizePass() { if (quantize_signed) quant_specs_.inference_type = tensorflow::DT_QINT8; @@ -257,15 +258,16 @@ void PrepareQuantizePass::runOnFunction() { // convert all of them to signed. OwningRewritePatternList patterns; bool is_signed = quant_specs_.IsSignedInferenceType(); + int bit_width = quant_specs_.GetQuantizationTypeWidth(); if (is_signed) { patterns.insert<quant::ConvertUnsignedToSigned<quant::QuantizeCastOp>>(ctx); // Convert quant stats to int8 quantization parameters. // Currently, only activation stats are imported, so narrow_range = false. - patterns.insert<PrepareQuantStats>(8, false, true, ctx); + patterns.insert<PrepareQuantStats>(bit_width, false, true, ctx); } else { // Convert quant stats to uint8 quantization parameters. // Currently, only activation stats are imported, so narrow_range = false. - patterns.insert<PrepareQuantStats>(8, false, false, ctx); + patterns.insert<PrepareQuantStats>(bit_width, false, false, ctx); } applyPatternsAndFoldGreedily(func, patterns); diff --git a/tensorflow/compiler/mlir/lite/utils/tftext_utils.cc b/tensorflow/compiler/mlir/lite/utils/tftext_utils.cc new file mode 100644 index 00000000000..12929152d1e --- /dev/null +++ b/tensorflow/compiler/mlir/lite/utils/tftext_utils.cc @@ -0,0 +1,127 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/lite/utils/tftext_utils.h" + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Identifier.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir { +namespace TFL { + +namespace { + +constexpr char kWhitespaceTokenizer[] = "tftext:WhitespaceTokenizer"; +constexpr char kTFAPIImplements[] = "tf.api_implements"; + +inline OpaqueElementsAttr emptyCustomOption(OpBuilder* builder) { + std::string content = ""; + ShapedType type = RankedTensorType::get( + {static_cast<int64_t>(content.size())}, builder->getIntegerType(8)); + return OpaqueElementsAttr::get( + builder->getContext()->getRegisteredDialect("tfl"), type, content); +} + +inline RankedTensorType getInputType(mlir::FuncOp func, int idx) { + return func.getType() + .getInput(idx) + .dyn_cast_or_null<mlir::RankedTensorType>(); +} + +inline RankedTensorType getResultType(mlir::FuncOp func, int idx) { + return func.getType() + .getResult(idx) + .dyn_cast_or_null<mlir::RankedTensorType>(); +} + +LogicalResult VerifyWhitespaceTokenizer(mlir::FuncOp func) { + if (func.getNumResults() != 2) { + return failure(); + } + if (func.getNumArguments() != 1) { + return failure(); + } + auto input_type = getInputType(func, 0); + if (!input_type || input_type.getRank() != 1 || + !input_type.getElementType().isa<mlir::TF::StringType>()) { + return failure(); + } + auto value_type = getResultType(func, 0); + if (!value_type || value_type.getRank() != 1 || + !value_type.getElementType().isa<mlir::TF::StringType>()) { + return failure(); + } + auto offset_type = getResultType(func, 1); + if (offset_type.getRank() != 1 || + !offset_type.getElementType().isInteger(64)) { + return failure(); + } + return success(); +} + +LogicalResult ConvertWhitespaceTokenizer(mlir::FuncOp func, + llvm::StringRef api) { + func.eraseBody(); + func.addEntryBlock(); + func.setAttr(kTFAPIImplements, StringAttr::get(api, func.getContext())); + + Value text = func.getArgument(0); + auto output_type = func.getType().getResult(0); + auto offset_type = func.getType().getResult(1); + SmallVector<Type, 2> shape = {output_type, offset_type}; + ArrayRef<Type> output_types(shape); + + OpBuilder builder(func.getBody()); + + auto op = builder.create<mlir::TFL::CustomOp>(func.getLoc(), output_types, + ValueRange(text), api, + emptyCustomOption(&builder)); + + builder.create<mlir::ReturnOp>(func.getLoc(), op.getResults()); + return success(); +} +} // namespace + +LogicalResult ConvertTFTextAPI(mlir::FuncOp func, llvm::StringRef api) { + if (api.str() == kWhitespaceTokenizer) { + if (succeeded(VerifyWhitespaceTokenizer(func))) { + return ConvertWhitespaceTokenizer(func, api); + } + } + return failure(); +} + +} // namespace TFL +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/utils/tftext_utils.h b/tensorflow/compiler/mlir/lite/utils/tftext_utils.h new file mode 100644 index 00000000000..283e57c179a --- /dev/null +++ b/tensorflow/compiler/mlir/lite/utils/tftext_utils.h @@ -0,0 +1,39 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This header file defines common utils used by TFLite transformation +// passes to work with op attributes. + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_TFTEXT_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_TFTEXT_UTILS_H_ + +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" + +namespace mlir { +namespace TFL { + +LogicalResult ConvertTFTextAPI(mlir::FuncOp func, llvm::StringRef api); + +} // end namespace TFL +} // end namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_TFTEXT_UTILS_H_ diff --git a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc index 11d3e7332db..b2225ec1c4f 100644 --- a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc +++ b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc @@ -21,6 +21,7 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/raw_os_ostream.h" +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" @@ -93,9 +94,10 @@ MlirOptimizationPassRegistry& MlirOptimizationPassRegistry::Global() { static void RegisterDialects() { static bool init_once = []() { mlir::registerDialect<mlir::StandardOpsDialect>(); + mlir::registerDialect<mlir::TF::TensorFlowDialect>(); + mlir::registerDialect<mlir::shape::ShapeDialect>(); mlir::registerDialect<mlir::tf_device::TensorFlowDeviceDialect>(); mlir::registerDialect<mlir::tf_executor::TensorFlowExecutorDialect>(); - mlir::registerDialect<mlir::TF::TensorFlowDialect>(); return true; }(); (void)init_once; diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 54b560ed6ce..5110ea7fbf5 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -224,7 +224,10 @@ cc_library( hdrs = [ "ir/tf_attributes.h", ], - deps = ["@llvm-project//mlir:IR"], + deps = [ + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + ], ) cc_library( @@ -427,6 +430,7 @@ cc_library( "transforms/parallel_execute_to_islands.cc", "transforms/promote_resources_to_args.cc", "transforms/raise_control_flow.cc", + "transforms/readonly_references_to_resources.cc", "transforms/replicate_invariant_op_hoisting.cc", "transforms/replicate_to_island.cc", "transforms/resource_device_inference.cc", @@ -555,8 +559,7 @@ cc_library( srcs = ["ir/dialect_registration.cc"], deps = [ ":tensorflow", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:SCFTransforms", + "@llvm-project//mlir:Shape", ], alwayslink = 1, ) @@ -785,6 +788,9 @@ cc_library( name = "convert_type", srcs = ["utils/convert_type.cc"], hdrs = ["utils/convert_type.h"], + visibility = [ + "//visibility:public", + ], deps = [ ":tensorflow_types", "//tensorflow/core:framework", @@ -1140,6 +1146,7 @@ COMPILE_MLIR_UTIL_DEPS = [ "//tensorflow/compiler/mlir/xla:type_to_shape", "//tensorflow/compiler/mlir/xla:xla_legalize_tf", "//tensorflow/compiler/mlir/xla:xla_legalize_tf_with_tf2xla", + "//tensorflow/compiler/mlir/xla:xla_sink_constants_to_control_flow", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/core:framework", @@ -1278,6 +1285,7 @@ cc_library( "//tensorflow/stream_executor/lib", "@com_google_absl//absl/strings", "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", ], ) @@ -1292,6 +1300,7 @@ tf_cc_test( "//tensorflow/core:test_main", "//tensorflow/core/protobuf/tpu:topology_proto_cc", "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", ], ) diff --git a/tensorflow/compiler/mlir/tensorflow/c/BUILD b/tensorflow/compiler/mlir/tensorflow/c/BUILD new file mode 100644 index 00000000000..3a503685fc6 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/c/BUILD @@ -0,0 +1,55 @@ +load( + "//tensorflow:tensorflow.bzl", + "tf_copts", + "tf_cuda_library", + "tfe_xla_copts", +) + +package( + default_visibility = [":friends"], + licenses = ["notice"], # Apache 2.0 +) + +package_group( + name = "friends", + packages = ["//tensorflow/..."], +) + +tf_cuda_library( + name = "mlir_c_api", + srcs = [ + "c_api_unified_experimental_mlir.cc", + ], + copts = tf_copts() + tfe_xla_copts(), + deps = [ + "//tensorflow/c:c_api", + "//tensorflow/c:tf_status_helper", + "//tensorflow/c:tf_status_internal", + "//tensorflow/c/eager:c_api", + "//tensorflow/c/eager:c_api_internal", + "//tensorflow/c/eager:c_api_unified_internal", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:convert_graphdef", + "//tensorflow/compiler/mlir/tensorflow:convert_type", + "//tensorflow/compiler/mlir/tensorflow:error_util", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:casts", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + ], +) + +cc_library( + name = "mlir_c_api_registration", + srcs = ["c_api_unified_experimental_mlir_registration.cc"], + deps = [ + ":mlir_c_api", + "//tensorflow/c/eager:c_api_unified_internal", + ], + alwayslink = 1, +) diff --git a/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc new file mode 100644 index 00000000000..0e8b7fedd9b --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc @@ -0,0 +1,493 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <cstddef> +#include <memory> + +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_internal.h" +#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" +#include "tensorflow/c/tf_status.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/c/tf_status_internal.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" + +namespace mlir { +namespace TF { +using tensorflow::internal::AbstractFunction; +using tensorflow::internal::AbstractOp; +using tensorflow::internal::AbstractTensor; +using tensorflow::internal::dyncast; +using tensorflow::internal::ExecutionContext; +using tensorflow::internal::OutputList; + +namespace { + +static void RegisterDialects() { + static bool init_once = []() { + mlir::registerDialect<mlir::StandardOpsDialect>(); + mlir::registerDialect<mlir::tf_device::TensorFlowDeviceDialect>(); + mlir::registerDialect<mlir::tf_executor::TensorFlowExecutorDialect>(); + mlir::registerDialect<mlir::TF::TensorFlowDialect>(); + return true; + }(); + (void)init_once; +} + +Status ConvertDataTypeToTensor(tensorflow::DataType dtype, Builder builder, + Type* type) { + Status s = tensorflow::ConvertDataType(dtype, builder, type); + if (s.ok()) *type = UnrankedTensorType::get(*type); + return s; +} + +class MlirTensor : public AbstractTensor { + public: + explicit MlirTensor(Value value) : AbstractTensor(kKind), value_(value) {} + + Value getValue() { return value_; } + + static constexpr AbstractTensorKind kKind = kMlirTensor; + + private: + Value value_; +}; + +class MlirAbstractOp : public AbstractOp { + public: + explicit MlirAbstractOp(MLIRContext* context) + : AbstractOp(kKind), context_(context) {} + + void SetOpType(const char* op_type, TF_Status* s) override; + + void SetAttrType(const char* attr_name, TF_DataType dtype, + TF_Status* s) override; + + void SetOpName(const char* const op_name, TF_Status* s) override; + + MLIRContext* GetContext() { return context_; } + + Type AddRef(Type type, TF_Status* s); + + OperationState* Create(ArrayRef<Value> operands, TF_Status* s); + + static constexpr AbstractOpKind kKind = kMlirOp; + + private: + MLIRContext* context_; + llvm::StringMap<Attribute> attrs_; + std::unique_ptr<OperationState> state_; + const char* op_name_ = nullptr; +}; + +// MlirFunction is a thin wrapper over a FuncOp. +class MlirFunction : public AbstractFunction { + public: + explicit MlirFunction(std::unique_ptr<MLIRContext> context, + OwningModuleRef module, FuncOp func) + : AbstractFunction(kKind), + context_(std::move(context)), + module_(std::move(module)), + func_(func) {} + + TF_Function* GetTfFunction(TF_Status* s) override; + + static constexpr AbstractFunctionKind kKind = kGraphFunc; + + private: + std::unique_ptr<MLIRContext> context_; + OwningModuleRef module_; + FuncOp func_; +}; + +class MlirFunctionContext : public ExecutionContext { + public: + explicit MlirFunctionContext(const char* name) + : ExecutionContext(kKind), + context_(std::make_unique<MLIRContext>()), + builder_(context_.get()) { + // TODO(aminim) figure out the location story here + module_ = ModuleOp::create(builder_.getUnknownLoc()); + func_ = FuncOp::create(builder_.getUnknownLoc(), name, + builder_.getFunctionType(llvm::None, llvm::None)); + module_->push_back(func_); + builder_ = OpBuilder::atBlockBegin(func_.addEntryBlock()); + } + + AbstractOp* CreateOperation() override { + return new MlirAbstractOp(context_.get()); + } + + void ExecuteOperation(AbstractOp* abstract_op, int num_inputs, + AbstractTensor* const* inputs, OutputList* o, + TF_Status* s) override; + + AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) override; + + AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) override; + + void RegisterFunction(AbstractFunction* func, TF_Status* s) override { + s->status = tensorflow::errors::Unimplemented( + "Registering graph functions has not been implemented yet."); + } + + static constexpr ExecutionContextKind kKind = kMlirContext; + + private: + std::unique_ptr<MLIRContext> context_; + OpBuilder builder_; + FuncOp func_; + OwningModuleRef module_; +}; + +void MlirAbstractOp::SetOpType(const char* op_type, TF_Status* s) { + if (state_) { + s->status = tensorflow::errors::FailedPrecondition( + "SetOpType called on already built op."); + return; + } + std::string name = "tf."; + name += op_type; + // TODO(aminim) figure out the location story here + state_ = std::make_unique<OperationState>(UnknownLoc::get(context_), name); +} + +void MlirAbstractOp::SetAttrType(const char* attr_name, TF_DataType dtype, + TF_Status* s) { + if (!state_) { + s->status = tensorflow::errors::FailedPrecondition( + "op_type must be specified before specifying attrs."); + return; + } + Type mlir_type; + Builder builder(context_); + s->status = ConvertDataTypeToTensor(static_cast<tensorflow::DataType>(dtype), + builder, &mlir_type); + if (!s->status.ok()) return; + attrs_[attr_name] = TypeAttr::get(mlir_type); +} + +void MlirAbstractOp::SetOpName(const char* const op_name, TF_Status* s) { + // TODO(aminim): should we use a location? + if (op_name_) { + s->status = tensorflow::errors::FailedPrecondition( + "SetOpName called on already built op."); + return; + } + op_name_ = op_name; +} + +Type MlirAbstractOp::AddRef(Type type, TF_Status* s) { + Type elt_type = getElementTypeOrSelf(type); + if (elt_type.isa<mlir::TF::TensorFlowRefType>()) { + s->status = tensorflow::errors::InvalidArgument( + "Requested reference to a reference type"); + return nullptr; + } + elt_type = TensorFlowRefType::get(elt_type); + if (RankedTensorType tensor_type = type.dyn_cast<RankedTensorType>()) { + return RankedTensorType::get(tensor_type.getShape(), elt_type); + } + return UnrankedTensorType::get(elt_type); +} + +OperationState* MlirAbstractOp::Create(ArrayRef<Value> operands, TF_Status* s) { + state_->operands = llvm::to_vector<4>(operands); + const tensorflow::OpDef* op_def; + auto node_name = state_->name.getStringRef().drop_front( + TensorFlowDialect::getDialectNamespace().size() + 1); + s->status = + tensorflow::OpRegistry::Global()->LookUpOpDef(node_name.str(), &op_def); + if (!s->status.ok()) return nullptr; + Builder builder(context_); + // Process operands according to the op_def and infer derived attributes. + int current_operand = 0; + for (const tensorflow::OpDef::ArgDef& input_arg : op_def->input_arg()) { + if (!input_arg.number_attr().empty()) { + // TODO(b/156122856): we don't support variadic operands. + s->status = tensorflow::errors::Unimplemented( + "Unsupported 'number_attr' for '", input_arg.number_attr(), "'"); + return nullptr; + } else if (!input_arg.type_list_attr().empty()) { + s->status = tensorflow::errors::InvalidArgument( + "Unsupported 'type_list_attr' for '", input_arg.number_attr(), "'"); + return nullptr; + } + if (current_operand >= operands.size()) { + s->status = tensorflow::errors::InvalidArgument("Missing operand for '", + input_arg.name(), "'"); + return nullptr; + } + Type expected_type; + if (input_arg.type() != tensorflow::DT_INVALID) { + s->status = + ConvertDataTypeToTensor(input_arg.type(), builder, &expected_type); + if (!s->status.ok()) return nullptr; + if (input_arg.is_ref()) expected_type = AddRef(expected_type, s); + if (!s->status.ok()) return nullptr; + } else { + expected_type = operands[current_operand].getType(); + } + if (!input_arg.type_attr().empty()) { + attrs_[input_arg.type_attr()] = TypeAttr::get(expected_type); + } + ++current_operand; + } + + for (const tensorflow::OpDef::ArgDef& output_arg : op_def->output_arg()) { + int original_size = state_->types.size(); + if (!output_arg.number_attr().empty()) { + // Same type repeated "repeats" times. + Attribute repeats_attr = attrs_[output_arg.number_attr()]; + if (!repeats_attr) { + s->status = tensorflow::errors::InvalidArgument( + "Missing attribute '", output_arg.number_attr(), + "' required for output list '", output_arg.name(), "'"); + return nullptr; + } + if (!repeats_attr.isa<IntegerAttr>()) { + s->status = tensorflow::errors::InvalidArgument( + "Attribute '", output_arg.number_attr(), + "' required for output list '", output_arg.name(), + "' isn't an integer"); + return nullptr; + } + int64_t repeats = repeats_attr.cast<IntegerAttr>().getInt(); + + if (!output_arg.type_attr().empty()) { + // Same type repeated "repeats" times. + Attribute attr = attrs_[output_arg.type_attr()]; + if (!attr) { + s->status = tensorflow::errors::InvalidArgument( + "Missing attribute '", output_arg.type_attr(), + "' required for output '", output_arg.name(), "'"); + return nullptr; + } + TypeAttr type_attr = attr.dyn_cast<TypeAttr>(); + if (!type_attr) { + s->status = tensorflow::errors::InvalidArgument( + "Attribute '", output_arg.type_attr(), "' required for output '", + output_arg.name(), "' isn't a type attribute"); + return nullptr; + } + for (int i = 0; i < repeats; ++i) + state_->types.push_back(type_attr.getType()); + } else if (output_arg.type() != tensorflow::DT_INVALID) { + for (int i = 0; i < repeats; ++i) { + Type type; + s->status = + ConvertDataTypeToTensor(output_arg.type(), builder, &type); + if (!s->status.ok()) return nullptr; + state_->types.push_back(type); + } + } else { + s->status = tensorflow::errors::InvalidArgument( + "Missing type or type_attr field in ", + output_arg.ShortDebugString()); + return nullptr; + } + } else if (!output_arg.type_attr().empty()) { + Attribute attr = attrs_[output_arg.type_attr()]; + if (!attr) { + s->status = tensorflow::errors::InvalidArgument( + "Missing attribute '", output_arg.type_attr(), + "' required for output '", output_arg.name(), "'"); + return nullptr; + } + TypeAttr type_attr = attr.dyn_cast<TypeAttr>(); + if (!type_attr) { + s->status = tensorflow::errors::InvalidArgument( + "Attribute '", output_arg.type_attr(), "' required for output '", + output_arg.name(), "' isn't a type attribute"); + return nullptr; + } + state_->types.push_back(type_attr.getValue()); + } else if (!output_arg.type_list_attr().empty()) { + // This is pointing to an attribute which is an array of types. + Attribute attr = attrs_[output_arg.type_list_attr()]; + if (!attr) { + s->status = tensorflow::errors::InvalidArgument( + "Missing attribute '", output_arg.type_list_attr(), + "' required for output '", output_arg.name(), "'"); + return nullptr; + } + ArrayAttr array_attr = attr.dyn_cast<ArrayAttr>(); + if (!array_attr) { + s->status = tensorflow::errors::InvalidArgument( + "Attribute '", output_arg.type_list_attr(), + "' required for output '", output_arg.name(), + "' isn't an array attribute"); + return nullptr; + } + for (Attribute attr : array_attr) { + TypeAttr type_attr = attr.dyn_cast<TypeAttr>(); + if (!type_attr) { + s->status = tensorflow::errors::InvalidArgument( + "Array Attribute '", output_arg.type_list_attr(), + "' required for output '", output_arg.name(), + "' has a non-Type element"); + return nullptr; + } + state_->types.push_back(type_attr.getValue()); + } + } else if (output_arg.type() != tensorflow::DT_INVALID) { + Type type; + Builder builder(context_); + s->status = ConvertDataTypeToTensor(output_arg.type(), builder, &type); + if (!s->status.ok()) return nullptr; + state_->types.push_back(type); + } else { + s->status = tensorflow::errors::InvalidArgument( + "No type fields in ", output_arg.ShortDebugString()); + if (!s->status.ok()) return nullptr; + } + if (output_arg.is_ref()) { + // For all types that were added by this function call, make them refs. + for (Type& type : llvm::make_range(&state_->types[original_size], + state_->types.end())) { + type = AddRef(type, s); + if (!s->status.ok()) return nullptr; + } + } + } + return state_.get(); +} + +TF_Function* MlirFunction::GetTfFunction(TF_Status* s) { + PassManager pm(func_.getContext()); + pm.addNestedPass<FuncOp>(CreateFunctionalToExecutorDialectConversionPass()); + pm.addNestedPass<FuncOp>(CreateBreakUpIslandsPass()); + + // In case of failure, the `diag_handler` converts MLIR errors emitted to + // the MLIRContext into a tensorflow::Status. + StatusScopedDiagnosticHandler diag_handler(func_.getContext()); + LogicalResult result = pm.run(func_.getParentOfType<ModuleOp>()); + (void)result; + s->status = diag_handler.ConsumeStatus(); + if (!s->status.ok()) return nullptr; + + tensorflow::GraphExportConfig configs; + std::unique_ptr<TF_Function> tf_function(new TF_Function); + s->status = ConvertMlirFunctionToFunctionLibraryDef(func_, configs, + &tf_function->fdef); + return tf_function.release(); +} + +void MlirFunctionContext::ExecuteOperation(AbstractOp* abstract_op, + int num_inputs, + AbstractTensor* const* inputs, + OutputList* o, TF_Status* s) { + auto* mlir_op = dyncast<MlirAbstractOp>(abstract_op); + if (mlir_op == nullptr) { + s->status = tensorflow::errors::InvalidArgument( + "Unable to cast AbstractOp to TF_GraphOp."); + return; + } + SmallVector<Value, 8> operands; + for (int i = 0; i < num_inputs; ++i) { + auto* operand = dyncast<MlirTensor>(inputs[i]); + if (!operand) { + s->status = tensorflow::errors::InvalidArgument( + "Capturing eager tensors is not supported yet."); + return; + } + if (operand->getValue().getContext() != context_.get()) { + s->status = tensorflow::errors::InvalidArgument( + "Capturing tensors from other context is not supported."); + return; + } + operands.push_back(operand->getValue()); + } + OperationState* state = mlir_op->Create(operands, s); + if (!s->status.ok() || !state) return; + Operation* op = builder_.createOperation(*state); + int num_results = op->getNumResults(); + o->outputs.clear(); + o->outputs.reserve(num_results); + for (Value result : op->getResults()) + o->outputs.push_back(new MlirTensor(result)); +} + +AbstractTensor* MlirFunctionContext::AddParameter(TF_DataType dtype, + TF_Status* s) { + Type type; + s->status = ConvertDataTypeToTensor(static_cast<tensorflow::DataType>(dtype), + builder_, &type); + if (!s->status.ok()) return nullptr; + return new MlirTensor(func_.getBody().front().addArgument(type)); +} + +AbstractFunction* MlirFunctionContext::Finalize(OutputList* outputs, + TF_Status* s) { + Block& body = func_.getBody().front(); + SmallVector<Value, 8> ret_operands; + for (AbstractTensor* output : outputs->outputs) { + auto* operand = dyncast<MlirTensor>(output); + if (!operand) { + s->status = tensorflow::errors::InvalidArgument( + "Capturing eager tensors is not supported yet."); + return nullptr; + } + if (operand->getValue().getContext() != context_.get()) { + s->status = tensorflow::errors::InvalidArgument( + "Capturing tensors from other context is not supported."); + return nullptr; + } + ret_operands.push_back(operand->getValue()); + } + builder_.create<ReturnOp>(func_.getLoc(), ret_operands); + + auto arg_types = llvm::to_vector<8>(body.getArgumentTypes()); + auto result_types = + llvm::to_vector<8>(body.getTerminator()->getOperandTypes()); + func_.setType(FunctionType::get(arg_types, result_types, func_.getContext())); + return new MlirFunction(std::move(context_), std::move(module_), func_); +} + +extern "C" { +ExecutionContext* MlirTracingFactory(const char* fn_name, TF_Status* s) { + RegisterDialects(); + return new MlirFunctionContext(fn_name); +} +} + +} // end anonymous namespace +} // end namespace TF +} // end namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir_registration.cc b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir_registration.cc new file mode 100644 index 00000000000..778f4b777a3 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir_registration.cc @@ -0,0 +1,31 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" + +using tensorflow::internal::ExecutionContext; + +extern "C" { +ExecutionContext* MlirTracingFactory(const char* fn_name, TF_Status* s); +} + +namespace { +// Register the tracing implemented in this file as the default tracing engine. +static bool register_tracing = [] { + RegisterTracingEngineFactory("mlir", MlirTracingFactory); + return true; +}(); + +} // namespace diff --git a/tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h b/tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h index 15a4ecfc537..39245425a5a 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h @@ -26,7 +26,7 @@ limitations under the License. #include "mlir/IR/Dialect.h" // from @llvm-project #include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project -#include "mlir/Interfaces/SideEffects.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project namespace mlir { namespace TFControlFlow { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/dialect_registration.cc b/tensorflow/compiler/mlir/tensorflow/ir/dialect_registration.cc index ac468d9810c..c95d7b7ca7c 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/dialect_registration.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/dialect_registration.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" @@ -31,5 +32,6 @@ static DialectRegistration<tf_device::TensorFlowDeviceDialect> tf_device_dialect; static DialectRegistration<tf_saved_model::TensorFlowSavedModelDialect> tf_saved_model_dialect; +static DialectRegistration<mlir::shape::ShapeDialect> shape_dialect; } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.cc index 6797c04ebcf..dfad1fce26d 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" +#include "mlir/IR/Attributes.h" // from @llvm-project + namespace mlir { namespace TF { @@ -45,6 +47,28 @@ struct ShapeAttrStorage : public AttributeStorage { bool unranked = false; }; +// The storage class for FuncAttr. +struct FuncAttrStorage : public AttributeStorage { + using KeyTy = std::pair<Attribute, Attribute>; + + explicit FuncAttrStorage(Attribute name, Attribute attrs) + : name(name), attrs(attrs) {} + + bool operator==(const KeyTy& key) const { return key == KeyTy(name, attrs); } + static unsigned hashKey(const KeyTy& key) { + return llvm::hash_combine(key.first, key.second); + } + + static FuncAttrStorage* construct(mlir::AttributeStorageAllocator& allocator, + const KeyTy& key) { + return new (allocator.allocate<FuncAttrStorage>()) + FuncAttrStorage(key.first, key.second); + } + + Attribute name; + Attribute attrs; +}; + } // namespace detail // Get or create a shape attribute. @@ -85,5 +109,24 @@ bool ShapeAttr::hasStaticShape() const { return true; } +FuncAttr FuncAttr::get(mlir::MLIRContext* context, llvm::StringRef name, + DictionaryAttr attr) { + auto symbol = SymbolRefAttr::get(name, context); + return Base::get(context, AttrKind::FUNC, symbol, attr); +} + +FuncAttr FuncAttr::get(mlir::MLIRContext* context, SymbolRefAttr symbol, + DictionaryAttr attr) { + return Base::get(context, AttrKind::FUNC, symbol, attr); +} + +SymbolRefAttr FuncAttr::GetName() const { + return getImpl()->name.cast<SymbolRefAttr>(); +} + +DictionaryAttr FuncAttr::GetAttrs() const { + return getImpl()->attrs.cast<DictionaryAttr>(); +} + } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h index 4d85dd95cea..ba67d6cb671 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h @@ -18,6 +18,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_ATTRIBUTES_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_ATTRIBUTES_H_ +#include "llvm/ADT/StringRef.h" #include "mlir/IR/Attributes.h" // from @llvm-project namespace mlir { @@ -30,6 +31,7 @@ namespace AttrKind { enum Kind { FIRST_USED_TENSORFLOW_ATTR = Attribute::FIRST_TENSORFLOW_ATTR, SHAPE = FIRST_USED_TENSORFLOW_ATTR, + FUNC, LAST_USED_TENSORFLOW_ATTR, }; @@ -38,6 +40,7 @@ enum Kind { namespace detail { struct ShapeAttrStorage; +struct FuncAttrStorage; } // namespace detail @@ -71,6 +74,33 @@ class ShapeAttr : public Attribute::AttrBase<ShapeAttr, Attribute, static bool kindof(unsigned kind) { return kind == AttrKind::SHAPE; } }; +// Custom attribute to model AttrValue.value.func (NameAttrList type attribute). +// This attribute holds a SymbolRefAttr, for the NameAttrList.name string and a +// DictionaryAttr for the NameAttrList.attr map<string, AttrValue>. It is +// currently printed and parsed for the following format: +// +// #tf.func<@symbol, {attr = "value"}> +// +// where the first element is the SymbolRefAttr and the second element is the +// DictionaryAttr. +class FuncAttr + : public Attribute::AttrBase<FuncAttr, Attribute, detail::FuncAttrStorage> { + public: + using Base::Base; + + static FuncAttr get(mlir::MLIRContext* context, llvm::StringRef name, + DictionaryAttr attr); + + static FuncAttr get(mlir::MLIRContext* context, SymbolRefAttr symbol, + DictionaryAttr attr); + + SymbolRefAttr GetName() const; + + DictionaryAttr GetAttrs() const; + + static bool kindof(unsigned kind) { return kind == AttrKind::FUNC; } +}; + } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc index d5ecbf3e292..9daebc22ba1 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc @@ -47,37 +47,6 @@ limitations under the License. namespace mlir { namespace tf_executor { -namespace { - -// If the given tensor has elements of type with subtypes, then returns a new -// type after dropping subtypes info. Otherwise, returns the original type as -// is. -ShapedType DropTypeSubTypes(ShapedType ty) { - Type element_ty = ty.getElementType(); - auto subtype_ty = element_ty.dyn_cast<TF::TensorFlowTypeWithSubtype>(); - if (!subtype_ty) return ty; - - Type default_ty = GetDefaultTypeOf(subtype_ty); - if (ty.hasRank()) return RankedTensorType::get(ty.getShape(), default_ty); - - return UnrankedTensorType::get(default_ty); -} - -// If the given tensor has elements of type ref, then returns a new type -// of the shape, but corresponding non-ref type as element type. Otherwise, -// returns the original type as is. -ShapedType DropRefType(ShapedType ty) { - Type element_ty = ty.getElementType(); - auto ref_ty = element_ty.dyn_cast<TF::TensorFlowRefType>(); - if (!ref_ty) return ty; - - Type default_ty = GetDefaultTypeOf(ref_ty); - if (ty.hasRank()) return RankedTensorType::get(ty.getShape(), default_ty); - - return UnrankedTensorType::get(default_ty); -} - -} // namespace //===----------------------------------------------------------------------===// // TF Executor Dialect @@ -85,6 +54,9 @@ ShapedType DropRefType(ShapedType ty) { namespace { +using TF::DropRefType; +using TF::DropTypeSubTypes; + struct TensorFlowExecutorInlinerInterface : public DialectInlinerInterface { using DialectInlinerInterface::DialectInlinerInterface; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index aa1601c4032..1df8f7fd519 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -1195,6 +1195,46 @@ subsequent operation and then be optimized away, however.) }]; } +def TF_CaseOp : TF_Op<"Case", []> { + let summary = [{ +An n-way switch statement which calls a single branch function. + }]; + + let description = [{ +An n-way switch statement, implementing the following: + ``` + switch (branch_index) { + case 0: + output = branches[0](input); + break; + case 1: + output = branches[1](input); + break; + ... + case [[nbranches-1]]: + default: + output = branches[nbranches-1](input); + break; + } + ``` + }]; + + let arguments = (ins + I32Tensor:$branch_index, + Variadic<TF_Tensor>:$input, + + Confined<SymbolRefArrayAttr, [ArrayMinCount<1>]>:$branches, + DefaultValuedAttr<TF_ShapeAttrArray, "{}">:$output_shapes + ); + + let results = (outs + Variadic<TF_Tensor>:$output + ); + + TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<1>; + TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>; +} + def TF_CastOp : TF_Op<"Cast", [NoSideEffect, SameOperandsAndResultShape]> { let summary = "Cast x of type SrcT to y of DstT."; @@ -6331,6 +6371,8 @@ If `x` and `y` are reals, this will return the floating-point division. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; let hasCanonicalizer = 1; + + let hasFolder = 1; } def TF_ReciprocalOp : TF_Op<"Reciprocal", [NoSideEffect, SameOperandsAndResultType]> { @@ -7434,9 +7476,15 @@ select(condition, t, e) ==> [[1, 2], ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>; + + let hasCanonicalizer = 1; + + let verifier = [{ + return Verify(*this); + }]; } -def TF_SelectV2Op : TF_Op<"SelectV2", [NoSideEffect]> { +def TF_SelectV2Op : TF_Op<"SelectV2", [NoSideEffect, ResultsBroadcastableShape]> { let summary = ""; let description = [{ @@ -8309,7 +8357,7 @@ def TF_StackV2Op : TF_Op<"StackV2", []> { ); } -def TF_StopGradientOp : TF_Op<"StopGradient", [NoSideEffect, SameOperandsAndResultType]> { +def TF_StopGradientOp : TF_Op<"StopGradient", [NoSideEffect, TF_AllTypesMatch<["input", "output"]>]> { let summary = "Stops gradient computation."; let description = [{ @@ -10586,6 +10634,27 @@ def TF_ZerosLikeOp : TF_Op<"ZerosLike", [NoSideEffect, SameOperandsAndResultType TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF__HostComputeMlirOp : TF_Op<"_HostComputeMlir", []> { + let summary = "A host-side computation called from a TPU device."; + + let description = [{ + }]; + + let arguments = (ins + Variadic<TF_Tensor>:$inputs, + + StrAttr:$key, + DefaultValuedAttr<I64Attr, "0">:$tpu_core + ); + + let results = (outs + Variadic<TF_Tensor>:$outputs + ); + + TF_DerivedOperandTypeListAttr Tinputs = TF_DerivedOperandTypeListAttr<0>; + TF_DerivedResultTypeListAttr Toutputs = TF_DerivedResultTypeListAttr<0>; +} + def TF__RecvTPUEmbeddingActivationsOp : TF_Op<"_RecvTPUEmbeddingActivations", []> { let summary = "An op that receives embeddng activations on the TPU."; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index 2007824369c..389be0d3b2b 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -58,6 +58,7 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/InliningUtils.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/core/platform/logging.h" @@ -110,7 +111,6 @@ static inline bool HasRankAtMost(Value value, int64_t rank) { return !type || type.getRank() <= rank; } - static bool IsUnknownDimOrRank(int64_t dim_or_rank) { return dim_or_rank == -1; } @@ -252,6 +252,39 @@ static LogicalResult VerifyTypesCompatibility( return success(); } +// This is a helper for the Select to SelectV2 canonicalization. The `data` rank +// refers to the rank of `t`/`e` (these two inputs have equal rank; this is +// checked in the verifier). +// +// In most cases, the predicate for Select can be used directly as the predicate +// for SelectV2. However, there is one case that varies, which is when the +// predicate is a tensor and the data is multidimensional. In this case, Select +// op semantics dictate that the predicate tensor length must match the size of +// the first data dimension. This varies from normal broadcasting semantics +// (which are used in SelectV2), so we must reshape the tensor in this case to +// be compatible. +static Value ReshapeSelectPredIfNecessary(OpBuilder *builder, Location loc, + Value cond, int data_rank) { + auto cond_tensor = cond.getType().cast<RankedTensorType>(); + // Reshape is only needed in the case that the cond rank is 1 (i.e. it is + // a vector) AND t/e rank is > 1. + if (cond_tensor.getRank() != 1 || data_rank <= 1) { + // No reshape necessary. Leave cond as it is. + return cond; + } + + // This is the case where a reshape is needed. We want to construct the + // shape [x,1,...1], where x is the value in the pred tensor and the + // length of the shape is equal to data_rank. + SmallVector<int64_t, 8> shape(data_rank, 1); + shape[0] = cond_tensor.getShape().front(); + auto new_shape_type = + RankedTensorType::get({data_rank}, builder->getIntegerType(64)); + auto shape_attr = DenseIntElementsAttr::get(new_shape_type, shape); + auto new_shape = builder->create<ConstOp>(loc, shape_attr); + return builder->create<ReshapeOp>(loc, cond, new_shape); +} + //===----------------------------------------------------------------------===// // Helper functions detect device capabilities from RuntimeDevices. //===----------------------------------------------------------------------===// @@ -462,9 +495,10 @@ LogicalResult FoldOperandsPermutation( namespace { // Folder that returns LHS of an Arithmetic Op if the RHS is a constant // known to be Identity (e.g X+0) -template <typename OpT, - typename std::enable_if<llvm::is_one_of< - OpT, AddV2Op, SubOp, MulOp, DivOp>::value>::type * = nullptr> +template < + typename OpT, + typename std::enable_if<llvm::is_one_of< + OpT, AddV2Op, SubOp, MulOp, DivOp, RealDivOp>::value>::type * = nullptr> OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op, ArrayRef<Attribute> operands) { auto result_op_type = arithmetic_op.getResult().getType(); @@ -479,7 +513,8 @@ OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op, // Mul and Div ops have identity value one while AddV2 and SubOp have identity // value zero. int identity = - (std::is_same<OpT, MulOp>::value || std::is_same<OpT, DivOp>::value); + (std::is_same<OpT, MulOp>::value || std::is_same<OpT, DivOp>::value || + std::is_same<OpT, RealDivOp>::value); Type element_ty = lhs_type.getElementType(); Attribute identity_attr; @@ -496,6 +531,12 @@ OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op, return arithmetic_op.x(); } + auto rhs_type = arithmetic_op.y().getType().template cast<ShapedType>(); + // TODO(chhe): we could fold and add an identity to force the broadcast. + if (result_op_type != rhs_type) { + return {}; + } + bool is_symmetric = (std::is_same<OpT, AddV2Op>::value || std::is_same<OpT, MulOp>::value); if (auto attr = operands[0].dyn_cast_or_null<DenseElementsAttr>()) { @@ -1256,8 +1297,8 @@ static LogicalResult Verify(DataFormatVecPermuteOp op) { if (rank == 1) { int64_t dim0 = input_ty.getDimSize(0); - if (dim0 != ShapedType::kDynamicSize && dim0 != 4) - return op.emitOpError("requires 1D input of size 4"); + if (dim0 != ShapedType::kDynamicSize && dim0 != 4 && dim0 != 2) + return op.emitOpError("requires 1D input of size 4 or size 2"); } if (rank == 2) { @@ -1620,10 +1661,16 @@ void FillOp::build(OpBuilder &builder, OperationState &result, Value dims, OpFoldResult FillOp::fold(ArrayRef<Attribute> operands) { assert(operands.size() == 2 && "fill op has two operand"); + auto type = getType().cast<ShapedType>(); + // DenseElementsAttr that is used in this folder only supports int and float + // types. + // TODO(hinsu): Handle complex types once there is a attribute kind for + // complex. + if (!type.getElementType().isIntOrFloat()) return {}; + auto value = operands[1].dyn_cast_or_null<ElementsAttr>(); if (!value) return {}; - auto type = getType().cast<ShapedType>(); if (type.hasStaticShape()) return DenseElementsAttr::get(type, value.getValue({})); @@ -1774,75 +1821,125 @@ static LogicalResult Verify(GatherV2Op op) { static LogicalResult Verify(IfOp op) { auto module = op.getParentOfType<ModuleOp>(); - auto thenFn = module.lookupSymbol<FuncOp>(op.then_branch()); - if (!thenFn) + auto then_fn = module.lookupSymbol<FuncOp>(op.then_branch()); + if (!then_fn) return op.emitOpError("then_branch refers to an undefined function : ") << op.then_branch(); - auto elseFn = module.lookupSymbol<FuncOp>(op.else_branch()); - if (!elseFn) + auto else_fn = module.lookupSymbol<FuncOp>(op.else_branch()); + if (!else_fn) return op.emitOpError("else_branch refers to an undefined function : ") << op.else_branch(); - auto thenFuncType = thenFn.getType(); - auto elseFuncType = elseFn.getType(); + auto then_fn_type = then_fn.getType(); + auto else_fn_type = else_fn.getType(); // Non-conditional operands starting with the second operand are passed to // branches and should be pair-wise compatible with branches' inputs. - unsigned expectedNumInputs = op.getNumOperands() - 1; - if (thenFuncType.getNumInputs() != expectedNumInputs || - elseFuncType.getNumInputs() != expectedNumInputs) - return op.emitError("branches should have " + Twine(expectedNumInputs) + + unsigned expected_num_inputs = op.getNumOperands() - 1; + if (then_fn_type.getNumInputs() != expected_num_inputs || + else_fn_type.getNumInputs() != expected_num_inputs) + return op.emitError("branches should have " + Twine(expected_num_inputs) + " inputs"); - for (unsigned i = 0; i < expectedNumInputs; ++i) { - auto operandType = op.getOperand(i + 1).getType().cast<TensorType>(); - auto thenInputType = thenFuncType.getInput(i).cast<TensorType>(); - if (!AreCastCompatible({operandType, thenInputType})) + for (unsigned i = 0; i < expected_num_inputs; ++i) { + auto operand_type = op.getOperand(i + 1).getType().cast<TensorType>(); + auto then_input_type = then_fn_type.getInput(i).cast<TensorType>(); + if (!AreCastCompatible({operand_type, then_input_type})) return op.emitError( llvm::formatv("then branch input type {0} is incompatible with " "operand type {1} at index {2}", - thenInputType, operandType, i)); + then_input_type, operand_type, i)); - auto elseInputType = elseFuncType.getInput(i).cast<TensorType>(); - if (!AreCastCompatible({operandType, elseInputType})) + auto else_input_type = else_fn_type.getInput(i).cast<TensorType>(); + if (!AreCastCompatible({operand_type, else_input_type})) return op.emitError( llvm::formatv("else branch input type {0} is incompatible with " "operand type {1} at index {2}", - elseInputType, operandType, i)); + else_input_type, operand_type, i)); // If branches have incompatible input types that means that no tensor can // serve as input to both the functions. Hence, the op is invalid. - if (!AreCastCompatible({thenInputType, elseInputType})) + if (!AreCastCompatible({then_input_type, else_input_type})) return op.emitError(llvm::formatv( "branches inputs have incompatible types {0} and {1} at index {2}", - thenInputType, elseInputType, i)); + then_input_type, else_input_type, i)); } // Branches' results should be pair-wise compatible with the op results. - unsigned expectedNumResults = op.getNumResults(); - if (thenFuncType.getNumResults() != expectedNumResults || - elseFuncType.getNumResults() != expectedNumResults) - return op.emitError("branches should have " + Twine(expectedNumResults) + + unsigned expected_num_results = op.getNumResults(); + if (then_fn_type.getNumResults() != expected_num_results || + else_fn_type.getNumResults() != expected_num_results) + return op.emitError("branches should have " + Twine(expected_num_results) + " results"); - for (unsigned i = 0; i < expectedNumResults; ++i) { - auto resultType = op.getResult(i).getType().cast<TensorType>(); - auto thenResultType = thenFuncType.getResult(i).cast<TensorType>(); - if (!AreCastCompatible({thenResultType, resultType})) + for (unsigned i = 0; i < expected_num_results; ++i) { + auto result_type = op.getResult(i).getType().cast<TensorType>(); + auto then_result_type = then_fn_type.getResult(i).cast<TensorType>(); + if (!AreCastCompatible({then_result_type, result_type})) return op.emitError( llvm::formatv("then branch result type {0} is incompatible with op " "result type {1} at index {2}", - thenResultType, resultType, i)); + then_result_type, result_type, i)); - auto elseResultType = elseFuncType.getResult(i).cast<TensorType>(); - if (!AreCastCompatible({elseResultType, resultType})) + auto else_result_type = else_fn_type.getResult(i).cast<TensorType>(); + if (!AreCastCompatible({else_result_type, result_type})) return op.emitError( llvm::formatv("else branch result type {0} is incompatible with op " "result type {1} at index {2}", - elseResultType, resultType, i)); + else_result_type, result_type, i)); } return success(); } +//===----------------------------------------------------------------------===// +// YieldOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(YieldOp op) { + auto parent = op.getParentOp(); + // A YieldOp should be contained within an IfRegion op + // (and WhileRegion in future) + if (!isa<IfRegionOp>(parent)) + op.emitError() << " expects parent op " + << "'" << IfRegionOp::getOperationName() << "' but got '" + << parent->getName().getStringRef() << "'"; + return success(); +} + +//===----------------------------------------------------------------------===// +// IfRegionOp +//===----------------------------------------------------------------------===// + +LogicalResult VerifyRegionResults(Operation *op, Region ®ion, + StringRef region_name) { + auto op_name = op->getName().getStringRef(); + // verify that op outputs match yield inputs + YieldOp yield = cast<YieldOp>(region.front().getTerminator()); + unsigned expected_num_results = op->getNumResults(); + if (yield.getNumOperands() != expected_num_results) + return op->emitError(region_name + " region should have " + + Twine(expected_num_results) + " results"); + + for (int idx : llvm::seq<int>(0, expected_num_results)) { + auto op_result_type = op->getResult(idx).getType().cast<TensorType>(); + auto region_result_type = + yield.getOperand(idx).getType().cast<TensorType>(); + if (!AreCastCompatible({region_result_type, op_result_type})) + return op->emitError(llvm::formatv( + "{0} result type {1} is incompatible with {2} " + "result type {3} at index {4}", + region_name, region_result_type, op_name, op_result_type, idx)); + } + return success(); +} + +static LogicalResult Verify(IfRegionOp op) { + if (failed(VerifyRegionResults(op, op.then_branch(), "then"))) + return failure(); + if (failed(VerifyRegionResults(op, op.else_branch(), "else"))) + return failure(); + return success(); +} + //===----------------------------------------------------------------------===// // InvertOp //===----------------------------------------------------------------------===// @@ -2408,6 +2505,10 @@ void RealDivOp::getCanonicalizationPatterns(OwningRewritePatternList &results, results.insert<RealDivWithSqrtDivisor>(context); } +OpFoldResult RealDivOp::fold(ArrayRef<Attribute> operands) { + return IdentityArithmeticOpFolder<RealDivOp>(*this, operands); +} + //===----------------------------------------------------------------------===// // ReshapeOp //===----------------------------------------------------------------------===// @@ -2539,6 +2640,81 @@ void ReshapeOp::build(OpBuilder &builder, OperationState &result, Value tensor, return unranked(); } +//===----------------------------------------------------------------------===// +// SelectOp +//===----------------------------------------------------------------------===// + +void SelectOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert<SelectToSelectV2>(context); +} + +// Verifies a few extra requirements on SelectOp: +// (1) `then` and `else` must have same shape +// (2) At least one of the following must be true: +// (a) `cond` has the same rank as `then` and `else` +// (b) `cond` is a scalar +// (c) `cond` is a vector AND `then` and `else` are non-scalar with their +// first dimension equal to `cond`. +static LogicalResult Verify(SelectOp op) { + auto then_tensor = op.t().getType().cast<TensorType>(); + auto else_tensor = op.e().getType().cast<TensorType>(); + // Check (1). + if (!AreCastCompatible({then_tensor, else_tensor})) + return op.emitOpError() << "requires t and e have compatible shapes"; + + // Get data rank (if exists). + int data_rank; + // If data is unranked or data_rank is 0, this will remain -2. Otherwise + // refers to first dimension of then and/or else. + int data_first_dim = -2; + bool then_has_rank = then_tensor.hasRank(); + bool else_has_rank = else_tensor.hasRank(); + if (then_has_rank && else_has_rank) { + data_rank = then_tensor.getRank(); + if (then_tensor.getRank() > 0) + data_first_dim = then_tensor.getShape().front(); + if (else_tensor.getRank() > 0) + data_first_dim = std::max( + static_cast<int>(else_tensor.getShape().front()), data_first_dim); + } else if (then_has_rank) { + data_rank = then_tensor.getRank(); + if (then_tensor.getRank() > 0) + data_first_dim = then_tensor.getShape().front(); + } else if (else_has_rank) { + data_rank = else_tensor.getRank(); + if (else_tensor.getRank() > 0) + data_first_dim = else_tensor.getShape().front(); + } else { + // Neither has a rank. + return success(); + } + + auto cond_tensor = op.condition().getType().dyn_cast<RankedTensorType>(); + if (!cond_tensor) return success(); + auto cond_rank = cond_tensor.getRank(); + // Check (2a) and (2b). + if (cond_rank == 0 || cond_rank == data_rank) return success(); + // Check (2c). + if (cond_rank == 1) { + auto cond_shape = cond_tensor.getShape().front(); + if (data_rank == 0) { + return op.emitOpError() + << "requires that t and e are nonscalar when pred is a vector"; + } + // We know `data` tensor has a rank of at least 1. + if (data_first_dim != -1 && cond_shape != -1 && + data_first_dim != cond_shape) { + return op.emitOpError() << "requires that, when pred is a vector, the " + "shape matches the first dimension of t and e"; + } + return success(); + } + // None of (2a,b,c) were true; fail. + return op.emitOpError() << "requires that pred is a scalar OR has the same " + "rank as t and e OR is a vector"; +} + //===----------------------------------------------------------------------===// // SelectV2Op //===----------------------------------------------------------------------===// @@ -2598,9 +2774,12 @@ LogicalResult VerifyShapeOperandAndResult(Operation *op, Type operand_type, << variadic_idx_str << " to match rank of operand" << variadic_idx_str; } else if (result_ranked_type.hasStaticShape()) { - // The operand is an unranked tensor, verify that the result is dynamic. - return op->emitOpError("requires dynamic shape result") - << variadic_idx_str << " for unranked operand" << variadic_idx_str; + // The operand is an unranked tensor, print a warning if the result + // is static. + // Note: We do not handle this situation as an error, this would be too + // restrictive due to incompleteness of shape inference at this point. + op->emitWarning("has static shape result") + << variadic_idx_str << " for unranked operand" << variadic_idx_str; } Type element_type = result_ranked_type.getElementType(); @@ -3551,12 +3730,20 @@ OpFoldResult FoldIdentityTranspose(TransposeOp op) { if (!const_perm) return {}; auto const_value = const_perm.value(); - const auto &elements = const_value.getValues<APInt>(); + const auto elements = const_value.getValues<APInt>(); for (auto it : llvm::enumerate(elements)) { if (it.index() != it.value()) return {}; } + // TODO(jpienaar): Remove if/when we handle this more generally. + if (op.getType() != op.x().getType()) { + // If the types don't match then only fold if all the operands are in the TF + // dialect. + for (auto user : op.getOperation()->getUsers()) + if (user->getDialect() != op.getDialect()) return {}; + } + return op.x(); } @@ -3700,36 +3887,37 @@ OpFoldResult VariableShapeOp::fold(ArrayRef<Attribute> operands) { static LogicalResult Verify(WhileOp op) { auto module = op.getParentOfType<ModuleOp>(); - auto condFn = module.lookupSymbol<FuncOp>(op.cond()); - auto bodyFn = module.lookupSymbol<FuncOp>(op.body()); - if (!condFn) { + auto cond_fn = module.lookupSymbol<FuncOp>(op.cond()); + auto body_fn = module.lookupSymbol<FuncOp>(op.body()); + if (!cond_fn) { return op.emitOpError("cond refers to an undefined function : ") << op.cond(); } - if (!bodyFn) { + if (!body_fn) { return op.emitOpError("body refers to an undefined function : ") << op.body(); } - auto condFuncType = condFn.getType(); - auto bodyFuncType = bodyFn.getType(); + auto cond_fn_type = cond_fn.getType(); + auto body_fn_type = body_fn.getType(); // Verify that the cond function has exactly one result. - if (condFuncType.getNumResults() != 1) + if (cond_fn_type.getNumResults() != 1) return op.emitOpError("requires cond function to have exactly one result"); SmallVector<Type, 4> operands(op.getOperandTypes()); // Collect all the type lists for the op so that different pairs of type lists // can be compared for the compatibility. - int numTypeLists = 5; - std::pair<std::string, ArrayRef<Type>> typeLists[] = { - {"operand", operands}, - {"body function result", bodyFuncType.getResults()}, - {"result", op.getResultTypes()}, - {"cond function input", condFuncType.getInputs()}, - {"body function input", bodyFuncType.getInputs()}, - }; + constexpr int kNumTypeLists = 5; + const std::array<std::pair<std::string, ArrayRef<Type>>, kNumTypeLists> + type_lists = {{ + {"operand", operands}, + {"body function result", body_fn_type.getResults()}, + {"result", op.getResultTypes()}, + {"cond function input", cond_fn_type.getInputs()}, + {"body function input", body_fn_type.getInputs()}, + }}; // A pair of type lists should be cast compatible with each other if one is // converted to the another for a function call or assignment or there is a @@ -3753,28 +3941,28 @@ static LogicalResult Verify(WhileOp op) { // never converted from one to the another nor there is a common source // tensors. Compatibility requirement is not transitive. - for (int i = 0; i < numTypeLists; ++i) { + for (int i = 0; i < kNumTypeLists; ++i) { // Skip the first pair as the While op operands and body function results // does not need to be compatible with each other. - for (int j = std::max(2, i + 1); j < numTypeLists; ++j) { - auto &a = typeLists[i]; - auto &b = typeLists[j]; + for (int j = std::max(2, i + 1); j < kNumTypeLists; ++j) { + auto &a = type_lists[i]; + auto &b = type_lists[j]; - int aSize = a.second.size(); - if (aSize != b.second.size()) + int a_size = a.second.size(); + if (a_size != b.second.size()) return op.emitOpError( llvm::formatv("requires the number of {0}s to be equal to the " "number of {1}s. Found {2} and {3}, respectively", - a.first, b.first, aSize, b.second.size())); + a.first, b.first, a_size, b.second.size())); - for (int idx = 0; idx < aSize; ++idx) { - auto aType = a.second[idx]; - auto bType = b.second[idx]; + for (int idx = 0; idx < a_size; ++idx) { + auto a_type = a.second[idx]; + auto b_type = b.second[idx]; - if (!AreCastCompatible({aType, bType})) + if (!AreCastCompatible({a_type, b_type})) return op.emitError(llvm::formatv( "{0} type {1} is incompatible with {2} type {3} at index {4}", - a.first, aType, b.first, bType, idx)); + a.first, a_type, b.first, b_type, idx)); } } } @@ -3856,7 +4044,7 @@ TensorFlowDialect::TensorFlowDialect(MLIRContext *context) #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def" >(); addInterfaces<TFInlinerInterface>(); - addAttributes<ShapeAttr>(); + addAttributes<ShapeAttr, FuncAttr>(); // Support unknown operations because not all TensorFlow operations are // registered. @@ -3911,6 +4099,49 @@ void PrintShapeAttr(ShapeAttr attr, DialectAsmPrinter &os) { // NOLINT os << ">"; } +// Parses a #tf.func attribute of the following format: +// +// #tf.func<@symbol, {attr = "value"}> +// +// where the first element is a SymbolRefAttr and the second element is a +// DictionaryAttr. +FuncAttr ParseFuncAttr(MLIRContext *context, StringRef spec, Location loc) { + auto emit_error = [&, spec]() { + emitError(loc, "invalid TensorFlow func attribute: ") << spec; + return nullptr; + }; + + if (!spec.consume_front("func<")) return emit_error(); + + size_t func_name_num_read = 0; + Attribute func_name_attr = + mlir::parseAttribute(spec, context, func_name_num_read); + if (!func_name_attr || !func_name_attr.isa<SymbolRefAttr>()) + return emit_error(); + spec = spec.drop_front(func_name_num_read); + + if (!spec.consume_front(", ")) return emit_error(); + + size_t func_attrs_num_read = 0; + Attribute func_attrs_attr = + mlir::parseAttribute(spec, context, func_attrs_num_read); + if (!func_attrs_attr || !func_attrs_attr.isa<DictionaryAttr>()) + return emit_error(); + spec = spec.drop_front(func_attrs_num_read); + + if (!spec.consume_front(">")) return emit_error(); + + return mlir::TF::FuncAttr::get(context, func_name_attr.cast<SymbolRefAttr>(), + func_attrs_attr.cast<DictionaryAttr>()); +} + +// Prints a #tf.func attribute of the following format: +// +// #tf.func<@symbol, {attr = "value"}> +void PrintFuncAttr(FuncAttr attr, DialectAsmPrinter &os) { + os << "func<" << attr.GetName() << ", " << attr.GetAttrs() << ">"; +} + } // namespace Attribute TensorFlowDialect::parseAttribute(DialectAsmParser &parser, @@ -3920,6 +4151,8 @@ Attribute TensorFlowDialect::parseAttribute(DialectAsmParser &parser, if (spec.startswith("shape")) return ParseShapeAttr(getContext(), spec, loc); + if (spec.startswith("func")) return ParseFuncAttr(getContext(), spec, loc); + return (emitError(loc, "unknown TensorFlow attribute: " + spec), nullptr); } @@ -3929,6 +4162,9 @@ void TensorFlowDialect::printAttribute(Attribute attr, case AttrKind::SHAPE: PrintShapeAttr(attr.cast<ShapeAttr>(), os); break; + case AttrKind::FUNC: + PrintFuncAttr(attr.cast<FuncAttr>(), os); + break; default: llvm_unreachable("unexpected tensorflow attribute kind"); } diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h index 979f506b3b1..88307267ab4 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h @@ -31,7 +31,7 @@ limitations under the License. #include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project #include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project #include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project -#include "mlir/Interfaces/SideEffects.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h" diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td index 94b0c5f5e19..1b8f9eb4bb6 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td @@ -207,6 +207,70 @@ else_branch: A function that takes 'inputs' and returns a list of }]; } +def TF_YieldOp : TF_Op<"Yield", [Terminator]> { + let summary = "Yield operation"; + let description = [{ + The "yield" operation represents a return operation within the conditional + and body of structured control flow (e.g., if and while). The operation + takes a variable number of operands and produces no results. The number and + types of inputs must match the signature of the operation that contains the + region. + }]; + + let arguments = (ins Variadic<AnyType>:$operands); + + let verifier = [{ + return Verify(*this); + }]; +} + +def TF_IfRegionOp : TF_Op<"IfRegion", + [SingleBlockImplicitTerminator<"YieldOp">]> { + let summary = "output = cond ? then_branch output : else_branch output"; + + let description = [{ +"output = cond ? then_branch output : else_branch output" + +cond: A Tensor. If the tensor is a scalar of non-boolean type, the + scalar is converted to a boolean according to the + following rule: if the scalar is a numerical value, non-zero means + True and zero means False; if the scalar is a string, non-empty + means True and empty means False. If the tensor is not a scalar, + being empty means False and being non-empty means True. +input: A list of input tensors. +then_branch: A region that computes the outputs of the op if cond = true. + It returns a list of tensors using tf.yield (as the terminator). The + types of these returned tensors is same as that of the else_branch +else_branch: A region that computes the outputs of the op if cond = false. + It returns a list of tensors using tf.yield (as the terminator). The + types of these returned tensors is same as that of the then_branch + }]; + + let arguments = (ins + TF_Tensor:$cond, + Variadic<TF_Tensor>:$input, + + DefaultValuedAttr<TF_ShapeAttrArray, "{}">:$output_shapes, + + // Used to map StatelessIf and If op defined in TensorFlow to a common op. + BoolAttr:$is_stateless + ); + + let results = (outs + Variadic<TF_Tensor>:$output + ); + + TF_DerivedOperandTypeAttr Tcond = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<1>; + TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>; + + let regions = (region SizedRegion<1>:$then_branch, SizedRegion<1>:$else_branch); + + let verifier = [{ + return Verify(*this); + }]; +} + def TF_MeanOp : TF_Op<"Mean", [NoSideEffect, TF_FoldOperandsTransposeInterface]> { let summary = "Computes the mean of elements across dimensions of a tensor."; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc index d312e5e409b..994378ea1cf 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc @@ -366,5 +366,27 @@ bool AreCastCompatible(ArrayRef<Type> types) { return true; } +ShapedType DropTypeSubTypes(ShapedType ty) { + Type element_ty = ty.getElementType(); + auto subtype_ty = element_ty.dyn_cast<TF::TensorFlowTypeWithSubtype>(); + if (!subtype_ty) return ty; + + Type default_ty = GetDefaultTypeOf(subtype_ty); + if (ty.hasRank()) return RankedTensorType::get(ty.getShape(), default_ty); + + return UnrankedTensorType::get(default_ty); +} + +ShapedType DropRefType(ShapedType ty) { + Type element_ty = ty.getElementType(); + TF::TensorFlowRefType ref_ty = element_ty.dyn_cast<TF::TensorFlowRefType>(); + if (!ref_ty) return ty; + + Type default_ty = TF::GetDefaultTypeOf(ref_ty); + if (ty.hasRank()) return RankedTensorType::get(ty.getShape(), default_ty); + + return UnrankedTensorType::get(default_ty); +} + } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h index 4c99aae4706..5f108e834a9 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h @@ -319,6 +319,16 @@ bool HasCompatibleElementTypes(Type lhs, Type rhs, // compatible. bool AreCastCompatible(ArrayRef<Type> types); +// If the given tensor has elements of type with subtypes, then returns a new +// type after dropping subtypes info. Otherwise, returns the original type as +// is. +ShapedType DropTypeSubTypes(ShapedType ty); + +// If the given tensor has elements of type ref, then returns a new type +// of the shape, but corresponding non-ref type as element type. Otherwise, +// returns the original type as is. +ShapedType DropRefType(ShapedType ty); + } // end namespace TF } // end namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir index e05894dc266..a77aa5b8346 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir @@ -258,6 +258,59 @@ func @testDoubleReciprocal(%arg0: tensor<8x16x32x64xi32>) -> tensor<8x16x32x64xi // CHECK: return %arg0 } +// CHECK-LABEL: testSelectScalarPred +func @testSelectScalarPred(%arg0: tensor<i1>, %arg1: tensor<4x2xf16>, %arg2: tensor<4x2xf16>) -> tensor<4x2xf16> { + // CHECK-NEXT: "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<4x2xf16>, tensor<4x2xf16>) -> tensor<4x2xf16> + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<4x2xf16>, tensor<4x2xf16>) -> tensor<4x2xf16> + return %0: tensor<4x2xf16> +} + +// CHECK-LABEL: testSelectVectorPred +func @testSelectVectorPred(%arg0: tensor<2xi1>, %arg1: tensor<2x3xf16>, %arg2: tensor<2x3xf16>) -> tensor<2x3xf16> { + // CHECK-NEXT: %[[SHAPE:.*]] = "tf.Const" + // CHECK-NEXT: %[[PRED:.*]] = "tf.Reshape"(%arg0, %[[SHAPE]]) : (tensor<2xi1>, tensor<2xi64>) -> tensor<2x1xi1> + // CHECK-NEXT: "tf.SelectV2"(%[[PRED]], %arg1, %arg2) : (tensor<2x1xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16> + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16> + return %0: tensor<2x3xf16> +} + +// CHECK-LABEL: testSelectAllSameShape +func @testSelectAllSameShape(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xf16>, %arg2: tensor<2x3xf16>) -> tensor<2x3xf16> { + // CHECK-NEXT: "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<2x3xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16> + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<2x3xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16> + return %0: tensor<2x3xf16> +} + +// If we don't have guarantees on input shapes, we can't support canonicalizing +// to SelectV2. Test these cases. +// CHECK-LABEL: testSelectInvalid +func @testSelectInvalid(%arg0: tensor<?xi1>, %arg1: tensor<2x3xf16>, %arg2: tensor<2x3xf16>) -> tensor<2x3xf16> { + // CHECK-NEXT: tf.Select + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<?xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16> + return %0: tensor<2x3xf16> +} + +// CHECK-LABEL: testSelectInvalidUnranked +func @testSelectInvalidUnranked(%arg0: tensor<6x7xi1>, %arg1: tensor<*xf16>, %arg2: tensor<*xf16>) -> tensor<*xf16> { + // CHECK-NEXT: tf.Select + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<6x7xi1>, tensor<*xf16>, tensor<*xf16>) -> tensor<*xf16> + return %0: tensor<*xf16> +} + +// CHECK-LABEL: testSelectThenUnranked +func @testSelectThenUnranked(%arg0: tensor<3xi1>, %arg1: tensor<*xf16>, %arg2: tensor<3x2xf16>) -> tensor<*xf16> { + // CHECK-NEXT: tf.Select + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<*xf16>, tensor<3x2xf16>) -> tensor<*xf16> + return %0: tensor<*xf16> +} + +// CHECK-LABEL: testSelectElseUnranked +func @testSelectElseUnranked(%arg0: tensor<3xi1>, %arg1: tensor<3x2xf16>, %arg2: tensor<*xf16>) -> tensor<*xf16> { + // CHECK-NEXT: tf.Select + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<3x2xf16>, tensor<*xf16>) -> tensor<*xf16> + return %0: tensor<*xf16> +} + // CHECK-LABEL: testLogicalNotOfEqual func @testLogicalNotOfEqual(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xi1> { %0 = "tf.Equal"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xi1> @@ -473,12 +526,20 @@ func @testRankOfRankedTensor(%arg0 : tensor<4x3x2xf32>) -> tensor<i32> { } // CHECK-LABEL: @foldFill -func @foldFill() -> (tensor<3x2x1xf32>, tensor<*xf32>) { +func @foldFill() -> (tensor<3x2x1xf32>, tensor<*xf32>, tensor<*xcomplex<f32>>) { %0 = "tf.Const"() {value = dense<[3, 2, 1]> : tensor<3xi32>} : () -> tensor<3xi32> %1 = "tf.Const"() {value = dense<23.0> : tensor<f32>} : () -> tensor<f32> // CHECK: "tf.Const"() {value = dense<2.300000e+01> : tensor<3x2x1xf32>} %2 = "tf.Fill"(%0, %1) : (tensor<3xi32>, tensor<f32>) -> tensor<3x2x1xf32> // CHECK: "tf.Const"() {value = dense<2.300000e+01> : tensor<3x2x1xf32>} %3 = "tf.Fill"(%0, %1) : (tensor<3xi32>, tensor<f32>) -> tensor<*xf32> - return %2, %3 : tensor<3x2x1xf32>, tensor<*xf32> + + %complex_cst = "tf.Const"() {value = dense<(0.000000e+00,1.000000e+00)> : tensor<complex<f32>>} : () -> tensor<complex<f32>> + // Here, custom folder doesn't handle complex dtypes and it is folded through + // the constant folding hook. + // TODO(hinsu): Handle complex dtypes in the custom folder for FillOp. + // CHECK: "tf.Const"() {value = dense<(0.000000e+00,1.000000e+00)> : tensor<3x2x1xcomplex<f32>>} : () -> tensor<*xcomplex<f32>> + %4 = "tf.Fill"(%0, %complex_cst) : (tensor<3xi32>, tensor<complex<f32>>) -> tensor<*xcomplex<f32>> + + return %2, %3, %4 : tensor<3x2x1xf32>, tensor<*xf32>, tensor<*xcomplex<f32>> } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir index bccb8923134..3ae6023400c 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir @@ -302,15 +302,13 @@ func @testTensorListElementShape(%arg0: tensor<!tf.variant<tensor<2x4xf32>>>) -> return %0: tensor<2xi32> } -func @RemoveTrivialAdd(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { +func @RemoveTrivialAdd(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { %cst = constant dense<0.0> : tensor<2x2xf32> - %0 = "tf.Add"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - %1 = "tf.Add"(%0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - return %1 : tensor<2x2xf32> + %0 = "tf.Add"(%arg0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> // CHECK-LABEL: RemoveTrivialAdd - // CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - // CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32> + // CHECK-NEXT: return %arg0 : tensor<2x2xf32> } func @RemoveTrivialAddBf16RHS(%arg0: tensor<2x2xbf16>) -> tensor<2x2xbf16> { @@ -331,26 +329,22 @@ func @RemoveTrivialAddBf16LHS(%arg0: tensor<2x2xbf16>) -> tensor<2x2xbf16> { // CHECK-NEXT: return %arg0 : tensor<2x2xbf16> } -func @RemoveTrivialAddV2(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { +func @RemoveTrivialAddV2(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { %cst = constant dense<0.0> : tensor<2x2xf32> - %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - %1 = "tf.AddV2"(%0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - return %1 : tensor<2x2xf32> + %0 = "tf.AddV2"(%arg0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> // CHECK-LABEL: RemoveTrivialAddV2 - // CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - // CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32> + // CHECK-NEXT: return %arg0 : tensor<2x2xf32> } -func @RemoveTrivialSub(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { +func @RemoveTrivialSub(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { %cst = constant dense<0.0> : tensor<2x2xf32> - %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - %1 = "tf.Sub"(%0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - return %1 : tensor<2x2xf32> + %0 = "tf.Sub"(%arg0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> // CHECK-LABEL: RemoveTrivialSub - // CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - // CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32> + // CHECK-NEXT: return %arg0 : tensor<2x2xf32> } func @RemoveTrivialSubInt8(%arg0: tensor<2x2xi8>) -> tensor<2x2xi8> { @@ -362,26 +356,31 @@ func @RemoveTrivialSubInt8(%arg0: tensor<2x2xi8>) -> tensor<2x2xi8> { // CHECK-NEXT: return %arg0 : tensor<2x2xi8> } -func @RemoveTrivialMul(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { +func @RemoveTrivialMul(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { %cst = constant dense<1.0> : tensor<2x2xf32> - %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - %1 = "tf.Mul"(%0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - return %1 : tensor<2x2xf32> + %0 = "tf.Mul"(%arg0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> // CHECK-LABEL: RemoveTrivialMul - // CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - // CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32> + // CHECK-NEXT: return %arg0 : tensor<2x2xf32> } -func @RemoveTrivialDiv(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { +func @RemoveTrivialDiv(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { %cst = constant dense<1.0> : tensor<2x2xf32> - %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - %1 = "tf.Div"(%0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - return %1 : tensor<2x2xf32> + %0 = "tf.Div"(%arg0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> // CHECK-LABEL: RemoveTrivialDiv - // CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - // CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32> + // CHECK-NEXT: return %arg0 : tensor<2x2xf32> +} + +func @RemoveTrivialRealDiv(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { + %cst = constant dense<1.0> : tensor<2x2xf32> + %0 = "tf.RealDiv"(%arg0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> + + // CHECK-LABEL: RemoveTrivialRealDiv + // CHECK-NEXT: return %arg0 : tensor<2x2xf32> } func @RemoveTrivialDivBf16RHS(%arg0: tensor<2x2xbf16>) -> tensor<2x2xbf16> { @@ -411,28 +410,35 @@ func @DivBf16LHS(%arg0: tensor<2x2xbf16>) -> tensor<2x2xbf16> { // CHECK: tf.Div } -func @DontRemoveTrivialAdd(%arg0: tensor<1x2xf32>, %arg1: tensor<1x2xf32>) -> tensor<2x2xf32> { +func @DontRemoveTrivialAdd(%arg0: tensor<1x2xf32>) -> tensor<2x2xf32> { %cst = constant dense<0.0> : tensor<2x2xf32> - %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<1x2xf32> - %1 = "tf.AddV2"(%0, %cst) : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - return %1 : tensor<2x2xf32> + %0 = "tf.AddV2"(%arg0, %cst) : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> // CHECK-LABEL: DontRemoveTrivialAdd // CHECK: %[[CONST:.*]] = constant dense<0.000000e+00> : tensor<2x2xf32> - // CHECK: %[[add:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<1x2xf32> - // CHECK: %[[RESULT:.*]] = "tf.AddV2"(%[[add]], %[[CONST]]) : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + // CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %[[CONST]]) : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> // CHECK: return %[[RESULT]] : tensor<2x2xf32> } -func @DontRemoveTrivialAdd2(%arg0: tensor<?x?xf32>, %arg1: tensor<2x2xf32>) -> tensor<?x?xf32> { +func @DontRemoveTrivialAdd2(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> { %cst = constant dense<0.0> : tensor<2x2xf32> - %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<2x2xf32>) -> tensor<?x?xf32> - %1 = "tf.AddV2"(%0, %cst) : (tensor<?x?xf32> , tensor<2x2xf32>) -> tensor<?x?xf32> - return %1 :tensor<?x?xf32> + %0 = "tf.AddV2"(%arg0, %cst) : (tensor<?x?xf32> , tensor<2x2xf32>) -> tensor<?x?xf32> + return %0 :tensor<?x?xf32> // CHECK-LABEL: DontRemoveTrivialAdd2 // CHECK: %[[CONST:.*]] = constant dense<0.000000e+00> : tensor<2x2xf32> - // CHECK: %[[add:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<2x2xf32>) -> tensor<?x?xf32> - // CHECK: %[[RESULT:.*]] = "tf.AddV2"(%[[add]], %[[CONST]]) : (tensor<?x?xf32>, tensor<2x2xf32>) -> tensor<?x?xf32> + // CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %[[CONST]]) : (tensor<?x?xf32>, tensor<2x2xf32>) -> tensor<?x?xf32> // CHECK: return %[[RESULT]] : tensor<?x?xf32> } + +// Test no fold because of the broadcast. +func @DontRemoveTrivialMul(%arg0: tensor<1x6x8x1xf32>) -> tensor<1x6x8x1xf32> { + %0 = "tf.Const"() {value = dense<2.000000e+00> : tensor<f32>} : () -> tensor<f32> + %1 = "tf.Mul"(%arg0, %0) : (tensor<1x6x8x1xf32>, tensor<f32>) -> tensor<1x6x8x1xf32> + return %1 : tensor<1x6x8x1xf32> + // CHECK-LABEL: DontRemoveTrivialMul + // CHECK: %[[CONST:.*]] = "tf.Const"() {value = dense<2.000000e+00> : tensor<f32>} : () -> tensor<f32> + // CHECK: %[[RESULT:.*]] = "tf.Mul"(%arg0, %[[CONST]]) : (tensor<1x6x8x1xf32>, tensor<f32>) -> tensor<1x6x8x1xf32> + // CHECK: return %[[RESULT]] : tensor<1x6x8x1xf32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/func-attr-invalid.mlir b/tensorflow/compiler/mlir/tensorflow/tests/func-attr-invalid.mlir new file mode 100644 index 00000000000..cd3b8b55032 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/func-attr-invalid.mlir @@ -0,0 +1,50 @@ +// RUN: tf-opt %s -split-input-file -verify-diagnostics + +// Tests invalid #tf.func attributes. + +// expected-error@+1 {{invalid TensorFlow func attribute: func}} +func @main() attributes {tf._implements = #tf.func} { + return +} + +// ----- + +// expected-error@+1 {{invalid TensorFlow func attribute: func<>}} +func @main() attributes {tf._implements = #tf.func<>} { + return +} + +// ----- + +// expected-error@+1 {{invalid TensorFlow func attribute: func<@symbol>}} +func @main() attributes {tf._implements = #tf.func<@symbol>} { + return +} + +// ----- + +// expected-error@+1 {{invalid TensorFlow func attribute: func<{}>}} +func @main() attributes {tf._implements = #tf.func<{}>} { + return +} + +// ----- + +// expected-error@+1 {{invalid TensorFlow func attribute: func<"test", {}>}} +func @main() attributes {tf._implements = #tf.func<"test", {}>} { + return +} + +// ----- + +// expected-error@+1 {{invalid TensorFlow func attribute: func<@symbol, "">}} +func @main() attributes {tf._implements = #tf.func<@symbol, "">} { + return +} + +// ----- + +// expected-error@+1 {{invalid TensorFlow func attribute: func<@symbol, {}, "">}} +func @main() attributes {tf._implements = #tf.func<@symbol, {}, "">} { + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/func-attr.mlir b/tensorflow/compiler/mlir/tensorflow/tests/func-attr.mlir new file mode 100644 index 00000000000..de17778c105 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/func-attr.mlir @@ -0,0 +1,13 @@ +// RUN: tf-opt %s | tf-opt | FileCheck %s --dump-input=fail + +// CHECK-LABEL: func @func_attr +// CHECK-SAME: tf._implements = #tf.func<@symbol_a, {attr0 = 1 : i32, attr1 = "random"}> +func @func_attr() attributes {tf._implements = #tf.func<@symbol_a, {attr0 = 1 : i32, attr1 = "random"}>} { + return +} + +// CHECK-LABEL: func @nested_func_attr +// CHECK-SAME: tf._implements = #tf.func<@symbol_a, {attr0 = 1 : i32, attr1 = "random", nested = #tf.func<@symbol_b, {attr2 = true, attr3 = 8.000000e+00 : f32}>}> +func @nested_func_attr() attributes {tf._implements = #tf.func<@symbol_a, {attr0 = 1 : i32, attr1 = "random", nested = #tf.func<@symbol_b, {attr2 = true, attr3 = 8.0 : f32}>}>} { + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/function-func-attr.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/function-func-attr.pbtxt new file mode 100644 index 00000000000..9f044c62736 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/function-func-attr.pbtxt @@ -0,0 +1,53 @@ +# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -o - | FileCheck %s --dump-input-on-failure + +node { + name: "custom_relu_func_call" + op: "custom_relu" +} +node { + name: "custom_embedding_matmul_func_call" + op: "custom_embedding_matmul" +} +library { + function { + signature { + name: "custom_relu" + } + attr { + key: "_implements" + value { + func { + name: "tensorflow.relu" + } + } + } + } + function { + signature { + name: "custom_embedding_matmul" + } + attr { + key: "_implements" + value { + func { + name: "tensorflow.embedding_matmul" + attr { + key: "key1" + value { + i: 2 + } + } + attr { + key: "key2" + value { + b: false + } + } + } + } + } + } +} + +# CHECK-DAG: func @custom_relu{{[0-9]*}}() attributes {tf._implements = #tf.func<@tensorflow.relu, {}>} +# CHECK-DAG: func @custom_embedding_matmul{{[0-9]*}}() attributes {tf._implements = #tf.func<@tensorflow.embedding_matmul, {key1 = 2 : i64, key2 = false}>} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir index 10cb4f8019d..abc12b2d89c 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir @@ -2,17 +2,17 @@ func @biasAdd_NHWC(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { - %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> + %0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> return %0 : tensor<1x32x10x32xi32> } func @biasAdd_NCHW(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { - %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> + %0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> return %0 : tensor<1x32x10x32xi32> } func @biasAdd_dynamic(%arg0: tensor<?x?x?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?x?x?x?xi32> { - %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<?x?x?x?xi32>, tensor<?xi32>) -> tensor<?x?x?x?xi32> + %0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<?x?x?x?xi32>, tensor<?xi32>) -> tensor<?x?x?x?xi32> return %0 : tensor<?x?x?x?xi32> } @@ -23,12 +23,12 @@ func @add(%arg0: tensor<2xi32>) -> tensor<2xi32> { } func @broadcast_add(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> + %0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0 : tensor<1x2xi32> } func @broadcast_multi_dim_add(%arg0: tensor<4x1x1xi32>, %arg1: tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> { - %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>} : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> + %0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>} : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> return %0 : tensor<4x4x4x4xi32> } @@ -38,7 +38,7 @@ func @div(%arg0: tensor<2xi32>) -> tensor<2xi32> { } func @broadcast_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - %0 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> + %0 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0 : tensor<1x2xi32> } @@ -48,7 +48,7 @@ func @shift_left(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { } func @div_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<?x?xi32>) -> tensor<?x?xi32> { - %0 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32> + %0 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32> return %0 : tensor<?x?xi32> } @@ -68,7 +68,7 @@ func @mul(%arg0: tensor<2xi32>) -> tensor<2xi32> { } func @broadcast_mul(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - %0 = "xla_hlo.multiply"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> + %0 = "xla_chlo.broadcast_multiply"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0 : tensor<1x2xi32> } @@ -78,7 +78,7 @@ func @real_div(%arg0: tensor<2xi32>) -> tensor<2xi32> { } func @broadcast_real_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - %0 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> + %0 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0 : tensor<1x2xi32> } @@ -88,7 +88,7 @@ func @sub(%arg0: tensor<2xi32>) -> tensor<2xi32> { } func @broadcast_sub(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - %0 = "xla_hlo.subtract"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> + %0 = "xla_chlo.broadcast_subtract"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0 : tensor<1x2xi32> } @@ -98,7 +98,7 @@ func @shift_right(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { } func @broadcast_shift_right(%arg0: tensor<4xi32>, %arg1: tensor<2x4xi32>) -> tensor<2x4xi32> { - %0 = "xla_hlo.shift_right_arithmetic"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32> + %0 = "xla_chlo.broadcast_shift_right_arithmetic"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32> return %0 : tensor<2x4xi32> } @@ -108,12 +108,12 @@ func @and(%arg0: tensor<2xi1>) -> tensor<2xi1> { } func @and_broadcast(%arg0: tensor<1xi1>, %arg1: tensor<1x2xi1>) -> tensor<1x2xi1> { - %0 = "xla_hlo.and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1> + %0 = "xla_chlo.broadcast_and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } func @and_dynamic(%arg0: tensor<?xi1>, %arg1: tensor<1xi1>) -> tensor<?xi1> { - %0 = "xla_hlo.and"(%arg0, %arg1) : (tensor<?xi1>, tensor<1xi1>) -> tensor<?xi1> + %0 = "xla_chlo.broadcast_and"(%arg0, %arg1) : (tensor<?xi1>, tensor<1xi1>) -> tensor<?xi1> return %0 : tensor<?xi1> } @@ -123,12 +123,12 @@ func @or(%arg0: tensor<2xi1>) -> tensor<2xi1> { } func @or_broadcast(%arg0: tensor<1xi1>, %arg1: tensor<1x2xi1>) -> tensor<1x2xi1> { - %0 = "xla_hlo.or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1> + %0 = "xla_chlo.broadcast_or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } func @or_dynamic(%arg0: tensor<?xi1>, %arg1: tensor<1xi1>) -> tensor<?xi1> { - %0 = "xla_hlo.or"(%arg0, %arg1) : (tensor<?xi1>, tensor<1xi1>) -> tensor<?xi1> + %0 = "xla_chlo.broadcast_or"(%arg0, %arg1) : (tensor<?xi1>, tensor<1xi1>) -> tensor<?xi1> return %0 : tensor<?xi1> } @@ -138,12 +138,12 @@ func @bitwise_or(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { } func @bitwise_or_broadcast(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> tensor<1x4xi8> { - %0 = "xla_hlo.or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8> + %0 = "xla_chlo.broadcast_or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8> return %0 : tensor<1x4xi8> } func @bitwise_or_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi32> { - %0 = "xla_hlo.or"(%arg0, %arg1) : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32> + %0 = "xla_chlo.broadcast_or"(%arg0, %arg1) : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32> return %0 : tensor<?xi32> } @@ -153,12 +153,12 @@ func @bitwise_and(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { } func @bitwise_and_broadcast(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> tensor<1x4xi8> { - %0 = "xla_hlo.and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8> + %0 = "xla_chlo.broadcast_and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8> return %0 : tensor<1x4xi8> } func @bitwise_and_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi32> { - %0 = "xla_hlo.and"(%arg0, %arg1) : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32> + %0 = "xla_chlo.broadcast_and"(%arg0, %arg1) : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32> return %0 : tensor<?xi32> } @@ -174,19 +174,19 @@ func @pow_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> { func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> { %0 = xla_hlo.constant dense<0> : tensor<2x3xi32> - %1 = "xla_hlo.compare"(%arg0, %0) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> + %1 = "xla_chlo.broadcast_compare"(%arg0, %0) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> %2 = xla_hlo.constant dense<0> : tensor<3xi32> - %3 = "xla_hlo.compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> - %4 = "xla_hlo.compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2x3xi1>, tensor<3xi1>) -> tensor<2x3xi1> - %5 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> + %3 = "xla_chlo.broadcast_compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> + %4 = "xla_chlo.broadcast_compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2x3xi1>, tensor<3xi1>) -> tensor<2x3xi1> + %5 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> %6 = "xla_hlo.abs"(%arg0) : (tensor<2x3xi32>) -> tensor<2x3xi32> %7 = "xla_hlo.abs"(%arg1) : (tensor<3xi32>) -> tensor<3xi32> %8 = xla_hlo.constant dense<1> : tensor<3xi32> %9 = xla_hlo.subtract %7, %8 : tensor<3xi32> - %10 = "xla_hlo.add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> + %10 = "xla_chlo.broadcast_add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> %11 = "xla_hlo.negate"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32> %12 = "xla_hlo.abs"(%arg1) : (tensor<3xi32>) -> tensor<3xi32> - %13 = "xla_hlo.divide"(%11, %12) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> + %13 = "xla_chlo.broadcast_divide"(%11, %12) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> %14 = "xla_hlo.select"(%4, %5, %13) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> return %14 : tensor<2x3xi32> } @@ -195,14 +195,14 @@ func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32 %0 = xla_hlo.constant dense<0> : tensor<3xi32> %1 = "xla_hlo.compare"(%arg0, %0) {comparison_direction = "LT"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> %2 = xla_hlo.constant dense<0> : tensor<2x3xi32> - %3 = "xla_hlo.compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> - %4 = "xla_hlo.compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<3xi1>, tensor<2x3xi1>) -> tensor<2x3xi1> - %5 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + %3 = "xla_chlo.broadcast_compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> + %4 = "xla_chlo.broadcast_compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<3xi1>, tensor<2x3xi1>) -> tensor<2x3xi1> + %5 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> %6 = "xla_hlo.abs"(%arg0) : (tensor<3xi32>) -> tensor<3xi32> %7 = "xla_hlo.abs"(%arg1) : (tensor<2x3xi32>) -> tensor<2x3xi32> %8 = xla_hlo.constant dense<1> : tensor<2x3xi32> %9 = xla_hlo.subtract %7, %8 : tensor<2x3xi32> - %10 = "xla_hlo.add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + %10 = "xla_chlo.broadcast_add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> %11 = "xla_hlo.negate"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32> %12 = "xla_hlo.abs"(%arg1) : (tensor<2x3xi32>) -> tensor<2x3xi32> %13 = xla_hlo.divide %11, %12 : tensor<2x3xi32> @@ -218,8 +218,8 @@ func @floordiv_f32(%arg0: tensor<2xf32>) -> tensor<2xf32> { } func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> tensor<2x3xf16> { - %0 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> - %1 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> + %0 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> + %1 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> %2 = "xla_hlo.floor"(%1) : (tensor<2x3xf16>) -> tensor<2x3xf16> return %2 : tensor<2x3xf16> } @@ -230,22 +230,22 @@ func @equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { } func @equal_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi1> { - %0 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1> + %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1> return %0 : tensor<?xi1> } func @equal_broadcast(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } func @equal_broadcast_no_incompatible_shapes_error(%arg0: tensor<2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } func @equal_incompatible_shape_broadcastable(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi1> { - %0 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1> + %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1> return %0 : tensor<?xi1> } @@ -255,17 +255,17 @@ func @notequal(%arg0: tensor<2xi32>) -> tensor<2xi1> { } func @notequal_broadcast(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } func @notequal_broadcast_no_incompatible_shapes_error(%arg0: tensor<2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } func @notequal_incompatible_shape_broadcastable(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi1> { - %0 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "NE"} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1> + %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = "NE"} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1> return %0 : tensor<?xi1> } @@ -275,7 +275,7 @@ func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> { } func @broadcast_greater(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GT"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GT"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } @@ -285,7 +285,7 @@ func @greater_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { } func @broadcast_greater_equal(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } @@ -295,7 +295,7 @@ func @less(%arg0: tensor<2xi32>) -> tensor<2xi1> { } func @broadcast_less(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LT"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LT"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } @@ -305,7 +305,7 @@ func @less_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { } func @broadcast_less_equal(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } @@ -326,35 +326,35 @@ func @const() -> tensor<2xi32> { func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> { %0 = xla_hlo.constant dense<0> : tensor<i32> - %1 = "xla_hlo.maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32> + %1 = "xla_chlo.broadcast_maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32> return %1 : tensor<1xi32> } func @relu_unranked(%arg0: tensor<?xi32>) -> tensor<?xi32> { %0 = xla_hlo.constant dense<0> : tensor<i32> - %1 = "xla_hlo.maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>, tensor<?xi32>) -> tensor<?xi32> + %1 = "xla_chlo.broadcast_maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>, tensor<?xi32>) -> tensor<?xi32> return %1 : tensor<?xi32> } func @relu6(%arg0: tensor<1xi32>) -> tensor<1xi32> { %0 = xla_hlo.constant dense<0> : tensor<i32> %1 = xla_hlo.constant dense<6> : tensor<i32> - %2 = "xla_hlo.minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor<i32>) -> tensor<1xi32> - %3 = "xla_hlo.maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor<i32>) -> tensor<1xi32> + %2 = "xla_chlo.broadcast_minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor<i32>) -> tensor<1xi32> + %3 = "xla_chlo.broadcast_maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor<i32>) -> tensor<1xi32> return %3 : tensor<1xi32> } func @relu6_unranked(%arg0: tensor<?xi32>) -> tensor<?xi32> { %0 = xla_hlo.constant dense<0> : tensor<i32> %1 = xla_hlo.constant dense<6> : tensor<i32> - %2 = "xla_hlo.minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32> - %3 = "xla_hlo.maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32> + %2 = "xla_chlo.broadcast_minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32> + %3 = "xla_chlo.broadcast_maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32> return %3 : tensor<?xi32> } func @relu_grad(%arg0: tensor<4x8xf32>, %arg1: tensor<?x?xf32>) -> tensor<4x8xf32> { %0 = xla_hlo.constant dense<0.000000e+00> : tensor<f32> - %1 = "xla_hlo.compare"(%arg1, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "GT"} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xi1> + %1 = "xla_chlo.broadcast_compare"(%arg1, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "GT"} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xi1> %2 = xla_hlo.constant dense<0.000000e+00> : tensor<4x8xf32> %3 = "xla_hlo.select"(%1, %arg0, %2) : (tensor<?x?xi1>, tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32> return %3 : tensor<4x8xf32> @@ -682,6 +682,11 @@ func @convert_i32_f32(%arg0: tensor<2xi32>) -> tensor<2xf32> { return %0 : tensor<2xf32> } +func @convert_slice(%arg0: tensor<1x4672xf32>) -> tensor<1x519xf32> { + %0 = "xla_hlo.slice"(%arg0) {limit_indices = dense<[1, 4672]> : tensor<2xi64>, start_indices = dense<[0, 4153]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<1x4672xf32>) -> tensor<1x519xf32> + return %0 : tensor<1x519xf32> +} + // NOTE: Assertions have been autogenerated by utils/generate-test-checks.py // CHECK-LABEL: func @biasAdd_NHWC( @@ -1493,3 +1498,12 @@ func @convert_i32_f32(%arg0: tensor<2xi32>) -> tensor<2xf32> { // CHECK: [[VAL_371:%.*]] = "tf.Cast"([[VAL_370]]) {Truncate = false} : (tensor<2xi32>) -> tensor<2xf32> // CHECK: return [[VAL_371]] : tensor<2xf32> // CHECK: } + +// CHECK-LABEL: func @convert_slice( +// CHECK-SAME: [[VAL_372:%.*]]: tensor<1x4672xf32>) -> tensor<1x519xf32> { +// CHECK: [[VAL_373:%.*]] = "tf.Const"() {value = dense<[0, 4153]> : tensor<2xi64>} : () -> tensor<2xi64> +// CHECK: [[VAL_374:%.*]] = "tf.Const"() {value = dense<[1, 519]> : tensor<2xi64>} : () -> tensor<2xi64> +// CHECK: [[VAL_375:%.*]] = "tf.Slice"([[VAL_372]], [[VAL_373]], [[VAL_374]]) : (tensor<1x4672xf32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x519xf32> +// CHECK: return [[VAL_375]] : tensor<1x519xf32> +// CHECK: } + diff --git a/tensorflow/compiler/mlir/tensorflow/tests/readonly_references_to_resources.mlir b/tensorflow/compiler/mlir/tensorflow/tests/readonly_references_to_resources.mlir new file mode 100644 index 00000000000..2970e31c3c9 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/readonly_references_to_resources.mlir @@ -0,0 +1,85 @@ +// RUN: tf-opt -verify-diagnostics -readonly-references-to-resources -split-input-file %s | FileCheck %s --dump-input=fail + +// Test case: Basic converting. + +func @f() { + // CHECK: "tf.VarHandleOp" + // CHECK: "tf.ReadVariableOp" + %val0 = "tf.VariableV2"() {_class = ["loc:@v"], container = "", device = "", shape = #tf.shape<96>, shared_name = ""} : () -> tensor<96x!tf.f32ref> + %val1 = "tf.Identity"(%val0) : (tensor<96x!tf.f32ref>) -> tensor<96xf32> + return +} + +// ----- + +// Test case: Two ReadVariable ops. + +func @f() { + // CHECK: "tf.VarHandleOp" + + // During lowering to resource variables, this pass will preserve the + // locations of the ReadVariableOps as Identity ops to keep the original graph + // composition and order. + + // CHECK: "tf.ReadVariableOp" + // CHECK: "tf.ReadVariableOp" + %val0 = "tf.VariableV2"() {_class = ["loc:@v"], container = "", device = "", shape = #tf.shape<96>, shared_name = ""} : () -> tensor<96x!tf.f32ref> + %val1 = "tf.Identity"(%val0) : (tensor<96x!tf.f32ref>) -> tensor<96xf32> + %val2 = "tf.Identity"(%val0) : (tensor<96x!tf.f32ref>) -> tensor<96xf32> + return +} + +// ----- + +// Test case: No follow-up ReadVariable case. + +func @f() { + // CHECK-NOT: "tf.VariableV2" + // CHECK-NOT: "tf.VarHandleOp" + %val0 = "tf.VariableV2"() {_class = ["loc:@v"], container = "", device = "", shape = #tf.shape<96>, shared_name = ""} : () -> tensor<96x!tf.f32ref> + return +} + +// ----- + +// Test case: No converting when there is another use case. + +func @f() { + // expected-error @+1 {{'tf.VariableV2' op expects all users to be 'tf.Identity', but got user tf.CustomIdentity}} + %val0 = "tf.VariableV2"() {_class = ["loc:@v"], container = "", device = "", shape = #tf.shape<96>, shared_name = ""} : () -> tensor<96x!tf.f32ref> + %val1 = "tf.CustomIdentity"(%val0) : (tensor<96x!tf.f32ref>) -> tensor<96xf32> + return +} + +// ----- + +// Test case: No class attribute on VariableV2 op. + +func @f() { + // expected-error @+1 {{'tf.VariableV2' op has no '_class' attribute}} + %val0 = "tf.VariableV2"() {container = "", device = "", shape = #tf.shape<96>, shared_name = ""} : () -> tensor<96x!tf.f32ref> + %val1 = "tf.Identity"(%val0) : (tensor<96x!tf.f32ref>) -> tensor<96xf32> + return +} + +// ----- + +// Test case: No named location found on VariableV2 op. + +func @f() { + // expected-error @+1 {{'tf.VariableV2' op expects variable name in '_class' attribute, but got ["unrelated_class"]}} + %val0 = "tf.VariableV2"() {_class = ["unrelated_class"], container = "", device = "", shape = #tf.shape<96>, shared_name = ""} : () -> tensor<96x!tf.f32ref> + %val1 = "tf.Identity"(%val0) : (tensor<96x!tf.f32ref>) -> tensor<96xf32> + return +} + +// ----- + +// Test case: Invalid multiple location information in a class attribute on VariableV2 op. + +func @f() { + // expected-error @+1 {{'tf.VariableV2' op expects only one named location in '_class' attribute, but got ["loc:@v1", "loc:@v2"]}} + %val0 = "tf.VariableV2"() {_class = ["loc:@v1", "loc:@v2"], container = "", device = "", shape = #tf.shape<96>, shared_name = ""} : () -> tensor<96x!tf.f32ref> + %val1 = "tf.Identity"(%val0) : (tensor<96x!tf.f32ref>) -> tensor<96xf32> + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir index 160bba94cfc..3cdade8da59 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir @@ -1,10 +1,11 @@ -// RUN: tf-opt %s -tf-shape-inference -verify-diagnostics | FileCheck %s -dump-input=fail +// RUN: tf-opt %s -tf-shape-inference=propagate-caller-callee-constants=false -verify-diagnostics | FileCheck %s -dump-input=fail +// RUN: tf-opt %s -tf-shape-inference=propagate-caller-callee-constants -verify-diagnostics | FileCheck %s -dump-input=fail module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 130 : i32}} { // CHECK-LABEL: func @main(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32> func @main(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<*xi32> { - // CHECK-NOT: tf.Cast - // CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + // CHECK: %[[RESULT:.*]] = "tf.AddV2" + // CHECK-SAME: (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> // CHECK: return %[[RESULT]] : tensor<1xi32> %0 = "tf.Cast"(%arg0) : (tensor<1xi32>) -> tensor<*xi32> %1 = "tf.Cast"(%arg1) : (tensor<1xi32>) -> tensor<*xi32> @@ -60,8 +61,8 @@ func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> { // CHECK-LABEL: func @simple_folding func @simple_folding(%arg0: tensor<1x1x1x1xi32>, %arg1: tensor<1x1x1x1xf32>) -> tensor<?x?x?x?xf32> { -// CHECK: %[[CST:.*]] = "tf.Const"{{.*}} {value = dense<1> : tensor<4xi32>} : () -> tensor<4xi32> -// CHECK: %[[CONV:.*]] = "tf.Conv2DBackpropInput"(%[[CST]] +// CHECK: %[[SHAPE:.*]] = "tf.Shape" +// CHECK: %[[CONV:.*]] = "tf.Conv2DBackpropInput"(%[[SHAPE]] // CHECK-SAME: (tensor<4xi32>, tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32> // CHECK: return %[[CONV]] : tensor<1x1x1x1xf32> %0 = "tf.Shape"(%arg0) : (tensor<1x1x1x1xi32>) -> tensor<4xi32> @@ -300,13 +301,6 @@ func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> { return %0 : tensor<*xi32> } - // CHECK-LABEL: func @fold_cast - func @fold_cast(%arg0: tensor<*xf32>) -> tensor<*xf32> { - // CHECK-NOT: Cast - %0 = "tf.Cast"(%arg0) : (tensor<*xf32>) -> (tensor<*xf32>) - return %0 : tensor<*xf32> - } - // CHECK-LABEL: func @while_variant // CHECK-SAME: -> tensor<!tf.variant<tensor<16x1xf32>>> func @while_variant(%arg0: tensor<!tf.variant<tensor<16x1xf32>>>) -> tensor<!tf.variant> { @@ -362,8 +356,6 @@ func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> { // CHECK-LABEL: func @partitioned_call_func_const func @partitioned_call_func_const(%arg0: tensor<2xi32>) -> tensor<2xi32> { - // CHECK: %[[CONST:.*]] = "tf.Const"() {value = dense<[3, 2]> : tensor<2xi32>} : () -> tensor<2xi32> - // CHECK: return %[[CONST]] return %arg0 : tensor<2xi32> } @@ -410,4 +402,18 @@ func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> { %40 = "tf.Reshape"(%39, %19) {T = f32, Tshape = i32, device = ""} : (tensor<1x4x4x32xf32>, tensor<2xi32>) -> tensor<?x?xf32> return } + + // CHECK-LABEL: const_fold + func @const_fold() -> () { + // CHECK: tf.Const + // CHECK-SAME: () -> tensor<4xi32> + %0 = "tf.Const"() {value = dense<[200, 26, 26, 32]> : tensor<4xi32>} : () -> tensor<*xi32> + // CHECK: tf.Const + // CHECK-SAME: () -> tensor<4xi32> + %1 = "tf.Const"() {value = dense<[200, 26, 26, 32]> : tensor<4xi32>} : () -> tensor<*xi32> + // CHECK: tf.Add + // CHECK-SAME: (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + %2 = "tf.Add"(%0, %1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> + return + } } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir index ffa287e0e53..c0d1a914788 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir @@ -854,6 +854,215 @@ func @testInvalidIfOp(tensor<i1>, tensor<*xf32>) -> tensor<2xf32> { // ----- +// Test invalid tf.Yield operation (parent should be IfRegion) +func @testInvalidYieldOp(%arg0: f32) -> () { + // expected-error @+1 {{expects parent op 'tf.IfRegion'}} + "tf.Yield"(%arg0) : (f32) -> () +} + +// ----- + +// Test valid tf.IfRegion operation +// CHECK-LABEL: func @testValidIfRegionOp +func @testValidIfRegionOp(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> { + %0 = "tf.IfRegion"(%arg0, %arg1) ({ + %t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> + "tf.Yield"(%t) : (tensor<2xf32>) -> () + }, { + %e = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> + "tf.Yield"(%e) : (tensor<2xf32>) -> () + }) { is_stateless = false} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32> + + return %0 : tensor<2xf32> +} + +// ----- + +// Test valid tf.IfRegion operation with multiple results +// CHECK-LABEL: func @testValidIfRegionOpWithMultipleResults +func @testValidIfRegionOpWithMultipleResults(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> { + %0, %1, %2 = "tf.IfRegion"(%arg0, %arg1) ({ + %t0 = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> + %t1 = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> + %t2 = "tf.Acosh"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> + "tf.Yield"(%t0, %t1, %t2) : (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) -> () + }, { + %e0 = "tf.Neg"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> + %e1 = "tf.Relu"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> + %e2 = "tf.Sin"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> + "tf.Yield"(%e0, %e1, %e2) : (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) -> () + }) { is_stateless = false} : (tensor<i1>, tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) + + %3 = "tf.Add"(%0, %1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> + %4 = "tf.Add"(%2, %3) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> + return %4 : tensor<2xf32> +} + +// ----- + +// Test invalid type for operand #0 for tf.IfRegion operation +func @testInvalidIfRegionOpType0(%arg0: f32, %arg1: tensor<2xf32>) -> tensor<2xf32> { + // expected-error @+1 {{operand #0 must be tensor of tf.dtype values}} + %0 = "tf.IfRegion"(%arg0, %arg1) ({ + %t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> + "tf.Yield"(%t) : (tensor<2xf32>) -> () + }, { + %e = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> + "tf.Yield"(%e) : (tensor<2xf32>) -> () + }) { is_stateless = false} : (f32, tensor<2xf32>) -> tensor<2xf32> + + return %0 : tensor<2xf32> +} + +// ----- + +// Test invalid type for operand #1 for tf.IfRegion operation +func @testInvalidIfRegionOpType1(%arg0: tensor<i1>, %arg1: f32) -> f32 { + // expected-error @+1 {{operand #1 must be tensor of tf.dtype values}} + %0 = "tf.IfRegion"(%arg0, %arg1) ({ + %t = addf %arg1, %arg1 : f32 + "tf.Yield"(%t) : (f32) -> () + }, { + %e = mulf %arg1, %arg1 : f32 + "tf.Yield"(%e) : (f32) -> () + }) { is_stateless = false} : (tensor<i1>, f32) -> f32 + + return %0 : f32 +} + +// ----- + +// tf.IfRegion operation should have 2 regions +func @testInvalidIfRegionOp1Region(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> { + // expected-error @+1 {{op expected 2 regions}} + %0 = "tf.IfRegion"(%arg0, %arg1) ({ + %t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> + "tf.Yield"(%t) : (tensor<2xf32>) -> () + }) { is_stateless = false} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32> + + return %0 : tensor<2xf32> +} + +// ----- + +func @testInvalidIfRegionOpNoRegions(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> { + // expected-error @+1 {{op expected 2 regions}} + %0 = "tf.IfRegion"(%arg0, %arg1) { is_stateless = false} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32> + + return %0 : tensor<2xf32> +} + +// ----- + +func @testInvalidIfRegionOp3Regions(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> { + // expected-error @+1 {{op expected 2 regions}} + %0 = "tf.IfRegion"(%arg0, %arg1) ({ + %t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> + "tf.Yield"(%t) : (tensor<2xf32>) -> () + }, { + %te = "tf.Relu"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> + "tf.Yield"(%te) : (tensor<2xf32>) -> () + }, { + %e = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> + "tf.Yield"(%e) : (tensor<2xf32>) -> () + }) { is_stateless = false} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32> + + return %0 : tensor<2xf32> +} + +// ----- + +// tf.IfRegion regions should be terminated with a tf.Yield +func @testIfRegionThenTerminator(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> { + // expected-error @+2 {{'tf.IfRegion' op expects regions to end with 'tf.Yield'}} + // expected-note @+1 {{in custom textual format, the absence of terminator implies 'tf.Yield'}} + %0 = "tf.IfRegion"(%arg0, %arg1) ({ + %t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> + }, { + %e = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> + "tf.Yield"(%e) : (tensor<2xf32>) -> () + }) { is_stateless = false} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32> + + return %0 : tensor<2xf32> +} + +// ----- + +func @testIfRegionElseTerminator(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> { + // expected-error @+2 {{'tf.IfRegion' op expects regions to end with 'tf.Yield'}} + // expected-note @+1 {{in custom textual format, the absence of terminator implies 'tf.Yield'}} + %0 = "tf.IfRegion"(%arg0, %arg1) ({ + %t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> + "tf.Yield"(%t) : (tensor<2xf32>) -> () + }, { + %e = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> + }) { is_stateless = false} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32> + + return %0 : tensor<2xf32> +} + +// ----- + +// tf.Region yield number of results should match op number of results +func @testIfRegionThenResultCount(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> { + // expected-error @+1 {{then region should have 1 result}} + %0 = "tf.IfRegion"(%arg0, %arg1) ({ + %t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> + "tf.Yield"(%t, %t) : (tensor<2xf32>, tensor<2xf32>) -> () + }, { + %e = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> + "tf.Yield"(%e) : (tensor<2xf32>) -> () + }) { is_stateless = false} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32> + + return %0 : tensor<2xf32> +} + +// ----- + +func @testIfRegionElseResultCount(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> { + // expected-error @+1 {{else region should have 1 result}} + %0 = "tf.IfRegion"(%arg0, %arg1) ({ + %t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> + "tf.Yield"(%t) : (tensor<2xf32>) -> () + }, { + %e = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> + "tf.Yield"(%e, %e) : (tensor<2xf32>, tensor<2xf32>) -> () + }) { is_stateless = false} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32> + + return %0 : tensor<2xf32> +} + +// ----- + +// tf.IfRegion yield types should match op result types +func @testIfRegionOpYieldMismatchThen(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> { + // expected-error @+1 {{then result type tensor<i1> is incompatible with tf.IfRegion result type tensor<2xf32> at index 0}} + %0 = "tf.IfRegion"(%arg0, %arg1) ({ + "tf.Yield"(%arg0) : (tensor<i1>) -> () + }, { + %e = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> + "tf.Yield"(%e) : (tensor<2xf32>) -> () + }) { is_stateless = false} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32> + + return %0 : tensor<2xf32> +} + +// ----- + +func @testIfRegionOpYieldMismatchElse(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> { + // expected-error @+1 {{else result type tensor<i1> is incompatible with tf.IfRegion result type tensor<2xf32> at index 0}} + %0 = "tf.IfRegion"(%arg0, %arg1) ({ + %t = "tf.Acos"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> + "tf.Yield"(%t) : (tensor<2xf32>) -> () + }, { + "tf.Yield"(%arg0) : (tensor<i1>) -> () + }) { is_stateless = false} : (tensor<i1>, tensor<2xf32>) -> tensor<2xf32> + + return %0 : tensor<2xf32> +} + +// ----- + // Test valid tf.MatrixBandPart // CHECK-LABEL: func @testValidMatrixBandPartOp func @testValidMatrixBandPartOp(%arg0: tensor<64x64xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<64x64xbf16> { @@ -1007,6 +1216,116 @@ func @pcall_func_2(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> { // ----- +//===--------------------------------------------------------------------===// +// tf.Select +//===--------------------------------------------------------------------===// + +// Test valid tf.Select +// CHECK-LABEL: func @testSelect +func @testSelect(%arg0: tensor<3xi1>, %arg1: tensor<3x2xf16>, %arg2: tensor<3x2xf16>) -> tensor<3x2xf16> { + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<3x2xf16>, tensor<3x2xf16>) -> tensor<3x2xf16> + return %0: tensor<3x2xf16> +} + +// ----- + +func @testInvalidSelect(%arg0: tensor<3xi1>, %arg1: tensor<2x3xf16>, %arg2: tensor<2x3xf16>) -> tensor<2x3xf16> { + // expected-error @+1 {{requires that, when pred is a vector, the shape matches the first dimension of t and e}} + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16> + return %0: tensor<2x3xf16> +} + +// ----- + +// Test invalid tf.Select - broadcasting then/else parameters is not supported +func @selectBroadcastThen(%arg0: tensor<i1>, %arg1: tensor<8x1xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x8x8xi32> { + // expected-error @+1 {{requires t and e have compatible shapes}} + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<8x1xi32>, tensor<2x8x8xi32>) -> tensor<2x8x8xi32> + return %0: tensor<2x8x8xi32> +} + +// ----- + +func @invalidSelect(%arg0: tensor<2xi1>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<2xi32> { + // expected-error @+1 {{requires that t and e are nonscalar when pred is a vector}} + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<i32>, tensor<i32>) -> tensor<2xi32> + return %0: tensor<2xi32> +} + +// ----- + +func @invalidSelect(%arg0: tensor<1x8xi1>, %arg1: tensor<1x8x8xi32>, %arg2: tensor<1x8x8xi32>) -> tensor<1x8x8xi32> { + // expected-error @+1 {{requires that pred is a scalar OR has the same rank as t and e OR is a vector}} + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<1x8xi1>, tensor<1x8x8xi32>, tensor<1x8x8xi32>) -> tensor<1x8x8xi32> + return %0: tensor<1x8x8xi32> +} + +// ----- + +//===--------------------------------------------------------------------===// +// tf.SelectV2 +//===--------------------------------------------------------------------===// + +// Test valid tf.SelectV2 +// CHfaECK-LABEL: func @selectV2BroadcastThen +func @selectV2BroadcastThen(%arg0: tensor<i1>, %arg1: tensor<8x1xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x8x8xi32> { + %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<8x1xi32>, tensor<2x8x8xi32>) -> tensor<2x8x8xi32> + return %0: tensor<2x8x8xi32> +} + +// ----- + +// Test valid tf.SelectV2 +// CHECK-LABEL: func @selectV2BroadcastElse +func @selectV2BroadcastElse(%arg0: tensor<i1>, %arg1: tensor<2x8x8xi32>, %arg2: tensor<8x1xi32>) -> tensor<2x8x8xi32> { + %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<2x8x8xi32>, tensor<8x1xi32>) -> tensor<2x8x8xi32> + return %0: tensor<2x8x8xi32> +} + +// ----- + +// Test valid tf.SelectV2 +// CHECK-LABEL: func @selectV2BroadcastPred +func @selectV2BroadcastPred(%arg0: tensor<1xi1>, %arg1: tensor<2x8x8xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x8x8xi32> { + %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x8x8xi32>, tensor<2x8x8xi32>) -> tensor<2x8x8xi32> + return %0: tensor<2x8x8xi32> +} + +// ----- + +// CHECK-LABEL: func @selectV2BroadcastAll +func @selectV2BroadcastAll(%arg0: tensor<8x1x1xi1>, %arg1: tensor<1x8x1xi32>, %arg2: tensor<1x1x8xi32>) -> tensor<8x8x8xi32> { + %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<8x1x1xi1>, tensor<1x8x1xi32>, tensor<1x1x8xi32>) -> tensor<8x8x8xi32> + return %0: tensor<8x8x8xi32> +} + +// ----- + +// CHECK-LABEL: func @selectV2DynamicRanked +func @selectV2DynamicRanked(%arg0: tensor<1xi1>, %arg1: tensor<2x?x8xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x?x8xi32> { + %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x?x8xi32>, tensor<2x8x8xi32>) -> tensor<2x?x8xi32> + return %0: tensor<2x?x8xi32> +} + +// ----- + +// CHECK-LABEL: func @selectV2Unranked +func @selectV2Unranked(%arg0: tensor<1xi1>, %arg1: tensor<2x8x8xi32>, %arg2: tensor<*xi32>) -> tensor<*xi32> { + %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x8x8xi32>, tensor<*xi32>) -> tensor<*xi32> + return %0: tensor<*xi32> +} + +// ----- + +// Test invalid tf.SelectV2: this is an invalid broadcast for the predicate +func @testInvalidSelectV2(%arg0: tensor<3xi1>, %arg1: tensor<3x2xf16>, %arg2: tensor<3x2xf16>) -> tensor<3x2xf16> { + // expected-error @+1 {{operands don't have broadcast-compatible shapes}} + %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<3x2xf16>, tensor<3x2xf16>) -> tensor<3x2xf16> + return %0: tensor<3x2xf16> +} + +// ----- + //===--------------------------------------------------------------------===// // tf.Softmax //===--------------------------------------------------------------------===// @@ -1326,7 +1645,7 @@ func @testShapeMismatchDim(tensor<1x32x32x16xf32>) -> tensor<2xi32> { func @testShapeWrongResultDimDynamic(tensor<*xf32>) -> tensor<2xi32> { ^bb0(%arg0: tensor<*xf32>): - // expected-error @+1 {{requires dynamic shape result for unranked operand}} + // expected-warning @+1 {{has static shape result for unranked operand}} %0 = "tf.Shape"(%arg0) {T = "tfdtype$DT_FLOAT", output = "tfdtype$DT_INT32"} : (tensor<*xf32>) -> tensor<2xi32> return %0 : tensor<2xi32> } @@ -1370,7 +1689,7 @@ func @testShapeNMismatchDim(tensor<1x32x32x16xf32>) -> tensor<2xi32> { func @testShapeNWrongResultDimDynamic(tensor<*xf32>) -> tensor<2xi32> { ^bb0(%arg0: tensor<*xf32>): - // expected-error @+1 {{requires dynamic shape result #1 for unranked operand #1}} + // expected-warning @+1 {{has static shape result #1 for unranked operand #1}} %0:2 = "tf.ShapeN"(%arg0, %arg0) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<?xi32>, tensor<2xi32>) return %0#1 : tensor<2xi32> } @@ -1428,7 +1747,7 @@ func @testVariableShapeMismatchDim(%arg0: tensor<*x!tf.resource<tensor<1x32x32x1 // ----- func @testVariableShapeWrongResultDimDynamic(%arg0: tensor<*x!tf.resource<tensor<*xf32>>>) -> tensor<2xi32> { - // expected-error @+1 {{requires dynamic shape result for unranked operand}} + // expected-warning @+1 {{has static shape result for unranked operand}} %0 = "tf.VariableShape"(%arg0) {output = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource<tensor<*xf32>>>) -> tensor<2xi32> return %0 : tensor<2xi32> } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_freeze_global_tensors.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_freeze_global_tensors.mlir index 38aa078358b..961039e7968 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_freeze_global_tensors.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_freeze_global_tensors.mlir @@ -104,3 +104,9 @@ module attributes {tf_saved_model.semantics} { return } } + +// ----- + +// Test running the pass on a module that does not have +// tf_saved_model.semantics. +module {} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_optimize_global_tensors.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_optimize_global_tensors.mlir index f985be16ab8..80d9a498253 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_optimize_global_tensors.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_optimize_global_tensors.mlir @@ -136,3 +136,9 @@ module attributes {tf_saved_model.semantics} { } } + +// ----- + +// Test running the pass on a module that does not have +// tf_saved_model.semantics. +module {} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_head_tail_outside_compilation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_head_tail_outside_compilation.mlir index eb67bdcc914..9af75255202 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_head_tail_outside_compilation.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_head_tail_outside_compilation.mlir @@ -2,80 +2,183 @@ // Tests extraction of a outside compiled ops at head of TPU computation. -func @single_head_outside_compilation(%arg0 : tensor<i32>) -> () { - // CHECK: tf_device.launch - // CHECK: "tf.A" - // CHECK-NEXT: tf_device.return - // - // CHECK: "tf_device.cluster" - // CHECK: "tf.C" - // CHECK-NEXT: tf_device.return - "tf_device.cluster"() ( { - "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> () - "tf.B"() : () -> () - "tf.C"() : () -> () - tf_device.return - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> () - return -} +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { + // CHECK-LABEL: func @single_head_outside_compilation + func @single_head_outside_compilation(%arg0 : tensor<i32>) -> () { + // CHECK: tf_device.launch + // CHECK: "tf.A" + // CHECK-NEXT: tf_device.return + // CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0" + // + // CHECK: "tf_device.cluster" + // CHECK: "tf.C" + // CHECK-NEXT: tf_device.return + "tf_device.cluster"() ( { + "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> () + "tf.B"() : () -> () + "tf.C"() : () -> () + tf_device.return + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> () + return + } -// CHECK-LABEL: func @multiple_head_outside_compilation -func @multiple_head_outside_compilation(%arg0 : tensor<i32>) -> () { - // CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"() - // CHECK: %[[A_OUT:.*]] = "tf.A" - // CHECK: %[[B_OUT:.*]] = "tf.B"(%[[A_OUT]]) - // CHECK: "tf.C" - // CHECK-NEXT: tf_device.return %[[B_OUT]] - // - // CHECK: "tf_device.cluster" - // CHECK: "tf.D"(%[[LAUNCH_OUT]]) - // CHECK-NEXT: tf_device.return - "tf_device.cluster"() ( { - %0 = "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> (tensor<i32>) - %1 = "tf.B"(%0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> (tensor<i32>) - "tf.C"(%1, %arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<i32>) -> () - "tf.D"(%1) : (tensor<i32>) -> () - tf_device.return - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> () - return -} + // CHECK-LABEL: func @ops_no_operands + func @ops_no_operands() -> () { + // CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"() + // CHECK: %[[A_OUT:.*]] = "tf.A" + // CHECK-NEXT: tf_device.return %[[A_OUT]] + // CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0" + // + // CHECK: "tf_device.cluster" + // CHECK-NEXT: "tf.B"(%[[LAUNCH_OUT]]) + // CHECK-NEXT: "tf.C" + // CHECK-NEXT: tf_device.return + "tf_device.cluster"() ( { + %0 = "tf.A"() {_xla_outside_compilation = "cluster1"} : () -> (tensor<i32>) + %1 = "tf.B"(%0) {}: (tensor<i32>) -> (tensor<i32>) + "tf.C"(%1) : (tensor<i32>) -> () + tf_device.return + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> () + return + } -// CHECK-LABEL: func @test_do_not_outside_compiled_ops_in_middle -func @test_do_not_outside_compiled_ops_in_middle(%arg0 : tensor<i32>) -> () { - // CHECK-NOT: tf_device.launch - // CHECK: "tf_device.cluster" - // CHECK-NEXT: "tf.A" - // CHECK-NEXT: "tf.B" - // CHECK-NEXT: "tf.C" - // CHECK-NEXT: tf_device.return - "tf_device.cluster"() ( { - %0 = "tf.A"(%arg0) {} : (tensor<i32>) -> (tensor<i32>) - %1 = "tf.B"(%0) {_xla_outside_compilation = "cluster1"}: (tensor<i32>) -> (tensor<i32>) - "tf.C"(%1) : (tensor<i32>) -> () - tf_device.return - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> () - return -} + // CHECK-LABEL: func @aliased_output + func @aliased_output() -> (tensor<i32>, tensor<i32>, tensor<i32>) { + // CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"() + // CHECK: %[[A_OUT:.*]] = "tf.A" + // CHECK-NEXT: tf_device.return %[[A_OUT]] + // CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0" + // + // CHECK: %[[CLUSTER_OUT:.*]]:2 = "tf_device.cluster" + // CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"(%[[LAUNCH_OUT]]) + // CHECK-NEXT: %[[C_OUT:.*]] = "tf.C" + // CHECK-NEXT: tf_device.return %[[C_OUT]], %[[B_OUT]] + // CHECK-NEXT: { + // CHECK-DAG: num_cores_per_replica = 1 + // CHECK-DAG: step_marker_location = "" + // CHECK-DAG: padding_map = [] + // CHECK-DAG: topology = "" + // CHECK-DAG: device_assignment = [] + // + // CHECK: return %[[LAUNCH_OUT]], %[[CLUSTER_OUT]]#0, %[[CLUSTER_OUT]]#1 + %0:3 = "tf_device.cluster"() ( { + %1 = "tf.A"() {_xla_outside_compilation = "cluster1"} : () -> (tensor<i32>) + %2 = "tf.B"(%1) {}: (tensor<i32>) -> (tensor<i32>) + %3 = "tf.C"(%2) : (tensor<i32>) -> (tensor<i32>) + tf_device.return %1, %3, %2 : tensor<i32>, tensor<i32>, tensor<i32> + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> (tensor<i32>, tensor<i32>, tensor<i32>) + return %0#0, %0#1, %0#2 : tensor<i32>, tensor<i32>, tensor<i32> + } -// CHECK-LABEL: func @test_ops_with_tpu_operands_not_extracted -func @test_ops_with_tpu_operands_not_extracted(%arg0 : tensor<i32>) -> () { - // CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"() - // CHECK: %[[A_OUT:.*]] = "tf.A" - // CHECK: %[[D_OUT:.*]] = "tf.D"(%[[A_OUT]]) - // CHECK-NEXT: tf_device.return %[[D_OUT]] - // - // CHECK: "tf_device.cluster" - // CHECK: "tf.B" - // CHECK: "tf.C" - // CHECK: "tf.E" - // CHECK-NEXT: tf_device.return - "tf_device.cluster"() ( { - %0 = "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> (tensor<i32>) - %1 = "tf.B"() {} : () -> (tensor<i32>) - %2 = "tf.C"(%arg0, %1) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<i32>) -> (tensor<i32>) - %3 = "tf.D"(%0) {_xla_outside_compilation = "cluster1"}: (tensor<i32>) -> (tensor<i32>) - %4 = "tf.E"(%3) {} : (tensor<i32>) -> (tensor<i32>) - tf_device.return - }) {device = "tpu0", launch_attr = "launch_attr"} : () -> () - return + // CHECK-LABEL: func @all_head_computation_ops + func @all_head_computation_ops(%arg0 : tensor<i32>) -> (tensor<i32>) { + // CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"() + // CHECK: %[[A_OUT:.*]] = "tf.A" + // CHECK: %[[B_OUT:.*]] = "tf.B"(%[[A_OUT]]) + // CHECK: %[[C_OUT:.*]] = "tf.C"(%[[B_OUT]], %arg0) + // CHECK-NEXT: tf_device.return %[[C_OUT]] + // CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0" + // + // CHECK: "tf_device.cluster" + // CHECK-NEXT: tf_device.return + // + // CHECK: return %[[LAUNCH_OUT]] + %0 = "tf_device.cluster"() ( { + %1 = "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> (tensor<i32>) + %2 = "tf.B"(%1) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> (tensor<i32>) + %3 = "tf.C"(%2, %arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<i32>) -> (tensor<i32>) + tf_device.return %3 : tensor<i32> + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> (tensor<i32>) + return %0 : tensor<i32> + } + + // CHECK-LABEL: func @multiple_head_outside_compilation + func @multiple_head_outside_compilation(%arg0 : tensor<i32>) -> () { + // CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"() + // CHECK: %[[A_OUT:.*]] = "tf.A" + // CHECK: %[[B_OUT:.*]] = "tf.B"(%[[A_OUT]]) + // CHECK: "tf.C" + // CHECK-NEXT: tf_device.return %[[B_OUT]] + // CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0" + // + // CHECK: "tf_device.cluster" + // CHECK: "tf.D"(%[[LAUNCH_OUT]]) + // CHECK-NEXT: tf_device.return + "tf_device.cluster"() ( { + %0 = "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> (tensor<i32>) + %1 = "tf.B"(%0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> (tensor<i32>) + "tf.C"(%1, %arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<i32>) -> () + "tf.D"(%1) : (tensor<i32>) -> () + tf_device.return + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> () + return + } + + // CHECK-LABEL: func @test_do_not_outside_compiled_ops_in_middle + func @test_do_not_outside_compiled_ops_in_middle(%arg0 : tensor<i32>) -> () { + // CHECK-NOT: tf_device.launch + // CHECK: "tf_device.cluster" + // CHECK-NEXT: "tf.A" + // CHECK-NEXT: "tf.B" + // CHECK-NEXT: "tf.C" + // CHECK-NEXT: tf_device.return + "tf_device.cluster"() ( { + %0 = "tf.A"(%arg0) {} : (tensor<i32>) -> (tensor<i32>) + %1 = "tf.B"(%0) {_xla_outside_compilation = "cluster1"}: (tensor<i32>) -> (tensor<i32>) + "tf.C"(%1) : (tensor<i32>) -> () + tf_device.return + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> () + return + } + + // CHECK-LABEL: func @test_ops_with_tpu_operands_not_extracted + func @test_ops_with_tpu_operands_not_extracted(%arg0 : tensor<i32>) -> () { + // CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"() + // CHECK: %[[A_OUT:.*]] = "tf.A" + // CHECK: %[[D_OUT:.*]] = "tf.D"(%[[A_OUT]]) + // CHECK-NEXT: tf_device.return %[[D_OUT]] + // CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0" + // + // CHECK: "tf_device.cluster" + // CHECK: "tf.B" + // CHECK: "tf.C" + // CHECK: "tf.E" + // CHECK-NEXT: tf_device.return + "tf_device.cluster"() ( { + %0 = "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> (tensor<i32>) + %1 = "tf.B"() {} : () -> (tensor<i32>) + %2 = "tf.C"(%arg0, %1) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<i32>) -> (tensor<i32>) + %3 = "tf.D"(%0) {_xla_outside_compilation = "cluster1"}: (tensor<i32>) -> (tensor<i32>) + %4 = "tf.E"(%3) {} : (tensor<i32>) -> (tensor<i32>) + tf_device.return + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> () + return + } + + // CHECK-LABEL: func @test_replicated_head_outside_compilation + func @test_replicated_head_outside_compilation(%arg0 : tensor<i32>) -> () { + // CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"() + // CHECK: %[[A_OUT:.*]] = "tf.A" + // CHECK: %[[D_OUT:.*]] = "tf.D"(%[[A_OUT]]) + // CHECK-NEXT: tf_device.return %[[D_OUT]] + // CHECK: device = "TPU_REPLICATED_HOST" + // + // CHECK: "tf_device.cluster" + // CHECK: "tf.B" + // CHECK: "tf.C" + // CHECK: "tf.E" + // CHECK-NEXT: tf_device.return + tf_device.replicate() {n = 2 : i32} { + "tf_device.cluster"() ( { + %0 = "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> (tensor<i32>) + %1 = "tf.B"() {} : () -> (tensor<i32>) + %2 = "tf.C"(%arg0, %1) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<i32>) -> (tensor<i32>) + %3 = "tf.D"(%0) {_xla_outside_compilation = "cluster1"}: (tensor<i32>) -> (tensor<i32>) + %4 = "tf.E"(%3) {} : (tensor<i32>) -> (tensor<i32>) + tf_device.return + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> () + tf_device.return + } + return + } } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir index 3cb693ee571..9396e1fb88a 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir @@ -141,4 +141,133 @@ func @multiple_tpu_return_single_outside_compilation(%arg0: tensor<?xi32>) -> te return %1 : tensor<?xf32> } +// Tests extraction of a single outside compiled cluster with single device->host input. + +// CHECK-LABEL: func @single_outside_compiled_input_single_outside_compilation +func @single_outside_compiled_input_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> { + %0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32> + // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" + // CHECK-NEXT: "tf_device.launch" + // CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir" + // CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]]) + // CHECK-SAME: key = "host_compute_channel_cluster1" + // CHECK: "tf.B"(%[[RECV_OUTPUT]]) + // CHECK: "tf_device.cluster" + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + // CHECK: "tf._HostComputeMlir"(%[[A_OUTPUT]]) + // CHECK-SAME: key = "host_compute_channel_cluster1" + %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} { + %2 = "tf_device.cluster"() ( { + %3 = "tf.A"() : () -> (tensor<?xi32>) + "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> () + %4 = "tf.C"() : () -> tensor<?xi32> + tf_device.return %4 : tensor<?xi32> + }) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32> + tf_device.return %2 : tensor<?xi32> + } + + return %1 : tensor<?xi32> +} + +// Tests extraction of a single outside compiled cluster with arg input and single device->host input. + +// CHECK-LABEL: func @mixed_input_single_outside_compilation +func @mixed_input_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> { + %0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32> + // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" + // CHECK-NEXT: "tf_device.launch" + // CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir" + // CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]]) + // CHECK-SAME: key = "host_compute_channel_cluster1" + // CHECK: "tf.B"(%arg0, %[[RECV_OUTPUT]]) + // CHECK: "tf_device.cluster" + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + // CHECK: "tf._HostComputeMlir"(%[[A_OUTPUT]]) + // CHECK-SAME: key = "host_compute_channel_cluster1" + %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} { + %2 = "tf_device.cluster"() ( { + %3 = "tf.A"() : () -> (tensor<?xi32>) + "tf.B"(%arg0, %3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>, tensor<?xi32>) -> () + %4 = "tf.C"() : () -> tensor<?xi32> + tf_device.return %4 : tensor<?xi32> + }) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32> + tf_device.return %2 : tensor<?xi32> + } + + return %1 : tensor<?xi32> +} + +// Tests extraction of a multiple outside compiled clusters with single device->host input. + +// CHECK-LABEL: func @single_outside_compiled_input_multiple_outside_compilation +func @single_outside_compiled_input_multiple_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> { + %0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32> + // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" + // CHECK-NEXT: "tf_device.launch" + // CHECK: %[[STATUS_OUTPUT_2:[a-z_0-9]*]], %[[PROGRAM_OUTPUT_2:[a-z_0-9]*]] = "tf._TPUCompileMlir" + // CHECK: %[[RECV_OUTPUT_2:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT_2]]) + // CHECK-SAME: key = "host_compute_channel_cluster2" + // CHECK: "tf.D"(%[[RECV_OUTPUT_2]]) + // CHECK: "tf_device.launch" + // CHECK: %[[STATUS_OUTPUT_1:[a-z_0-9]*]], %[[PROGRAM_OUTPUT_1:[a-z_0-9]*]] = "tf._TPUCompileMlir" + // CHECK: %[[RECV_OUTPUT_1:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT_1]]) + // CHECK-SAME: key = "host_compute_channel_cluster1" + // CHECK: "tf.B"(%[[RECV_OUTPUT_1]]) + // CHECK: "tf_device.cluster" + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + // CHECK: "tf._HostComputeMlir"(%[[A_OUTPUT]]) + // CHECK-SAME: key = "host_compute_channel_cluster1" + // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C" + // CHECK: "tf._HostComputeMlir"(%[[C_OUTPUT]]) + // CHECK-SAME: key = "host_compute_channel_cluster2" + %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} { + %2 = "tf_device.cluster"() ( { + %3 = "tf.A"() : () -> (tensor<?xi32>) + "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> () + %4 = "tf.C"() : () -> tensor<?xi32> + "tf.D"(%4) {_xla_outside_compilation = "cluster2"} : (tensor<?xi32>) -> () + tf_device.return %4 : tensor<?xi32> + }) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32> + tf_device.return %2 : tensor<?xi32> + } + + return %1 : tensor<?xi32> +} + +// Tests extraction of a single outside compiled cluster with multiple device->host inputs. + +// CHECK-LABEL: func @multiple_outside_compiled_inputs_single_outside_compilation +func @multiple_outside_compiled_inputs_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> { + %0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32> + // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" + // CHECK-NEXT: "tf_device.launch" + // CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir" + // CHECK: %[[RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]]) + // CHECK-SAME: key = "host_compute_channel_cluster1" + // CHECK: "tf.C"(%[[RECV_OUTPUT]]#0) + // CHECK: "tf.D"(%[[RECV_OUTPUT]]#1, %[[RECV_OUTPUT]]#0) + // CHECK: "tf_device.cluster" + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B" + // CHECK: "tf._HostComputeMlir"(%[[A_OUTPUT]], %[[B_OUTPUT]]) + // CHECK-SAME: key = "host_compute_channel_cluster1" + %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} { + %2 = "tf_device.cluster"() ( { + %3 = "tf.A"() : () -> (tensor<?xi32>) + %4 = "tf.B"() : () -> (tensor<?xi32>) + "tf.C"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> () + "tf.D"(%4, %3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>, tensor<?xi32>) -> () + %5 = "tf.E"() : () -> tensor<?xi32> + tf_device.return %5 : tensor<?xi32> + }) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32> + tf_device.return %2 : tensor<?xi32> + } + + return %1 : tensor<?xi32> +} + // TODO(b/154363171): Add test cases for when output of outside compilation is returned by parallel_execute. diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir index b8a48bbb379..5d65342b4a7 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir @@ -747,7 +747,9 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests simple case of `tf_device.cluster_func` on TPU with replication. +// Tests simple case of `tf_device.cluster_func` on TPU with replication. Under +// data parallelism replicated host devices are also added to the +// tf_device.replicate module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]} { // CHECK-LABEL: func @replicated_tpu_cluster_func @@ -758,7 +760,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate // CHECK-SAME: ([%[[A_OUTPUT]], %[[ARG_0]]] as %[[RI_0:[a-z0-9]*]]: tensor<?xi32>) - // CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]} + // CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"], TPU_REPLICATED_HOST = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:CPU:0"]} // CHECK-SAME: n = 2 %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} { // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[RI_0]]) @@ -1222,7 +1224,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests simple case of `tf_device.cluster_func` on TPU with replication and parallel_execute. +// Tests simple case of `tf_device.cluster_func` on TPU with replication and +// parallel_execute. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]} { // CHECK-LABEL: func @replicated_parallel_tpu_cluster_func @@ -1231,17 +1234,26 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor %0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32> // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} { + // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch" // CHECK: "tf._TPUCompileMlir" // CHECK: "tf.TPUCompileSucceededAssert" // CHECK: "tf_device.parallel_execute" + // CHECK-NOT:"tf._TPUCompileMlir" + // CHECK: "tf.D"(%[[COMPILE_OUTPUT]]#1 // CHECK: "tf.TPUExecute" + // CHECK-NOT:"tf._TPUCompileMlir" + // CHECK: "tf.E"(%[[COMPILE_OUTPUT]]#1 %3 = "tf_device.parallel_execute"() ( { - "tf.D"() : () -> () + %status, %program = "tf._TPUCompileMlir"() {metadata = "...", mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>) + "tf.D"(%program) : (tensor<!tf.string>) -> () tf_device.return }, { %4 = "tf_device.cluster_func"(%ri_0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<?xi32>) -> tensor<?xi32> - tf_device.return %4 : tensor<?xi32> + }, { + %status, %program = "tf._TPUCompileMlir"() {metadata = "...", mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>) + "tf.E"(%program) : (tensor<!tf.string>) -> () + tf_device.return }) : () -> (tensor<?xi32>) tf_device.return %3 : tensor<?xi32> } @@ -1317,15 +1329,14 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01" // ----- -// Tests devices are set properly for replicated model parallelism. +// Tests devices are set properly for replicated model parallelism. No +// replicated host device should be present. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} { // CHECK-LABEL: func @replicated_parallel_execute func @replicated_parallel_execute(%arg0: tensor<8xi32>, %arg1: tensor<8xi32>) -> (tensor<8xi32>, tensor<8xi32>) { // CHECK: tf_device.replicate - // CHECK-SAME: devices = - // CHECK-SAME: TPU_REPLICATED_CORE_0 = ["/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1"] - // CHECK-SAME: TPU_REPLICATED_CORE_1 = ["/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU:0"] + // CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1"], TPU_REPLICATED_CORE_1 = ["/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU:0"]} %0:2 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<8xi32>) {n = 2 : i32} { // CHECK-NEXT: %[[COMPILE:[a-z0-9]+]]:3 = "tf_device.launch" // CHECK-NEXT: "tf._TPUCompileMlir"() @@ -1357,8 +1368,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // ----- -// Tests that inputs are inputs with maximal and replicate sharding are set properly -// for replicated model parallelism. +// Tests that inputs are inputs with maximal and replicate sharding are set +// properly for replicated model parallelism. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} { // CHECK-LABEL: func @parallel_execute_with_input_with_sharding_configurations @@ -1392,8 +1403,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // ----- -// Tests devices are set properly for replicated model parallelism with -// outputs to TPU computation placed on logical device 0. +// Tests devices are set properly for replicated model parallelism with outputs +// to TPU computation placed on logical device 0. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} { // CHECK-LABEL: func @parallel_execute_with_different_outputs @@ -1469,8 +1480,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // ----- -// Tests inputs are correctly split and fed into TPU computation for -// tiled input sharding. +// Tests inputs are correctly split and fed into TPU computation for tiled input +// sharding. // The following OpSharding is used for TPU computation inputs in below test: // Proto debug string: diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td b/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td index ccc3e83a2a2..cf09f8d64fb 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td @@ -152,6 +152,23 @@ def RealDivWithSqrtDivisor : Pat<(TF_RealDivOp $arg0, (TF_SqrtOp $arg1)), def ReciprocalNested : Pat<(TF_ReciprocalOp (TF_ReciprocalOp $arg)), (replaceWithValue $arg)>; +//===----------------------------------------------------------------------===// +// Select op patterns. +//===----------------------------------------------------------------------===// + +def ReshapeSelectPredIfNecessary : NativeCodeCall< + "ReshapeSelectPredIfNecessary(&($_builder), $0.getOwner()->getLoc(), $1, " + "$2.getType().cast<RankedTensorType>().getRank())">; + +// Select supports tensor `condition` where the shape is equal to the first +// dimension of t and e. SelectV2 op supports normal broadcasting, so in these +// cases the condition needs to be reshaped. +def SelectToSelectV2 : Pat< + (TF_SelectOp:$op StaticShapeTensorOf<[AnyType]>:$cond, + StaticShapeTensorOf<[AnyType]>:$t, + StaticShapeTensorOf<[AnyType]>:$e), + (TF_SelectV2Op (ReshapeSelectPredIfNecessary $op, $cond, $t), $t, $e)>; + //===----------------------------------------------------------------------===// // Square op patterns. //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc index be35c6caa16..55a0b5c3fd3 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc @@ -17,7 +17,7 @@ limitations under the License. #include <algorithm> -#include "mlir/Interfaces/SideEffects.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/tf_status.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/freeze_global_tensors.cc b/tensorflow/compiler/mlir/tensorflow/transforms/freeze_global_tensors.cc index d3b064f3efa..a0cf9c8eb9a 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/freeze_global_tensors.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/freeze_global_tensors.cc @@ -48,6 +48,9 @@ struct FreezeGlobalTensorsPass void FreezeGlobalTensorsPass::runOnOperation() { auto module = getOperation(); + if (!tf_saved_model::HasTfSavedModelSemantics(module)) { + return; + } SymbolTable symbol_table(module); DenseSet<Operation*> frozen_global_tensors; @@ -66,7 +69,9 @@ void FreezeGlobalTensorsPass::runOnOperation() { // previous optimize global tensors pass). If not, this pass has to fail // since it cannot perform one of its goals. if (global_tensor.is_mutable()) { - global_tensor.emitError() << "is not immutable"; + global_tensor.emitError() << "is not immutable, try running " + "tf-saved-model-optimize-global-tensors " + "to prove tensors are immutable"; return signalPassFailure(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc index 50f77cd9c3d..524b3e4f4b7 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc @@ -16,21 +16,61 @@ limitations under the License. // This file implements logic for legalizing HLO to TensorFlow. #include <memory> +#include <vector> +#include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/xla/ir/chlo_ops.h" #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" namespace mlir { namespace TF { namespace { +class ConvertSliceOp : public OpConversionPattern<xla_hlo::SliceOp> { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + xla_hlo::SliceOp slice_op, ArrayRef<Value> args, + ConversionPatternRewriter &rewriter) const final { + DenseIntElementsAttr strides = slice_op.strides(); + // Strides must be 1 otherwise we cannot legalize this `xla_hlo.slice` op. + if (!strides.isSplat() || + strides.getSplatValue().cast<IntegerAttr>().getInt() != 1) + return failure(); + + rewriter.setInsertionPointAfter(slice_op); + auto start_indices = slice_op.start_indices(); + auto limit_indices = slice_op.limit_indices(); + std::vector<int64_t> size_values; + for (auto pair : llvm::zip(start_indices.getValues<APInt>(), + limit_indices.getValues<APInt>())) { + size_values.emplace_back(std::get<1>(pair).getSExtValue() - + std::get<0>(pair).getSExtValue()); + } + + RankedTensorType ty = + RankedTensorType::get({static_cast<int64_t>(size_values.size())}, + rewriter.getIntegerType(64)); + auto start = rewriter.create<ConstOp>(slice_op.getLoc(), start_indices); + auto size = rewriter.create<ConstOp>( + slice_op.getLoc(), DenseIntElementsAttr::get(ty, size_values)); + rewriter.replaceOpWithNewOp<SliceOp>(slice_op, slice_op.getType(), + slice_op.operand(), start, size); + return success(); + }; +}; + class LegalizeHloToTf : public PassWrapper<LegalizeHloToTf, FunctionPass> { public: LegalizeHloToTf() = default; @@ -63,6 +103,7 @@ void LegalizeHloToTf::runOnFunction() { // Add legalization patterns to the list. OwningRewritePatternList patterns; populateWithGenerated(&context, &patterns); + patterns.insert<ConvertSliceOp>(&context); ConversionTarget target(context); target.addLegalDialect<TensorFlowDialect>(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td index f3371989b73..6fd7556084d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td @@ -18,12 +18,16 @@ limitations under the License. include "mlir/IR/OpBase.td" include "mlir/Dialect/StandardOps/IR/Ops.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" +include "tensorflow/compiler/mlir/xla/ir/chlo_ops.td" include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td" def : Pat<(HLO_ConstOp $value), (TF_ConstOp $value)>; //===----------------------------------------------------------------------===// // Binary op patterns. +// Note that these are legalized from chlo.broadcast_* ops, since those are +// semantically compatible with the corresponding TF ops. Depending on +// context, getting to these ops may require some raising. //===----------------------------------------------------------------------===// // Check that two values can be broadcasted together @@ -31,36 +35,45 @@ def : Pat<(HLO_ConstOp $value), (TF_ConstOp $value)>; def AreBroadcastCompatible : Constraint<CPred<"AreBroadcastCompatible($0, $1)">, "types must be broadcastable">; -foreach fromToBinPair = [[HLO_AddOp, TF_AddV2Op], - [HLO_DivOp, TF_DivOp], - [HLO_ShiftLeftOp, TF_LeftShiftOp], - [HLO_MaxOp, TF_MaximumOp], - [HLO_MinOp, TF_MinimumOp], - [HLO_MulOp, TF_MulOp], - [HLO_PowOp, TF_PowOp], - [HLO_SubOp, TF_SubOp], - [HLO_Atan2Op, TF_Atan2Op], - [HLO_RemOp, TF_ModOp]] in - def : Pat<(fromToBinPair[0] $l, $r, $_), (fromToBinPair[1] $l, $r), +foreach fromToBinPair = [[HLO_AddOp, HLOClient_BroadcastAddOp, TF_AddV2Op], + [HLO_DivOp, HLOClient_BroadcastDivOp, TF_DivOp], + [HLO_ShiftLeftOp, HLOClient_BroadcastShiftLeftOp, TF_LeftShiftOp], + [HLO_MaxOp, HLOClient_BroadcastMaxOp, TF_MaximumOp], + [HLO_MinOp, HLOClient_BroadcastMinOp, TF_MinimumOp], + [HLO_MulOp, HLOClient_BroadcastMulOp, TF_MulOp], + [HLO_PowOp, HLOClient_BroadcastPowOp, TF_PowOp], + [HLO_SubOp, HLOClient_BroadcastSubOp, TF_SubOp], + [HLO_Atan2Op, HLOClient_BroadcastAtan2Op, TF_Atan2Op], + [HLO_RemOp, HLOClient_BroadcastRemOp, TF_ModOp]] in { + def : Pat<(fromToBinPair[0] $l, $r), (fromToBinPair[2] $l, $r)>; + def : Pat<(fromToBinPair[1] $l, $r, $_), (fromToBinPair[2] $l, $r), [(AreBroadcastCompatible $l, $r)]>; +} -foreach pair = [[HLO_AndOp, TF_BitwiseAndOp], - [HLO_OrOp, TF_BitwiseOrOp], - [HLO_XorOp, TF_BitwiseXorOp]] in - def : Pat<(pair[0] TF_IntTensor:$l, TF_IntTensor:$r, $_), (pair[1] $l, $r), +foreach pair = [[HLO_AndOp, HLOClient_BroadcastAndOp, TF_BitwiseAndOp], + [HLO_OrOp, HLOClient_BroadcastOrOp, TF_BitwiseOrOp], + [HLO_XorOp, HLOClient_BroadcastXorOp, TF_BitwiseXorOp]] in { + def : Pat<(pair[0] TF_IntTensor:$l, TF_IntTensor:$r), (pair[2] $l, $r)>; + def : Pat<(pair[1] TF_IntTensor:$l, TF_IntTensor:$r, $_), (pair[2] $l, $r), [(AreBroadcastCompatible $l, $r)]>; +} -foreach pair = [[HLO_AndOp, TF_LogicalAndOp], - [HLO_OrOp, TF_LogicalOrOp]] in - def : Pat<(pair[0] I1Tensor:$l, I1Tensor:$r, $_), (pair[1] $l, $r), +foreach pair = [[HLO_AndOp, HLOClient_BroadcastAndOp, TF_LogicalAndOp], + [HLO_OrOp, HLOClient_BroadcastOrOp, TF_LogicalOrOp]] in { + def : Pat<(pair[0] I1Tensor:$l, I1Tensor:$r), (pair[2] $l, $r)>; + def : Pat<(pair[1] I1Tensor:$l, I1Tensor:$r, $_), (pair[2] $l, $r), [(AreBroadcastCompatible $l, $r)]>; +} -def : Pat<(HLO_ShiftRightArithmeticOp $l, $r, $_), (TF_RightShiftOp $l, $r), +def : Pat<(HLO_ShiftRightArithmeticOp $l, $r), (TF_RightShiftOp $l, $r)>; +def : Pat<(HLOClient_BroadcastShiftRightArithmeticOp $l, $r, $_), (TF_RightShiftOp $l, $r), [(AreBroadcastCompatible $l, $r)]>; -def : Pat<(HLO_ShiftRightLogicalOp $l, $r, $_), (TF_RightShiftOp $l, $r), +def : Pat<(HLO_ShiftRightLogicalOp $l, $r), (TF_RightShiftOp $l, $r)>; +def : Pat<(HLOClient_BroadcastShiftRightLogicalOp $l, $r, $_), (TF_RightShiftOp $l, $r), [(AreBroadcastCompatible $l, $r)]>; -def : Pat<(HLO_FloorOp (HLO_DivOp $l, $r, $_)), (TF_FloorDivOp $l, $r), +def : Pat<(HLO_FloorOp (HLO_DivOp $l, $r)), (TF_FloorDivOp $l, $r)>; +def : Pat<(HLO_FloorOp (HLOClient_BroadcastDivOp $l, $r, $_)), (TF_FloorDivOp $l, $r), [(AreBroadcastCompatible $l, $r)]>; def : Pat<(HLO_ComplexOp $r, $i), (TF_ComplexOp $r, $i)>; @@ -117,16 +130,23 @@ def : Pat<(HLO_ConcatenateOp $inputs, $dim), //===----------------------------------------------------------------------===// // Compare op patterns. +// Note that these are legalized from chlo.broadcast_* ops, since those are +// semantically compatible with the corresponding TF ops. Depending on +// context, getting to these ops may require some raising. //===----------------------------------------------------------------------===// foreach p = [[TF_EqualOp, HLO_COMPARISON_DIRECTION_EQ], - [TF_NotEqualOp, HLO_COMPARISON_DIRECTION_NE]] in - def : Pat<(HLO_CompareOp $l, $r, $_, p[1]), (p[0] $l, $r, ConstBoolAttrTrue), + [TF_NotEqualOp, HLO_COMPARISON_DIRECTION_NE]] in { + def : Pat<(HLOClient_BroadcastCompareOp $l, $r, $_, p[1]), (p[0] $l, $r, ConstBoolAttrTrue), [(AreBroadcastCompatible $l, $r)]>; + def : Pat<(HLO_CompareOp $l, $r, p[1]), (p[0] $l, $r, ConstBoolAttrTrue)>; +} foreach pair = [[TF_GreaterEqualOp, HLO_COMPARISON_DIRECTION_GE], [TF_GreaterOp, HLO_COMPARISON_DIRECTION_GT], [TF_LessEqualOp, HLO_COMPARISON_DIRECTION_LE], - [TF_LessOp, HLO_COMPARISON_DIRECTION_LT]] in - def : Pat<(HLO_CompareOp $l, $r, $_, pair[1]), (pair[0] $l, $r), + [TF_LessOp, HLO_COMPARISON_DIRECTION_LT]] in { + def : Pat<(HLOClient_BroadcastCompareOp $l, $r, $_, pair[1]), (pair[0] $l, $r), [(AreBroadcastCompatible $l, $r)]>; + def : Pat<(HLO_CompareOp $l, $r, pair[1]), (pair[0] $l, $r)>; +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc b/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc index 550100c8ebf..cd8f988fd5f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc @@ -278,6 +278,10 @@ void EraseUnusedBoundInputs(ModuleOp module) { void OptimizeGlobalTensorsPass::runOnOperation() { auto module = getOperation(); + if (!tf_saved_model::HasTfSavedModelSemantics(module)) { + return; + } + EraseUnusedBoundInputs(module); ResourceAnalyzer resource_analyzer(module); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index 81d0259d2d6..5c140ddd6aa 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -95,6 +95,11 @@ std::unique_ptr<OperationPass<ModuleOp>> CreatePromoteResourcesToArgsPass(); // functions. std::unique_ptr<OperationPass<ModuleOp>> CreatePromoteVarHandlesToArgsPass(); +// Creates a pass that converts readonly reference variables to the +// corresponding resource variables. +std::unique_ptr<OperationPass<FuncOp>> +CreateConvertReadonlyReferenceVariablesToResourceVariablesPass(); + // Marks function visibility using tf.entry_function specification. That is, // functions with tf.entry_function attributes are marked with public // visibility while the other functions are marked with private visibility. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/readonly_references_to_resources.cc b/tensorflow/compiler/mlir/tensorflow/transforms/readonly_references_to_resources.cc new file mode 100644 index 00000000000..a80b84ddeda --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/readonly_references_to_resources.cc @@ -0,0 +1,179 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" + +namespace mlir { +namespace TF { +namespace { + +// Location attribute. +constexpr StringRef kClassAttr = "_class"; +constexpr StringRef kLocationPrefix = "loc:@"; + +// A pass that converts readonly reference variables to the corresponding +// resource variables. +// +// It converts (VariableV2 -> Identity) to (VarHandle -> ReadVariable). +// +// For the background, this pass is a part of hoisting VariableV2 ops by +// re-using the pipeline for hoisting (VarHandle -> ReadVariable) cases, which +// can be done by the following passes: +// - Capturing resource values into global tensors (importing saved model). +// - Promoting VarHandle ops to function input/outputs. +// - Freezing global tensor pass. +// +// This path assumes that all the VariableV2 ops is read-only via verifying the +// heuristic method that assumes that all the users of them is Identity op, +// fed directly. +class ConvertReadonlyReferenceVariablesToResourceVariablesPass + : public PassWrapper< + ConvertReadonlyReferenceVariablesToResourceVariablesPass, + FunctionPass> { + public: + void runOnFunction() override; +}; + +// Parse node name from "_class" attribute. +StringRef GetNodeNameFromClassAttr(Operation *op) { + ArrayAttr classes_attr = op->getAttrOfType<ArrayAttr>(kClassAttr); + if (!classes_attr) { + op->emitOpError() << "has no '_class' attribute"; + return StringRef(); + } + + StringRef result; + for (Attribute class_attr : classes_attr) { + StringRef node_name = class_attr.cast<StringAttr>().getValue(); + if (!node_name.startswith(kLocationPrefix)) { + continue; + } + if (!result.empty()) { + // Invalid case since there are multiple loc:@ attributes. + op->emitOpError() + << "expects only one named location in '_class' attribute, but got " + << classes_attr; + return StringRef(); + } + result = node_name.drop_front(kLocationPrefix.size()); + } + if (result.empty()) { + op->emitOpError() << "expects variable name in '_class' attribute, but got " + << classes_attr; + } + return result; +} + +void ConvertReadonlyReferenceVariablesToResourceVariablesPass::runOnFunction() { + FuncOp func = getFunction(); + + OpBuilder builder(func.getContext()); + SmallVector<VariableV2Op, 4> variable_v2s_to_replace; + + // Checks all the VariableV2 ops is read-only via verifying the heuristic + // method that assumes that all the users of them is Identity op, feeded + // directly. + auto read_only_vars_fn = [&variable_v2s_to_replace]( + VariableV2Op variable_v2_op) { + if (variable_v2_op.getResult().use_empty()) { + // Erase the op when there is no user. + variable_v2_op.erase(); + return mlir::WalkResult::advance(); + } + if (!all_of(variable_v2_op.getResult().getUsers(), [&variable_v2_op]( + Operation *user) { + if (!isa<IdentityOp>(user)) { + variable_v2_op.emitOpError() + << "expects all users to be 'tf.Identity', but got user " + << user->getName(); + return false; + } + return true; + })) { + return mlir::WalkResult::interrupt(); + } + variable_v2s_to_replace.push_back(variable_v2_op); + return mlir::WalkResult::advance(); + }; + + WalkResult walk_res = func.walk(read_only_vars_fn); + if (walk_res.wasInterrupted()) return signalPassFailure(); + + for (VariableV2Op variable_v2_op : variable_v2s_to_replace) { + builder.setInsertionPoint(variable_v2_op); + ShapedType shaped_type = + variable_v2_op.getResult().getType().cast<ShapedType>(); + TensorType tensor_type = DropRefType(shaped_type).cast<TensorType>(); + StringAttr device_attr = variable_v2_op.getAttrOfType<StringAttr>("device"); + if (!device_attr) device_attr = builder.getStringAttr(""); + StringRef variable_name = GetNodeNameFromClassAttr(variable_v2_op); + if (variable_name.empty()) { + return signalPassFailure(); + } + VarHandleOp var_handle_op = builder.create<VarHandleOp>( + variable_v2_op.getLoc(), + ArrayRef<Type>{RankedTensorType::get( + {}, TF::ResourceType::get(ArrayRef<TensorType>{tensor_type}, + builder.getContext()))}, + ArrayRef<Value>{}, + ArrayRef<NamedAttribute>{ + builder.getNamedAttr("device", device_attr), + builder.getNamedAttr("container", variable_v2_op.containerAttr()), + builder.getNamedAttr("shared_name", + builder.getStringAttr(variable_name))}); + for (Operation *user : + make_early_inc_range(variable_v2_op.getResult().getUsers())) { + builder.setInsertionPoint(user); + ReadVariableOp read_variable_op = builder.create<ReadVariableOp>( + user->getLoc(), ArrayRef<Type>{tensor_type}, + ArrayRef<Value>{var_handle_op}, ArrayRef<NamedAttribute>{}); + user->getResult(0).replaceAllUsesWith(read_variable_op.getResult()); + user->erase(); + } + variable_v2_op.erase(); + } +} + +} // namespace + +std::unique_ptr<OperationPass<FuncOp>> +CreateConvertReadonlyReferenceVariablesToResourceVariablesPass() { + return std::make_unique< + ConvertReadonlyReferenceVariablesToResourceVariablesPass>(); +} + +static PassRegistration< + ConvertReadonlyReferenceVariablesToResourceVariablesPass> + pass("readonly-references-to-resources", + "Convert readonly reference variables to resource variables."); + +} // namespace TF + +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc index d37dfd14590..21d74d81b20 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc @@ -149,7 +149,7 @@ LogicalResult ComputeResourceDevicesInComputation(FuncOp func_op, } auto walk_res = func_op.walk([&](Operation* op) { if (auto var_handle = llvm::dyn_cast<TF::VarHandleOp>(op)) { - // Record VarHanldeOp's device attribute. + // Record VarHandleOp's device attribute. auto device_attr = var_handle.getAttrOfType<mlir::StringAttr>(kDeviceAttr); if (!device_attr || device_attr.getValue().empty()) { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc index 611c4d2725a..82bc612b1f8 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc @@ -571,7 +571,7 @@ void AddLoadsStoresOutsideControlFlowOp( } // Lifts loads/stores from while loop's body and cond functions. -LogicalResult HanldeWhileLoop(TF::WhileOp while_op, FuncOp body, FuncOp cond) { +LogicalResult HandleWhileLoop(TF::WhileOp while_op, FuncOp body, FuncOp cond) { // Remove identity nodes to avoid aliasing. RemoveIdentity(&body.front()); RemoveIdentity(&cond.front()); @@ -985,7 +985,7 @@ LogicalResult HoistForFunctionalControlFlow( lifted_partitioned_call_callees); HoistForFunctionalControlFlow(&cond.front(), module, lifted_partitioned_call_callees); - if (failed(HanldeWhileLoop(while_op, body, cond))) return failure(); + if (failed(HandleWhileLoop(while_op, body, cond))) return failure(); } else if (auto if_op = llvm::dyn_cast<TF::IfOp>(&op)) { auto then_branch = llvm::cast<FuncOp>(module.lookupSymbol(if_op.then_branch())); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc index 5a2cae38062..1e9be76aa66 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc @@ -429,7 +429,8 @@ LogicalResult ComputeInputsRequiredForOutput(ValuePort value_port, // existing computed values. Attribute ComputeOutputComponent(const ValuePort& value_port, ValueQueryFn values) { - LLVM_DEBUG(value_port.print(llvm::errs() << "\nComputing output for ")); + LLVM_DEBUG(value_port.print(llvm::dbgs() << "Computing output for ") << "\n"); + if (auto known = values(value_port)) return known; auto op = value_port.producer.dyn_cast<Operation*>(); if (!op) return nullptr; @@ -454,6 +455,21 @@ Attribute ComputeOutputComponent(const ValuePort& value_port, ValuePort op_port(op->getOperand(port[1])); return values(op_port); } + + if (auto graph = dyn_cast<tf_executor::GraphOp>(op)) { + if (port.size() == 1) + return ComputeOutputComponent( + ValuePort(graph.GetFetch().fetches()[port[0]]), values); + return nullptr; + } + + if (auto island = dyn_cast<tf_executor::IslandOp>(op)) { + if (port.size() == 1) + return ComputeOutputComponent( + ValuePort(island.GetYield().fetches()[port[0]]), values); + return nullptr; + } + return nullptr; } @@ -462,7 +478,8 @@ Attribute ComputeOutputComponent(const ValuePort& value_port, // TF Graph version, constant values computed, etc.) class ShapeInference { public: - ShapeInference(int64_t graph_version, MLIRContext* context); + ShapeInference(int64_t graph_version, MLIRContext* context, + bool propagate_caller_callee_constants); LogicalResult ComputeInputsRequiredForOutput(ValuePort value_port, ValuePortInputs* inputs) { @@ -475,14 +492,19 @@ class ShapeInference { } Attribute ComputeOutputComponent(const ValuePort& value_port) { - return ::mlir::TF::ComputeOutputComponent( + if (auto known_attr = results_[value_port]) return known_attr; + auto attr = ::mlir::TF::ComputeOutputComponent( value_port, [this](const ValuePort& port) { return results_[port]; }); + RecordValue(value_port, attr); + return attr; } // Returns ShapeHandle if the op result could be computed as shape. ShapeHandle ComputeOutputAsShape(OpResult result, InferenceContext* ic); void RecordValue(const ValuePort& value_port, Attribute value) { + LLVM_DEBUG(value_port.print(llvm::dbgs() << "\trecording ") + << value << "\n"); results_[value_port] = value; } @@ -520,19 +542,41 @@ class ShapeInference { LogicalResult PropagateShapeIntoAttachedFunctions(Operation* op, int64_t max_iteration); + // Propagates any constant operand of call_op to the called function body's + // corresponding argument if the callee has only one use. + // + // TODO(b/154065712): Move this to a more general inter-procedural constant + // folding pass. + void PropagateConstantToCallee(CallOpInterface call_op, + SymbolRefAttr callee_sym, ModuleOp module); + + // Propagates any constant return value of the callee function to the call + // op's corresponding result. + void PropagateConstantFromCallee(CallOpInterface call_op, + SymbolRefAttr callee_sym, ModuleOp module); + + // Tries to compute the result of folding the op. This doesn't actually + // perform constant folding, it is just computes the equivalent constants. + // Returns whether it was able to compute constant values. + LogicalResult TryToFold(Operation* op); + private: // Mapping between ValuePort (which corresponds to an OpResult or smaller, - // e.g., first element of OpResult produded) to an Attribute if the ValuePort + // e.g., first element of OpResult produced) to an Attribute if the ValuePort // corresponds to a constant value. ValuePortResultMap results_; int64_t graph_version_; - MLIRContext* context_; Dialect* tf_dialect_; + + // TODO(b/154065712): Remove propagate_caller_callee_constants once using + // SCCP pass instead. + bool propagate_caller_callee_constants_; }; -ShapeInference::ShapeInference(int64_t graph_version, MLIRContext* context) - : graph_version_(graph_version) { - context_ = context; +ShapeInference::ShapeInference(int64_t graph_version, MLIRContext* context, + bool propagate_caller_callee_constants) + : graph_version_(graph_version), + propagate_caller_callee_constants_(propagate_caller_callee_constants) { tf_dialect_ = context->getRegisteredDialect<TensorFlowDialect>(); } @@ -581,7 +625,6 @@ ShapeHandle ShapeInference::ComputeOutputAsShape(OpResult result, auto ret = ComputeOutputComponent(front); if (!ret) continue; - RecordValue(front, ret); LLVM_DEBUG(ret.print(llvm::dbgs() << "\ncomputed result = ")); // If worklist is empty, then this is the root query op. @@ -602,6 +645,8 @@ ShapeHandle ShapeInference::ComputeOutputAsShape(OpResult result, } bool ShapeInference::InferShapeForSingleOperation(Operation* op) { + LLVM_DEBUG(op->print(llvm::dbgs() << "InferShapeForSingleOperation for "); + llvm::dbgs() << "\n"); assert(tf_dialect_ == op->getDialect()); // The shape function of these ops sometimes does not propagate subtypes // (handle shapes) for resource and variant types. We use a simple passthrough @@ -686,10 +731,14 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op) { size_t index = it.index(); // If the operand is constant, then convert it to Tensor. - ElementsAttr attr; - if (matchPattern(operand, m_Constant(&attr))) { + ValuePort vp(operand); + Attribute attr = ComputeOutputComponent(vp); + if (!attr && matchPattern(operand, m_Constant(&attr))) + RecordValue(vp, attr); + if (attr) { tensorflow::Tensor* input_tensor = &tensors[index]; - auto status = tensorflow::ConvertToTensor(attr, input_tensor); + auto status = + tensorflow::ConvertToTensor(attr.cast<ElementsAttr>(), input_tensor); if (status.ok()) { input_tensors[index] = input_tensor; } else { @@ -728,10 +777,12 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op) { !input_tensors[input]; }); if (requires_inputs) { + LLVM_DEBUG(llvm::dbgs() << "\trequired input\n"); std::vector<ShapeHandle> input_tensors_as_shapes; for (int input : llvm::seq<int>(0, c.num_inputs())) { if (c.requested_input_tensor_as_partial_shape(input) && !input_tensors[input]) { + LLVM_DEBUG(llvm::dbgs() << "Requesting " << input << " as shape\n"); auto op_result = op->getOperand(input).dyn_cast<OpResult>(); if (!op_result) continue; // Resize on first valid shape computed. @@ -865,45 +916,62 @@ LogicalResult ShapeInference::PropagateShapeToFunctions( return success(all_succeeded); } -// If the callee has only one use, propagates any constant operand of call_op to -// the called function body's corresponding argument. -// -// TODO(b/154065712): Move this to a more general inter-procedural constant -// folding pass. -void PropagateConstantToCallee(CallOpInterface call_op, - SymbolRefAttr callee_sym, ModuleOp module) { +void ShapeInference::PropagateConstantToCallee(CallOpInterface call_op, + SymbolRefAttr callee_sym, + ModuleOp module) { auto func = module.lookupSymbol<FuncOp>(callee_sym.getRootReference()); auto func_uses = SymbolTable::getSymbolUses(func, &module.getBodyRegion()); int num_uses = std::distance(func_uses->begin(), func_uses->end()); + if (num_uses != 1) return; + OpBuilder builder(&func.front().front()); Operation* op = call_op.getOperation(); - if (num_uses == 1) { - // If this is the only caller, and an operand is a constant, propagate - // the constant inside the function. - for (auto arg : func.getArguments()) { - auto operand = op->getOperand(arg.getArgNumber()).getDefiningOp(); - if (isa_and_nonnull<TF::ConstOp>(operand)) { - arg.replaceAllUsesWith(builder.clone(*operand)->getResult(0)); + // If this is the only caller, and an operand is a constant, propagate + // the constant value inside the function. + for (auto arg : func.getArguments()) { + auto operand = op->getOperand(arg.getArgNumber()); + if (propagate_caller_callee_constants_) { + if (isa_and_nonnull<TF::ConstOp>(operand.getDefiningOp())) { + arg.replaceAllUsesWith( + builder.clone(*operand.getDefiningOp())->getResult(0)); } + continue; } + + auto known_constant = ComputeOutputComponent(ValuePort(operand)); + if (!known_constant) continue; + LLVM_DEBUG(call_op.print(llvm::dbgs() << "Propagate to calee: "); + known_constant.print(llvm::dbgs() << " constant "); + llvm::dbgs() << "\n"); + RecordValue(ValuePort(arg), known_constant); } } -// Propagates any constant return value of the callee function to the call op's -// corresponding result. -void PropagateConstantFromCallee(CallOpInterface call_op, - SymbolRefAttr callee_sym, ModuleOp module) { +void ShapeInference::PropagateConstantFromCallee(CallOpInterface call_op, + SymbolRefAttr callee_sym, + ModuleOp module) { auto func = module.lookupSymbol<FuncOp>(callee_sym.getRootReference()); - // If the return value is a constant, replace the call result with a constant. + // If the return value is a constant, use the constant as the value of + // the call return. Operation* op = call_op.getOperation(); OpBuilder builder(op); builder.setInsertionPointAfter(op); for (auto retval : llvm::enumerate(func.front().getTerminator()->getOperands())) { - auto retval_op = retval.value().getDefiningOp(); - if (isa_and_nonnull<TF::ConstOp>(retval_op)) { - op->getResult(retval.index()) - .replaceAllUsesWith(builder.clone(*retval_op)->getResult(0)); + if (propagate_caller_callee_constants_) { + auto retval_op = retval.value().getDefiningOp(); + if (isa_and_nonnull<TF::ConstOp>(retval_op)) { + op->getResult(retval.index()) + .replaceAllUsesWith(builder.clone(*retval_op)->getResult(0)); + } + continue; + } + + ValuePort vp(retval.value()); + if (auto known_constant = ComputeOutputComponent(vp)) { + LLVM_DEBUG(known_constant.print(llvm::dbgs() << "Propagate constant "); + call_op.print(llvm::dbgs() << "from "); llvm::dbgs() << "\n"); + RecordValue(ValuePort(op->getResult(retval.index())), known_constant); } } } @@ -938,10 +1006,71 @@ LogicalResult ShapeInference::PropagateShapeIntoAttachedFunctions( return success(); } +LogicalResult ShapeInference::TryToFold(Operation* op) { + LLVM_DEBUG(op->print(llvm::dbgs() << "TryToFold "); llvm::dbgs() << "\n"); + // If any output result is known, then the op probably has been computed + // before. + if (op->getNumResults() > 0 && results_[ValuePort(op->getResult(0))]) + return success(); + + SmallVector<Attribute, 8> constant_operands(op->getNumOperands()); + SmallVector<OpFoldResult, 8> fold_results; + + // Check to see if any operands to the operation is constant and whether + // the operation knows how to constant fold itself. + bool some_unknown = false; + for (int i = 0, e = op->getNumOperands(); i != e; ++i) { + if (!(constant_operands[i] = + ComputeOutputComponent(ValuePort(op->getOperand(i))))) + some_unknown = true; + } + + // Attempt to constant fold the operation. + auto* abstract_op = op->getAbstractOperation(); + LogicalResult folded = failure(); + if (abstract_op) { + folded = abstract_op->foldHook(op, constant_operands, fold_results); + } + // Attempt dialect fallback if op's fold hook failed. + if (failed(folded)) { + Dialect* dialect = op->getDialect(); + if (!dialect) return failure(); + // Only attempt TF dialect fallback if there are no unknown operands. + if (some_unknown && dialect == tf_dialect_) return failure(); + SmallVector<Attribute, 8> constants; + if (failed(dialect->constantFoldHook(op, constant_operands, constants))) + return failure(); + fold_results.assign(constants.begin(), constants.end()); + } + + for (auto result : zip(op->getResults(), fold_results)) { + auto fold_result = std::get<1>(result); + Attribute attr = nullptr; + if ((attr = fold_result.dyn_cast<Attribute>())) { + RecordValue(ValuePort(std::get<0>(result)), attr); + } else { + auto value = fold_result.get<Value>(); + if ((attr = ComputeOutputComponent(ValuePort(value)))) + RecordValue(ValuePort(std::get<0>(result)), attr); + } + + if (ElementsAttr eattr = attr.dyn_cast_or_null<ElementsAttr>()) { + if (std::get<0>(result).getType() == eattr.getType()) continue; + + // Inserts a cast back to the original type if any user is not in the + // TF dialect. + Type old_type = std::get<0>(result).getType(); + std::get<0>(result).setType(eattr.getType()); + AddCastBackForUnsupportedNonTFUses(op, std::get<0>(result), tf_dialect_, + old_type); + } + } + + return success(); +} + LogicalResult ShapeInference::InferShapeUntilFixPoint(Region* region, int64_t max_iteration) { - // An operation folder that is used to attempt folding before inference._ - OperationFolder folder(context_); bool changed = true; // TODO(aminim): we could have a more efficient traversal by guiding the @@ -955,9 +1084,7 @@ LogicalResult ShapeInference::InferShapeUntilFixPoint(Region* region, region->walk([&](Operation* op) { if (auto infer_ti = dyn_cast<InferTypeOpInterface>(op)) { changed |= RefineWithInferTypeOpInterface(infer_ti, tf_dialect_); - // TODO(jpienaar): Debug why we can't just return here. We end up with - // additional constant due to the propagation of constant into attached - // function if we return already. + return; } if (op->getDialect() != tf_dialect_) { @@ -965,8 +1092,9 @@ LogicalResult ShapeInference::InferShapeUntilFixPoint(Region* region, return; } - // Before attempting inference, just try to fold the operation. - if (succeeded(folder.tryToFold(op))) return; + // Before attempting inference, just try to compute the folded + // value/shape. + if (succeeded(TryToFold(op))) return; // Best-effort shape inference in attached functions. Do not return // failure even if it doesn't get to fixed point. @@ -989,8 +1117,10 @@ LogicalResult ShapeInference::InferShapeUntilFixPoint(Region* region, LogicalResult InferShapeForFunction(FuncOp func, ArrayRef<ArrayRef<int64_t>> arg_shapes, - int64_t graph_version) { - ShapeInference context(graph_version, func.getContext()); + int64_t graph_version, + bool propagate_caller_callee_constants) { + ShapeInference context(graph_version, func.getContext(), + propagate_caller_callee_constants); if (arg_shapes.empty()) { if (failed(context.InferShapeUntilFixPoint(&func.getBody()))) return failure(); @@ -1014,7 +1144,7 @@ LogicalResult InferShapeForFunction(FuncOp func, ArrayRef<int64_t> shape = arg_shapes[i]; Type element_type; if (auto input_ty = func_type.getInput(i).dyn_cast<RankedTensorType>()) { - if (!input_ty || input_ty.getShape().size() != shape.size()) { + if (input_ty.getRank() != shape.size()) { return failure(); } element_type = input_ty.getElementType(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h index e36d8d56d6d..7486fd77388 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h @@ -30,9 +30,11 @@ namespace TF { // Given a list of refined shapes matching the function arguments of func, runs // shape inference over the function to propagate this updated information. // If arg_shapes are empty, then argument shapes will be left unchanged. -LogicalResult InferShapeForFunction(FuncOp func, - ArrayRef<ArrayRef<int64_t>> arg_shapes, - int64_t graph_version); +// TODO(b/154065712): Remove propagate_caller_callee_constants once using +// SCCP pass instead. +LogicalResult InferShapeForFunction( + FuncOp func, ArrayRef<ArrayRef<int64_t>> arg_shapes, int64_t graph_version, + bool propagate_caller_callee_constants = true); } // namespace TF diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc index acdfc0eb039..1a846398412 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc @@ -47,8 +47,15 @@ namespace { // This transformation pass propagate shapes on the TensorFlow graph. // It is a ModulePass in order to be able to change function types. -struct ShapeInference +class ShapeInference : public PassWrapper<ShapeInference, OperationPass<ModuleOp>> { + public: + ShapeInference() = default; + ShapeInference(const ShapeInference& that) { + propagate_caller_callee_constants_ = + that.propagate_caller_callee_constants_; + } + void runOnOperation() override { auto module = getOperation(); auto producer_or = tensorflow::GetTfGraphProducerVersion(module); @@ -58,10 +65,17 @@ struct ShapeInference } int64_t producer = producer_or.ValueOrDie(); for (auto func : module.getOps<FuncOp>()) { - if (failed(InferShapeForFunction(func, /*arg_shapes=*/{}, producer))) + if (failed(InferShapeForFunction(func, /*arg_shapes=*/{}, producer, + propagate_caller_callee_constants_))) return signalPassFailure(); } } + + private: + Option<bool> propagate_caller_callee_constants_{ + *this, "propagate-caller-callee-constants", + llvm::cl::desc("Propagate constants between callers and callees"), + llvm::cl::init(true)}; }; PassRegistration<ShapeInference> pass( diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc index b9e214470cd..5a059ce507c 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc @@ -14,23 +14,30 @@ limitations under the License. ==============================================================================*/ #include <memory> +#include <tuple> #include <type_traits> -#include "llvm/ADT/Optional.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/FormatVariadic.h" #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Block.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h" namespace mlir { namespace TFTPU { @@ -46,146 +53,184 @@ bool HasOutsideCompilationAttribute(Operation* op) { return op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr) != nullptr; } -// Returns whether all operands of `op` are from values inside the -// `input_value_set`. -bool OpContainsOperandsFromSet(Operation* op, - const llvm::SetVector<Value>& input_value_set) { - for (auto operand : op->getOperands()) - if (input_value_set.count(operand) == 0) return false; +Operation* GetOpOfValue(Value value) { + if (auto block_arg = value.dyn_cast<BlockArgument>()) + return block_arg.getOwner()->getParentOp(); - return true; + return value.getDefiningOp(); } -void RecordOutsideCompiledOpsAndUsages( - Operation* op, llvm::SmallSetVector<Operation*, 4>* outside_compiled_ops, - llvm::SetVector<Value>* outside_compiled_op_usages) { - if (HasOutsideCompilationAttribute(op) && - OpContainsOperandsFromSet(op, *outside_compiled_op_usages)) { - outside_compiled_ops->insert(op); - outside_compiled_op_usages->insert(op->getResults().begin(), - op->getResults().end()); +// Returns a set of ops that are outside compiled and can be extracted to before +// the TPU computation. These ops are either connected to the inputs of the TPU +// computation or other ops that can be extracted, and have no dependencies with +// other ops in the TPU computation that cannot be extracted. +llvm::SmallVector<Operation*, 4> FindOutsideCompiledOpsAtHead( + tf_device::ClusterOp cluster) { + llvm::SmallSetVector<Operation*, 4> head_outside_compiled_ops; + + auto cluster_ops = cluster.GetBody().without_terminator(); + for (Operation& cluster_op : cluster_ops) { + if (!HasOutsideCompilationAttribute(&cluster_op)) continue; + // An outside compiled op can be extracted if its operands are not from + // other ops in the cluster that cannot be extracted. + auto result = cluster_op.walk([&](Operation* op) { + for (Value operand : op->getOperands()) { + Operation* operand_op = GetOpOfValue(operand); + if (operand_op->isProperAncestor(cluster) || + cluster_op.isAncestor(operand_op) || + head_outside_compiled_ops.count(operand_op)) + continue; + + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + + if (!result.wasInterrupted()) head_outside_compiled_ops.insert(&cluster_op); } + + return head_outside_compiled_ops.takeVector(); } -// Traverses the MLIR graph and returns a set of ops that -// are connected to inputs of TPU computation and outside compiled. -void ExtractOutsideCompiledOpsConnectedToHead( - Value input_value, llvm::SetVector<Value>* values_used_in_host_cluster, - llvm::SmallSetVector<Operation*, 4>* outside_compiled_ops) { - llvm::SmallSetVector<Operation*, 4> parent_outside_compiled_ops_at_head; - for (auto& usage : input_value.getUses()) { - auto head_operation = usage.getOwner(); - RecordOutsideCompiledOpsAndUsages(head_operation, - &parent_outside_compiled_ops_at_head, - values_used_in_host_cluster); +// Parses TPU compilation and execution devices from a TPU cluster and returns +// the host device for the head and tail computations. If the TPU computation is +// replicated, kTPUReplicatedHost is returned instead. +LogicalResult GetHostDeviceForHeadTailComputation( + mlir::TF::RuntimeDevices devices, tf_device::ClusterOp cluster, + std::string* host_device) { + auto replicate = cluster.getParentOfType<tf_device::ReplicateOp>(); + if (replicate) { + *host_device = tensorflow::kTPUReplicatedHost; + return success(); } - // Traverse the graph and find all outside compiled ops connected from - // the `input_value`. - while (!parent_outside_compiled_ops_at_head.empty()) { - llvm::SmallSetVector<Operation*, 4> connected_outside_compiled_ops; - for (auto head_outside_compiled_op : parent_outside_compiled_ops_at_head) { - auto op_results = head_outside_compiled_op->getOpResults(); - for (auto op_result : op_results) { - for (auto& use : op_result.getUses()) { - auto connected_op = use.getOwner(); - RecordOutsideCompiledOpsAndUsages(connected_op, - &connected_outside_compiled_ops, - values_used_in_host_cluster); + auto num_cores_per_replica_attr = + cluster.getAttrOfType<IntegerAttr>(tensorflow::kNumCoresPerReplicaAttr); + if (!num_cores_per_replica_attr) + return cluster.emitOpError( + "cluster op missing `num_cores_per_replica` attribute"); + + if (num_cores_per_replica_attr.getInt() != 1) + return cluster.emitOpError( + "outside compilation is not supported with model parallelism."); + + auto topology_attr = + cluster.getAttrOfType<StringAttr>(tensorflow::kTopologyAttr); + if (!topology_attr) + return cluster.emitOpError("cluster op missing `topology` attribute"); + + auto device_assignment_attr = + cluster.getAttrOfType<mlir::ArrayAttr>(tensorflow::kDeviceAssignmentAttr); + if (!device_assignment_attr) + return cluster.emitOpError(llvm::formatv("requires attribute '{0}'", + tensorflow::kDeviceAssignmentAttr) + .str()); + + auto status_or_device_coodinates = + tensorflow::GetDeviceCoordinates(device_assignment_attr); + + if (!status_or_device_coodinates.ok()) + return cluster.emitError() + << "error in fetching tpu device coordinates: " + << status_or_device_coodinates.status().error_message(); + + // Determine compilation and execution devices. + auto status_or_tpu_device_assignment = + tensorflow::GetTPUCompilationAndExecutionDevices( + devices.device_names(), /*num_replicas=*/1, + /*num_cores_per_replica=*/1, topology_attr.getValue(), + status_or_device_coodinates.ConsumeValueOrDie()); + if (!status_or_tpu_device_assignment.ok()) + return cluster.emitError() + << "error in fetching TPU compilation/execution devices: " + << status_or_tpu_device_assignment.status().error_message(); + auto& tpu_device_assignment = status_or_tpu_device_assignment.ValueOrDie(); + + *host_device = tpu_device_assignment.tpu_devices[0][0].host; + return success(); +} + +// Moves head outside compiled ops into its own `tf_device.LaunchOp` +// computation. +tf_device::LaunchOp CreateHeadComputation( + OpBuilder* builder, tf_device::ClusterOp cluster, + llvm::ArrayRef<Operation*> head_outside_compiled_ops, + llvm::StringRef host_device) { + Block* launch_block = new Block; + for (Operation* head_outside_compiled_op : head_outside_compiled_ops) + head_outside_compiled_op->moveBefore(launch_block, launch_block->end()); + + // Find results of ops in head computation that needs to returned. + llvm::SmallVector<Value, 4> launch_results; + llvm::SmallVector<Type, 4> launch_result_types; + for (Operation& head_outside_compiled_op : *launch_block) { + for (Value result : head_outside_compiled_op.getResults()) { + bool has_uses_in_cluster = false; + for (Operation* user : result.getUsers()) { + if (user->getParentRegion() && + cluster.body().isAncestor(user->getParentRegion())) { + has_uses_in_cluster = true; + break; } } - } - - outside_compiled_ops->insert(parent_outside_compiled_ops_at_head.begin(), - parent_outside_compiled_ops_at_head.end()); - std::swap(parent_outside_compiled_ops_at_head, - connected_outside_compiled_ops); - } -} - -// TODO(hongjunchoi): Also handle ops without inputs that are outside -// compiled. -// -// Returns set of ops that are outside compiled and are directly connected -// to inputs to the TPU computation. -llvm::SmallSetVector<Operation*, 4> IdentifyOutsideCompiledOpsAtHead( - tf_device::ClusterOp tpu_cluster) { - llvm::SmallSetVector<Operation*, 4> outside_compiled_at_head_ops; - llvm::SetVector<Value> values_used_in_cluster; - auto& cluster_region = tpu_cluster.body(); - getUsedValuesDefinedAbove(cluster_region, cluster_region, - values_used_in_cluster); - - auto input_value_list = llvm::to_vector<8>(values_used_in_cluster); - for (auto input_value : input_value_list) - ExtractOutsideCompiledOpsConnectedToHead( - input_value, &values_used_in_cluster, &outside_compiled_at_head_ops); - return outside_compiled_at_head_ops; -} - -// Returns output values of extracted outside compiled cluster at head that -// are used by the TPU computation. -llvm::SmallVector<Value, 8> GetHeadExtractedClusterOutputs( - const llvm::SmallSetVector<Operation*, 4>& head_outside_compiled_ops) { - llvm::SmallVector<Value, 8> outputs; - outputs.reserve(head_outside_compiled_ops.size()); - - for (auto op : head_outside_compiled_ops) { - for (Operation* user : op->getUsers()) { - if (!head_outside_compiled_ops.count(user)) { - outputs.append(op->result_begin(), op->result_end()); - break; + if (has_uses_in_cluster) { + launch_results.push_back(result); + launch_result_types.push_back(result.getType()); } } } - return outputs; + builder->setInsertionPoint(cluster); + auto launch = builder->create<tf_device::LaunchOp>( + cluster.getLoc(), builder->getStringAttr(host_device), + launch_result_types); + launch.body().push_back(launch_block); + + builder->setInsertionPointToEnd(&launch.GetBody()); + builder->create<tf_device::ReturnOp>(cluster.getLoc(), launch_results); + + for (auto result : llvm::zip(launch_results, launch.getResults())) + replaceAllUsesInRegionWith(std::get<0>(result), std::get<1>(result), + cluster.body()); + + return launch; } -// Creates new tf_device.launch op with outside compiled ops extracted -// from the head of TPU computation. -llvm::Optional<tf_device::LaunchOp> IsolateHeadExtractedOpsToLaunchOp( - OpBuilder* builder, tf_device::ClusterOp cluster, - const llvm::SmallSetVector<Operation*, 4>& head_outside_compiled_ops) { - if (head_outside_compiled_ops.empty()) - return llvm::Optional<tf_device::LaunchOp>(); - - // Create tf_device.launch op to separate all extracted outside compiled ops - // before the tf_device.cluster. - auto output_values = - GetHeadExtractedClusterOutputs(head_outside_compiled_ops); - - llvm::SmallVector<Type, 8> output_return_types; - output_return_types.reserve(output_values.size()); - for (auto output : output_values) - output_return_types.emplace_back(output.getType()); - - builder->setInsertionPoint(cluster); - auto host_launch_op = builder->create<tf_device::LaunchOp>( - cluster.getLoc(), builder->getStringAttr(""), output_return_types); - - // Replace all usages of outside compiled ops that are used in TPU - // computation with the results of the above created launch op. - for (auto output_and_index : llvm::enumerate(output_values)) { - auto output_index = output_and_index.index(); - auto output = output_and_index.value(); - for (auto& use : output.getUses()) { - if (!head_outside_compiled_ops.count(use.getOwner())) - use.set(host_launch_op.getResult(output_index)); +// Removes aliased outputs in cluster from head computation after head +// computation has been extracted. +void RemoveHeadComputationAliasedOutputs(OpBuilder* builder, + tf_device::LaunchOp head_computation, + tf_device::ClusterOp cluster) { + llvm::SmallVector<Value, 4> used_old_cluster_results; + llvm::SmallVector<Value, 4> new_cluster_results; + llvm::SmallVector<Type, 4> new_cluster_result_types; + Operation* cluster_terminator = cluster.GetBody().getTerminator(); + for (auto result : + llvm::zip(cluster_terminator->getOperands(), cluster.getResults())) { + Value cluster_terminator_operand = std::get<0>(result); + if (cluster_terminator_operand.getDefiningOp() == head_computation) { + std::get<1>(result).replaceAllUsesWith(cluster_terminator_operand); + } else { + new_cluster_results.push_back(cluster_terminator_operand); + new_cluster_result_types.push_back(cluster_terminator_operand.getType()); + used_old_cluster_results.push_back(std::get<1>(result)); } } - // Create terminator op for the newly created launch op. - host_launch_op.body().push_back(new Block()); - builder->setInsertionPointToEnd(&host_launch_op.GetBody()); - auto terminator = builder->create<tf_device::ReturnOp>( - host_launch_op.getLoc(), output_values); + if (new_cluster_results.size() == cluster.getNumResults()) return; - // Move all outside compile ops from cluster op to launch op. - for (auto outside_compiled_op : head_outside_compiled_ops) - outside_compiled_op->moveBefore(terminator); + builder->setInsertionPoint(cluster); + auto new_cluster = builder->create<tf_device::ClusterOp>( + cluster.getLoc(), new_cluster_result_types, + /*operands=*/llvm::ArrayRef<Value>{}, cluster.getAttrs()); + new_cluster.body().takeBody(cluster.body()); + new_cluster.GetBody().getTerminator()->setOperands(new_cluster_results); - return host_launch_op; + for (auto result : + llvm::zip(used_old_cluster_results, new_cluster.getResults())) + std::get<0>(result).replaceAllUsesWith(std::get<1>(result)); + + cluster.erase(); } struct TPUExtractHeadTailOutsideCompilation @@ -202,17 +247,25 @@ void TPUExtractHeadTailOutsideCompilation::runOnOperation() { return signalPassFailure(); OpBuilder builder(&getContext()); - module.walk([&](tf_device::ClusterOp cluster) { - auto head_outside_compiled_ops = IdentifyOutsideCompiledOpsAtHead(cluster); - IsolateHeadExtractedOpsToLaunchOp(&builder, cluster, - head_outside_compiled_ops); + llvm::SmallVector<tf_device::ClusterOp, 4> clusters; + module.walk( + [&](tf_device::ClusterOp cluster) { clusters.push_back(cluster); }); - // TODO(b/156030523): Update device attribute of newly created host launch - // op as well as enclosing Replicate op (if TPU computation is replicated) - // with host device names. + for (tf_device::ClusterOp cluster : clusters) { + llvm::SmallVector<Operation*, 4> head_outside_compiled_ops = + FindOutsideCompiledOpsAtHead(cluster); + if (head_outside_compiled_ops.empty()) continue; + std::string host_device; + if (failed(GetHostDeviceForHeadTailComputation(devices, cluster, + &host_device))) + return signalPassFailure(); - // TODO(b/155115766): Implement tail outside compiled op extraction. - }); + tf_device::LaunchOp head_computation = CreateHeadComputation( + &builder, cluster, head_outside_compiled_ops, host_device); + RemoveHeadComputationAliasedOutputs(&builder, head_computation, cluster); + + // TODO(b/157160906): Implement tail outside compiled op extraction. + } } } // anonymous namespace diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc index 4281b85bd7f..58b3bf8bf7d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc @@ -19,19 +19,24 @@ limitations under the License. #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/FormatVariadic.h" #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" -#include "tensorflow/core/platform/logging.h" namespace mlir { namespace TFTPU { namespace { -constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation"; +constexpr char kAncestorsAttr[] = "ancestors"; constexpr char kDeviceAttr[] = "device"; +constexpr char kKeyAttr[] = "key"; +constexpr char kShapesAttr[] = "shapes"; +constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation"; // Mapping for `_xla_outside_compilation` attribute to ops of a cluster. using OutsideClusterMap = @@ -116,6 +121,85 @@ void PropagateParallelExecuteReturnToReplicate( parallel_execute_op.execute_outputs()); } +// Extracts all externally provided operands of `cluster_ops`. +llvm::SmallSetVector<Value, 4> GetExternalOperands( + const llvm::SmallVector<Operation*, 8>& cluster_ops) { + llvm::SmallSetVector<Value, 4> external_values; + + for (Operation* op : cluster_ops) { + for (Value v : op->getOperands()) { + Operation* defining_op = v.getDefiningOp(); + if (!defining_op) continue; + bool is_external = llvm::none_of(cluster_ops, [&](Operation* cluster_op) { + return defining_op == cluster_op; + }); + + if (is_external) external_values.insert(v); + } + } + + return external_values; +} + +void MoveOutsideCompiledOps( + tf_device::ClusterOp tpu_cluster, llvm::StringRef outside_cluster_name, + tf_device::LaunchOp host_launch_op, + const llvm::SmallVector<Operation*, 8>& cluster_ops, + const llvm::SmallSetVector<Value, 4>& external_inputs, + const llvm::SmallVector<Value, 4>& external_outputs) { + if (external_inputs.empty() && external_outputs.empty()) { + MoveOutsideClusterOpsToLaunchOp(host_launch_op, cluster_ops); + return; + } + + OpBuilder builder(host_launch_op.GetBody().getTerminator()); + auto result_type = + RankedTensorType::get({}, builder.getType<TF::StringType>()); + + std::string txt_metadata; + std::string txt_module; + // TODO(b/157054714): Use a better abstraction instead of _TPUCompileMlirOp + // and _XlaRecvAtHostOp and _XlaSendFromHostOp. + + // A placeholder _TpuCompileMlirOp is created because it is required input to + // XlaRecvAtHostOp and XlaSendFromHostOp but the _TpuCompileMlirOp has not yet + // been created for the TPU cluster that contains the outside compiled ops. + // This placeholder should be replaced by the TPU cluster _TPUCompileMlirOp in + // a subsequent pass. + auto compile_op = builder.create<TF::_TPUCompileMlirOp>( + tpu_cluster.getLoc(), /*compilation_status=*/result_type, /*program=*/ + llvm::ArrayRef<Type>{result_type}, llvm::ArrayRef<Value>{}, txt_module, + txt_metadata); + + llvm::SmallVector<Type, 4> host_output_types; + for (const auto& external_input : external_inputs) + host_output_types.push_back(external_input.getType()); + + std::string communication_key = + llvm::formatv("host_compute_channel_{0}", outside_cluster_name).str(); + // XlaRecvAtHostOp takes both the program key(dynamic_key) from the + // _TpuCompileMlirOp and the communication_key. + auto recv_at_host = builder.create<TF::_XlaRecvAtHostOp>( + tpu_cluster.getLoc(), host_output_types, + /*dynamic_key=*/compile_op.getResult(1), + builder.getStringAttr(communication_key), + builder.getIntegerAttr(builder.getIntegerType(64), 0)); + + // TODO(b/156006200): Handle host->device outputs. + builder.setInsertionPoint(cluster_ops.front()); + auto host_compute = builder.create<TF::_HostComputeMlirOp>( + tpu_cluster.getLoc(), llvm::ArrayRef<Type>{}, + external_inputs.getArrayRef(), llvm::ArrayRef<NamedAttribute>{}); + host_compute.setAttr(kAncestorsAttr, builder.getArrayAttr({})); + host_compute.setAttr(kShapesAttr, builder.getArrayAttr({})); + host_compute.setAttr(kKeyAttr, builder.getStringAttr(communication_key)); + MoveOutsideClusterOpsToLaunchOp(host_launch_op, cluster_ops); + + for (auto result : llvm::zip(external_inputs, recv_at_host.getResults())) + mlir::replaceAllUsesInRegionWith(std::get<0>(result), std::get<1>(result), + host_launch_op.body()); +} + // Creates a `parallel_execute` op in place of launch with 'clusters` and // 'launch` as regions. void CreateParallelExecuteFromOutsideClusters( @@ -123,7 +207,7 @@ void CreateParallelExecuteFromOutsideClusters( OpBuilder builder(tpu_cluster); // Create parallel_execute regions. The original TPU cluster computation // is the extra region. - int num_regions = 1 + clusters.size(); + const int num_regions = 1 + clusters.size(); auto parallel_execute_op = builder.create<tf_device::ParallelExecuteOp>( tpu_cluster.getLoc(), num_regions, tpu_cluster.results().getTypes()); @@ -134,9 +218,18 @@ void CreateParallelExecuteFromOutsideClusters( Block& outside_block = parallel_execute_op.GetRegionBlockWithIndex(cluster.index()); builder.setInsertionPointToEnd(&outside_block); - tf_device::LaunchOp launch_op = + tf_device::LaunchOp host_launch_op = CreateLaunchOpForOutsideCluster(&builder, cluster_ops.back()); - MoveOutsideClusterOpsToLaunchOp(launch_op, cluster_ops); + + // Determine if there are any inputs that are provided out of cluster. + auto external_inputs = GetExternalOperands(cluster_ops); + llvm::SmallVector<Value, 4> external_outputs; + // TODO(b/156006200): Compute the external outputs. + + MoveOutsideCompiledOps(tpu_cluster, cluster.value().getFirst(), + host_launch_op, cluster_ops, external_inputs, + external_outputs); + builder.setInsertionPointToEnd(&outside_block); // TODO(b/154363171): Handle returns from OutsideCompiled parallel_execute // regions either through communication with TPU parallel_execute regions @@ -146,12 +239,13 @@ void CreateParallelExecuteFromOutsideClusters( } // Move the launch body to last parallel_execute block. - Block& inside_block = + Block& parallel_execute_tpu_block = parallel_execute_op.GetRegionBlockWithIndex(num_regions - 1); - builder.setInsertionPointToEnd(&inside_block); + builder.setInsertionPointToEnd(¶llel_execute_tpu_block); builder.create<tf_device::ReturnOp>(tpu_cluster.getLoc(), tpu_cluster.getResults()); - tpu_cluster.getOperation()->moveBefore(inside_block.getTerminator()); + tpu_cluster.getOperation()->moveBefore( + parallel_execute_tpu_block.getTerminator()); PropagateParallelExecuteReturnToReplicate(parallel_execute_op); // TODO(b/154363171): Handle returns from OutsideCompiled parallel_execute diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc index f5e9da915c8..696882cd105 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc @@ -64,19 +64,14 @@ static llvm::cl::opt<bool> tpu_compile_metadata_debug( "'tf._TPUCompileMlir' op as a proto debug string")); constexpr char kNumReplicasAttr[] = "num_replicas"; -constexpr char kNumCoresPerReplicaAttr[] = "num_cores_per_replica"; constexpr char kStepMarkerLocationAttr[] = "step_marker_location"; constexpr char kPaddingMapAttr[] = "padding_map"; -constexpr char kTopologyAttr[] = "topology"; -constexpr char kDeviceAssignmentAttr[] = "device_assignment"; constexpr char kDeviceAttr[] = "device"; constexpr char kDevicesAttr[] = "devices"; constexpr char kVersionsAttr[] = "tf.versions"; constexpr char kBadStringArrayElementMsg[] = "bad '{0}' attribute at index {1}, not a string"; -constexpr char kBadIntArrayElementMsg[] = - "bad '{0}' attribute at index {1}, not an int"; constexpr char kBadArrayElementMsg[] = "bad '{0}' attribute at index {1} with value '{2}': failed to parse to {3}"; constexpr char kBadArrayAttrLengthMsg[] = @@ -163,32 +158,6 @@ LogicalResult EncapsulateFuncAndSerialize(FuncOp entry_func, return success(); } -// Extracts device coordinates from a device assignment attribute on an op. -LogicalResult GetDeviceCoordinates( - tf_device::ClusterFuncOp op, - llvm::SmallVectorImpl<int64_t>* device_assignment) { - auto device_assignment_attr = - op.getAttrOfType<ArrayAttr>(kDeviceAssignmentAttr); - if (!device_assignment_attr) - return op.emitOpError(CreateMissingAttributeMsg(kDeviceAssignmentAttr)); - - device_assignment->reserve(device_assignment_attr.size()); - - for (auto device_coordinate_and_idx : - llvm::enumerate(device_assignment_attr)) { - auto device_coordinate = - device_coordinate_and_idx.value().dyn_cast<IntegerAttr>(); - if (!device_coordinate) - return op.emitOpError(llvm::formatv(kBadIntArrayElementMsg, - kDeviceAssignmentAttr, - device_coordinate_and_idx.index())); - - device_assignment->push_back(device_coordinate.getInt()); - } - - return success(); -} - // Populates a TPUCompileMetadataProto with StepMarkerLocation from a // `tf_device::ClusterFuncOp`. LogicalResult SetMetadataProtoStepMarkerLocation( @@ -468,6 +437,18 @@ void AssignDevicesToReplicate( builder->getStrArrayAttr(devices_by_core))); } + // For data parallelism, also add replicated host devices, as these are + // necessary for outside compilation. + if (num_cores_per_replica == 1) { + llvm::SmallVector<StringRef, 8> hosts; + hosts.reserve(num_replicas); + for (int replica = 0; replica < num_replicas; ++replica) + hosts.push_back(tpu_devices[replica][0].host); + + device_attrs.push_back(builder->getNamedAttr( + tensorflow::kTPUReplicatedHost, builder->getStrArrayAttr(hosts))); + } + replicate.setAttr(kDevicesAttr, builder->getDictionaryAttr(device_attrs)); } @@ -661,27 +642,41 @@ LogicalResult Rewrite( : nullptr; if (replicate) num_replicas = replicate.n().getLimitedValue(); - auto num_cores_per_replica_attr = - cluster_func.getAttrOfType<IntegerAttr>(kNumCoresPerReplicaAttr); + auto num_cores_per_replica_attr = cluster_func.getAttrOfType<IntegerAttr>( + tensorflow::kNumCoresPerReplicaAttr); if (!num_cores_per_replica_attr) return cluster_func.emitOpError( - CreateMissingAttributeMsg(kNumCoresPerReplicaAttr)); + CreateMissingAttributeMsg(tensorflow::kNumCoresPerReplicaAttr)); int num_cores_per_replica = num_cores_per_replica_attr.getInt(); - auto topology_attr = cluster_func.getAttrOfType<StringAttr>(kTopologyAttr); + auto topology_attr = + cluster_func.getAttrOfType<StringAttr>(tensorflow::kTopologyAttr); if (!topology_attr) - return cluster_func.emitOpError(CreateMissingAttributeMsg(kTopologyAttr)); + return cluster_func.emitOpError( + CreateMissingAttributeMsg(tensorflow::kTopologyAttr)); - llvm::SmallVector<int64_t, 6> device_assignment; - if (failed(GetDeviceCoordinates(cluster_func, &device_assignment))) - return failure(); + auto device_assignment_attr = cluster_func.getAttrOfType<mlir::ArrayAttr>( + tensorflow::kDeviceAssignmentAttr); + if (!device_assignment_attr) + return cluster_func.emitOpError( + llvm::formatv("requires attribute '{0}'", + tensorflow::kDeviceAssignmentAttr) + .str()); + + auto status_or_device_coodinates = + tensorflow::GetDeviceCoordinates(device_assignment_attr); + if (!status_or_device_coodinates.ok()) + return cluster_func.emitError() + << "error in fetching tpu device coordinates: " + << status_or_device_coodinates.status().error_message(); // Determine compilation and execution devices. auto status_or_tpu_device_assignment = tensorflow::GetTPUCompilationAndExecutionDevices( devices, num_replicas, num_cores_per_replica, - topology_attr.getValue(), device_assignment); + topology_attr.getValue(), + status_or_device_coodinates.ConsumeValueOrDie()); if (!status_or_tpu_device_assignment.ok()) return cluster_func.emitError() << "error in fetching TPU compilation/execution devices: " @@ -706,6 +701,19 @@ LogicalResult Rewrite( std::move(tpu_device_assignment.xla_device_assignment), builder); if (!compile_op) return failure(); + // This replaces _TPUCompileMlir placeholder ops that are required + // by XlaRecvAtHost and XlaSendFromHost ops add in earlier pass. + // TODO(b/157054714): When a better abstraction instead of _TPUCompileMlirOp + // and _XlaRecvAtHostOp and _XlaSendFromHostOp are used, update to a more + // structured lowering. + if (auto parallel_op = llvm::dyn_cast<tf_device::ParallelExecuteOp>( + cluster_func.getParentOp())) { + parallel_op.walk([&](TF::_TPUCompileMlirOp parallel_compile_op) { + parallel_compile_op.replaceAllUsesWith(compile_op); + parallel_compile_op.erase(); + }); + } + // After rewrite, find if there is a TPUCompilationResultOp in the block with // the same _tpu_replicate attribute and replace it with the result of the // compile op. This op is used as a placeholder to hook during graph creation diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc index 9e8745918e3..ec4a25c6fdd 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc @@ -229,7 +229,7 @@ AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping( mapping.emplace_back(it->second, std::move(while_args)); } // Sort the mapping according to execute operand order. - llvm::sort(mapping); + llvm::sort(mapping, llvm::less_first()); // Populate the `retval_index_for_sharding` field of the argument metadate. for (auto entry : llvm::enumerate(execute.device_var_reads_indices())) { int64_t arg_index = entry.value().cast<IntegerAttr>().getInt(); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc index 75fcede8fbb..2bf55922d4b 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc @@ -782,4 +782,22 @@ StatusOr<std::unique_ptr<GraphDef>> ConvertMlirToGraphdef( return graphdef; } +stream_executor::port::Status ConvertMlirFunctionToFunctionLibraryDef( + mlir::FuncOp func, const GraphExportConfig& configs, + FunctionDef* function_def) { + Dialect* tf_dialect = func.getContext()->getRegisteredDialect("tf"); + FunctionDefLibrary flib; + TF_RETURN_IF_ERROR( + Exporter::ConvertLibFunction(configs, tf_dialect, func, &flib)); + for (auto& func_def : flib.function()) { + if (func_def.signature().name() == func.getName()) { + *function_def = func_def; + return Status::OK(); + } + } + return errors::InvalidArgument( + "Function couldn't be found in the FunctionDefLibrary after converting " + "from MLIR"); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h index 2d522f6031e..a5aebd16146 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h @@ -18,6 +18,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "llvm/ADT/StringRef.h" +#include "mlir/IR/Function.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project @@ -50,6 +51,12 @@ stream_executor::port::Status ConvertMlirToGraph( stream_executor::port::Status ConvertMlirToGraph( mlir::ModuleOp module, const GraphExportConfig& configs, std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def); + +// Converts an MLIR function and adds it to a FunctionLibraryDefinition. +stream_executor::port::Status ConvertMlirFunctionToFunctionLibraryDef( + mlir::FuncOp func, const GraphExportConfig& configs, + FunctionDef* function_def); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_EXPORT_GRAPHDEF_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index a613ce1f920..bd63a3b224f 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -978,7 +978,6 @@ StatusOr<mlir::Type> ImporterBase::InferOutputType(const Node& node, int idx, if (dtype == DT_RESOURCE) { const AttrValue* dtype_attr = node.attrs().Find("_handle_dtypes"); const AttrValue* shape_attr = node.attrs().Find("_handle_shapes"); - LOG(INFO) << dtype_attr << " " << shape_attr; if (dtype_attr && shape_attr) { if (dtype_attr->list().type().empty()) { return errors::InvalidArgument( @@ -1169,8 +1168,18 @@ StatusOr<mlir::Attribute> ImporterBase::ConvertAttributeValue( return builder_.getArrayAttr( llvm::makeArrayRef(attrs.begin(), attrs.end())); } - case AttrValue::kFunc: - return errors::Unknown("kFunc type should be handled separately!"); + case AttrValue::kFunc: { + // TODO(b/156546237): Unify kFunc/NameAttrList attribute representation. + // Currently kFunc/NameAttrList attributes in a kList/repeated AttrValue + // will not use this representation. + NamedAttrList attrs; + for (const auto& func_attr : value.func().attr()) { + TF_ASSIGN_OR_RETURN(auto attr, ConvertAttributeValue(func_attr.second)); + attrs.push_back(builder_.getNamedAttr(func_attr.first, attr)); + } + auto func_attrs = builder_.getDictionaryAttr(attrs); + return mlir::TF::FuncAttr::get(context_, value.func().name(), func_attrs); + } case AttrValue::VALUE_NOT_SET: return builder_.getUnitAttr(); // kPlaceholder is not implemented. diff --git a/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.cc b/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.cc index 06805e633e2..d7b511094d3 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.cc @@ -38,6 +38,7 @@ inline static void Log(BridgeLoggerConfig::PrintCallbackFn print_callback, std::unique_ptr<llvm::raw_ostream> os; std::string filepath; if (CreateFileForDumping(name, &os, &filepath).ok()) print_callback(*os); + VLOG(1) << "Dumped MLIR module to " << filepath; } void BridgeLoggerConfig::printBeforeIfEnabled(mlir::Pass* pass, diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc index e8ca691f961..fd1ba3b1901 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc @@ -20,6 +20,7 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/Dialect.h" // from @llvm-project #include "mlir/IR/Function.h" // from @llvm-project @@ -247,9 +248,10 @@ Status RefineShapes(llvm::ArrayRef<TensorShape> arg_shapes, static void RegisterDialects() { static bool init_once = []() { - mlir::registerDialect<mlir::tf_executor::TensorFlowExecutorDialect>(); - mlir::registerDialect<mlir::TF::TensorFlowDialect>(); mlir::registerDialect<mlir::StandardOpsDialect>(); + mlir::registerDialect<mlir::TF::TensorFlowDialect>(); + mlir::registerDialect<mlir::shape::ShapeDialect>(); + mlir::registerDialect<mlir::tf_executor::TensorFlowExecutorDialect>(); mlir::registerDialect<mlir::xla_hlo::XlaHloDialect>(); return true; }(); @@ -305,6 +307,10 @@ Status ConvertMLIRToXlaComputation( // invocation. tf2xla.addNestedPass<mlir::FuncOp>( mlir::xla_hlo::createLegalizeTFPass(false)); + // In order to export to XLA, we must sink constants to control flow regions, + // since XLA uses functional control flow. + tf2xla.addNestedPass<mlir::FuncOp>( + mlir::xla_hlo::createSinkConstantsToControlFlowPass()); if (VLOG_IS_ON(1)) { // Print the whole module after each pass which requires disabling diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc index 06c10c26835..282b7ad3139 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc @@ -26,9 +26,9 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringRef.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/FormatVariadic.h" +#include "mlir/IR/Attributes.h" // from @llvm-project #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -39,6 +39,12 @@ limitations under the License. #include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { + +const char* const kTPUReplicatedHost = "TPU_REPLICATED_HOST"; +const char* const kNumCoresPerReplicaAttr = "num_cores_per_replica"; +const char* const kTopologyAttr = "topology"; +const char* const kDeviceAssignmentAttr = "device_assignment"; + // Device coordinates are defined as (x, y, z, core), thus resulting in a rank 4 // topology. constexpr int kTPUTopologyRank = 4; @@ -46,8 +52,8 @@ constexpr int kTPUTopologyRank = 4; constexpr char kDeviceTPUSystem[] = "TPU_SYSTEM"; constexpr char kDeviceTPU[] = "TPU"; constexpr char kTPUReplicatedCore[] = "TPU_REPLICATED_CORE"; -constexpr char kTopologyAttr[] = "topology"; -constexpr char kDeviceAssignmentAttr[] = "device_assignment"; +constexpr char kBadIntArrayElementMsg[] = + "bad '{0}' attribute at index {1}, not an int"; using Device = DeviceNameUtils::ParsedName; using Devices = llvm::ArrayRef<DeviceNameUtils::ParsedName>; @@ -417,6 +423,27 @@ GetGeneralTPUExecutionDeviceAssignment( } // anonymous namespace +StatusOr<llvm::SmallVector<int64_t, 8>> GetDeviceCoordinates( + mlir::ArrayAttr device_assignment_attr) { + llvm::SmallVector<int64_t, 8> device_coordinates; + device_coordinates.reserve(device_assignment_attr.size()); + + for (auto device_coordinate_and_idx : + llvm::enumerate(device_assignment_attr)) { + auto device_coordinate = + device_coordinate_and_idx.value().dyn_cast<mlir::IntegerAttr>(); + if (!device_coordinate) + return errors::InvalidArgument( + llvm::formatv(kBadIntArrayElementMsg, kDeviceAssignmentAttr, + device_coordinate_and_idx.index()) + .str()); + + device_coordinates.push_back(device_coordinate.getInt()); + } + + return device_coordinates; +} + StatusOr<TPUDeviceAssignment> GetTPUCompilationAndExecutionDevices( Devices devices, int num_replicas, int num_cores_per_replica, llvm::StringRef topology_attr, diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h index 5fdb6b8768b..6bb541ab683 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h @@ -22,6 +22,7 @@ limitations under the License. #include "llvm/ADT/Optional.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" +#include "mlir/IR/Attributes.h" // from @llvm-project #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/util/device_name_utils.h" @@ -30,6 +31,11 @@ limitations under the License. namespace tensorflow { using stream_executor::port::StatusOr; +extern const char* const kTPUReplicatedHost; +extern const char* const kNumCoresPerReplicaAttr; +extern const char* const kTopologyAttr; +extern const char* const kDeviceAssignmentAttr; + // A TPU device for execution alongside its associated host CPU device. struct TPUDeviceAndHost { TPUDeviceAndHost() {} @@ -67,6 +73,10 @@ struct TPUDeviceAssignment { llvm::Optional<xla::DeviceAssignmentProto> xla_device_assignment; }; +// Extracts device coordinates from a device assignment attribute on an op. +StatusOr<llvm::SmallVector<int64_t, 8>> GetDeviceCoordinates( + mlir::ArrayAttr device_assignment_attr); + // Finds the TPU compilation device and execution devices from `devices` for a // TPU computation subgraph. Compilation device is determined from looking up // all TPU_SYSTEM:0 devices and choosing the CPU device associated to the first diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc index 7ac5635a6e4..a70e93a0195 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc @@ -19,6 +19,8 @@ limitations under the License. #include <tuple> #include "llvm/Support/FormatVariadic.h" +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/protobuf/tpu/topology.pb.h" @@ -596,5 +598,29 @@ TEST(TPURewriteDeviceUtilTest, ValidGeneralDeviceAssignmentMesh1x2x1x3) { EXPECT_EQ(computation_device_2.replica_device_ids(1), 3); } +TEST(TPURewriteDeviceUtilTest, TestGetDeviceCoordinates) { + mlir::MLIRContext context; + mlir::Builder builder(&context); + auto device_assignment_attr = builder.getI64ArrayAttr({1, 2, 3}); + auto status_or_device_coodinates = + GetDeviceCoordinates(device_assignment_attr); + ASSERT_TRUE(status_or_device_coodinates.ok()); + auto device_coordinates = status_or_device_coodinates.ConsumeValueOrDie(); + EXPECT_EQ(device_coordinates[0], 1); + EXPECT_EQ(device_coordinates[1], 2); + EXPECT_EQ(device_coordinates[2], 3); +} + +TEST(TPURewriteDeviceUtilTest, TestInvalidAttrForDeviceAssignmentDisallowed) { + mlir::MLIRContext context; + mlir::Builder builder(&context); + auto device_assignment_attr = builder.getF32ArrayAttr({1.0, 2.0, 3.0}); + auto status_or_device_coodinates = + GetDeviceCoordinates(device_assignment_attr); + ASSERT_TRUE(!status_or_device_coodinates.ok()); + EXPECT_EQ(status_or_device_coodinates.status().error_message(), + "bad 'device_assignment' attribute at index 0, not an int"); +} + } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h b/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h index 545183a052b..9c98c9b0e19 100644 --- a/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h +++ b/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h @@ -26,7 +26,7 @@ limitations under the License. #include "mlir/IR/Dialect.h" // from @llvm-project #include "mlir/IR/OpImplementation.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project -#include "mlir/Interfaces/SideEffects.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project namespace mlir { diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc b/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc index b1c4b1beae1..f47485d0214 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc @@ -231,8 +231,14 @@ StatusOr<std::vector<uint8_t>> tensorflow::kernel_gen::GenerateCubinForTfCode( xla::mlir_gpu::LowerLHLOToGPU(module.get(), tile_sizes, unroll_factors, /*collapseParallelLoops=*/false)); TF_RETURN_IF_ERROR(xla::mlir_gpu::LowerKernelBodiesToNVVM(module.get())); - TF_RETURN_IF_ERROR( - PropagateStaticShapeKnowledgeToKernel(module.get(), same_shape)); + // TODO(b/156985522): Figure out why we get a segfault when generating Tanh + // with 'same_shape' containing {0, 1}. We would also get the crash if we + // unconditionally call PropagateStaticShapeKnowledgeToKernel while + // 'same_shape' is empty. + if (!same_shape.empty()) { + TF_RETURN_IF_ERROR( + PropagateStaticShapeKnowledgeToKernel(module.get(), same_shape)); + } mlir::OwningModuleRef kernel_module = xla::mlir_gpu::ExtractKernelModule(*module).ValueOrDie(); diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index 12334e463fa..179a637ec7b 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -193,6 +193,24 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "xla_sink_constants_to_control_flow", + srcs = [ + "transforms/sink_constants_to_control_flow.cc", + ], + deps = [ + ":hlo", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:lower_tf_lib", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + ], + alwayslink = 1, +) + cc_library( name = "map_xla_to_scalar_op", hdrs = ["transforms/map_xla_to_scalar_op.h"], @@ -873,6 +891,7 @@ cc_library( ":xla_legalize_to_standard", ":xla_lower", ":xla_materialize_broadcasts", + ":xla_sink_constants_to_control_flow", ":xla_test_passes", ], ) diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc index 5dc610a5670..22a0b038833 100644 --- a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc +++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc @@ -420,15 +420,37 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction( } case HloOpcode::kConditional: { llvm::SmallVector<Type, 4> rets; - TF_RETURN_IF_ERROR(GetMlirTypes( - {instruction->true_computation()->root_instruction()}, &rets)); + mlir::Type pred_or_index_type = + operands[0].getType().cast<mlir::TensorType>().getElementType(); + // It is a predicated conditional if first argument is a boolean and + // should be mapped to If op. + if (pred_or_index_type.isInteger(1)) { + TF_RETURN_IF_ERROR(GetMlirTypes( + {instruction->true_computation()->root_instruction()}, &rets)); - auto op = func_builder->create<mlir::xla_hlo::ConditionalOp>( - loc, rets, operands, attributes); - TF_RETURN_IF_ERROR(ImportComputation(instruction->true_computation(), - &op.true_branch())); - TF_RETURN_IF_ERROR(ImportComputation(instruction->false_computation(), - &op.false_branch())); + auto op = func_builder->create<mlir::xla_hlo::IfOp>(loc, rets, operands, + attributes); + TF_RETURN_IF_ERROR(ImportComputation(instruction->true_computation(), + &op.true_branch())); + TF_RETURN_IF_ERROR(ImportComputation(instruction->false_computation(), + &op.false_branch())); + return op.getOperation(); + } + + // Otherwise, it is a indexed conditional and should be mapped to Case op. + TF_RETURN_IF_ERROR(GetMlirTypes( + {instruction->branch_computation(0)->root_instruction()}, &rets)); + + int num_branches = instruction->branch_count(); + auto op = func_builder->create<mlir::xla_hlo::CaseOp>( + loc, rets, operands, attributes, num_branches); + for (auto index_and_computation : + llvm::enumerate(instruction->branch_computations())) { + auto index = index_and_computation.index(); + HloComputation* computation = index_and_computation.value(); + TF_RETURN_IF_ERROR( + ImportComputation(computation, &op.branches()[index])); + } return op.getOperation(); } case HloOpcode::kConcatenate: { diff --git a/tensorflow/compiler/mlir/xla/ir/chlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/chlo_ops.cc index 5322668aa2e..26db4549a2a 100644 --- a/tensorflow/compiler/mlir/xla/ir/chlo_ops.cc +++ b/tensorflow/compiler/mlir/xla/ir/chlo_ops.cc @@ -185,6 +185,16 @@ LogicalResult BroadcastComplexOp::reifyReturnTypeShapes( // BroadcastCompareOp (has custom type inference due to different result type). //===----------------------------------------------------------------------===// +void BroadcastCompareOp::build(OpBuilder& builder, OperationState& result, + Value lhs, Value rhs, + DenseIntElementsAttr broadcast_dimensions, + StringAttr comparison_direction) { + auto new_type = GetBroadcastType(lhs.getType(), rhs.getType(), + builder.getI1Type(), broadcast_dimensions); + build(builder, result, new_type, lhs, rhs, broadcast_dimensions, + comparison_direction); +} + LogicalResult BroadcastCompareOp::inferReturnTypeComponents( MLIRContext* context, Optional<Location> location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, diff --git a/tensorflow/compiler/mlir/xla/ir/chlo_ops.h b/tensorflow/compiler/mlir/xla/ir/chlo_ops.h index 474d4b7d95a..a5337907579 100644 --- a/tensorflow/compiler/mlir/xla/ir/chlo_ops.h +++ b/tensorflow/compiler/mlir/xla/ir/chlo_ops.h @@ -25,7 +25,7 @@ limitations under the License. #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project -#include "mlir/Interfaces/SideEffects.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project namespace mlir { namespace xla_chlo { diff --git a/tensorflow/compiler/mlir/xla/ir/chlo_ops.td b/tensorflow/compiler/mlir/xla/ir/chlo_ops.td index f9672c1a95a..febc99f6b72 100644 --- a/tensorflow/compiler/mlir/xla/ir/chlo_ops.td +++ b/tensorflow/compiler/mlir/xla/ir/chlo_ops.td @@ -360,6 +360,11 @@ def HLOClient_BroadcastCompareOp : HLOClient_BroadcastBinaryElementwiseOp< HLO_ComparisonDirectionAttr:$comparison_direction ); let results = (outs HLO_PredTensor); + + let builders = [OpBuilder< + "OpBuilder &builder, OperationState &result, Value lhs, Value rhs, " + "DenseIntElementsAttr broadcast_dimensions, StringAttr comparison_direction" + >]; } #endif // CHLO_OPS diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc index 68eafb8b33e..d20f1713eba 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc @@ -1358,19 +1358,23 @@ static LogicalResult Verify(PadOp op) { //===----------------------------------------------------------------------===// static LogicalResult Verify(ReshapeOp op) { - auto operand_ty = op.operand().getType().cast<TensorType>(); + // If the operand type is dynamically shaped there is nothing to verify. + auto operand_ty = op.operand().getType().cast<RankedTensorType>(); if (!operand_ty || !operand_ty.hasStaticShape()) return success(); - int64_t num_input_elements = operand_ty.getNumElements(); - auto out_ty = op.getType().cast<RankedTensorType>(); - if (out_ty && out_ty.hasStaticShape()) { - int64_t num_output_elements = out_ty.getNumElements(); - if (num_input_elements != num_output_elements) - return op.emitOpError() - << "number of output elements (" << num_output_elements - << ") doesn't match expected number of elements (" - << num_input_elements << ")"; - } + // If the operand type is statically shaped (not required) the number of + // elements must match that of the result type. + auto result_ty = op.getType().cast<RankedTensorType>(); + assert(result_ty && result_ty.hasStaticShape() && + "result type must be statically shaped"); + int64_t num_result_elements = result_ty.getNumElements(); + int64_t num_operand_elements = operand_ty.getNumElements(); + if (num_result_elements != num_operand_elements) + return op.emitOpError() + << "number of output elements (" << num_result_elements + << ") doesn't match expected number of elements (" + << num_operand_elements << ")"; + return success(); } @@ -1392,94 +1396,71 @@ OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) { return {}; } +//===----------------------------------------------------------------------===// +// Case Op +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(CaseOp op) { + auto num_branches = op.branches().size(); + if (op.branch_operands().size() != num_branches) + return op.emitOpError() << "expects number of branches " << num_branches + << " to be same as number of branch operands " + << op.branch_operands().size(); + + MutableArrayRef<Region> branches = op.branches(); + OperandRange branch_operands = op.branch_operands(); + for (unsigned i = 0; i < num_branches; ++i) { + mlir::Region& branch_region = branches[i]; + if (branch_region.empty()) + return op.emitOpError() << "cannot have empty regions"; + mlir::Block& entry_block = branch_region.front(); + if (entry_block.getNumArguments() != 1) + return op.emitOpError() + << "expects branch regions to have single argument, but found " + << entry_block.getNumArguments() << " for branch " << i; + auto operand = branch_operands[i]; + if (entry_block.getArgument(0).getType() != operand.getType()) + return op.emitOpError() + << "expects operand " << i + 1 << " to be of type " + << entry_block.getArgument(0).getType() << ", but found " + << operand.getType(); + WalkResult walker = branch_region.walk([&](ReturnOp return_op) { + if (return_op.getOperands().getTypes() != op.getResultTypes()) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + if (walker.wasInterrupted()) + return op.emitOpError() + << "branch " << i + << " returned values do not match op result types"; + } + return success(); +} + //===----------------------------------------------------------------------===// // BinaryOps //===----------------------------------------------------------------------===// namespace { -// Gets the resulting type from a broadcast between two types. -static Type GetBroadcastType(Builder* builder, Type x, Type y, - Type element_type, - DenseIntElementsAttr broadcast_dimensions) { + +// Updates the element type of a (presumed) tensor type 'x', returning either +// a permuted UnrankedTensorType or RankedTensorType. +static Type UpdateResultElementType(Builder* builder, Type x, + Type element_type) { auto x_ranked = x.dyn_cast<RankedTensorType>(); - auto y_ranked = y.dyn_cast<RankedTensorType>(); - if (!x_ranked || !y_ranked) { + if (!x_ranked) { return UnrankedTensorType::get(element_type); } auto shape_x = x_ranked.getShape(); - auto shape_y = y_ranked.getShape(); - - if (shape_x.size() == shape_y.size()) { - llvm::SmallVector<int64_t, 4> out_shape(shape_x.size()); - for (int i = 0; i < shape_x.size(); i++) { - auto x_val = shape_x[i]; - auto y_val = shape_y[i]; - if (x_val == -1 || y_val == -1) { - out_shape[i] = -1; - } else { - out_shape[i] = std::max(x_val, y_val); - } - } - return RankedTensorType::get(out_shape, element_type); - } - - // Return unranked tensor for invalid broadcast dimensions. - if (!broadcast_dimensions) return UnrankedTensorType::get(element_type); - - auto shape_large = shape_x.size() > shape_y.size() ? shape_x : shape_y; - auto shape_small = shape_x.size() <= shape_y.size() ? shape_x : shape_y; - - llvm::SmallVector<int64_t, 4> out_shape(shape_large.begin(), - shape_large.end()); - - // Update according to the broadcast dimensions. - for (auto index_pair : llvm::enumerate(broadcast_dimensions.getIntValues())) { - auto old_value = out_shape[index_pair.value().getSExtValue()]; - auto new_value = shape_small[index_pair.index()]; - if (old_value != -1 && (new_value == -1 || new_value > old_value)) { - out_shape[index_pair.value().getSExtValue()] = new_value; - } - } - - return RankedTensorType::get(out_shape, element_type); + return RankedTensorType::get(shape_x, element_type); } } // namespace -#define BINARY_BUILDER(Op) \ - void Op::build(OpBuilder& builder, OperationState& result, Value left, \ - Value right, DenseIntElementsAttr broadcast_dimensions) { \ - auto type = GetBroadcastType(&builder, left.getType().cast<ShapedType>(), \ - right.getType().cast<ShapedType>(), \ - getElementTypeOrSelf(right.getType()), \ - broadcast_dimensions); \ - return Op::build(builder, result, type, left, right, \ - broadcast_dimensions); \ - } - -BINARY_BUILDER(AddOp); -BINARY_BUILDER(AndOp); -BINARY_BUILDER(Atan2Op); -BINARY_BUILDER(DivOp); -BINARY_BUILDER(MaxOp); -BINARY_BUILDER(MinOp); -BINARY_BUILDER(MulOp); -BINARY_BUILDER(OrOp); -BINARY_BUILDER(PowOp); -BINARY_BUILDER(RemOp); -BINARY_BUILDER(ShiftLeftOp); -BINARY_BUILDER(ShiftRightArithmeticOp); -BINARY_BUILDER(ShiftRightLogicalOp); -BINARY_BUILDER(SubOp); -BINARY_BUILDER(XorOp); - -#undef BINARY_BUILDER - template <typename Op, typename ElementType = Type, typename ValType, typename Convert> static Attribute BinaryFolder(Op* op, ArrayRef<Attribute> attrs) { if (!attrs[0] || !attrs[1]) return {}; - if (op->broadcast_dimensions().hasValue()) return {}; DenseElementsAttr lhs = attrs[0].dyn_cast<DenseElementsAttr>(); DenseElementsAttr rhs = attrs[1].dyn_cast<DenseElementsAttr>(); @@ -1889,12 +1870,10 @@ void UnaryEinsumOp::getCanonicalizationPatterns( //===----------------------------------------------------------------------===// void CompareOp::build(OpBuilder& builder, OperationState& result, Value lhs, - Value rhs, DenseIntElementsAttr broadcast_dimensions, - StringAttr comparison_direction) { - auto new_type = GetBroadcastType(&builder, lhs.getType(), rhs.getType(), - builder.getI1Type(), broadcast_dimensions); - build(builder, result, new_type, lhs, rhs, broadcast_dimensions, - comparison_direction); + Value rhs, StringAttr comparison_direction) { + auto new_type = + UpdateResultElementType(&builder, lhs.getType(), builder.getI1Type()); + build(builder, result, new_type, lhs, rhs, comparison_direction); } #define GET_OP_CLASSES diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.h b/tensorflow/compiler/mlir/xla/ir/hlo_ops.h index 25b2f009cc6..9725a0684f6 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.h +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.h @@ -29,7 +29,7 @@ limitations under the License. #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project -#include "mlir/Interfaces/SideEffects.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project namespace mlir { class OpBuilder; diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td index f78ac7624d2..6c54e3fbf90 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td @@ -241,15 +241,9 @@ class HLO_BinaryElementwiseOp<string mnemonic, list<OpTrait> traits> : HLO_Op<mnemonic, !listconcat(traits, [InferShapedTypeOpInterface])> { let arguments = (ins HLO_Tensor:$lhs, - HLO_Tensor:$rhs, - OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions + HLO_Tensor:$rhs ); - let builders = [OpBuilder< - "OpBuilder &builder, OperationState &result, Value left, Value right, " - "DenseIntElementsAttr broadcast_dimensions" - >]; - let extraClassDeclaration = [{ static LogicalResult inferReturnTypeComponents( MLIRContext* context, Optional<Location> location, ValueRange operands, @@ -270,15 +264,15 @@ class HLO_BinaryElementwiseOp<string mnemonic, list<OpTrait> traits> : } def HLO_AddOp : HLO_BinaryElementwiseOp<"add", - [Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_AddOp { + [Commutative, NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_AddOp { let hasFolder = 1; } def HLO_Atan2Op : HLO_BinaryElementwiseOp<"atan2", - [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_Atan2Op; + [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_Atan2Op; def HLO_ComplexOp: HLO_Op<"complex", - [NoSideEffect, SameOperandsElementType, SameOperandsAndResultShape]>, + [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_ComplexOp { let builders = [OpBuilder< "OpBuilder &, OperationState &tblgen_state, Value lhs, Value rhs">]; @@ -289,39 +283,39 @@ def HLO_ComplexOp: HLO_Op<"complex", } def HLO_DivOp : HLO_BinaryElementwiseOp<"divide", - [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_DivOp { + [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_DivOp { } def HLO_MaxOp : HLO_BinaryElementwiseOp<"maximum", - [Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_MaxOp { + [Commutative, NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_MaxOp { } def HLO_MinOp : HLO_BinaryElementwiseOp<"minimum", - [Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_MinOp { + [Commutative, NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_MinOp { } def HLO_MulOp : HLO_BinaryElementwiseOp<"multiply", - [Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_MulOp { + [Commutative, NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_MulOp { let hasFolder = 1; } def HLO_PowOp : HLO_BinaryElementwiseOp<"power", - [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_PowOp; + [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_PowOp; def HLO_RemOp : HLO_BinaryElementwiseOp<"remainder", - [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_RemOp; + [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_RemOp; def HLO_ShiftLeftOp : HLO_BinaryElementwiseOp<"shift_left", - [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_ShiftLeftOp; + [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ShiftLeftOp; def HLO_ShiftRightArithmeticOp : HLO_BinaryElementwiseOp<"shift_right_arithmetic", - [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_ShiftRightArithmeticOp; + [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ShiftRightArithmeticOp; def HLO_ShiftRightLogicalOp : HLO_BinaryElementwiseOp<"shift_right_logical", - [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_ShiftRightLogicalOp; + [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ShiftRightLogicalOp; def HLO_SubOp : HLO_BinaryElementwiseOp<"subtract", - [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_SubOp { + [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_SubOp { let hasFolder = 1; } @@ -331,11 +325,11 @@ def HLO_SubOp : HLO_BinaryElementwiseOp<"subtract", // See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations class HLO_BinaryLogicalElementwiseOp<string mnemonic> : - HLO_BinaryElementwiseOp<mnemonic, [Commutative, NoSideEffect]> { + HLO_BinaryElementwiseOp< + mnemonic, [Commutative, NoSideEffect, SameOperandsAndResultType]> { let arguments = (ins HLO_PredOrIntTensor:$lhs, - HLO_PredOrIntTensor:$rhs, - OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions + HLO_PredOrIntTensor:$rhs ); } @@ -467,7 +461,7 @@ def HLO_ReplicaIdOp : HLO_Op<"replica_id", [NoSideEffect]>, // XLA control flow op definitions. //===----------------------------------------------------------------------===// -def HLO_AfterAllOp : HLO_Op<"after_all", []> { +def HLO_AfterAllOp : HLO_Op<"after_all", [NoSideEffect]> { string summary = "AfterAll operator"; @@ -484,8 +478,11 @@ def HLO_AfterAllOp : HLO_Op<"after_all", []> { let results = (outs HLO_Token); } -def HLO_ConditionalOp: HLO_Op<"conditional", []> { - string summary = "Conditional operator"; +// Xla Client API has two separate calls for indexed and predicated conditional, +// although both eventually map to kConditional HLO. IfOp maps to predicated +// conditional use of kConditional HLO. +def HLO_IfOp: HLO_Op<"if", [RecursiveSideEffects]> { + string summary = "If operator"; string description = [{ Returns the result of executing either a true or false function depending on @@ -500,7 +497,8 @@ def HLO_ConditionalOp: HLO_Op<"conditional", []> { HLO_TensorOrTuple:$false_arg ); - let regions = (region AnyRegion:$true_branch, AnyRegion:$false_branch); + let regions = (region AnyRegion:$true_branch, + AnyRegion:$false_branch); let results = (outs HLO_TensorOrTuple); @@ -508,7 +506,27 @@ def HLO_ConditionalOp: HLO_Op<"conditional", []> { let hasCustomHLOConverter = 1; } -def HLO_WhileOp: HLO_Op<"while", [SameOperandsAndResultType]> { +// Xla Client API has two separate calls for indexed and predicated conditional, +// although both eventually map to kConditional HLO. CaseOp maps to indexed +// conditional use of kConditional HLO. +def HLO_CaseOp: HLO_Op<"case", [RecursiveSideEffects]>, + BASE_HLO_CaseOp { + + let arguments = (ins + I32Tensor:$index, + Variadic<HLO_TensorOrTuple>:$branch_operands + ); + + let regions = (region VariadicRegion<AnyRegion>:$branches); + + let results = (outs Variadic<HLO_TensorOrTuple>); + + let hasCustomHLOConverter = 1; +} + + +def HLO_WhileOp: HLO_Op<"while", [RecursiveSideEffects, + SameOperandsAndResultType]> { string summary = "While operator"; string description = [{ @@ -529,7 +547,7 @@ def HLO_WhileOp: HLO_Op<"while", [SameOperandsAndResultType]> { } def HLO_AllReduceOp : HLO_Op<"all_reduce", - [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_AllReduceOp { + [SameOperandsAndResultType]>, BASE_HLO_AllReduceOp { let arguments = (ins HLO_Tensor:$operand, @@ -556,7 +574,7 @@ def HLO_AllToAllOp : HLO_Op<"all_to_all", } def HLO_ReduceOp: HLO_Op<"reduce", [ - NoSideEffect, + RecursiveSideEffects, SameVariadicOperandSize, SingleBlockImplicitTerminator<"ReturnOp"> ]>, BASE_HLO_ReduceOp { @@ -614,23 +632,18 @@ def HLO_TupleOp : HLO_Op<"tuple", [NoSideEffect]>, BASE_HLO_TupleOp { } def HLO_CompareOp: HLO_Op<"compare", - [NoSideEffect, SameOperandsElementType]>, BASE_HLO_CompareOp { + [NoSideEffect, SameTypeOperands, SameOperandsAndResultShape]>, + BASE_HLO_CompareOp { let arguments = (ins HLO_Tensor:$lhs, HLO_Tensor:$rhs, - OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions, HLO_ComparisonDirectionAttr:$comparison_direction ); - let builders = [OpBuilder< - "OpBuilder &builder, OperationState &result, Value left, Value right, " - "DenseIntElementsAttr broadcast_dimensions, " - "StringAttr comparison_direction" - >]; let results = (outs HLO_PredTensor); let builders = [OpBuilder< "OpBuilder &builder, OperationState &result, Value lhs, Value rhs, " - "DenseIntElementsAttr broadcast_dimensions, StringAttr comparison_direction" + "StringAttr comparison_direction" >]; } @@ -785,7 +798,7 @@ def HLO_ScalarsToDimensionTensorOp : HLO_Op<"scalars_to_dimension_tensor", compute shape arguments to dynamic operations. }]; - let arguments = (ins Variadic<AnySignlessInteger>:$scalars); + let arguments = (ins Variadic<HLO_DimensionValue>:$scalars); let results = (outs HLO_DimensionTensor); // Cannot be exported to legacy formats. @@ -1042,8 +1055,8 @@ def HLO_GetDimensionSizeOp: HLO_Op<"get_dimension_size", [NoSideEffect]>, } def HLO_MapOp: HLO_Op<"map", - [NoSideEffect, SameOperandsElementType, SameOperandsAndResultShape, - SingleBlockImplicitTerminator<"ReturnOp">]>, + [RecursiveSideEffects, SameOperandsElementType, + SameOperandsAndResultShape, SingleBlockImplicitTerminator<"ReturnOp">]>, BASE_HLO_MapOp { let arguments = (ins Variadic<HLO_Tensor>:$operands, @@ -1058,13 +1071,13 @@ def HLO_ReshapeOp: HLO_Op<"reshape", [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_ReshapeOp { let arguments = (ins HLO_Tensor:$operand); - let results = (outs HLO_Tensor); + let results = (outs HLO_StaticShapeTensor); let hasFolder = 1; let hasCustomHLOConverter = 1; } -def HLO_DynamicReshapeOp: HLO_Op<"dynamic_reshape", []> { +def HLO_DynamicReshapeOp: HLO_Op<"dynamic_reshape", [NoSideEffect]> { let summary = "Reshape a tensor to a given, possibly dynamic, shape."; let description = [{ Reshapes `operand` to `output_shape`. @@ -1092,7 +1105,8 @@ def ScatterDimensionNumbers : StructAttr<"ScatterDimensionNumbers", HLO_Dialect, let description = "Structure of dimension information for scatter"; } -def HLO_ScatterOp: HLO_Op<"scatter", [NoSideEffect]>, BASE_HLO_ScatterOp { +def HLO_ScatterOp: HLO_Op<"scatter", [RecursiveSideEffects]>, + BASE_HLO_ScatterOp { let arguments = (ins HLO_Tensor:$operand, HLO_Tensor:$scatter_indices, @@ -1121,7 +1135,7 @@ def HLO_SelectOp: HLO_Op<"select", [NoSideEffect, DeclareOpInterfaceMethods<Infe } def HLO_SelectAndScatterOp: HLO_Op<"select_and_scatter", - [NoSideEffect]>, BASE_HLO_SelectAndScatterOp { + [RecursiveSideEffects]>, BASE_HLO_SelectAndScatterOp { let arguments = (ins HLO_Tensor:$operand, HLO_Tensor:$source, @@ -1148,7 +1162,7 @@ def HLO_SetDimensionSizeOp: HLO_Op<"set_dimension_size", [NoSideEffect]>, let results = (outs HLO_Tensor); } -def HLO_SortOp : HLO_Op<"sort", [NoSideEffect]>, BASE_HLO_SortOp { +def HLO_SortOp : HLO_Op<"sort", [RecursiveSideEffects]>, BASE_HLO_SortOp { let arguments = (ins Variadic<HLO_Tensor>:$operands, DefaultValuedAttr<I64Attr, "-1">:$dimension, @@ -1200,7 +1214,7 @@ def HLO_PadOp: HLO_Op<"pad", let hasCustomHLOConverter = 1; } -def HLO_TraceOp: HLO_Op<"trace", [NoSideEffect]>, BASE_HLO_TraceOp { +def HLO_TraceOp: HLO_Op<"trace", []>, BASE_HLO_TraceOp { let arguments = (ins HLO_Tensor:$operand, StrAttr:$tag @@ -1234,7 +1248,7 @@ def HLO_TriangularSolveOp: HLO_Op<"triangular_solve", } def HLO_ReduceWindowOp: HLO_Op<"reduce_window", [ - NoSideEffect, + RecursiveSideEffects, SingleBlockImplicitTerminator<"ReturnOp"> ]>, BASE_HLO_ReduceWindowOp { @@ -1265,7 +1279,7 @@ def HLO_ReduceWindowOp: HLO_Op<"reduce_window", [ // TODO(hinsu): Implement custom printer and parser. } -def HLO_ReturnOp : HLO_Op<"return", [Terminator]> { +def HLO_ReturnOp : HLO_Op<"return", [NoSideEffect, Terminator]> { let summary = [{ The `hlo.return` operation terminates a region and returns values. }]; diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td index b5de675f13f..bad1bf16ec3 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td @@ -62,9 +62,11 @@ def HLO_Tuple : NestedTupleOf<[HLO_Tensor, HLO_Token]>; def HLO_TensorOrTuple : AnyTypeOf<[HLO_Tensor, HLO_Tuple]>; +def HLO_DimensionValue : AnyTypeOf<[Index, HLO_Pred, HLO_Int]>; + // Dynamic representation of a shape vector as a tensor. def HLO_DimensionTensor : ShapedContainerType< - [Index, HLO_Pred, HLO_Int], + [HLO_DimensionValue], And<[IsTensorTypePred, HasAnyRankOfPred<[1]>]>, "a 1D tensor of dimensions">; @@ -553,6 +555,29 @@ class BASE_HLO_XorOp { }]; } +//===----------------------------------------------------------------------===// +// XLA control flow related op definitions. +//===----------------------------------------------------------------------===// + +class BASE_HLO_CaseOp { + string summary = "Switch-Case operator"; + + string description = [{ + Returns the result of executing `branches[index]`. If + `index` is < 0 or >= N, then `branches[N-1] is executed as + the default branch. + + Each branch `branches[b]` must take in a single argument of same type as + `branch_operands[b]` and will be invoked with `branch_operands[b]`. The type + of the returned value of each branch must be the same. + + Note that only one of the branches will be executed depending on the value + of index. + See https://www.tensorflow.org/xla/operation_semantics#conditional. + }]; + +} + //===----------------------------------------------------------------------===// // XLA parallelism related op definitions. //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_utils.h b/tensorflow/compiler/mlir/xla/ir/hlo_utils.h index 079169e9c5c..03e41f6432c 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_utils.h +++ b/tensorflow/compiler/mlir/xla/ir/hlo_utils.h @@ -35,22 +35,33 @@ mlir::DenseIntElementsAttr getBroadcastDimensionsAttr(mlir::Builder* b, mlir::Value y, bool allow_empty = true); -/// Get a constant splat for the given value type. +// Get a constant splat for the given value of type. Requires value to be of +// type static shaped RankedTensorType. +template <typename T> +static ElementsAttr getSplat(Builder* b, RankedTensorType ty, T constant) { + Type element_ty = getElementTypeOrSelf(ty); + + if (element_ty.isSignlessInteger()) + return DenseElementsAttr::get(ty, b->getIntegerAttr(element_ty, constant)); + + if (element_ty.isa<FloatType>()) + return DenseElementsAttr::get(ty, b->getFloatAttr(element_ty, constant)); + + if (auto complex_ty = element_ty.dyn_cast<ComplexType>()) { + auto complex_element_ty = complex_ty.getElementType(); + if (complex_element_ty.isF32()) + return DenseElementsAttr::get(ty, + static_cast<std::complex<float>>(constant)); + if (complex_element_ty.isF64()) + return DenseElementsAttr::get( + ty, static_cast<std::complex<double>>(constant)); + } + llvm_unreachable("unhandled element type"); +} + template <typename T> static ElementsAttr getSplat(Builder* b, Value val, T constant) { - auto valType = val.getType().cast<TensorType>(); - auto valElementType = getElementTypeOrSelf(val.getType()); - - // Handle integer elements. - Attribute elementAttr; - if (valElementType.isSignlessInteger()) - elementAttr = b->getIntegerAttr(valElementType, constant); - else if (valElementType.isa<FloatType>()) - elementAttr = b->getFloatAttr(valElementType, constant); - else - llvm_unreachable("unhandled element type"); - - return DenseElementsAttr::get(valType, elementAttr); + return getSplat(b, val.getType().cast<RankedTensorType>(), constant); } // Returns DenseElementsAttr of rank zero with the given element type and the diff --git a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.h b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.h index 190c5ff832d..1c4ccaae214 100644 --- a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.h +++ b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.h @@ -27,7 +27,7 @@ limitations under the License. #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project -#include "mlir/Interfaces/SideEffects.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project namespace mlir { class OpBuilder; diff --git a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td index db75bbd1f67..9a2168d3088 100644 --- a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td +++ b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td @@ -196,6 +196,19 @@ def LHLO_ReduceWindowOp: LHLO_Op<"reduce_window", [ let regions = (region SizedRegion<1>:$body); } +def LHLO_CaseOp: LHLO_Op<"case", [ + SingleBlockImplicitTerminator<"TerminatorOp"> + ]>, BASE_HLO_CaseOp { + + let arguments = (ins + Arg<LHLO_Buffer, "", [MemRead]>:$index, + Arg<Variadic<LHLO_BufferOrTuple>, "", [MemRead]>:$branch_operands, + Arg<LHLO_BufferOrTuple, "", [MemWrite]>:$out + ); + + let regions = (region VariadicRegion<SizedRegion<1>>:$branches); +} + //===----------------------------------------------------------------------===// // XLA tuple op definitions. //===----------------------------------------------------------------------===// @@ -431,6 +444,10 @@ def TerminatorOp : let description = [{ Terminator operation for the LHLO dialect. }]; + let builders = [OpBuilder< + "OpBuilder &b, OperationState &result, ValueRange operands", + [{ build(b, result, llvm::None, operands, llvm::None); }] + >]; } #endif // LHLO_OPS diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc index 461c357e509..774caab77fb 100644 --- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc +++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc @@ -209,7 +209,6 @@ StatusOr<XlaOp> MlirHloBuilder::Compare(const Shape& shape, XlaOp lhs, shape, builder_)); auto op = builder_.create<mlir::xla_hlo::CompareOp>( loc_, ty, GetValue(lhs), GetValue(rhs), - /*broadcast_dimensions=*/mlir::DenseIntElementsAttr(), builder_.getStringAttr(ComparisonDirectionToString(direction))); return MakeXlaOp(op.getResult()); } diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc index 228a26b5abd..8150d719f3e 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc @@ -618,7 +618,7 @@ LogicalResult ExportXlaOp(DynamicReshapeOp op, OpLoweringContext ctx) { return failure(); } -LogicalResult ExportXlaOp(ConditionalOp op, OpLoweringContext ctx) { +LogicalResult ExportXlaOp(IfOp op, OpLoweringContext ctx) { xla::XlaComputation true_branch; xla::XlaComputation false_branch; auto& value_map = *ctx.values; @@ -636,6 +636,33 @@ LogicalResult ExportXlaOp(ConditionalOp op, OpLoweringContext ctx) { return success(); } +LogicalResult ExportXlaOp(CaseOp op, OpLoweringContext ctx) { + llvm::DenseMap<mlir::Value, xla::XlaOp>& value_map = *ctx.values; + OperandRange operands = op.branch_operands(); + MutableArrayRef<Region> branches = op.branches(); + llvm::SmallVector<xla::XlaOp, 4> branch_operands(branches.size()); + std::vector<xla::XlaComputation> computations(branches.size()); + std::vector<xla::XlaComputation*> computations_p(branches.size()); + + for (unsigned i = 0; i < branches.size(); ++i) { + branch_operands[i] = value_map[operands[i]]; + computations_p[i] = &computations[i]; + if (failed(ctx.converter->LowerRegionAsComputation(&branches[i], + computations_p[i]))) + return failure(); + } + xla::XlaOp result = + xla::Conditional(value_map[op.index()], computations_p, branch_operands); + if (op.getNumResults() == 1) { + value_map[op.getResult(0)] = result; + } else { + for (auto item : llvm::enumerate(op.getResults())) { + value_map[item.value()] = xla::GetTupleElement(result, item.index()); + } + } + return success(); +} + LogicalResult ExportXlaOp(ConstOp op, OpLoweringContext ctx) { return failure(); } diff --git a/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir b/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir index 30255586002..afe3e1b73a5 100644 --- a/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir @@ -387,8 +387,8 @@ func @dynamic_reshape_not_actually_dynamic(%arg0: tensor<4xf32>, %shape: tensor< return %0 : tensor<4x1xf32> } -// CHECK-LABEL: do_not_dce_while -func @do_not_dce_while(%arg0: tensor<i64>) -> tensor<i64> { +// CHECK-LABEL: do_not_dce_while_with_outfeed +func @do_not_dce_while_with_outfeed(%arg0: tensor<i64>) -> tensor<i64> { // CHECK: xla_hlo.while %0 = "xla_hlo.while"(%arg0) ( { ^bb0(%arg1: tensor<i64>): @@ -404,3 +404,19 @@ func @do_not_dce_while(%arg0: tensor<i64>) -> tensor<i64> { return %arg0 : tensor<i64> } + +// CHECK-LABEL: dce_while_without_side_effect +func @dce_while_without_side_effect(%arg0: tensor<i64>) -> tensor<i64> { + // CHECK-NOT: xla_hlo.while + %0 = "xla_hlo.while"(%arg0) ( { + ^bb0(%arg1: tensor<i64>): + %1 = "xla_hlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1> + "xla_hlo.return"(%1) : (tensor<i1>) -> () + }, { + ^bb0(%arg1: tensor<i64>): + %1 = "xla_hlo.create_token"() : () -> !xla_hlo.token + "xla_hlo.return"(%arg1) : (tensor<i64>) -> () + }) : (tensor<i64>) -> tensor<i64> + + return %arg0 : tensor<i64> +} diff --git a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir index 53296b257ae..68f6d172afc 100644 --- a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir +++ b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir @@ -395,3 +395,15 @@ func @tanh_dyn(%arg0: tensor<?x?xf32>) { // CHECK: "xla_lhlo.tanh"(%arg0, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>) -> () return } + +// ----- + +// CHECK-LABEL: func @dot +func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> { +// CHECK-SAME: (%[[ARG0:.*]]: [[TYPE:.*]], +// CHECK-SAME: %[[RESULT:.*]]: [[TYPE]]) +// CHECK: "xla_lhlo.dot"(%[[ARG0]], %[[ARG0]], %{{.*}}) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> () + %dot = "xla_hlo.dot"(%arg0, %arg0) + : (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32> + return %dot : tensor<1024x1024xf32> + } diff --git a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir index a856ee5e83c..a27bf2cff79 100644 --- a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir +++ b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir @@ -542,3 +542,16 @@ func @convert_f32_to_i32(%input: tensor<2x2xf32>) -> tensor<2x2xi32> { // CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32): // CHECK-NEXT: %[[RESULT:.*]] = fptosi %[[OPERAND_IN]] : f32 to i32 // CHECK-NEXT: linalg.yield %[[RESULT]] : i32 + +// ----- + +// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 2)> +// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func @reverse +func @reverse(%input: tensor<2x3xf32>) -> tensor<2x3xf32> { + %result = "xla_hlo.reverse"(%input) { + dimensions = dense<1> : tensor<1xi64> + } : (tensor<2x3xf32>) -> tensor<2x3xf32> + return %result : tensor<2x3xf32> +} +// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-control-flow.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-control-flow.mlir index 83c3f765dc3..83880bc8ce9 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-control-flow.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-control-flow.mlir @@ -35,7 +35,7 @@ func @conditional(%arg0: tensor<f32>) -> tensor<f32> { // CHECK: [[VAL1:%.+]] = extract_element [[VAL0]][] : tensor<i1> // CHECK: cond_br [[VAL1]], ^bb1(%arg0 : tensor<f32>), ^bb2(%arg0 : tensor<f32>) - %1 = "xla_hlo.conditional"(%0, %arg0, %arg0) ( { + %1 = "xla_hlo.if"(%0, %arg0, %arg0) ( { ^bb0(%arg1: tensor<f32>): // CHECK: ^bb1([[VAL2:%.+]]: tensor<f32>): @@ -131,7 +131,7 @@ func @conditional_with_multiple_blocks(%arg0: tensor<f32>, %arg1: tensor<f32>, % // CHECK: ^[[EXIT]](%6: tensor<f32>): // CHECK: return %6 : tensor<f32> // CHECK: } - %1 = "xla_hlo.conditional"(%pred, %arg0, %arg1) ( { + %1 = "xla_hlo.if"(%pred, %arg0, %arg1) ( { ^then_entry(%arg2: tensor<f32>): br ^then_succ(%arg2: tensor<f32>) ^then_succ(%0: tensor<f32>): diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir new file mode 100644 index 00000000000..c114b8c50a5 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir @@ -0,0 +1,334 @@ +// Note that binary elementwise tests are run with chlo legalization enabled +// (unlike the rest), since this is the primary use case for such ops and +// verification of shapes and broadcasts is desired. +// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion legalize-chlo=true" %s | FileCheck %s --dump-input-on-failure + +//===----------------------------------------------------------------------===// +// Binary op legalizations. +// Most of these expand from the same pattern. Full semantics are +// verified for tf.Add and pattern application only for the rest. +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @add +func @add(%arg0: tensor<2xi32>) -> tensor<2xi32> { + // CHECK-NEXT: %[[SUM0:.*]] = xla_hlo.add %arg0, %arg0 : tensor<2xi32> + // CHECK-NEXT: %[[SUM1:.*]] = xla_hlo.add %[[SUM0]], %arg0 : tensor<2xi32> + // CHECK-NEXT: return %[[SUM1]] : tensor<2xi32> + %0 = "tf.Add"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + %1 = "tf.AddV2"(%0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + return %1: tensor<2xi32> +} + +// CHECK-LABEL: func @broadcast_add +// TODO(laurenzo): Change this to a (5 + 2x1) shaped add to make the check +// patterns unambiguous and more interesting (once broadcastable trait is +// fixed upstream). +func @broadcast_add(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { + // CHECK: %[[UNUSED_LHS_SHAPE:.+]] = shape.const_shape [1] + // CHECK: %[[UNUSED_RHS_SHAPE:.+]] = shape.const_shape [1, 2] + // CHECK: %[[RESULT_SHAPE:.+]] = shape.const_shape [1, 2] + // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) + // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} + // CHECK: xla_hlo.add %[[LHS_BCAST]], %[[RHS_BCAST]] + %0 = "tf.Add"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> + return %0: tensor<1x2xi32> +} + +// CHECK-LABEL: func @broadcast_multi_dim_add +// TODO(laurenzo): Change this to a (4x1x1 + 1x4x4x4) shaped add once upstream +// broadcastable bug is fixed (helps make the CHECK matching unambiguous) +func @broadcast_multi_dim_add(%arg0: tensor<4x1x1xi32>, %arg1: tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> { + // CHECK: %[[UNUSED_LHS_SHAPE:.+]] = shape.const_shape [4, 1, 1] + // CHECK: %[[UNUSED_RHS_SHAPE:.+]] = shape.const_shape [4, 4, 4, 4] + // CHECK: %[[RESULT_SHAPE:.+]] = shape.const_shape [4, 4, 4, 4] + // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) + // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>} + // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>} + // CHECK: xla_hlo.add %[[LHS_BCAST]], %[[RHS_BCAST]] + %0 = "tf.Add"(%arg0, %arg1) : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> + return %0: tensor<4x4x4x4xi32> +} + +// CHECK-LABEL: func @add_dynamic +func @add_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<?x?xi32>) -> tensor<?x?xi32> { + // CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.shape_of %arg0 + // CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.shape_of %arg1 + // CHECK-DAG: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE]], %[[RHS_SHAPE]]) + // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) + // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} + // CHECK: xla_hlo.add %4, %5 : tensor<?x?xi32> + %0 = "tf.Add"(%arg0, %arg1) : (tensor<?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32> + return %0: tensor<?x?xi32> +} + +// CHECK-LABEL: func @div +func @div(%arg0: tensor<2xi32>) -> tensor<2xi32> { + // CHECK-NEXT: %0 = xla_hlo.divide %arg0, %arg0 : tensor<2xi32> + // CHECK-NEXT: return %0 : tensor<2xi32> + %0 = "tf.Div"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + return %0: tensor<2xi32> +} + +// CHECK-LABEL: func @shift_left +func @shift_left(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { + // CHECK: xla_hlo.shift_left %arg0, %arg1 : tensor<4xi32> + %0 = "tf.LeftShift"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + return %0 : tensor<4xi32> +} + +// CHECK-LABEL: func @div_unranked +func @div_unranked(%arg0: tensor<*xi32>, %arg1: tensor<?x?xi32>) -> tensor<?x?xi32> { + // CHECK: tf.Div + %0 = "tf.Div"(%arg0, %arg1) : (tensor<*xi32>, tensor<?x?xi32>) -> tensor<?x?xi32> + return %0: tensor<?x?xi32> +} + +// CHECK-LABEL: func @maximum +func @maximum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + // CHECK: xla_hlo.maximum %arg0, %arg1 : tensor<4xf32> + %0 = "tf.Maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +} + +// CHECK-LABEL: func @minimum +func @minimum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + // CHECK: xla_hlo.minimum %arg0, %arg1 : tensor<4xf32> + %0 = "tf.Minimum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +} + +// CHECK-LABEL: func @mul +func @mul(%arg0: tensor<2xi32>) -> tensor<2xi32> { + // CHECK-NEXT: %0 = xla_hlo.multiply %arg0, %arg0 : tensor<2xi32> + // CHECK-NEXT: return %0 : tensor<2xi32> + %0 = "tf.Mul"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + return %0: tensor<2xi32> +} + +// CHECK-LABEL: func @real_div +func @real_div(%arg0: tensor<2xi32>) -> tensor<2xi32> { + // CHECK-NEXT: %0 = xla_hlo.divide %arg0, %arg0 : tensor<2xi32> + %0 = "tf.RealDiv"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + return %0: tensor<2xi32> +} + +// CHECK-LABEL: func @sub +func @sub(%arg0: tensor<2xi32>) -> tensor<2xi32> { + // CHECK-NEXT: %0 = xla_hlo.subtract %arg0, %arg0 : tensor<2xi32> + // CHECK-NEXT: return %0 : tensor<2xi32> + %0 = "tf.Sub"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + return %0: tensor<2xi32> +} + +// CHECK-LABEL: func @shift_right +func @shift_right(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { + // CHECK: xla_hlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32> + %0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + return %0 : tensor<4xi32> +} + +// CHECK-LABEL: func @shift_right_unsigned +func @shift_right_unsigned(%arg0: tensor<4xui8>, %arg1: tensor<4xui8>) -> tensor<4xui8> { + // CHECK: tf.RightShift + %0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xui8>, tensor<4xui8>) -> tensor<4xui8> + return %0 : tensor<4xui8> +} + +// CHECK-LABEL: func @broadcast_shift_right_unsigned +func @broadcast_shift_right_unsigned(%arg0: tensor<4xui8>, %arg1: tensor<2x4xui8>) -> tensor<2x4xui8> { + // CHECK: tf.RightShift + %0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xui8>, tensor<2x4xui8>) -> tensor<2x4xui8> + return %0 : tensor<2x4xui8> +} + +// CHECK-LABEL: func @and +func @and(%arg0: tensor<2xi1>) -> tensor<2xi1> { + // CHECK-NEXT: xla_hlo.and + %0 = "tf.LogicalAnd"(%arg0, %arg0) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1> + return %0: tensor<2xi1> +} + +// CHECK-LABEL: func @and_unranked +func @and_unranked(%arg0: tensor<*xi1>, %arg1: tensor<*xi1>) -> tensor<*xi1> { + // CHECK: tf.LogicalAnd + %0 = "tf.LogicalAnd"(%arg0, %arg1) : (tensor<*xi1>, tensor<*xi1>) -> tensor<*xi1> + return %0: tensor<*xi1> +} + +// CHECK-LABEL: func @or +func @or(%arg0: tensor<2xi1>) -> tensor<2xi1> { + // CHECK-NEXT: xla_hlo.or + %0 = "tf.LogicalOr"(%arg0, %arg0) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1> + return %0: tensor<2xi1> +} + +// CHECK-LABEL: func @bitwise_or +func @bitwise_or(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { + // CHECK-NEXT: xla_hlo.or + %0 = "tf.BitwiseOr"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + return %0: tensor<4xi32> +} + +// CHECK-LABEL: func @bitwise_and +func @bitwise_and(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { + // CHECK-NEXT: xla_hlo.and + %0 = "tf.BitwiseAnd"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + return %0: tensor<4xi32> +} + +// CHECK-LABEL: func @pow +func @pow(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK-NEXT: xla_hlo.power + %0 = "tf.Pow"(%arg0, %arg0) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> + return %0: tensor<2xf32> +} + +//===----------------------------------------------------------------------===// +// Equality op legalizations. +// tf.Equal and tf.NotEqual expand from the same pattern. Full semantics are +// verified for tf.Equal and pattern application only for tf.NotEqual +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @equal +func @equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { + // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} + %0 = "tf.Equal"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + return %0: tensor<2xi1> +} + +// CHECK-LABEL: func @equal_dynamic +func @equal_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi1> { + // CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.shape_of %arg0 + // CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.const_shape [1] + // CHECK-DAG: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE]], %[[RHS_SHAPE]]) + // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) + // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "EQ"} + %0 = "tf.Equal"(%arg0, %arg1) : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1> + return %0: tensor<?xi1> +} + +// CHECK-LABEL: func @equal_broadcast +func @equal_broadcast(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { + // CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.const_shape [1] + // CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.const_shape [1, 2] + // CHECK-DAG: %[[RESULT_SHAPE:.+]] = shape.const_shape [1, 2] + // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) + // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} + // CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "EQ"} + %0 = "tf.Equal"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + return %0: tensor<1x2xi1> +} + +// CHECK-LABEL: func @equal_broadcast_no_incompatible_shapes_error +func @equal_broadcast_no_incompatible_shapes_error(%arg0: tensor<2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { + // CHECK-NEXT: "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = false} + %0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + return %0: tensor<1x2xi1> +} + +// CHECK-LABEL: func @equal_incompatible_shape_broadcastable +func @equal_incompatible_shape_broadcastable(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi1> { + // CHECK-NEXT: "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = false} + %0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1> + return %0: tensor<?xi1> +} + +// CHECK-LABEL: func @equal_incompatible_shape_dynamic +func @equal_incompatible_shape_dynamic(%arg0: tensor<2xi32>, %arg1: tensor<?xi32>) -> tensor<*xi1> { + // CHECK-NEXT: "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = false} + %0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor<2xi32>, tensor<?xi32>) -> tensor<*xi1> + return %0: tensor<*xi1> +} + +// CHECK-LABEL: func @equal_incompatible_shape_both_dynamic +func @equal_incompatible_shape_both_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<*xi1> { + // CHECK-NEXT: "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = false} + %0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor<?xi32>, tensor<?xi32>) -> tensor<*xi1> + return %0: tensor<*xi1> +} + +// CHECK-LABEL: func @equal_unranked +func @equal_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi1> { + // CHECK: "tf.Equal" + %0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi1> + return %0: tensor<*xi1> +} + +// CHECK-LABEL: func @notequal +func @notequal(%arg0: tensor<2xi32>) -> tensor<2xi1> { + // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} + %0 = "tf.NotEqual"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + return %0: tensor<2xi1> +} + +//===----------------------------------------------------------------------===// +// Compare op legalizations. +// These expand from the same pattern. Full semantics are checked for +// tf.Greater. Others just check that the pattern applied. +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @greater +func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> { + // CHECK: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} + %0 = "tf.Greater"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + return %0: tensor<2xi1> +} + +// CHECK-LABEL: func @broadcast_greater +func @broadcast_greater(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { + // CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.const_shape [1] + // CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.const_shape [1, 2] + // CHECK-DAG: %[[RESULT_SHAPE:.+]] = shape.const_shape [1, 2] + // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) + // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} + // CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "GT"} + %0 = "tf.Greater"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + return %0: tensor<1x2xi1> +} + +// CHECK-LABEL: func @greater_dynamic +func @greater_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi1> { + // CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.shape_of %arg0 + // CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.shape_of %arg1 + // CHECK-DAG: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE]], %[[RHS_SHAPE]]) + // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) + // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "GT"} + %0 = "tf.Greater"(%arg0, %arg1) : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi1> + return %0: tensor<?xi1> +} + +// CHECK-LABEL: func @greater_uranked +func @greater_uranked(%arg0: tensor<*xi32>) -> tensor<*xi1> { + // CHECK: "tf.Greater" + %0 = "tf.Greater"(%arg0, %arg0) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi1> + return %0: tensor<*xi1> +} + +// CHECK-LABEL: func @greater_equal +func @greater_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { + // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} + %0 = "tf.GreaterEqual"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + return %0: tensor<2xi1> +} + +// CHECK-LABEL: func @less +func @less(%arg0: tensor<2xi32>) -> tensor<2xi1> { + // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} + %0 = "tf.Less"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + return %0: tensor<2xi1> +} + +// CHECK-LABEL: func @less_equal +func @less_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { + // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} + %0 = "tf.LessEqual"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + return %0: tensor<2xi1> +} diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-control-flow.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-control-flow.mlir index 2984ba46993..b3307a8f52a 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-control-flow.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-control-flow.mlir @@ -1,12 +1,12 @@ // RUN: tf-opt -xla-legalize-tf-control-flow %s | FileCheck %s --dump-input-on-failure -// CHECK-LABEL: @conditional -func @conditional(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>) +// CHECK-LABEL: @if +func @if(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>) attributes {tf._input_shapes = ["tfshape$", "tfshape$"]} { // CHECK: [[VAL0:%.+]] = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1> %0 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1> // CHECK: [[VAL1:%.+]] = "xla_hlo.tuple"(%arg0, %arg1) - // CHECK: [[VAL2:%.+]] = "xla_hlo.conditional"([[VAL0]], [[VAL1]], [[VAL1]]) ( { + // CHECK: [[VAL2:%.+]] = "xla_hlo.if"([[VAL0]], [[VAL1]], [[VAL1]]) ( { // CHECK: ^bb0(%arg2: tuple<tensor<f32>, tensor<f32>>): // CHECK: [[VAL4:%.+]] = "xla_hlo.get_tuple_element"(%arg2) {index = 0 : i32} // CHECK: [[VAL5:%.+]] = "xla_hlo.get_tuple_element"(%arg2) {index = 1 : i32} @@ -40,7 +40,52 @@ attributes {tf._input_shapes = ["tfshape$", "tfshape$"]} { return %0 : tensor<f32> } -// CHECK-LABEL: @while + +// CHECK-LABEL: func @case +// CHECK-SAME: %[[BRANCH_INDEX:.*]]: tensor<i32>, %[[ARG0:.*]]: tensor<f32>, %[[ARG1:.*]]: tensor<f32>) -> (tensor<f32>, tensor<f32>) +func @case(%index: tensor<i32>, %arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>) { + %0:2 = "tf.Case"(%index, %arg0, %arg1) {branches = [@exponential, @log, @floor]} : (tensor<i32>, tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>) + // CHECK: %[[TUPLE_INPUT:.*]] = "xla_hlo.tuple"(%[[ARG0]], %[[ARG1]]) : (tensor<f32>, tensor<f32>) -> tuple<tensor<f32>, tensor<f32>> + // CHECK: %[[CASE:.*]]:2 = "xla_hlo.case"(%[[BRANCH_INDEX]], %[[TUPLE_INPUT]], %[[TUPLE_INPUT]], %[[TUPLE_INPUT]]) ( { + // CHECK: ^bb0(%[[TUPLE_ARG:.*]]: tuple<tensor<f32>, tensor<f32>>): + // CHECK: %[[TUPLE_ELEMENT_0:.*]] = "xla_hlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 0 : i32} : (tuple<tensor<f32>, tensor<f32>>) -> tensor<f32> + // CHECK: %[[TUPLE_ELEMENT_1:.*]] = "xla_hlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 1 : i32} : (tuple<tensor<f32>, tensor<f32>>) -> tensor<f32> + // CHECK: %[[CALL_EXP:.*]]:2 = call @exponential(%[[TUPLE_ELEMENT_0]], %[[TUPLE_ELEMENT_1]]) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>) + // CHECK: "xla_hlo.return"(%[[CALL_EXP]]#0, %[[CALL_EXP]]#1) : (tensor<f32>, tensor<f32>) -> () + // CHECK: }, { + // CHECK: ^bb0(%[[TUPLE_ARG:.*]]: tuple<tensor<f32>, tensor<f32>>): + // CHECK: %[[TUPLE_ELEMENT_0:.*]] = "xla_hlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 0 : i32} : (tuple<tensor<f32>, tensor<f32>>) -> tensor<f32> + // CHECK: %[[TUPLE_ELEMENT_1:.*]] = "xla_hlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 1 : i32} : (tuple<tensor<f32>, tensor<f32>>) -> tensor<f32> + // CHECK: %[[CALL_LOG:.*]]:2 = call @log(%[[TUPLE_ELEMENT_0]], %[[TUPLE_ELEMENT_1]]) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>) + // CHECK: "xla_hlo.return"(%[[CALL_LOG]]#0, %[[CALL_LOG]]#1) : (tensor<f32>, tensor<f32>) -> () + // CHECK: }, { + // CHECK: ^bb0(%[[TUPLE_ARG:.*]]: tuple<tensor<f32>, tensor<f32>>): + // CHECK: %[[TUPLE_ELEMENT_0:.*]] = "xla_hlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 0 : i32} : (tuple<tensor<f32>, tensor<f32>>) -> tensor<f32> + // CHECK: %[[TUPLE_ELEMENT_1:.*]] = "xla_hlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 1 : i32} : (tuple<tensor<f32>, tensor<f32>>) -> tensor<f32> + // CHECK: %[[CALL_FLOOR:.*]]:2 = call @floor(%[[TUPLE_ELEMENT_0]], %[[TUPLE_ELEMENT_1]]) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>) + // CHECK: "xla_hlo.return"(%[[CALL_FLOOR]]#0, %[[CALL_FLOOR]]#1) : (tensor<f32>, tensor<f32>) -> () + // CHECK: }) : (tensor<i32>, tuple<tensor<f32>, tensor<f32>>, tuple<tensor<f32>, tensor<f32>>, tuple<tensor<f32>, tensor<f32>>) -> (tensor<f32>, tensor<f32>) + return %0#0, %0#1 : tensor<f32>, tensor<f32> +// CHECK: return %[[CASE]]#0, %[[CASE]]#1 : tensor<f32>, tensor<f32> +} + +func @exponential(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>) { + %0 = "xla_hlo.exponential"(%arg1) : (tensor<f32>) -> tensor<f32> + return %0, %arg1 : tensor<f32>, tensor<f32> +} + +func @log(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>) { + %0 = "xla_hlo.log"(%arg0) : (tensor<f32>) -> tensor<f32> + return %0, %arg1 : tensor<f32>, tensor<f32> +} + +func @floor(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>) { + %0 = "xla_hlo.floor"(%arg0) : (tensor<f32>) -> tensor<f32> + return %0, %arg1 : tensor<f32>, tensor<f32> +} + + +// CHECK-LABEL: func @while func @while(%arg0: tensor<f32> {tf_saved_model.index_path = [0]}) -> (tensor<i32> {tf_saved_model.index_path = []}) attributes {tf._input_shapes = ["tfshape$"]} { // CHECK: [[VAL0:%.+]] = xla_hlo.constant dense<0> diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index 450910b2e4d..363e60eb341 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -1,4 +1,11 @@ -// RUN: tf-opt -xla-legalize-tf=allow-partial-conversion %s | FileCheck %s --dump-input-on-failure +// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion legalize-chlo=false" %s | FileCheck %s --dump-input-on-failure +// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion legalize-chlo=true" -verify-diagnostics %s +// This test runs twice: +// 1. Through FileCheck with chlo legalization disabled since verifying +// that the chlo ops emit produces more useful tests. +// 2. With chlo legalization enabled, verifying diagnostics to pick up any +// issues with the full lowering (can catch some broadcasting corner +// cases which emit with a warning). //===----------------------------------------------------------------------===// // BatchNorm op legalizations. @@ -47,7 +54,7 @@ func @fusedBatchNormV3_training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32> // CHECK: "xla_hlo.get_tuple_element"(%[[RESULT0]]) {index = 1 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> // CHECK: %[[VAR:.*]] = "xla_hlo.get_tuple_element"(%[[RESULT0]]) {index = 2 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> // CHECK: xla_hlo.constant - // CHECK: "xla_hlo.multiply"(%[[VAR]], {{.*}}) : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32> + // CHECK: xla_chlo.broadcast_multiply %[[VAR]], {{.*}} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32> return %0#0 : tensor<8x8x8x8xf32> } @@ -68,18 +75,18 @@ func @fusedBatchNormV3_training_exponentialAvgFactor(%arg0: tensor<8x8x8x8xf32>, // CHECK-DAG: %[[BATCH_VAR:.*]] = "xla_hlo.get_tuple_element"(%[[RESULT0]]) {index = 2 : i32} // CHECK: %[[FACTOR:.*]] = xla_hlo.constant dense<1.00195694> - // CHECK: %[[CORRECTED_VAR:.*]] = "xla_hlo.multiply"(%[[BATCH_VAR]], %[[FACTOR]]) + // CHECK: %[[CORRECTED_VAR:.*]] = xla_chlo.broadcast_multiply %[[BATCH_VAR]], %[[FACTOR]] // CHECK-DAG: %[[ALPHA:.*]] = xla_hlo.constant dense<0.199999988> // CHECK-DAG: %[[BETA:.*]] = xla_hlo.constant dense<8.000000e-01> - // CHECK: %[[ALPHA_MUL_OLD_MEAN:.*]] = "xla_hlo.multiply"(%[[ALPHA]], %arg3) - // CHECK: %[[BETA_MUL_BATCH_MEAN:.*]] = "xla_hlo.multiply"(%[[BETA]], %[[BATCH_MEAN]]) - // CHECK: %[[NEW_BATCH_MEAN:.*]] = xla_hlo.add %[[ALPHA_MUL_OLD_MEAN]], %[[BETA_MUL_BATCH_MEAN]] + // CHECK: %[[ALPHA_MUL_OLD_MEAN:.*]] = xla_chlo.broadcast_multiply %[[ALPHA]], %arg3 + // CHECK: %[[BETA_MUL_BATCH_MEAN:.*]] = xla_chlo.broadcast_multiply %[[BETA]], %[[BATCH_MEAN]] + // CHECK: %[[NEW_BATCH_MEAN:.*]] = xla_chlo.broadcast_add %[[ALPHA_MUL_OLD_MEAN]], %[[BETA_MUL_BATCH_MEAN]] - // CHECK: %[[ALPHA_MUL_OLD_VAR:.*]] = "xla_hlo.multiply"(%[[ALPHA]], %arg4) - // CHECK: %[[BETA_MUL_CORRECTED_VAR:.*]] = "xla_hlo.multiply"(%[[BETA]], %[[CORRECTED_VAR]]) - // CHECK: %[[NEW_BATCH_VAR:.*]] = xla_hlo.add %[[ALPHA_MUL_OLD_VAR]], %[[BETA_MUL_CORRECTED_VAR]] + // CHECK: %[[ALPHA_MUL_OLD_VAR:.*]] = xla_chlo.broadcast_multiply %[[ALPHA]], %arg4 + // CHECK: %[[BETA_MUL_CORRECTED_VAR:.*]] = xla_chlo.broadcast_multiply %[[BETA]], %[[CORRECTED_VAR]] + // CHECK: %[[NEW_BATCH_VAR:.*]] = xla_chlo.broadcast_add %[[ALPHA_MUL_OLD_VAR]], %[[BETA_MUL_CORRECTED_VAR]] // CHECK: return %[[NEW_BATCH_MEAN]], %[[NEW_BATCH_VAR]], %[[BATCH_MEAN]], %[[BATCH_VAR]] return %0#1, %0#2, %0#3, %0#4 : tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32> @@ -127,11 +134,12 @@ func @fusedBatchNormGrad_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x // CHECK-NEXT: %[[act:.*]] = "xla_hlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: %[[eps:.*]] = xla_hlo.constant dense<1.000000e-03> : tensor<f32> - // CHECK-NEXT: %[[add:.*]] = "xla_hlo.add"(%arg4, %[[eps]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32> + // CHECK-NEXT: %[[add:.*]] = xla_chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32> // CHECK-NEXT: %[[scr1:.*]] = "xla_hlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32> - // CHECK-NEXT: %[[sub:.*]] = "xla_hlo.subtract"(%[[act]], %arg3) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %[[grad]], %[[sub]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8x8x8x8xf32> + // CHECK: %[[bcast_arg3:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[sub:.*]] = xla_hlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32> // CHECK-NEXT: xla_hlo.constant dense<[0, 1, 2]> : tensor<3xi64> // CHECK-NEXT: %[[cmul:.*]] = "xla_hlo.convert"(%[[mul]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: %[[init:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<f32> @@ -142,10 +150,10 @@ func @fusedBatchNormGrad_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32> // CHECK-NEXT: %[[scr2:.*]] = "xla_hlo.convert"(%[[red1]]) : (tensor<8xf32>) -> tensor<8xf32> - // CHECK-NEXT: %[[mul2:.*]] = xla_hlo.multiply %arg2, %[[scr1]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32> - // CHECK-NEXT: %[[mul3:.*]] = "xla_hlo.multiply"(%[[grad]], %[[mul2]]) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> - - // CHECK-NEXT: %[[scale_backprop:.*]] = xla_hlo.multiply %[[scr1]], %[[scr2]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32> + // CHECK-NEXT: %[[mul2:.*]] = xla_hlo.multiply %arg2, %[[scr1]] : tensor<8xf32> + // CHECK: %[[bcast_mul2:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul3:.*]] = xla_hlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[scale_backprop:.*]] = xla_hlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32> // CHECK-NEXT: xla_hlo.constant dense<[0, 1, 2]> : tensor<3xi64> // CHECK-NEXT: %[[cgrad:.*]] = "xla_hlo.convert"(%[[grad]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> @@ -185,11 +193,12 @@ func @fusedBatchNormGradV2_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor< // CHECK-NEXT: %[[act:.*]] = "xla_hlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: %[[eps:.*]] = xla_hlo.constant dense<1.000000e-03> : tensor<f32> - // CHECK-NEXT: %[[add:.*]] = "xla_hlo.add"(%arg4, %[[eps]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32> + // CHECK-NEXT: %[[add:.*]] = xla_chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32> // CHECK-NEXT: %[[scr1:.*]] = "xla_hlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32> - // CHECK-NEXT: %[[sub:.*]] = "xla_hlo.subtract"(%[[act]], %arg3) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %[[grad]], %[[sub]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8x8x8x8xf32> + // CHECK: %[[bcast_arg3:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[sub:.*]] = xla_hlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32> // CHECK-NEXT: xla_hlo.constant dense<[0, 1, 2]> : tensor<3xi64> // CHECK-NEXT: %[[cmul:.*]] = "xla_hlo.convert"(%[[mul]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: %[[init:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<f32> @@ -200,10 +209,11 @@ func @fusedBatchNormGradV2_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor< // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32> // CHECK-NEXT: %[[scr2:.*]] = "xla_hlo.convert"(%[[red1]]) : (tensor<8xf32>) -> tensor<8xf32> - // CHECK-NEXT: %[[mul2:.*]] = xla_hlo.multiply %arg2, %[[scr1]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32> - // CHECK-NEXT: %[[mul3:.*]] = "xla_hlo.multiply"(%[[grad]], %[[mul2]]) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul2:.*]] = xla_hlo.multiply %arg2, %[[scr1]] : tensor<8xf32> + // CHECK: %[[bcast_mul2:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul3:.*]] = xla_hlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[scale_backprop:.*]] = xla_hlo.multiply %[[scr1]], %[[scr2]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32> + // CHECK-NEXT: %[[scale_backprop:.*]] = xla_hlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32> // CHECK-NEXT: xla_hlo.constant dense<[0, 1, 2]> : tensor<3xi64> // CHECK-NEXT: %[[cgrad:.*]] = "xla_hlo.convert"(%[[grad]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> @@ -270,11 +280,12 @@ func @fusedBatchNormGradV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor< // CHECK-NEXT: %[[act:.*]] = "xla_hlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: %[[eps:.*]] = xla_hlo.constant dense<1.000000e-03> : tensor<f32> - // CHECK-NEXT: %[[add:.*]] = "xla_hlo.add"(%arg4, %[[eps]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32> + // CHECK-NEXT: %[[add:.*]] = xla_chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32> // CHECK-NEXT: %[[scr1:.*]] = "xla_hlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32> - // CHECK-NEXT: %[[sub:.*]] = "xla_hlo.subtract"(%[[act]], %arg3) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %[[grad]], %[[sub]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8x8x8x8xf32> + // CHECK: %[[bcast_arg3:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[sub:.*]] = xla_hlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32> // CHECK-NEXT: xla_hlo.constant dense<[0, 1, 2]> : tensor<3xi64> // CHECK-NEXT: %[[cmul:.*]] = "xla_hlo.convert"(%[[mul]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: %[[init:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<f32> @@ -285,10 +296,11 @@ func @fusedBatchNormGradV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor< // CHECK-NEXT: }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32> // CHECK-NEXT: %[[scr2:.*]] = "xla_hlo.convert"(%[[red1]]) : (tensor<8xf32>) -> tensor<8xf32> - // CHECK-NEXT: %[[mul2:.*]] = xla_hlo.multiply %arg2, %[[scr1]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32> - // CHECK-NEXT: %[[mul3:.*]] = "xla_hlo.multiply"(%[[grad]], %[[mul2]]) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul2:.*]] = xla_hlo.multiply %arg2, %[[scr1]] : tensor<8xf32> + // CHECK: %[[bcast_mul2:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul3:.*]] = xla_hlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[scale_backprop:.*]] = xla_hlo.multiply %[[scr1]], %[[scr2]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32> + // CHECK-NEXT: %[[scale_backprop:.*]] = xla_hlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32> // CHECK-NEXT: xla_hlo.constant dense<[0, 1, 2]> : tensor<3xi64> // CHECK-NEXT: %[[cgrad:.*]] = "xla_hlo.convert"(%[[grad]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> @@ -355,11 +367,12 @@ func @fusedBatchNormGradV3_noTraining_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: te // CHECK-NEXT: %[[act:.*]] = "xla_hlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: %[[eps:.*]] = xla_hlo.constant dense<1.000000e-03> : tensor<f32> - // CHECK-NEXT: %[[add:.*]] = "xla_hlo.add"(%arg4, %[[eps]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32> + // CHECK-NEXT: %[[add:.*]] = xla_chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32> // CHECK-NEXT: %[[scr1:.*]] = "xla_hlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32> - // CHECK-NEXT: %[[sub:.*]] = "xla_hlo.subtract"(%[[act]], %arg3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %[[grad]], %[[sub]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8x8x8x8xf32> + // CHECK: %[[bcast_arg3:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[sub:.*]] = xla_hlo.subtract %[[act]], %[[bcast_arg3]] : tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul:.*]] = xla_hlo.multiply %[[grad]], %[[sub]] : tensor<8x8x8x8xf32> // CHECK-NEXT: xla_hlo.constant dense<[0, 2, 3]> : tensor<3xi64> // CHECK-NEXT: %[[cmul:.*]] = "xla_hlo.convert"(%[[mul]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: %[[init:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<f32> @@ -370,10 +383,11 @@ func @fusedBatchNormGradV3_noTraining_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: te // CHECK-NEXT: }) {dimensions = dense<[0, 2, 3]> : tensor<3xi64>} : (tensor<8x8x8x8xf32>, tensor<f32>) -> tensor<8xf32> // CHECK-NEXT: %[[scr2:.*]] = "xla_hlo.convert"(%[[red1]]) : (tensor<8xf32>) -> tensor<8xf32> - // CHECK-NEXT: %[[mul2:.*]] = xla_hlo.multiply %arg2, %[[scr1]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32> - // CHECK-NEXT: %[[mul3:.*]] = "xla_hlo.multiply"(%[[grad]], %[[mul2]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul2:.*]] = xla_hlo.multiply %arg2, %[[scr1]] : tensor<8xf32> + // CHECK: %[[bcast_mul2:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[mul2]], {{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[mul3:.*]] = xla_hlo.multiply %[[grad]], %[[bcast_mul2]] : tensor<8x8x8x8xf32> - // CHECK-NEXT: %[[scale_backprop:.*]] = xla_hlo.multiply %[[scr1]], %[[scr2]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<8xf32> + // CHECK-NEXT: %[[scale_backprop:.*]] = xla_hlo.multiply %[[scr1]], %[[scr2]] : tensor<8xf32> // CHECK-NEXT: xla_hlo.constant dense<[0, 2, 3]> : tensor<3xi64> // CHECK-NEXT: %[[cgrad:.*]] = "xla_hlo.convert"(%[[grad]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> @@ -405,207 +419,41 @@ func @fusedBatchNormGradV3_Training_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: tens // CHECK-LABEL: func @biasAdd_NHWC func @biasAdd_NHWC(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { - // CHECK: "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} + // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0 + // CHECK: %[[ARG0_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[ARG0_SHAPE]]) + // CHECK: %[[ARG1_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]]) + // CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>} + // CHECK: %[[RESULT:.+]] = xla_hlo.add %arg0, %[[ARG1_BCAST]] %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> return %0 : tensor<1x32x10x32xi32> } // CHECK-LABEL: func @biasAdd_NCHW func @biasAdd_NCHW(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { - // CHECK: "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0 + // CHECK: %[[ARG0_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[ARG0_SHAPE]]) + // CHECK: %[[ARG1_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]]) + // CHECK-SAME: {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK: %[[RESULT:.+]] = xla_hlo.add %arg0, %[[ARG1_BCAST]] %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NCHW"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> return %0 : tensor<1x32x10x32xi32> } // CHECK-LABEL: func @biasAdd_dynamic func @biasAdd_dynamic(%arg0: tensor<?x?x?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?x?x?x?xi32> { - // CHECK: "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0 + // CHECK: %[[ARG0_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[ARG0_SHAPE]]) + // CHECK: %[[ARG1_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]]) + // CHECK-SAME: {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK: %[[RESULT:.+]] = xla_hlo.add %arg0, %[[ARG1_BCAST]] %0 = "tf.BiasAdd"(%arg0, %arg1) {data_format = "NCHW"} : (tensor<?x?x?x?xi32>, tensor<?xi32>) -> tensor<?x?x?x?xi32> return %0 : tensor<?x?x?x?xi32> } //===----------------------------------------------------------------------===// -// Binary op legalizations. -// Most of these expand from the same pattern. Full semantics are -// verified for tf.Add and pattern application only for the rest. +// DiagPart //===----------------------------------------------------------------------===// -// CHECK-LABEL: func @add -func @add(%arg0: tensor<2xi32>) -> tensor<2xi32> { - // CHECK-NEXT: %[[SUM0:.*]] = xla_hlo.add %arg0, %arg0 : tensor<2xi32> - // CHECK-NEXT: %[[SUM1:.*]] = xla_hlo.add %[[SUM0]], %arg0 : tensor<2xi32> - // CHECK-NEXT: return %[[SUM1]] : tensor<2xi32> - %0 = "tf.Add"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - %1 = "tf.AddV2"(%0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - return %1: tensor<2xi32> -} - -// CHECK-LABEL: func @broadcast_add -// TODO(laurenzo): Change this to a (5 + 2x1) shaped add to make the check -// patterns unambiguous and more interesting (once broadcastable trait is -// fixed upstream). -func @broadcast_add(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - // CHECK: %[[UNUSED_LHS_SHAPE:.+]] = shape.const_shape [1] - // CHECK: %[[UNUSED_RHS_SHAPE:.+]] = shape.const_shape [1, 2] - // CHECK: %[[RESULT_SHAPE:.+]] = shape.const_shape [1, 2] - // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) - // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} - // CHECK: xla_hlo.add %[[LHS_BCAST]], %[[RHS_BCAST]] - %0 = "tf.Add"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> - return %0: tensor<1x2xi32> -} - -// CHECK-LABEL: func @broadcast_multi_dim_add -// TODO(laurenzo): Change this to a (4x1x1 + 1x4x4x4) shaped add once upstream -// broadcastable bug is fixed (helps make the CHECK matching unambiguous) -func @broadcast_multi_dim_add(%arg0: tensor<4x1x1xi32>, %arg1: tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> { - // CHECK: %[[UNUSED_LHS_SHAPE:.+]] = shape.const_shape [4, 1, 1] - // CHECK: %[[UNUSED_RHS_SHAPE:.+]] = shape.const_shape [4, 4, 4, 4] - // CHECK: %[[RESULT_SHAPE:.+]] = shape.const_shape [4, 4, 4, 4] - // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) - // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>} - // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>} - // CHECK: xla_hlo.add %[[LHS_BCAST]], %[[RHS_BCAST]] - %0 = "tf.Add"(%arg0, %arg1) : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> - return %0: tensor<4x4x4x4xi32> -} - -// CHECK-LABEL: func @add_dynamic -func @add_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<?x?xi32>) -> tensor<?x?xi32> { - // CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.shape_of %arg0 - // CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.shape_of %arg1 - // CHECK-DAG: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE]], %[[RHS_SHAPE]]) - // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) - // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} - // CHECK: xla_hlo.add %4, %5 : tensor<?x?xi32> - %0 = "tf.Add"(%arg0, %arg1) : (tensor<?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32> - return %0: tensor<?x?xi32> -} - -// CHECK-LABEL: func @div -func @div(%arg0: tensor<2xi32>) -> tensor<2xi32> { - // CHECK-NEXT: %0 = xla_hlo.divide %arg0, %arg0 : tensor<2xi32> - // CHECK-NEXT: return %0 : tensor<2xi32> - %0 = "tf.Div"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - return %0: tensor<2xi32> -} - -// CHECK-LABEL: func @shift_left -func @shift_left(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { - // CHECK: xla_hlo.shift_left %arg0, %arg1 : tensor<4xi32> - %0 = "tf.LeftShift"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - return %0 : tensor<4xi32> -} - -// CHECK-LABEL: func @div_unranked -func @div_unranked(%arg0: tensor<*xi32>, %arg1: tensor<?x?xi32>) -> tensor<?x?xi32> { - // CHECK: tf.Div - %0 = "tf.Div"(%arg0, %arg1) : (tensor<*xi32>, tensor<?x?xi32>) -> tensor<?x?xi32> - return %0: tensor<?x?xi32> -} - -// CHECK-LABEL: func @maximum -func @maximum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: xla_hlo.maximum %arg0, %arg1 : tensor<4xf32> - %0 = "tf.Maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - return %0 : tensor<4xf32> -} - -// CHECK-LABEL: func @minimum -func @minimum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: xla_hlo.minimum %arg0, %arg1 : tensor<4xf32> - %0 = "tf.Minimum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - return %0 : tensor<4xf32> -} - -// CHECK-LABEL: func @mul -func @mul(%arg0: tensor<2xi32>) -> tensor<2xi32> { - // CHECK-NEXT: %0 = xla_hlo.multiply %arg0, %arg0 : tensor<2xi32> - // CHECK-NEXT: return %0 : tensor<2xi32> - %0 = "tf.Mul"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - return %0: tensor<2xi32> -} - -// CHECK-LABEL: func @real_div -func @real_div(%arg0: tensor<2xi32>) -> tensor<2xi32> { - // CHECK-NEXT: %0 = xla_hlo.divide %arg0, %arg0 : tensor<2xi32> - %0 = "tf.RealDiv"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - return %0: tensor<2xi32> -} - -// CHECK-LABEL: func @sub -func @sub(%arg0: tensor<2xi32>) -> tensor<2xi32> { - // CHECK-NEXT: %0 = xla_hlo.subtract %arg0, %arg0 : tensor<2xi32> - // CHECK-NEXT: return %0 : tensor<2xi32> - %0 = "tf.Sub"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - return %0: tensor<2xi32> -} - -// CHECK-LABEL: func @shift_right -func @shift_right(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { - // CHECK: xla_hlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32> - %0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - return %0 : tensor<4xi32> -} - -// CHECK-LABEL: func @shift_right_unsigned -func @shift_right_unsigned(%arg0: tensor<4xui8>, %arg1: tensor<4xui8>) -> tensor<4xui8> { - // CHECK: tf.RightShift - %0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xui8>, tensor<4xui8>) -> tensor<4xui8> - return %0 : tensor<4xui8> -} - -// CHECK-LABEL: func @broadcast_shift_right_unsigned -func @broadcast_shift_right_unsigned(%arg0: tensor<4xui8>, %arg1: tensor<2x4xui8>) -> tensor<2x4xui8> { - // CHECK: tf.RightShift - %0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xui8>, tensor<2x4xui8>) -> tensor<2x4xui8> - return %0 : tensor<2x4xui8> -} - -// CHECK-LABEL: func @and -func @and(%arg0: tensor<2xi1>) -> tensor<2xi1> { - // CHECK-NEXT: xla_hlo.and - %0 = "tf.LogicalAnd"(%arg0, %arg0) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1> - return %0: tensor<2xi1> -} - -// CHECK-LABEL: func @and_unranked -func @and_unranked(%arg0: tensor<*xi1>, %arg1: tensor<*xi1>) -> tensor<*xi1> { - // CHECK: tf.LogicalAnd - %0 = "tf.LogicalAnd"(%arg0, %arg1) : (tensor<*xi1>, tensor<*xi1>) -> tensor<*xi1> - return %0: tensor<*xi1> -} - -// CHECK-LABEL: func @or -func @or(%arg0: tensor<2xi1>) -> tensor<2xi1> { - // CHECK-NEXT: xla_hlo.or - %0 = "tf.LogicalOr"(%arg0, %arg0) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1> - return %0: tensor<2xi1> -} - -// CHECK-LABEL: func @bitwise_or -func @bitwise_or(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { - // CHECK-NEXT: xla_hlo.or - %0 = "tf.BitwiseOr"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - return %0: tensor<4xi32> -} - -// CHECK-LABEL: func @bitwise_and -func @bitwise_and(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { - // CHECK-NEXT: xla_hlo.and - %0 = "tf.BitwiseAnd"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - return %0: tensor<4xi32> -} - -// CHECK-LABEL: func @pow -func @pow(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK-NEXT: xla_hlo.power - %0 = "tf.Pow"(%arg0, %arg0) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> - return %0: tensor<2xf32> -} - // CHECK-LABEL: func @diag_part // CHECK-SAME: %[[ARG:.*]]: tensor<4x3x4x3xf32> func @diag_part(%arg0: tensor<4x3x4x3xf32>) -> tensor<4x3xf32> { @@ -625,6 +473,10 @@ func @diag_part(%arg0: tensor<4x3x4x3xf32>) -> tensor<4x3xf32> { return %0: tensor<4x3xf32> } +//===----------------------------------------------------------------------===// +// Einsum. +//===----------------------------------------------------------------------===// + // CHECK-LABEL: func @einsum func @einsum(%arg0: tensor<2x3xf32>, %arg1: tensor<3x4xf32>) -> tensor<2x4xf32> { // CHECK: xla_hlo.einsum @@ -639,22 +491,26 @@ func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> { return %0: tensor<2x2xf32> } +//===----------------------------------------------------------------------===// +// FloorDiv and FloorMod. +//===----------------------------------------------------------------------===// + // CHECK-LABEL: func @floordiv_broadcast_i32 func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> { // CHECK-DAG: [[ZEROS1:%.+]] = xla_hlo.constant dense<0> - // CHECK-DAG: [[CMP1:%.+]] = "xla_hlo.compare"(%arg0, [[ZEROS1]]) {comparison_direction = "LT"} + // CHECK-DAG: [[CMP1:%.+]] = xla_chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"} // CHECK-DAG: [[ZEROS2:%.+]] = xla_hlo.constant dense<0> - // CHECK-DAG: [[CMP2:%.+]] = "xla_hlo.compare"(%arg1, [[ZEROS2]]) {comparison_direction = "LT"} - // CHECK-DAG: [[CMP3:%.+]] = "xla_hlo.compare"([[CMP1]], [[CMP2]]) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} - // CHECK-DAG: [[DIV1:%.+]] = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[CMP2:%.+]] = xla_chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"} + // CHECK-DAG: [[CMP3:%.+]] = xla_chlo.broadcast_compare [[CMP1]], [[CMP2]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} + // CHECK-DAG: [[DIV1:%.+]] = xla_chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: [[ABS1:%.+]] = "xla_hlo.abs"(%arg0) // CHECK-DAG: [[ABS2:%.+]] = "xla_hlo.abs"(%arg1) - // CHECK-DAG: [[ZEROS3:%.+]] = xla_hlo.constant dense<1> - // CHECK-DAG: [[SUB:%.+]] = xla_hlo.subtract [[ABS2]], [[ZEROS3]] - // CHECK-DAG: [[ADD:%.+]] = "xla_hlo.add"([[ABS1]], [[SUB]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[ONES:%.+]] = xla_hlo.constant dense<1> + // CHECK-DAG: [[SUB:%.+]] = xla_chlo.broadcast_subtract [[ABS2]], [[ONES]] + // CHECK-DAG: [[ADD:%.+]] = xla_chlo.broadcast_add [[ABS1]], [[SUB]] {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: [[NEG:%.+]] = "xla_hlo.negate"([[ADD]]) // CHECK-DAG: [[ABS3:%.+]] = "xla_hlo.abs"(%arg1) - // CHECK-DAG: [[DIV2:%.+]] = "xla_hlo.divide"([[NEG]], [[ABS3]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[DIV2:%.+]] = xla_chlo.broadcast_divide [[NEG]], [[ABS3]] {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: [[SELECT:%.+]] = "xla_hlo.select"([[CMP3]], [[DIV1]], [[DIV2]]) // CHECK: return [[SELECT]] %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> @@ -664,19 +520,19 @@ func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> te // CHECK-LABEL: func @floordiv_reverse_broadcast_i32 func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> { // CHECK-DAG: [[ZEROS1:%.+]] = xla_hlo.constant dense<0> - // CHECK-DAG: [[CMP1:%.+]] = "xla_hlo.compare"(%arg0, [[ZEROS1]]) {comparison_direction = "LT"} + // CHECK-DAG: [[CMP1:%.+]] = xla_chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"} // CHECK-DAG: [[ZEROS2:%.+]] = xla_hlo.constant dense<0> - // CHECK-DAG: [[CMP2:%.+]] = "xla_hlo.compare"(%arg1, [[ZEROS2]]) {comparison_direction = "LT"} - // CHECK-DAG: [[CMP3:%.+]] = "xla_hlo.compare"([[CMP1]], [[CMP2]]) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} - // CHECK-DAG: [[DIV1:%.+]] = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[CMP2:%.+]] = xla_chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"} + // CHECK-DAG: [[CMP3:%.+]] = xla_chlo.broadcast_compare [[CMP1]], [[CMP2]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} + // CHECK-DAG: [[DIV1:%.+]] = xla_chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: [[ABS1:%.+]] = "xla_hlo.abs"(%arg0) // CHECK-DAG: [[ABS2:%.+]] = "xla_hlo.abs"(%arg1) - // CHECK-DAG: [[ZEROS3:%.+]] = xla_hlo.constant dense<1> - // CHECK-DAG: [[SUB:%.+]] = xla_hlo.subtract [[ABS2]], [[ZEROS3]] - // CHECK-DAG: [[ADD:%.+]] = "xla_hlo.add"([[ABS1]], [[SUB]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[ONES:%.+]] = xla_hlo.constant dense<1> + // CHECK-DAG: [[SUB:%.+]] = xla_chlo.broadcast_subtract [[ABS2]], [[ONES]] + // CHECK-DAG: [[ADD:%.+]] = xla_chlo.broadcast_add [[ABS1]], [[SUB]] {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: [[NEG:%.+]] = "xla_hlo.negate"([[ADD]]) // CHECK-DAG: [[ABS3:%.+]] = "xla_hlo.abs"(%arg1) - // CHECK-DAG: [[DIV2:%.+]] = xla_hlo.divide [[NEG]], [[ABS3]] + // CHECK-DAG: [[DIV2:%.+]] = xla_chlo.broadcast_divide [[NEG]], [[ABS3]] // CHECK-DAG: [[SELECT:%.+]] = "xla_hlo.select"([[CMP3]], [[DIV1]], [[DIV2]]) // CHECK: return [[SELECT]] %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> @@ -685,7 +541,7 @@ func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32 // CHECK-LABEL: func @floordiv_f32 func @floordiv_f32(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK-NEXT: %[[DIV:.*]] = xla_hlo.divide %arg0, %arg0 + // CHECK-NEXT: %[[DIV:.*]] = xla_chlo.broadcast_divide %arg0, %arg0 // CHECK-NEXT: %[[FLOOR:.*]] = "xla_hlo.floor"(%[[DIV]]) // CHECK-NEXT: return %[[FLOOR]] : tensor<2xf32> %0 = "tf.FloorDiv"(%arg0, %arg0) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> @@ -696,7 +552,7 @@ func @floordiv_f32(%arg0: tensor<2xf32>) -> tensor<2xf32> { func @floordiv_bf16(%arg0: tensor<2xbf16>) -> tensor<2xbf16> { // CHECK-NEXT: xla_hlo.convert // CHECK-NEXT: xla_hlo.convert - // CHECK-NEXT: xla_hlo.divide + // CHECK-NEXT: xla_chlo.broadcast_divide // CHECK-NEXT: xla_hlo.floor // CHECK-NEXT: xla_hlo.convert // CHECK-NEXT: return @@ -706,7 +562,7 @@ func @floordiv_bf16(%arg0: tensor<2xbf16>) -> tensor<2xbf16> { // CHECK-LABEL: func @floordiv_f16_broadcast func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> tensor<2x3xf16> { - // CHECK-NEXT: xla_hlo.divide + // CHECK-NEXT: xla_chlo.broadcast_divide // CHECK-NEXT: xla_hlo.floor // CHECK-NEXT: return %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> @@ -715,7 +571,22 @@ func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> te // CHECK-LABEL: func @floordiv_dynamic func @floordiv_dynamic(%arg0: tensor<?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?x?xi32> { - // CHECK: tf.FloorDiv + // CHECK-DAG: [[ZEROS1:%.+]] = xla_hlo.constant dense<0> + // CHECK-DAG: [[CMP1:%.+]] = xla_chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"} + // CHECK-DAG: [[ZEROS2:%.+]] = xla_hlo.constant dense<0> + // CHECK-DAG: [[CMP2:%.+]] = xla_chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"} + // CHECK-DAG: [[CMP3:%.+]] = xla_chlo.broadcast_compare [[CMP1]], [[CMP2]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} + // CHECK-DAG: [[DIV1:%.+]] = xla_chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[ABS1:%.+]] = "xla_hlo.abs"(%arg0) + // CHECK-DAG: [[ABS2:%.+]] = "xla_hlo.abs"(%arg1) + // CHECK-DAG: [[ONES:%.+]] = xla_hlo.constant dense<1> + // CHECK-DAG: [[SUB:%.+]] = xla_chlo.broadcast_subtract [[ABS2]], [[ONES]] + // CHECK-DAG: [[ADD:%.+]] = xla_chlo.broadcast_add [[ABS1]], [[SUB]] {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[NEG:%.+]] = "xla_hlo.negate"([[ADD]]) + // CHECK-DAG: [[ABS3:%.+]] = "xla_hlo.abs"(%arg1) + // CHECK-DAG: [[DIV2:%.+]] = xla_chlo.broadcast_divide [[NEG]], [[ABS3]] {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[SELECT:%.+]] = "xla_hlo.select"([[CMP3]], [[DIV1]], [[DIV2]]) + // CHECK: return [[SELECT]] %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<?x?xi32>, tensor<?xi32>) -> tensor<?x?xi32> return %0: tensor<?x?xi32> } @@ -729,15 +600,15 @@ func @floordiv_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*x // CHECK-LABEL: func @floormod_broadcast_numerator func @floormod_broadcast_numerator(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> { - // CHECK-DAG: [[REM:%.+]] = "xla_hlo.remainder"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[REM:%.+]] = xla_chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: [[ZL:%.+]] = xla_hlo.constant dense<0> - // CHECK-DAG: [[CMP1:%.+]] = "xla_hlo.compare"([[REM]], [[ZL]]) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} + // CHECK-DAG: [[CMP1:%.+]] = xla_chlo.broadcast_compare [[REM]], [[ZL]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} // CHECK-DAG: [[ZR:%.+]] = xla_hlo.constant dense<0> - // CHECK-DAG: [[CMP2:%.+]] = "xla_hlo.compare"(%arg1, [[ZR:%.+]]) {comparison_direction = "LT"} - // CHECK-DAG: [[CMP3:%.+]] = "xla_hlo.compare"([[REM:%.+]], [[ZR]]) {comparison_direction = "LT"} - // CHECK-DAG: [[CMP4:%.+]] = "xla_hlo.compare"([[CMP2]], [[CMP3]]) {comparison_direction = "NE"} - // CHECK-DAG: [[AND:%.+]] = xla_hlo.and [[CMP1]], [[CMP4]] - // CHECK-DAG: [[ADD:%.+]] = xla_hlo.add %arg1, [[REM]] + // CHECK-DAG: [[CMP2:%.+]] = xla_chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = "LT"} + // CHECK-DAG: [[CMP3:%.+]] = xla_chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "LT"} + // CHECK-DAG: [[CMP4:%.+]] = xla_chlo.broadcast_compare [[CMP2]], [[CMP3]] {comparison_direction = "NE"} + // CHECK-DAG: [[AND:%.+]] = xla_chlo.broadcast_and [[CMP1]], [[CMP4]] + // CHECK-DAG: [[ADD:%.+]] = xla_chlo.broadcast_add %arg1, [[REM]] // CHECK-DAG: [[SELECT:%.+]] = "xla_hlo.select"([[AND]], [[ADD]], [[REM]]) // CHECK-NEXT: return [[SELECT]] %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> @@ -746,15 +617,15 @@ func @floormod_broadcast_numerator(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) // CHECK-LABEL: func @floormod_broadcast_denominator func @floormod_broadcast_denominator(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> { - // CHECK-DAG: [[REM:%.+]] = "xla_hlo.remainder"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[REM:%.+]] = xla_chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: [[ZL:%.+]] = xla_hlo.constant dense<0> - // CHECK-DAG: [[CMP1:%.+]] = "xla_hlo.compare"([[REM]], [[ZL]]) {comparison_direction = "NE"} + // CHECK-DAG: [[CMP1:%.+]] = xla_chlo.broadcast_compare [[REM]], [[ZL]] {comparison_direction = "NE"} // CHECK-DAG: [[ZR:%.+]] = xla_hlo.constant dense<0> - // CHECK-DAG: [[CMP2:%.+]] = "xla_hlo.compare"(%arg1, [[ZR:%.+]]) {comparison_direction = "LT"} - // CHECK-DAG: [[CMP3:%.+]] = "xla_hlo.compare"([[REM:%.+]], [[ZR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LT"} - // CHECK-DAG: [[CMP4:%.+]] = "xla_hlo.compare"([[CMP2]], [[CMP3]]) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} - // CHECK-DAG: [[AND:%.+]] = xla_hlo.and [[CMP1]], [[CMP4]] - // CHECK-DAG: [[ADD:%.+]] = "xla_hlo.add"(%arg1, [[REM]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[CMP2:%.+]] = xla_chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = "LT"} + // CHECK-DAG: [[CMP3:%.+]] = xla_chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "LT"} + // CHECK-DAG: [[CMP4:%.+]] = xla_chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} + // CHECK-DAG: [[AND:%.+]] = xla_chlo.broadcast_and [[CMP1]], [[CMP4]] + // CHECK-DAG: [[ADD:%.+]] = xla_chlo.broadcast_add %arg1, [[REM]] {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: [[SELECT:%.+]] = "xla_hlo.select"([[AND]], [[ADD]], [[REM]]) // CHECK-NEXT: return [[SELECT]] %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> @@ -763,7 +634,17 @@ func @floormod_broadcast_denominator(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32 // CHECK-LABEL: func @floormod_dynamic func @floormod_dynamic(%arg0: tensor<?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?x?xi32> { - // CHECK: tf.FloorMod + // CHECK-DAG: [[REM:%.+]] = xla_chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[ZL:%.+]] = xla_hlo.constant dense<0> + // CHECK-DAG: [[CMP1:%.+]] = xla_chlo.broadcast_compare [[REM]], [[ZL]] {comparison_direction = "NE"} + // CHECK-DAG: [[ZR:%.+]] = xla_hlo.constant dense<0> + // CHECK-DAG: [[CMP2:%.+]] = xla_chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = "LT"} + // CHECK-DAG: [[CMP3:%.+]] = xla_chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "LT"} + // CHECK-DAG: [[CMP4:%.+]] = xla_chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} + // CHECK-DAG: [[AND:%.+]] = xla_chlo.broadcast_and [[CMP1]], [[CMP4]] + // CHECK-DAG: [[ADD:%.+]] = xla_chlo.broadcast_add %arg1, [[REM]] {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[SELECT:%.+]] = "xla_hlo.select"([[AND]], [[ADD]], [[REM]]) + // CHECK-NEXT: return [[SELECT]] %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<?x?xi32>, tensor<?xi32>) -> tensor<?x?xi32> return %0: tensor<?x?xi32> } @@ -775,6 +656,10 @@ func @floormod_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*x return %0: tensor<*xi32> } +//===----------------------------------------------------------------------===// +// BroadcastTo. +//===----------------------------------------------------------------------===// + // CHECK-LABEL: func @broadcast_to func @broadcast_to(%arg0: tensor<16xf32>) -> tensor<16x16x16x16xf32> { %cst = "tf.Const"() { value = dense<16> : tensor<4xi32> } : () -> tensor<4xi32> @@ -787,155 +672,6 @@ func @broadcast_to(%arg0: tensor<16xf32>) -> tensor<16x16x16x16xf32> { return %0 : tensor<16x16x16x16xf32> } -//===----------------------------------------------------------------------===// -// Equality op legalizations. -// tf.Equal and tf.NotEqual expand from the same pattern. Full semantics are -// verified for tf.Equal and pattern application only for tf.NotEqual -//===----------------------------------------------------------------------===// - -// CHECK-LABEL: func @equal -func @equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} - %0 = "tf.Equal"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> - return %0: tensor<2xi1> -} - -// CHECK-LABEL: func @equal_dynamic -func @equal_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi1> { - // CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.shape_of %arg0 - // CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.const_shape [1] - // CHECK-DAG: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE]], %[[RHS_SHAPE]]) - // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) - // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} - // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} - // CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "EQ"} - %0 = "tf.Equal"(%arg0, %arg1) : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1> - return %0: tensor<?xi1> -} - -// CHECK-LABEL: func @equal_broadcast -func @equal_broadcast(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - // CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.const_shape [1] - // CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.const_shape [1, 2] - // CHECK-DAG: %[[RESULT_SHAPE:.+]] = shape.const_shape [1, 2] - // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) - // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} - // CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "EQ"} - %0 = "tf.Equal"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> - return %0: tensor<1x2xi1> -} - -// CHECK-LABEL: func @equal_broadcast_no_incompatible_shapes_error -func @equal_broadcast_no_incompatible_shapes_error(%arg0: tensor<2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - // CHECK-NEXT: "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = false} - %0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> - return %0: tensor<1x2xi1> -} - -// CHECK-LABEL: func @equal_incompatible_shape_broadcastable -func @equal_incompatible_shape_broadcastable(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi1> { - // CHECK-NEXT: "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = false} - %0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1> - return %0: tensor<?xi1> -} - -// CHECK-LABEL: func @equal_incompatible_shape_dynamic -func @equal_incompatible_shape_dynamic(%arg0: tensor<2xi32>, %arg1: tensor<?xi32>) -> tensor<*xi1> { - // CHECK-NEXT: "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = false} - %0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor<2xi32>, tensor<?xi32>) -> tensor<*xi1> - return %0: tensor<*xi1> -} - -// CHECK-LABEL: func @equal_incompatible_shape_both_dynamic -func @equal_incompatible_shape_both_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<*xi1> { - // CHECK-NEXT: "tf.Equal"(%arg0, %arg1) {incompatible_shape_error = false} - %0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor<?xi32>, tensor<?xi32>) -> tensor<*xi1> - return %0: tensor<*xi1> -} - -// CHECK-LABEL: func @equal_unranked -func @equal_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi1> { - // CHECK: "tf.Equal" - %0 = "tf.Equal"(%arg0, %arg1) { incompatible_shape_error = false } : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi1> - return %0: tensor<*xi1> -} - -// CHECK-LABEL: func @notequal -func @notequal(%arg0: tensor<2xi32>) -> tensor<2xi1> { - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} - %0 = "tf.NotEqual"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> - return %0: tensor<2xi1> -} - -//===----------------------------------------------------------------------===// -// Compare op legalizations. -// These expand from the same pattern. Full semantics are checked for -// tf.Greater. Others just check that the pattern applied. -//===----------------------------------------------------------------------===// - -// CHECK-LABEL: func @greater -func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> { - // CHECK: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} - %0 = "tf.Greater"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> - return %0: tensor<2xi1> -} - -// CHECK-LABEL: func @broadcast_greater -func @broadcast_greater(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - // CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.const_shape [1] - // CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.const_shape [1, 2] - // CHECK-DAG: %[[RESULT_SHAPE:.+]] = shape.const_shape [1, 2] - // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) - // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} - // CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "GT"} - %0 = "tf.Greater"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> - return %0: tensor<1x2xi1> -} - -// CHECK-LABEL: func @greater_dynamic -func @greater_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi1> { - // CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.shape_of %arg0 - // CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.shape_of %arg1 - // CHECK-DAG: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE]], %[[RHS_SHAPE]]) - // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) - // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} - // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} - // CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "GT"} - %0 = "tf.Greater"(%arg0, %arg1) : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi1> - return %0: tensor<?xi1> -} - -// CHECK-LABEL: func @greater_uranked -func @greater_uranked(%arg0: tensor<*xi32>) -> tensor<*xi1> { - // CHECK: "tf.Greater" - %0 = "tf.Greater"(%arg0, %arg0) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi1> - return %0: tensor<*xi1> -} - -// CHECK-LABEL: func @greater_equal -func @greater_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} - %0 = "tf.GreaterEqual"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> - return %0: tensor<2xi1> -} - -// CHECK-LABEL: func @less -func @less(%arg0: tensor<2xi32>) -> tensor<2xi1> { - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} - %0 = "tf.Less"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> - return %0: tensor<2xi1> -} - -// CHECK-LABEL: func @less_equal -func @less_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} - %0 = "tf.LessEqual"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> - return %0: tensor<2xi1> -} - - //===----------------------------------------------------------------------===// // Complex op legalizations. //===----------------------------------------------------------------------===// @@ -1224,12 +960,12 @@ func @matrix_band_part(%arg0: tensor<64x64xbf16>, %arg1: tensor<i64>, %arg2: ten // CHECK: %[[X:.*]] = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<64x64xbf16> // CHECK: %[[Y:.*]] = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<64x64xbf16> // CHECK: %[[OFFSET:.*]] = xla_hlo.subtract %[[X]], %[[Y]] : tensor<64x64xbf16> - // CHECK: %[[G:.*]] = "xla_hlo.compare"(%[[F]], %[[OFFSET]]) {comparison_direction = "LE"} : (tensor<bf16>, tensor<64x64xbf16>) -> tensor<*xi1> + // CHECK: %[[G:.*]] = xla_chlo.broadcast_compare %[[F]], %[[OFFSET]] {comparison_direction = "LE"} : (tensor<bf16>, tensor<64x64xbf16>) -> tensor<64x64xi1> // CHECK: %[[H:.*]] = "xla_hlo.convert"(%[[D]]) : (tensor<i64>) -> tensor<bf16> - // CHECK: %[[I:.*]] = "xla_hlo.compare"(%[[OFFSET]], %[[H]]) {comparison_direction = "LE"} : (tensor<64x64xbf16>, tensor<bf16>) -> tensor<*xi1> + // CHECK: %[[I:.*]] = xla_chlo.broadcast_compare %[[OFFSET]], %[[H]] {comparison_direction = "LE"} : (tensor<64x64xbf16>, tensor<bf16>) -> tensor<64x64xi1> - // CHECK: %[[J:.*]] = xla_hlo.and %[[G]], %[[I]] : tensor<*xi1> + // CHECK: %[[J:.*]] = xla_hlo.and %[[G]], %[[I]] : tensor<64x64xi1> // CHECK: %[[ZERO2:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<64x64xbf16> // CHECK: %[[R:.*]] = "xla_hlo.select"(%[[J]], %[[INPUT]], %[[ZERO2]]) @@ -1245,11 +981,11 @@ func @matrix_band_part_2(%arg0: tensor<12x24x48xbf16>, %arg1: tensor<i64>, %arg2 // CHECK: %[[Y:.*]] = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<24x48xbf16> // CHECK: %[[OFFSET:.*]] = xla_hlo.subtract %[[X]], %[[Y]] : tensor<24x48xbf16> - // CHECK: %[[G:.*]] = "xla_hlo.compare"(%[[F]], %[[OFFSET]]) {comparison_direction = "LE"} : (tensor<bf16>, tensor<24x48xbf16>) -> tensor<*xi1> + // CHECK: %[[G:.*]] = xla_chlo.broadcast_compare %[[F]], %[[OFFSET]] {comparison_direction = "LE"} : (tensor<bf16>, tensor<24x48xbf16>) -> tensor<24x48xi1> // CHECK: %[[H:.*]] = "xla_hlo.convert"(%[[D]]) : (tensor<i64>) -> tensor<bf16> - // CHECK: %[[I:.*]] = "xla_hlo.compare"(%[[OFFSET]], %[[H]]) {comparison_direction = "LE"} : (tensor<24x48xbf16>, tensor<bf16>) -> tensor<*xi1> - // CHECK: %[[J:.*]] = xla_hlo.and %[[G]], %[[I]] {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : tensor<*xi1> + // CHECK: %[[I:.*]] = xla_chlo.broadcast_compare %[[OFFSET]], %[[H]] {comparison_direction = "LE"} : (tensor<24x48xbf16>, tensor<bf16>) -> tensor<24x48xi1> + // CHECK: %[[J:.*]] = xla_hlo.and %[[G]], %[[I]] : tensor<24x48xi1> // CHECK: %[[ZERO2:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<12x24x48xbf16> // CHECK: %[[R:.*]] = "xla_hlo.select"(%[[J]], %[[INPUT]], %[[ZERO2]]) @@ -1396,7 +1132,8 @@ func @max_pool_3d_grad_same(%orig_input: tensor<2x8x13x25x7xf32>, %orig_output: // CHECK-LABEL:one_hot func @one_hot(%indices: tensor<3xi32>, %on_value: tensor<f32>, %off_value: tensor<f32>) -> tensor<3x5xf32> { // CHECK: %[[IOTA:.*]] = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<3x5xi32> - // CHECK: %[[COMPARE:.*]] = "xla_hlo.compare"(%arg0, %[[IOTA]]) {broadcast_dimensions = dense<0> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<3xi32>, tensor<3x5xi32>) -> tensor<3x5xi1> + // CHECK: %[[BCAST_ARG0:.+]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<3x5xi32> + // CHECK: %[[COMPARE:.*]] = "xla_hlo.compare"(%[[BCAST_ARG0]], %[[IOTA]]) {comparison_direction = "EQ"} : (tensor<3x5xi32>, tensor<3x5xi32>) -> tensor<3x5xi1> // CHECK: %[[ON_VALUE:.*]] = "xla_hlo.broadcast"(%arg1) {broadcast_sizes = dense<[3, 5]> : tensor<2xi64>} : (tensor<f32>) -> tensor<3x5xf32> // CHECK: %[[OFF_VALUE:.*]] = "xla_hlo.broadcast"(%arg2) {broadcast_sizes = dense<[3, 5]> : tensor<2xi64>} : (tensor<f32>) -> tensor<3x5xf32> // CHECK: %[[RESULT:.*]] = "xla_hlo.select"(%[[COMPARE]], %[[ON_VALUE]], %[[OFF_VALUE]]) : (tensor<3x5xi1>, tensor<3x5xf32>, tensor<3x5xf32>) -> tensor<3x5xf32> @@ -1561,7 +1298,7 @@ func @stateful_pcall_multi_in_out(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (te // CHECK-LABEL: func @relu func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> { // CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0> : tensor<i32> - // CHECK: "xla_hlo.maximum"(%[[ZERO]], %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32> + // CHECK: xla_chlo.broadcast_maximum %[[ZERO]], %arg0 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32> %0 = "tf.Relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> return %0: tensor<1xi32> } @@ -1569,7 +1306,7 @@ func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> { // CHECK-LABEL: func @relu_unranked func @relu_unranked(%arg0: tensor<?xi32>) -> tensor<?xi32> { // CHECK: %[[ZERO:.*]] = xla_hlo.constant dense<0> : tensor<i32> - // CHECK: "xla_hlo.maximum"(%[[ZERO]], %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>, tensor<?xi32>) -> tensor<?xi32> + // CHECK: xla_chlo.broadcast_maximum %[[ZERO]], %arg0 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>, tensor<?xi32>) -> tensor<?xi32> %0 = "tf.Relu"(%arg0) : (tensor<?xi32>) -> tensor<?xi32> return %0: tensor<?xi32> } @@ -1597,8 +1334,8 @@ func @relu6_unranked(%arg0: tensor<?xi32>) -> tensor<?xi32> { func @relu_grad(%gradients: tensor<4x8xf32>, %features: tensor<?x?xf32>) -> tensor<4x8xf32> { // CHECK-DAG: %[[ZERO_SCALAR:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<f32> // CHECK-DAG: %[[ZERO:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<4x8xf32> - // CHECK-DAG: %[[PRED:.*]] = "xla_hlo.compare"(%[[FEATURES]], %[[ZERO_SCALAR]]) {comparison_direction = "GT"} : (tensor<?x?xf32>, tensor<f32>) -> tensor<*xi1> - // CHECK-DAG: %[[RESULT:.*]] = "xla_hlo.select"(%[[PRED]], %[[GRADIENTS]], %[[ZERO]]) : (tensor<*xi1>, tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32> + // CHECK-DAG: %[[PRED:.*]] = xla_chlo.broadcast_compare %[[FEATURES]], %[[ZERO_SCALAR]] {comparison_direction = "GT"} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xi1> + // CHECK-DAG: %[[RESULT:.*]] = "xla_hlo.select"(%[[PRED]], %[[GRADIENTS]], %[[ZERO]]) : (tensor<?x?xi1>, tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32> // CHECK-DAG: return %[[RESULT]] : tensor<4x8xf32> %2 = "tf.ReluGrad"(%gradients, %features) : (tensor<4x8xf32>, tensor<?x?xf32>) -> tensor<4x8xf32> return %2 : tensor<4x8xf32> @@ -1608,27 +1345,6 @@ func @relu_grad(%gradients: tensor<4x8xf32>, %features: tensor<?x?xf32>) -> tens // Select op legalizations. //===----------------------------------------------------------------------===// -// CHECK-LABEL: func @select -func @select(%arg0: tensor<2xi1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> { - // CHECK-NEXT: "xla_hlo.select"(%arg0, %arg1, %arg2) - %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - return %0: tensor<2xi32> -} - -// CHECK-LABEL: func @select_float -func @select_float(%arg0: tensor<2xi1>, %arg1: tensor<2xf32>, %arg2: tensor<2xf32>) -> tensor<2xf32> { - // CHECK-NEXT: "xla_hlo.select"(%arg0, %arg1, %arg2) - %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> - return %0: tensor<2xf32> -} - -// CHECK-LABEL: func @select_multidimensional -func @select_multidimensional(%arg0: tensor<3x2xi1>, %arg1: tensor<3x2xi32>, %arg2: tensor<3x2xi32>) -> tensor<3x2xi32> { - // CHECK-NEXT: "xla_hlo.select"(%arg0, %arg1, %arg2) - %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<3x2xi1>, tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32> - return %0: tensor<3x2xi32> -} - // CHECK-LABEL: func @selectv2 func @selectv2(%arg0: tensor<2xi1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> { // CHECK-NEXT: "xla_hlo.select"(%arg0, %arg1, %arg2) @@ -1667,6 +1383,14 @@ func @selectv2_broadcast_pred(%arg0: tensor<1xi1>, %arg1: tensor<2x8x8xi32>, %ar return %0: tensor<2x8x8xi32> } +// CHECK-LABEL: func @selectv2_broadcast_tensor_pred +func @selectv2_broadcast_tensor_pred(%arg0: tensor<3xi1>, %arg1: tensor<2x3xf16>, %arg2: tensor<2x3xf16>) -> tensor<2x3xf16> { + // CHECK: %[[BROADCAST:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi1>) -> tensor<2x3xi1> + // CHECK: "xla_hlo.select"(%[[BROADCAST]], %arg1, %arg2) + %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16> + return %0: tensor<2x3xf16> +} + // CHECK-LABEL: func @selectv2_broadcast_all func @selectv2_broadcast_all(%arg0: tensor<8x1x1xi1>, %arg1: tensor<1x8x1xi32>, %arg2: tensor<1x1x8xi32>) -> tensor<8x8x8xi32> { // CHECK-DAG: %[[BROADCAST_0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x1x1xi1>) -> tensor<8x8x8xi1> @@ -1708,7 +1432,10 @@ func @simple_softmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { // CHECK: {dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf32>, tensor<f32>) -> tensor<2xf32> // CHECK: %[[CASTED_MAX:.*]] = "xla_hlo.convert"(%[[MAX]]) : (tensor<2xf32>) -> tensor<2xf32> - // CHECK: %[[SHIFTED_INP:.*]] = "xla_hlo.subtract"(%[[ARG0]], %[[CASTED_MAX]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // CHECK: %[[RESULT_SHAPE:.+]] = shape.shape_of %[[ARG0]] + // CHECK: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) : (!shape.shape) -> tensor<2xindex> + // CHECK: %[[BCAST_MAX:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[CASTED_MAX]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // CHECK: %[[SHIFTED_INP:.*]] = xla_hlo.subtract %[[ARG0]], %[[BCAST_MAX]] // CHECK: %[[EXP:.*]] = "xla_hlo.exponential"(%[[SHIFTED_INP]]) // Verify reduce op for summation and its body. @@ -1720,8 +1447,11 @@ func @simple_softmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { // CHECK: {dimensions = dense<1> : tensor<1xi64>} // CHECK: %[[CASTED_SUM:.*]] = "xla_hlo.convert"(%[[SUM]]) : (tensor<2xf32>) -> tensor<2xf32> - // CHECK: %[[RESULT:.*]] = "xla_hlo.divide"(%[[EXP]], %[[CASTED_SUM]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} - // return %[[RESULT]] + // CHECK: %[[RESULT_SHAPE:.+]] = shape.shape_of %[[ARG0]] + // CHECK: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) : (!shape.shape) -> tensor<2xindex> + // CHECK: %[[BCAST_SUM:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[CASTED_SUM]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // CHECK: %[[RESULT:.*]] = xla_hlo.divide %[[EXP]], %[[BCAST_SUM]] + // CHECK: return %[[RESULT]] %0 = "tf.Softmax"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32> return %0: tensor<2x3xf32> @@ -1730,7 +1460,7 @@ func @simple_softmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { // Verify intermediate and final shape are correct with dynamic shapes. // CHECK-LABEL: func @dynamic_softmax func @dynamic_softmax(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> { - // CHECK: "xla_hlo.divide"({{.*}}) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?x?xf32>, tensor<?xf32>) -> tensor<?x?xf32> + // CHECK: xla_hlo.divide {{.*}} : tensor<?x?xf32> %0 = "tf.Softmax"(%arg0) : (tensor<?x?xf32>) -> tensor<?x?xf32> return %0: tensor<?x?xf32> } @@ -1756,43 +1486,29 @@ func @rank4_softmax(%arg0: tensor<2x3x4x5xf16>) -> tensor<2x3x4x5xf16> { // CHECK: "xla_hlo.reduce" // CHECK: dimensions = dense<3> - // CHECK: "xla_hlo.divide"{{.*}} {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} + // CHECK: {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} + // CHECK: xla_hlo.divide {{.*}} %0 = "tf.Softmax"(%arg0) : (tensor<2x3x4x5xf16>) -> tensor<2x3x4x5xf16> return %0: tensor<2x3x4x5xf16> } //===----------------------------------------------------------------------===// // LogSoftmax op legalizations. +// This just changes the tail of the regular Softmax legalization //===----------------------------------------------------------------------===// // CHECK-LABEL: func @simple_logsoftmax // CHECK-SAME: (%[[ARG0:.*]]: tensor<2x3xf32>) func @simple_logsoftmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { - - // Verify reduce op for max computation and its body. - // CHECK-DAG: %[[CASTED_INP:.*]] = "xla_hlo.convert"(%[[ARG0]]) : (tensor<2x3xf32>) -> tensor<2x3xf32> - // CHECK-DAG: %[[NEG_INF:.*]] = xla_hlo.constant dense<0xFF800000> : tensor<f32> - // CHECK: %[[MAX:.*]] = "xla_hlo.reduce"(%[[CASTED_INP]], %[[NEG_INF]]) - // CHECK: xla_hlo.maximum - // CHECK: "xla_hlo.return" - // CHECK: {dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf32>, tensor<f32>) -> tensor<2xf32> - // CHECK: %[[CASTED_MAX:.*]] = "xla_hlo.convert"(%[[MAX]]) : (tensor<2xf32>) -> tensor<2xf32> - - // CHECK: %[[SHIFTED_INP:.*]] = "xla_hlo.subtract"(%[[ARG0]], %[[CASTED_MAX]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} - // CHECK: %[[EXP:.*]] = "xla_hlo.exponential"(%[[SHIFTED_INP]]) - - // Verify reduce op for summation and its body. - // CHECK-DAG: %[[CASTED_EXP:.*]] = "xla_hlo.convert"(%[[EXP]]) : (tensor<2x3xf32>) -> tensor<2x3xf32> - // CHECK-DAG: %[[ZERO:.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<f32> - // CHECK: %[[SUM:.*]] = "xla_hlo.reduce"(%[[CASTED_EXP]], %[[ZERO]]) - // CHECK: xla_hlo.add - // CHECK: "xla_hlo.return" - // CHECK: {dimensions = dense<1> : tensor<1xi64>} + // CHECK: %{{.*}} = "xla_hlo.reduce"({{.*}}) + // CHECK: %[[SUM:.*]] = "xla_hlo.reduce"({{.*}}) // CHECK: %[[CASTED_SUM:.*]] = "xla_hlo.convert"(%[[SUM]]) : (tensor<2xf32>) -> tensor<2xf32> // CHECK: %[[LOG:.*]] = "xla_hlo.log"(%[[CASTED_SUM]]) : (tensor<2xf32>) -> tensor<2xf32> - - // CHECK: %[[RESULT:.*]] = "xla_hlo.subtract"(%[[SHIFTED_INP]], %[[LOG]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} - // return %[[RESULT]] + // CHECK: %[[RESULT_SHAPE:.+]] = shape.shape_of %[[ARG0]] + // CHECK: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) : (!shape.shape) -> tensor<2xindex> + // CHECK: %[[BCAST_SUM:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[LOG]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} + // CHECK: %[[RESULT:.*]] = xla_hlo.subtract {{.*}}, %[[BCAST_SUM]] + // CHECK: return %[[RESULT]] %0 = "tf.LogSoftmax"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32> return %0: tensor<2x3xf32> @@ -2093,16 +1809,41 @@ func @neg_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { // CHECK-LABEL: @sigmoid func @sigmoid(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK-DAG: [[R0:%.+]] = xla_hlo.constant dense<5.000000e-01> : tensor<f32> - // CHECK-DAG: [[R1:%.+]] = "xla_hlo.broadcast"([[R0]]) {broadcast_sizes = dense<2> : tensor<1xi64>} : (tensor<f32>) -> tensor<2xf32> - // CHECK-DAG: [[R2:%.+]] = xla_hlo.multiply %arg0, [[R1]] : tensor<2xf32> - // CHECK-DAG: [[R3:%.+]] = "xla_hlo.tanh"([[R2]]) : (tensor<2xf32>) -> tensor<2xf32> - // CHECK-DAG: [[R4:%.+]] = xla_hlo.multiply [[R3]], [[R1]] : tensor<2xf32> - // CHECK-DAG: [[R5:%.+]] = xla_hlo.add [[R4]], [[R1]] : tensor<2xf32> + // CHECK-DAG: [[SCALAR:%.+]] = xla_hlo.constant dense<5.000000e-01> : tensor<f32> + // CHECK-DAG: [[SHAPE:%.+]] = shape.shape_of %arg0 : tensor<2xf32> + // CHECK-DAG: [[SHAPE_VAL:%.+]] = "shape.to_extent_tensor"([[SHAPE]]) : (!shape.shape) -> tensor<1xindex> + // CHECK-DAG: [[HALF:%.+]] = "xla_hlo.dynamic_broadcast_in_dim"([[SCALAR]], [[SHAPE_VAL]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<2xf32> + // CHECK-DAG: [[R1:%.+]] = xla_hlo.multiply %arg0, [[HALF]] : tensor<2xf32> + // CHECK-DAG: [[R2:%.+]] = "xla_hlo.tanh"([[R1]]) : (tensor<2xf32>) -> tensor<2xf32> + // CHECK-DAG: [[R3:%.+]] = xla_hlo.multiply [[R2]], [[HALF]] : tensor<2xf32> + // CHECK-DAG: [[R4:%.+]] = xla_hlo.add [[R3]], [[HALF]] : tensor<2xf32> %0 = "tf.Sigmoid"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } +// CHECK-LABEL: @sigmoid_complex +func @sigmoid_complex(%arg0: tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>> { + // CHECK: [[R0:%.+]] = xla_hlo.constant dense<(5.000000e-01,0.000000e+00)> : tensor<complex<f32>> + // CHECK-NOT: tf.Sigmoid + %0 = "tf.Sigmoid"(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>> + return %0 : tensor<2xcomplex<f32>> +} + +// CHECK-LABEL: @sigmoid_unranked +func @sigmoid_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { + // CHECK-DAG: [[SCALAR:%.+]] = xla_hlo.constant dense<5.000000e-01> : tensor<f32> + // CHECK-DAG: [[SHAPE:%.+]] = shape.shape_of %arg0 : tensor<*xf32> + // CHECK-DAG: [[SHAPE_VAL:%.+]] = "shape.to_extent_tensor"([[SHAPE]]) : (!shape.shape) -> tensor<?xindex> + // CHECK-DAG: [[HALF:%.+]] = "xla_hlo.dynamic_broadcast_in_dim"([[SCALAR]], [[SHAPE_VAL]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, tensor<?xindex>) -> tensor<*xf32> + // CHECK-DAG: [[R1:%.+]] = xla_hlo.multiply %arg0, [[HALF]] : tensor<*xf32> + // CHECK-DAG: [[R2:%.+]] = "xla_hlo.tanh"([[R1]]) : (tensor<*xf32>) -> tensor<*xf32> + // CHECK-DAG: [[R3:%.+]] = xla_hlo.multiply [[R2]], [[HALF]] : tensor<*xf32> + // CHECK-DAG: [[R4:%.+]] = xla_hlo.add [[R3]], [[HALF]] : tensor<*xf32> + %0 = "tf.Sigmoid"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + + // CHECK-LABEL: @sigmoid_grad func @sigmoid_grad(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> { // CHECK-DAG: [[MUL0:%.+]] = xla_hlo.multiply %arg1, %arg0 : tensor<2xf32> @@ -2114,6 +1855,17 @@ func @sigmoid_grad(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> return %0 : tensor<2xf32> } +// CHECK-LABEL: @sigmoid_grad_complex +func @sigmoid_grad_complex(%arg0: tensor<2xcomplex<f32>>, %arg1: tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>> { + // CHECK-DAG: [[MUL0:%.+]] = xla_hlo.multiply %arg1, %arg0 : tensor<2xcomplex<f32>> + // CHECK-DAG: [[ONE:%.+]] = xla_hlo.constant dense<(1.000000e+00,0.000000e+00)> : tensor<2xcomplex<f32>> + // CHECK-DAG: [[SUB:%.+]] = xla_hlo.subtract [[ONE]], %arg0 : tensor<2xcomplex<f32>> + // CHECK-DAG: [[MUL1:%.+]] = xla_hlo.multiply [[MUL0]], [[SUB]] : tensor<2xcomplex<f32>> + // CHECK: return [[MUL1]] + %0 = "tf.SigmoidGrad"(%arg0, %arg1) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>> + return %0 : tensor<2xcomplex<f32>> +} + // CHECK-LABEL: @sin func @sin(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK: "xla_hlo.sine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> @@ -2643,10 +2395,10 @@ func @strided_slice_nonconstant_begin_end(%arg0: tensor<i32>, %arg1: tensor<32x1 // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<1xi32>) -> tensor<1xi32> // CHECK-NEXT: %[[INDEX2:.*]] = "xla_hlo.reshape"(%[[INDEX]]) : (tensor<1xi32>) -> tensor<i32> - // CHECK-NEXT: %[[CMP:.*]] = "xla_hlo.compare"(%[[INDEX2]], %[[ZERO]]) + // CHECK-NEXT: %[[CMP:.*]] = xla_chlo.broadcast_compare %[[INDEX2]], %[[ZERO]] // CHECK-DAG-SAME: {comparison_direction = "LT"} : (tensor<i32>, tensor<i32>) -> tensor<i1> // CHECK-NEXT: %[[DIM:.*]] = xla_hlo.constant dense<32> : tensor<i32> - // CHECK-NEXT: %[[WRAP:.*]] = xla_hlo.add %[[DIM]], %[[INDEX2]] : tensor<i32> + // CHECK-NEXT: %[[WRAP:.*]] = xla_chlo.broadcast_add %[[DIM]], %[[INDEX2]] : (tensor<i32>, tensor<i32>) -> tensor<i32> // CHECK-NEXT: %[[INDEX3:.*]] = "xla_hlo.select"(%[[CMP]], %[[WRAP]], %[[INDEX2]]) : // CHECK-DAG-SAME: (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32> // CHECK-NEXT: %[[SLICED:.*]] = "xla_hlo.dynamic-slice" @@ -2775,7 +2527,7 @@ func @mean(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { // CHECK: "xla_hlo.return"(%[[REDUCE_BODY_RESULT]]) : (tensor<f32>) -> () // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xf32>, tensor<f32>) -> tensor<4xf32> // CHECK: %[[DIVISOR:.*]] = xla_hlo.constant dense<8.000000e+00> : tensor<f32> - // CHECK: %[[MEAN:.*]] = "xla_hlo.divide"(%[[REDUCED]], %[[DIVISOR]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32> + // CHECK: %[[MEAN:.*]] = xla_chlo.broadcast_divide %[[REDUCED]], %[[DIVISOR]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32> // CHECK: %[[CAST_BACK:.*]] = "xla_hlo.convert"(%[[MEAN]]) : (tensor<4xf32>) -> tensor<4xf16> // CHECK: %[[RESULT:.*]] = "xla_hlo.reshape"(%[[CAST_BACK]]) : (tensor<4xf16>) -> tensor<4x1xf16> // CHECK: return %[[RESULT]] : tensor<4x1xf16> @@ -3079,8 +2831,8 @@ func @rng_std_normal(%arg0: tensor<3xi32>) -> tensor<12x?x64xf32> { func @range(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<5xf32> { %1 = "tf.Const"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "range/limit", value = dense<5.000000e+00> : tensor<f32>} : () -> tensor<f32> // CHECK-DAG: [[IOTA:%.*]] = "xla_hlo.iota" - // CHECK-DAG: [[MUL:%.*]] = "xla_hlo.multiply"([[IOTA]], [[DELTA]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} - // CHECK: "xla_hlo.add"([[MUL]], [[START]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} + // CHECK-DAG: [[MUL:%.*]] = xla_chlo.broadcast_multiply [[IOTA]], [[DELTA]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} + // CHECK: xla_chlo.broadcast_add [[MUL]], [[START]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} %3 = "tf.Range"(%arg0, %1, %arg1) {Tidx = "tfdtype$DT_FLOAT", device = "", name = "range"} : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<5xf32> return %3 : tensor<5xf32> } @@ -3092,12 +2844,12 @@ func @linspace_static(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<4xf32> { // CHECK-DAG: [[NUM_CAST:%.*]] = tensor_cast [[NUM]] // CHECK-DAG: [[NUM_F32:%.*]] = "xla_hlo.convert"([[NUM_CAST]]) // CHECK-DAG: [[ONE:%.*]] = xla_hlo.constant dense<1.000000e+00> - // CHECK-DAG: [[STEP_DENOMINATOR:%.*]] = xla_hlo.subtract [[NUM_F32]], [[ONE]] - // CHECK-DAG: [[STEP_NUMERATOR:%.*]] = xla_hlo.subtract [[STOP]], [[START]] - // CHECK-DAG: [[STEP:%.*]] = xla_hlo.divide [[STEP_NUMERATOR]], [[STEP_DENOMINATOR]] + // CHECK-DAG: [[STEP_DENOMINATOR:%.*]] = xla_chlo.broadcast_subtract [[NUM_F32]], [[ONE]] + // CHECK-DAG: [[STEP_NUMERATOR:%.*]] = xla_chlo.broadcast_subtract [[STOP]], [[START]] + // CHECK-DAG: [[STEP:%.*]] = xla_chlo.broadcast_divide [[STEP_NUMERATOR]], [[STEP_DENOMINATOR]] // CHECK-DAG: [[IOTA:%.*]] = "xla_hlo.iota"() {iota_dimension = 0 : i64} - // CHECK-DAG: [[MUL:%.*]] = "xla_hlo.multiply"([[IOTA]], [[STEP]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} - // CHECK-DAG: [[LINSPACE:%.*]] = "xla_hlo.add"([[MUL]], [[START]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} + // CHECK-DAG: [[MUL:%.*]] = xla_chlo.broadcast_multiply [[IOTA]], [[STEP]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} + // CHECK-DAG: [[LINSPACE:%.*]] = xla_chlo.broadcast_add [[MUL]], [[START]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} // CHECK: return [[LINSPACE]] %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<4> : tensor<i32>} : () -> tensor<i32> %1 = "tf.LinSpace"(%arg0, %arg1, %0) : (tensor<f32>, tensor<f32>, tensor<i32>) -> tensor<4xf32> @@ -3392,13 +3144,13 @@ func @size_ranked(%input: tensor<2x?x8xf32>) -> (tensor<i32>) { // CHECK: %[[CONST:.*]] = xla_hlo.constant dense<1> // CHECK: %[[DIM_0:.*]] = "xla_hlo.get_dimension_size"(%[[INPUT]]) // CHECK-SAME: dimension = 0 - // CHECK: %[[MUL_0:.*]] = xla_hlo.multiply %[[CONST]], %[[DIM_0]] + // CHECK: %[[MUL_0:.*]] = xla_chlo.broadcast_multiply %[[CONST]], %[[DIM_0]] // CHECK: %[[DIM_1:.*]] = "xla_hlo.get_dimension_size"(%[[INPUT]]) // CHECK-SAME: dimension = 1 - // CHECK: %[[MUL_1:.*]] = xla_hlo.multiply %[[MUL_0]], %[[DIM_1]] + // CHECK: %[[MUL_1:.*]] = xla_chlo.broadcast_multiply %[[MUL_0]], %[[DIM_1]] // CHECK: %[[DIM_2:.*]] = "xla_hlo.get_dimension_size"(%[[INPUT]]) // CHECK-SAME: dimension = 2 - // CHECK: %[[MUL_2:.*]] = xla_hlo.multiply %[[MUL_1]], %[[DIM_2]] + // CHECK: %[[MUL_2:.*]] = xla_chlo.broadcast_multiply %[[MUL_1]], %[[DIM_2]] %size = "tf.Size"(%input) {T = "tfdtype$DT_FLOAT", out_type = "tfdtype$DT_INT32"} : (tensor<2x?x8xf32>) -> tensor<i32> // CHECK: return %[[MUL_2]] return %size : tensor<i32> @@ -3555,30 +3307,31 @@ func @assert(%arg0: tensor<i1>, %arg1: tensor<*xf32>) { // tf.Unpack legalization //===----------------------------------------------------------------------===// -// CHECK-LABEL: @unpack -func @unpack(%input: tensor<4x3x6xf32>) -> (tensor<4x?xf32>, tensor<4x6xf32>, tensor<4x6xf32>) { - // CHECK: %[[SLICE1:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[4, 1, 6]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> - // CHECK: %[[RES1:.*]] = "xla_hlo.reshape"(%[[SLICE1]]) : (tensor<4x1x6xf32>) -> tensor<4x?xf32> - // CHECK: %[[SLICE2:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[4, 2, 6]> : tensor<3xi64>, start_indices = dense<[0, 1, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> - // CHECK: %[[RES2:.*]] = "xla_hlo.reshape"(%[[SLICE2]]) : (tensor<4x1x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[SLICE3:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[4, 3, 6]> : tensor<3xi64>, start_indices = dense<[0, 2, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> - // CHECK: %[[RES3:.*]] = "xla_hlo.reshape"(%[[SLICE3]]) : (tensor<4x1x6xf32>) -> tensor<4x6xf32> +// TODO(b/156340000): Re-enable when fixed. +// // C-HECK-LABEL: @unpack +// func @unpack(%input: tensor<4x3x6xf32>) -> (tensor<4x?xf32>, tensor<4x6xf32>, tensor<4x6xf32>) { +// // C-HECK: %[[SLICE1:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[4, 1, 6]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> +// // C-HECK: %[[RES1:.*]] = "xla_hlo.reshape"(%[[SLICE1]]) : (tensor<4x1x6xf32>) -> tensor<4x?xf32> +// // C-HECK: %[[SLICE2:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[4, 2, 6]> : tensor<3xi64>, start_indices = dense<[0, 1, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> +// // C-HECK: %[[RES2:.*]] = "xla_hlo.reshape"(%[[SLICE2]]) : (tensor<4x1x6xf32>) -> tensor<4x6xf32> +// // C-HECK: %[[SLICE3:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[4, 3, 6]> : tensor<3xi64>, start_indices = dense<[0, 2, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> +// // C-HECK: %[[RES3:.*]] = "xla_hlo.reshape"(%[[SLICE3]]) : (tensor<4x1x6xf32>) -> tensor<4x6xf32> - %0:3 = "tf.Unpack"(%input) {axis = 1} : (tensor<4x3x6xf32>) -> (tensor<4x?xf32>, tensor<4x6xf32>, tensor<4x6xf32>) - // return %[[RES1]], %[[RES2]], %[[RES3]] - return %0#0, %0#1, %0#2 : tensor<4x?xf32>, tensor<4x6xf32>, tensor<4x6xf32> -} +// %0:3 = "tf.Unpack"(%input) {axis = 1} : (tensor<4x3x6xf32>) -> (tensor<4x?xf32>, tensor<4x6xf32>, tensor<4x6xf32>) +// // return %[[RES1]], %[[RES2]], %[[RES3]] +// return %0#0, %0#1, %0#2 : tensor<4x?xf32>, tensor<4x6xf32>, tensor<4x6xf32> +// } -// CHECK-LABEL: @unpack_dynamic -func @unpack_dynamic(%input: tensor<?x?x2xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) { - // CHECK: %[[SLICE1:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[-1, -1, 1]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<?x?x2xf32>) -> tensor<?x?x1xf32> - // CHECK: "xla_hlo.reshape"(%[[SLICE1]]) : (tensor<?x?x1xf32>) -> tensor<?x?xf32> - // CHECK: %[[SLICE2:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[-1, -1, 2]> : tensor<3xi64>, start_indices = dense<[0, 0, 1]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<?x?x2xf32>) -> tensor<?x?x1xf32> - // CHECK: "xla_hlo.reshape"(%[[SLICE2]]) : (tensor<?x?x1xf32>) -> tensor<?x?xf32> +// // C-HECK-LABEL: @unpack_dynamic +// func @unpack_dynamic(%input: tensor<?x?x2xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) { +// // C-HECK: %[[SLICE1:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[-1, -1, 1]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<?x?x2xf32>) -> tensor<?x?x1xf32> +// // C-HECK: "xla_hlo.reshape"(%[[SLICE1]]) : (tensor<?x?x1xf32>) -> tensor<?x?xf32> +// // C-HECK: %[[SLICE2:.*]] = "xla_hlo.slice"(%{{.*}}) {limit_indices = dense<[-1, -1, 2]> : tensor<3xi64>, start_indices = dense<[0, 0, 1]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<?x?x2xf32>) -> tensor<?x?x1xf32> +// // C-HECK: "xla_hlo.reshape"(%[[SLICE2]]) : (tensor<?x?x1xf32>) -> tensor<?x?xf32> - %0:2 = "tf.Unpack"(%input) {axis = -1} : (tensor<?x?x2xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) - return %0#0, %0#1 : tensor<?x?xf32>, tensor<?x?xf32> -} +// %0:2 = "tf.Unpack"(%input) {axis = -1} : (tensor<?x?x2xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) +// return %0#0, %0#1 : tensor<?x?xf32>, tensor<?x?xf32> +// } //===----------------------------------------------------------------------===// // tf.UnsortedSegment{Max|Min|Prod|Sum} legalization @@ -3914,7 +3667,7 @@ func @random_shuffle_3D(%input: tensor<4x?x16xf32>) -> tensor<4x?x16xf32> { // CHECK: [[INDICES1:%.*]] = "xla_hlo.dynamic-update-slice"([[INDICES]], [[TGT_IDX]], [[IV]]) : (tensor<4xi32>, tensor<1xi32>, tensor<i32>) -> tensor<4xi32> // CHECK: [[INDICES2:%.*]] = "xla_hlo.dynamic-update-slice"([[INDICES1]], [[SRC_IDX]], [[SWP]]) : (tensor<4xi32>, tensor<1xi32>, tensor<i32>) -> tensor<4xi32> // CHECK: [[ONE:%.*]] = xla_hlo.constant dense<1> : tensor<i32> - // CHECK: [[NEW_IV:%.*]] = xla_hlo.add [[IV]], [[ONE]] + // CHECK: [[NEW_IV:%.*]] = xla_chlo.broadcast_add [[IV]], [[ONE]] // CHECK: [[NEW_TUPLE:%.*]] = "xla_hlo.tuple"([[NEW_IV]], [[SWAPS]], [[INDICES2]]) // CHECK: "xla_hlo.return"([[NEW_TUPLE]]) // CHECK: }) : (tuple<tensor<i32>, tensor<4xi32>, tensor<4xi32>>) -> tuple<tensor<i32>, tensor<4xi32>, tensor<4xi32>> @@ -3983,7 +3736,7 @@ func @avgpool_valid_padding(%arg0: tensor<2x12x20x7xf16>) -> tensor<2x3x5x7xf16> // CHECK: "xla_hlo.return"([[ADD]]) // CHECK: }) {window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<[1, 4, 4, 1]> : tensor<4xi64>} : (tensor<2x12x20x7xf32>, tensor<f32>) -> tensor<2x3x5x7xf32> // CHECK: [[COUNT:%.+]] = xla_hlo.constant dense<4.000000e+00> : tensor<f32> - // CHECK: [[DIV:%.+]] = "xla_hlo.divide"([[REDUCE]], [[COUNT]]) {broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>} : (tensor<2x3x5x7xf32>, tensor<f32>) -> tensor<2x3x5x7xf32> + // CHECK: [[DIV:%.+]] = xla_chlo.broadcast_divide [[REDUCE]], [[COUNT]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<2x3x5x7xf32>, tensor<f32>) -> tensor<2x3x5x7xf32> // CHECK: [[CONV16:%.+]] = "xla_hlo.convert"([[DIV]]) : (tensor<2x3x5x7xf32>) -> tensor<2x3x5x7xf16> // CHECK: return [[CONV16]] %0 = "tf.AvgPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "VALID", strides = [1, 4, 4, 1]} : (tensor<2x12x20x7xf16>) -> tensor<2x3x5x7xf16> @@ -4123,177 +3876,11 @@ func @cumsum_dynamic(%arg0: tensor<?xf32>, %arg1: tensor<i32>) -> tensor<?xf32> // CHECK: func @qr([[VAL_0:%.*]]: tensor<500x100x75xf32>) -> (tensor<500x100x75xf32>, tensor<500x75x75xf32>) func @qr(%arg0: tensor<500x100x75xf32>) -> (tensor<500x100x75xf32>, tensor<500x75x75xf32>) { -// CHECK: [[VAL_1:%.*]] = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<100x100xi32> -// CHECK: [[VAL_2:%.*]] = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<100x100xi32> -// CHECK: [[VAL_3:%.*]] = "xla_hlo.compare"([[VAL_1]], [[VAL_2]]) {comparison_direction = "EQ"} : (tensor<100x100xi32>, tensor<100x100xi32>) -> tensor<100x100xi1> -// CHECK: [[VAL_4:%.*]] = "xla_hlo.convert"([[VAL_3]]) : (tensor<100x100xi1>) -> tensor<100x100xf32> -// CHECK: [[VAL_5:%.*]] = "xla_hlo.broadcast"([[VAL_4]]) {broadcast_sizes = dense<500> : tensor<1xi64>} : (tensor<100x100xf32>) -> tensor<500x100x100xf32> -// CHECK: [[VAL_6:%.*]] = "xla_hlo.slice"([[VAL_0]]) {limit_indices = dense<[500, 100, 75]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<500x100x75xf32>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_7:%.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<f32> -// CHECK: [[VAL_8:%.*]] = "xla_hlo.broadcast"([[VAL_7]]) {broadcast_sizes = dense<[500, 100, 75]> : tensor<3xi64>} : (tensor<f32>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_9:%.*]] = "xla_hlo.broadcast"([[VAL_7]]) {broadcast_sizes = dense<[500, 75]> : tensor<2xi64>} : (tensor<f32>) -> tensor<500x75xf32> -// CHECK: [[VAL_10:%.*]] = xla_hlo.constant dense<0> : tensor<i32> -// CHECK: [[VAL_11:%.*]] = "xla_hlo.tuple"([[VAL_10]], [[VAL_6]], [[VAL_8]], [[VAL_9]]) : (tensor<i32>, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>) -> tuple<tensor<i32>, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>> -// CHECK: [[VAL_12:%.*]] = "xla_hlo.while"([[VAL_11]]) ( { -// CHECK: ^bb0([[VAL_13:%.*]]: tuple<tensor<i32>, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>): -// CHECK: [[VAL_14:%.*]] = "xla_hlo.get_tuple_element"([[VAL_13]]) {index = 0 : i32} : (tuple<tensor<i32>, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> tensor<i32> -// CHECK: [[VAL_15:%.*]] = xla_hlo.constant dense<75> : tensor<i32> -// CHECK: [[VAL_16:%.*]] = "xla_hlo.compare"([[VAL_14]], [[VAL_15]]) {comparison_direction = "LT"} : (tensor<i32>, tensor<i32>) -> tensor<i1> -// CHECK: "xla_hlo.return"([[VAL_16]]) : (tensor<i1>) -> () -// CHECK: }, { -// CHECK: ^bb0([[VAL_17:%.*]]: tuple<tensor<i32>, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>): -// CHECK: [[VAL_18:%.*]] = "xla_hlo.get_tuple_element"([[VAL_17]]) {index = 0 : i32} : (tuple<tensor<i32>, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> tensor<i32> -// CHECK: [[VAL_19:%.*]] = "xla_hlo.get_tuple_element"([[VAL_17]]) {index = 1 : i32} : (tuple<tensor<i32>, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_20:%.*]] = "xla_hlo.get_tuple_element"([[VAL_17]]) {index = 2 : i32} : (tuple<tensor<i32>, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_21:%.*]] = "xla_hlo.get_tuple_element"([[VAL_17]]) {index = 3 : i32} : (tuple<tensor<i32>, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> tensor<500x75xf32> -// CHECK: [[VAL_22:%.*]] = xla_hlo.constant dense<0> : tensor<i32> -// CHECK: [[VAL_23:%.*]] = "xla_hlo.dynamic-slice"([[VAL_19]], [[VAL_22]], [[VAL_22]], [[VAL_18]]) {slice_sizes = dense<[500, 100, 1]> : tensor<3xi64>} : (tensor<500x100x75xf32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<500x100x1xf32> -// CHECK: [[VAL_24:%.*]] = "xla_hlo.reshape"([[VAL_23]]) : (tensor<500x100x1xf32>) -> tensor<500x100xf32> -// CHECK: [[VAL_25:%.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<f32> -// CHECK: [[VAL_26:%.*]] = xla_hlo.constant dense<1.000000e+00> : tensor<f32> -// CHECK: [[VAL_27:%.*]] = xla_hlo.constant dense<0> : tensor<i32> -// CHECK: [[VAL_28:%.*]] = "xla_hlo.dynamic-slice"([[VAL_24]], [[VAL_27]], [[VAL_18]]) {slice_sizes = dense<[500, 1]> : tensor<2xi64>} : (tensor<500x100xf32>, tensor<i32>, tensor<i32>) -> tensor<500x1xf32> -// CHECK: [[VAL_29:%.*]] = "xla_hlo.reshape"([[VAL_28]]) : (tensor<500x1xf32>) -> tensor<500xf32> -// CHECK: [[VAL_30:%.*]] = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<100xi32> -// CHECK: [[VAL_31:%.*]] = "xla_hlo.compare"([[VAL_30]], [[VAL_18]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "GT"} : (tensor<100xi32>, tensor<i32>) -> tensor<100xi1> -// CHECK: [[VAL_32:%.*]] = "xla_hlo.convert"([[VAL_31]]) : (tensor<100xi1>) -> tensor<100xf32> -// CHECK: [[VAL_33:%.*]] = "xla_hlo.multiply"([[VAL_24]], [[VAL_32]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<500x100xf32>, tensor<100xf32>) -> tensor<500x100xf32> -// CHECK: [[VAL_34:%.*]] = xla_hlo.multiply [[VAL_33]], [[VAL_33]] : tensor<500x100xf32> -// CHECK: [[VAL_35:%.*]] = "xla_hlo.reduce"([[VAL_34]], [[VAL_25]]) ( { -// CHECK: ^bb0([[VAL_36:%.*]]: tensor<f32>, [[VAL_37:%.*]]: tensor<f32>): -// CHECK: [[VAL_38:%.*]] = xla_hlo.add [[VAL_36]], [[VAL_37]] : tensor<f32> -// CHECK: "xla_hlo.return"([[VAL_38]]) : (tensor<f32>) -> () -// CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<500x100xf32>, tensor<f32>) -> tensor<500xf32> -// CHECK: [[VAL_39:%.*]] = xla_hlo.multiply [[VAL_29]], [[VAL_29]] : tensor<500xf32> -// CHECK: [[VAL_40:%.*]] = xla_hlo.add [[VAL_39]], [[VAL_41:%.*]] : tensor<500xf32> -// CHECK: [[VAL_42:%.*]] = "xla_hlo.sqrt"([[VAL_40]]) : (tensor<500xf32>) -> tensor<500xf32> -// CHECK: [[VAL_43:%.*]] = "xla_hlo.compare"([[VAL_41]], [[VAL_25]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "EQ"} : (tensor<500xf32>, tensor<f32>) -> tensor<500xi1> -// CHECK: [[VAL_44:%.*]] = "xla_hlo.compare"([[VAL_29]], [[VAL_25]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "LT"} : (tensor<500xf32>, tensor<f32>) -> tensor<500xi1> -// CHECK: [[VAL_45:%.*]] = "xla_hlo.broadcast"([[VAL_26]]) {broadcast_sizes = dense<500> : tensor<1xi64>} : (tensor<f32>) -> tensor<500xf32> -// CHECK: [[VAL_46:%.*]] = "xla_hlo.negate"([[VAL_45]]) : (tensor<500xf32>) -> tensor<500xf32> -// CHECK: [[VAL_47:%.*]] = "xla_hlo.select"([[VAL_44]], [[VAL_45]], [[VAL_46]]) : (tensor<500xi1>, tensor<500xf32>, tensor<500xf32>) -> tensor<500xf32> -// CHECK: [[VAL_48:%.*]] = xla_hlo.multiply [[VAL_47]], [[VAL_42]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<500xf32> -// CHECK: [[VAL_49:%.*]] = "xla_hlo.select"([[VAL_43]], [[VAL_29]], [[VAL_48]]) : (tensor<500xi1>, tensor<500xf32>, tensor<500xf32>) -> tensor<500xf32> -// CHECK: [[VAL_50:%.*]] = xla_hlo.subtract [[VAL_49]], [[VAL_29]] : tensor<500xf32> -// CHECK: [[VAL_51:%.*]] = xla_hlo.divide [[VAL_50]], [[VAL_49]] : tensor<500xf32> -// CHECK: [[VAL_52:%.*]] = "xla_hlo.broadcast"([[VAL_25]]) {broadcast_sizes = dense<500> : tensor<1xi64>} : (tensor<f32>) -> tensor<500xf32> -// CHECK: [[VAL_53:%.*]] = "xla_hlo.select"([[VAL_43]], [[VAL_52]], [[VAL_51]]) : (tensor<500xi1>, tensor<500xf32>, tensor<500xf32>) -> tensor<500xf32> -// CHECK: [[VAL_54:%.*]] = xla_hlo.subtract [[VAL_29]], [[VAL_49]] : tensor<500xf32> -// CHECK: [[VAL_55:%.*]] = "xla_hlo.select"([[VAL_43]], [[VAL_45]], [[VAL_54]]) : (tensor<500xi1>, tensor<500xf32>, tensor<500xf32>) -> tensor<500xf32> -// CHECK: [[VAL_56:%.*]] = "xla_hlo.compare"([[VAL_30]], [[VAL_18]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "EQ"} : (tensor<100xi32>, tensor<i32>) -> tensor<100xi1> -// CHECK: [[VAL_57:%.*]] = "xla_hlo.convert"([[VAL_56]]) : (tensor<100xi1>) -> tensor<100xf32> -// CHECK: [[VAL_58:%.*]] = "xla_hlo.broadcast"([[VAL_57]]) {broadcast_sizes = dense<1> : tensor<1xi64>} : (tensor<100xf32>) -> tensor<1x100xf32> -// CHECK: [[VAL_59:%.*]] = "xla_hlo.divide"([[VAL_33]], [[VAL_55]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<500x100xf32>, tensor<500xf32>) -> tensor<500x100xf32> -// CHECK: [[VAL_60:%.*]] = "xla_hlo.add"([[VAL_58]], [[VAL_59]]) : (tensor<1x100xf32>, tensor<500x100xf32>) -> tensor<500x100xf32> -// CHECK: [[VAL_61:%.*]] = "xla_hlo.reshape"([[VAL_60]]) : (tensor<500x100xf32>) -> tensor<500x1x100xf32> -// CHECK: [[VAL_62:%.*]] = "xla_hlo.dot_general"([[VAL_61]], [[VAL_19]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<2> : tensor<1xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, precision_config = ["HIGHEST", "HIGHEST"]} : (tensor<500x1x100xf32>, tensor<500x100x75xf32>) -> tensor<500x1x75xf32> -// CHECK: [[VAL_63:%.*]] = "xla_hlo.dot_general"([[VAL_61]], [[VAL_62]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<1> : tensor<1xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, precision_config = ["HIGHEST", "HIGHEST"]} : (tensor<500x1x100xf32>, tensor<500x1x75xf32>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_64:%.*]] = "xla_hlo.multiply"([[VAL_53]], [[VAL_63]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<500xf32>, tensor<500x100x75xf32>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_65:%.*]] = xla_hlo.subtract [[VAL_19]], [[VAL_64]] : tensor<500x100x75xf32> -// CHECK: [[VAL_66:%.*]] = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<100x1xi32> -// CHECK: [[VAL_67:%.*]] = "xla_hlo.compare"([[VAL_66]], [[VAL_18]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "LT"} : (tensor<100x1xi32>, tensor<i32>) -> tensor<100x1xi1> -// CHECK: [[VAL_68:%.*]] = "xla_hlo.convert"([[VAL_67]]) : (tensor<100x1xi1>) -> tensor<100x1xf32> -// CHECK: [[VAL_69:%.*]] = "xla_hlo.compare"([[VAL_66]], [[VAL_18]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "EQ"} : (tensor<100x1xi32>, tensor<i32>) -> tensor<100x1xi1> -// CHECK: [[VAL_70:%.*]] = "xla_hlo.convert"([[VAL_69]]) : (tensor<100x1xi1>) -> tensor<100x1xf32> -// CHECK: [[VAL_71:%.*]] = "xla_hlo.broadcast"([[VAL_70]]) {broadcast_sizes = dense<1> : tensor<1xi64>} : (tensor<100x1xf32>) -> tensor<1x100x1xf32> -// CHECK: [[VAL_72:%.*]] = "xla_hlo.multiply"([[VAL_23]], [[VAL_68]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<500x100x1xf32>, tensor<100x1xf32>) -> tensor<500x100x1xf32> -// CHECK: [[VAL_73:%.*]] = "xla_hlo.multiply"([[VAL_49]], [[VAL_71]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<500xf32>, tensor<1x100x1xf32>) -> tensor<500x100x1xf32> -// CHECK: [[VAL_74:%.*]] = xla_hlo.add [[VAL_72]], [[VAL_73]] : tensor<500x100x1xf32> -// CHECK: [[VAL_75:%.*]] = "xla_hlo.broadcast_in_dim"([[VAL_74]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<500x100x1xf32>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_76:%.*]] = "xla_hlo.iota"() {iota_dimension = 2 : i64} : () -> tensor<500x100x75xi32> -// CHECK: [[VAL_77:%.*]] = "xla_hlo.compare"([[VAL_76]], [[VAL_18]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "EQ"} : (tensor<500x100x75xi32>, tensor<i32>) -> tensor<500x100x75xi1> -// CHECK: [[VAL_78:%.*]] = "xla_hlo.select"([[VAL_77]], [[VAL_75]], [[VAL_65]]) : (tensor<500x100x75xi1>, tensor<500x100x75xf32>, tensor<500x100x75xf32>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_79:%.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<f32> -// CHECK: [[VAL_80:%.*]] = "xla_hlo.broadcast"([[VAL_79]]) {broadcast_sizes = dense<[500, 100, 75]> : tensor<3xi64>} : (tensor<f32>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_81:%.*]] = "xla_hlo.add"([[VAL_80]], [[VAL_60]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<500x100x75xf32>, tensor<500x100xf32>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_82:%.*]] = "xla_hlo.select"([[VAL_77]], [[VAL_81]], [[VAL_80]]) : (tensor<500x100x75xi1>, tensor<500x100x75xf32>, tensor<500x100x75xf32>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_83:%.*]] = xla_hlo.add [[VAL_20]], [[VAL_82]] : tensor<500x100x75xf32> -// CHECK: [[VAL_84:%.*]] = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<500x75xi32> -// CHECK: [[VAL_85:%.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<f32> -// CHECK: [[VAL_86:%.*]] = "xla_hlo.broadcast"([[VAL_85]]) {broadcast_sizes = dense<[500, 75]> : tensor<2xi64>} : (tensor<f32>) -> tensor<500x75xf32> -// CHECK: [[VAL_87:%.*]] = "xla_hlo.compare"([[VAL_84]], [[VAL_18]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "EQ"} : (tensor<500x75xi32>, tensor<i32>) -> tensor<500x75xi1> -// CHECK: [[VAL_88:%.*]] = "xla_hlo.add"([[VAL_86]], [[VAL_53]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<500x75xf32>, tensor<500xf32>) -> tensor<500x75xf32> -// CHECK: [[VAL_89:%.*]] = "xla_hlo.select"([[VAL_87]], [[VAL_88]], [[VAL_86]]) : (tensor<500x75xi1>, tensor<500x75xf32>, tensor<500x75xf32>) -> tensor<500x75xf32> -// CHECK: [[VAL_90:%.*]] = xla_hlo.add [[VAL_21]], [[VAL_89]] : tensor<500x75xf32> -// CHECK: [[VAL_91:%.*]] = xla_hlo.constant dense<1> : tensor<i32> -// CHECK: [[VAL_92:%.*]] = xla_hlo.add [[VAL_18]], [[VAL_91]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<i32> -// CHECK: [[VAL_93:%.*]] = "xla_hlo.tuple"([[VAL_92]], [[VAL_78]], [[VAL_83]], [[VAL_90]]) : (tensor<i32>, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>) -> tuple<tensor<i32>, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>> -// CHECK: "xla_hlo.return"([[VAL_93]]) : (tuple<tensor<i32>, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> () -// CHECK: }) : (tuple<tensor<i32>, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> tuple<tensor<i32>, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>> -// CHECK: [[VAL_94:%.*]] = "xla_hlo.get_tuple_element"([[VAL_95:%.*]]) {index = 1 : i32} : (tuple<tensor<i32>, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_96:%.*]] = "xla_hlo.get_tuple_element"([[VAL_95]]) {index = 2 : i32} : (tuple<tensor<i32>, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_97:%.*]] = "xla_hlo.get_tuple_element"([[VAL_95]]) {index = 3 : i32} : (tuple<tensor<i32>, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> tensor<500x75xf32> -// CHECK: [[VAL_98:%.*]] = xla_hlo.constant dense<0> : tensor<i32> -// CHECK: [[VAL_99:%.*]] = xla_hlo.constant dense<0> : tensor<i32> -// CHECK: [[VAL_100:%.*]] = xla_hlo.constant dense<0> : tensor<i32> -// CHECK: [[VAL_101:%.*]] = "xla_hlo.dynamic-update-slice"([[VAL_0]], [[VAL_94]], [[VAL_100]], [[VAL_98]], [[VAL_99]]) : (tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_102:%.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<f32> -// CHECK: [[VAL_103:%.*]] = "xla_hlo.broadcast"([[VAL_102]]) {broadcast_sizes = dense<[500, 100, 75]> : tensor<3xi64>} : (tensor<f32>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_104:%.*]] = "xla_hlo.slice"([[VAL_96]]) {limit_indices = dense<[500, 100, 1]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<500x100x75xf32>) -> tensor<500x100x1xf32> -// CHECK: [[VAL_105:%.*]] = "xla_hlo.slice"([[VAL_97]]) {limit_indices = dense<[500, 1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<500x75xf32>) -> tensor<500x1xf32> -// CHECK: [[VAL_106:%.*]] = "xla_hlo.negate"([[VAL_105]]) : (tensor<500x1xf32>) -> tensor<500x1xf32> -// CHECK: [[VAL_107:%.*]] = "xla_hlo.multiply"([[VAL_106]], [[VAL_104]]) {broadcast_dimensions = dense<[0, 2]> : tensor<2xi64>} : (tensor<500x1xf32>, tensor<500x100x1xf32>) -> tensor<500x100x1xf32> -// CHECK: [[VAL_108:%.*]] = xla_hlo.constant dense<0> : tensor<i32> -// CHECK: [[VAL_109:%.*]] = xla_hlo.constant dense<0> : tensor<i32> -// CHECK: [[VAL_110:%.*]] = "xla_hlo.dynamic-update-slice"([[VAL_103]], [[VAL_107]], [[VAL_109]], [[VAL_109]], [[VAL_108]]) : (tensor<500x100x75xf32>, tensor<500x100x1xf32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_111:%.*]] = xla_hlo.constant dense<0> : tensor<i32> -// CHECK: [[VAL_112:%.*]] = "xla_hlo.tuple"([[VAL_111]], [[VAL_110]], [[VAL_96]], [[VAL_97]]) : (tensor<i32>, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>) -> tuple<tensor<i32>, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>> -// CHECK: [[VAL_113:%.*]] = "xla_hlo.while"([[VAL_112]]) ( { -// CHECK: ^bb0([[VAL_114:%.*]]: tuple<tensor<i32>, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>): -// CHECK: [[VAL_115:%.*]] = "xla_hlo.get_tuple_element"([[VAL_114]]) {index = 0 : i32} : (tuple<tensor<i32>, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> tensor<i32> -// CHECK: [[VAL_116:%.*]] = xla_hlo.constant dense<74> : tensor<i32> -// CHECK: [[VAL_117:%.*]] = "xla_hlo.compare"([[VAL_115]], [[VAL_116]]) {comparison_direction = "LT"} : (tensor<i32>, tensor<i32>) -> tensor<i1> -// CHECK: "xla_hlo.return"([[VAL_117]]) : (tensor<i1>) -> () -// CHECK: }, { -// CHECK: ^bb0([[VAL_118:%.*]]: tuple<tensor<i32>, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>): -// CHECK: [[VAL_119:%.*]] = "xla_hlo.get_tuple_element"([[VAL_118]]) {index = 0 : i32} : (tuple<tensor<i32>, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> tensor<i32> -// CHECK: [[VAL_120:%.*]] = "xla_hlo.get_tuple_element"([[VAL_118]]) {index = 1 : i32} : (tuple<tensor<i32>, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_121:%.*]] = "xla_hlo.get_tuple_element"([[VAL_118]]) {index = 2 : i32} : (tuple<tensor<i32>, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_122:%.*]] = "xla_hlo.get_tuple_element"([[VAL_118]]) {index = 3 : i32} : (tuple<tensor<i32>, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> tensor<500x75xf32> -// CHECK: [[VAL_123:%.*]] = xla_hlo.constant dense<1> : tensor<i32> -// CHECK: [[VAL_124:%.*]] = xla_hlo.add [[VAL_119]], [[VAL_123]] : tensor<i32> -// CHECK: [[VAL_125:%.*]] = xla_hlo.constant dense<0> : tensor<i32> -// CHECK: [[VAL_126:%.*]] = "xla_hlo.dynamic-slice"([[VAL_121]], [[VAL_125]], [[VAL_125]], [[VAL_124]]) {slice_sizes = dense<[500, 100, 1]> : tensor<3xi64>} : (tensor<500x100x75xf32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<500x100x1xf32> -// CHECK: [[VAL_127:%.*]] = xla_hlo.constant dense<0> : tensor<i32> -// CHECK: [[VAL_128:%.*]] = "xla_hlo.dynamic-slice"([[VAL_122]], [[VAL_127]], [[VAL_124]]) {slice_sizes = dense<[500, 1]> : tensor<2xi64>} : (tensor<500x75xf32>, tensor<i32>, tensor<i32>) -> tensor<500x1xf32> -// CHECK: [[VAL_129:%.*]] = "xla_hlo.iota"() {iota_dimension = 2 : i64} : () -> tensor<500x100x75xi32> -// CHECK: [[VAL_130:%.*]] = xla_hlo.constant dense<0.000000e+00> : tensor<f32> -// CHECK: [[VAL_131:%.*]] = "xla_hlo.broadcast"([[VAL_130]]) {broadcast_sizes = dense<[500, 100, 75]> : tensor<3xi64>} : (tensor<f32>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_132:%.*]] = "xla_hlo.compare"([[VAL_129]], [[VAL_124]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "GE"} : (tensor<500x100x75xi32>, tensor<i32>) -> tensor<500x100x75xi1> -// CHECK: [[VAL_133:%.*]] = "xla_hlo.select"([[VAL_132]], [[VAL_131]], [[VAL_121]]) : (tensor<500x100x75xi1>, tensor<500x100x75xf32>, tensor<500x100x75xf32>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_134:%.*]] = "xla_hlo.dot_general"([[VAL_133]], [[VAL_126]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<1> : tensor<1xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, precision_config = ["HIGHEST", "HIGHEST"]} : (tensor<500x100x75xf32>, tensor<500x100x1xf32>) -> tensor<500x75x1xf32> -// CHECK: [[VAL_135:%.*]] = "xla_hlo.dot_general"([[VAL_120]], [[VAL_134]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<2> : tensor<1xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, precision_config = ["HIGHEST", "HIGHEST"]} : (tensor<500x100x75xf32>, tensor<500x75x1xf32>) -> tensor<500x100x1xf32> -// CHECK: [[VAL_136:%.*]] = "xla_hlo.negate"([[VAL_128]]) : (tensor<500x1xf32>) -> tensor<500x1xf32> -// CHECK: [[VAL_137:%.*]] = xla_hlo.add [[VAL_126]], [[VAL_135]] : tensor<500x100x1xf32> -// CHECK: [[VAL_138:%.*]] = "xla_hlo.multiply"([[VAL_136]], [[VAL_137]]) {broadcast_dimensions = dense<[0, 2]> : tensor<2xi64>} : (tensor<500x1xf32>, tensor<500x100x1xf32>) -> tensor<500x100x1xf32> -// CHECK: [[VAL_139:%.*]] = xla_hlo.constant dense<0> : tensor<i32> -// CHECK: [[VAL_140:%.*]] = "xla_hlo.dynamic-update-slice"([[VAL_120]], [[VAL_138]], [[VAL_139]], [[VAL_139]], [[VAL_124]]) : (tensor<500x100x75xf32>, tensor<500x100x1xf32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_141:%.*]] = xla_hlo.constant dense<1> : tensor<i32> -// CHECK: [[VAL_142:%.*]] = xla_hlo.add [[VAL_119]], [[VAL_141]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : tensor<i32> -// CHECK: [[VAL_143:%.*]] = "xla_hlo.tuple"([[VAL_142]], [[VAL_140]], [[VAL_121]], [[VAL_122]]) : (tensor<i32>, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>) -> tuple<tensor<i32>, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>> -// CHECK: "xla_hlo.return"([[VAL_143]]) : (tuple<tensor<i32>, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> () -// CHECK: }) : (tuple<tensor<i32>, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> tuple<tensor<i32>, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>> -// CHECK: [[VAL_144:%.*]] = "xla_hlo.get_tuple_element"([[VAL_145:%.*]]) {index = 1 : i32} : (tuple<tensor<i32>, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_146:%.*]] = "xla_hlo.get_tuple_element"([[VAL_145]]) {index = 2 : i32} : (tuple<tensor<i32>, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_147:%.*]] = "xla_hlo.get_tuple_element"([[VAL_145]]) {index = 3 : i32} : (tuple<tensor<i32>, tensor<500x100x75xf32>, tensor<500x100x75xf32>, tensor<500x75xf32>>) -> tensor<500x75xf32> -// CHECK: [[VAL_148:%.*]] = "xla_hlo.slice"([[VAL_101]]) {limit_indices = dense<[500, 100, 75]> : tensor<3xi64>, start_indices = dense<[0, 0, 75]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<500x100x75xf32>) -> tensor<500x100x0xf32> -// CHECK: [[VAL_149:%.*]] = "xla_hlo.dot_general"([[VAL_144]], [[VAL_148]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<1> : tensor<1xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, precision_config = ["HIGHEST", "HIGHEST"]} : (tensor<500x100x75xf32>, tensor<500x100x0xf32>) -> tensor<500x75x0xf32> -// CHECK: [[VAL_150:%.*]] = "xla_hlo.dot_general"([[VAL_96]], [[VAL_149]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<2> : tensor<1xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, precision_config = ["HIGHEST", "HIGHEST"]} : (tensor<500x100x75xf32>, tensor<500x75x0xf32>) -> tensor<500x100x0xf32> -// CHECK: [[VAL_151:%.*]] = xla_hlo.add [[VAL_148]], [[VAL_150]] : tensor<500x100x0xf32> -// CHECK: [[VAL_152:%.*]] = xla_hlo.constant dense<0> : tensor<i32> -// CHECK: [[VAL_153:%.*]] = xla_hlo.constant dense<75> : tensor<i32> -// CHECK: [[VAL_154:%.*]] = xla_hlo.constant dense<0> : tensor<i32> -// CHECK: [[VAL_155:%.*]] = "xla_hlo.dynamic-update-slice"([[VAL_101]], [[VAL_151]], [[VAL_154]], [[VAL_152]], [[VAL_153]]) : (tensor<500x100x75xf32>, tensor<500x100x0xf32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_156:%.*]] = "xla_hlo.slice"([[VAL_5]]) {limit_indices = dense<[500, 100, 100]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<500x100x100xf32>) -> tensor<500x100x100xf32> -// CHECK: [[VAL_157:%.*]] = "xla_hlo.dot_general"([[VAL_156]], [[VAL_144]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<2> : tensor<1xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, precision_config = ["HIGHEST", "HIGHEST"]} : (tensor<500x100x100xf32>, tensor<500x100x75xf32>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_158:%.*]] = "xla_hlo.dot_general"([[VAL_157]], [[VAL_96]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<2> : tensor<1xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}, precision_config = ["HIGHEST", "HIGHEST"]} : (tensor<500x100x75xf32>, tensor<500x100x75xf32>) -> tensor<500x100x100xf32> -// CHECK: [[VAL_159:%.*]] = xla_hlo.add [[VAL_156]], [[VAL_158]] : tensor<500x100x100xf32> -// CHECK: [[VAL_160:%.*]] = xla_hlo.constant dense<0> : tensor<i32> -// CHECK: [[VAL_161:%.*]] = xla_hlo.constant dense<0> : tensor<i32> -// CHECK: [[VAL_162:%.*]] = "xla_hlo.dynamic-update-slice"([[VAL_5]], [[VAL_159]], [[VAL_161]], [[VAL_161]], [[VAL_160]]) : (tensor<500x100x100xf32>, tensor<500x100x100xf32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<500x100x100xf32> -// CHECK: [[VAL_163:%.*]] = "xla_hlo.slice"([[VAL_162]]) {limit_indices = dense<[500, 100, 75]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<500x100x100xf32>) -> tensor<500x100x75xf32> -// CHECK: [[VAL_164:%.*]] = "xla_hlo.slice"([[VAL_155]]) {limit_indices = dense<[500, 75, 75]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<500x100x75xf32>) -> tensor<500x75x75xf32> -// CHECK: return [[VAL_163]], [[VAL_164]] : tensor<500x100x75xf32>, tensor<500x75x75xf32> + // The tf.Qr lowering is a full algorithm that is not effective to verify with + // FileCheck. Just verify that it converted. + // TODO(laurenzo): Move this out of the mainline tf2xla conversion as it is + // really only applicable to certain legacy uses. + // CHECK-NOT: "tf.Qr" %0:2 = "tf.Qr"(%arg0) {full_matrices = false} : (tensor<500x100x75xf32>) -> (tensor<500x100x75xf32>, tensor<500x75x75xf32>) return %0#0, %0#1 : tensor<500x100x75xf32>, tensor<500x75x75xf32> } diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir index d25a84d0e25..9f27a204baf 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir @@ -1,4 +1,4 @@ -// RUN: xla-opt -xla-legalize-to-std %s -o - | FileCheck %s +// RUN: xla-opt -xla-legalize-to-std %s -o - | FileCheck %s --dump-input-on-failure // CHECK-LABEL: func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { @@ -42,40 +42,6 @@ func @binary_ops_int(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32 return %4 : tensor<4xi32> } -// Broadcasting is not currently supported. -// TODO(suderman):Future pass should take all broadcasted binary ops and convert -// them to separate broadcast and binary op. -// CHECK-LABEL: func @binary_ops_broadcast(%arg0: tensor<4x4xf32>, %arg1: tensor<4xf32>) -> tensor<4x4xf32> { -func @binary_ops_broadcast(%arg0: tensor<4x4xf32>, %arg1: tensor<4xf32>) -> tensor<4x4xf32> { - // CHECK-NEXT: %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "add.3"} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> - %0 = "xla_hlo.add"(%arg0, %arg1) { - name = "add.3", broadcast_dimensions = dense<1> : tensor<1xi64>} : - (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> - - // CHECK-NEXT: %1 = "xla_hlo.multiply"(%0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "mul.4"} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> - %1 = "xla_hlo.multiply"(%0, %arg1) { - name = "mul.4", broadcast_dimensions = dense<1> : tensor<1xi64>} : - (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> - - // CHECK-NEXT: %2 = "xla_hlo.subtract"(%1, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "sub.5"} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> - %2 = "xla_hlo.subtract"(%1, %arg1) { - name = "sub.5", broadcast_dimensions = dense<1> : tensor<1xi64>} : - (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> - - // CHECK-NEXT: %3 = "xla_hlo.divide"(%2, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "div.6"} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> - %3 = "xla_hlo.divide"(%2, %arg1) { - name = "div.6", broadcast_dimensions = dense<1> : tensor<1xi64>} : - (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> - - // CHECK-NEXT: %4 = "xla_hlo.remainder"(%3, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> - %4 = "xla_hlo.remainder"(%3, %arg1) { - broadcast_dimensions = dense<1> : tensor<1xi64>} : - (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> - - // CHECK-NEXT: return %4 : tensor<4x4xf32> - return %4 : tensor<4x4xf32> -} - // CHECK-LABEL: func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>) { func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>) { // CHECK-NEXT: %0 = cmpi "eq", %arg0, %arg0 : tensor<4xi32> diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir index bb8010b520c..626e905695c 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir @@ -636,3 +636,16 @@ func @reshape_2D_4D(%arg0: memref<12x42xi32>, %arg1 : memref<12x1x42x1xi32>) { return } // CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] + +// ----- + +// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 2)> +// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func @reverse +func @reverse(%arg0: memref<2x3xf32>, %arg1: memref<2x3xf32>) { + "xla_lhlo.reverse"(%arg0, %arg1) { + dimensions = dense<1> : tensor<1xi64> + } : (memref<2x3xf32>, memref<2x3xf32>) -> () + return +} +// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir index 23e9d9b68e0..d4d775731c8 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir @@ -178,3 +178,24 @@ func @fusion_memref(%input1: memref<10xf32>, %input2: memref<10xf32>, %input3: m } ) : () -> () return } + +// ----- + +// CHECK-LABEL: func @case_memref +func @case_memref(%index: memref<i32>, %operand_1: memref<f32>, %operand_2: memref<f32>, %operand_3: memref<f32>, %out: memref<f32>) -> () { + "xla_lhlo.case"(%index, %operand_1, %operand_2, %operand_3, %out) ( { + ^bb0(%arg0: memref<f32>): + "xla_lhlo.negate"(%arg0, %out) : (memref<f32>, memref<f32>) -> () + "xla_lhlo.terminator"() : () -> () + }, { + ^bb0(%arg0: memref<f32>): + "xla_lhlo.copy"(%arg0, %out) : (memref<f32>, memref<f32>) -> () + "xla_lhlo.terminator"() : () -> () + }, { + ^bb0(%arg0: memref<f32>): + "xla_lhlo.add"(%arg0, %arg0, %out) : (memref<f32>, memref<f32>, memref<f32>) -> () + "xla_lhlo.terminator"() : () -> () + } + ) : (memref<i32>, memref<f32>, memref<f32>, memref<f32>, memref<f32>) -> () + return +} diff --git a/tensorflow/compiler/mlir/xla/tests/lower-complex.mlir b/tensorflow/compiler/mlir/xla/tests/lower-complex.mlir index 35a5ae549d5..81376761467 100644 --- a/tensorflow/compiler/mlir/xla/tests/lower-complex.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lower-complex.mlir @@ -1,4 +1,4 @@ -// RUN: xla-opt %s -test-xla-lower-complex | FileCheck %s +// RUN: xla-opt %s -test-xla-chlo-legalize-to-hlo -test-xla-lower-complex | FileCheck %s --dump-input-on-failure // CHECK-LABEL: @add func @add(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) { @@ -15,21 +15,6 @@ func @add(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, % return %5, %6 : tensor<2xf32>, tensor<2xf32> } -// CHECK-LABEL: @add_broadcast -func @add_broadcast(%arg0 : tensor<1x2xf32>, %arg1 : tensor<1x2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<1x2xf32>, tensor<1x2xf32>) { - %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> (tensor<1x2xcomplex<f32>>) - %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>) - - // CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.add"(%arg0, %arg2) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.add"(%arg1, %arg3) {broadcast_dimensions = dense<1> : tensor<1xi64>} - %4 = "xla_hlo.add"(%2, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<1x2xcomplex<f32>>) - %5 = "xla_hlo.real"(%4) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>) - %6 = "xla_hlo.imag"(%4) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>) - - // CHECK: return [[VAL0]], [[VAL1]] - return %5, %6 : tensor<1x2xf32>, tensor<1x2xf32> -} - // CHECK-LABEL: @add_unranked func @add_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) { %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>) @@ -60,21 +45,6 @@ func @sub(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, % return %5, %6 : tensor<2xf32>, tensor<2xf32> } -// CHECK-LABEL: @sub_broadcast -func @sub_broadcast(%arg0 : tensor<1x2xf32>, %arg1 : tensor<1x2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<1x2xf32>, tensor<1x2xf32>) { - %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> (tensor<1x2xcomplex<f32>>) - %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>) - - // CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.subtract"(%arg0, %arg2) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.subtract"(%arg1, %arg3) {broadcast_dimensions = dense<1> : tensor<1xi64>} - %4 = "xla_hlo.subtract"(%2, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<1x2xcomplex<f32>>) - %5 = "xla_hlo.real"(%4) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>) - %6 = "xla_hlo.imag"(%4) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>) - - // CHECK: return [[VAL0]], [[VAL1]] - return %5, %6 : tensor<1x2xf32>, tensor<1x2xf32> -} - // CHECK-LABEL: @sub_unranked func @sub_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) { %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>) @@ -109,25 +79,6 @@ func @mul(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, % return %5, %6 : tensor<2xf32>, tensor<2xf32> } -// CHECK-LABEL: @mul_broadcast -func @mul_broadcast(%arg0 : tensor<1x2xf32>, %arg1 : tensor<1x2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<1x2xf32>, tensor<1x2xf32>) { - %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> (tensor<1x2xcomplex<f32>>) - %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>) - - // CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.multiply"(%arg0, %arg2) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.multiply"(%arg1, %arg3) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[VAL2:%.+]] = xla_hlo.subtract [[VAL0]], [[VAL1]] - // CHECK-DAG: [[VAL3:%.+]] = "xla_hlo.multiply"(%arg0, %arg3) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[VAL4:%.+]] = "xla_hlo.multiply"(%arg1, %arg2) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[VAL5:%.+]] = xla_hlo.add [[VAL3]], [[VAL4]] - %4 = "xla_hlo.multiply"(%2, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<1x2xcomplex<f32>>) - %5 = "xla_hlo.real"(%4) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>) - %6 = "xla_hlo.imag"(%4) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>) - - // CHECK: return %2, %5 : tensor<1x2xf32>, tensor<1x2xf32> - return %5, %6 : tensor<1x2xf32>, tensor<1x2xf32> -} - // CHECK-LABEL: @mul_unranked func @mul_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) { %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>) @@ -186,45 +137,6 @@ func @div(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, % // ----- -// CHECK-LABEL: @div_broadcast -func @div_broadcast(%arg0 : tensor<1x2xf32>, %arg1 : tensor<1x2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<1x2xf32>, tensor<1x2xf32>) { - %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> (tensor<1x2xcomplex<f32>>) - %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>) - - // CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.negate"(%arg3) - - // Compute the numerator's real component: - // numerator.real = lhs.real * rhs.real lhs.imag * rhs.imag - // CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.multiply"(%arg0, %arg2) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[VAL2:%.+]] = "xla_hlo.multiply"(%arg1, [[VAL0]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[VAL3:%.+]] = xla_hlo.subtract [[VAL1]], [[VAL2]] - - // Compute the real valued denominator as rhs * con(rhs): - // denominator = rhs.real * rhs.real + rhs.imag * rhs.imag - // CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply %arg2, %arg2 - // CHECK-DAG: [[VAL5:%.+]] = xla_hlo.multiply %arg3, [[VAL0]] - // CHECK-DAG: [[VAL6:%.+]] = xla_hlo.subtract [[VAL4]], [[VAL5]] - - // Compute the numerator's imaginary component: - // numerator.imag = lhs.imag * rhs.real - lhs.real * rhs.imag - // CHECK-DAG: [[VAL7:%.+]] = "xla_hlo.multiply"(%arg1, %arg2) - // CHECK-DAG: [[VAL8:%.+]] = "xla_hlo.multiply"(%arg0, [[VAL0]]) - // CHECK-DAG: [[VAL9:%.+]] = xla_hlo.add [[VAL8]], [[VAL7]] - - // Divide the numerator by the real valued denominator. - // CHECK-DAG: [[VAL10:%.+]] = "xla_hlo.divide"([[VAL3]], [[VAL6]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} - // CHECK-DAG: [[VAL11:%.+]] = "xla_hlo.divide"([[VAL9]], [[VAL6]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} - %4 = "xla_hlo.divide"(%2, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<1x2xcomplex<f32>>) - - %5 = "xla_hlo.real"(%4) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>) - %6 = "xla_hlo.imag"(%4) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>) - - // CHECK: return [[VAL10]], [[VAL11]] - return %5, %6 : tensor<1x2xf32>, tensor<1x2xf32> -} - -// ----- - // CHECK-LABEL: @div_unranked func @div_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) { %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>) diff --git a/tensorflow/compiler/mlir/xla/tests/materialize-broadcasts.mlir b/tensorflow/compiler/mlir/xla/tests/materialize-broadcasts.mlir index 2340650dda8..55b55c7b4e2 100644 --- a/tensorflow/compiler/mlir/xla/tests/materialize-broadcasts.mlir +++ b/tensorflow/compiler/mlir/xla/tests/materialize-broadcasts.mlir @@ -1,225 +1,5 @@ // RUN: xla-opt -test-xla-materialize-broadcasts -split-input-file %s -o - | FileCheck --dump-input=fail %s -// CHECK-LABEL: @addBroadcastRhs -func @addBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.add %arg0, %[[BROADCAST1]] : tensor<1x4xf32> - %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} - -// ----- - -// CHECK-LABEL: @addBroadcastLhs -func @addBroadcastLhs(%arg0: tensor<4xf32>, %arg1: tensor<1x4xf32>) -> tensor<1x4xf32> { - // CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.add %[[BROADCAST0]], %arg1 : tensor<1x4xf32> - %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>, tensor<1x4xf32>) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} - -// ----- - -// CHECK-LABEL: @addBroadcastEqual -func @addBroadcastEqual(%arg0: tensor<4x1xf32>, %arg1: tensor<1x4xf32>) -> tensor<4x4xf32> { - // CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x1xf32>) -> tensor<4x4xf32> - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x4xf32>) -> tensor<4x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.add %[[BROADCAST0]], %[[BROADCAST1]] : tensor<4x4xf32> - %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<4x1xf32>, tensor<1x4xf32>) -> tensor<4x4xf32> - return %0 : tensor<4x4xf32> -} - -// ----- - -// CHECK-LABEL: @addBroadcastMultidimension -func @addBroadcastMultidimension(%arg0: tensor<1x1xf32>, %arg1: tensor<1x1x4xf32>) -> tensor<1x1x4xf32> { - // CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x1xf32>) -> tensor<1x1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.add %[[BROADCAST0]], %arg1 : tensor<1x1x4xf32> - %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x1xf32>, tensor<1x1x4xf32>) -> tensor<1x1x4xf32> - return %0 : tensor<1x1x4xf32> -} - -// ----- - -// CHECK-LABEL: @addBroadcastBothArgs -func @addBroadcastBothArgs(%arg0: tensor<1x2xf32>, %arg1: tensor<3x2x1xf32>) -> tensor<3x2x2xf32> { - // CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<1x2xf32>) -> tensor<3x2x2xf32> - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<3x2x1xf32>) -> tensor<3x2x2xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.add %[[BROADCAST0]], %[[BROADCAST1]] : tensor<3x2x2xf32> - %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<1x2xf32>, tensor<3x2x1xf32>) -> tensor<3x2x2xf32> - return %0 : tensor<3x2x2xf32> -} - -// ----- - -// CHECK-LABEL: @addBroadcastScalar -func @addBroadcastScalar(%arg0: tensor<4xf32>, %arg1: tensor<f32>) -> tensor<4xf32> { - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>) -> tensor<4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.add %arg0, %[[BROADCAST1]] : tensor<4xf32> - %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32> - return %0 : tensor<4xf32> -} - -// ----- - -// CHECK-LABEL: @addWithoutBroadcast -func @addWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.add %arg0, %arg1 : tensor<4xf32> - %0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - return %0 : tensor<4xf32> -} - -// ----- - -// CHECK-LABEL: @addUnranked -func @addUnranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.add %arg0, %arg1 : tensor<*xf32> - %0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - return %0 : tensor<*xf32> -} - -// ----- - -// CHECK-LABEL: @atan2BroadcastRhs -func @atan2BroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.atan2 %arg0, %[[BROADCAST1]] : tensor<1x4xf32> - %0 = "xla_hlo.atan2"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} - -// ----- - -// CHECK-LABEL: @divBroadcastRhs -func @divBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.divide %arg0, %[[BROADCAST1]] : tensor<1x4xf32> - %0 = "xla_hlo.divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} - -// ----- - -// CHECK-LABEL: @maxBroadcastRhs -func @maxBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.maximum %arg0, %[[BROADCAST1]] : tensor<1x4xf32> - %0 = "xla_hlo.maximum"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} - -// ----- - -// CHECK-LABEL: @minBroadcastRhs -func @minBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.minimum %arg0, %[[BROADCAST1]] : tensor<1x4xf32> - %0 = "xla_hlo.minimum"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} - -// ----- - -// CHECK-LABEL: @mulBroadcastRhs -func @mulBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.multiply %arg0, %[[BROADCAST1]] : tensor<1x4xf32> - %0 = "xla_hlo.multiply"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} - -// ----- - -// CHECK-LABEL: @powBroadcastRhs -func @powBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.power %arg0, %[[BROADCAST1]] : tensor<1x4xf32> - %0 = "xla_hlo.power"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} - -// ----- - -// CHECK-LABEL: @remainderBroadcastRhs -func @remainderBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.remainder %arg0, %[[BROADCAST1]] : tensor<1x4xf32> - %0 = "xla_hlo.remainder"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} - -// ----- - -// CHECK-LABEL: @shiftLeftBroadcastRhs -func @shiftLeftBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.shift_left %arg0, %[[BROADCAST1]] : tensor<1x4xf32> - %0 = "xla_hlo.shift_left"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} - -// ----- - -// CHECK-LABEL: @shiftRightArithmeticBroadcastRhs -func @shiftRightArithmeticBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.shift_right_arithmetic %arg0, %[[BROADCAST1]] : tensor<1x4xf32> - %0 = "xla_hlo.shift_right_arithmetic"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} - -// ----- - -// CHECK-LABEL: @shiftRightLogicalBroadcastRhs -func @shiftRightLogicalBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.shift_right_logical %arg0, %[[BROADCAST1]] : tensor<1x4xf32> - %0 = "xla_hlo.shift_right_logical"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} - -// ----- - -// CHECK-LABEL: @subBroadcastRhs -func @subBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.subtract %arg0, %[[BROADCAST1]] : tensor<1x4xf32> - %0 = "xla_hlo.subtract"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> - return %0 : tensor<1x4xf32> -} - -// ----- - -// CHECK-LABEL: @andBroadcastRhs -func @andBroadcastRhs(%arg0: tensor<1x4xi32>, %arg1: tensor<4xi32>) -> tensor<1x4xi32> { - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>) -> tensor<1x4xi32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.and %arg0, %[[BROADCAST1]] : tensor<1x4xi32> - %0 = "xla_hlo.and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xi32>, tensor<4xi32>) -> tensor<1x4xi32> - return %0 : tensor<1x4xi32> -} - -// ----- - -// CHECK-LABEL: @orBroadcastRhs -func @orBroadcastRhs(%arg0: tensor<1x4xi32>, %arg1: tensor<4xi32>) -> tensor<1x4xi32> { - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>) -> tensor<1x4xi32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.or %arg0, %[[BROADCAST1]] : tensor<1x4xi32> - %0 = "xla_hlo.or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xi32>, tensor<4xi32>) -> tensor<1x4xi32> - return %0 : tensor<1x4xi32> -} - -// ----- - -// CHECK-LABEL: @xorBroadcastRhs -func @xorBroadcastRhs(%arg0: tensor<1x4xi32>, %arg1: tensor<4xi32>) -> tensor<1x4xi32> { - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>) -> tensor<1x4xi32> - // CHECK-NEXT: %[[RESULT:.*]] = xla_hlo.xor %arg0, %[[BROADCAST1]] : tensor<1x4xi32> - %0 = "xla_hlo.xor"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xi32>, tensor<4xi32>) -> tensor<1x4xi32> - return %0 : tensor<1x4xi32> -} - -// ----- - // CHECK-LABEL: @clampBroadcast // CHECK-SAME: (%[[MIN:.+]]: tensor<f32>, %[[VAL:.+]]: tensor<4xf32>, %[[MAX:.+]]: tensor<f32>) func @clampBroadcast(%min: tensor<f32>, %value: tensor<4xf32>, %max: tensor<f32>) -> tensor<4xf32> { @@ -229,69 +9,3 @@ func @clampBroadcast(%min: tensor<f32>, %value: tensor<4xf32>, %max: tensor<f32> %0 = "xla_hlo.clamp"(%min, %value, %max) : (tensor<f32>, tensor<4xf32>, tensor<f32>) -> tensor<4xf32> return %0 : tensor<4xf32> } - -// ----- - -// CHECK-LABEL: @compareBroadcastRhs -func @compareBroadcastRhs(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xi1> { - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x4xf32> - // CHECK-NEXT: %[[RESULT:.*]] = "xla_hlo.compare"(%arg0, %[[BROADCAST1]]) {comparison_direction = "NE"} : (tensor<1x4xf32>, tensor<1x4xf32>) -> tensor<1x4xi1> - %0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xi1> - return %0 : tensor<1x4xi1> -} - -// ----- - -// CHECK-LABEL: @dynamicCompareBroadcastRhs -func @dynamicCompareBroadcastRhs(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>) -> tensor<?x?xi1> { - // CHECK-NEXT: %[[DIM0:.*]] = dim %arg0, 0 : tensor<?x?xf32> - // CHECK-NEXT: %[[DIM0C:.*]] = index_cast %[[DIM0]] : index to i32 - // CHECK-NEXT: %c1 = constant 1 : index - // CHECK-NEXT: %[[DIM1_0:.*]] = dim %arg0, 1 : tensor<?x?xf32> - // CHECK-NEXT: %[[DIM1_1:.*]] = dim %arg1, 0 : tensor<?xf32> - // CHECK-NEXT: %[[CMPI:.*]] = cmpi "eq", %[[DIM1_0]], %c1 : index - // CHECK-NEXT: %[[SEL:.*]] = select %[[CMPI]], %[[DIM1_0]], %[[DIM1_1]] : index - // CHECK-NEXT: %[[DIM1C:.*]] = index_cast %[[SEL]] : index to i32 - // CHECK-NEXT: %[[SHAPE:.*]] = "xla_hlo.scalars_to_dimension_tensor"(%[[DIM0C]], %[[DIM1C]]) : (i32, i32) -> tensor<2xi32> - // CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[SHAPE]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xi32>) -> tensor<?x?xf32> - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[SHAPE]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xi32>) -> tensor<?x?xf32> - // CHECK-NEXT: "xla_hlo.compare"(%[[BROADCAST0]], %[[BROADCAST1]]) {comparison_direction = "NE"} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1> - %0 = "xla_hlo.compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<?x?xf32>, tensor<?xf32>) -> tensor<?x?xi1> - return %0 : tensor<?x?xi1> -} - -// ----- - -// CHECK-LABEL: @dynamicBroadcastAdd -func @dynamicBroadcastAdd(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>) -> tensor<?x?xf32> { - // CHECK-NEXT: %[[DIM0:.*]] = dim %arg0, 0 : tensor<?x?xf32> - // CHECK-NEXT: %[[DIM0C:.*]] = index_cast %[[DIM0]] : index to i32 - // CHECK-NEXT: %c1 = constant 1 : index - // CHECK-NEXT: %[[DIM1_0:.*]] = dim %arg0, 1 : tensor<?x?xf32> - // CHECK-NEXT: %[[DIM1_1:.*]] = dim %arg1, 0 : tensor<?xf32> - // CHECK-NEXT: %[[CMPI:.*]] = cmpi "eq", %[[DIM1_0]], %c1 : index - // CHECK-NEXT: %[[SEL:.*]] = select %[[CMPI]], %[[DIM1_0]], %[[DIM1_1]] : index - // CHECK-NEXT: %[[DIM1C:.*]] = index_cast %[[SEL]] : index to i32 - // CHECK-NEXT: %[[SHAPE:.*]] = "xla_hlo.scalars_to_dimension_tensor"(%[[DIM0C]], %[[DIM1C]]) : (i32, i32) -> tensor<2xi32> - // CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[SHAPE]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xi32>) -> tensor<?x?xf32> - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[SHAPE]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xi32>) -> tensor<?x?xf32> - // CHECK-NEXT: xla_hlo.add %[[BROADCAST0]], %[[BROADCAST1]] : tensor<?x?xf32> - %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?x?xf32>, tensor<?xf32>) -> tensor<?x?xf32> - return %0 : tensor<?x?xf32> -} - -// ----- - -// CHECK-LABEL: @dynamicBroadcastAddScalar -func @dynamicBroadcastAddScalar(%arg0: tensor<?x?xf32>, %arg1: tensor<f32>) -> tensor<?x?xf32> { - // CHECK-NEXT: %[[DIM0:.*]] = dim %arg0, 0 : tensor<?x?xf32> - // CHECK-NEXT: %[[DIM0C:.*]] = index_cast %[[DIM0]] : index to i32 - // CHECK-NEXT: %[[DIM1:.*]] = dim %arg0, 1 : tensor<?x?xf32> - // CHECK-NEXT: %[[DIM1C:.*]] = index_cast %[[DIM1]] : index to i32 - // CHECK-NEXT: %[[SHAPE:.*]] = "xla_hlo.scalars_to_dimension_tensor"(%[[DIM0C]], %[[DIM1C]]) : (i32, i32) -> tensor<2xi32> - // CHECK-NEXT: %[[BROADCAST0:.*]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[SHAPE]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xi32>) -> tensor<?x?xf32> - // CHECK-NEXT: %[[BROADCAST1:.*]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[SHAPE]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, tensor<2xi32>) -> tensor<?x?xf32> - // CHECK-NEXT: xla_hlo.add %[[BROADCAST0]], %[[BROADCAST1]] : tensor<?x?xf32> - %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32> - return %0 : tensor<?x?xf32> -} diff --git a/tensorflow/compiler/mlir/xla/tests/ops.mlir b/tensorflow/compiler/mlir/xla/tests/ops.mlir index 8cb63311657..e6ae074f922 100644 --- a/tensorflow/compiler/mlir/xla/tests/ops.mlir +++ b/tensorflow/compiler/mlir/xla/tests/ops.mlir @@ -156,6 +156,98 @@ func @broadcast_in_dim_bad_shape_mismatch(%arg0: tensor<3xi32>) -> tensor<1x2x3x // ----- +func @case_mismatch_num_args(%index: tensor<i32>, %operand_1: tensor<f32>, %operand_2: tensor<f32>, %operand_3: tensor<f32>) -> tensor<f32> { + // expected-error@+1 {{expects branch regions to have single argument, but found 2 for branch 1}} + %0 = "xla_hlo.case"(%index, %operand_1, %operand_2, %operand_3) ( { + ^bb0(%arg0: tensor<f32>): + %1 = "xla_hlo.negate"(%arg0) : (tensor<f32>) -> tensor<f32> + "xla_hlo.return"(%1) : (tensor<f32>) -> () + }, { + ^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>): + %1 = "xla_hlo.copy"(%arg0) : (tensor<f32>) -> tensor<f32> + "xla_hlo.return"(%1) : (tensor<f32>) -> () + }, { + ^bb0(%arg0: tensor<f32>): + %1 = "xla_hlo.floor"(%arg0) : (tensor<f32>) -> tensor<f32> + "xla_hlo.return"(%1) : (tensor<f32>) -> () + } + ) : (tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<f32> + return %0 : tensor<f32> +} + +// ----- + +func @case_mismatch_num_results(%index: tensor<i32>, %operand_1: tensor<f32>, %operand_2: tensor<f32>, %operand_3: tensor<f32>) -> tensor<f32> { + // expected-error@+1 {{branch 1 returned values do not match op result types}} + %0 = "xla_hlo.case"(%index, %operand_1, %operand_2, %operand_3) ( { + ^bb0(%arg0: tensor<f32>): + %1 = "xla_hlo.negate"(%arg0) : (tensor<f32>) -> tensor<f32> + "xla_hlo.return"(%1) : (tensor<f32>) -> () + }, { + ^bb0(%arg0: tensor<f32>): + %1 = "xla_hlo.copy"(%arg0) : (tensor<f32>) -> tensor<f32> + "xla_hlo.return"(%1, %arg0) : (tensor<f32>, tensor<f32>) -> () + }, { + ^bb0(%arg0: tensor<f32>): + %1 = "xla_hlo.floor"(%arg0) : (tensor<f32>) -> tensor<f32> + "xla_hlo.return"(%1) : (tensor<f32>) -> () + } + ) : (tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<f32> + return %0 : tensor<f32> +} + +// ----- + +func @case_mismatch_arg_type(%index: tensor<i32>, %operand_1: tensor<f32>, %operand_2: tensor<f32>, %operand_3: tensor<f32>) -> tensor<f32> { + // expected-error@+1 {{expects operand 2 to be of type 'tensor<i32>', but found 'tensor<f32>'}} + %0 = "xla_hlo.case"(%index, %operand_1, %operand_2, %operand_3) ( { + ^bb0(%arg0: tensor<f32>): + %1 = "xla_hlo.negate"(%arg0) : (tensor<f32>) -> tensor<f32> + "xla_hlo.return"(%1) : (tensor<f32>) -> () + }, { + ^bb0(%arg0: tensor<i32>): + %1 = xla_hlo.constant dense<2.0> : tensor<f32> + "xla_hlo.return"(%1) : (tensor<f32>) -> () + }, { + ^bb0(%arg0: tensor<f32>): + %1 = "xla_hlo.floor"(%arg0) : (tensor<f32>) -> tensor<f32> + "xla_hlo.return"(%1) : (tensor<f32>) -> () + } + ) : (tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<f32> + return %0 : tensor<f32> +} + +// ----- + +func @case_mismatch_return_type(%index: tensor<i32>, %operand_1: tensor<f32>, %operand_2: tensor<f32>, %operand_3: tensor<f32>) -> tensor<f32> { + // expected-error@+1 {{branch 1 returned values do not match op result types}} + %0 = "xla_hlo.case"(%index, %operand_1, %operand_2, %operand_3) ( { + ^bb0(%arg0: tensor<f32>): + %1 = "xla_hlo.negate"(%arg0) : (tensor<f32>) -> tensor<f32> + "xla_hlo.return"(%1) : (tensor<f32>) -> () + }, { + ^bb0(%arg0: tensor<f32>): + %1 = xla_hlo.constant dense<2> : tensor<i32> + "xla_hlo.return"(%1) : (tensor<i32>) -> () + }, { + ^bb0(%arg0: tensor<f32>): + %1 = "xla_hlo.floor"(%arg0) : (tensor<f32>) -> tensor<f32> + "xla_hlo.return"(%1) : (tensor<f32>) -> () + } + ) : (tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<f32> + return %0 : tensor<f32> +} + +// ----- + +func @case_empty_region(%index: tensor<i32>, %operand_1: tensor<f32>) -> () { + // expected-error@+1 {{cannot have empty regions}} + "xla_hlo.case"(%index, %operand_1) ( {} ) : (tensor<i32>, tensor<f32>) -> tensor<f32> + return +} + +// ----- + // CHECK-LABEL: func @comp_eq func @comp_eq(%arg0: tensor<3xi32>, %arg1: tensor<3xi32>) -> tensor<3xi1> { %0 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> @@ -461,6 +553,14 @@ func @scalars_to_dimension_tensor(%arg0: i32, %arg1: i32) -> tensor<2xi32> { // ----- +// CHECK-LABEL: @scalars_to_dimension_tensor_index +func @scalars_to_dimension_tensor_index(%arg0: index, %arg1: index) -> tensor<2xindex> { + %0 = "xla_hlo.scalars_to_dimension_tensor"(%arg0, %arg1) : (index, index) -> tensor<2xindex> + return %0 : tensor<2xindex> +} + +// ----- + // CHECK-LABEL: func @select func @select(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> { %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> diff --git a/tensorflow/compiler/mlir/xla/tests/sink-constants-to-control-flow.mlir b/tensorflow/compiler/mlir/xla/tests/sink-constants-to-control-flow.mlir new file mode 100644 index 00000000000..9f54e40dcaa --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/sink-constants-to-control-flow.mlir @@ -0,0 +1,60 @@ +// RUN: xla-opt %s -xla-hlo-sink-constants-to-control-flow | FileCheck %s --dump-input=fail + +// Tests sinking constants to a while loop. + +// CHECK-LABEL: func @sink_const_to_while +func @sink_const_to_while(%arg0: tensor<i64>) -> tensor<i64> { + // CHECK-NEXT: xla_hlo.while + %c0 = xla_hlo.constant dense<1> : tensor<i64> + %c1 = xla_hlo.constant dense<2> : tensor<i64> + %0 = "xla_hlo.while"(%arg0) ( { + ^bb0(%arg1: tensor<i64>): + // CHECK: %[[ARG1A:.+]]: tensor<i64> + // CHECK: %[[C0:.+]] = xla_hlo.constant dense<1> : tensor<i64> + // CHECK: "xla_hlo.compare"(%[[C0]], %[[ARG1A]]) + %1 = "xla_hlo.compare"(%c0, %arg1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1> + "xla_hlo.return"(%1) : (tensor<i1>) -> () + }, { + ^bb0(%arg1: tensor<i64>): + // CHECK: %[[ARG1B:.+]]: tensor<i64> + // CHECK-DAG: %[[C1:.+]] = xla_hlo.constant dense<2> : tensor<i64> + // CHECK-DAG: %[[ADD0:.+]] = xla_hlo.add %[[ARG1B]], %[[ARG1B]] + %2 = xla_hlo.add %arg1, %arg1 : tensor<i64> + // CHECK: %[[ADD1:.+]] = xla_hlo.add %[[C1]], %[[ADD0]] + %3 = xla_hlo.add %c1, %2 : tensor<i64> + // CHECK: %[[ADD2:.+]] = xla_hlo.add %[[C1]], %[[ADD1]] + %4 = xla_hlo.add %c1, %3 : tensor<i64> + "xla_hlo.return"(%4) : (tensor<i64>) -> () + }) : (tensor<i64>) -> tensor<i64> + return %0 : tensor<i64> +} + +// Tests sinking constants to a conditional op. + +// CHECK-LABEL: func @sink_const_to_conditional +func @sink_const_to_conditional(%arg0: tensor<i64>) -> tensor<i64> { + %c0 = xla_hlo.constant dense<1> : tensor<i64> + %c1 = xla_hlo.constant dense<2> : tensor<i64> + %0 = "xla_hlo.compare"(%arg0, %c0) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1> + %1 = "xla_hlo.tuple"(%arg0) : (tensor<i64>) -> tuple<tensor<i64>> + // CHECK: xla_hlo.if + %2 = "xla_hlo.if"(%0, %1, %1) ( { + ^bb0(%arg1: tuple<tensor<i64>>): + // CHECK: %[[C0:.+]] = xla_hlo.constant dense<1> : tensor<i64> + %3 = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple<tensor<i64>>) -> tensor<i64> + // CHECK: %[[ADD0:.+]] = xla_hlo.add %[[C0]], + %4 = xla_hlo.add %c0, %3 : tensor<i64> + %5 = "xla_hlo.tuple"(%4) : (tensor<i64>) -> tuple<tensor<i64>> + "xla_hlo.return"(%5) : (tuple<tensor<i64>>) -> () + }, { + ^bb0(%arg1: tuple<tensor<i64>>): + // CHECK: %[[C1:.+]] = xla_hlo.constant dense<2> : tensor<i64> + %6 = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple<tensor<i64>>) -> tensor<i64> + // CHECK: %[[ADD1:.+]] = xla_hlo.add %[[C1]], + %7 = xla_hlo.add %c1, %6 : tensor<i64> + %8 = "xla_hlo.tuple"(%7) : (tensor<i64>) -> tuple<tensor<i64>> + "xla_hlo.return"(%8) : (tuple<tensor<i64>>) -> () + }) : (tensor<i1>, tuple<tensor<i64>>, tuple<tensor<i64>>) -> tuple<tensor<i64>> + %9 = "xla_hlo.get_tuple_element"(%2) {index = 0 : i32} : (tuple<tensor<i64>>) -> tensor<i64> + return %9 : tensor<i64> +} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/case.mlir b/tensorflow/compiler/mlir/xla/tests/translate/case.mlir new file mode 100644 index 00000000000..dba9e8b61ca --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/case.mlir @@ -0,0 +1,99 @@ +// RUN: tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s | FileCheck %s + +func @main() -> tensor<f32> { + %cst = constant {name = "constant"} dense<1> : tensor<i32> + %cst_0 = constant {name = "constant.1"} dense<5.600000e+01> : tensor<f32> + %cst_1 = constant {name = "constant.2"} dense<1.200000e+01> : tensor<f32> + %cst_2 = constant {name = "constant.3"} dense<1.300000e+01> : tensor<f32> + %0 = "xla_hlo.case"(%cst, %cst_0, %cst_1, %cst_2) ( { + ^bb0(%arg0: tensor<f32>): + %1 = "xla_hlo.negate"(%arg0) : (tensor<f32>) -> tensor<f32> + "xla_hlo.return"(%1) : (tensor<f32>) -> () + }, { + ^bb0(%arg0: tensor<f32>): + %1 = "xla_hlo.copy"(%arg0) : (tensor<f32>) -> tensor<f32> + "xla_hlo.return"(%1) : (tensor<f32>) -> () + }, { + ^bb0(%arg0: tensor<f32>): + %1 = "xla_hlo.floor"(%arg0) : (tensor<f32>) -> tensor<f32> + "xla_hlo.return"(%1) : (tensor<f32>) -> () + }) {name = "conditional"} : (tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<f32> + return %0 : tensor<f32> +} + +// CHECK: %[[NEGATE_BRANCH:.*]] ({{.*}}: f32[]) -> f32[] { +// CHECK: %[[ARG:.*]] = f32[] parameter(0) +// CHECK: ROOT %[[RESULT:.*]] = f32[] negate(f32[] %[[ARG]]) +// CHECK: } + +// CHECK: %[[COPY_BRANCH:.*]] ({{.*}}: f32[]) -> f32[] { +// CHECK: %[[ARG:.*]] = f32[] parameter(0) +// CHECK: ROOT %[[RESULT:.*]] = f32[] copy(f32[] %[[ARG]]) +// CHECK: } + +// CHECK: %[[FLOOR_BRANCH:.*]] ({{.*}}: f32[]) -> f32[] { +// CHECK: %[[ARG:.*]] = f32[] parameter(0) +// CHECK: ROOT %[[RESULT:.*]] = f32[] floor(f32[] %[[ARG]]) +// CHECK: } + +// CHECK-LABEL: ENTRY +// CHECK-SAME: () -> f32[] + +// CHECK: %[[INDEX:.*]] = s32[] constant(1) +// CHECK: %[[OPERAND_1:.*]] = f32[] constant(56) +// CHECK: %[[OPERAND_2:.*]] = f32[] constant(12) +// CHECK: %[[OPERAND_3:.*]] = f32[] constant(13) +// CHECK: ROOT %[[RESULT:.*]] = f32[] conditional(s32[] %[[INDEX]], f32[] %[[OPERAND_1]], f32[] %[[OPERAND_2]], f32[] %[[OPERAND_3]]), branch_computations={%[[NEGATE_BRANCH]], %[[COPY_BRANCH]], %[[FLOOR_BRANCH]]} + +// ----- + +func @main() -> (tensor<f32>, tensor<f32>) { + %cst = constant {name = "constant"} dense<1> : tensor<i32> + %cst_0 = constant {name = "constant.1"} dense<5.600000e+01> : tensor<f32> + %cst_1 = constant {name = "constant.2"} dense<1.200000e+01> : tensor<f32> + %cst_2 = constant {name = "constant.3"} dense<1.300000e+01> : tensor<f32> + %0:2 = "xla_hlo.case"(%cst, %cst_0, %cst_1, %cst_2) ( { + ^bb0(%arg0: tensor<f32>): + %1 = "xla_hlo.negate"(%arg0) {name = "negate"} : (tensor<f32>) -> tensor<f32> + "xla_hlo.return"(%1, %1) : (tensor<f32>, tensor<f32>) -> () + }, { + ^bb0(%arg0: tensor<f32>): + %1 = "xla_hlo.copy"(%arg0) {name = "copy"} : (tensor<f32>) -> tensor<f32> + "xla_hlo.return"(%1, %1) : (tensor<f32>, tensor<f32>) -> () + }, { + ^bb0(%arg0: tensor<f32>): + %1 = "xla_hlo.floor"(%arg0) {name = "floor"} : (tensor<f32>) -> tensor<f32> + "xla_hlo.return"(%1, %1) : (tensor<f32>, tensor<f32>) -> () + }) {name = "conditional"} : (tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>) + return %0#0, %0#1 : tensor<f32>, tensor<f32> +} + +// CHECK: %[[NEGATE_BRANCH:.*]] ({{.*}}: f32[]) -> (f32[], f32[]) { +// CHECK: %[[ARG:.*]] = f32[] parameter(0) +// CHECK: %[[NEGATE:.*]] = f32[] negate(f32[] %[[ARG]]) +// CHECK: ROOT %[[TUPLE:.*]] = (f32[], f32[]) tuple(f32[] %[[NEGATE]], f32[] %[[NEGATE]]) +// CHECK: } + +// CHECK: %[[COPY_BRANCH:.*]] ({{.*}}: f32[]) -> (f32[], f32[]) { +// CHECK: %[[ARG:.*]] = f32[] parameter(0) +// CHECK: %[[COPY:.*]] = f32[] copy(f32[] %[[ARG]]) +// CHECK: ROOT %[[TUPLE:.*]] = (f32[], f32[]) tuple(f32[] %[[COPY]], f32[] %[[COPY]]) +// CHECK: } + +// CHECK: %[[FLOOR_BRANCH:.*]] ({{.*}}: f32[]) -> (f32[], f32[]) { +// CHECK: %[[ARG:.*]] = f32[] parameter(0) +// CHECK: %[[FLOOR:.*]] = f32[] floor(f32[] %[[ARG]]) +// CHECK: ROOT %[[TUPLE:.*]] = (f32[], f32[]) tuple(f32[] %[[FLOOR]], f32[] %[[FLOOR]]) +// CHECK: } + +// CHECK-LABEL: ENTRY +// CHECK-SAME: () -> (f32[], f32[]) + +// CHECK: %[[INDEX:.*]] = s32[] constant(1) +// CHECK: %[[OPERAND_1:.*]] = f32[] constant(56) +// CHECK: %[[OPERAND_2:.*]] = f32[] constant(12) +// CHECK: %[[OPERAND_3:.*]] = f32[] constant(13) +// CHECK: %[[TUPLE:.*]] = (f32[], f32[]) conditional(s32[] %[[INDEX]], f32[] %[[OPERAND_1]], f32[] %[[OPERAND_2]], f32[] %[[OPERAND_3]]), branch_computations={%[[NEGATE_BRANCH]], %[[COPY_BRANCH]], %[[FLOOR_BRANCH]]} +// CHECK: %[[RES_1:.*]] = f32[] get-tuple-element((f32[], f32[]) %[[TUPLE]]), index=0 +// CHECK: %[[RES_2:.*]] = f32[] get-tuple-element((f32[], f32[]) %[[TUPLE]]), index=1 +// CHECK: ROOT %[[RESULT:.*]] = (f32[], f32[]) tuple(f32[] %[[RES_1]], f32[] %[[RES_2]]) diff --git a/tensorflow/compiler/mlir/xla/tests/translate/case_conditional.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/case_conditional.hlotxt new file mode 100644 index 00000000000..2ff223cd480 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/case_conditional.hlotxt @@ -0,0 +1,46 @@ +// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s + +HloModule Indexed_Conditional + +%Negate (x: f32[]) -> f32[] { + %x = f32[] parameter(0) + ROOT %negate = f32[] negate(f32[] %x) +} + +%Identity (y: f32[]) -> f32[] { + %y = f32[] parameter(0) + ROOT %copy = f32[] copy(f32[] %y) +} + +%Floor (z: f32[]) -> f32[] { + %z = f32[] parameter(0) + ROOT %floor = f32[] floor(f32[] %z) +} + +ENTRY %indexed_conditional () -> f32[] { + %constant = s32[] constant(1) + %constant.1 = f32[] constant(56) + %constant.2 = f32[] constant(12) + %constant.3 = f32[] constant(13) + ROOT %conditional = f32[] conditional(s32[] %constant, f32[] %constant.1, f32[] %constant.2, f32[] %constant.3), branch_computations={%Negate, %Identity, %Floor} +} + +// CHECK-LABEL: func @main() -> tensor<f32> +// CHECK: %[[INDEX:.*]] = constant {name = "constant"} dense<1> : tensor<i32> +// CHECK: %[[OPERAND_1:.*]] = constant {name = "{{.*}}"} dense<5.600000e+01> : tensor<f32> +// CHECK: %[[OPERAND_2:.*]] = constant {name = "{{.*}}"} dense<1.200000e+01> : tensor<f32> +// CHECK: %[[OPERAND_3:.*]] = constant {name = "{{.*}}"} dense<1.300000e+01> : tensor<f32> +// CHECK: %[[RESULT:.*]] = "xla_hlo.case"(%[[INDEX]], %[[OPERAND_1]], %[[OPERAND_2]], %[[OPERAND_3]]) ( { +// CHECK: ^bb0(%[[ARG_1:.*]]: tensor<f32>): +// CHECK: %[[RES_1:.*]] = "xla_hlo.negate"(%[[ARG_1]]) {name = "{{.*}}"} : (tensor<f32>) -> tensor<f32> +// CHECK: "xla_hlo.return"(%[[RES_1]]) : (tensor<f32>) -> () +// CHECK: }, { +// CHECK: ^bb0(%[[ARG_2:.*]]: tensor<f32>): +// CHECK: %[[RES_2:.*]] = "xla_hlo.copy"(%[[ARG_2]]) {name = "{{.*}}"} : (tensor<f32>) -> tensor<f32> +// CHECK: "xla_hlo.return"(%[[RES_2]]) : (tensor<f32>) -> () +// CHECK: }, { +// CHECK: ^bb0(%[[ARG_3:.*]]: tensor<f32>): +// CHECK: %[[RES_3:.*]] = "xla_hlo.floor"(%[[ARG_3]]) {name = "{{.*}}"} : (tensor<f32>) -> tensor<f32> +// CHECK: "xla_hlo.return"(%[[RES_3]]) : (tensor<f32>) -> () +// CHECK: }) {name = "{{.*}}"} : (tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<f32> +// CHECK: return %[[RESULT]] : tensor<f32> diff --git a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir index 15fa91588a5..20b43e8633d 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir @@ -1,4 +1,4 @@ -// RUN: tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s | FileCheck %s +// RUN: tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s | FileCheck %s --dump-input-on-failure // CHECK: HloModule func @main(%arg0: !xla_hlo.token, %arg1: !xla_hlo.token) -> !xla_hlo.token { @@ -96,34 +96,6 @@ func @main(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tensor // ----- -// CHECK: HloModule -func @main(%arg0: tensor<1x4xi32>, %arg1: tensor<2x4xi32>, %arg2: tensor<2x3x4xi32>) -> tensor<2x3x4xi32> { - // Same rank degenerate broadcast - // CHECK: [[ARG_0:%.*]] = s32[1,4] parameter(0) - // CHECK-NEXT: [[RESHAPE_1:%.*]] = s32[4] reshape(s32[1,4] [[ARG_0]]) - // CHECK-NEXT: [[BROADCAST_1:%.*]] = s32[2,4] broadcast(s32[4] [[RESHAPE_1]]) - // CHECK-NEXT: [[ARG_1:%.*]] = s32[2,4] parameter(1) - // CHECK-NEXT: s32[2,4] add(s32[2,4] [[BROADCAST_1]], s32[2,4] [[ARG_1]]) - %0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<1x4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32> - - // Broadcast up rank - // CHECK-NEXT: [[BROADCAST_2:%.*]] = s32[2,3,4] broadcast(s32[2,4] [[ARG_1]]), dimensions={0,2} - // CHECK-NEXT: [[ARG_2:%.*]] = s32[2,3,4] parameter(2) - // CHECK-NEXT: s32[2,3,4] add(s32[2,3,4] [[BROADCAST_2]], s32[2,3,4] [[ARG_2]]) - %1 = "xla_hlo.add"(%arg1, %arg2) {broadcast_dimensions = dense<[0,2]> : tensor<2xi64>} : (tensor<2x4xi32>, tensor<2x3x4xi32>) -> tensor<2x3x4xi32> - - // Broadcast up rank + degenerate broadcast - // CHECK-NEXT: [[BROADCAST_3:%.*]] = s32[2,1,4] broadcast(s32[1,4] [[ARG_0]]), dimensions={1,2} - // CHECK-NEXT: [[RESHAPE_2:%.*]] = s32[2,4] reshape(s32[2,1,4] [[BROADCAST_3]]) - // CHECK-NEXT: [[BROADCAST_4:%.*]] = s32[2,3,4] broadcast(s32[2,4] [[RESHAPE_2]]), dimensions={0,2} - // CHECK: ROOT - // CHECK-SAME: s32[2,3,4] add(s32[2,3,4] [[BROADCAST_4]], s32[2,3,4] [[ARG_2]]) - %2 = "xla_hlo.add"(%arg0, %arg2) {broadcast_dimensions = dense<[1,2]> : tensor<2xi64>} : (tensor<1x4xi32>, tensor<2x3x4xi32>) -> tensor<2x3x4xi32> - return %2 : tensor<2x3x4xi32> -} - -// ----- - // CHECK: HloModule func @main(%arg0: tensor<2xi32>) -> tensor<2xf32> { %0 = "xla_hlo.bitcast_convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32> diff --git a/tensorflow/compiler/mlir/xla/tests/translate/conditional.mlir b/tensorflow/compiler/mlir/xla/tests/translate/if.mlir similarity index 98% rename from tensorflow/compiler/mlir/xla/tests/translate/conditional.mlir rename to tensorflow/compiler/mlir/xla/tests/translate/if.mlir index e510a2aa35f..6542966fc7c 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/conditional.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/if.mlir @@ -41,7 +41,7 @@ func @main(%arg0: tensor<f32>) -> tuple<tensor<f32>> { %1 = "xla_hlo.tuple"(%arg0) : (tensor<f32>) -> tuple<tensor<f32>> // CHECK: %[[VAL3:.+]] = (f32[]) conditional(pred[] %[[VAL1]], (f32[]) %[[VAL2]], (f32[]) %[[VAL2]]), true_computation=[[R0]], false_computation=[[R1]] - %2 = "xla_hlo.conditional"(%0, %1, %1) ( { + %2 = "xla_hlo.if"(%0, %1, %1) ( { ^bb0(%arg1: tuple<tensor<f32>>): %6 = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple<tensor<f32>>) -> tensor<f32> %7 = "xla_hlo.log"(%6) : (tensor<f32>) -> tensor<f32> diff --git a/tensorflow/compiler/mlir/xla/tests/translate/conditional.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/if_conditional.hlotxt similarity index 97% rename from tensorflow/compiler/mlir/xla/tests/translate/conditional.hlotxt rename to tensorflow/compiler/mlir/xla/tests/translate/if_conditional.hlotxt index 00f6ec2d308..d2c6e669e9b 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/conditional.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/if_conditional.hlotxt @@ -29,7 +29,7 @@ ENTRY %tfcompile.20 { // CHECK: [[R2:%.+]] = "xla_hlo.tuple"([[A0]]) %tuple.5 = (f32[]) tuple(%arg0.1), metadata={op_type="If" op_name="cond/Merge_if"} - // CHECK: [[R3:%.+]] = "xla_hlo.conditional"([[R1]], [[R2]], [[R2]]) ( { + // CHECK: [[R3:%.+]] = "xla_hlo.if"([[R1]], [[R2]], [[R2]]) ( { // CHECK: ^bb0([[A1:%.+]]: tuple<tensor<f32>>): // CHECK: [[R7:%.+]] = "xla_hlo.get_tuple_element"([[A1]]) // CHECK: [[R8:%.+]] = "xla_hlo.log"([[R7]]) diff --git a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt index 207a8f2eabc..af45f84b34d 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt @@ -1,4 +1,4 @@ -// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s +// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s --dump-input-on-failure HloModule main @@ -20,29 +20,6 @@ ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] { ROOT %dot.4 = f32[] dot(f32[4]{0} %add.3, f32[4]{0} %Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={0} } -// This test is more thorough than those of the the other binary ops to test -// their shared functionality. - -// CHECK-LABEL: func @test_add -%test_add (Arg_0.1: f32[4], Arg_1.2: f32[4], Arg_2.3: f32[], Arg_3.4: f32[]) -> f32[4] { - %Arg_0.1 = f32[4] parameter(0) - %Arg_1.2 = f32[4] parameter(1) - %Arg_2.3 = f32[] parameter(2) - %Arg_3.4 = f32[] parameter(3) - - // Add two tensors - // CHECK-NEXT: xla_hlo.add %arg0, %arg1 {name = "{{.*}}"} - %add.3 = f32[4] add(f32[4] %Arg_0.1, f32[4] %Arg_1.2) - - // Add two scalars - // CHECK-NEXT: xla_hlo.add %arg2, %arg3 - %add.4 = f32[] add(f32[] %Arg_2.3, f32[] %Arg_3.4) - - // Add a tensor and scalar - // CHECK-NEXT: "xla_hlo.add"(%0, %1) - ROOT %add.5 = f32[4] add(f32[4] %add.3, f32[] %add.4) -} - // CHECK-LABEL: func @test_after_all // CHECK-SAME: ([[VAL_0:%.*]]: !xla_hlo.token, [[VAL_1:%.*]]: !xla_hlo.token) -> !xla_hlo.token %test_after_all (token0: token[], token1: token[] ) -> token[] { @@ -159,11 +136,11 @@ add { } -// CHECK-LABEL: func @test_compare(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>, %arg2: tensor<1xf32>) -> tensor<3xi1> { -%test_compare (Arg_0.1: f32[3], Arg_1.2: f32[3], Arg_2.3: f32[1]) -> pred[3] { +// CHECK-LABEL: func @test_compare(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>, %arg2: tensor<3xf32>) -> tensor<3xi1> { +%test_compare (Arg_0.1: f32[3], Arg_1.2: f32[3], Arg_2.3: f32[3]) -> pred[3] { %Arg_0.1 = f32[3] parameter(0) %Arg_1.2 = f32[3] parameter(1) - %Arg_2.3 = f32[1] parameter(2) + %Arg_2.3 = f32[3] parameter(2) // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "EQ", name = "{{.*}}"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1> %compare.4 = pred[3] compare(Arg_0.1, Arg_1.2), direction=EQ @@ -172,7 +149,7 @@ add { %compare.5 = pred[3] compare(Arg_0.1, Arg_1.2), direction=LE // Requires broadcast of compatible tensors. - // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg2) {comparison_direction = "GT", name = "{{.*}}"} : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xi1> + // CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg2) {comparison_direction = "GT", name = "{{.*}}"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1> ROOT %compare.6 = pred[3] compare(Arg_0.1, Arg_2.3), direction=GT } @@ -280,19 +257,19 @@ add { ROOT %convolution = f32[1,5,1] convolution(f32[1,2,1] %input, f32[1,1,1] %filter), feature_group_count=1, dim_labels=b0f_0io->b0f, window={pad=1_2 size=1} } -// CHECK-LABEL: func @test_convert(%arg0: tensor<4xf32>, %arg1: tensor<f32>) -> tensor<4xf64> { -%test_convert (Arg_0.1: f32[4], Arg_1.2: f32[]) -> f64[4] { +// CHECK-LABEL: func @test_convert(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf64> { +%test_convert (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f64[4] { %Arg_0.1 = f32[4] parameter(0) - %Arg_1.2 = f32[] parameter(1) + %Arg_1.2 = f32[4] parameter(1) // CHECK-NEXT: %0 = "xla_hlo.convert"(%arg0) {name = "{{.*}}"} : (tensor<4xf32>) -> tensor<4xf64> %convert.3 = f64[4] convert(f32[4] %Arg_0.1) - // CHECK-NEXT: %1 = "xla_hlo.convert"(%arg1) {name = "{{.*}}"} : (tensor<f32>) -> tensor<f64> - %convert.4 = f64[] convert(f32[] %Arg_1.2) + // CHECK-NEXT: %1 = "xla_hlo.convert"(%arg1) {name = "{{.*}}"} : (tensor<4xf32>) -> tensor<4xf64> + %convert.4 = f64[4] convert(f32[4] %Arg_1.2) - // CHECK-NEXT: "xla_hlo.add"(%0, %1) - ROOT %add.5 = f64[4] add(f64[4] %convert.3, f64[] %convert.4) + // CHECK-NEXT: xla_hlo.add %0, %1 + ROOT %add.5 = f64[4] add(f64[4] %convert.3, f64[4] %convert.4) } // CHECK-LABEL: func @test_cosine(%arg0: tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32> { diff --git a/tensorflow/compiler/mlir/xla/tests/unfuse_batch_norm.mlir b/tensorflow/compiler/mlir/xla/tests/unfuse_batch_norm.mlir index 9778772e250..7a54de73db7 100644 --- a/tensorflow/compiler/mlir/xla/tests/unfuse_batch_norm.mlir +++ b/tensorflow/compiler/mlir/xla/tests/unfuse_batch_norm.mlir @@ -106,24 +106,19 @@ func @batchNormInference_dynamic_shape( -> tensor<?x?x?x?xf32> { // CHECK-DAG: %[[EPS:.+]] = xla_hlo.constant dense<1.000000e-03> : tensor<f32> // CHECK-DAG: %[[DIM:.+]] = dim %[[VARIANCE]], 0 : tensor<?xf32> - // CHECK-DAG: %[[INDEX_CAST:.+]] = index_cast %[[DIM]] : index to i32 - // CHECK-DAG: %[[TO_DIM_TENSOR:.+]] = "xla_hlo.scalars_to_dimension_tensor"(%[[INDEX_CAST]]) : (i32) -> tensor<1xi32> - // CHECK-DAG: %[[EPS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[EPS]], %[[TO_DIM_TENSOR]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32> + // CHECK-DAG: %[[TO_DIM_TENSOR:.+]] = "xla_hlo.scalars_to_dimension_tensor"(%[[DIM]]) : (index) -> tensor<1xindex> + // CHECK-DAG: %[[EPS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[EPS]], %[[TO_DIM_TENSOR]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32> // CHECK-DAG: %[[VARIANCE_EPS:.+]] = xla_hlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<?xf32> // CHECK-DAG: %[[STDDEV:.+]] = "xla_hlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor<?xf32>) -> tensor<?xf32> // CHECK-DAG: %[[INPUT_DIM_0:.+]] = dim %[[X]], 0 : tensor<?x?x?x?xf32> - // CHECK-DAG: %[[INPUT_INDEX_CAST_0:.+]] = index_cast %[[INPUT_DIM_0]] : index to i32 // CHECK-DAG: %[[INPUT_DIM_1:.+]] = dim %[[X]], 1 : tensor<?x?x?x?xf32> - // CHECK-DAG: %[[INPUT_INDEX_CAST_1:.+]] = index_cast %[[INPUT_DIM_1]] : index to i32 // CHECK-DAG: %[[INPUT_DIM_2:.+]] = dim %[[X]], 2 : tensor<?x?x?x?xf32> - // CHECK-DAG: %[[INPUT_INDEX_CAST_2:.+]] = index_cast %[[INPUT_DIM_2]] : index to i32 // CHECK-DAG: %[[INPUT_DIM_3:.+]] = dim %[[X]], 3 : tensor<?x?x?x?xf32> - // CHECK-DAG: %[[INPUT_INDEX_CAST_3:.+]] = index_cast %[[INPUT_DIM_3]] : index to i32 - // CHECK-DAG: %[[TO_INPUT_DIM_TENSOR:.+]] = "xla_hlo.scalars_to_dimension_tensor"(%[[INPUT_INDEX_CAST_0]], %[[INPUT_INDEX_CAST_1]], %[[INPUT_INDEX_CAST_2]], %[[INPUT_INDEX_CAST_3]]) : (i32, i32, i32, i32) -> tensor<4xi32> - // CHECK-DAG: %[[STDDEV_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[STDDEV]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xi32>) -> tensor<?x?x?x?xf32> - // CHECK-DAG: %[[SCALE_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[SCALE]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xi32>) -> tensor<?x?x?x?xf32> - // CHECK-DAG: %[[OFFSET_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[OFFSET]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xi32>) -> tensor<?x?x?x?xf32> - // CHECK-DAG: %[[MEAN_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[MEAN]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xi32>) -> tensor<?x?x?x?xf32> + // CHECK-DAG: %[[TO_INPUT_DIM_TENSOR:.+]] = "xla_hlo.scalars_to_dimension_tensor"(%[[INPUT_DIM_0]], %[[INPUT_DIM_1]], %[[INPUT_DIM_2]], %[[INPUT_DIM_3]]) : (index, index, index, index) -> tensor<4xindex> + // CHECK-DAG: %[[STDDEV_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[STDDEV]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32> + // CHECK-DAG: %[[SCALE_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[SCALE]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32> + // CHECK-DAG: %[[OFFSET_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[OFFSET]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32> + // CHECK-DAG: %[[MEAN_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[MEAN]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32> // CHECK-DAG: %[[X_CENTER:.+]] = xla_hlo.subtract %[[X]], %[[MEAN_BCAST]] : tensor<?x?x?x?xf32> // CHECK-DAG: %[[X_SCALED:.+]] = xla_hlo.multiply %[[X_CENTER]], %[[SCALE_BCAST]] : tensor<?x?x?x?xf32> // CHECK-DAG: %[[X_NORMED:.+]] = xla_hlo.divide %[[X_SCALED]], %[[STDDEV_BCAST]] : tensor<?x?x?x?xf32> diff --git a/tensorflow/compiler/mlir/xla/transforms/chlo_legalize_to_hlo.cc b/tensorflow/compiler/mlir/xla/transforms/chlo_legalize_to_hlo.cc index 0c9585a817f..e5a79616d5b 100644 --- a/tensorflow/compiler/mlir/xla/transforms/chlo_legalize_to_hlo.cc +++ b/tensorflow/compiler/mlir/xla/transforms/chlo_legalize_to_hlo.cc @@ -163,8 +163,7 @@ struct HloBinaryElementwiseAdaptor { Value broadcasted_lhs, Value broadcasted_rhs, OpBuilder &builder) { return builder.create<ToOpTy>(from_op.getLoc(), result_type, - broadcasted_lhs, broadcasted_rhs, - /*broadcast_dimensions=*/nullptr); + broadcasted_lhs, broadcasted_rhs); } }; @@ -183,9 +182,9 @@ struct HloCompareAdaptor { Type result_type, Value broadcasted_lhs, Value broadcasted_rhs, OpBuilder &builder) { - return builder.create<xla_hlo::CompareOp>( - from_op.getLoc(), result_type, broadcasted_lhs, broadcasted_rhs, - /*broadcast_dimensions=*/nullptr, from_op.comparison_direction()); + return builder.create<xla_hlo::CompareOp>(from_op.getLoc(), result_type, + broadcasted_lhs, broadcasted_rhs, + from_op.comparison_direction()); } }; diff --git a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc index 10f35768bbd..5851bad4565 100644 --- a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc +++ b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc @@ -43,8 +43,8 @@ constexpr StringRef kTempBufferAttr = "temp"; template <typename T> using BaseOpConversion = BufferAssignmentOpConversionPattern<T>; using StdReturnOpConverter = - NonVoidToVoidReturnOpConverter<mlir::ReturnOp, xla_lhlo::TerminatorOp, - xla_lhlo::CopyOp>; + NoBufferOperandsReturnOpConverter<mlir::ReturnOp, xla_lhlo::TerminatorOp, + xla_lhlo::CopyOp>; Value InsertDynamicAllocAndDealloc(Location loc, Value result, Value shape_operand, @@ -362,6 +362,7 @@ void populateHLOToLHLOConversionPattern( HloToLhloOpConverter<xla_hlo::CopyOp>, HloToLhloOpConverter<xla_hlo::CosOp>, HloToLhloOpConverter<xla_hlo::DivOp>, + HloToLhloOpConverter<xla_hlo::DotOp>, HloToLhloOpConverter<xla_hlo::ExpOp>, HloToLhloOpConverter<xla_hlo::ImagOp>, HloToLhloOpConverter<xla_hlo::IotaOp>, diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_control_flow.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_control_flow.cc index 129a24600a2..bb1169a57d6 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_control_flow.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_control_flow.cc @@ -61,47 +61,46 @@ LogicalResult ReplaceTerminators(Region* region, Block* target_block, return success(); } -LogicalResult LowerConditionalOp(mlir::xla_hlo::ConditionalOp conditional_op) { - Operation* op_inst = conditional_op.getOperation(); - mlir::OpBuilder builder(conditional_op); +LogicalResult LowerIfOp(mlir::xla_hlo::IfOp if_op) { + Operation* op_inst = if_op.getOperation(); + mlir::OpBuilder builder(if_op); auto orig_block = op_inst->getBlock(); auto* tail_block = orig_block->splitBlock(op_inst); - auto loc = conditional_op.getLoc(); + auto loc = if_op.getLoc(); // Duplicate the true and false regions in the block between the sections // before and after the conditional. BlockAndValueMapping mapper; - conditional_op.true_branch().cloneInto(orig_block->getParent(), - Region::iterator(tail_block), mapper); - conditional_op.false_branch().cloneInto(orig_block->getParent(), - Region::iterator(tail_block), mapper); + if_op.true_branch().cloneInto(orig_block->getParent(), + Region::iterator(tail_block), mapper); + if_op.false_branch().cloneInto(orig_block->getParent(), + Region::iterator(tail_block), mapper); // Determine the blocks for the start of the true and false regions. - Block* true_block = mapper.lookup(&conditional_op.true_branch().front()); - Block* false_block = mapper.lookup(&conditional_op.false_branch().front()); + Block* true_block = mapper.lookup(&if_op.true_branch().front()); + Block* false_block = mapper.lookup(&if_op.false_branch().front()); // Perform the conditional branch into the true/false cases. builder.setInsertionPointToEnd(orig_block); // Extract the predicate for checking branching, then branch to the true and // false regions appropriately. - auto cond_value = - builder.create<mlir::ExtractElementOp>(loc, conditional_op.pred()); + auto cond_value = builder.create<mlir::ExtractElementOp>(loc, if_op.pred()); builder.create<mlir::CondBranchOp>(loc, cond_value, true_block, - conditional_op.true_arg(), false_block, - conditional_op.false_arg()); + if_op.true_arg(), false_block, + if_op.false_arg()); // Replace the true case's return operations with a branch to the tail of // the condition. - if (failed(ReplaceTerminators(&conditional_op.true_branch(), tail_block, loc, - mapper, &builder))) + if (failed(ReplaceTerminators(&if_op.true_branch(), tail_block, loc, mapper, + &builder))) return failure(); - if (failed(ReplaceTerminators(&conditional_op.false_branch(), tail_block, loc, - mapper, &builder))) + if (failed(ReplaceTerminators(&if_op.false_branch(), tail_block, loc, mapper, + &builder))) return failure(); - tail_block->addArguments(conditional_op.getResult().getType()); - conditional_op.getResult().replaceAllUsesWith(tail_block->getArgument(0)); + tail_block->addArguments(if_op.getResult().getType()); + if_op.getResult().replaceAllUsesWith(tail_block->getArgument(0)); op_inst->erase(); return success(); @@ -210,11 +209,11 @@ LogicalResult LowerWhileOp(mlir::xla_hlo::WhileOp while_op) { void LegalizeControlFlow::runOnFunction() { auto func = getFunction(); - llvm::SmallVector<ConditionalOp, 4> conditional_ops; - func.walk([&](ConditionalOp op) { conditional_ops.push_back(op); }); + llvm::SmallVector<IfOp, 4> if_ops; + func.walk([&](IfOp op) { if_ops.push_back(op); }); - for (auto& op : conditional_ops) { - if (failed(LowerConditionalOp(op))) return signalPassFailure(); + for (auto& op : if_ops) { + if (failed(LowerIfOp(op))) return signalPassFailure(); } llvm::SmallVector<WhileOp, 4> while_ops; diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index 10bac232b0f..2d6da67fc60 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -67,8 +67,9 @@ class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> { public: LegalizeTF() = default; LegalizeTF(const LegalizeTF &) {} - explicit LegalizeTF(bool allow_partial_conversion) { + explicit LegalizeTF(bool allow_partial_conversion, bool legalize_chlo) { allow_partial_conversion_ = allow_partial_conversion; + legalize_chlo_ = legalize_chlo; } /// Performs the lowering to XLA dialect. @@ -79,6 +80,11 @@ class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> { *this, "allow-partial-conversion", llvm::cl::desc("Allow operations that can't be legalized."), llvm::cl::init(false)}; + Option<bool> legalize_chlo_{ + *this, "legalize-chlo", + llvm::cl::desc( + "Also legalizes intermediate chlo ops to hlo (default true)"), + llvm::cl::init(true)}; }; /// Returns if the given TF data format string is the default format. @@ -362,6 +368,174 @@ static Value UpdateSliceInMinorDims(Location loc, Value v, Value update, return DynamicUpdateSliceInMinorDims(loc, v, update, dus_starts, builder); } +// Deprecated: This is maintained to aid in porting old code that is not yet +// dynamic shape aware and uses broadcasting modes that CHLO does not support. +// Gets the resulting type from a broadcast between two types for statically +// shaped types. This is to be used for legacy lowerings that both use non +// left-padded broadcasting and static shapes. Its use should not be permitted +// in new code. +// May return nullptr on invalid static broadcast dimensions. +// ABSL_DEPRECATED() +static RankedTensorType GetStaticBroadcastType( + RankedTensorType x, RankedTensorType y, + DenseIntElementsAttr broadcast_dimensions_attr) { + auto element_type = x.getElementType(); + auto shape_x = x.getShape(); + auto shape_y = y.getShape(); + + if (shape_x.size() == shape_y.size()) { + llvm::SmallVector<int64_t, 4> out_shape(shape_x.size()); + for (int i = 0; i < shape_x.size(); i++) { + auto x_val = shape_x[i]; + auto y_val = shape_y[i]; + out_shape[i] = std::max(x_val, y_val); + } + return RankedTensorType::get(out_shape, element_type); + } + + auto shape_large = shape_x.size() > shape_y.size() ? shape_x : shape_y; + auto shape_small = shape_x.size() <= shape_y.size() ? shape_x : shape_y; + + llvm::SmallVector<int64_t, 4> broadcast_dimensions; + // Explicit broadcast dimensions. + for (const APInt &int_value : broadcast_dimensions_attr) { + broadcast_dimensions.push_back(int_value.getSExtValue()); + } + if (broadcast_dimensions.size() != shape_small.size()) { + return nullptr; + } + llvm::SmallVector<int64_t, 4> out_shape(shape_large.begin(), + shape_large.end()); + + // Update according to the broadcast dimensions. + for (auto index_pair : llvm::enumerate(broadcast_dimensions)) { + auto old_value = out_shape[index_pair.value()]; + auto new_value = shape_small[index_pair.index()]; + out_shape[index_pair.value()] = std::max(old_value, new_value); + } + return RankedTensorType::get(out_shape, element_type); +} + +// Deprecated: This is maintained to aid in porting old code that is not yet +// dynamic shape aware and uses broadcasting modes that CHLO does not support. +// Applies static binary broadcasting to a binary elementwise op. +// This is a legacy helper to provide general broadcasting support in legacy, +// static shaped code that relies on non-left-padded broadcasting semantics. +template <typename BinaryOp> +static Value StaticBinaryBroadcast(Location loc, Value x, Value y, + DenseIntElementsAttr broadcast_dims, + OpBuilder &builder) { + auto x_type = x.getType().cast<RankedTensorType>(); + auto y_type = y.getType().cast<RankedTensorType>(); + auto result_type = GetStaticBroadcastType(x_type, y_type, broadcast_dims); + if (!result_type) { + emitError(loc) << "could not binary broadcast " << x_type << ", " << y_type + << " with broadcast_dims = " << broadcast_dims; + return nullptr; + } + auto larger_broadcast_dims = + GetI64ElementsAttrForSeq(0, result_type.getRank(), &builder); + if (x_type.getRank() < y_type.getRank()) { + if (x_type != result_type) { + x = builder.create<BroadcastInDimOp>(loc, result_type, x, broadcast_dims); + } + if (y_type != result_type) { + y = builder.create<BroadcastInDimOp>(loc, result_type, y, + larger_broadcast_dims); + } + } else { + if (x_type != result_type) { + x = builder.create<BroadcastInDimOp>(loc, result_type, x, + larger_broadcast_dims); + } + if (y_type != result_type) { + y = builder.create<BroadcastInDimOp>(loc, result_type, y, broadcast_dims); + } + } + return builder.create<BinaryOp>(loc, x, y); +} + +// Gets a 1D tensor type suitable for expressing extents of the given tensor +// value type. If the value type is ranked, the result will be statically +// shaped. Otherwise, it will have a dynamic dimension. +static RankedTensorType GetExtentsTensorTypeFor(TensorType value_type) { + Builder b(value_type.getContext()); + int64_t dim = value_type.hasRank() ? value_type.getRank() : -1; + return RankedTensorType::get({dim}, b.getIndexType()); +} + +// Broadcasts a 'lower_rank_value' to the shape of a 'higher_rank_value' +// by assuming that the shape of the lower ranked is a broadcast compatible +// prefix of the higher ranked. +// Values must be RankedTensorType (this restriction derives from the +// broadcast_dimensions attribute on DynamicBroadcastInDim). +// +// Example: +// CommonPrefixBroadcast(tensor<4x3x256>, tensor<4, 3>) will broadcast the +// lower rank value to [4, 3, 256] (i.e. the opposite of numpy-style +// implicit broadcasting). +static Value CommonPrefixBroadcast(Location loc, Value higher_rank_value, + Value lower_rank_value, OpBuilder &builder) { + Value higher_rank_shape = + builder.create<shape::ShapeOfOp>(loc, higher_rank_value); + auto result_extents_type = + GetExtentsTensorTypeFor(higher_rank_value.getType().cast<TensorType>()); + Value result_extents = builder.create<shape::ToExtentTensorOp>( + loc, result_extents_type, higher_rank_shape); + + auto lower_rank_type = lower_rank_value.getType().cast<RankedTensorType>(); + auto lower_rank = lower_rank_type.getRank(); + auto prefix_dims = GetI64ElementsAttrForSeq(0, lower_rank, &builder); + return builder.create<DynamicBroadcastInDimOp>( + loc, higher_rank_value.getType(), lower_rank_value, result_extents, + prefix_dims); +} + +// Given a value (broadcast_to) and a feature dimension, broadcasts a 1D +// value (broadcast_from) along that feature dimension. This is a shortcut +// for the cases where a 1D tensor must be broadcast along a specific feature +// dimension, which can vary based on data layout, etc. +// +// The extent of `broadcast_from` dim0 must be equal to the extent of the +// feature_dim of `broadcast_to`. +// +// Example: +// [1x2x3x4], [2], 1 -> [1x2x3x4] +// TODO(laurenzo): Swap the order of broadcast_to and broadcast_from for +// consistency. Possibly also rename for clarity. +static Value Broadcast1DToFeatureDim(Location loc, Value broadcast_to, + Value broadcast_from, int64_t feature_dim, + OpBuilder &builder) { + auto broadcast_dims = GetI64ElementsAttr({feature_dim}, &builder); + auto to_type = broadcast_to.getType().cast<RankedTensorType>(); + auto result_shape = builder.create<shape::ShapeOfOp>(loc, broadcast_to); + auto result_extents_type = GetExtentsTensorTypeFor(to_type); + auto result_extents = builder.create<shape::ToExtentTensorOp>( + loc, result_extents_type, result_shape); + return builder.create<DynamicBroadcastInDimOp>( + loc, to_type, broadcast_from, result_extents, broadcast_dims); +} + +// Broadcasts `input` to the shape of `broadcast_to` value following +// TF::BroadcastTo semantics. +// +// Requires that input is a ranked tensor. +// +// TODO(hinsu): Utilize TF::ShapeOp followed by TF::BroadcastTo once ShapeOp +// supports unranked inputs in the lowering. +static Value BroadcastToShapeOf(Location loc, Value input, Value broadcast_to, + OpBuilder &builder) { + auto result_shape = builder.create<shape::ShapeOfOp>(loc, broadcast_to); + auto to_type = broadcast_to.getType().cast<TensorType>(); + auto result_extents_type = GetExtentsTensorTypeFor(to_type); + auto result_extents = builder.create<shape::ToExtentTensorOp>( + loc, result_extents_type, result_shape); + int64_t rank = input.getType().cast<RankedTensorType>().getRank(); + auto broadcast_dims = GetI64ElementsAttrForSeq(0, rank, &builder); + return builder.create<DynamicBroadcastInDimOp>( + loc, to_type, input, result_extents, broadcast_dims); +} + // Creates a batch dot using xla_hlo::DotGeneralOp. Value BatchDot(Location loc, Value lhs, bool transpose_lhs, Value rhs, bool transpose_rhs, int64_t num_batch_dims, @@ -407,8 +581,7 @@ static void BuildReduceBody(Type element_type, Region *body, Location loc = body->getLoc(); auto reducer = - builder->create<Op>(loc, block->getArgument(0), block->getArgument(1), - /*broadcast_dimensions=*/nullptr); + builder->create<Op>(loc, block->getArgument(0), block->getArgument(1)); builder->create<ReturnOp>(loc, reducer.getResult()); } @@ -508,8 +681,7 @@ static void CreateWhile32(Location loc, int num_iterations, loc, builder->getI32IntegerAttr(num_iterations)); StringAttr compare_direction = StringAttr::get("LT", builder->getContext()); Value compare = builder->create<xla_hlo::CompareOp>( - loc, loop_iv, upper_limit, - /*broadcast_dimensions=*/nullptr, compare_direction); + loc, loop_iv, upper_limit, compare_direction); builder->create<xla_hlo::ReturnOp>(loc, compare); } @@ -539,9 +711,9 @@ static void CreateWhile32(Location loc, int num_iterations, // Increment the loop induction variable by one. auto one = builder->create<xla_hlo::ConstOp>(loc, builder->getI32IntegerAttr(1)); - auto no_broadcast_dims = GetI64ElementsAttr({}, builder); - auto plus_one = builder->create<xla_hlo::AddOp>(loc, old_values[0], one, - no_broadcast_dims); + auto scalar_broadcast_dims = GetI64ElementsAttr({}, builder); + auto plus_one = builder->create<xla_chlo::BroadcastAddOp>( + loc, old_values[0], one, scalar_broadcast_dims); // Prepend with the updated loop induction variable. new_values.insert(new_values.begin(), plus_one); @@ -566,21 +738,6 @@ static IntegerAttr getFeatureDimensionAttr(Builder &b, StringAttr format, GetFeatureDimension(format, input.getType().cast<RankedTensorType>())); } -//===----------------------------------------------------------------------===// -// Bias op utilities. -//===----------------------------------------------------------------------===// - -// Return a 1D DenseIntElementsAttr for the feature dimension of a BiasAdd. -// Requires input to have ranked tensor. -static DenseIntElementsAttr getBiasFeatureDimension(Builder &b, - StringAttr format, - Value input) { - auto inputType = input.getType().cast<RankedTensorType>(); - size_t featureDim = GetFeatureDimension(format, inputType); - RankedTensorType type = RankedTensorType::get(1, b.getIntegerType(64)); - return DenseIntElementsAttr::get(type, featureDim); -} - //===----------------------------------------------------------------------===// // MatMul op utilities. //===----------------------------------------------------------------------===// @@ -743,8 +900,7 @@ static void BuildArgMinMaxReductionBody(Type input_element_type, StringAttr compare_direction = StringAttr::get(direction, builder->getContext()); Value compare = builder->create<CompareOp>( - loc, block->getArgument(0), block->getArgument(2), - /*broadcast_dimensions=*/nullptr, compare_direction); + loc, block->getArgument(0), block->getArgument(2), compare_direction); Value selected_input = builder->create<SelectOp>( loc, input_type, compare, block->getArgument(0), block->getArgument(2)); @@ -860,8 +1016,7 @@ static void BuildSortComparisonBody(llvm::ArrayRef<Type> element_types, StringAttr compare_direction = StringAttr::get(direction, builder->getContext()); Value compare = builder->create<xla_hlo::CompareOp>( - loc, block->getArgument(0), block->getArgument(1), - /*broadcast_dimensions=*/nullptr, compare_direction); + loc, block->getArgument(0), block->getArgument(1), compare_direction); builder->create<xla_hlo::ReturnOp>(loc, compare); } @@ -900,6 +1055,27 @@ NamedAttribute GetConvDimensionNumbersAttr( feature_dim, spatial_dims, builder->getContext())); } +// Converts a TF::BiasAddOp to HLO. +// This differs from a normal TF::AddOp with respect to how the data_format +// is handled, which can optionally require a general broadcast of the +// 'bias' term in a way that is not compatible with the standard left-padded +// broadcast semantics (i.e. NCHW will broadcast into dimension 1). +// The correct 'bias' broadcast will be synthesized manually. +class ConvertBiasAddOp : public OpRewritePattern<TF::BiasAddOp> { + public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(TF::BiasAddOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto feature_dim = GetFeatureDimension( + op.data_formatAttr(), op.value().getType().cast<RankedTensorType>()); + auto bias_broadcast = Broadcast1DToFeatureDim(loc, op.value(), op.bias(), + feature_dim, rewriter); + rewriter.replaceOpWithNewOp<AddOp>(op, op.value(), bias_broadcast); + return success(); + } +}; + // Converts the TensorFlow conv op in template to the generic HLO conv op by // converting TensorFlow op attributes to HLO op attributes. // @@ -1161,7 +1337,6 @@ class ConvertDiagPartOp : public OpRewritePattern<TF::DiagPartOp> { rewriter.getI64IntegerAttr(1)); Value compare = rewriter.create<CompareOp>( op.getLoc(), iota0, iota1, - /*broadcast_dimensions=*/nullptr, StringAttr::get("EQ", rewriter.getContext())); Value zero = GetScalarConstOfType(input_type.getElementType(), op.getLoc(), 0, &rewriter); @@ -1274,33 +1449,35 @@ class ConvertFusedBatchNormGradBase non_feature_dims.push_back(i); } auto reduce_dims = GetI64ElementsAttr(non_feature_dims, &rewriter); - auto broadcast_dims = GetI64ElementsAttr({feature_dim}, &rewriter); - auto no_broadcast_dims = GetI64ElementsAttr({}, &rewriter); + auto scalar_broadcast_dims = GetI64ElementsAttr({}, &rewriter); // scratch1 = rsqrt(var + epsilon) RankedTensorType scalar_float = RankedTensorType::get({}, kernel_type); auto epsilon = rewriter.create<ConstOp>( loc, DenseFPElementsAttr::get(scalar_float, {op.epsilon()})); - auto add_op = rewriter.create<AddOp>(loc, var, epsilon.getResult(), - no_broadcast_dims); + auto add_op = rewriter.create<xla_chlo::BroadcastAddOp>( + loc, var, epsilon.getResult(), scalar_broadcast_dims); + Value scratch1 = rewriter.create<RsqrtOp>(loc, add_op); // scratch2 = sum(y_backprop * (x - mean)) - auto sub_op = rewriter.create<SubOp>(loc, act, mean, broadcast_dims); - auto weighted_grad = - rewriter.create<MulOp>(loc, grad, sub_op, no_broadcast_dims); + auto sub_op = rewriter.create<xla_hlo::SubOp>( + loc, act, + Broadcast1DToFeatureDim(loc, act, mean, feature_dim, rewriter)); + auto weighted_grad = rewriter.create<xla_hlo::MulOp>(loc, grad, sub_op); Value scratch2 = ApplyReduction(loc, weighted_grad, reduce_dims, &rewriter); // x_backprop = y_backprop * (scale * scratch1) auto scaled_grad = - rewriter.create<MulOp>(loc, op.scale(), scratch1, no_broadcast_dims); - x_backprop = - rewriter.create<MulOp>(loc, grad, scaled_grad, broadcast_dims); + rewriter.create<xla_hlo::MulOp>(loc, op.scale(), scratch1); + x_backprop = rewriter.create<xla_hlo::MulOp>( + loc, grad, + Broadcast1DToFeatureDim(loc, act, scaled_grad, feature_dim, + rewriter)); // scale_backprop = scratch2 * scratch1 - scale_backprop = - rewriter.create<MulOp>(loc, scratch1, scratch2, no_broadcast_dims); + scale_backprop = rewriter.create<xla_hlo::MulOp>(loc, scratch1, scratch2); // offset_backprop = sum(y_backprop) offset_backprop = ApplyReduction(loc, grad, reduce_dims, &rewriter); @@ -1396,7 +1573,7 @@ class ConvertFusedBatchNormV3Op auto factor_const_op = rewriter.create<xla_hlo::ConstOp>( op.getLoc(), rewriter.getFloatAttr(scale_element_type, factor)); - Value corrected_variance = rewriter.create<xla_hlo::MulOp>( + Value corrected_variance = rewriter.create<xla_chlo::BroadcastMulOp>( op.getLoc(), batch_variance.getType(), batch_variance, factor_const_op, /*broadcast_dimensions=*/DenseIntElementsAttr()); @@ -1416,24 +1593,26 @@ class ConvertFusedBatchNormV3Op rewriter.getFloatAttr(mean_element_type, exponential_avg_factor)); // new_running_mean = alpha * old_mean + beta * batch_mean. - auto alpha_mul_old_mean = rewriter.create<MulOp>( + auto alpha_mul_old_mean = rewriter.create<xla_chlo::BroadcastMulOp>( op.getLoc(), op.mean().getType(), alpha, op.mean(), /*broadcast_dimensions=*/DenseIntElementsAttr()); - auto beta_mul_batch_mean = rewriter.create<MulOp>( + auto beta_mul_batch_mean = rewriter.create<xla_chlo::BroadcastMulOp>( op.getLoc(), batch_mean.getType(), beta, batch_mean, /*broadcast_dimensions=*/DenseIntElementsAttr()); - batch_mean = rewriter.create<AddOp>( + batch_mean = rewriter.create<xla_chlo::BroadcastAddOp>( op.getLoc(), alpha_mul_old_mean, beta_mul_batch_mean, /*broadcast_dimensions=*/DenseIntElementsAttr()); // new_running_variance = alpha * old_variance + beta * batch_variance. - auto alpha_mul_old_variance = rewriter.create<MulOp>( + auto alpha_mul_old_variance = rewriter.create<xla_chlo::BroadcastMulOp>( op.getLoc(), op.variance().getType(), alpha, op.variance(), /*broadcast_dimensions=*/DenseIntElementsAttr()); - auto beta_mul_batch_variance = rewriter.create<MulOp>( - op.getLoc(), corrected_variance.getType(), beta, corrected_variance, - /*broadcast_dimensions=*/DenseIntElementsAttr()); - corrected_variance = rewriter.create<AddOp>( + auto beta_mul_batch_variance = + rewriter.create<xla_chlo::BroadcastMulOp>( + op.getLoc(), corrected_variance.getType(), beta, + corrected_variance, + /*broadcast_dimensions=*/DenseIntElementsAttr()); + corrected_variance = rewriter.create<xla_chlo::BroadcastAddOp>( op.getLoc(), alpha_mul_old_variance, beta_mul_batch_variance, /*broadcast_dimensions=*/DenseIntElementsAttr()); } @@ -1586,10 +1765,9 @@ class ConvertAvgPoolOp : public OpRewritePattern<TF::AvgPoolOp> { // Divide by the number of elements in the window. Value divisor = GetScalarConstOfType(sum_element_type, op.getLoc(), count, &rewriter); - auto batch_dims = - GetI64ElementsAttrForSeq(0, input_type.getRank(), &rewriter); - Value result = rewriter.create<DivOp>(op.getLoc(), result_type, reduce, - divisor, batch_dims); + auto scalar_broadcast_dims = GetI64ElementsAttr({}, &rewriter); + Value result = rewriter.create<xla_chlo::BroadcastDivOp>( + op.getLoc(), result_type, reduce, divisor, scalar_broadcast_dims); // Convert back if we enlarged the element type's bitwidth. if (input_element_type != sum_element_type) @@ -1746,29 +1924,20 @@ class ConvertSigmoidOp : public OpRewritePattern<TF::SigmoidOp> { LogicalResult matchAndRewrite(TF::SigmoidOp op, PatternRewriter &rewriter) const override { - auto operand = op.getOperand(); + Location loc = op.getLoc(); - auto scalar_one = rewriter.create<ConstOp>( - op.getLoc(), - rewriter.getFloatAttr(getElementTypeOrSelf(operand.getType()), 0.5)); + // Create constant half with shape and element type same as the operand. + Value operand = op.getOperand(); + auto operand_ty = operand.getType().cast<TensorType>(); + auto scalar_ty = RankedTensorType::get({}, operand_ty.getElementType()); + ElementsAttr attr = mlir::xla::getSplat(&rewriter, scalar_ty, 0.5); + auto scalar_half = rewriter.create<ConstOp>(loc, attr); + auto half = BroadcastToShapeOf(loc, scalar_half, operand, rewriter); - auto type = operand.getType().dyn_cast<RankedTensorType>(); - if (!type) - return rewriter.notifyMatchFailure(op, "requires ranked tensor type"); - auto constant_ones = rewriter.create<BroadcastOp>( - op.getLoc(), type, scalar_one, - GetI64ElementsAttr(type.getShape(), &rewriter)); - - auto scaled_input = rewriter.create<MulOp>( - op.getLoc(), operand, constant_ones, DenseIntElementsAttr()); - auto tanh_op = - rewriter.create<TanhOp>(op.getLoc(), operand.getType(), scaled_input); - auto mul_op = - rewriter.create<MulOp>(op.getLoc(), tanh_op, constant_ones, - /*DenseIntElementsAttr=*/DenseIntElementsAttr()); - auto add_op = - rewriter.create<AddOp>(op.getLoc(), mul_op, constant_ones, - /*DenseIntElementsAttr=*/DenseIntElementsAttr()); + auto scaled_input = rewriter.create<MulOp>(loc, operand, half); + auto tanh_op = rewriter.create<TanhOp>(loc, scaled_input); + auto mul_op = rewriter.create<MulOp>(loc, tanh_op, half); + auto add_op = rewriter.create<AddOp>(loc, mul_op, half); rewriter.replaceOp(op, add_op.getResult()); return success(); @@ -1807,20 +1976,18 @@ class ConvertSoftmaxOp : public OpRewritePattern<OpTy> { LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { - Value logits = op.logits(); - // Softmax converter requires ranked type because the XLA reduce ops used // while lowering requires dimensions attribute to reduce along. + // Note that the input and output shape is equivalent, so we use 'logits' + // and its type for shape calculations. + Value logits = op.logits(); RankedTensorType type = logits.getType().dyn_cast<RankedTensorType>(); if (!type) return failure(); - auto loc = op.getLoc(); int rank = type.getRank(); // Note that the TensorFlow Softmax op verifies that the input rank is - // greater than or equal to one so both of the following sequences are - // valid. - auto batch_dims = GetI64ElementsAttrForSeq(0, rank - 1, &rewriter); + // greater than or equal to one so the following sequence is valid. auto reduce_dim = rewriter.create<TF::ConstOp>( loc, GetI64ElementsAttr({rank - 1}, &rewriter)); @@ -1833,8 +2000,10 @@ class ConvertSoftmaxOp : public OpRewritePattern<OpTy> { auto max_logits = rewriter.create<TF::MaxOp>(loc, logits, reduce_dim, /*keep_dims=*/rewriter.getBoolAttr(false)); - auto shifted_logits = - rewriter.create<SubOp>(loc, type, logits, max_logits, batch_dims); + auto max_logits_broadcast = + CommonPrefixBroadcast(loc, logits, max_logits, rewriter); + auto shifted_logits = rewriter.create<xla_hlo::SubOp>(loc, type, logits, + max_logits_broadcast); // Exponentiate the inputs. Value exp = rewriter.create<ExpOp>(loc, type, shifted_logits); @@ -1847,9 +2016,12 @@ class ConvertSoftmaxOp : public OpRewritePattern<OpTy> { if (use_log) { Value log = rewriter.create<LogOp>(loc, sum); - rewriter.replaceOpWithNewOp<SubOp>(op, shifted_logits, log, batch_dims); + auto log_broadcast = CommonPrefixBroadcast(loc, logits, log, rewriter); + rewriter.replaceOpWithNewOp<xla_hlo::SubOp>(op, shifted_logits, + log_broadcast); } else { - rewriter.replaceOpWithNewOp<DivOp>(op, exp, sum, batch_dims); + auto sum_broadcast = CommonPrefixBroadcast(loc, logits, sum, rewriter); + rewriter.replaceOpWithNewOp<xla_hlo::DivOp>(op, exp, sum_broadcast); } return success(); } @@ -1896,7 +2068,7 @@ class ConvertSizeOp : public OpRewritePattern<TF::SizeOp> { auto dim = rewriter.create<GetDimensionSizeOp>( op.getLoc(), result_type, input, rewriter.getIntegerAttr(rewriter.getIntegerType(32), i)); - size = rewriter.create<MulOp>( + size = rewriter.create<xla_chlo::BroadcastMulOp>( op.getLoc(), size->getResult(0), dim.getResult(), /*DenseIntElementsAttr=*/DenseIntElementsAttr()); } @@ -2582,10 +2754,10 @@ class ConvertRangeOp : public OpRewritePattern<TF::RangeOp> { auto iota = rewriter.create<IotaOp>(op.getLoc(), result_type, rewriter.getI64IntegerAttr(0)); - auto scaled = rewriter.create<MulOp>( + auto scaled = rewriter.create<xla_chlo::BroadcastMulOp>( op.getLoc(), result_type, iota, op.delta(), xla::getBroadcastDimensionsAttr(&rewriter, iota, op.delta())); - rewriter.replaceOpWithNewOp<AddOp>( + rewriter.replaceOpWithNewOp<xla_chlo::BroadcastAddOp>( op, result_type, scaled, op.start(), xla::getBroadcastDimensionsAttr(&rewriter, scaled, op.start())); return success(); @@ -2633,7 +2805,7 @@ class ConvertLinSpaceOp : public OpRewritePattern<TF::LinSpaceOp> { int64_t num = (*num_attr.begin()).getSExtValue(); // Calculate the scaling that needs to be applied to the iota. - auto step_numerator = rewriter.create<SubOp>( + auto step_numerator = rewriter.create<xla_chlo::BroadcastSubOp>( op.getLoc(), op.start().getType(), op.stop(), op.start(), xla::getBroadcastDimensionsAttr(&rewriter, op.stop(), op.start())); Value step_denominator = rewriter.create<ConvertOp>( @@ -2641,11 +2813,11 @@ class ConvertLinSpaceOp : public OpRewritePattern<TF::LinSpaceOp> { if (num > 1) { Value one = GetScalarConstOfType(result_type.getElementType(), op.getLoc(), 1, &rewriter); - step_denominator = rewriter.create<SubOp>( + step_denominator = rewriter.create<xla_chlo::BroadcastSubOp>( op.getLoc(), step_denominator.getType(), step_denominator, one, xla::getBroadcastDimensionsAttr(&rewriter, step_denominator, one)); } - auto step = rewriter.create<DivOp>( + auto step = rewriter.create<xla_chlo::BroadcastDivOp>( op.getLoc(), step_numerator.getType(), step_numerator, step_denominator, xla::getBroadcastDimensionsAttr(&rewriter, step_numerator, step_denominator)); @@ -2653,10 +2825,10 @@ class ConvertLinSpaceOp : public OpRewritePattern<TF::LinSpaceOp> { // Scale the iota and add the offset. auto iota = rewriter.create<IotaOp>(op.getLoc(), result_type, rewriter.getI64IntegerAttr(0)); - auto scaled = rewriter.create<MulOp>( + auto scaled = rewriter.create<xla_chlo::BroadcastMulOp>( op.getLoc(), result_type, iota, step, xla::getBroadcastDimensionsAttr(&rewriter, iota, step)); - rewriter.replaceOpWithNewOp<AddOp>( + rewriter.replaceOpWithNewOp<xla_chlo::BroadcastAddOp>( op, result_type, scaled, op.start(), xla::getBroadcastDimensionsAttr(&rewriter, scaled, op.start())); return success(); @@ -2732,8 +2904,8 @@ class GenericConvertReductionOp : public OpRewritePattern<OpTy> { auto divisor = GetScalarConstOfType(reduce_element_type, loc, divisor_count, &rewriter); auto broadcast_dims = GetI64ElementsAttr({}, &rewriter); - result = rewriter.create<DivOp>(loc, result, divisor.getResult(), - broadcast_dims); + result = rewriter.create<xla_chlo::BroadcastDivOp>( + loc, result, divisor.getResult(), broadcast_dims); } result = rewriter.create<ConvertOp>(loc, result, element_type); @@ -3118,7 +3290,6 @@ class ConvertMaxPoolGradOp : public OpRewritePattern<OpTy> { auto reducer = rewriter.create<CompareOp>( loc, block->getArgument(0), block->getArgument(1), - /*broadcast_dimensions=*/nullptr, StringAttr::get("GE", rewriter.getContext())); rewriter.create<ReturnOp>(loc, reducer.getResult()); } @@ -3544,13 +3715,20 @@ class ConvertOneHotOp : public OpRewritePattern<TF::OneHotOp> { output_dims.insert(output_dims.begin() + axis, depth); Location loc = op.getLoc(); + + // The iota result is the effective output shape of the computation, + // and indices must be broadcast into it. At this point, this computation + // would need to be reworked quite a bit to support dynamic shapes, so + // just using static broadcasting. auto index_type = RankedTensorType::get(output_dims, element_type); - Value compare = rewriter.create<CompareOp>( - loc, op.indices(), - rewriter.create<IotaOp>( - loc, index_type, - IntegerAttr::get(rewriter.getIntegerType(64), axis)), - GetI64ElementsAttr(broadcast_dims, &rewriter), + auto iota = rewriter.create<IotaOp>( + loc, index_type, IntegerAttr::get(rewriter.getIntegerType(64), axis)); + auto broadcast_indices = rewriter.create<BroadcastInDimOp>( + loc, index_type, op.indices(), + GetI64ElementsAttr(broadcast_dims, &rewriter)); + + Value compare = rewriter.create<xla_hlo::CompareOp>( + loc, broadcast_indices, iota, StringAttr::get("EQ", rewriter.getContext())); Value on_value = rewriter.create<BroadcastOp>( loc, op.getType(), op.on_value(), @@ -4396,7 +4574,6 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> { rewriter.getI64IntegerAttr(1)); Value compare = rewriter.create<CompareOp>( op.getLoc(), iota0, iota1, - /*broadcast_dimensions=*/nullptr, StringAttr::get("EQ", rewriter.getContext())); Value identity_matrix = rewriter.create<ConvertOp>(op.getLoc(), compare, type.getElementType()); @@ -4430,8 +4607,7 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> { batch_dims.size(), precision_config, &rewriter); a_update = BatchDot(op.getLoc(), y, false, a_update, false, batch_dims.size(), precision_config, &rewriter); - a_panel = rewriter.create<AddOp>(op.getLoc(), a_panel, a_update, - /*broadcast_dimensions=*/nullptr); + a_panel = rewriter.create<AddOp>(op.getLoc(), a_panel, a_update); a = UpdateSliceInMinorDims(op.getLoc(), a, a_panel, {i, i + k}, &rewriter); @@ -4442,8 +4618,7 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> { batch_dims.size(), precision_config, &rewriter); q_update = BatchDot(op.getLoc(), q_update, false, y, true, batch_dims.size(), precision_config, &rewriter); - q_panel = rewriter.create<AddOp>(op.getLoc(), q_panel, q_update, - /*broadcast_dimensions=*/nullptr); + q_panel = rewriter.create<AddOp>(op.getLoc(), q_panel, q_update); q = UpdateSliceInMinorDims(op.getLoc(), q, q_panel, {i}, &rewriter); } // full_matrices is false when only a partial result in needed. Slice to the @@ -4505,34 +4680,31 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> { Value iota = builder->create<IotaOp>( loc, RankedTensorType::get({m}, builder->getIntegerType(32)), builder->getI64IntegerAttr(0)); - Value gtk = builder->create<CompareOp>( + Value gtk = builder->create<xla_chlo::BroadcastCompareOp>( loc, iota, k, GetI64ElementsAttr({}, builder), StringAttr::get("GT", builder->getContext())); gtk = builder->create<ConvertOp>(loc, gtk, x_type.getElementType()); - Value x_after_k = builder->create<MulOp>( + Value x_after_k = builder->create<xla_chlo::BroadcastMulOp>( loc, x, gtk, GetI64ElementsAttr({minor_dim}, builder)); - Value x_after_k_sq = builder->create<MulOp>( - loc, x_after_k, x_after_k, /*broadcast_dimensions=*/nullptr); + Value x_after_k_sq = builder->create<MulOp>(loc, x_after_k, x_after_k); // sigma = np.dot(x[k+1:], x[k+1:]) auto sigma = builder->create<ReduceOp>( loc, x_after_k_sq, zero, GetI64ElementsAttr({minor_dim}, builder)); BuildReduceBody<AddOp>(x_type.getElementType(), &sigma.body(), builder); // mu = np.sqrt(x[k]*x[k] + sigma) - Value alpha_sq = builder->create<MulOp>(loc, alpha, alpha, - /*broadcast_dimensions=*/nullptr); + Value alpha_sq = builder->create<MulOp>(loc, alpha, alpha); Value mu = builder->create<SqrtOp>( - loc, builder->create<AddOp>(loc, alpha_sq, sigma.getResult(0), - /*broadcast_dimensions=*/nullptr)); + loc, builder->create<AddOp>(loc, alpha_sq, sigma.getResult(0))); - Value sigma_is_zero = builder->create<CompareOp>( + Value sigma_is_zero = builder->create<xla_chlo::BroadcastCompareOp>( loc, sigma.getResult(0), zero, GetI64ElementsAttr({}, builder), StringAttr::get("EQ", builder->getContext())); - Value alpha_is_negative = builder->create<CompareOp>( + Value alpha_is_negative = builder->create<xla_chlo::BroadcastCompareOp>( loc, alpha, zero, GetI64ElementsAttr({}, builder), StringAttr::get("LT", builder->getContext())); auto batch_size_one = builder->create<BroadcastOp>( loc, alpha.getType(), one, GetI64ElementsAttr(batch_dims, builder)); - Value signed_mu = builder->create<MulOp>( + Value signed_mu = builder->create<xla_chlo::BroadcastMulOp>( loc, builder->create<SelectOp>(loc, mu.getType(), alpha_is_negative, batch_size_one, @@ -4541,21 +4713,16 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> { *beta = builder->create<SelectOp>(loc, alpha.getType(), sigma_is_zero, alpha, signed_mu); *tau = builder->create<DivOp>( - loc, - builder->create<SubOp>(loc, *beta, alpha, - /*broadcast_dimensions=*/nullptr), - *beta, - /*broadcast_dimensions=*/nullptr); + loc, builder->create<SubOp>(loc, *beta, alpha), *beta); Value zero_tau = builder->create<BroadcastOp>( loc, alpha.getType(), zero, GetI64ElementsAttr(batch_dims, builder)); *tau = builder->create<SelectOp>(loc, alpha.getType(), sigma_is_zero, zero_tau, *tau); - Value divisor = builder->create<SubOp>(loc, alpha, *beta, - /*broadcast_dimensions=*/nullptr); + Value divisor = builder->create<SubOp>(loc, alpha, *beta); divisor = builder->create<SelectOp>(loc, divisor.getType(), sigma_is_zero, batch_size_one, divisor); - Value eqk = builder->create<CompareOp>( + Value eqk = builder->create<xla_chlo::BroadcastCompareOp>( loc, iota, k, GetI64ElementsAttr({}, builder), StringAttr::get("EQ", builder->getContext())); eqk = builder->create<ConvertOp>(loc, eqk, x_type.getElementType()); @@ -4568,10 +4735,12 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> { // Form v as [0, 0, ..., 1] ++ x[k+1:] / divisor // If sigma is zero, x[k+1:] is zero, so use any non-zero divisor. - *v = builder->create<AddOp>( + // Note that the add performs a degenerate broadcast. + *v = builder->create<xla_chlo::BroadcastAddOp>( loc, e_k, - builder->create<DivOp>(loc, x_after_k, divisor, - GetI64ElementsAttr(batch_dim_ids, builder)), + StaticBinaryBroadcast<DivOp>(loc, x_after_k, divisor, + GetI64ElementsAttr(batch_dim_ids, builder), + *builder), /*broadcast_dimensions=*/nullptr); } @@ -4645,10 +4814,10 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> { precision, builder); vva = BatchDot(loc, v_broadcast, true, vva, false, num_batch_dims, precision, builder); - auto tau_x_vva = builder->create<MulOp>( - loc, tau, vva, GetI64ElementsAttr(batch_dim_indices, builder)); - a = builder->create<SubOp>(loc, a, tau_x_vva, - /*broadcast_dimensions=*/nullptr); + auto tau_x_vva = StaticBinaryBroadcast<xla_hlo::MulOp>( + loc, tau, vva, GetI64ElementsAttr(batch_dim_indices, builder), + *builder); + a = builder->create<SubOp>(loc, a, tau_x_vva); // It is more precise to populate column 'k' explicitly, rather than // computing it implicitly by applying the Householder transformation. @@ -4657,12 +4826,12 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> { auto iota = builder->create<IotaOp>( loc, RankedTensorType::get({m, 1}, builder->getIntegerType(32)), builder->getI64IntegerAttr(0)); - Value predecessor_mask = builder->create<CompareOp>( + Value predecessor_mask = builder->create<xla_chlo::BroadcastCompareOp>( loc, iota, j, GetI64ElementsAttr({}, builder), StringAttr::get("LT", builder->getContext())); predecessor_mask = builder->create<ConvertOp>(loc, predecessor_mask, a_type.getElementType()); - Value mask = builder->create<CompareOp>( + Value mask = builder->create<xla_chlo::BroadcastCompareOp>( loc, iota, j, GetI64ElementsAttr({}, builder), StringAttr::get("EQ", builder->getContext())); mask = builder->create<ConvertOp>(loc, mask, a_type.getElementType()); @@ -4674,14 +4843,14 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> { mask, GetI64ElementsAttr(llvm::SmallVector<int64_t, 4>(num_batch_dims, 1), builder)); - Value predecessor_masked_x = builder->create<MulOp>( + Value predecessor_masked_x = StaticBinaryBroadcast<MulOp>( loc, x, predecessor_mask, - GetI64ElementsAttr({num_dims - 2, num_dims - 1}, builder)); - Value masked_beta = builder->create<MulOp>( - loc, beta, mask, GetI64ElementsAttr(batch_dim_indices, builder)); + GetI64ElementsAttr({num_dims - 2, num_dims - 1}, builder), *builder); + Value masked_beta = StaticBinaryBroadcast<MulOp>( + loc, beta, mask, GetI64ElementsAttr(batch_dim_indices, builder), + *builder); Value new_x = - builder->create<AddOp>(loc, predecessor_masked_x, masked_beta, - /*broadcast_dimensions=*/nullptr); + builder->create<AddOp>(loc, predecessor_masked_x, masked_beta); // Update a[:,j] llvm::SmallVector<int64_t, 4> dim_ids(num_dims); std::iota(dim_ids.begin(), dim_ids.end(), 0); @@ -4692,7 +4861,7 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> { loc, RankedTensorType::get(a_type.getShape(), builder->getIntegerType(32)), builder->getI64IntegerAttr(minor_dim + 1)); - Value xa_mask = builder->create<CompareOp>( + Value xa_mask = builder->create<xla_chlo::BroadcastCompareOp>( loc, iota_mn, j, GetI64ElementsAttr({}, builder), StringAttr::get("EQ", builder->getContext())); a = builder->create<SelectOp>(loc, a_type, xa_mask, new_x, a); @@ -4708,11 +4877,11 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> { builder)); auto vs_update = builder->create<SelectOp>( loc, vs.getType(), xa_mask, - builder->create<AddOp>( - loc, vs_zeros, v, GetI64ElementsAttr(vs_broadcast_dims, builder)), + StaticBinaryBroadcast<AddOp>( + loc, vs_zeros, v, GetI64ElementsAttr(vs_broadcast_dims, builder), + *builder), vs_zeros); - vs = builder->create<AddOp>(loc, vs, vs_update, - /*broadcast_dimensions=*/nullptr); + vs = builder->create<AddOp>(loc, vs, vs_update); // taus[j] = tau llvm::SmallVector<int64_t, 4> tau_broadcast_dims(batch_dims.size()); @@ -4729,17 +4898,16 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> { loc, taus.getType(), taus_zeros, GetI64ElementsAttr(taus.getType().cast<RankedTensorType>().getShape(), builder)); - Value taus_mask = builder->create<CompareOp>( + Value taus_mask = builder->create<xla_chlo::BroadcastCompareOp>( loc, iota_n, j, GetI64ElementsAttr({}, builder), StringAttr::get("EQ", builder->getContext())); auto taus_update = builder->create<SelectOp>( loc, taus.getType(), taus_mask, - builder->create<AddOp>( + StaticBinaryBroadcast<AddOp>( loc, taus_zeros, tau, - GetI64ElementsAttr(tau_broadcast_dims, builder)), + GetI64ElementsAttr(tau_broadcast_dims, builder), *builder), taus_zeros); - taus = builder->create<AddOp>(loc, taus, taus_update, - /*broadcast_dimensions=*/nullptr); + taus = builder->create<AddOp>(loc, taus, taus_update); new_values->assign({a, vs, taus}); }; @@ -4796,8 +4964,7 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> { j = builder->create<AddOp>( loc, j, GetScalarConstOfType(getElementTypeOrSelf(j.getType()), loc, 1, - builder), - /*broadcast_dimensions=*/nullptr); + builder)); // vs has shape [..., m, 1] auto v = DynamicSliceInMinorDims(loc, vs, {j}, {1}, builder); // beta has shape [..., 1] @@ -4816,7 +4983,7 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> { loc, vs.getType(), zero, GetI64ElementsAttr(vs.getType().cast<RankedTensorType>().getShape(), builder)); - auto compare = builder->create<CompareOp>( + auto compare = builder->create<xla_chlo::BroadcastCompareOp>( loc, iota_mn, j, GetI64ElementsAttr({}, builder), StringAttr::get("GE", builder->getContext())); auto y = builder->create<SelectOp>(loc, vs.getType(), compare, zero, vs); @@ -4831,13 +4998,12 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> { // z = -beta * (v + wyv) auto neg_beta = builder->create<NegOp>(loc, beta); - auto v_wyv = builder->create<AddOp>(loc, v, wyv, - /*broadcast_dimensions=*/nullptr); + auto v_wyv = builder->create<AddOp>(loc, v, wyv); auto beta_broadcast_dims = llvm::to_vector<4>(batch_dim_indices); beta_broadcast_dims.push_back(n_index); - auto z = builder->create<MulOp>( + auto z = StaticBinaryBroadcast<MulOp>( loc, neg_beta, v_wyv, - GetI64ElementsAttr(beta_broadcast_dims, builder)); + GetI64ElementsAttr(beta_broadcast_dims, builder), *rewriter); w = DynamicUpdateSliceInMinorDims(loc, w, z, {j}, builder); new_values->assign({w, vs, taus}); @@ -4855,8 +5021,9 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> { auto neg_beta = rewriter->create<NegOp>(loc, beta); auto beta_broadcast_dims = llvm::to_vector<4>(batch_dim_indices); beta_broadcast_dims.push_back(n_index); - auto bv = rewriter->create<MulOp>( - loc, neg_beta, v, GetI64ElementsAttr(beta_broadcast_dims, rewriter)); + auto bv = StaticBinaryBroadcast<MulOp>( + loc, neg_beta, v, GetI64ElementsAttr(beta_broadcast_dims, rewriter), + *rewriter); w = UpdateSliceInMinorDims(loc, w, bv, {0}, rewriter); SmallVector<Value, 4> while_output; @@ -4912,7 +5079,8 @@ void EmitLegalizationErrors(Operation *op, // Performs the lowering to XLA dialect. void LegalizeTF::runOnFunction() { - if (failed(legalizeTF(getFunction(), allow_partial_conversion_))) + if (failed( + legalizeTF(getFunction(), allow_partial_conversion_, legalize_chlo_))) signalPassFailure(); } @@ -4923,7 +5091,8 @@ static PassRegistration<LegalizeTF> pass( #include "tensorflow/compiler/mlir/xla/transforms/generated_legalize_tf.inc" -LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) { +LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion, + bool legalize_chlo) { MLIRContext *context = op->getContext(); // Add lowering patterns to the list. @@ -4936,19 +5105,19 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) { TF::PopulateLoweringTFPatterns(context, &patterns); patterns.insert< ConvertAllOp, ConvertAnyOp, ConvertArgMaxOp, ConvertBatchMatMulV2Op, - ConvertBroadcastToOp, ConvertBF16FloorDivOp, ConvertConv2DOp, - ConvertConv3DOp, ConvertDepthConv2DOp, ConvertConv2DBackpropFilterOp, - ConvertConv3DBackpropFilterOp, ConvertConv2DBackpropInputOp, - ConvertConv3DBackpropInputOp, ConvertCumsumOp, ConvertDiagPartOp, - ConvertEinsumOp, ConvertFusedBatchNormGradOp, - ConvertFusedBatchNormGradV2Op, ConvertFusedBatchNormGradV3Op, - ConvertFusedBatchNormV3Op, ConvertInfeedDequeueTupleOp, - ConvertInplaceUpdateOp, ConvertLinSpaceOp, ConvertMaxOp, ConvertMinOp, - ConvertAvgPoolOp, ConvertMaxPool2DOp, ConvertMaxPool3DOp, - ConvertMaxPool2DGradOp, ConvertMaxPool3DGradOp, ConvertMeanOp, - ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp, ConvertProdOp, ConvertQrOp, - ConvertRangeOp, ConvertSelectV2Op, ConvertSigmoidOp, ConvertSizeOp, - ConvertSoftmaxOp<TF::LogSoftmaxOp, true>, + ConvertBiasAddOp, ConvertBroadcastToOp, ConvertBF16FloorDivOp, + ConvertConv2DOp, ConvertConv3DOp, ConvertDepthConv2DOp, + ConvertConv2DBackpropFilterOp, ConvertConv3DBackpropFilterOp, + ConvertConv2DBackpropInputOp, ConvertConv3DBackpropInputOp, + ConvertCumsumOp, ConvertDiagPartOp, ConvertEinsumOp, + ConvertFusedBatchNormGradOp, ConvertFusedBatchNormGradV2Op, + ConvertFusedBatchNormGradV3Op, ConvertFusedBatchNormV3Op, + ConvertInfeedDequeueTupleOp, ConvertInplaceUpdateOp, ConvertLinSpaceOp, + ConvertMaxOp, ConvertMinOp, ConvertAvgPoolOp, ConvertMaxPool2DOp, + ConvertMaxPool3DOp, ConvertMaxPool2DGradOp, ConvertMaxPool3DGradOp, + ConvertMeanOp, ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp, + ConvertProdOp, ConvertQrOp, ConvertRangeOp, ConvertSelectV2Op, + ConvertSigmoidOp, ConvertSizeOp, ConvertSoftmaxOp<TF::LogSoftmaxOp, true>, ConvertSoftmaxOp<TF::SoftmaxOp, false>, ConvertSplitOp, ConvertSplitVOp, ConvertStridedSliceOp, ConvertStridedSliceGradOp, ConvertSumOp, ConvertTensorScatterUpdateOp, ConvertTileOp, ConvertTopKV2Op, @@ -4959,10 +5128,16 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) { // Populate with CHLO->HLO lowerings to account for TF ops legalized to // CHLO first. - xla_chlo::PopulateLegalizeChloToHloPatterns(context, &patterns); + if (legalize_chlo) { + xla_chlo::PopulateLegalizeChloToHloPatterns(context, &patterns); + } ConversionTarget target(*context); - target.addIllegalDialect<xla_chlo::XlaHloClientDialect>(); + if (legalize_chlo) { + target.addIllegalDialect<xla_chlo::XlaHloClientDialect>(); + } else { + target.addLegalDialect<xla_chlo::XlaHloClientDialect>(); + } target.addLegalDialect<XlaHloDialect>(); target.addLegalDialect<StandardOpsDialect>(); target.addLegalDialect<shape::ShapeDialect>(); @@ -4988,8 +5163,8 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) { } std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFPass( - bool allow_partial_conversion) { - return std::make_unique<LegalizeTF>(allow_partial_conversion); + bool allow_partial_conversion, bool legalize_chlo) { + return std::make_unique<LegalizeTF>(allow_partial_conversion, legalize_chlo); } } // end namespace xla_hlo diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc index 86927fe0e07..d5e5b6f5a71 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc @@ -66,7 +66,7 @@ createLegalizeTFControlFlowPass() { namespace { void Detuple(Value tuple, Operation::result_range replace, OpBuilder* builder) { - // De-tuple the results of the xla hlo conditional result. + // De-tuple the results of the xla hlo if result. for (auto result_it : llvm::enumerate(replace)) { auto get_tuple_value = builder->create<xla_hlo::GetTupleElementOp>( result_it.value().getLoc(), tuple, result_it.index()); @@ -74,14 +74,13 @@ void Detuple(Value tuple, Operation::result_range replace, OpBuilder* builder) { } } -// Imports the source region into the destination region. The XLA conditional +// Imports the source region into the destination region. The XLA if // operation only supports one argument per branch. Therefore any branch that // requires additional arguments requires their values be tupled together. Then, // to support multiple returns (as XLA only supports a single return value) the -// results of the conditional are tupled together. +// results of the if operation are tupled together. void ImportXlaRegion(mlir::FuncOp func, Region* dest_region, Location loc, bool tuple_return = true) { - BlockAndValueMapping mapper; OpBuilder builder(dest_region); auto entry_block = builder.createBlock(dest_region); @@ -111,27 +110,52 @@ void LowerIf(TF::IfOp op, ModuleOp module) { // XLA prefers tuple arguments for control flow due to XLA not supporting // multiple return values. SmallVector<Value, 3> inputs(op.input()); - builder.setInsertionPoint(op); auto tuple_input = builder.create<xla_hlo::TupleOp>(loc, inputs); - // Create the new conditional op with tuple inputs. - SmallVector<Value, 3> operands(op.getOperands()); + // Create the new if op with tuple inputs. auto result_type = builder.getTupleType(op.getResultTypes()); - auto conditional = builder.create<xla_hlo::ConditionalOp>( - loc, result_type, op.cond(), tuple_input, tuple_input); + auto if_op = builder.create<xla_hlo::IfOp>(loc, result_type, op.cond(), + tuple_input, tuple_input); // Import the regions for both the true and false cases. These regions // must be updated to tuple the return results together and use the xla hlo // return op. - BlockAndValueMapping mapper; auto then_branch = module.lookupSymbol<mlir::FuncOp>(op.then_branch()); auto else_branch = module.lookupSymbol<mlir::FuncOp>(op.else_branch()); - ImportXlaRegion(then_branch, &conditional.true_branch(), loc); - ImportXlaRegion(else_branch, &conditional.false_branch(), loc); + ImportXlaRegion(then_branch, &if_op.true_branch(), loc); + ImportXlaRegion(else_branch, &if_op.false_branch(), loc); - // De-tuple the results of the xla hlo conditional result. - builder.setInsertionPointAfter(op); - Detuple(conditional.getResult(), op.getResults(), &builder); + // De-tuple the results of the xla hlo if result. + Detuple(if_op.getResult(), op.getResults(), &builder); + op.erase(); +} + +void LowerCase(TF::CaseOp op, ModuleOp module) { + Location loc = op.getLoc(); + OpBuilder builder(op); + + // XLA requires one argument per branch so we create a tuple of inputs to pass + // to each branch. + SmallVector<Value, 4> inputs(op.input()); + auto tuple_input = builder.create<xla_hlo::TupleOp>(loc, inputs); + + // Create replica of input tuple for each branch + SmallVector<Value, 4> n_tuple_inputs(op.branches().size(), tuple_input); + + // Create the new case op with tuple inputs. + auto case_op = builder.create<xla_hlo::CaseOp>( + loc, op.getResultTypes(), op.branch_index(), n_tuple_inputs, + op.branches().size()); + + // Import the regions for all branches. + for (unsigned i = 0; i < op.branches().size(); ++i) { + mlir::FuncOp branch_func = module.lookupSymbol<mlir::FuncOp>( + op.branches()[i].cast<SymbolRefAttr>()); + ImportXlaRegion(branch_func, &case_op.branches()[i], loc, + /*tuple_return=*/false); + } + + op.replaceAllUsesWith(case_op.getResults()); op.erase(); } @@ -146,7 +170,6 @@ void LowerWhile(TF::WhileOp op, ModuleOp module) { Value tuple_input = builder.create<xla_hlo::TupleOp>(loc, inputs); // Create the new while op with tuple inputs. - SmallVector<Value, 3> operands(op.getOperands()); auto while_op = builder.create<xla_hlo::WhileOp>( loc, builder.getTupleType(op.getResultTypes()), tuple_input); @@ -159,7 +182,6 @@ void LowerWhile(TF::WhileOp op, ModuleOp module) { ImportXlaRegion(cond_branch, &while_op.cond(), loc, /*tuple_return=*/false); // De-tuple the results of the xla hlo while. - builder.setInsertionPointAfter(op); Detuple(while_op.getResult(), op.getResults(), &builder); op.erase(); } @@ -168,8 +190,20 @@ void LowerWhile(TF::WhileOp op, ModuleOp module) { void LegalizeTFControlFlow::runOnOperation() { auto module = getOperation(); - module.walk([&](TF::WhileOp op) -> void { LowerWhile(op, module); }); - module.walk([&](TF::IfOp op) -> void { LowerIf(op, module); }); + module.walk([&](Operation* op) { + if (auto while_op = dyn_cast<TF::WhileOp>(op)) { + LowerWhile(while_op, module); + return; + } + if (auto if_op = dyn_cast<TF::IfOp>(op)) { + LowerIf(if_op, module); + return; + } + if (auto case_op = dyn_cast<TF::CaseOp>(op)) { + LowerCase(case_op, module); + return; + } + }); } } // namespace xla_hlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td index 959902692dc..ef5a8356a32 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td @@ -73,21 +73,6 @@ def : Pattern< // HLO and XLA doesn't support Assertions. def LowerAssert : Pattern<(TF_AssertOp $condition, $data, $summarize), []>; -//===----------------------------------------------------------------------===// -// Bias op patterns. -//===----------------------------------------------------------------------===// -def BiasAddFeatureDimension : NativeCodeCall< - "getBiasFeatureDimension($_builder, $0, $1)">; - -// $input needs to be a ranked tensor to identify index of the feature -// dimension depending on the data_format 'NHWC' or 'NCHW'. -// TODO(laurenzo): This should be converted to do explicit broadcasting since -// it can generate broadcast dimensions that are not compatible with the simple -// xla_chlo.add broadcast_dims. -def : Pat<(TF_BiasAddOp AnyRankedTensor:$input, $bias, $data_format), - (HLO_AddOp $input, $bias, - (BiasAddFeatureDimension $data_format, $input))>; - //===----------------------------------------------------------------------===// // Binary op patterns. //===----------------------------------------------------------------------===// @@ -114,7 +99,8 @@ foreach fromToBinPair = [[TF_AddOp, HLOClient_BroadcastAddOp], def LowerRightShiftSigned : Pat<(TF_RightShiftOp AnyRankedTensor:$l, AnyRankedTensor:$r), - (HLO_ShiftRightArithmeticOp $l, $r, (BinBroadcastDimensions $l, $r)), + (HLOClient_BroadcastShiftRightArithmeticOp $l, $r, + (BinBroadcastDimensions $l, $r)), [(SignedIntTensor $r)]>; // TODO(hinsu): Lower unsigned types to HLO_ShiftRightLogical once the HLO op @@ -126,10 +112,11 @@ def : Pat<(TF_ComplexOp $r, $i), (HLO_ComplexOp $r, $i)>; // // return floor(div(x, y)) def : Pat<(TF_FloorDivOp AnyRankedTensor:$l, AnyRankedTensor:$r), - (HLO_FloorOp (HLO_DivOp $l, $r, (BinBroadcastDimensions $l, $r))), + (HLO_FloorOp + (HLOClient_BroadcastDivOp $l, $r, (BinBroadcastDimensions $l, $r))), [(IEEEFloatTensor $l)]>; -// Performs a substitution of FloorDir for integer tensors, which required +// Performs a substitution of FloorDiv for integer tensors, which required // additional correction for a negative numerator / denominator. Equivalent // pseudocode is shown below: // @@ -148,19 +135,19 @@ def : Pat<(TF_FloorDivOp AnyRankedTensor:$l, AnyRankedTensor:$r), // NOTE: This should be optimized for unsigned integers. // Requires static shaped inputs to create constant splats and computation of // broadcast attributes. -def : Pat<(TF_FloorDivOp AnyStaticShapeTensor:$l, AnyStaticShapeTensor:$r), +def : Pat<(TF_FloorDivOp AnyRankedTensor:$l, AnyRankedTensor:$r), (HLO_SelectOp - (HLO_CompareOp - (HLO_CompareOp $l, (HLO_ConstOp (ConstantSplat<"0"> $l)), + (HLOClient_BroadcastCompareOp + (HLOClient_BroadcastCompareOp $l, (HLO_ConstOp (GetScalarOfType<0> $l)), (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT), - (HLO_CompareOp $r, (HLO_ConstOp (ConstantSplat<"0"> $r)), + (HLOClient_BroadcastCompareOp $r, (HLO_ConstOp (GetScalarOfType<0> $r)), (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT), (BinBroadcastDimensions $l, $r), HLO_COMPARISON_DIRECTION_EQ), - (HLO_DivOp $l, $r, (BinBroadcastDimensions $l, $r)), - (HLO_DivOp - (HLO_NegOp:$neg (HLO_AddOp (HLO_AbsOp $l), - (HLO_SubOp (HLO_AbsOp $r), - (HLO_ConstOp (ConstantSplat<"1"> $r)), + (HLOClient_BroadcastDivOp $l, $r, (BinBroadcastDimensions $l, $r)), + (HLOClient_BroadcastDivOp + (HLO_NegOp:$neg (HLOClient_BroadcastAddOp (HLO_AbsOp $l), + (HLOClient_BroadcastSubOp (HLO_AbsOp $r), + (HLO_ConstOp (GetScalarOfType<1> $r)), (NullDenseIntElementsAttr)), (BinBroadcastDimensions $l, $r))), (HLO_AbsOp:$abs $r), (BinBroadcastDimensions $neg, $abs))), @@ -173,22 +160,22 @@ def : Pat<(TF_FloorDivOp AnyStaticShapeTensor:$l, AnyStaticShapeTensor:$r), // return trunc_mod != 0 && (y < 0 != trunc_mod < 0) ? trunc_mod + y // Requires static shaped inputs to create constant splats and computation of // broadcast attributes. -def : Pat<(TF_FloorModOp AnyStaticShapeTensor:$l, AnyStaticShapeTensor:$r), +def : Pat<(TF_FloorModOp AnyRankedTensor:$l, AnyRankedTensor:$r), (HLO_SelectOp - (HLO_AndOp - (HLO_CompareOp - (HLO_RemOp:$rem $l, $r, (BinBroadcastDimensions $l, $r)), - (HLO_ConstOp:$l_zeros (ConstantSplat<"0"> $l)), + (HLOClient_BroadcastAndOp + (HLOClient_BroadcastCompareOp + (HLOClient_BroadcastRemOp:$rem $l, $r, (BinBroadcastDimensions $l, $r)), + (HLO_ConstOp:$l_zeros (GetScalarOfType<0> $l)), (BinBroadcastDimensions $l, $rem), HLO_COMPARISON_DIRECTION_NE), - (HLO_CompareOp - (HLO_CompareOp:$r_cmp $r, - (HLO_ConstOp:$r_zeros (ConstantSplat<"0"> $r)), + (HLOClient_BroadcastCompareOp + (HLOClient_BroadcastCompareOp:$r_cmp $r, + (HLO_ConstOp:$r_zeros (GetScalarOfType<0> $r)), (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT), - (HLO_CompareOp:$rem_cmp $rem, $r_zeros, + (HLOClient_BroadcastCompareOp:$rem_cmp $rem, $r_zeros, (BinBroadcastDimensions $rem, $r_zeros), HLO_COMPARISON_DIRECTION_LT), (BinBroadcastDimensions $r_cmp, $rem_cmp), HLO_COMPARISON_DIRECTION_NE), (NullDenseIntElementsAttr)), - (HLO_AddOp $r, + (HLOClient_BroadcastAddOp $r, $rem, (BinBroadcastDimensions $r, $rem)), $rem)>; //===----------------------------------------------------------------------===// @@ -406,39 +393,36 @@ def : Pattern<(TF_MatrixBandPartOp:$op AnyRankedTensor:$input, $num_lower, $num_ (HLO_SelectOp:$num_lower_or_m (HLO_CompareOp $num_lower, (HLO_ConstOp:$zero (ConstantSplat<"0"> $num_lower)), - (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT + HLO_COMPARISON_DIRECTION_LT ), $m_dim, $num_lower ), (HLO_SelectOp:$num_upper_or_n (HLO_CompareOp - $num_upper, $zero, - (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT + $num_upper, $zero, HLO_COMPARISON_DIRECTION_LT ), $n_dim, $num_upper ), (HLO_SelectOp (HLO_AndOp - (HLO_CompareOp + (HLOClient_BroadcastCompareOp (HLO_NegOp (createConvertOp $op, $num_lower_or_m, $input) ), (HLO_SubOp:$offset - (createIotaOp<"1"> $op, $input), (createIotaOp<"0"> $op, $input), - (NullDenseIntElementsAttr) + (createIotaOp<"1"> $op, $input), (createIotaOp<"0"> $op, $input) ), (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LE ), - (HLO_CompareOp + (HLOClient_BroadcastCompareOp $offset, (createConvertOp $op, $num_upper_or_n, $input ), (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LE - ), - (BinBroadcastDimensions $offset, $input) + ) ), $input, (HLO_ConstOp (ConstantSplat<"0"> $input)) @@ -462,8 +446,9 @@ def : Pat<(TF_ConstOp:$res ElementsAttr:$value), // TODO(hinsu): Lower unsigned and quantized types after supporting // them in GetScalarOfType. def : Pat<(TF_ReluOp AnyRankedTensor:$input), - (HLO_MaxOp (HLO_ConstOp:$zero (GetScalarOfType<0> $input)), $input, - (BinBroadcastDimensions $zero, $input)), + (HLOClient_BroadcastMaxOp + (HLO_ConstOp:$zero (GetScalarOfType<0> $input)), $input, + (BinBroadcastDimensions $zero, $input)), [(TF_SintOrFpTensor $input)]>; // TODO(hinsu): Lower unsigned and quantized types after supporting @@ -485,7 +470,7 @@ def : Pat<(TF_Relu6Op AnyRankedTensor:$input), // to create splat tensor of dynamic shape in HLO. def : Pat<(TF_ReluGradOp AnyStaticShapeTensor:$gradients, AnyRankedTensor:$features), (HLO_SelectOp - (HLO_CompareOp $features, + (HLOClient_BroadcastCompareOp $features, (HLO_ConstOp (GetScalarOfType<0> $features)), (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_GT), $gradients, (HLO_ConstOp (ConstantSplat<"0"> $gradients)))>; @@ -536,18 +521,6 @@ def ConvertAxisAttr : NativeCodeCall<"ConvertAxisAttr($0, $1, &$_builder)">; def : Pat<(TF_ReverseV2Op AnyRankedTensor:$values, (TF_ConstOp $axis)), (HLO_ReverseOp $values, (ConvertAxisAttr $values, $axis))>; -//===----------------------------------------------------------------------===// -// Ternary op patterns. -//===----------------------------------------------------------------------===// - -def BothTypesMatch : Constraint<CPred<"$0.getType() == $1.getType()">, - "types must be equal">; - -def : Pat<(TF_SelectOp $cond, $t, $e), (HLO_SelectOp $cond, $t, $e), - // TODO(jpienaar): This restriction is to avoid creating a currently - // unsupported HLO select. - [(BothTypesMatch $t, $e)]>; - //===----------------------------------------------------------------------===// // Unary op patterns. //===----------------------------------------------------------------------===// @@ -598,7 +571,6 @@ def : Pat<(TF_SignOp $x), (HLO_CompareOp $x, $x, - (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_NE ), (HLO_ConstOp (ConstantSplat<"0"> $x)), @@ -639,10 +611,10 @@ def : Pat<(srcDstOpPair[0]:$old $shape, $seed, $seed2), //===----------------------------------------------------------------------===// // Sigmoid grad op. //===----------------------------------------------------------------------===// + +// TODO(hinsu): Handle unranked inputs by broadcasting constant one to the +// shape of $l instead of having it as a constant. def : Pat<(TF_SigmoidGradOp AnyRankedTensor:$l, AnyRankedTensor:$r), (HLO_MulOp - (HLO_MulOp $r, $l, (NullDenseIntElementsAttr)), - (HLO_SubOp (HLO_ConstOp (ConstantSplat<"1"> $l)), $l, - (NullDenseIntElementsAttr)), - (NullDenseIntElementsAttr)), - [(IEEEFloatTensor $l)]>; + (HLO_MulOp $r, $l), + (HLO_SubOp (HLO_ConstOp (ConstantSplat<"1"> $l)), $l))>; diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc index 76657bd5e20..b15974979c9 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc @@ -87,6 +87,7 @@ static bool IsOpWhitelisted(Operation* op) { TypeID::get<TF::AcosOp>(), TypeID::get<TF::AddNOp>(), TypeID::get<TF::AddV2Op>(), + TypeID::get<TF::AngleOp>(), TypeID::get<TF::ApproximateEqualOp>(), TypeID::get<TF::AsinhOp>(), TypeID::get<TF::AsinOp>(), diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard_patterns.td index c0f6c2c3541..21e39db018b 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard_patterns.td @@ -36,47 +36,36 @@ def IsSameSizePred : CPred< def IsSameSizeConstraint : Constraint<IsSameSizePred, "inputs are same size">; -def : Pat<(HLO_AndOp HLO_PredTensor:$l, HLO_PredTensor:$r, - IsNullAttr:$broadcast_dimensions), +def : Pat<(HLO_AndOp HLO_PredTensor:$l, HLO_PredTensor:$r), (AndOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(HLO_AddOp HLO_FpTensor:$l, HLO_FpTensor:$r, - IsNullAttr:$broadcast_dimensions), +def : Pat<(HLO_AddOp HLO_FpTensor:$l, HLO_FpTensor:$r), (AddFOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(HLO_SubOp HLO_FpTensor:$l, HLO_FpTensor:$r, - IsNullAttr:$broadcast_dimensions), +def : Pat<(HLO_SubOp HLO_FpTensor:$l, HLO_FpTensor:$r), (SubFOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(HLO_MulOp HLO_FpTensor:$l, HLO_FpTensor:$r, - IsNullAttr:$broadcast_dimensions), +def : Pat<(HLO_MulOp HLO_FpTensor:$l, HLO_FpTensor:$r), (MulFOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(HLO_DivOp HLO_FpTensor:$l, HLO_FpTensor:$r, - IsNullAttr:$broadcast_dimensions), +def : Pat<(HLO_DivOp HLO_FpTensor:$l, HLO_FpTensor:$r), (DivFOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(HLO_RemOp HLO_FpTensor:$l, HLO_FpTensor:$r, - IsNullAttr:$broadcast_dimensions), +def : Pat<(HLO_RemOp HLO_FpTensor:$l, HLO_FpTensor:$r), (RemFOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(HLO_AddOp HLO_IntTensor:$l, HLO_IntTensor:$r, - IsNullAttr:$broadcast_dimensions), +def : Pat<(HLO_AddOp HLO_IntTensor:$l, HLO_IntTensor:$r), (AddIOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(HLO_SubOp HLO_IntTensor:$l, HLO_IntTensor:$r, - IsNullAttr:$broadcast_dimensions), +def : Pat<(HLO_SubOp HLO_IntTensor:$l, HLO_IntTensor:$r), (SubIOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(HLO_MulOp HLO_IntTensor:$l, HLO_IntTensor:$r, - IsNullAttr:$broadcast_dimensions), +def : Pat<(HLO_MulOp HLO_IntTensor:$l, HLO_IntTensor:$r), (MulIOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(HLO_DivOp HLO_IntTensor:$l, HLO_IntTensor:$r, - IsNullAttr:$broadcast_dimensions), +def : Pat<(HLO_DivOp HLO_IntTensor:$l, HLO_IntTensor:$r), (SignedDivIOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(HLO_RemOp HLO_IntTensor:$l, HLO_IntTensor:$r, - IsNullAttr:$broadcast_dimensions), +def : Pat<(HLO_RemOp HLO_IntTensor:$l, HLO_IntTensor:$r), (SignedRemIOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc index 43c0911a4a6..ddbb672c70a 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc @@ -57,8 +57,9 @@ class LhloFuseLinalg : public PassWrapper<LhloFuseLinalg, FunctionPass> { for (auto func_arg : func.getArguments()) { func_args.insert(func_arg); } + MLIRContext* ctx = func.getContext(); OpBuilder b(func); - OperationFolder folder(func.getContext()); + OperationFolder folder(ctx); func.walk([&](linalg::GenericOp generic_op) { SmallVector<int64_t, 2> tile_sizes(tile_sizes_.begin(), tile_sizes_.end()); @@ -68,12 +69,14 @@ class LhloFuseLinalg : public PassWrapper<LhloFuseLinalg, FunctionPass> { auto op = cast<LinalgOp>(generic_op.getOperation()); for (const Value result : op.getOutputBuffers()) { if (!func_args.count(result)) continue; - if (tileGenericOp(op, tile_sizes, &b, &folder)) { + if (tileGenericOp(op, tile_sizes, &b)) { generic_op.erase(); return; } } }); + auto patterns = linalg::getLinalgTilingCanonicalizationPatterns(ctx); + applyPatternsAndFoldGreedily(func, patterns); // Fuse producers of tiled linalg ops. llvm::SmallDenseSet<Operation*> erase_set; @@ -92,19 +95,22 @@ class LhloFuseLinalg : public PassWrapper<LhloFuseLinalg, FunctionPass> { *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); } } + + auto patterns = linalg::getLinalgTilingCanonicalizationPatterns(ctx); + applyPatternsAndFoldGreedily(func, patterns); } for (auto* e : erase_set) e->erase(); } private: - bool tileGenericOp(LinalgOp op, ArrayRef<int64_t> tile_sizes, OpBuilder* b, - OperationFolder* folder) { - auto tiled_generic_op = - use_parallel_loops_ - ? linalg::tileLinalgOpToParallelLoops(*b, op, tile_sizes, - /*permutation=*/{}, folder) - : linalg::tileLinalgOp(*b, op, tile_sizes, - /*permutation=*/{}, folder); + bool tileGenericOp(LinalgOp op, ArrayRef<int64_t> tile_sizes, OpBuilder* b) { + auto loopType = use_parallel_loops_ + ? linalg::LinalgTilingLoopType::ParallelLoops + : linalg::LinalgTilingLoopType::Loops; + auto tiled_generic_op = linalg::tileLinalgOp(*b, op, + linalg::LinalgTilingOptions() + .setTileSizes(tile_sizes) + .setLoopType(loopType)); return tiled_generic_op.hasValue(); } diff --git a/tensorflow/compiler/mlir/xla/transforms/lower_complex_patterns.td b/tensorflow/compiler/mlir/xla/transforms/lower_complex_patterns.td index dcb0ab20e9e..e1ae5ef6abf 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lower_complex_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/lower_complex_patterns.td @@ -28,70 +28,62 @@ include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td" // and imaginary components. foreach elementwiseOp = [HLO_AddOp, HLO_SubOp] in def : Pat<(elementwiseOp HLO_ComplexTensor:$lhs, - HLO_ComplexTensor:$rhs, $broadcast_dimensions), + HLO_ComplexTensor:$rhs), (HLO_ComplexOp - (elementwiseOp (HLO_RealOp $lhs), (HLO_RealOp $rhs), - $broadcast_dimensions), - (elementwiseOp (HLO_ImagOp $lhs), (HLO_ImagOp $rhs), - $broadcast_dimensions))>; + (elementwiseOp (HLO_RealOp $lhs), (HLO_RealOp $rhs)), + (elementwiseOp (HLO_ImagOp $lhs), (HLO_ImagOp $rhs)))>; // Complex multiplication results in a cross product multiplication between the // real and imaginary components such that: // result.real = lhs.real * rhs.real - lhs.imag * rhs.imag // result.imag = lhs.imag * rhs.real + lhs.real * rhs.imag def : Pat<(HLO_MulOp HLO_ComplexTensor:$lhs, - HLO_ComplexTensor:$rhs, $broadcast_dimensions), + HLO_ComplexTensor:$rhs), (HLO_ComplexOp (HLO_SubOp (HLO_MulOp (HLO_RealOp:$lhs_real $lhs), - (HLO_RealOp:$rhs_real $rhs), - $broadcast_dimensions), + (HLO_RealOp:$rhs_real $rhs)), (HLO_MulOp (HLO_ImagOp:$lhs_imag $lhs), - (HLO_ImagOp:$rhs_imag $rhs), - $broadcast_dimensions), - (NullDenseIntElementsAttr)), + (HLO_ImagOp:$rhs_imag $rhs))), (HLO_AddOp - (HLO_MulOp $lhs_real, $rhs_imag, $broadcast_dimensions), - (HLO_MulOp $lhs_imag, $rhs_real, $broadcast_dimensions), - (NullDenseIntElementsAttr)))>; + (HLO_MulOp $lhs_real, $rhs_imag), + (HLO_MulOp $lhs_imag, $rhs_real)))>; // Multiplication between a complex and real tensor can be distributed by // applying the real multiplicant to both the real and complex component. // // Note that the sourcep pattern is not legal according to the HLO dialect but // instead handle intermediates generated by other patterns. -def : Pat<(HLO_MulOp HLO_ComplexTensor:$lhs, HLO_IntOrFpTensor:$rhs, $broadcast_dimensions), +def : Pat<(HLO_MulOp HLO_ComplexTensor:$lhs, HLO_IntOrFpTensor:$rhs), (HLO_ComplexOp - (HLO_MulOp (HLO_RealOp $lhs), $rhs, $broadcast_dimensions), - (HLO_MulOp (HLO_ImagOp $lhs), $rhs, $broadcast_dimensions))>; + (HLO_MulOp (HLO_RealOp $lhs), $rhs), + (HLO_MulOp (HLO_ImagOp $lhs), $rhs))>; -def : Pat<(HLO_MulOp HLO_IntOrFpTensor:$lhs, HLO_ComplexTensor:$rhs, $broadcast_dimensions), +def : Pat<(HLO_MulOp HLO_IntOrFpTensor:$lhs, HLO_ComplexTensor:$rhs), (HLO_ComplexOp - (HLO_MulOp $lhs, (HLO_RealOp $rhs), $broadcast_dimensions), - (HLO_MulOp $lhs, (HLO_ImagOp $rhs), $broadcast_dimensions))>; + (HLO_MulOp $lhs, (HLO_RealOp $rhs)), + (HLO_MulOp $lhs, (HLO_ImagOp $rhs)))>; // Division is performed by normalizing the denominator by multiplying by the // conjugate of the rhs. // numerator = lhs * conj(rhs) // denominator = rhs * conj(rhs) -def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_ComplexTensor:$rhs, $broadcast_dimensions), +def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_ComplexTensor:$rhs), (HLO_DivOp (HLO_MulOp:$num $lhs, (HLO_ComplexOp:$conj (HLO_RealOp $rhs), - (HLO_NegOp (HLO_ImagOp $rhs))), - $broadcast_dimensions), - (HLO_RealOp:$den (HLO_MulOp $rhs, $conj, $broadcast_dimensions)), - (BinBroadcastDimensions $num, $den))>; + (HLO_NegOp (HLO_ImagOp $rhs)))), + (HLO_RealOp:$den (HLO_MulOp $rhs, $conj)))>; -def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_IntOrFpTensor:$rhs, $broadcast_dimensions), +def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_IntOrFpTensor:$rhs), (HLO_ComplexOp - (HLO_DivOp (HLO_RealOp $lhs), $rhs, $broadcast_dimensions), - (HLO_DivOp (HLO_ImagOp $lhs), $rhs, $broadcast_dimensions))>; + (HLO_DivOp (HLO_RealOp $lhs), $rhs), + (HLO_DivOp (HLO_ImagOp $lhs), $rhs))>; // Absolute value is evaluated as: @@ -100,11 +92,8 @@ def : Pat<(HLO_AbsOp HLO_ComplexTensor:$val), (HLO_ComplexOp (HLO_SqrtOp (HLO_AddOp - (HLO_MulOp (HLO_RealOp:$real $val), $real, - (NullDenseIntElementsAttr)), - (HLO_MulOp (HLO_ImagOp:$imag $val), $imag, - (NullDenseIntElementsAttr)), - (NullDenseIntElementsAttr))), + (HLO_MulOp (HLO_RealOp:$real $val), $real), + (HLO_MulOp (HLO_ImagOp:$imag $val), $imag))), (HLO_ConstOp (ConstantSplat<"0"> $real)))>; // Exponential can be lowered to an exponential on the real component and a @@ -117,5 +106,4 @@ def : Pat<(HLO_ExpOp HLO_ComplexTensor:$val), (HLO_ExpOp (HLO_RealOp $val)), (HLO_ComplexOp (HLO_CosOp (HLO_ImagOp:$imag $val)), - (HLO_SinOp $imag)), - (NullDenseIntElementsAttr))>; + (HLO_SinOp $imag)))>; diff --git a/tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h b/tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h index fed21e9bafc..21b954a3eb4 100644 --- a/tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h +++ b/tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h @@ -49,6 +49,7 @@ MAP_HLO_TO_LHLO(ConvertOp); MAP_HLO_TO_LHLO(CopyOp); MAP_HLO_TO_LHLO(CosOp); MAP_HLO_TO_LHLO(DivOp); +MAP_HLO_TO_LHLO(DotOp); MAP_HLO_TO_LHLO(ExpOp); MAP_HLO_TO_LHLO(ImagOp); MAP_HLO_TO_LHLO(IotaOp); diff --git a/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts.cc b/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts.cc index bf666400900..c56f5adc12d 100644 --- a/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts.cc +++ b/tensorflow/compiler/mlir/xla/transforms/materialize_broadcasts.cc @@ -28,264 +28,6 @@ namespace xla_hlo { namespace { -// Returns a 1-d i64 elements attribute populated with numbers from start to -// end, excluding. -static DenseIntElementsAttr GetI64ElementsAttrForSeq(int start, int end, - Builder *builder) { - int size = end - start; - - SmallVector<int64_t, 4> vals; - vals.resize(size); - std::iota(vals.begin(), vals.end(), start); - - TensorType ty = RankedTensorType::get({size}, builder->getIntegerType(64)); - return DenseIntElementsAttr::get(ty, vals); -} - -// Helper function for OpRewritePattern classes to materialize broadcasts on -// LHS and RHS arguments to a binary op. -// -// Returns true and sets out_lhs and out_rhs to BroadcastInDimOps if successful, -// returns false otherwise. -template <typename SrcOp> -bool CreateStaticBroadcastsForBinaryOp(SrcOp op, PatternRewriter *rewriter, - Value *out_lhs, Value *out_rhs) { - // Insert BroadcastInDimOps for the left-hand-side and right-hand-side args, - // replacing the original LHS and RHS args in the source op with the results - // of the broadcasts. - // - // If the higher dimensional argument does not actually need the broadcast, - // a canonicalization pass should be able to remove that op later. - Value lhs = op.lhs(); - Value rhs = op.rhs(); - - auto op_ranked_type = op.getType().template dyn_cast<RankedTensorType>(); - auto lhs_ranked_type = lhs.getType().dyn_cast<RankedTensorType>(); - auto rhs_ranked_type = rhs.getType().dyn_cast<RankedTensorType>(); - if (!op_ranked_type || !lhs_ranked_type || !rhs_ranked_type) { - // Unranked, can't determine at this point how to perform the broadcast. - return false; - } - - // Dynamic result shape, can't use BroadcastInDimOp. - assert(op_ranked_type.hasStaticShape() && - "dynamic shape requires DynamicBroadcastInDim"); - - auto lhs_rank = lhs_ranked_type.getRank(); - auto rhs_rank = rhs_ranked_type.getRank(); - ArrayRef<int64_t> op_shape = op_ranked_type.getShape(); - - // BroadcastInDimOp must have the same element type for operands and results, - // so preserve the original output shape and the original input element type. - // For example, `SrcOp (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xi1>`: - // broadcast_in_dim (tensor<1x4xf32>) -> tensor<1x4xf32> - // broadcast_in_dim (tensor<4xf32>) -> tensor<1x4xf32> - // SrcOp (tensor<1x4xf32>, tensor<1x4xf32>) -> tensor<1x4xi1> - if (lhs_ranked_type.getShape() != op_ranked_type.getShape()) { - auto type = - RankedTensorType::get(op_shape, lhs_ranked_type.getElementType()); - DenseIntElementsAttr attr = GetI64ElementsAttrForSeq(0, lhs_rank, rewriter); - if (lhs_rank < rhs_rank) { - attr = op.broadcast_dimensions().getValue(); - } - - lhs = - rewriter->createOrFold<BroadcastInDimOp>(op.getLoc(), type, lhs, attr); - } - - if (rhs_ranked_type.getShape() != op_ranked_type.getShape()) { - auto type = - RankedTensorType::get(op_shape, rhs_ranked_type.getElementType()); - DenseIntElementsAttr attr = GetI64ElementsAttrForSeq(0, rhs_rank, rewriter); - if (rhs_rank < lhs_rank) { - attr = op.broadcast_dimensions().getValue(); - } - - rhs = - rewriter->createOrFold<BroadcastInDimOp>(op.getLoc(), type, rhs, attr); - } - - *out_lhs = lhs; - *out_rhs = rhs; - return true; -} - -// Helper template to generate code for computing the result shape of a -// broadcasted operation. This ultimately should be subsumed by functions -// from the shape dialect. -// Assumes that large and small are the operand values of `op` and that they -// have a ranked tensory type with rank(large) >= rank(small). -template <typename SrcOp> -std::vector<Value> ComputeBroadcastedShape(SrcOp op, Value small, Value large, - PatternRewriter *rewriter) { - auto loc = op.getLoc(); - auto larger_ranked_type = large.getType().cast<RankedTensorType>(); - auto output_rank = larger_ranked_type.getRank(); - - constexpr int kExpandShape = -1; - - std::vector<Value> shape_values; - shape_values.reserve(output_rank); - std::vector<int> indexes(output_rank, kExpandShape); - DenseIntElementsAttr broadcast_dimensions = - op.broadcast_dimensions().getValue(); - // Compute a mapping from output dimensions to their corresponding input - // dimensions in the smaller ranked operand. - for (auto pair : llvm::enumerate(broadcast_dimensions.getIntValues())) { - indexes.at(pair.value().getLimitedValue()) = pair.index(); - } - - // Compute the broadcasted shape of the result using numpy style broadcasting - // semantics. The result shape at a position is the shape of the larger - // operand at that position if the no dimension of the smaller operand is - // mapped to it. - // If both operands contribute to an output dimension, their shape has to - // either be the same in that dimension or it can be 1, in which case the - // shape of the other operand is used. - for (int i = 0; i < output_rank; ++i) { - Value index_value; - if (indexes[i] == kExpandShape) { - // The smaller shape gets expanded to the larger one in this case. - index_value = rewriter->create<mlir::DimOp>(loc, large, i); - } else { - // Compute the result shape depending on whether the rank of smaller is 1. - // This does not check that the broadcast operation actualy is correct. - // In particular, we do not check that both shapes are the same if the - // smaller ranked shape is not 1. - ConstantOp one = rewriter->create<mlir::ConstantOp>( - loc, rewriter->getIntegerAttr(rewriter->getIndexType(), 1)); - DimOp lrg_dim = rewriter->create<mlir::DimOp>(loc, large, i); - DimOp sml_dim = rewriter->create<mlir::DimOp>(loc, small, indexes[i]); - CmpIOp compare = - rewriter->create<mlir::CmpIOp>(loc, CmpIPredicate::eq, lrg_dim, one); - index_value = - rewriter->create<mlir::SelectOp>(loc, compare, lrg_dim, sml_dim); - } - // Ideally, we would like to keep this on index but MLIR does not allow - // this. - shape_values.push_back(rewriter->create<mlir::IndexCastOp>( - loc, index_value, rewriter->getIntegerType(32))); - } - - return shape_values; -} - -// Helper function for OpRewritePattern classes to materialize dynamic -// broadcasts on LHS and RHS arguments to a binary op. -// -// Returns true and set out_lhs and out_rhs for materialized dynamic broadcasts -// for LHS and RHS arguments, else returns false. -template <typename SrcOp> -bool CreateDynamicBroadcastsForBinaryOp(SrcOp op, PatternRewriter *rewriter, - Value *out_lhs, Value *out_rhs) { - if (!op.broadcast_dimensions().hasValue()) { - // Note: the op may still have an implicit broadcast on it, such as - // for (tensor<1xf32>, tensor<4xf32>). - return false; - } - - // Insert BroadcastInDimOps for the left-hand-side and right-hand-side args, - // replacing the original LHS and RHS args in the source op with the results - // of the broadcasts. - Value lhs = op.lhs(); - Value rhs = op.rhs(); - - auto lhs_ranked_type = lhs.getType().dyn_cast<RankedTensorType>(); - auto rhs_ranked_type = rhs.getType().dyn_cast<RankedTensorType>(); - if (!lhs_ranked_type || !rhs_ranked_type) { - // Unranked, can't determine at this point how to perform the broadcast. - return false; - } - - auto lhs_rank = lhs_ranked_type.getRank(); - auto rhs_rank = rhs_ranked_type.getRank(); - - // Set broadcast_dimensions to [0, ..., rank] for the higher rank arg. - // Use the original op.broadcast_dimensions for the lower rank arg. - auto higher_rank_broadcast_dims = - GetI64ElementsAttrForSeq(0, std::max(lhs_rank, rhs_rank), rewriter); - DenseIntElementsAttr lhs_broadcast_dims; - DenseIntElementsAttr rhs_broadcast_dims; - std::vector<Value> shape_elements; - if (lhs_rank > rhs_rank) { - lhs_broadcast_dims = higher_rank_broadcast_dims; - rhs_broadcast_dims = op.broadcast_dimensions().getValue(); - shape_elements = ComputeBroadcastedShape<SrcOp>(op, rhs, lhs, rewriter); - } else if (lhs_rank < rhs_rank) { - lhs_broadcast_dims = op.broadcast_dimensions().getValue(); - rhs_broadcast_dims = higher_rank_broadcast_dims; - shape_elements = ComputeBroadcastedShape<SrcOp>(op, lhs, rhs, rewriter); - } else { - // This shouldn't happen for legal ops. If the broadcast_dimensions - // attribute is set, the ranks should be different. - // TODO(scotttodd): Add a custom verification for ops and assert here. - return false; - } - - // DynamicBroadcastInDimOp preserves the element type but produces a tensor - // with unranked shape. The rank of the output is the length of the - // output shape argument. - SmallVector<int64_t, 4> op_shape(shape_elements.size(), - RankedTensorType::kDynamicSize); - auto lhs_type = - RankedTensorType::get(op_shape, lhs_ranked_type.getElementType()); - auto rhs_type = - RankedTensorType::get(op_shape, rhs_ranked_type.getElementType()); - - // We need a way to turn a list of scalars into a vector. While Standard - // dialect does not have one, use the XLA_HLO variant. - int shape_size = shape_elements.size(); - Type shape_element_type = shape_elements.front().getType(); - Value shape_value = rewriter->create<ScalarsToDimensionTensorOp>( - op.getLoc(), RankedTensorType::get({shape_size}, shape_element_type), - shape_elements); - - *out_lhs = rewriter->createOrFold<DynamicBroadcastInDimOp>( - op.getLoc(), lhs_type, lhs, shape_value, lhs_broadcast_dims); - *out_rhs = rewriter->createOrFold<DynamicBroadcastInDimOp>( - op.getLoc(), rhs_type, rhs, shape_value, rhs_broadcast_dims); - return true; -} - -template <typename SrcOp> -bool CreateBroadcastForBinaryOp(SrcOp op, PatternRewriter *rewriter, - Value *out_lhs, Value *out_rhs) { - auto op_ranked_type = op.getType().template dyn_cast<RankedTensorType>(); - if (!op_ranked_type) return false; - - if (op_ranked_type.hasStaticShape()) { - if (!CreateStaticBroadcastsForBinaryOp(op, rewriter, out_lhs, out_rhs)) { - return false; - } - } else { - if (!CreateDynamicBroadcastsForBinaryOp(op, rewriter, out_lhs, out_rhs)) { - return false; - } - } - return true; -} - -template <typename SrcOp> -struct BinaryOpWithBroadcastConvert : public OpRewritePattern<SrcOp> { - explicit BinaryOpWithBroadcastConvert(MLIRContext *context) - : OpRewritePattern<SrcOp>(context) {} - - LogicalResult matchAndRewrite(SrcOp op, - PatternRewriter &rewriter) const override { - Value new_lhs; - Value new_rhs; - - if (!CreateBroadcastForBinaryOp(op, &rewriter, &new_lhs, &new_rhs)) - return failure(); - - // Replace the original op with a new one that uses the new args. - // New args are broadcasts, so no dims are needed on the replacement op. - rewriter.replaceOpWithNewOp<SrcOp>(op, op.getType(), new_lhs, new_rhs, - /*broadcast_dims=*/nullptr); - return success(); - } -}; - // Converts ClampOp with broadcast semantics. ClampOp requires "all three arrays // must be the same shape. Alternatively, as a restricted form of broadcasting, // min and/or max can be a scalar of type T." @@ -327,63 +69,10 @@ struct ClampWithBroadcastConvert : public OpRewritePattern<ClampOp> { } }; -// Specialized class for CompareOp, as it has an additional builder argument. -struct CompareWithBroadcastConvert : public OpRewritePattern<CompareOp> { - explicit CompareWithBroadcastConvert(MLIRContext *context) - : OpRewritePattern<CompareOp>(context) {} - - LogicalResult matchAndRewrite(CompareOp op, - PatternRewriter &rewriter) const override { - Value new_lhs; - Value new_rhs; - - if (!CreateBroadcastForBinaryOp(op, &rewriter, &new_lhs, &new_rhs)) - return failure(); - - rewriter.replaceOpWithNewOp<CompareOp>(op, op.getType(), new_lhs, new_rhs, - /*broadcast_dims=*/nullptr, - op.comparison_direction()); - return success(); - } -}; - } // namespace void SetupMaterializeBroadcastsLegality(MLIRContext *context, ConversionTarget *conversionTarget) { -#define ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(OpType) \ - conversionTarget->addDynamicallyLegalOp<OpType>([](OpType op) { \ - if (op.broadcast_dimensions().hasValue()) return false; \ - auto l = op.lhs().getType().cast<ShapedType>(); \ - auto r = op.rhs().getType().cast<ShapedType>(); \ - if (!l.hasRank() || !r.hasRank()) return false; \ - return l.getShape() == r.getShape(); \ - }); - - // Binary elementwise ops. - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(AddOp); - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(Atan2Op); - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(DivOp); - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(MaxOp); - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(MinOp); - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(MulOp); - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(PowOp); - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(RemOp); - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(ShiftLeftOp); - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(ShiftRightArithmeticOp); - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(ShiftRightLogicalOp); - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(SubOp); - - // Binary logical elementwise ops. - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(AndOp); - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(OrOp); - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(XorOp); - - // CompareOp. - ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST(CompareOp); - -#undef ADD_DYNAMICALLY_LEGAL_OP_WITH_BROADCAST - conversionTarget->addDynamicallyLegalOp<ClampOp>([](ClampOp op) { return op.max().getType() == op.operand().getType() && op.min().getType() == op.operand().getType(); @@ -392,30 +81,10 @@ void SetupMaterializeBroadcastsLegality(MLIRContext *context, void PopulateMaterializeBroadcastsPatterns(MLIRContext *context, OwningRewritePatternList *patterns) { - // Binary elementwise ops. - patterns->insert<BinaryOpWithBroadcastConvert<AddOp>>(context); - patterns->insert<BinaryOpWithBroadcastConvert<Atan2Op>>(context); - patterns->insert<BinaryOpWithBroadcastConvert<DivOp>>(context); - patterns->insert<BinaryOpWithBroadcastConvert<MaxOp>>(context); - patterns->insert<BinaryOpWithBroadcastConvert<MinOp>>(context); - patterns->insert<BinaryOpWithBroadcastConvert<MulOp>>(context); - patterns->insert<BinaryOpWithBroadcastConvert<PowOp>>(context); - patterns->insert<BinaryOpWithBroadcastConvert<RemOp>>(context); - patterns->insert<BinaryOpWithBroadcastConvert<ShiftLeftOp>>(context); - patterns->insert<BinaryOpWithBroadcastConvert<ShiftRightArithmeticOp>>( - context); - patterns->insert<BinaryOpWithBroadcastConvert<ShiftRightLogicalOp>>(context); - patterns->insert<BinaryOpWithBroadcastConvert<SubOp>>(context); - - // Binary logical elementwise ops. - patterns->insert<BinaryOpWithBroadcastConvert<AndOp>>(context); - patterns->insert<BinaryOpWithBroadcastConvert<OrOp>>(context); - patterns->insert<BinaryOpWithBroadcastConvert<XorOp>>(context); - - // ClampOp. It can have a restricted form of broadcasting. + // ClampOp. This op has a special case where it accepts either same-shaped + // inputs or scalars (a restricted form of broadcasting). This makes the + // broadcast explicit. patterns->insert<ClampWithBroadcastConvert>(context); - // CompareOp. Note the specialized class instead of using the template. - patterns->insert<CompareWithBroadcastConvert>(context); } } // namespace xla_hlo diff --git a/tensorflow/compiler/mlir/xla/transforms/passes.h b/tensorflow/compiler/mlir/xla/transforms/passes.h index 39375e210d5..a1dd6c5ce1e 100644 --- a/tensorflow/compiler/mlir/xla/transforms/passes.h +++ b/tensorflow/compiler/mlir/xla/transforms/passes.h @@ -36,7 +36,7 @@ namespace xla_hlo { /// Lowers from TF dialect to HLO dialect. When allow_partial_conversion is /// false, emits an error if there is any operation that can't be legalized. std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFPass( - bool allow_partial_conversion = false); + bool allow_partial_conversion = false, bool legalize_chlo = true); /// Lowers from TF dialect to HLO dialect using tf2xla op kernels for the /// specified device type. @@ -50,7 +50,8 @@ std::unique_ptr<OperationPass<ModuleOp>> createLegalizeTFControlFlowPass(); /// dialect using the conversion patterns registered by the HLO dialect. When /// allow_partial_conversion is false, emits an error if there is any operation /// that can't be legalized. -LogicalResult legalizeTF(Operation* op, bool allow_partial_conversion = false); +LogicalResult legalizeTF(Operation* op, bool allow_partial_conversion = false, + bool legalize_chlo = true); /// Lowers HLO control flow ops to the Standard dialect. std::unique_ptr<OperationPass<FuncOp>> createLegalizeControlFlowPass(); @@ -65,6 +66,10 @@ std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass(); // Lowers from HLO dialect to Linalg dialect. std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass(); +// Sinks constants implicitly captured in control flow regions. This is +// necessary to export to XLA. +std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass(); + } // namespace xla_hlo namespace xla_lhlo { diff --git a/tensorflow/compiler/mlir/xla/transforms/sink_constants_to_control_flow.cc b/tensorflow/compiler/mlir/xla/transforms/sink_constants_to_control_flow.cc new file mode 100644 index 00000000000..5a45e0f3b18 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/transforms/sink_constants_to_control_flow.cc @@ -0,0 +1,85 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "llvm/ADT/DenseMap.h" +#include "llvm/Support/Casting.h" +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Transforms/RegionUtils.h" // from @llvm-project +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" + +namespace mlir { +namespace xla_hlo { + +namespace { + +// A pass that sinks constants implicitly captured in control flow regions. This +// is necessary to export to XLA. +class SinkConstantsToControlFlow + : public mlir::PassWrapper<SinkConstantsToControlFlow, FunctionPass> { + void runOnFunction() override { + getFunction().walk([](Operation* op) { + if (auto while_op = llvm::dyn_cast<WhileOp>(op)) { + SinkToRegion(&while_op.body()); + SinkToRegion(&while_op.cond()); + } else if (auto if_op = llvm::dyn_cast<IfOp>(op)) { + SinkToRegion(&if_op.true_branch()); + SinkToRegion(&if_op.false_branch()); + } + }); + } + + private: + // Performs constant sinking into a region. + static void SinkToRegion(Region* region) { + llvm::DenseMap<Value, ConstOp> sunk_constant; + visitUsedValuesDefinedAbove({*region}, [&](OpOperand* use) { + Value constant = use->get(); + auto const_op = dyn_cast_or_null<ConstOp>(constant.getDefiningOp()); + if (!const_op) return; + auto map_entry = sunk_constant.try_emplace(constant, nullptr); + if (!map_entry.second) { + // This constant has already been cloned into the region, reuse it. + use->set(map_entry.first->getSecond().getResult()); + if (constant.use_empty()) const_op.erase(); + return; + } + if (constant.hasOneUse()) { + const_op.getOperation()->moveBefore(®ion->front().front()); + return; + } + map_entry.first->getSecond() = const_op.clone(); + region->front().getOperations().insert(region->front().begin(), + map_entry.first->getSecond()); + use->set(map_entry.first->getSecond().getResult()); + }); + } +}; + +static mlir::PassRegistration<SinkConstantsToControlFlow> pass( + "xla-hlo-sink-constants-to-control-flow", + "Sink constants implicitly captured in control flow regions. This is " + "necessary to export to XLA."); + +} // anonymous namespace + +std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass() { + return std::make_unique<SinkConstantsToControlFlow>(); +} + +} // namespace xla_hlo +} // namespace mlir diff --git a/tensorflow/compiler/mlir/xla/transforms/unfuse_batch_norm.cc b/tensorflow/compiler/mlir/xla/transforms/unfuse_batch_norm.cc index 32d8b079c89..98eb404e4d4 100644 --- a/tensorflow/compiler/mlir/xla/transforms/unfuse_batch_norm.cc +++ b/tensorflow/compiler/mlir/xla/transforms/unfuse_batch_norm.cc @@ -58,9 +58,7 @@ Value CalculateShapeValue(Location loc, Value operand, int64_t rank = result_type.getRank(); shape_values.reserve(rank); for (int64_t i = 0; i < rank; ++i) { - auto index_value = rewriter.create<mlir::DimOp>(loc, operand, i); - shape_values.push_back(rewriter.create<mlir::IndexCastOp>( - loc, index_value, rewriter.getIntegerType(32))); + shape_values.push_back(rewriter.create<mlir::DimOp>(loc, operand, i)); } Type shape_element_type = shape_values.front().getType(); return rewriter.create<ScalarsToDimensionTensorOp>( @@ -137,8 +135,8 @@ class UnfuseBatchNormInferencePattern if (!epsilon) { return failure(); } - Value stddev = rewriter.create<xla_hlo::AddOp>( - bn_op.getLoc(), bn_op.variance(), epsilon, /*broadcast_dims=*/nullptr); + Value stddev = rewriter.create<xla_hlo::AddOp>(bn_op.getLoc(), + bn_op.variance(), epsilon); stddev = rewriter.create<xla_hlo::SqrtOp>(bn_op.getLoc(), stddev); // Broadcast all terms. @@ -162,13 +160,13 @@ class UnfuseBatchNormInferencePattern // Compute: // scale * (input - mean) / stddev + offset Value result = rewriter.create<xla_hlo::SubOp>( - bn_op.getLoc(), bn_op.operand(), broadcast_mean, nullptr); + bn_op.getLoc(), bn_op.operand(), broadcast_mean); result = rewriter.create<xla_hlo::MulOp>(bn_op.getLoc(), result, - broadcast_scale, nullptr); + broadcast_scale); result = rewriter.create<xla_hlo::DivOp>(bn_op.getLoc(), result, - broadcast_stddev, nullptr); - rewriter.replaceOpWithNewOp<xla_hlo::AddOp>(bn_op, result, broadcast_offset, - nullptr); + broadcast_stddev); + rewriter.replaceOpWithNewOp<xla_hlo::AddOp>(bn_op, result, + broadcast_offset); return success(); } diff --git a/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc b/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc index 799a20aa693..2b496677d62 100644 --- a/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc +++ b/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc @@ -573,6 +573,34 @@ class ConstConverter : public OpConversionPattern<xla_lhlo::ConstOp> { } }; +// TODO(b/156787842): Support the lowering for dynamic shapes. +template <typename OpTy, bool isLHLO = true> +class ReverseConverter + : public DataMovementOpConverter<ReverseConverter<OpTy, isLHLO>, OpTy, + isLHLO> { + public: + using DataMovementOpConverter<ReverseConverter<OpTy, isLHLO>, OpTy, + isLHLO>::DataMovementOpConverter; + static ArrayAttr getIndexingMapsAttr(OpTy op, Builder* b) { + auto resultType = + getXLAOpResultType<isLHLO>(op).template cast<ShapedType>(); + auto nloops = resultType.getRank(); + SmallVector<AffineExpr, 2> inputExprs; + inputExprs.reserve(nloops); + for (int i = 0; i < nloops; ++i) + inputExprs.push_back(b->getAffineDimExpr(i)); + for (auto dim : op.dimensions()) { + int i = dim.getZExtValue(); + if (resultType.isDynamicDim(i)) return {}; + int n = resultType.getShape()[i]; + inputExprs[i] = b->getAffineConstantExpr(n - 1) - inputExprs[i]; + } + return b->getAffineMapArrayAttr( + {AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()), + b->getMultiDimIdentityMap(nloops)}); + } +}; + class SliceConverter : public OpConversionPattern<xla_lhlo::SliceOp> { public: using OpConversionPattern<xla_lhlo::SliceOp>::OpConversionPattern; @@ -642,6 +670,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter<xla_lhlo::SubOp>, PointwiseToLinalgConverter<xla_lhlo::TanhOp>, ReshapeAddRemoveDimConverter<xla_lhlo::ReshapeOp>, + ReverseConverter<xla_lhlo::ReverseOp>, ScalarPointwiseToStandardConverter<xla_lhlo::AddOp>, SliceConverter >(context); @@ -742,6 +771,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter<xla_hlo::TanhOp, false>, ReshapeAddRemoveDimConverter<xla_hlo::ReshapeOp, false>, ReshapeOpConverter<xla_hlo::ReshapeOp, false>, + ReverseConverter<xla_hlo::ReverseOp, false>, TransposeConverter<xla_hlo::TransposeOp, false>>(context); } diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 00ed6d83e2e..c7be2c55de7 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -1579,8 +1579,6 @@ class BinaryOpsTest(xla_test.XLATestCase): np.array([4, 5, 6], dtype=np.int32), expected=None) - @test_util.disable_mlir_bridge( - "Requires BroadcastInDim method in MlirHloBuilder") def testBroadcastTo(self): for dtype in self.all_types: x = np.random.randint(0, high=100, size=[2, 3]) @@ -1591,29 +1589,16 @@ class BinaryOpsTest(xla_test.XLATestCase): expected=x) self._testBinary( array_ops.broadcast_to, - x, - np.array([6, 6], dtype=np.int32), - expected=np.tile(x, [3, 2])) + np.zeros([2, 3], dtype=dtype), + np.array([2, 2, 3], dtype=np.int32), + expected=np.zeros([2, 2, 3], dtype=dtype)) + + x = np.arange(2).reshape((2, 1)).astype(dtype) self._testBinary( array_ops.broadcast_to, x, - np.array([7, 4, 3], dtype=np.int32), - expected=np.tile(x, [7, 2, 1])) - self._testBinary( - array_ops.broadcast_to, - x, - np.array([7, 0, 3], dtype=np.int32), - expected=np.zeros([7, 0, 3], dtype=dtype)) - self._testBinary( - array_ops.broadcast_to, - x, - np.array([7, 1, 2, 9], dtype=np.int32), - expected=np.tile(x, [7, 1, 1, 3])) - self._testBinary( - array_ops.broadcast_to, - np.zeros([2, 0], dtype=dtype), - np.array([4, 0], dtype=np.int32), - expected=np.zeros([4, 0], dtype=dtype)) + np.array([2, 2, 3], dtype=np.int32), + expected=np.tile(x, (2, 1, 3))) x = np.arange(3).reshape((3, 1, 1, 1)).astype(dtype) self._testBinary( diff --git a/tensorflow/compiler/tests/data_format_ops_test.py b/tensorflow/compiler/tests/data_format_ops_test.py index 681c1f3499e..08d44256b50 100644 --- a/tensorflow/compiler/tests/data_format_ops_test.py +++ b/tensorflow/compiler/tests/data_format_ops_test.py @@ -81,11 +81,21 @@ class XlaPermuteOpTest(xla_test.XLATestCase): x = np.array([7, 4, 9, 3], dtype=dtype) self._runPermuteAndCompare(x, "NHWC", "NCHW", [7, 3, 4, 9]) + def testNHWCToNCHW_Size2(self): + for dtype in {np.int32, np.int64}: + x = np.array([4, 9], dtype=dtype) + self._runPermuteAndCompare(x, "NHWC", "NCHW", [4, 9]) + def testNCHWToNHWC(self): for dtype in {np.int32, np.int64}: x = np.array([7, 4, 9, 3], dtype=dtype) self._runPermuteAndCompare(x, "NCHW", "NHWC", [7, 9, 3, 4]) + def testNCHWToNHWC_Size2(self): + for dtype in {np.int32, np.int64}: + x = np.array([9, 3], dtype=dtype) + self._runPermuteAndCompare(x, "NCHW", "NHWC", [9, 3]) + def testNHWCToHWNC(self): for dtype in {np.int32, np.int64}: x = np.array([7, 4, 9, 3], dtype=dtype) diff --git a/tensorflow/compiler/tests/ternary_ops_test.py b/tensorflow/compiler/tests/ternary_ops_test.py index a1bb64eb88d..7bbfecff403 100644 --- a/tensorflow/compiler/tests/ternary_ops_test.py +++ b/tensorflow/compiler/tests/ternary_ops_test.py @@ -77,7 +77,6 @@ class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase): np.int32(2), expected=np.array([1, 3, 5], dtype=np.int32)) - @test_util.disable_mlir_bridge('TODO(b/155949336)') def testSelect(self): for dtype in self.numeric_types: self._testTernary( diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index 3e36f67615b..d0e928a5ce6 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -601,7 +601,6 @@ class UnaryOpsTest(xla_test.XLATestCase): np.array([-1, -0.5, 0, 0.3], dtype=dtype), expected=np.array([-1., -0.5, 0., 0.296875], dtype=dtype)) - @test_util.disable_mlir_bridge("TODO(b/156135423): Fix ConvertSigmoidOp") def testComplexOps(self): for dtype in self.complex_types: diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc index 806d930b76f..aed422a5627 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc @@ -41,14 +41,11 @@ limitations under the License. #include "tensorflow/core/grappler/clusters/virtual_cluster.h" #include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/grappler/devices.h" -#include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/optimizers/meta_optimizer.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/types.h" #include "tensorflow/core/protobuf/config.pb.h" // NOLINT #include "tensorflow/core/protobuf/device_properties.pb.h" // NOLINT #include "tensorflow/core/protobuf/rewriter_config.pb.h" // NOLINT @@ -90,8 +87,6 @@ bool AllowDynamicNonBatchDimension(const ConversionParams& params) { GetEngineType(params) == EngineInfo::EngineType::TRTDynamic; } -} // namespace - struct EdgePtrCompare { bool operator()(const Edge* lhs, const Edge* rhs) const { return lhs->id() < rhs->id(); @@ -555,6 +550,13 @@ Status CreateTRTNode(const ConversionParams& params, return Status::OK(); } +int64 GetNextGraphSequenceNumber() { + static std::atomic<int64> graph_sequence_num; + return graph_sequence_num++; +} + +} // namespace + Status RegisterGraphToFunctionLibrary(const GraphDef& segment_graph_def, Graph* graph, const string& engine_name) { Graph segment_graph(graph->flib_def()); @@ -629,11 +631,6 @@ std::pair<int, Allocator*> GetDeviceAndAllocator(const ConversionParams& params, return std::make_pair(cuda_device_id, dev_allocator); } -int64 GetNextGraphSequenceNumber() { - static std::atomic<int64> graph_sequence_num; - return graph_sequence_num++; -} - // Entry function from optimization pass. Status ConvertAfterShapes(const ConversionParams& params) { // Sanity checks. @@ -643,12 +640,15 @@ Status ConvertAfterShapes(const ConversionParams& params) { "Calibration with FP32 or FP16 is not supported."); } + grappler::GraphProperties static_graph_properties(*params.grappler_item); + TF_RETURN_IF_ERROR(static_graph_properties.InferStatically(true)); + + const GraphDef& graph_def = params.grappler_item->graph; // Convert graphdef to graph. - FunctionLibraryDefinition flib(OpRegistry::Global(), - params.input_graph_def->library()); + FunctionLibraryDefinition flib(OpRegistry::Global(), graph_def.library()); Graph graph(flib); - TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(), - *params.input_graph_def, &graph)); + TF_RETURN_IF_ERROR( + ConvertGraphDefToGraph(GraphConstructorOptions(), graph_def, &graph)); // Segment the graph into subgraphs that can be converted to TensorRT segment::SegmentOptions segment_options; @@ -662,10 +662,10 @@ Status ConvertAfterShapes(const ConversionParams& params) { AllowDynamicNonBatchDimension(params); segment::SegmentNodesVector initial_segments; - TrtNodeValidator validator(*params.graph_properties, params.precision_mode, + TrtNodeValidator validator(static_graph_properties, params.precision_mode, params.use_calibration, params.use_implicit_batch); TF_RETURN_IF_ERROR(segment::SegmentGraph( - &graph, params.graph_properties, + &graph, &static_graph_properties, std::bind(&TrtNodeValidator::IsTensorRTCandidate, &validator, std::placeholders::_1), // Input validation is already done by TrtNodeValidator, so we don't @@ -693,9 +693,8 @@ Status ConvertAfterShapes(const ConversionParams& params) { auto& curr_segment = initial_segments.at(t); EngineInfo curr_engine; curr_engine.engine_name = StrCat(engine_name_prefix, t); - Status status = - GetEngineInfo(&graph, *params.graph_properties, curr_segment, node_map, - reverse_topo_order, &curr_engine); + Status status = GetEngineInfo(&graph, static_graph_properties, curr_segment, + node_map, reverse_topo_order, &curr_engine); if (!status.ok()) { LOG(WARNING) << "Failed to get engine info for segment " << t << ": " << status; diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h index 2bfaa2a786c..53ab84a6fa9 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h @@ -18,10 +18,9 @@ limitations under the License. #include <vector> #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h" -#include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/grappler/clusters/cluster.h" -#include "tensorflow/core/grappler/costs/graph_properties.h" +#include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" @@ -33,7 +32,7 @@ namespace tensorrt { namespace convert { struct ConversionParams { - const GraphDef* input_graph_def = nullptr; + const grappler::GrapplerItem* grappler_item = nullptr; const std::vector<string>* output_names = nullptr; string trt_logger_name; size_t max_batch_size = 1; @@ -41,7 +40,6 @@ struct ConversionParams { GraphDef* output_graph_def = nullptr; TrtPrecisionMode precision_mode = TrtPrecisionMode::FP32; int minimum_segment_size = 3; - const grappler::GraphProperties* graph_properties = nullptr; const grappler::Cluster* cluster = nullptr; // Whether to create engine on conversion or execution time bool is_dyn_op = false; diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc index 2cfefd27a67..a1f523d6bfa 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc @@ -162,12 +162,11 @@ class ConvertAfterShapesTest : public ::testing::Test { // Construct ConversionParams. const std::vector<string> output_names{"output"}; ConversionParams params; - params.input_graph_def = &item.graph; params.output_names = &output_names; params.max_workspace_size_bytes = 8 << 20; params.output_graph_def = output_graph_def; params.minimum_segment_size = 1; - params.graph_properties = &graph_properties; + params.grappler_item = &item; params.use_calibration = false; params.trt_logger_name = "DefaultLogger"; diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index a43b16e9e6a..132c4d6dd68 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "tensorflow/compiler/tf2tensorrt/convert/utils.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" @@ -795,6 +796,19 @@ nvinfer1::Dims TRT_TensorOrWeights::GetTrtDims() const { } } +Status TRT_TensorOrWeights::GetTfType(DataType* tf_type) const { + if (is_tensor()) { + nvinfer1::DataType trt_type = tensor()->getType(); + return TrtTypeToTfType(trt_type, tf_type); + } + + if (is_weights()) { + *tf_type = weights().GetTensor().dtype(); + return Status::OK(); + } + return errors::Internal("The object is probably not initialized"); +} + string TRT_TensorOrWeights::DebugString() const { string output = "TRT_TensorOrWeights(type="; if (is_tensor()) { @@ -1900,27 +1914,48 @@ Status CheckInputsWeights( return Status::OK(); } -Status AllowDataTypes(const OpConverterParams& params, - const std::set<DataType>& allowed_dtypes, - const char* dtype_attr_name = "T") { - const auto& node_def = params.node_def; +Status GetNodeDefTfType(const NodeDef& node_def, DataType* tf_type, + const char* type_attr_name) { TFAttrs attrs(node_def); - if (!attrs.count(dtype_attr_name)) { - return errors::InvalidArgument("Attribute with name ", dtype_attr_name, + if (!attrs.count(type_attr_name)) { + return errors::InvalidArgument("Attribute with name ", type_attr_name, " not found."); } - const auto op_dtype = attrs.get<DataType>(dtype_attr_name); - if (!allowed_dtypes.count(op_dtype)) { - // Build string list of allowed types. - std::ostringstream ss; - for (auto it = allowed_dtypes.begin(); it != allowed_dtypes.end(); ++it) { - if (it != allowed_dtypes.begin()) ss << ", "; - ss << DataTypeString(*it); - } - return errors::Unimplemented("Data type ", DataTypeString(op_dtype), + *tf_type = attrs.get<DataType>(type_attr_name); + return Status::OK(); +} + +Status GetInputTfType(const OpConverterParams& params, DataType* tf_type, + int pos) { + const std::vector<TRT_TensorOrWeights>& inputs = params.inputs; + if (inputs.size() <= pos) { + return errors::Internal("Invalid input position"); + } + + return inputs[pos].GetTfType(tf_type); +} + +constexpr const char kOutputTypeAttrName[] = "T"; + +Status GetOutputTfType(const OpConverterParams& params, DataType* tf_type) { + return GetNodeDefTfType(params.node_def, tf_type, kOutputTypeAttrName); +} + +Status AllowDataTypes(const OpConverterParams& params, + const std::set<DataType>& allowed_types, + const char* type_attr_name = kOutputTypeAttrName) { + const auto& node_def = params.node_def; + DataType tf_type; + TF_RETURN_IF_ERROR(GetNodeDefTfType(node_def, &tf_type, type_attr_name)); + if (!allowed_types.count(tf_type)) { + string allowed_types_string = absl::StrJoin( + allowed_types, ", ", [](string* out, const DataType& type) { + absl::StrAppendFormat(out, "%s", DataTypeString(type)); + }); + return errors::Unimplemented("Data type ", DataTypeString(tf_type), " is not supported for ", node_def.op(), - ", must be one of [", ss.str(), "], at ", - node_def.name()); + ", must be one of [", allowed_types_string, + "], at ", node_def.name()); } return Status::OK(); } @@ -2111,6 +2146,12 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group, "Stride must be 1 for batch and channel dimensions, at ", node_def.name()); } + // Channel dim must be static for DepthwiseConv2dNative since we use that + // value for num_groups at build time. + if (!params->use_implicit_batch && tensor->getDimensions().d[c_index] == -1) { + return errors::InvalidArgument("Channel dimension must be static, at ", + node_def.name()); + } const nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]); if (params->validation_only) return Status::OK(); @@ -2122,11 +2163,12 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group, } // Dimensions of transposed tensor. const auto tensor_dim = tensor->getDimensions(); + const int c_dim_size = tensor_dim.d[params->use_implicit_batch ? 0 : 1]; // group == 0 signifies that this is a depthwise convolution, so set // num_groups to size of input's channel dim. For a non-depthwise conv, // num_groups will be 1. - const int num_groups = (group == 0) ? tensor_dim.d[0] : group; + const int num_groups = (group == 0) ? c_dim_size : group; // For conv, TF weights are RSCK, and TRT expects KCRS. // For backprop, TF weights are RSKC, and TRT expects CKRS. @@ -4598,6 +4640,42 @@ Status ConvertUnpack(OpConverterParams* params) { return ConvertSplitHelper(params, inputs.at(0), tf_axis, num, true); } +// Supports cast fp16=>fp32 through IIdentityLayer. +Status ConvertCast(OpConverterParams* params) { + const NodeDef& node_def = params->node_def; + TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}})); + auto unsupport_cast_error = [&]() { + return errors::Unimplemented("Cast op: ", node_def.op(), + " not supported at: ", node_def.name()); + }; + + DataType input_type; + TF_RETURN_IF_ERROR(GetInputTfType(*params, &input_type, 0)); + if (input_type != DataType::DT_HALF) { + return unsupport_cast_error(); + } + + DataType output_type; + TF_RETURN_IF_ERROR(GetOutputTfType(*params, &output_type)); + if (output_type != DataType::DT_FLOAT) { + return unsupport_cast_error(); + } + + if (params->validation_only) return Status::OK(); + + nvinfer1::ITensor* input = params->inputs.at(0).tensor(); + nvinfer1::IIdentityLayer* layer = + params->converter->network()->addIdentity(*input); + layer->setPrecision(nvinfer1::DataType::kFLOAT); + + if (layer->getOutput(0)->getType() != nvinfer1::DataType::kFLOAT) { + return errors::Internal("IIdentityLayer doesn't work as expected"); + } + + params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0))); + return Status::OK(); +} + Status ConvertConcat(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; @@ -5675,6 +5753,7 @@ static void RegisterValidatableOpConverters( (*registration)["CombinedNonMaxSuppression"] = ConvertCombinedNMS; #endif (*registration)["AddN"] = ConvertAddN; + (*registration)["Cast"] = ConvertCast; (*registration)["ConcatV2"] = ConvertConcat; (*registration)["Const"] = ConvertConst; (*registration)["Conv2D"] = ConvertConv2D; diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h index 2092aecd657..2fe8eec9675 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h @@ -294,6 +294,8 @@ class TRT_TensorOrWeights { nvinfer1::Dims GetTrtDims() const; + Status GetTfType(DataType* tf_type) const; + int batch_size() const { return batch_size_; } string DebugString() const; diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc index 884ed7a5771..57b2e13fad0 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc @@ -67,9 +67,7 @@ namespace convert { using absl::StrCat; using ::testing::ElementsAre; using ::testing::ElementsAreArray; -using ::testing::FloatNear; using ::testing::Matcher; -using ::testing::NanSensitiveFloatNear; // TensorRT modes for testing. We define the following three modes: // 1. Implicit batch mode: The tensors have static (known) input shape and the @@ -137,30 +135,18 @@ std::ostream& operator<<(std::ostream& os, const std::vector<T>& v) { return os; } -nvinfer1::DataType TfDataTypeToTrt(DataType tf_dtype) { - switch (tf_dtype) { - case DT_FLOAT: - return nvinfer1::DataType::kFLOAT; - case DT_HALF: - return nvinfer1::DataType::kHALF; - case DT_INT32: - return nvinfer1::DataType::kINT32; - default: - QCHECK(false) << "Unexpected data type " << DataTypeString(tf_dtype); - } +nvinfer1::DataType TfDataTypeToTrt(DataType tf_type) { + nvinfer1::DataType trt_type; + Status status = TfTypeToTrtType(tf_type, &trt_type); + EXPECT_EQ(status, Status::OK()); + return trt_type; } -DataType TrtDataTypeToTf(nvinfer1::DataType trt_dtype) { - switch (trt_dtype) { - case nvinfer1::DataType::kFLOAT: - return DT_FLOAT; - case nvinfer1::DataType::kHALF: - return DT_HALF; - case nvinfer1::DataType::kINT32: - return DT_INT32; - default: - QCHECK(false) << "Unexpected data type " << static_cast<int>(trt_dtype); - } +DataType TrtDataTypeToTf(nvinfer1::DataType trt_type) { + DataType tf_type; + Status status = TrtTypeToTfType(trt_type, &tf_type); + EXPECT_EQ(status, Status::OK()); + return tf_type; } NodeDef MakeNodeDef(const string& name, const string& op, @@ -225,9 +211,12 @@ Matcher<std::vector<float>> ArrayFloatNear(const std::vector<float>& values, matchers.reserve(values.size()); for (const float& v : values) { if (nan_sensitive) { - matchers.emplace_back(NanSensitiveFloatNear(v, max_abs_error)); + matchers.emplace_back(::testing::NanSensitiveFloatNear(v, max_abs_error)); + } else if (max_abs_error == 0) { + matchers.emplace_back(::testing::FloatEq(v)); } else { - matchers.emplace_back(FloatNear(v, max_abs_error)); + EXPECT_GE(max_abs_error, 0); + matchers.emplace_back(::testing::FloatNear(v, max_abs_error)); } } return ElementsAreArray(matchers); @@ -310,7 +299,8 @@ struct StaticCaster { }; template <typename InCType, typename OutCType> -std::vector<OutCType> CastTestVector(const std::vector<InCType>& vals) { +std::vector<OutCType> CastTestVector( + const gtl::ArraySlice<InCType>& vals) { // non-absl ok std::vector<OutCType> res(vals.size()); std::transform(vals.begin(), vals.end(), res.begin(), StaticCaster<InCType, OutCType>()); @@ -1300,6 +1290,21 @@ inline absl::Span<const T> GetSpanForData(const InputOutputData& data) { return absl::Span<const T>(tensor_map.data(), tensor_map.size()); } +std::vector<float> GetDataAsFloat(InputOutputData& data) { + if (data.tensor.dtype() == DT_FLOAT) { + auto span = GetSpanForData<float>(data); + return std::vector<float>(span.begin(), span.end()); + } + if (data.tensor.dtype() == DT_HALF) { + return CastTestVector<Eigen::half, float>( + GetSpanForData<Eigen::half>(data)); + } + if (data.tensor.dtype() == DT_INT32) { + return CastTestVector<int32, float>(GetSpanForData<int32>(data)); + } + LOG(FATAL) << "DataType not supported for testing " + << DataTypeString(data.tensor.dtype()); +} // Class to test various op converters, using both a TrtNodeValidator and // Converter. class OpConverterTest : public ::testing::Test { @@ -1353,6 +1358,33 @@ class OpConverterTest : public ::testing::Test { return ret; } + // Constructs a tensor with given values (vals). The tensor type is defined by + // the tf_dtype argument, its shape is given by input_dims. The tensor is + // constructed using the allocator of OpConverterTest in Unified Memory. + template <typename T> + Tensor AsTensor(std::vector<T> vals, const std::vector<int> input_dims, + DataType tf_dtype) { + Tensor ret(allocator_.get(), tf_dtype, {static_cast<int64>(vals.size())}); + if (tf_dtype == DT_FLOAT) { + auto conv_vals = CastTestVector<T, float>(vals); + std::copy_n(conv_vals.data(), conv_vals.size(), ret.flat<float>().data()); + } else if (tf_dtype == DT_HALF) { + auto conv_vals = CastTestVector<T, Eigen::half>(vals); + std::copy_n(conv_vals.data(), conv_vals.size(), + ret.flat<Eigen::half>().data()); + } else if (tf_dtype == DT_INT32) { + auto conv_vals = CastTestVector<T, int32>(vals); + std::copy_n(conv_vals.data(), conv_vals.size(), ret.flat<int32>().data()); + } else { + LOG(FATAL) << "Cannot create tensor with type " + << DataTypeString(tf_dtype); + } + TensorShape shape; + TF_EXPECT_OK(TensorShapeUtils::MakeShape(input_dims, &shape)); + CHECK(ret.CopyFrom(ret, shape)); + return ret; + } + // Constructs a flat tensor in Unified Memory. template <typename T> Tensor ConstructTensor(int data_size, const T& value = T()) { @@ -1360,6 +1392,13 @@ class OpConverterTest : public ::testing::Test { return AsTensor<T>(values); } + // Constructs a flat tensor in Unified Memory. + template <typename T> + Tensor ConstructTensor(int data_size, const T& value, DataType tf_dtype) { + std::vector<T> values(data_size, value); + return AsTensor<T>(values, {data_size}, tf_dtype); + } + void CheckDataTypeMatches(const DataVec& datas) { for (const auto& data : datas) { const int input_index = engine_->getBindingIndex(data.name.c_str()); @@ -1373,27 +1412,29 @@ class OpConverterTest : public ::testing::Test { } } - void BuildAndRun(const DataVec& input_data, DataVec* output_data, - const int batch_size = 1) { + Status BuildAndRun(const DataVec& input_data, DataVec* output_data, + const int batch_size = 1) { // Mark the output tensor as TRT engine output. std::vector<Converter::EngineOutputInfo> output_info; for (const auto& data : *output_data) { output_info.push_back( {data.name, data.name, TfDataTypeToTrt(data.tensor.dtype())}); } - TF_EXPECT_OK(converter_->RenameAndMarkOutputTensors(output_info)); + TF_RETURN_IF_ERROR(converter_->RenameAndMarkOutputTensors(output_info)); // Build the TRT engine. - ASSERT_EQ(nullptr, engine_.get()); + if (engine_.get() != nullptr) { + return errors::Internal("Engine already exists"); + } TrtShapeOptimizationProfile profiles; if (!converter_->use_implicit_batch()) { // Create a single optimization profile for explicit batch mode std::vector<TensorShape> input_shapes; - TF_ASSERT_OK(GetShapeFromDataVec(input_data, &input_shapes)); + TF_RETURN_IF_ERROR(GetShapeFromDataVec(input_data, &input_shapes)); profiles.AddShape(input_shapes); profiles.InitProfiles(); } - TF_ASSERT_OK( + TF_RETURN_IF_ERROR( converter_->BuildCudaEngine(&engine_, /*max_batch_size=*/batch_size, /*max_workspace_size_bytes=*/1 << 26, @@ -1407,7 +1448,9 @@ class OpConverterTest : public ::testing::Test { const int num_bindings = input_data.size() + output_data->size(); std::vector<void*> buffers(num_bindings); - ASSERT_EQ(engine_->getNbBindings(), num_bindings); + if (engine_->getNbBindings() != num_bindings) { + return errors::Internal("Number of bindings do not match"); + } // Since we have only 1 optimization profile (which is enabled by default) // it is fine to create execution context directly, instead of calling // profiles.CreateExecutionContexts() @@ -1415,19 +1458,19 @@ class OpConverterTest : public ::testing::Test { engine_->createExecutionContext()); // Prepare input bindings. - TF_ASSERT_OK(SetTrtEngineInputs(engine_.get(), execution_context.get(), 0, - buffers, converter_->use_implicit_batch(), - batch_size, nullptr, &input_data)); - + TF_RETURN_IF_ERROR(SetTrtEngineInputs( + engine_.get(), execution_context.get(), 0, buffers, + converter_->use_implicit_batch(), batch_size, nullptr, &input_data)); // Prepare output bindings. - TF_ASSERT_OK(SetTrtEngineOutputs(engine_.get(), execution_context.get(), 0, - buffers, converter_->use_implicit_batch(), - batch_size, nullptr, output_data)); - + TF_RETURN_IF_ERROR(SetTrtEngineOutputs( + engine_.get(), execution_context.get(), 0, buffers, + converter_->use_implicit_batch(), batch_size, nullptr, output_data)); // Execute the TRT engine. - TF_ASSERT_OK(TrtEnqueue(execution_context.get(), buffers, stream_, - converter_->use_implicit_batch(), batch_size)); + TF_RETURN_IF_ERROR(TrtEnqueue(execution_context.get(), buffers, stream_, + converter_->use_implicit_batch(), + batch_size)); cudaStreamSynchronize(stream_); + return Status::OK(); } bool HasStaticShape(const nvinfer1::Dims& dims) const { @@ -1444,7 +1487,7 @@ class OpConverterTest : public ::testing::Test { // Adds ITensor for both validation and conversion, assuming explicit batch // dimension is included in dims (ie for an NCHW tensor dims = {N, C, H, W}). - void AddTestTensorWithExplicitBatchDim( + void AddTestTensorWithTFDims( const string& name, const std::vector<int32>& dims, nvinfer1::DataType trt_dtype = nvinfer1::DataType::kFLOAT) { DataType tf_dtype = TrtDataTypeToTf(trt_dtype); @@ -1464,54 +1507,19 @@ class OpConverterTest : public ::testing::Test { } } - // Adds ITensor for both validation and conversion. The tensor can have - // partial input shape. This function defines static or dynamic shape input - // tensor for the network based on the trt_mode attribute. This is done - // automatically, unless the user overrides it with an explicit - // partial_input_shape_dims argument. - // - // Parameters: - // - dims actual dimensions of the tensor that we will use during the test - // (including explicit batch dim). This is not used if partial_input_shape - // is defined. - // - partial_input_shape dimensions which can incude unknown shapes. This can - // be empty, in that case the partial_input_shape will be set automatically - // depending on the trt_mode argument. (This also includse explicit batch - // dim). - // - // On return skip_test is false if trt_mode is not compatible with the - // partial input shape. - void AddTestTensor( - const string& name, const std::vector<int32>& dims, - nvinfer1::DataType trt_dtype, TrtTestMode trt_mode, - const std::vector<int32>* partial_input_shape_dims = nullptr) { - std::vector<int32> partial_shape; - if (partial_input_shape_dims && !partial_input_shape_dims->empty()) { - partial_shape = *partial_input_shape_dims; - } else { - if (trt_mode == TrtTestMode::kDynamicShape) { - // In dynamic shape mode we set the all dims unknown. - partial_shape = std::vector<int32>(dims.size(), -1); - } else { - // Use static (known) input shapes. - partial_shape = dims; - } - } - AddTestTensorWithExplicitBatchDim(name, partial_shape, trt_dtype); - } - // Adds ITensor for both validation and conversion. The difference compared to - // AddTestTensorWithExplicitBatchDim is in the meaning of the dims parameter. - // To define a tensor with NCHW shape, here we set dims = {C,H,W} and - // batch_size = N. TODO(tfeher) remove this function once all test are updated - // to use the other version of AddTestTensor which has the trt_mode arg. + // AddTestTensorWithTFDims is in the meaning of the dims parameter. To define + // a tensor with NCHW shape, here we set dims = {C,H,W} and batch_size = N. + // TODO(tfeher) remove this function once all test are updated to use the + // other version of AddTestTensor (defined by + // ParameterizedOpConverterTestBase). void AddTestTensor( const string& name, const std::vector<int32>& dims, int batch_size = 1, nvinfer1::DataType trt_dtype = nvinfer1::DataType::kFLOAT) { std::vector<int32> dims_with_batch(dims.size() + 1); dims_with_batch[0] = batch_size; std::copy(dims.begin(), dims.end(), dims_with_batch.begin() + 1); - AddTestTensorWithExplicitBatchDim(name, dims_with_batch, trt_dtype); + AddTestTensorWithTFDims(name, dims_with_batch, trt_dtype); if (HasStaticShape(dims)) { ASSERT_EQ(batch_size, converter_->batch_size_); } @@ -1544,6 +1552,21 @@ class OpConverterTest : public ::testing::Test { converter_->AddTensorOrWeights(name, TRT_TensorOrWeights{weights})); } + template <typename T> + void AddTestWeights(const string& name, const std::vector<int>& dims, + const std::vector<T>& values, DataType tf_dtype) { + if (tf_dtype == DT_FLOAT) { + AddTestWeights(name, dims, CastTestVector<T, float>(values)); + } else if (tf_dtype == DT_HALF) { + AddTestWeights(name, dims, CastTestVector<T, Eigen::half>(values)); + } else if (tf_dtype == DT_INT32) { + AddTestWeights(name, dims, CastTestVector<T, int32>(values)); + } else { + FAIL() << "Cannot create test weights with type " + << DataTypeString(tf_dtype); + } + } + // Test validation in validation-only mode. void RunValidation(const Node* node, error::Code expected_code = error::OK, const char* expected_msg_substr = nullptr) { @@ -1681,20 +1704,146 @@ std::ostream& operator<<(std::ostream& os, const TestParamBase& p) { return os; } -// Parameterized version of OpConverterTest. This class will be instantiated -// to test all the TrtTestModes but only in FP32 precision. This means that we -// will use the following combinations of test parameters: +// Parameterized version of OpConverterTest. We have the following parameters: // 1. TrtTestMode: implicit batch, explicit batch, dynamic shape modes -// 2. DataType of the input TF tensors: DT_FLOAT -// 3. TrtPrecisionMode argument for the Converter: FP32 -class ParameterizedOpConverterTest +// 2. DataType of the input TF tensors: DT_FLOAT, DT_HALF, DT_INT32 +// 3. TrtPrecisionMode argument for the Converter: FP32, FP16, INT8 +// We will introduce subclasses that will be instantiated using different +// combinations of the DataType and TrtPrecisionMode parameters. +class ParameterizedOpConverterTestBase : public OpConverterTest, public ::testing::WithParamInterface< - std::tuple<TrtTestMode, DataType, TrtPrecisionMode>> {}; + std::tuple<TrtTestMode, DataType, TrtPrecisionMode>> { + public: + ParameterizedOpConverterTestBase() + : trt_mode(std::get<0>(GetParam())), + tf_dtype(std::get<1>(GetParam())), + converter_precision(std::get<2>(GetParam())) {} -// Instantiate parameter combinations to test. For debugging purposes it might -// make sense to run over all possible combinations, but normally a subset of -// them would be sufficient: + void Reset() { + OpConverterTest::Reset(converter_precision, trt_mode); + input_data_.clear(); + } + + // Adds an input ITensor for TRT network. Also creates the corresponding TF + // tensor, and stores it in the list of inputs (input_data_). + // + // The TF tensor is always created with concrete static input shape given by + // dims. The ITensor can have static or dynamic shape based on the trt_mode + // attribute. The ITensor shape is set automatically according to the trt_mode + // parameter, unless the user overrides it with an explicit + // partial_input_shape_dims argument. + // + // Parameters: + // - name of the input node + // - dims actual dimensions of the tensor that we will use during the test + // (including explicit batch dim) + // - values initial values for the TF tensor + // - dtype data type of the tensor + // - partial_input_shape dimensions which can incude unknown shapes. This can + // be empty, in that case the partial_input_shape will be set automatically + // depending on the trt_mode argument. (This argument also includes explicit + // batch dim). + // + template <typename T> + void AddTestTensor(const string& name, const std::vector<int32>& dims, + DataType tf_dtype, const std::vector<T>& values, + const std::vector<int32>& partial_input_shape_dims = {}) { + std::vector<int32> partial_shape; + if (!partial_input_shape_dims.empty()) { + partial_shape = partial_input_shape_dims; + } else { + if (trt_mode == TrtTestMode::kDynamicShape) { + // In dynamic shape mode we make all dims unknown. + partial_shape = std::vector<int32>(dims.size(), -1); + } else { + // Use static (known) input shapes. + partial_shape = dims; + } + } + AddTestTensorWithTFDims(name, partial_shape, TfDataTypeToTrt(tf_dtype)); + if (!values.empty()) { + VLOG(2) << "Adding test tensor: " << name << " " + << DataTypeString(tf_dtype); + InputOutputData data{name, AsTensor(values, dims, tf_dtype)}; + VLOG(2) << "Added tensor: " << data.name + << DataTypeString(data.tensor.dtype()); + input_data_.push_back(data); + } + } + + // Adds test tensor (same as above) but with the default tf_dtype defined by + // the test params. + void AddTestTensor(const string& name, const std::vector<int32>& dims, + const std::vector<float>& values = {}, + const std::vector<int32>& partial_input_shape_dims = {}) { + AddTestTensor<float>(name, dims, tf_dtype, values, + partial_input_shape_dims); + } + + // Builds and runs the converted network. Checks output tensor shape. Tests + // output values using a matcher. The network can have multiple input and + // output tensors. The inputs are defined by the input_data_ member variable. + void BuildAndRun(const string& name, + const std::vector<std::vector<int>>& expected_output_dims, + const Status& expected_runtime_status, + const std::vector<Matcher<std::vector<float>>>& matcher) { + TensorShape shape; + const int n_output = expected_output_dims.size(); + ASSERT_EQ(n_output, matcher.size()); + DataVec output_data; + for (int i = 0; i < n_output; i++) { + TF_EXPECT_OK( + TensorShapeUtils::MakeShape(expected_output_dims[i], &shape)); + string out_name = (n_output == 1) ? name : StrCat(name, ":", i); + InputOutputData data{out_name, + ConstructTensor(shape.num_elements(), 0, tf_dtype)}; + output_data.push_back(data); + } + ASSERT_FALSE(input_data_.empty()); + const int batch_size = input_data_[0].tensor.shape().dim_size(0); + Status stat = + OpConverterTest::BuildAndRun(input_data_, &output_data, batch_size); + ASSERT_EQ(expected_runtime_status, stat); + if (expected_runtime_status.ok() && stat.ok()) { + for (int i = 0; i < n_output; i++) { + // Check the shape of the actual output tensors + TF_EXPECT_OK( + TensorShapeUtils::MakeShape(expected_output_dims[i], &shape)); + EXPECT_TRUE(output_data[i].tensor.shape() == shape) + << "Expected shape: " << shape.DebugString() << ", actual shape" + << output_data[i].tensor.shape().DebugString(); + EXPECT_THAT(GetDataAsFloat(output_data[i]), matcher[i]); + } + } + } + + // Runs validation and conversion. If conversion is successfull then builds + // the TRT network, executes it and checks the output. + void TestOpConverter(const string& name, const NodeDef node_def, + const std::vector<int>& expected_output_dims, + const Status& expected_conversion_status, + const Status& expected_runtime_status, + const Matcher<std::vector<float>>& matcher) { + RunValidationAndConversion(node_def, expected_conversion_status, + name.c_str(), expected_output_dims); + if (expected_conversion_status.ok()) { + BuildAndRun(name, std::vector<std::vector<int>>({expected_output_dims}), + expected_runtime_status, + std::vector<Matcher<std::vector<float>>>({matcher})); + } + } + + protected: + const TrtTestMode trt_mode; + const DataType tf_dtype; + const TrtPrecisionMode converter_precision; + DataVec input_data_; +}; + +// Op converter test in FP32 mode. While for debugging purposes it might make +// sense to run over all possible combinations, normally a subset of them +// would be sufficient: // - All valid options to TrtTestMode (implicit, explicit, dynamic shape) // - DataType: is the TF data type of the input tensors. This usually only // influences the data type added by Converter::AddInputTensor. We test the @@ -1704,66 +1853,15 @@ class ParameterizedOpConverterTest // how TRT handles the precision inside the TRT network, but should not matter // for the TF -> TRT conversion. Therefore it should be sufficient to test // for FP32. +class OpConverterTest1 : public ParameterizedOpConverterTestBase {}; + +// Instantiate parameter combinations to OpConverterTest1 INSTANTIATE_TEST_CASE_P( - OpConvTestInstantiation, ParameterizedOpConverterTest, + OpConvTestInstantiation, OpConverterTest1, ::testing::Combine(::testing::ValuesIn(ValidTrtModes), ::testing::Values(DT_FLOAT), ::testing::Values(TrtPrecisionMode::FP32))); -// Builds and runs the converted network. Checks output tensor shape. Tests -// output values using a matcher. -template <DataType dtype> -void BuildAndRunConvertedNetwork(const string& name, OpConverterTest* test, - const TestParamBase& p, - const std::vector<float>& input_vec, - const Matcher<std::vector<float>>& matcher) { - if (!p.status.ok()) { - // conversion was not successful, we cannot run the network - return; - } - if (!p.runtime_status.ok()) { - // Runtime error is expected. This can happen if the operation is invalid - // for the actual input shape. Usually we catch these errors during - // conversion. If the network was defined with dynamic input shape than we - // have to postpone these steps until runtime. - // - // TODO(tfeher) Instead of early return, modify BuildAndRun to handle - // runtime errors. - return; - } - typedef typename EnumToDataType<dtype>::Type T; - TensorShape shape; - TF_EXPECT_OK(TensorShapeUtils::MakeShape(p.input_dims, &shape)); - const DataVec input_data{ - {"input", test->AsTensor<T>(CastTestVector<float, T>(input_vec), shape)}}; - DataVec output_data{{name, test->ConstructTensor<T>(6)}}; - test->BuildAndRun(input_data, &output_data); - // Check the shape of the actual output tensor - TF_EXPECT_OK(TensorShapeUtils::MakeShape(p.expected_output_dims, &shape)); - EXPECT_TRUE(output_data[0].tensor.shape() == shape) - << "Expected shape: " << shape.DebugString() << ", actual shape" - << output_data[0].tensor.shape().DebugString(); - // Cast the output to float and compare to expected output - auto out_span = GetSpanForData<T>(output_data[0]); - std::vector<float> casted_output(out_span.begin(), out_span.end()); - EXPECT_THAT(casted_output, matcher); -} - -void InstantiateBuildAndRun(DataType tf_dtype, const string& name, - OpConverterTest* test, const TestParamBase& p, - const std::vector<float>& input_vec, - const Matcher<std::vector<float>>& matcher) { - if (tf_dtype == DT_FLOAT) { - BuildAndRunConvertedNetwork<DT_FLOAT>(name, test, p, input_vec, matcher); - } else if (tf_dtype == DT_HALF) { - BuildAndRunConvertedNetwork<DT_HALF>(name, test, p, input_vec, matcher); - } else if (tf_dtype == DT_INT32) { - BuildAndRunConvertedNetwork<DT_INT32>(name, test, p, input_vec, matcher); - } else { - FAIL() << "Test not supported for " << tf_dtype; - } -} - template <typename T> void CopyTensorElements(const Tensor& tensor, protobuf::RepeatedField<T>* out) { out->Clear(); @@ -1901,14 +1999,7 @@ TEST_F(OpConverterTest, ConvertConst) { TestConvertConst<DT_UINT64, uint64, int32>(this); } -TEST_P(ParameterizedOpConverterTest, ConvertTranspose) { - const auto& spec = GetParam(); - const TrtTestMode trt_mode = std::get<0>(spec); - // Data type of TF input tensors - const DataType tf_dtype = std::get<1>(spec); - // Precision mode used for TensorRT engine - TrtPrecisionMode converter_precision = std::get<2>(spec); - +TEST_P(OpConverterTest1, ConvertTranspose) { // Get the NodeDef for Transpose. Scope s = Scope::NewRootScope(); auto input = ops::Placeholder(s.WithOpName("input"), tf_dtype); @@ -1919,7 +2010,7 @@ TEST_P(ParameterizedOpConverterTest, ConvertTranspose) { std::vector<TestParamBase> test_params = { // For the first test we leave param empty. This signals to use a // input as weight which will be invalid - TestParamBase{{1, 1, 2, 3}, + TestParamBase{{3, 1, 2, 1}, {}, {}, {}, @@ -1953,20 +2044,17 @@ TEST_P(ParameterizedOpConverterTest, ConvertTranspose) { std::vector<float> expected_values{1, 4, 2, 5, 3, 6}; for (auto p : test_params) { SCOPED_TRACE(p); - Reset(converter_precision, trt_mode); - AddTestTensor("input", p.input_dims, TfDataTypeToTrt(tf_dtype), trt_mode, - &p.partial_input_dims); + Reset(); + AddTestTensor("input", p.input_dims, {1, 2, 3, 4, 5, 6}, + p.partial_input_dims); if (p.param.empty()) { AddTestTensor("weights", {3}); } else { AddTestWeights<int32>("weights", {static_cast<int>(p.param.size())}, p.param); } - RunValidationAndConversion(node_def, p.status, "my_transpose", - p.expected_output_dims); - InstantiateBuildAndRun(tf_dtype, "my_transpose", this, p, - {1, 2, 3, 4, 5, 6}, - ElementsAreArray(expected_values)); + TestOpConverter("my_transpose", node_def, p.expected_output_dims, p.status, + p.runtime_status, ElementsAreArray(expected_values)); } } @@ -2063,7 +2151,7 @@ TEST_F(OpConverterTest, ConvertReshape) { const DataVec input_data{{"input", AsTensor<float>(input_vec)}}; DataVec output_data{ {"my_reshape", ConstructTensor<float>(input_vec.size())}}; - BuildAndRun(input_data, &output_data, batch_size); + TF_EXPECT_OK(BuildAndRun(input_data, &output_data, batch_size)); EXPECT_THAT(GetSpanForData<float>(output_data[0]), ElementsAreArray(input_vec)); } @@ -2118,7 +2206,7 @@ void TestMatMulHelper( const DataVec input_data{{"input", test->AsTensor<float>({0, 1})}}; DataVec output_data{{"my_matmul", test->ConstructTensor<float>(2)}}; - test->BuildAndRun(input_data, &output_data); + TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data)); if (transpose_b) { EXPECT_THAT(GetSpanForData<float>(output_data[0]), ElementsAre(1, 3)); } else { @@ -2145,7 +2233,7 @@ void TestMatMulHelper( ExpectTrtDimsEqualsArray({2}, output.tensor()->getDimensions()); const DataVec input_data{{"input", test->AsTensor<float>({0, 1})}}; DataVec output_data{{"my_matmul", test->ConstructTensor<float>(2)}}; - test->BuildAndRun(input_data, &output_data); + TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data)); if (transpose_b) { EXPECT_THAT(GetSpanForData<float>(output_data[0]), ElementsAre(1, 3)); } else { @@ -2288,7 +2376,7 @@ TEST_F(OpConverterTest, ConvertBatchMatMul) { ExpectTrtDimsEqualsArray({2, 2}, output.tensor()->getDimensions()); const DataVec input_data{{"input", AsTensor<float>({0, 1, 2, 3})}}; DataVec output_data{{"my_matmul", ConstructTensor<float>(4)}}; - BuildAndRun(input_data, &output_data); + TF_EXPECT_OK(BuildAndRun(input_data, &output_data)); if (!transpose_a && !transpose_b) { EXPECT_THAT(GetSpanForData<float>(output_data[0]), ElementsAre(3, 4, 11, 16)); @@ -2362,7 +2450,7 @@ void TestConvertBiasAdd(OpConverterTest* test) { {"input", test->ConstructTensor<CType>(num_input, CType(0))}}; DataVec output_data{ {"my_biasadd", test->ConstructTensor<CType>(num_input)}}; - test->BuildAndRun(input_data, &output_data); + TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data)); if (trt_input_rank == 1) { if (data_format == "NHWC") { EXPECT_THAT(GetSpanForData<CType>(output_data[0]), @@ -2445,7 +2533,7 @@ void TestBinaryOp(OpConverterTest* test, bool operand_1_is_tensor, ExpectTrtDimsEqualsArray({2, 2}, output.tensor()->getDimensions()); // After broadcasting first input becomes {3, 6, 3, 6} and second input // becomes {2, 3, 2, 3}. - test->BuildAndRun(input_data, &output_data, /*batch_size=*/2); + TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data, /*batch_size=*/2)); if (node_def.op() == "Add") { EXPECT_THAT( GetSpanForData<CType>(output_data[0]), @@ -2579,7 +2667,7 @@ void TestAddN(OpConverterTest* test) { ExpectTrtDimsEqualsArray({1, 2}, output.tensor()->getDimensions()); DataVec output_data{{"my_addn", test->ConstructTensor<CType>(4)}}; - test->BuildAndRun(input_data, &output_data, /*batch_size=*/2); + TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data, /*batch_size=*/2)); EXPECT_THAT(GetSpanForData<CType>(output_data[0]), ElementsAreArray(CastTestVector<int, CType>({3, 6, 9, 12}))); } @@ -2603,7 +2691,7 @@ void TestAddN(OpConverterTest* test) { ExpectTrtDimsEqualsArray({1, 2}, output.tensor()->getDimensions()); DataVec output_data{{"my_addn", test->ConstructTensor<CType>(2)}}; - test->BuildAndRun(input_data, &output_data); + TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data)); EXPECT_THAT(GetSpanForData<CType>(output_data[0]), ElementsAreArray(CastTestVector<int, CType>({5, 8}))); } @@ -2769,7 +2857,7 @@ void TestConvertSquare(OpConverterTest* test) { // Engine outputs are converted to FP16 automatically if we set FP16 mode in // the builder. DataVec output_data{{"my_square", test->ConstructTensor<CType>(num_inputs)}}; - test->BuildAndRun(input_data, &output_data); + TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data)); ExpectArrayNear(expected_outputs, GetSpanForData<CType>(output_data[0])); } @@ -2881,7 +2969,7 @@ TEST_F(OpConverterTest, ConvertCombinedNMS) { }; const DataVec input_data{{"boxes", AsTensor<float>({0, 0, 0.3, 0.4})}, {"scores", AsTensor<float>({0.4, 0.7, 0.3})}}; - BuildAndRun(input_data, &output_data); + TF_EXPECT_OK(BuildAndRun(input_data, &output_data)); EXPECT_THAT(GetSpanForData<float>(output_data[0]), ElementsAre(0, 0, 0.3, 0.4, 0, 0, 0.3, 0.4)); EXPECT_THAT(GetSpanForData<float>(output_data[1]), ElementsAre(0.7, 0.4)); @@ -2891,90 +2979,67 @@ TEST_F(OpConverterTest, ConvertCombinedNMS) { } #endif // IS_TRT_VERSION_GE(5, 1, 0, 0) -TEST_F(OpConverterTest, ConvertActivation) { +template <typename T> +NodeDef CreateUnaryOp(DataType tf_dtype) { + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), tf_dtype); + return T(s.WithOpName("my_unary"), input).operation.node()->def(); +} + +constexpr float kLeakyReluAlpha = 0.2f; +template <> +NodeDef CreateUnaryOp<ops::internal::LeakyRelu>(DataType tf_dtype) { + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), tf_dtype); + return ops::internal::LeakyRelu( + s.WithOpName("my_unary"), input, + ops::internal::LeakyRelu::Alpha(kLeakyReluAlpha)) + .operation.node() + ->def(); +} + +TEST_P(OpConverterTest1, ConvertActivation) { { // Input is weights, should fail. Reset(); - Scope s = Scope::NewRootScope(); - auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); - auto relu = ops::Relu(s.WithOpName("my_act"), input); - const NodeDef& node_def = relu.operation.node()->def(); + const NodeDef& node_def = CreateUnaryOp<ops::Relu>(tf_dtype); AddTestWeights<int32>("input", {1, 2, 3}, {-3, -2, -1, 0, 1, 2}); RunValidationAndConversion( node_def, error::UNIMPLEMENTED, - "The input \"input\" for Relu must be a tensor, at my_act"); + "The input \"input\" for Relu must be a tensor, at my_unary"); } - constexpr float kLeakyReluAlpha = 0.2f; constexpr float kSeluAlpha = 1.7580993408473768599402175208123f; constexpr float kSeluScale = 1.0507009873554804934193349852946f; + using OpFunc = std::function<NodeDef(DataType)>; + using ValFunc = float (*)(float); + std::map<std::string, std::pair<OpFunc, ValFunc>> op_map; - // Get nodedef for activation layer. - auto get_act_nodedef = [](string op_name) -> NodeDef { - Scope s = Scope::NewRootScope(); - auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); - if (op_name == "LeakyRelu") { - auto act = ops::internal::LeakyRelu( - s.WithOpName("my_act"), input, - ops::internal::LeakyRelu::Alpha(kLeakyReluAlpha)); - return act.operation.node()->def(); - } else if (op_name == "Relu") { - auto act = ops::Relu(s.WithOpName("my_act"), input); - return act.operation.node()->def(); - } else if (op_name == "Relu6") { - auto act = ops::Relu6(s.WithOpName("my_act"), input); - return act.operation.node()->def(); - } else if (op_name == "Sigmoid") { - auto act = ops::Sigmoid(s.WithOpName("my_act"), input); - return act.operation.node()->def(); - } else if (op_name == "Tanh") { - auto act = ops::Tanh(s.WithOpName("my_act"), input); - return act.operation.node()->def(); - } else if (op_name == "Elu") { - auto act = ops::Elu(s.WithOpName("my_act"), input); - return act.operation.node()->def(); - } else if (op_name == "Selu") { - auto act = ops::Selu(s.WithOpName("my_act"), input); - return act.operation.node()->def(); - } else if (op_name == "Softsign") { - auto act = ops::Softsign(s.WithOpName("my_act"), input); - return act.operation.node()->def(); - } else if (op_name == "Softplus") { - auto act = ops::Softplus(s.WithOpName("my_act"), input); - return act.operation.node()->def(); - } - EXPECT_TRUE(false); - return NodeDef(); - }; - // Get expected output for activation layer. - auto get_act_output = [](string op_name, float input) -> float { - if (op_name == "LeakyRelu") { - return (input > 0.0f) ? input : input * kLeakyReluAlpha; - } else if (op_name == "Relu") { - return (input > 0.0f) ? input : 0.0f; - } else if (op_name == "Relu6") { - return std::min(std::max(input, 0.0f), 6.0f); - } else if (op_name == "Sigmoid") { - return 1.0f / (1.0f + std::exp(-input)); - } else if (op_name == "Tanh") { - return std::tanh(input); - } else if (op_name == "Elu") { - return (input > 0.0f) ? input : std::exp(input) - 1; - } else if (op_name == "Selu") { - return (input > 0.0f) ? kSeluScale * input - : kSeluScale * kSeluAlpha * (std::exp(input) - 1); - } else if (op_name == "Softsign") { - return input / (std::abs(input) + 1); - } else if (op_name == "Softplus") { - return std::log(std::exp(input) + 1); - } - EXPECT_TRUE(false); - return 0; - }; +#define ADD_OP(name, op, compute) \ + op_map[name] = std::make_pair(CreateUnaryOp<op>, compute) + ADD_OP("LeakyRelu", ops::internal::LeakyRelu, + [](float x) { return (x > 0.0f) ? x : x * kLeakyReluAlpha; }); + ADD_OP("Relu", ops::Relu, [](float x) { return (x > 0.0f) ? x : 0.0f; }); + ADD_OP("Relu6", ops::Relu6, + [](float x) { return std::min(std::max(x, 0.0f), 6.0f); }); + ADD_OP("Sigmoid", ops::Sigmoid, + [](float x) { return 1.0f / (1.0f + std::exp(-x)); }); + ADD_OP("Tanh", ops::Tanh, static_cast<ValFunc>(std::tanh)); + ADD_OP("Elu", ops::Elu, + [](float x) { return (x > 0.0f) ? x : std::exp(x) - 1; }); + ADD_OP("Selu", ops::Selu, [](float x) { + return (x > 0.0f) ? kSeluScale * x + : kSeluScale * kSeluAlpha * (std::exp(x) - 1); + }); + ADD_OP("Softsign", ops::Softsign, + [](float x) { return x / (std::abs(x) + 1); }); + ADD_OP("Softplus", ops::Softplus, + [](float x) { return std::log(std::exp(x) + 1); }); +#undef ADD_OP // Get list of ops to test. std::vector<string> ops_to_test; - // Add all ops supported by ConvertUnary. + // Add all ops supported by ConvertActivation. auto* map = ActivationTypeMap(); ops_to_test.reserve(map->size()); for (auto& pair : *map) { @@ -2983,16 +3048,30 @@ TEST_F(OpConverterTest, ConvertActivation) { // Add other activation ops to test. ops_to_test.push_back("Relu6"); ops_to_test.push_back("LeakyRelu"); + auto p = TestParamBase{ + {1, 1, 2, 3}, // input dims + {}, // input partial dims + {1, 1, 2, 3}, // expected output dims + }; // Ok. for (const string& op_name : ops_to_test) { + if (!op_map.count(op_name)) { + FAIL() << "Activation op test map does not contain op " << op_name; + } Reset(); - NodeDef node_def = get_act_nodedef(op_name); - AddTestTensor("input", {1, 2, 3}); - RunValidationAndConversion(node_def); + NodeDef node_def = op_map[op_name].first(tf_dtype); + const std::vector<float> input = {-100, -2, -1, 0, 1, 88}; + AddTestTensor("input", p.input_dims, input); + + // std::exp in Softplus will overflow for input > 88 + std::vector<float> output_values; + std::transform(input.begin(), input.end(), + std::back_inserter(output_values), op_map[op_name].second); + TestOpConverter("my_unary", node_def, p.expected_output_dims, Status::OK(), + Status::OK(), ArrayFloatNear(output_values, 0, false)); + TRT_TensorOrWeights output; - TF_EXPECT_OK(GetTensorOrWeights("my_act", &output)); - ASSERT_TRUE(output.is_tensor()); - ExpectTrtDimsEqualsArray({1, 2, 3}, output.tensor()->getDimensions()); + TF_EXPECT_OK(GetTensorOrWeights("my_unary", &output)); // Certain activations should set quantization range automatically. auto ranges = quantization_ranges(); @@ -3002,17 +3081,6 @@ TEST_F(OpConverterTest, ConvertActivation) { op_name == "Softsign") { EXPECT_EQ(ranges[output.tensor()], 1.0f); } - - // std::exp in Softplus will overflow for input > 88 - const std::vector<float> input = {-100, -2, -1, 0, 1, 88}; - const DataVec input_data{{"input", AsTensor<float>(input)}}; - DataVec output_data{{"my_act", ConstructTensor<float>(6)}}; - BuildAndRun(input_data, &output_data); - for (int i = 0; i < input.size(); i++) { - const float expected_output = get_act_output(op_name, input[i]); - EXPECT_FLOAT_EQ(GetSpanForData<float>(output_data[0])[i], - expected_output); - } } } @@ -3112,23 +3180,17 @@ TEST_F(OpConverterTest, ConvertExpandDims) { const DataVec input_data{{"input", AsTensor<float>({1, 2, 3, 4, 5, 6})}}; DataVec output_data{{"my_expanddims", ConstructTensor<float>(6)}}; - BuildAndRun(input_data, &output_data); + TF_EXPECT_OK(BuildAndRun(input_data, &output_data)); EXPECT_THAT(GetSpanForData<float>(output_data[0]), ElementsAre(1, 2, 3, 4, 5, 6)); } } -TEST_P(ParameterizedOpConverterTest, ConvertSqueeze) { - const auto& spec = GetParam(); - const TrtTestMode trt_mode = std::get<0>(spec); +TEST_P(OpConverterTest1, ConvertSqueeze) { const bool use_implicit_batch = (trt_mode == TrtTestMode::kImplicitBatch); - // Data type of TF input tensors - const DataType tf_dtype = std::get<1>(spec); - // Precision mode used for TensorRT engine - TrtPrecisionMode converter_precision = std::get<2>(spec); - // Get the NodeDef for Squeeze. - auto get_squeeze_nodedef = [tf_dtype](std::vector<int> axes) -> NodeDef { + auto get_squeeze_nodedef = [](std::vector<int> axes, + DataType tf_dtype) -> NodeDef { Scope s = Scope::NewRootScope(); auto input = ops::Placeholder(s.WithOpName("input"), tf_dtype); if (!axes.empty()) { @@ -3221,14 +3283,12 @@ TEST_P(ParameterizedOpConverterTest, ConvertSqueeze) { for (TestParamBase p : test_params) { SCOPED_TRACE(p); - Reset(converter_precision, trt_mode); - NodeDef node_def = get_squeeze_nodedef(p.param); - AddTestTensor("input", p.input_dims, TfDataTypeToTrt(tf_dtype), trt_mode, - &p.partial_input_dims); - RunValidationAndConversion(node_def, p.status, "my_squeeze", - p.expected_output_dims); - InstantiateBuildAndRun(tf_dtype, "my_squeeze", this, p, {1, 2, 3, 4, 5, 6}, - ElementsAreArray({1, 2, 3, 4, 5, 6})); + Reset(); + NodeDef node_def = get_squeeze_nodedef(p.param, tf_dtype); + AddTestTensor("input", p.input_dims, {1, 2, 3, 4, 5, 6}, + p.partial_input_dims); + TestOpConverter("my_squeeze", node_def, p.expected_output_dims, p.status, + p.runtime_status, ElementsAreArray({1, 2, 3, 4, 5, 6})); } } @@ -3831,7 +3891,7 @@ TEST_F(OpConverterTest, ConvertStridedSlice) { DataVec output_data{ {"my_strided_slice", ConstructTensor<float>(ok_params[i].expected_output.size())}}; - BuildAndRun(input_data, &output_data); + TF_EXPECT_OK(BuildAndRun(input_data, &output_data)); EXPECT_THAT(GetSpanForData<float>(output_data[0]), ElementsAreArray(ok_params[i].expected_output)); } @@ -3971,21 +4031,22 @@ TEST_F(OpConverterTest, ConvertSlice) { const DataVec input_data{{"input", AsTensor<float>({1, 2, 3, 4, 5, 6})}}; DataVec output_data{{"my_slice", ConstructTensor<float>( ok_params[i].expected_output.size())}}; - BuildAndRun(input_data, &output_data); + TF_EXPECT_OK(BuildAndRun(input_data, &output_data)); EXPECT_THAT(GetSpanForData<float>(output_data[0]), ElementsAreArray(ok_params[i].expected_output)); } } -TEST_F(OpConverterTest, ConvertConv2D) { +TEST_P(OpConverterTest1, ConvertConv2D) { // Get nodedef for Conv2D layer. + DataType tf_type = tf_dtype; auto get_conv2d_nodedef = - [](std::vector<int> strides = {1, 1, 1, 1}, string padding = "SAME", - string data_format = "NCHW", - std::vector<int> dilations = {1, 1, 1, 1}) -> NodeDef { + [tf_type](std::vector<int> strides = {1, 1, 1, 1}, + string padding = "SAME", string data_format = "NCHW", + std::vector<int> dilations = {1, 1, 1, 1}) -> NodeDef { Scope s = Scope::NewRootScope(); - auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); - auto filter = ops::Placeholder(s.WithOpName("weights"), DT_FLOAT); + auto input = ops::Placeholder(s.WithOpName("input"), tf_type); + auto filter = ops::Placeholder(s.WithOpName("weights"), tf_type); ops::Conv2D::Attrs attrs = ops::Conv2D::Attrs().DataFormat(data_format).Dilations(dilations); auto conv2d = ops::Conv2D(s.WithOpName("my_conv2d"), input, filter, strides, @@ -4007,7 +4068,7 @@ TEST_F(OpConverterTest, ConvertConv2D) { // Filter is tensor, should fail. Reset(); NodeDef node_def = get_conv2d_nodedef(); - AddTestTensor("input", {1, 2, 3}); + AddTestTensor("input", {3, 1, 2, 1}); AddTestTensor("weights", {3, 3, 1, 1}); RunValidationAndConversion( node_def, error::UNIMPLEMENTED, @@ -4017,7 +4078,7 @@ TEST_F(OpConverterTest, ConvertConv2D) { // Filter is not 4D, should fail. Reset(); NodeDef node_def = get_conv2d_nodedef(); - AddTestTensor("input", {1, 2, 3}); + AddTestTensor("input", {1, 1, 2, 3}); AddTestWeights<float>("weights", {3, 3, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); RunValidationAndConversion( node_def, error::INVALID_ARGUMENT, @@ -4028,7 +4089,7 @@ TEST_F(OpConverterTest, ConvertConv2D) { Reset(); NodeDef node_def = get_conv2d_nodedef({1, 1, 1, 1}, "SAME", "NCHW", {1, 1, 1}); - AddTestTensor("input", {1, 2, 3}); + AddTestTensor("input", {1, 1, 2, 3}); AddTestWeights<float>("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); RunValidationAndConversion( node_def, error::INVALID_ARGUMENT, @@ -4039,7 +4100,7 @@ TEST_F(OpConverterTest, ConvertConv2D) { Reset(); NodeDef node_def = get_conv2d_nodedef({1, 1, 1, 1}, "SAME", "NCHW", {1, 2, 1, 1}); - AddTestTensor("input", {1, 2, 3}); + AddTestTensor("input", {1, 1, 2, 3}); AddTestWeights<float>("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); RunValidationAndConversion(node_def, error::UNIMPLEMENTED, "Dilation rate must be 1 for batch and channel " @@ -4050,7 +4111,7 @@ TEST_F(OpConverterTest, ConvertConv2D) { Reset(); NodeDef node_def = get_conv2d_nodedef({1, 1, 1, 1}, "SAME", "NHWC", {1, 1, 1, 2}); - AddTestTensor("input", {2, 3, 1}); + AddTestTensor("input", {1, 2, 3, 1}); AddTestWeights<float>("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); RunValidationAndConversion(node_def, error::UNIMPLEMENTED, "Dilation rate must be 1 for batch and channel " @@ -4061,7 +4122,7 @@ TEST_F(OpConverterTest, ConvertConv2D) { Reset(); NodeDef node_def = get_conv2d_nodedef({1, 1, 1}, "SAME", "NCHW", {1, 1, 1, 1}); - AddTestTensor("input", {1, 2, 3}); + AddTestTensor("input", {1, 1, 2, 3}); AddTestWeights<float>("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); RunValidationAndConversion( node_def, error::INVALID_ARGUMENT, @@ -4072,12 +4133,23 @@ TEST_F(OpConverterTest, ConvertConv2D) { Reset(); NodeDef node_def = get_conv2d_nodedef({1, 2, 1, 1}, "SAME", "NCHW", {1, 1, 1, 1}); - AddTestTensor("input", {1, 2, 3}); + AddTestTensor("input", {1, 1, 2, 3}); AddTestWeights<float>("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); RunValidationAndConversion( node_def, error::UNIMPLEMENTED, "Stride must be 1 for batch and channel dimensions, at my_conv2d"); } + if (trt_mode == TrtTestMode::kDynamicShape) { + Reset(); + NodeDef node_def = get_conv2d_nodedef(); + // Channel dim unknown, should fail. + AddTestTensorWithTFDims("input", {-1, -1, -1, -1}, + TfDataTypeToTrt(tf_type)); + AddTestWeights<float>("weights", {1, 2, 1, 1}, {-1, 1}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Channel dimension must be static, at my_conv2d"); + } struct TestParams { std::vector<int> input_dims; @@ -4095,7 +4167,7 @@ TEST_F(OpConverterTest, ConvertConv2D) { // Ok. std::vector<TestParams> ok_params = { // Basic - TestParams{/*input_dims=*/{1, 2, 3}, + TestParams{/*input_dims=*/{1, 1, 2, 3}, /*input=*/{0, 1, 2, 3, 3, 4}, /*filter_dims=*/{1, 2, 1, 1}, /*filter=*/{-1, 1}, @@ -4103,10 +4175,10 @@ TEST_F(OpConverterTest, ConvertConv2D) { /*padding=*/"VALID", /*data_format=*/"NCHW", /*dilations=*/{1, 1, 1, 1}, - /*expected_output_dims=*/{1, 2, 2}, + /*expected_output_dims=*/{1, 1, 2, 2}, /*expected_output=*/{1, 1, 0, 1}}, // SAME padding (Asymmetric) - TestParams{/*input_dims=*/{1, 2, 3}, + TestParams{/*input_dims=*/{1, 1, 2, 3}, /*input=*/{0, 1, 2, 3, 3, 4}, /*filter_dims=*/{1, 2, 1, 1}, /*filter=*/{-1, 1}, @@ -4114,10 +4186,10 @@ TEST_F(OpConverterTest, ConvertConv2D) { /*padding=*/"SAME", /*data_format=*/"NCHW", /*dilations=*/{1, 1, 1, 1}, - /*expected_output_dims=*/{1, 2, 3}, + /*expected_output_dims=*/{1, 1, 2, 3}, /*expected_output=*/{1, 1, -2, 0, 1, -4}}, // SAME padding (Symmetric) - TestParams{/*input_dims=*/{1, 2, 3}, + TestParams{/*input_dims=*/{1, 1, 2, 3}, /*input=*/{0, 1, 2, 3, 3, 4}, /*filter_dims=*/{1, 3, 1, 1}, /*filter=*/{-1, 0, 1}, @@ -4125,10 +4197,10 @@ TEST_F(OpConverterTest, ConvertConv2D) { /*padding=*/"SAME", /*data_format=*/"NCHW", /*dilations=*/{1, 1, 1, 1}, - /*expected_output_dims=*/{1, 2, 3}, + /*expected_output_dims=*/{1, 1, 2, 3}, /*expected_output=*/{1, 2, -1, 3, 1, -3}}, // NHWC - TestParams{/*input_dims=*/{2, 3, 1}, + TestParams{/*input_dims=*/{1, 2, 3, 1}, /*input=*/{0, 1, 2, 3, 3, 4}, /*filter_dims=*/{1, 2, 1, 1}, /*filter=*/{-1, 1}, @@ -4136,10 +4208,10 @@ TEST_F(OpConverterTest, ConvertConv2D) { /*padding=*/"VALID", /*data_format=*/"NHWC", /*dilations=*/{1, 1, 1, 1}, - /*expected_output_dims=*/{2, 2, 1}, + /*expected_output_dims=*/{1, 2, 2, 1}, /*expected_output=*/{1, 1, 0, 1}}, // Dilated - TestParams{/*input_dims=*/{1, 2, 3}, + TestParams{/*input_dims=*/{1, 1, 2, 3}, /*input=*/{0, 1, 2, 3, 3, 4}, /*filter_dims=*/{1, 2, 1, 1}, /*filter=*/{-1, 1}, @@ -4147,10 +4219,10 @@ TEST_F(OpConverterTest, ConvertConv2D) { /*padding=*/"VALID", /*data_format=*/"NCHW", /*dilations=*/{1, 1, 1, 2}, - /*expected_output_dims=*/{1, 2, 1}, + /*expected_output_dims=*/{1, 1, 2, 1}, /*expected_output=*/{2, 1}}, // Strided - TestParams{/*input_dims=*/{1, 2, 4}, + TestParams{/*input_dims=*/{1, 1, 2, 4}, /*input=*/{0, 1, 2, 2, 3, 4, 4, 7}, /*filter_dims=*/{1, 2, 1, 1}, /*filter=*/{-1, 1}, @@ -4158,7 +4230,7 @@ TEST_F(OpConverterTest, ConvertConv2D) { /*padding=*/"VALID", /*data_format=*/"NCHW", /*dilations=*/{1, 1, 1, 1}, - /*expected_output_dims=*/{1, 2, 2}, + /*expected_output_dims=*/{1, 1, 2, 2}, /*expected_output=*/{1, 0, 1, 3}}, }; @@ -4167,23 +4239,22 @@ TEST_F(OpConverterTest, ConvertConv2D) { NodeDef node_def = get_conv2d_nodedef(ok_params[i].strides, ok_params[i].padding, ok_params[i].data_format, ok_params[i].dilations); - AddTestTensor("input", ok_params[i].input_dims); + std::vector<int> partial_input_shape; + if (trt_mode == TrtTestMode::kDynamicShape) { + // The channel dim cannot have unknown size, fix that. + partial_input_shape.resize(ok_params[i].input_dims.size(), -1); + int channel_id = (ok_params[i].data_format == "NCHW") ? 1 : 3; + partial_input_shape[channel_id] = ok_params[i].input_dims[channel_id]; + } + + AddTestTensor("input", ok_params[i].input_dims, tf_dtype, + ok_params[i].input, partial_input_shape); AddTestWeights<float>("weights", ok_params[i].filter_dims, ok_params[i].filter); - RunValidationAndConversion(node_def); - TRT_TensorOrWeights output; - TF_EXPECT_OK(GetTensorOrWeights("my_conv2d", &output)); - ASSERT_TRUE(output.is_tensor()); - ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, - output.tensor()->getDimensions()); - const DataVec input_data{{"input", AsTensor<float>(ok_params[i].input)}}; - DataVec output_data{ - {"my_conv2d", - ConstructTensor<float>(ok_params[i].expected_output.size())}}; - BuildAndRun(input_data, &output_data); - EXPECT_THAT(GetSpanForData<float>(output_data[0]), - ElementsAreArray(ok_params[i].expected_output)); + TestOpConverter("my_conv2d", node_def, ok_params[i].expected_output_dims, + Status::OK(), Status::OK(), + ElementsAreArray(ok_params[i].expected_output)); } } @@ -4308,7 +4379,7 @@ TEST_F(OpConverterTest, ConvertConv2DBackpropInput) { DataVec output_data{ {"my_conv2d_backprop_input", ConstructTensor<float>(ok_params[i].expected_output.size())}}; - BuildAndRun(input_data, &output_data); + TF_EXPECT_OK(BuildAndRun(input_data, &output_data)); EXPECT_THAT(GetSpanForData<float>(output_data[0]), ElementsAreArray(ok_params[i].expected_output)); } @@ -4640,7 +4711,7 @@ TEST_F(OpConverterTest, ConvertConv3D) { DataVec output_data{ {"my_conv3d", ConstructTensor<float>(ok_params[i].expected_output.size())}}; - BuildAndRun(input_data, &output_data); + TF_EXPECT_OK(BuildAndRun(input_data, &output_data)); EXPECT_THAT(GetSpanForData<float>(output_data[0]), ElementsAreArray(ok_params[i].expected_output)); } @@ -4827,7 +4898,7 @@ TEST_F(OpConverterTest, ConvertPool3D) { DataVec output_data{ {expected_node_name, ConstructTensor<float>(ok_params[i].expected_output.size())}}; - BuildAndRun(input_data, &output_data); + TF_EXPECT_OK(BuildAndRun(input_data, &output_data)); EXPECT_THAT(GetSpanForData<float>(output_data[0]), ElementsAreArray(ok_params[i].expected_output)); } @@ -4872,7 +4943,7 @@ TEST_F(OpConverterTest, ConvertTopK) { {"input", AsTensor<float>({-9, 3, 5, 1, 6, -5, 7, 1, 0, -1})}}; DataVec output_data{{"my_topk", ConstructTensor<float>(4)}, {"my_topk:1", ConstructTensor<int32>(4)}}; - BuildAndRun(input_data, &output_data); + TF_EXPECT_OK(BuildAndRun(input_data, &output_data)); EXPECT_THAT(GetSpanForData<float>(output_data[0]), ElementsAre(6, 5, 7, 1)); EXPECT_THAT(GetSpanForData<int32>(output_data[1]), @@ -5059,8 +5130,8 @@ void TestConvertGather(OpConverterTest* test) { } DataVec output_data{ {"my_gather", test->ConstructTensor<CType>(expected_output.size())}}; - test->BuildAndRun(input_data, &output_data, - /*batch_size=*/expected_output_shape[0]); + TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data, + /*batch_size=*/expected_output_shape[0])); EXPECT_THAT(GetSpanForData<CType>(output_data[0]), ElementsAreArray(converted_expected_output)); } @@ -5131,28 +5202,25 @@ TEST_F(OpConverterTest, ConvertGather) { TestConvertGather<DT_INT32>(this); } -template <typename T> -NodeDef CreateUnaryOp() { +NodeDef CreateCastOp(DataType tf_dtype) { Scope s = Scope::NewRootScope(); - auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); - return T(s.WithOpName("my_unary"), input).operation.node()->def(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_HALF); + return ops::Cast(s.WithOpName("my_unary"), input, DT_FLOAT) + .operation.node() + ->def(); } -TEST_P(ParameterizedOpConverterTest, ConvertUnary) { - const auto& spec = GetParam(); - const TrtTestMode trt_mode = std::get<0>(spec); - const DataType tf_dtype = std::get<1>(spec); - TrtPrecisionMode converter_precision = std::get<2>(spec); +TEST_P(OpConverterTest1, ConvertUnary) { { // Input is weights, should fail. - Reset(converter_precision, trt_mode); - const NodeDef node_def = CreateUnaryOp<ops::Neg>(); + Reset(); + const NodeDef node_def = CreateUnaryOp<ops::Neg>(tf_dtype); AddTestWeights<float>("input", {1, 2, 3}, {-3, -2, -1, 0, 1, 2}); RunValidationAndConversion( node_def, error::UNIMPLEMENTED, "The input \"x\" for Neg must be a tensor, at my_unary"); } - using OpFunc = std::function<NodeDef(void)>; + using OpFunc = std::function<NodeDef(DataType)>; using ValFunc = float (*)(float); std::map<std::string, std::pair<OpFunc, ValFunc>> op_map; #define ADD_OP(name, op, compute) \ @@ -5165,6 +5233,7 @@ TEST_P(ParameterizedOpConverterTest, ConvertUnary) { ADD_OP("Asinh", ops::Asinh, std::asinh); ADD_OP("Atan", ops::Atan, std::atan); ADD_OP("Atanh", ops::Atanh, std::atanh); + op_map["Cast"] = std::make_pair(CreateCastOp, [](float x) { return x; }); ADD_OP("Ceil", ops::Ceil, std::ceil); ADD_OP("Cos", ops::Cos, std::cos); ADD_OP("Cosh", ops::Cosh, std::cosh); @@ -5197,22 +5266,27 @@ TEST_P(ParameterizedOpConverterTest, ConvertUnary) { }; for (const string& op_name : ops_to_test) { SCOPED_TRACE(op_name); - Reset(converter_precision, trt_mode); + Reset(); if (!op_map.count(op_name)) { FAIL() << "Unary op test map does not contain op " << op_name; } - NodeDef node_def = op_map[op_name].first(); + NodeDef node_def = op_map[op_name].first(tf_dtype); - AddTestTensor("input", p.input_dims, TfDataTypeToTrt(tf_dtype), trt_mode); - RunValidationAndConversion(node_def, Status::OK(), "my_unary", - p.expected_output_dims); + // TODO(bixia): we assume this test is only instantiated for DT_FLOAT for + // now. Need to find a better way to express input and output types. + // + // TODO(tfeher): improve tests by defining an expected output data type and + // check that. Currently only the shape and values of the output are + // checked. + DataType input_tf_dtype = op_name == "Cast" ? DT_HALF : tf_dtype; std::vector<float> input_values{-0.9f, 0.6f, 0.0f, -3.5f, 100.0f, 2.9f}; + AddTestTensor("input", p.input_dims, input_tf_dtype, input_values); std::vector<float> output; std::transform(input_values.begin(), input_values.end(), std::back_inserter(output), op_map[op_name].second); - InstantiateBuildAndRun(tf_dtype, "my_unary", this, p, input_values, - ArrayFloatNear(output, 0.0001, true)); + TestOpConverter("my_unary", node_def, p.expected_output_dims, Status::OK(), + p.runtime_status, ArrayFloatNear(output, 0.0001, true)); } } @@ -5316,7 +5390,7 @@ void TestConvertConcat(OpConverterTest* test) { DataVec output_data{ {"my_concat", test->ConstructTensor<CType>(ok_params[i].expected_output.size())}}; - test->BuildAndRun(input_data, &output_data); + TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data)); EXPECT_THAT(GetSpanForData<CType>(output_data[0]), ElementsAreArray(ok_params[i].expected_output)); } @@ -5481,7 +5555,7 @@ void TestConvertSplit(OpConverterTest* test) { // Verify output values are correct. const DataVec input_data{ {"value", test->AsTensor<CType>(ok_params[i].value)}}; - test->BuildAndRun(input_data, &output_data); + TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data)); for (int j = 0; j < outputs.size(); ++j) { EXPECT_THAT(GetSpanForData<CType>(output_data[j]), ElementsAreArray(ok_params[i].expected_outputs[j])); @@ -5658,7 +5732,7 @@ void TestConvertUnpack(OpConverterTest* test) { // Verify output values are correct. const DataVec input_data{ {"value", test->AsTensor<CType>(ok_params[i].value)}}; - test->BuildAndRun(input_data, &output_data); + TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data)); for (int j = 0; j < outputs.size(); ++j) { EXPECT_THAT(GetSpanForData<CType>(output_data[j]), ElementsAreArray(ok_params[i].expected_outputs[j])); @@ -5827,7 +5901,7 @@ void TestConvertPack(OpConverterTest* test) { } DataVec output_data{{"my_pack", test->ConstructTensor<CType>( params[i].expected_output.size())}}; - test->BuildAndRun(input_data, &output_data); + TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data)); EXPECT_THAT(GetSpanForData<CType>(output_data[0]), ElementsAreArray(params[i].expected_output)); } @@ -5975,7 +6049,7 @@ void TestConvertArgMinMax(OpConverterTest* test) { DataVec output_data{ {"my_arg", test->ConstructTensor<int32>( params[i].expected_argmax_output.size())}}; - test->BuildAndRun(input_data, &output_data); + TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data)); if (node_def.op() == "ArgMax") { EXPECT_THAT(GetSpanForData<int32>(output_data[0]), @@ -6074,7 +6148,7 @@ void TestConvertDepthSpaceShuffle( DataVec input_data{{"input", test->AsTensor<CType>(params[i].input_value)}}; DataVec output_data{{"my_shuffle", test->ConstructTensor<CType>( params[i].expected_output.size())}}; - test->BuildAndRun(input_data, &output_data); + TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data)); EXPECT_THAT(GetSpanForData<CType>(output_data[0]), ElementsAreArray(params[i].expected_output)); } @@ -6350,7 +6424,7 @@ void TestConvertClipByValue(OpConverterTest* test) { DataVec input_data{{"t", test->AsTensor<CType>(params[i].input_value)}}; DataVec output_data{{"my_clip", test->ConstructTensor<CType>( params[i].expected_output.size())}}; - test->BuildAndRun(input_data, &output_data); + TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data)); EXPECT_THAT(GetSpanForData<CType>(output_data[0]), ElementsAreArray(params[i].expected_output)); } @@ -6458,7 +6532,7 @@ void TestConvertSquaredDifference(OpConverterTest* test) { DataVec output_data{ {"my_squared_diff", test->ConstructTensor<CType>(params[i].expected_output.size())}}; - test->BuildAndRun(input_data, &output_data); + TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data)); EXPECT_THAT(GetSpanForData<CType>(output_data[0]), ElementsAreArray(params[i].expected_output)); } @@ -6563,7 +6637,7 @@ void TestConvertResize(OpConverterTest* test) { {"my_resize", test->ConstructTensor<CType>( params[i].expected_nearest_output_values.size())}}; - test->BuildAndRun(input_data, &output_data); + TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data)); if (node_def.op() == "ResizeBilinear") { ExpectArrayAlmostEqual(params[i].expected_bilinear_output_values, @@ -6663,7 +6737,7 @@ void TestConvertPad(OpConverterTest* test) { {"my_pad", test->ConstructTensor<CType>( params[i].expected_output_values.size())}}; - test->BuildAndRun(input_data, &output_data); + TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data)); ExpectArrayAlmostEqual(params[i].expected_output_values, GetSpanForData<CType>(output_data[0]), CType(1e-5)); } diff --git a/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc index 6ab719db54d..72f4fe5ef9b 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc @@ -228,9 +228,6 @@ Status TRTOptimizationPass::Optimize(grappler::Cluster* cluster, << "This can result in poor performance."; } } - grappler::GraphProperties static_graph_properties(item); - TF_RETURN_IF_ERROR(static_graph_properties.InferStatically(true)); - ConversionParams cp; if (use_calibration_ && precision_mode_ != TrtPrecisionMode::INT8) { VLOG(1) << "Calibration with FP32 or FP16 is not implemented. " @@ -255,7 +252,9 @@ Status TRTOptimizationPass::Optimize(grappler::Cluster* cluster, } nodes_to_preserve.push_back(s); } - cp.input_graph_def = &item.graph; + + ConversionParams cp; + cp.grappler_item = &item; cp.output_names = &nodes_to_preserve; cp.trt_logger_name = trt_logger_name_; cp.max_batch_size = maximum_batch_size_; @@ -263,7 +262,6 @@ Status TRTOptimizationPass::Optimize(grappler::Cluster* cluster, cp.output_graph_def = optimized_graph; cp.precision_mode = precision_mode_; cp.minimum_segment_size = minimum_segment_size_; - cp.graph_properties = &static_graph_properties; cp.cluster = cluster; cp.is_dyn_op = is_dynamic_op_; cp.max_cached_engines = max_cached_batches_; diff --git a/tensorflow/compiler/tf2tensorrt/convert/utils.cc b/tensorflow/compiler/tf2tensorrt/convert/utils.cc index fb3ae6943d3..a4b64ec0dc5 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/utils.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/utils.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/errors.h" namespace tensorflow { namespace tensorrt { @@ -185,6 +186,40 @@ Status TrtDimsToTensorShape(const nvinfer1::Dims trt_dims, return Status::OK(); } +Status TfTypeToTrtType(DataType tf_type, nvinfer1::DataType* trt_type) { + switch (tf_type) { + case DT_FLOAT: + *trt_type = nvinfer1::DataType::kFLOAT; + break; + case DT_HALF: + *trt_type = nvinfer1::DataType::kHALF; + break; + case DT_INT32: + *trt_type = nvinfer1::DataType::kINT32; + break; + default: + return errors::Internal("Unsupported tensorflow type"); + } + return Status::OK(); +} + +Status TrtTypeToTfType(nvinfer1::DataType trt_type, DataType* tf_type) { + switch (trt_type) { + case nvinfer1::DataType::kFLOAT: + *tf_type = DT_FLOAT; + break; + case nvinfer1::DataType::kHALF: + *tf_type = DT_HALF; + break; + case nvinfer1::DataType::kINT32: + *tf_type = DT_INT32; + break; + default: + return errors::Internal("Invalid TRT type"); + } + return Status::OK(); +} + int GetNumberOfEngineInputs(const nvinfer1::ICudaEngine* engine) { int n_bindings = engine->getNbBindings(); int n_input = 0; diff --git a/tensorflow/compiler/tf2tensorrt/convert/utils.h b/tensorflow/compiler/tf2tensorrt/convert/utils.h index 5d4cf1bb851..59eeb420134 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/utils.h +++ b/tensorflow/compiler/tf2tensorrt/convert/utils.h @@ -106,6 +106,9 @@ Status TrtDimsToTensorShape(const nvinfer1::Dims trt_dims, bool use_implicit_batch, int batch_size, TensorShape& shape); +Status TfTypeToTrtType(DataType tf_type, nvinfer1::DataType* trt_type); +Status TrtTypeToTfType(nvinfer1::DataType trt_type, DataType* tf_type); + // Returns a string that includes compile time TensorRT library version // information {Maj, Min, Patch}. string GetLinkedTensorRTVersion(); diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 55341c0a01f..37110442b26 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -350,6 +350,7 @@ cc_library( ":sharding_util", ":side_effect_util", ":tf2xla_util", + "//tensorflow/compiler/jit:common", "//tensorflow/compiler/jit:flags", "//tensorflow/compiler/jit:shape_inference", "//tensorflow/compiler/jit:xla_cluster_util", diff --git a/tensorflow/compiler/tf2xla/graph_compiler_util.cc b/tensorflow/compiler/tf2xla/graph_compiler_util.cc index 57278eea292..a9385e05564 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler_util.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler_util.cc @@ -49,10 +49,12 @@ typedef std::unordered_map<string, Node*> NodeMap; // Each feed id identifies the positional output of some node, which may consist // of multiple edges. AddPlaceholdersForFeeds has already replaced each fed // tensor with a placeholder. For each feed tensor, replaces all edges so they -// point from a new _Arg node instead. +// point from a new _Arg node instead. The newly created _Arg nodes are added to +// `arg_nodes`. Status AddArgNodes(Graph* graph, const NodeMap& node_map, const protobuf::RepeatedPtrField<tf2xla::Feed>& feeds, - const std::unordered_map<string, string>& feed_remapping) { + const std::unordered_map<string, string>& feed_remapping, + std::unordered_set<const Node*>* arg_nodes) { for (int arg_index = 0; arg_index < feeds.size(); ++arg_index) { const tf2xla::Feed& feed = feeds[arg_index]; // All feeds have been replaced by placeholders. @@ -86,6 +88,7 @@ Status AddArgNodes(Graph* graph, const NodeMap& node_map, .Attr(kShapeAttr, TensorShape(feed.shape())) .Attr(kDebugNameAttr, feed.name()) .Finalize(graph, &arg_node)); + arg_nodes->insert(arg_node); // Collects out-edges from the feed node that have a matching edge index; // these will be replaced with edges from the arg node instead. @@ -149,13 +152,13 @@ Status RewriteAndPruneGraph( for (Node* n : graph->nodes()) { node_map[n->name()] = n; } + std::unordered_set<const Node*> nodes_to_keep; + TF_RETURN_IF_ERROR(AddArgNodes(graph, node_map, config.feed(), feed_remapping, + &nodes_to_keep)); TF_RETURN_IF_ERROR( - AddArgNodes(graph, node_map, config.feed(), feed_remapping)); - std::unordered_set<const Node*> retval_nodes; - TF_RETURN_IF_ERROR( - AddRetvalNodes(graph, node_map, config.fetch(), &retval_nodes)); + AddRetvalNodes(graph, node_map, config.fetch(), &nodes_to_keep)); VLOG(2) << "Post rewrite: " << DumpGraphToFile("tf2xla_post_rewrite", *graph); - PruneForReverseReachability(graph, std::move(retval_nodes)); + PruneForReverseReachability(graph, std::move(nodes_to_keep)); FixupSourceAndSinkEdges(graph); VLOG(2) << "Post prune: " << DumpGraphToFile("tfcompile_post_prune", *graph); // Sanity-check, to make sure the feeds and fetches still exist post-pruning. @@ -277,8 +280,16 @@ Status InitGraph(const GraphDef& graph_def, const tf2xla::Config& config, // Prune the GraphDef first so that unknown ops that we aren't compiling get // filtered out. GraphDef second_copy_def; + // Add the placeholder nodes as "fetches" in prune_config, such that they will + // be preserved in PruneGraphDefInto. + auto prune_config = config; + for (const auto& entry : feed_remapping) { + auto ph = prune_config.add_fetch(); + *ph->mutable_id()->mutable_node_name() = entry.second; + ph->mutable_id()->set_output_index(0); + } TF_RETURN_IF_ERROR( - PruneGraphDefInto(config, first_copy_def, &second_copy_def)); + PruneGraphDefInto(prune_config, first_copy_def, &second_copy_def)); TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef( &second_copy_def, *g->op_registry(), /*node_offset=*/0)); diff --git a/tensorflow/compiler/tf2xla/kernels/data_format_ops.cc b/tensorflow/compiler/tf2xla/kernels/data_format_ops.cc index fb89742b139..c1f60abc0d6 100644 --- a/tensorflow/compiler/tf2xla/kernels/data_format_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/data_format_ops.cc @@ -106,8 +106,9 @@ class DataFormatVecPermuteOp : public XlaOpKernel { errors::InvalidArgument( "Input must be a vector or matrix, but got shape ", input_tensor_shape.DebugString())); + const int dim0 = input_tensor_shape.dim_size(0); OP_REQUIRES( - ctx, input_tensor_shape.dim_size(0) == 4, + ctx, dim0 == 2 || dim0 == 4, errors::InvalidArgument( "First dimension of input must be of size 4, but got shape ", input_tensor_shape.DebugString())); @@ -118,10 +119,25 @@ class DataFormatVecPermuteOp : public XlaOpKernel { "Second dimension of 2D input must be of size 2, but got shape ", input_tensor_shape.DebugString())); } - int32 dst_indices[4]; - for (int i = 0; i < 4; ++i) { - for (int j = 0; j < 4; ++j) { - if (src_format_[i] == dst_format_[j]) { + + string src_format_str = src_format_; + string dst_format_str = dst_format_; + if (dim0 == 2) { + // If the input is a vector of size 2, treat the two elements as spatial + // dimensions. + auto keep_only_spatial_dimensions = [](string* format_str) -> void { + auto new_end = std::remove_if( + format_str->begin(), format_str->end(), + [](const char dim) { return dim != 'H' && dim != 'W'; }); + format_str->erase(new_end, format_str->end()); + }; + keep_only_spatial_dimensions(&src_format_str); + keep_only_spatial_dimensions(&dst_format_str); + } + std::vector<int32> dst_indices(dim0); + for (int i = 0; i < dim0; ++i) { + for (int j = 0; j < dim0; ++j) { + if (src_format_str[i] == dst_format_str[j]) { dst_indices[j] = i; break; } diff --git a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc index a3fcb4d4b8f..bd6f58453df 100644 --- a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" @@ -81,9 +82,7 @@ class MatMulOp : public XlaOpKernel { b = xla::ConvertElementType(b, xla::F32); } } - auto lhs = (transpose_a_) ? xla::Transpose(a, {1, 0}) : a; - auto rhs = (transpose_b_) ? xla::Transpose(b, {1, 0}) : b; - ctx->SetOutput(0, xla::Dot(lhs, rhs)); + ctx->SetOutput(0, xla::BatchDot(a, transpose_a_, b, transpose_b_)); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc index 8431724f438..beb8e7aa174 100644 --- a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc @@ -36,10 +36,8 @@ limitations under the License. namespace tensorflow { namespace { -// TODO(phawkins): implement double-sized windowed reductions in XLA and remove -// the type constraint. -constexpr std::array<DataType, 4> kScanOpTypes = { - {DT_HALF, DT_BFLOAT16, DT_FLOAT, DT_INT32}}; +constexpr std::array<DataType, 5> kScanOpTypes = { + {DT_HALF, DT_BFLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32}}; class ScanOp : public XlaOpKernel { public: diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc index fa5a96ca6bd..d01f094dc2e 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc @@ -431,6 +431,120 @@ class TensorListStackOp : public XlaOpKernel { REGISTER_XLA_OP(Name("TensorListStack"), TensorListStackOp); +class TensorListConcatOp : public XlaOpKernel { + public: + explicit TensorListConcatOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaOp input = ctx->Input(0); + + // Check that the TensorList is initialized. + bool is_initialized; + OP_REQUIRES_OK(ctx, (IsTensorListInitialized(input, &is_initialized))); + OP_REQUIRES(ctx, is_initialized, + errors::InvalidArgument("TensorList is not initialized")); + + // Only non-nested TensorList is supported for now. + bool is_nested; + OP_REQUIRES_OK(ctx, IsNestedTensorList(input, &is_nested)); + OP_REQUIRES(ctx, !is_nested, + errors::Unimplemented("Only non-nested TensorList is supported " + "for TensorListConcat.")); + + xla::XlaOp buffer; + OP_REQUIRES_OK(ctx, GetTensorListBuffer(input, &buffer)); + + xla::XlaBuilder* b = input.builder(); + auto shape_or = b->GetShape(buffer); + OP_REQUIRES_OK(ctx, shape_or.status()); + xla::Shape element_shape = shape_or.ConsumeValueOrDie(); + std::vector<int64> element_dims = + xla::SpanToVector(element_shape.dimensions()); + OP_REQUIRES( + ctx, element_dims.size() > 1, + errors::Unimplemented("TensorList of scalars is not supported")); + int64 num_elements = element_dims[0]; + int64 tensor_lengths = element_dims[1]; + + std::vector<int64> new_dims = {num_elements * tensor_lengths}; + + for (int i = 2; i < element_dims.size(); i++) { + new_dims.push_back(element_dims[i]); + } + + xla::XlaOp out = xla::Reshape(buffer, new_dims); + ctx->SetOutput(0, out); + + // Second output is a tensor of lengths of returned tensors. + xla::XlaOp lengths = xla::ConstantR1(b, num_elements, tensor_lengths); + ctx->SetOutput(1, lengths); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(TensorListConcatOp); +}; + +REGISTER_XLA_OP(Name("TensorListConcatV2"), TensorListConcatOp); + +class TensorListSplitOp : public XlaOpKernel { + public: + explicit TensorListSplitOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_)); + // Only non-nested TensorList is supported for now. + OP_REQUIRES( + ctx, dtype_ != DT_VARIANT, + errors::Unimplemented( + "Only non-nested TensorList is supported for TensorListReserve.")); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaOp input_tensor = ctx->Input(0); + + xla::XlaBuilder* b = input_tensor.builder(); + auto shape_or = b->GetShape(input_tensor); + OP_REQUIRES_OK(ctx, shape_or.status()); + xla::Shape element_shape = shape_or.ConsumeValueOrDie(); + std::vector<int64> element_dims = + xla::SpanToVector(element_shape.dimensions()); + OP_REQUIRES( + ctx, !element_dims.empty(), + errors::Unimplemented("Element dimensions have to be non-empty")); + + std::vector<int64> lengths; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &lengths)); + OP_REQUIRES(ctx, !lengths.empty(), + errors::Unimplemented("Length has to be non-empty")); + int64 length = lengths[0]; + for (int64 len : lengths) { + OP_REQUIRES(ctx, len == length, + errors::Unimplemented("All lengths have to be the same")); + } + OP_REQUIRES( + ctx, element_dims[0] % length == 0, + errors::Unimplemented("Buffer size has to be a multiple of length")); + std::vector<int64> new_dims = {element_dims[0] / length, length}; + for (int i = 1; i < element_dims.size(); i++) { + new_dims.push_back(element_dims[i]); + } + + xla::XlaOp reshaped = xla::Reshape(input_tensor, new_dims); + + xla::XlaOp result; + OP_REQUIRES_OK(ctx, ExecuteTensorListFromTensor(length, reshaped, &result)); + ctx->SetTensorListOutput(0, result); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorListSplitOp); +}; + +REGISTER_XLA_OP(Name("TensorListSplit") + .CompileTimeConstantInput("element_shape") + .CompileTimeConstantInput("lengths"), + TensorListSplitOp); + class TensorListFromTensorOp : public XlaOpKernel { public: explicit TensorListFromTensorOp(OpKernelConstruction* ctx) diff --git a/tensorflow/compiler/tf2xla/tf2xla_test.cc b/tensorflow/compiler/tf2xla/tf2xla_test.cc index 24afe595b18..7ea69f734c9 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_test.cc @@ -99,5 +99,42 @@ TEST(ConvertGraphDefToXla, Sum) { ConvertGraphDefToXla(graph_def, config, client, &computation))); } +TEST(ConvertGraphDefToXla, SumWithUnusedArgument) { + GraphDef graph_def = SumGraph(); + tf2xla::Config config = SumConfig(); + NodeDef* unused = graph_def.add_node(); + unused->set_name("unused"); + unused->set_op("Placeholder"); + (*unused->mutable_attr())["dtype"] = TypeAttrValue(DT_INT32); + config.add_feed()->mutable_id()->set_node_name("unused"); + + xla::LocalClient* client = xla::ClientLibrary::LocalClientOrDie(); + xla::XlaComputation computation; + TF_EXPECT_OK(ConvertGraphDefToXla(graph_def, config, client, &computation)); + + // Set up arguments. + auto x_literal = xla::LiteralUtil::CreateR0<int32>(10); + auto y_literal = xla::LiteralUtil::CreateR0<int32>(32); + auto x_global_or = client->TransferToServer(x_literal); + auto y_global_or = client->TransferToServer(y_literal); + auto unused_global_or = client->TransferToServer(y_literal); + TF_EXPECT_OK(x_global_or.status()); + TF_EXPECT_OK(y_global_or.status()); + TF_EXPECT_OK(unused_global_or.status()); + std::unique_ptr<xla::GlobalData> x_global = + std::move(x_global_or.ValueOrDie()); + std::unique_ptr<xla::GlobalData> y_global = + std::move(y_global_or.ValueOrDie()); + std::unique_ptr<xla::GlobalData> unused_global = + std::move(unused_global_or.ValueOrDie()); + + // Execute and check result. + auto result_or = client->ExecuteAndTransfer( + computation, {x_global.get(), y_global.get(), unused_global.get()}); + TF_EXPECT_OK(result_or.status()); + xla::Literal result = std::move(result_or.ValueOrDie()); + EXPECT_EQ("(\ns32[] 42\n)", result.ToString()); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 3d6083621f4..1cf3e10b774 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/types/variant.h" +#include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/shape_inference.h" #include "tensorflow/compiler/tf2xla/graph_compiler.h" @@ -571,6 +572,10 @@ std::unique_ptr<Graph> XlaCompiler::GetGraph(const FunctionBody* fbody) { std::unique_ptr<Graph> graph(new Graph(options_.flib_def)); CopyGraph(*fbody->graph, graph.get()); + bool is_inside_mustcompile = false; + TryGetNodeAttr(AttrSlice(&fbody->fdef.attr()), kXlaMustCompileAttr, + &is_inside_mustcompile); + // Performs a first function inlining pass before shape inference, since // otherwise shape inference can't see inside functions and a comprehensive // shape_map, including function ops, is needed to constant-propagate Shape @@ -622,6 +627,8 @@ std::unique_ptr<Graph> XlaCompiler::GetGraph(const FunctionBody* fbody) { graph_optimizer_options.inline_multi_device_functions = true; graph_optimizer_options.inline_impl_selection_group_functions = true; graph_optimizer_options.inline_with_single_device_body_placer = true; + graph_optimizer_options.ignore_noinline = is_inside_mustcompile; + optimizer.Optimize(flib_runtime_, flib_runtime_->env(), /*device=*/nullptr, &graph, graph_optimizer_options); diff --git a/tensorflow/compiler/xla/client/lib/math_test.cc b/tensorflow/compiler/xla/client/lib/math_test.cc index 9b8156efe5b..cb79b2ef7db 100644 --- a/tensorflow/compiler/xla/client/lib/math_test.cc +++ b/tensorflow/compiler/xla/client/lib/math_test.cc @@ -236,6 +236,19 @@ XLA_TEST_F(MathTest, SqrtF32) { ComputeAndCompareR0<float>(&builder, 0.0f, {zero_data.get()}, error_spec_); } +XLA_TEST_F(MathTest, SqrtF64) { + XlaBuilder builder(TestName()); + Literal zero_literal = LiteralUtil::Zero(PrimitiveType::F64); + + std::unique_ptr<GlobalData> zero_data = + client_->TransferToServer(zero_literal).ConsumeValueOrDie(); + + XlaOp zero = Parameter(&builder, 0, zero_literal.shape(), "zero"); + Sqrt(zero); + + ComputeAndCompareR0<double>(&builder, 0.0f, {zero_data.get()}, error_spec_); +} + #ifndef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64 XLA_TEST_F(MathTest, ErfInvF64) { XlaBuilder builder(TestName()); diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index a4e5b936153..58365c0f498 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -3188,6 +3188,10 @@ XlaOp Compare(const XlaOp lhs, const XlaOp rhs, broadcast_dimensions, direction); } +XlaOp Compare(const XlaOp lhs, const XlaOp rhs, ComparisonDirection direction) { + return Compare(lhs, rhs, {}, direction); +} + XlaOp Dot(const XlaOp lhs, const XlaOp rhs, const PrecisionConfig* precision_config) { return lhs.builder()->Dot(lhs, rhs, precision_config); diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index b631514248c..426b6d83207 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -889,6 +889,7 @@ class XlaBuilder { friend XlaOp Compare(XlaOp lhs, XlaOp rhs, absl::Span<const int64> broadcast_dimensions, ComparisonDirection direction); + friend XlaOp Compare(XlaOp lhs, XlaOp rhs, ComparisonDirection direction); friend XlaOp Dot(XlaOp lhs, XlaOp rhs, const PrecisionConfig* precision_config); friend XlaOp DotGeneral(XlaOp lhs, XlaOp rhs, @@ -1498,10 +1499,12 @@ XlaOp Lt(XlaOp lhs, XlaOp rhs, XlaOp Le(XlaOp lhs, XlaOp rhs, absl::Span<const int64> broadcast_dimensions = {}); -// Enqueues a comparison instruction onto the computation. +// Enqueues a comparison instruction onto the computation (optionally without +// broadcast_dimensions for consistency with others). XlaOp Compare(XlaOp lhs, XlaOp rhs, absl::Span<const int64> broadcast_dimensions, ComparisonDirection direction); +XlaOp Compare(XlaOp lhs, XlaOp rhs, ComparisonDirection direction); // Enqueues a dot instruction onto the computation. XlaOp Dot(XlaOp lhs, XlaOp rhs, diff --git a/tensorflow/compiler/xla/client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_builder_test.cc index e1733cd179c..4fa47077fca 100644 --- a/tensorflow/compiler/xla/client/xla_builder_test.cc +++ b/tensorflow/compiler/xla/client/xla_builder_test.cc @@ -381,7 +381,18 @@ TEST_F(XlaBuilderTest, Transpose) { EXPECT_THAT(root, op::Transpose(op::Parameter())); } -TEST_F(XlaBuilderTest, AllGather) { +TEST_F(XlaBuilderTest, AllGatherR1) { + XlaBuilder b(TestName()); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4}), "x"); + AllGather(x, /*all_gather_dimension=*/0, /*shard_count=*/4); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + + EXPECT_EQ(root->opcode(), HloOpcode::kAllGather); + EXPECT_TRUE(ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {16}))); +} + +TEST_F(XlaBuilderTest, AllGatherR2) { XlaBuilder b(TestName()); auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x"); AllGather(x, /*all_gather_dimension=*/1, /*shard_count=*/4); diff --git a/tensorflow/compiler/xla/client/xla_computation.h b/tensorflow/compiler/xla/client/xla_computation.h index 3ccbfb28bd0..6a3b17a154a 100644 --- a/tensorflow/compiler/xla/client/xla_computation.h +++ b/tensorflow/compiler/xla/client/xla_computation.h @@ -29,8 +29,8 @@ namespace xla { class XlaComputation { public: XlaComputation() : unique_id_(-1) {} - XlaComputation(const HloModuleProto& proto) - : unique_id_(proto.id()), proto_(proto) {} + XlaComputation(HloModuleProto proto) + : unique_id_(proto.id()), proto_(std::move(proto)) {} ~XlaComputation() {} diff --git a/tensorflow/compiler/xla/debug_options_flags.cc b/tensorflow/compiler/xla/debug_options_flags.cc index 60a563ee956..4152982bf4c 100644 --- a/tensorflow/compiler/xla/debug_options_flags.cc +++ b/tensorflow/compiler/xla/debug_options_flags.cc @@ -55,9 +55,16 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { // b/77879207. opts.set_xla_gpu_disable_multi_streaming(true); - // TODO(jlebar): Disable fastmath once doing so is not a performance - // regression. + // Disable forms of fast math that have caused users problems in the past. opts.set_xla_cpu_enable_fast_math(true); + opts.set_xla_cpu_fast_math_honor_nans(true); + opts.set_xla_cpu_fast_math_honor_infs(true); + opts.set_xla_cpu_fast_math_honor_functions(true); + opts.set_xla_cpu_fast_math_honor_division(true); + + // By default, copy TF's Eigen style min_max behavior with nans. + opts.set_xla_cpu_enable_fast_min_max(false); + opts.set_xla_gpu_enable_fast_min_max(true); opts.set_xla_allow_excess_precision(true); @@ -261,6 +268,12 @@ static void AllocateFlags() { "When xla_cpu_enable_fast_math is true then this controls whether we " "forbid to approximate calculations for functions. Ignored when " "xla_cpu_enable_fast_math is false.")); + flag_objects->push_back(tensorflow::Flag( + "xla_cpu_enable_fast_min_max", + bool_setter_for(&DebugOptions::set_xla_cpu_enable_fast_min_max), + flag_values->xla_cpu_enable_fast_min_max(), + "Enable fast floating point min/max lowering that always propagates " + "NaNs.")); flag_objects->push_back(tensorflow::Flag( "xla_gpu_enable_fast_min_max", bool_setter_for(&DebugOptions::set_xla_gpu_enable_fast_min_max), diff --git a/tensorflow/compiler/xla/g3doc/operation_semantics.md b/tensorflow/compiler/xla/g3doc/operation_semantics.md index 495701eaac2..002d07184a7 100644 --- a/tensorflow/compiler/xla/g3doc/operation_semantics.md +++ b/tensorflow/compiler/xla/g3doc/operation_semantics.md @@ -2299,20 +2299,26 @@ The output is guaranteed to be a deterministic function of the initial state but it is *not* guaranteed to be deterministic between backends and different compiler versions. -<b>`RngBitGenerator(algorithm, key, shape)`</b> | Arguments | Type | Semantics | -|---------------- | ----------------- | ------------------------------------- | -| `algorithm` | `RandomAlgorithm` | PRNG algorithm to be used. | | -`initial_state` | `XlaOp` | Initial state for the PRNG algorithm. | | `shape` | -`Shape` | Output shape for generated data. | +<b>`RngBitGenerator(algorithm, key, shape)`</b> -Available values for `algorithm`: * `rng_default`: Backend specific algorithm -with backend specific shape requirements. * `rng_three_fry`: ThreeFry -counter-based PRNG algorithm. The `initial_state` shape is `u64[2]` with -arbitrary values. -[Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.](http://www.thesalmons.org/john/random123/papers/random123sc11.pdf) -* `rng_philox`: Philox algorithm to generate random numbers in parallel. The -`initial_state` shape is `u64[3]` with arbitrary values. -[Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.](http://www.thesalmons.org/john/random123/papers/random123sc11.pdf) +Arguments | Type | Semantics +--------------- | ----------------- | ------------------------------------- +`algorithm` | `RandomAlgorithm` | PRNG algorithm to be used. +`initial_state` | `XlaOp` | Initial state for the PRNG algorithm. +`shape` | `Shape` | Output shape for generated data. + +Available values for `algorithm`: + +- `rng_default`: Backend specific algorithm with backend specific shape + requirements. + +- `rng_three_fry`: ThreeFry counter-based PRNG algorithm. The `initial_state` + shape is `u64[2]` with arbitrary values. + [Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.](http://www.thesalmons.org/john/random123/papers/random123sc11.pdf) + +- `rng_philox`: Philox algorithm to generate random numbers in parallel. The + `initial_state` shape is `u64[3]` with arbitrary values. + [Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.](http://www.thesalmons.org/john/random123/papers/random123sc11.pdf) ## Scatter diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index cbbad741ce3..73c37d6b2f3 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -2104,6 +2104,32 @@ MutableBorrowingLiteral::MutableBorrowingLiteral(const char* src_buf_ptr, root_piece_->set_subshape(shape_.get()); } +MutableBorrowingLiteral::MutableBorrowingLiteral(absl::Span<char*> src_buf_ptrs, + const Shape& shape) + : MutableLiteralBase() { + shape_ = absl::make_unique<Shape>(shape); + if (!shape_->IsTuple()) { + CHECK_EQ(src_buf_ptrs.size(), 1); + root_piece_ = new Piece(); + root_piece_->set_buffer(const_cast<char*>(src_buf_ptrs[0])); + root_piece_->set_subshape(shape_.get()); + } else { + CHECK(!ShapeUtil::IsNestedTuple(*shape_)); + CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(*shape_)); + root_piece_ = new Piece(); + root_piece_->set_subshape(shape_.get()); + + for (int i = 0; i < src_buf_ptrs.size(); ++i) { + Piece child_piece; + const auto& src_shape = shape_->tuple_shapes(i); + CHECK(src_shape.IsArray()); + child_piece.set_subshape(&src_shape); + child_piece.set_buffer(src_buf_ptrs[i]); + root_piece_->emplace_back(std::move(child_piece)); + } + } +} + MutableBorrowingLiteral::~MutableBorrowingLiteral() { if (root_piece_ != nullptr) { delete root_piece_; diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h index 1553d042e80..a2be92fbf5b 100644 --- a/tensorflow/compiler/xla/literal.h +++ b/tensorflow/compiler/xla/literal.h @@ -776,6 +776,10 @@ class MutableBorrowingLiteral : public MutableLiteralBase { const ShapeIndex& view_root); MutableBorrowingLiteral(const char* src_buf_ptr, const Shape& shape); + // Create a literal from a list of buffers and a shape. + // Returns a tuple literal if `shape` is a tuple type. + MutableBorrowingLiteral(absl::Span<char*> src_buf_ptrs, const Shape& shape); + private: // Recursively copies the subtree from the `src_piece` at the given child // index to the `dest_piece`. For buffers only the pointers are copied, but diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index 8c6bc84cf8e..10737489331 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -1,5 +1,6 @@ load("//tensorflow/core/platform:build_config.bzl", "pyx_library") load("//tensorflow/compiler/xla:xla.bzl", "xla_py_test_deps") +load("//tensorflow:tensorflow.bzl", "tf_cc_test") # buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "pybind_extension") @@ -186,6 +187,89 @@ cc_library( ], ) +cc_library( + name = "ops", + srcs = ["ops.cc"], + hdrs = ["ops.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":types", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/client/lib:comparators", + "//tensorflow/compiler/xla/client/lib:math", + "//tensorflow/compiler/xla/client/lib:qr", + "//tensorflow/compiler/xla/client/lib:self_adjoint_eig", + "//tensorflow/compiler/xla/client/lib:sorting", + "//tensorflow/compiler/xla/client/lib:svd", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + "@pybind11", + ], +) + +cc_library( + name = "outfeed_receiver", + srcs = ["outfeed_receiver.cc"], + hdrs = ["outfeed_receiver.h"], + deps = [ + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/pjrt:pjrt_client", + "//tensorflow/core/profiler/lib:traceme", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:str_format", + ], +) + +tf_cc_test( + name = "cpu_outfeed_receiver_test", + size = "small", + srcs = ["outfeed_receiver_test.cc"], + deps = [ + ":outfeed_receiver", + "//tensorflow/compiler/jit:xla_cpu_jit", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/client:executable_build_options", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/pjrt:cpu_device", + "//tensorflow/compiler/xla/pjrt:pjrt_client", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "@com_google_absl//absl/synchronization", + ], +) + +cc_library( + name = "outfeed_receiver_py", + srcs = ["outfeed_receiver_py.cc"], + hdrs = ["outfeed_receiver_py.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":outfeed_receiver", + ":types", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/pjrt:pjrt_client", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/synchronization", + "@pybind11", + ], +) + config_setting( name = "enable_gpu", values = {"define": "xla_python_enable_gpu=true"}, @@ -205,7 +289,9 @@ pybind_extension( deps = [ ":bfloat16", ":dlpack", + ":ops", ":python_ref_manager", + ":outfeed_receiver_py", ":types", "@com_google_absl//absl/base", "@com_google_absl//absl/hash", @@ -228,12 +314,6 @@ pybind_extension( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/lib:comparators", - "//tensorflow/compiler/xla/client/lib:math", - "//tensorflow/compiler/xla/client/lib:qr", - "//tensorflow/compiler/xla/client/lib:self_adjoint_eig", - "//tensorflow/compiler/xla/client/lib:sorting", - "//tensorflow/compiler/xla/client/lib:svd", "//tensorflow/compiler/xla/pjrt:cpu_device", "//tensorflow/compiler/xla/pjrt:nvidia_gpu_device", "//tensorflow/compiler/xla/pjrt:pjrt_client", @@ -260,8 +340,8 @@ pybind_extension( "//tensorflow/core:lib_internal_impl", # buildcleaner: keep "//tensorflow/core/profiler/lib:profiler_backends", "//tensorflow/core/profiler/lib:profiler_session", - "//tensorflow/core/profiler/lib:traceme", "//tensorflow/core/profiler/rpc:profiler_server", + "//tensorflow/python/profiler/internal:traceme_wrapper", "//tensorflow/stream_executor:device_memory_allocator", "//tensorflow/stream_executor:platform", ] + select({ diff --git a/tensorflow/compiler/xla/python/ops.cc b/tensorflow/compiler/xla/python/ops.cc new file mode 100644 index 00000000000..89891d39f78 --- /dev/null +++ b/tensorflow/compiler/xla/python/ops.cc @@ -0,0 +1,356 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/python/ops.h" + +#include <string> +#include <vector> + +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "pybind11/attr.h" +#include "pybind11/pybind11.h" +#include "tensorflow/compiler/xla/client/lib/comparators.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/lib/qr.h" +#include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h" +#include "tensorflow/compiler/xla/client/lib/sorting.h" +#include "tensorflow/compiler/xla/client/lib/svd.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/python/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +namespace py = pybind11; + +void BuildOpsSubmodule(py::module* m) { + // ops submodule, containing free functions that add operators to an + // XlaBuilder. + py::module ops = m->def_submodule("ops", "XLA operations"); + + py::enum_<TriangularSolveOptions::Transpose>( + ops, "TriangularSolveOptions_Transpose") + .value("TRANSPOSE_INVALID", TriangularSolveOptions::TRANSPOSE_INVALID) + .value("NO_TRANSPOSE", TriangularSolveOptions::NO_TRANSPOSE) + .value("TRANSPOSE", TriangularSolveOptions::TRANSPOSE) + .value("ADJOINT", TriangularSolveOptions::ADJOINT); + + ops.def("AfterAll", &AfterAll, py::arg("builder"), py::arg("tokens")); + ops.def( + "AllReduce", + static_cast<XlaOp (*)( + XlaOp, const XlaComputation&, absl::Span<const ReplicaGroup>, + const absl::optional<ChannelHandle>&, const absl::optional<Shape>&)>( + &AllReduce), + py::arg("operand"), py::arg("computation"), + py::arg("replica_groups") = py::list(), + py::arg("channel_id") = absl::nullopt, + py::arg("shape_with_layout") = absl::nullopt); + ops.def("AllToAll", &AllToAll, py::arg("operand"), py::arg("split_dimension"), + py::arg("concat_dimension"), py::arg("split_count"), + py::arg("replica_groups") = py::list(), + py::arg("layout") = absl::nullopt); + ops.def("CollectivePermute", &CollectivePermute, py::arg("operand"), + py::arg("source_target_pairs")); + ops.def("CreateToken", &CreateToken, py::arg("builder")); + ops.def("CrossReplicaSum", + static_cast<XlaOp (*)(XlaOp, absl::Span<const ReplicaGroup>)>( + &CrossReplicaSum), + py::arg("operand"), py::arg("replica_groups") = py::list()); + ops.def("BitcastConvertType", &BitcastConvertType, py::arg("operand"), + py::arg("new_element_type")); + ops.def("Broadcast", &Broadcast, py::arg("operand"), py::arg("sizes")); + ops.def("BroadcastInDim", &BroadcastInDim, py::arg("operand"), + py::arg("shape"), py::arg("broadcast_dimensions")); + ops.def("Call", &Call, py::arg("builder"), py::arg("computation"), + py::arg("operands")); + ops.def("Cholesky", &Cholesky, py::arg("a"), py::arg("lower") = true); + ops.def("Clamp", &Clamp, py::arg("min"), py::arg("operand"), py::arg("max")); + ops.def("Collapse", &Collapse, py::arg("operand"), py::arg("dimensions")); + ops.def("ConcatInDim", &ConcatInDim, py::arg("builder"), py::arg("operands"), + py::arg("dimension")); + ops.def("Conditional", + static_cast<XlaOp (*)(XlaOp, absl::Span<const XlaComputation* const>, + absl::Span<const XlaOp>)>(&Conditional), + py::arg("branch_index"), py::arg("branch_computations"), + py::arg("branch_operands")); + ops.def("Conditional", + static_cast<XlaOp (*)(XlaOp, XlaOp, const XlaComputation&, XlaOp, + const XlaComputation&)>(&Conditional), + py::arg("predicate"), py::arg("true_operand"), + py::arg("true_computation"), py::arg("false_operand"), + py::arg("false_computation")); + ops.def("Constant", &ConstantLiteral, py::arg("builder"), py::arg("literal")); + ops.def("ConstantLiteral", &ConstantLiteral, py::arg("builder"), + py::arg("literal")); + ops.def("ConvGeneralDilated", &ConvGeneralDilated, py::arg("lhs"), + py::arg("rhs"), py::arg("window_strides"), py::arg("padding"), + py::arg("lhs_dilation"), py::arg("rhs_dilation"), + py::arg("dimension_numbers"), py::arg("feature_group_count") = 1, + py::arg("batch_group_count") = 1, + py::arg("precision_config") = nullptr); + ops.def("ConvertElementType", &ConvertElementType, py::arg("operand"), + py::arg("new_element_type")); + ops.def( + "CustomCall", + [](XlaBuilder* builder, const py::bytes& call_target_name, + absl::Span<const XlaOp> operands, const Shape& shape, + const py::bytes& opaque) -> XlaOp { + return CustomCall(builder, call_target_name, operands, shape, opaque); + }, + py::arg("builder"), py::arg("call_target_name"), py::arg("operands"), + py::arg("shape"), py::arg("opaque") = py::bytes("")); + ops.def( + "CustomCallWithLayout", + [](XlaBuilder* builder, const py::bytes& call_target_name, + absl::Span<const XlaOp> operands, const Shape& shape_with_layout, + absl::Span<const Shape> operand_shapes_with_layout, + const py::bytes& opaque) -> XlaOp { + return CustomCallWithLayout(builder, call_target_name, operands, + shape_with_layout, + operand_shapes_with_layout, opaque); + }, + py::arg("builder"), py::arg("call_target_name"), py::arg("operands"), + py::arg("shape_with_layout"), py::arg("operand_shapes_with_layout"), + py::arg("opaque") = py::bytes("")); + ops.def("Dot", &Dot, py::arg("lhs"), py::arg("rhs"), + py::arg("precision_config") = nullptr); + ops.def("DotGeneral", &DotGeneral, py::arg("lhs"), py::arg("rhs"), + py::arg("dimension_numbers"), py::arg("precision_config") = nullptr); + ops.def("DynamicSlice", + static_cast<XlaOp (*)(XlaOp, absl::Span<const XlaOp>, + absl::Span<const int64>)>(&DynamicSlice), + py::arg("operand"), py::arg("start_indices"), py::arg("slice_sizes")); + ops.def("DynamicUpdateSlice", + static_cast<XlaOp (*)(XlaOp, XlaOp, absl::Span<const XlaOp>)>( + &DynamicUpdateSlice), + py::arg("operand"), py::arg("update"), py::arg("start_indices")); + + ops.def("Fft", &Fft, py::arg("operand"), py::arg("fft_type"), + py::arg("fft_length")); + + ops.def("Gather", &Gather, py::arg("a"), py::arg("start_indices"), + py::arg("dimension_numbers"), py::arg("slice_sizes"), + py::arg("indices_are_sorted") = false); + ops.def("GetTupleElement", &GetTupleElement, py::arg("tuple_data"), + py::arg("index")); + ops.def("InfeedWithToken", &InfeedWithToken, py::arg("token"), + py::arg("shape"), py::arg("config") = ""); + ops.def("Iota", + static_cast<XlaOp (*)(XlaBuilder*, const Shape&, int64)>(&Iota), + py::arg("builder"), py::arg("shape"), py::arg("iota_dimension")); + ops.def("Iota", + static_cast<XlaOp (*)(XlaBuilder*, PrimitiveType, int64)>(&Iota), + py::arg("builder"), py::arg("type"), py::arg("size")); + ops.def("Map", &Map, py::arg("builder"), py::arg("operands"), + py::arg("computation"), py::arg("dimensions"), + py::arg("static_operands") = py::list()); + ops.def("NextAfter", &NextAfter, py::arg("from"), py::arg("to")); + ops.def("OutfeedWithToken", &OutfeedWithToken, py::arg("operand"), + py::arg("token"), py::arg("shape_with_layout"), + py::arg("outfeed_config") = ""); + ops.def("Pad", &Pad, py::arg("operand"), py::arg("padding_value"), + py::arg("padding_config")); + ops.def("Parameter", + static_cast<XlaOp (*)(XlaBuilder*, int64, const Shape&, + const std::string&, const std::vector<bool>&)>( + &Parameter), + py::arg("builder"), py::arg("parameter_number"), py::arg("shape"), + py::arg("name") = "", + py::arg("replicated_at_leaf_buffers") = std::vector<bool>()); + ops.def( + "QR", + [](XlaOp a, bool full_matrices) -> StatusOr<std::pair<XlaOp, XlaOp>> { + TF_ASSIGN_OR_RETURN(auto qr, QRDecomposition(a, full_matrices)); + return std::make_pair(qr.q, qr.r); + }, + py::arg("operand"), py::arg("full_matrices")); + ops.def( + "Eigh", + [](XlaOp a, bool lower, int64 max_iter, + float epsilon) -> std::pair<XlaOp, XlaOp> { + auto eigh = SelfAdjointEig(a, lower, max_iter, epsilon); + return std::make_pair(eigh.v, eigh.w); + }, + py::arg("a"), py::arg("lower") = true, py::arg("max_iter") = 100, + py::arg("epsilon") = 1e-6); + ops.def( + "SVD", + [](XlaOp a, int64 max_iter, + float epsilon) -> std::tuple<XlaOp, XlaOp, XlaOp> { + auto svd = SVD(a, max_iter, epsilon); + return std::make_tuple(svd.u, svd.d, svd.v); + }, + py::arg("a"), py::arg("max_iter") = 100, py::arg("epsilon") = 1e-6); + ops.def("Reduce", + static_cast<XlaOp (*)(XlaBuilder*, absl::Span<const XlaOp>, + absl::Span<const XlaOp>, const XlaComputation&, + absl::Span<const int64>)>(&Reduce), + py::arg("builder"), py::arg("operands"), py::arg("init_values"), + py::arg("computation"), py::arg("dimensions_to_reduce")); + ops.def("ReducePrecision", &ReducePrecision, py::arg("operand"), + py::arg("exponent_bits"), py::arg("mantissa_bits")); + ops.def("ReduceWindowWithGeneralPadding", &ReduceWindowWithGeneralPadding, + py::arg("operand"), py::arg("init_value"), py::arg("computation"), + py::arg("window_dimensions"), py::arg("window_strides"), + py::arg("base_dilations"), py::arg("window_dilations"), + py::arg("padding")); + ops.def("ReplicaId", &ReplicaId, py::arg("builder")); + ops.def("Reshape", + static_cast<XlaOp (*)(XlaOp, absl::Span<const int64>, + absl::Span<const int64>)>(&Reshape), + py::arg("operand"), py::arg("dimensions"), py::arg("new_sizes")); + ops.def("Reshape", + static_cast<XlaOp (*)(XlaOp, absl::Span<const int64>)>(&Reshape), + py::arg("operand"), py::arg("new_sizes")); + ops.def("Rev", &Rev, py::arg("operand"), py::arg("dimensions")); + ops.def("RngNormal", &RngNormal, py::arg("mu"), py::arg("sigma"), + py::arg("shape")); + ops.def("RngUniform", &RngUniform, py::arg("a"), py::arg("b"), + py::arg("shape")); + ops.def("Scatter", &Scatter, py::arg("input"), py::arg("scatter_indices"), + py::arg("updates"), py::arg("update_computation"), + py::arg("dimension_numbers"), py::arg("indices_are_sorted") = false, + py::arg("unique_indices") = false); + ops.def("Select", &Select, py::arg("pred"), py::arg("on_true"), + py::arg("on_false")); + ops.def("SelectAndScatterWithGeneralPadding", + &SelectAndScatterWithGeneralPadding, py::arg("operand"), + py::arg("select"), py::arg("window_dimensions"), + py::arg("window_strides"), py::arg("padding"), py::arg("source"), + py::arg("init_value"), py::arg("scatter")); + ops.def("Slice", &Slice, py::arg("operand"), py::arg("start_indices"), + py::arg("limit_indices"), py::arg("strides")); + ops.def("SliceInDim", &SliceInDim, py::arg("operand"), py::arg("start_index"), + py::arg("limit_index"), py::arg("stride"), py::arg("dimno")); + ops.def( + "Sort", + [](XlaBuilder* builder, absl::Span<const XlaOp> operands, + absl::optional<const XlaComputation*> comparator, int64 dimension, + bool is_stable) -> XlaOp { + return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { + std::vector<PrimitiveType> operand_types; + for (const auto& operand : operands) { + TF_ASSIGN_OR_RETURN(auto operand_shape, builder->GetShape(operand)); + operand_types.push_back(operand_shape.element_type()); + } + + if (comparator) { + return Sort(operands, **comparator, dimension, is_stable); + } else { + return Sort(operands, + CreateScalarLtComputation(operand_types, builder), + dimension, is_stable); + } + }); + }, + py::arg("builder"), py::arg("operands"), + py::arg("comparator") = absl::nullopt, py::arg("dimension") = -1, + py::arg("is_stable") = false); + ops.def("TopK", &TopK, py::arg("input"), py::arg("k")); + ops.def("Transpose", &Transpose, py::arg("operand"), py::arg("permutation")); + ops.def("TriangularSolve", &TriangularSolve, py::arg("a"), py::arg("b"), + py::arg("left_side"), py::arg("lower"), py::arg("unit_diagonal"), + py::arg("transpose_a")); + ops.def("Tuple", &Tuple, py::arg("builder"), py::arg("elements")); + ops.def("While", &While, py::arg("condition"), py::arg("body"), + py::arg("init")); + + ops.def("Igamma", &Igamma, py::arg("a"), py::arg("x")); + ops.def("Igammac", &Igammac, py::arg("a"), py::arg("x")); + ops.def("IgammaGradA", &IgammaGradA, py::arg("a"), py::arg("x")); + ops.def("RandomGammaGrad", &RandomGammaGrad, py::arg("a"), py::arg("x")); + ops.def("RegularizedIncompleteBeta", &RegularizedIncompleteBeta, py::arg("a"), + py::arg("b"), py::arg("x")); + +#define BINARY_OP(op) \ + ops.def( \ + #op, \ + [](XlaOp a, XlaOp b, absl::optional<std::vector<int64>> dims) { \ + return dims ? op(a, b, *dims) : op(a, b); \ + }, \ + py::arg("lhs"), py::arg("rhs"), \ + py::arg("broadcast_dimensions") = absl::nullopt) + BINARY_OP(Eq); + BINARY_OP(Ne); + BINARY_OP(Ge); + BINARY_OP(Gt); + BINARY_OP(Lt); + BINARY_OP(Le); + BINARY_OP(Add); + BINARY_OP(Sub); + BINARY_OP(Mul); + BINARY_OP(Div); + BINARY_OP(Rem); + BINARY_OP(Max); + BINARY_OP(Min); + BINARY_OP(And); + BINARY_OP(Or); + BINARY_OP(Xor); + BINARY_OP(ShiftLeft); + BINARY_OP(ShiftRightArithmetic); + BINARY_OP(ShiftRightLogical); + BINARY_OP(Atan2); + BINARY_OP(Pow); + BINARY_OP(Complex); +#undef BINARY_OP + +#define UNARY_OP(op) ops.def(#op, &op) + UNARY_OP(Not); + UNARY_OP(PopulationCount); + UNARY_OP(Clz); + UNARY_OP(Abs); + UNARY_OP(Exp); + UNARY_OP(Expm1); + UNARY_OP(Floor); + UNARY_OP(Ceil); + UNARY_OP(Round); + UNARY_OP(Log); + UNARY_OP(Log1p); + UNARY_OP(Sign); + UNARY_OP(Cos); + UNARY_OP(Sin); + UNARY_OP(Tanh); + UNARY_OP(IsFinite); + UNARY_OP(Neg); + UNARY_OP(Sqrt); + UNARY_OP(Rsqrt); + UNARY_OP(Square); + UNARY_OP(Reciprocal); + UNARY_OP(Erfc); + UNARY_OP(Erf); + UNARY_OP(ErfInv); + UNARY_OP(Lgamma); + UNARY_OP(Digamma); + UNARY_OP(BesselI0e); + UNARY_OP(BesselI1e); + UNARY_OP(Acos); + UNARY_OP(Asin); + UNARY_OP(Atan); + UNARY_OP(Tan); + UNARY_OP(Acosh); + UNARY_OP(Asinh); + UNARY_OP(Atanh); + UNARY_OP(Cosh); + UNARY_OP(Sinh); + UNARY_OP(Real); + UNARY_OP(Imag); + UNARY_OP(Conj); +#undef UNARY_OP +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/python/ops.h b/tensorflow/compiler/xla/python/ops.h new file mode 100644 index 00000000000..7fe34e941ba --- /dev/null +++ b/tensorflow/compiler/xla/python/ops.h @@ -0,0 +1,27 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_OPS_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_OPS_H_ + +#include "pybind11/pybind11.h" + +namespace xla { + +void BuildOpsSubmodule(pybind11::module* m); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_OPS_H_ diff --git a/tensorflow/compiler/xla/python/outfeed_receiver.cc b/tensorflow/compiler/xla/python/outfeed_receiver.cc new file mode 100644 index 00000000000..0be4167c397 --- /dev/null +++ b/tensorflow/compiler/xla/python/outfeed_receiver.cc @@ -0,0 +1,492 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/python/outfeed_receiver.h" + +#include <sys/types.h> + +#include <memory> +#include <sstream> + +#include "absl/container/flat_hash_map.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_format.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/profiler/lib/traceme.h" + +// Implementation notes: +// +// Startup: +// ------- +// +// The startup is initiated by a call from Python to StartOutfeedReceiver, +// which starts N threads for listening to the N devices and for enqueueing +// the received data into a callback queue. There is one additional callback +// thread for dequeing the data and invoking the Python callback. +// +// Framing protocol +// ---------------- +// +// The outfeed mechanism has a single channel and the receiver must know +// exactly the shape and number of outfeed operations issued by the compiled +// code. This makes it hard to use outfeed in conditionals and loops and +// especially when outfeeding different-shaped data. +// +// To address this, when we compile the code we capture the shape of the +// data being outfed, and we generate a consumer ID (uint32_t) that is unique +// across the lifetime of the program to: the Python callable to callback to, +// the shape of the arguments, the keyword arguments to pass to the callable. +// Each outfeed payload is preceeded by a header (of shape u32[2]) with a +// special first value and the consumer ID. We maintain a registry of shapes +// by consumer ID. When receiving we lookup the shape by consumer ID, and then +// we read the payload. +// +// Back pressure: +// -------------- +// +// We maintain a sum of the bytes from all the data waiting in the callback +// queue. The listening threads will wait for the sum to drop below a +// configurable threshold, default 256Mb. While the listening thread is waiting, +// on CPU and GPU the next outfeed operation from the device will block. On +// TPU there is a buffer, but eventually the TPU will also block. +// +// Shutdown: +// --------- +// +// The shutdown is initiated automatically when the last reference to the +// outfeed receiver object is dropped, and the Python garbage collector invokes +// the destructor. +// +// The shutdown sequence is implemented as follows: +// * we enqueue on all devices a computation that outfeeds a special header +// with customer ID kOutfeedCidShutdown. +// * when each listening threads gets the shutdown header, it decrements +// a counter of listening threads, and if the counter reaches 0, it +// enqueues a special shutdown callback. +// * when the callback thread gets the shutdown callback marker, it terminates. +// * the shutdown code waits until all threads terminate. +// +// Since we currently keep the shape registry in the OutfeedReceiver, it is +// not safe to replace the OutfeedReceiver instance during the lifetime of +// the JAX program, or else previously cached jitted computations may refer +// to previously cached shapes. This can be solved, but for now we disallow +// replacing the OutfeedReceiver, and do not provide a Shutdown API to the +// Python program. + +namespace xla { + +// The header contains: +// 0. kOutfeedHeaderStart +// 1. consumer id +int constexpr kOutfeedHeaderWords = 2; +uint32_t constexpr kOutfeedHeaderStart = 271828; +// Special consumer IDs, without outfeed payload. +uint32_t constexpr kOutfeedCidShutdown = 0; + +// A Device and its PjRtClient. +struct DeviceWithClient { + Device* device; + std::shared_ptr<PjRtClient> client; +}; + +// Encapsulates data received from a device outfeed. +class OutfeedData { + public: + OutfeedData(DeviceWithClient device_client, uint32_t consumer_id, Shape shape) + : device_client_(device_client), + consumer_id_(consumer_id), + shape_(shape), + literal_(nullptr), + literal_size_bytes_(0) {} + + DeviceWithClient device_client() { return device_client_; } + uint32_t consumer_id() const { return consumer_id_; } + Shape shape() const { return shape_; } + std::unique_ptr<Literal> literal() { + CHECK(literal_); + return std::move(literal_); + } + + void SetLiteral(std::unique_ptr<Literal> literal); + + ssize_t literal_size_bytes() const { return literal_size_bytes_; } + + std::string DebugString() const; + + private: + DeviceWithClient device_client_; + uint32_t consumer_id_; + Shape shape_; + std::unique_ptr<Literal> literal_; + ssize_t literal_size_bytes_; +}; + +void OutfeedData::SetLiteral(std::unique_ptr<Literal> literal) { + literal_ = std::move(literal); + shape_ = literal_->shape(); + int total_size_bytes = 0; + ShapeUtil::ForEachSubshape( + shape_, [&](const Shape& literal_subshape, const ShapeIndex& index) { + if (!literal_subshape.IsTuple()) { + total_size_bytes += ShapeUtil::ByteSizeOf(literal_subshape, 8); + } + }); + literal_size_bytes_ = total_size_bytes; +} + +std::string OutfeedData::DebugString() const { + return absl::StrFormat("dev=%s; cons=%d; shape=%s", + device_client_.device->DebugString(), consumer_id_, + shape_.ToString()); +} + +class OutfeedReceiverImpl { + public: + OutfeedReceiverImpl(OutfeedReceiver::Callback callback, + std::vector<std::shared_ptr<PjRtClient>> clients, + ssize_t max_callback_queue_size_bytes); + + OutfeedReceiverImpl(const OutfeedReceiverImpl&) = delete; + OutfeedReceiverImpl& operator=(const OutfeedReceiverImpl&) = delete; + + // Blocks until all data has been received from devices and all data + // in the queue has been passed to Python. + ~OutfeedReceiverImpl(); + + void Start(); + + StatusOr<XlaOp> AddOutfeedToBuilder(XlaBuilder* builder, XlaOp token, + uint32_t consumer_id, + std::vector<XlaOp> arrays); + + private: + bool CallbackQueueNotEmpty() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + return !callback_queue_.empty(); + } + + bool CallbackQueueHasSpace() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + return callback_queue_size_bytes_ < max_callback_queue_size_bytes_; + } + + bool ShutdownDone() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + return (num_working_callback_threads_ == 0 && num_listening_threads_ == 0); + } + + void CallbackThreadLoop(); + void DeviceListenerThreadLoop(int device_idx); + + // Enqueues to a device an outfeed operation with a shutdown consumer ID. + Status SendShutdownOutfeedHeader(int device_idx); + + // Receives a raw Literal from a device outfeed. + StatusOr<std::unique_ptr<Literal>> ReceiveRawFromOutfeed(const Device* device, + const Shape& shape); + + // Enqueues received data in the callbaback queue. + void EnqueueReceivedData(std::unique_ptr<OutfeedData> received) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Shuts down the threads. See implementation notes at top of file. + // It is not safe to restart an OutfeedReceiver after shutting down one. + void Shutdown(); + + OutfeedReceiver::Callback callback_; + // The devices on which we are listening, with their clients. + std::vector<DeviceWithClient> devices_; + // Maximum bytes capacity of the callback queue. + uint64_t max_callback_queue_size_bytes_; + + absl::Mutex mu_; + // Registered shapes by consumer id. + // The shape registry must be alive as long as the program exists. + // Right now we tell the user to never restart after Shutdown. + absl::flat_hash_map<uint32_t, Shape> shape_registry_ TF_GUARDED_BY(mu_); + // How many bytes of Literal are in the callback queue. + uint64_t callback_queue_size_bytes_ TF_GUARDED_BY(mu_); + // Threads listening. + int num_listening_threads_ TF_GUARDED_BY(mu_); + bool shutdown_started_ TF_GUARDED_BY(mu_); + + // How many callback threads are still working. Used for shutdown. + int num_working_callback_threads_ TF_GUARDED_BY(mu_); + + std::queue<std::unique_ptr<OutfeedData>> callback_queue_ TF_GUARDED_BY(mu_); + // The threadpool must come last to ensure the queue exists + // when the pool destructor is called. + std::unique_ptr<tensorflow::thread::ThreadPool> threads_; +}; + +OutfeedReceiverImpl::OutfeedReceiverImpl( + OutfeedReceiver::Callback callback, + std::vector<std::shared_ptr<PjRtClient>> clients, + ssize_t max_callback_queue_size_bytes) { + callback_ = callback; + max_callback_queue_size_bytes_ = max_callback_queue_size_bytes; + for (const auto& client : clients) { + for (const auto& device : client->devices()) { + devices_.push_back(DeviceWithClient{device.get(), client}); + } + } + CHECK_GT(devices_.size(), 0); + + callback_queue_size_bytes_ = 0; + num_listening_threads_ = 0; + num_working_callback_threads_ = 0; + shutdown_started_ = false; +} + +void OutfeedReceiverImpl::Start() { + { + absl::MutexLock lock(&mu_); + CHECK(!shutdown_started_); + } + int num_threads = 1 + devices_.size(); + threads_ = absl::make_unique<tensorflow::thread::ThreadPool>( + tensorflow::Env::Default(), "outfeed_receiver", num_threads); + threads_->Schedule([this]() { CallbackThreadLoop(); }); + for (int device_idx = 0; device_idx < devices_.size(); ++device_idx) { + threads_->Schedule( + [this, device_idx]() { DeviceListenerThreadLoop(device_idx); }); + } +} + +void OutfeedReceiverImpl::Shutdown() { + VLOG(2) << "Shutdown start"; + { + absl::MutexLock lock(&mu_); + CHECK(!shutdown_started_); + shutdown_started_ = true; + } + for (int device_idx = 0; device_idx < devices_.size(); ++device_idx) { + CHECK(SendShutdownOutfeedHeader(device_idx).ok()); + } + VLOG(2) << "Shutdown waiting for listening and callback threads to stop"; + absl::MutexLock lock(&mu_); + mu_.Await(absl::Condition(this, &OutfeedReceiverImpl::ShutdownDone)); + VLOG(2) << "Shutdown done"; +} + +OutfeedReceiverImpl::~OutfeedReceiverImpl() { + VLOG(2) << "~OutfeedReceiverImpl"; + Shutdown(); +} + +void OutfeedReceiverImpl::DeviceListenerThreadLoop(int device_idx) { + { + absl::MutexLock lock(&mu_); + ++num_listening_threads_; + } + DeviceWithClient device_client = devices_[device_idx]; + while (true) { + Shape header_shape = ShapeUtil::MakeShape(U32, {kOutfeedHeaderWords}); + std::unique_ptr<Literal> header = + ReceiveRawFromOutfeed(device_client.device, header_shape).ValueOrDie(); + absl::Span<uint32_t> header_data = header->data<uint32>(); + CHECK_EQ(header_data.size(), kOutfeedHeaderWords); + CHECK_EQ(header_data[0], kOutfeedHeaderStart); + uint32_t consumer_id = header_data[1]; + Shape shape; + { + absl::MutexLock lock(&mu_); + auto registered_shape = shape_registry_.find(consumer_id); + if (registered_shape == shape_registry_.end()) { + LOG(FATAL) + << "[" << device_client.device->DebugString() + << "] Cannot find registered shape for consumer ID " << consumer_id + << ". Perhaps the code was compiled with a different instance " + << "of OutfeedReceiver."; + } + shape = registered_shape->second; + } + auto received = + absl::make_unique<OutfeedData>(device_client, consumer_id, shape); + VLOG(2) << "Listener received header " << received->DebugString(); + if (consumer_id == kOutfeedCidShutdown) { + VLOG(2) << "[" << device_client.device->DebugString() + << "] Listener received shutdown header"; + absl::MutexLock lock(&mu_); + --num_listening_threads_; + if (num_listening_threads_ == 0) { + VLOG(2) << "Last listener shutdown; enqueue shutdown callback"; + EnqueueReceivedData(std::move(received)); + } + return; + } + std::unique_ptr<Literal> data = + ReceiveRawFromOutfeed(device_client.device, shape).ValueOrDie(); + received->SetLiteral(std::move(data)); + absl::MutexLock lock(&mu_); + EnqueueReceivedData(std::move(received)); + } +} + +void OutfeedReceiverImpl::EnqueueReceivedData( + std::unique_ptr<OutfeedData> received) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + mu_.Await(absl::Condition(this, &OutfeedReceiverImpl::CallbackQueueHasSpace)); + ssize_t literal_size_bytes = received->literal_size_bytes(); + callback_queue_size_bytes_ += literal_size_bytes; + VLOG(2) << "Listener enqueues data " << received->DebugString() << " of size " + << literal_size_bytes << " bytes; " << (1 + callback_queue_.size()) + << " callbacks in queue of total size " << callback_queue_size_bytes_ + << " bytes.\n"; + callback_queue_.push(std::move(received)); +} + +StatusOr<std::unique_ptr<Literal>> OutfeedReceiverImpl::ReceiveRawFromOutfeed( + const Device* device, const Shape& shape) { + std::shared_ptr<Literal> literal_shared; + + TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, + device->GetLocalDeviceState()); + TF_ASSIGN_OR_RETURN(Literal literal, + local_device->client()->TransferFromOutfeedLocal( + shape, local_device->device_ordinal())); + + return absl::make_unique<Literal>(std::move(literal)); +} + +void OutfeedReceiverImpl::CallbackThreadLoop() { + { + absl::MutexLock lock(&mu_); + num_working_callback_threads_++; + CHECK_EQ(num_working_callback_threads_, 1); + } + while (true) { + std::unique_ptr<OutfeedData> received; + { + absl::MutexLock lock(&mu_); + mu_.Await( + absl::Condition(this, &OutfeedReceiverImpl::CallbackQueueNotEmpty)); + received = std::move(callback_queue_.front()); + callback_queue_.pop(); + callback_queue_size_bytes_ -= received->literal_size_bytes(); + VLOG(2) << "Dequeued callback for " << received->DebugString() << "; " + << callback_queue_.size() << " callbacks in queue of total size " + << callback_queue_size_bytes_ << " bytes.\n"; + } + if (received->consumer_id() == kOutfeedCidShutdown) { + VLOG(2) << "Callback loop received shutdown signal"; + { + absl::MutexLock lock(&mu_); + CHECK(callback_queue_.empty()); + CHECK_EQ(callback_queue_size_bytes_, 0); + --num_working_callback_threads_; + } + VLOG(2) << "Callback loop done"; + return; + } + { + tensorflow::profiler::TraceMe traceme("OutfeedReceiver::Callback"); + DeviceWithClient device_client = received->device_client(); + callback_(device_client.device, std::move(device_client.client), + received->consumer_id(), received->literal()); + } + } +} + +Status OutfeedReceiverImpl::SendShutdownOutfeedHeader(int device_idx) { + const Device* device = devices_[device_idx].device; + constexpr int consumer_id = kOutfeedCidShutdown; + VLOG(2) << "[" << device->DebugString() + << "] SendSpecialHeader cons=" << consumer_id; + XlaBuilder builder( + absl::StrFormat("special_outfeed_header_%d_%d", consumer_id, device_idx)); + XlaOp send = + AddOutfeedToBuilder(&builder, CreateToken(&builder), consumer_id, {}) + .ValueOrDie(); + XlaComputation computation = builder.Build(send).ValueOrDie(); + + CompileOptions compile_options; + compile_options.executable_build_options.set_num_replicas(1); + compile_options.executable_build_options.set_num_partitions(1); + DeviceAssignment device_assignment(1, 1); + device_assignment(0, 0) = device->id(); + compile_options.executable_build_options.set_device_assignment( + device_assignment); + + TF_ASSIGN_OR_RETURN( + std::unique_ptr<PjRtExecutable> executable, + PjRtExecutable::Compile(computation, devices_[device_idx].client.get(), + std::move(compile_options))); + ExecuteOptions execute_options; + TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<PjRtBuffer>> output_buffers, + executable->Execute({}, execute_options)); + return Status::OK(); +} + +StatusOr<XlaOp> OutfeedReceiverImpl::AddOutfeedToBuilder( + XlaBuilder* builder, XlaOp token, uint32_t consumer_id, + std::vector<XlaOp> arrays) { + XlaOp data = Tuple(builder, std::move(arrays)); + Shape shape_with_layout = builder->GetShape(data).ValueOrDie(); + ShapeUtil::ForEachMutableSubshape( + &shape_with_layout, [](Shape* subshape, const ShapeIndex&) { + if (!subshape->has_layout()) { + LayoutUtil::SetToDefaultLayout(subshape); + } + }); + VLOG(2) << "RegisterShape cons=" << consumer_id + << "; shape=" << shape_with_layout.ToString(); + { + absl::MutexLock lock(&mu_); + auto found = shape_registry_.find(consumer_id); + if (found != shape_registry_.end()) { + if (!ShapeUtil::Equal(shape_with_layout, found->second)) { + return InvalidArgument( + "Shape %s does not match previous shape %s used " + "for consumer id %d", + shape_with_layout.DebugString(), found->second.DebugString(), + consumer_id); + } + } else { + shape_registry_.insert({consumer_id, shape_with_layout}); + } + } + + std::vector<uint32_t> header{kOutfeedHeaderStart, consumer_id}; + XlaOp header_op = ConstantR1<uint32_t>(builder, header); + token = OutfeedWithToken( + header_op, token, ShapeUtil::MakeShape(U32, {kOutfeedHeaderWords}), ""); + if (consumer_id != kOutfeedCidShutdown) { + token = OutfeedWithToken(data, token, shape_with_layout, ""); + } + return token; +} + +OutfeedReceiver::OutfeedReceiver( + Callback callback, std::vector<std::shared_ptr<PjRtClient>> clients, + ssize_t max_callback_queue_size_bytes) { + p_impl_ = absl::make_unique<OutfeedReceiverImpl>( + callback, std::move(clients), max_callback_queue_size_bytes); +} + +OutfeedReceiver::~OutfeedReceiver() {} + +void OutfeedReceiver::Start() { p_impl_->Start(); } + +StatusOr<XlaOp> OutfeedReceiver::AddOutfeedToBuilder( + XlaBuilder* builder, XlaOp token, uint32_t consumer_id, + std::vector<XlaOp> arrays) { + if (consumer_id == kOutfeedCidShutdown) { + return InvalidArgument("Consumer ID cannot be a reserved value: %d", + consumer_id); + } + return p_impl_->AddOutfeedToBuilder(builder, token, consumer_id, arrays); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/python/outfeed_receiver.h b/tensorflow/compiler/xla/python/outfeed_receiver.h new file mode 100644 index 00000000000..a0fdfcd36f0 --- /dev/null +++ b/tensorflow/compiler/xla/python/outfeed_receiver.h @@ -0,0 +1,77 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_OUTFEED_RECEIVER_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_OUTFEED_RECEIVER_H_ + +#include <memory> + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +class OutfeedReceiverImpl; + +// Implements a multithreaded receiver of outfeeds from devices. +class OutfeedReceiver { + public: + // A callback takes: device, client (for the device), consumer id, received. + // The client pointer should be alive while the device is used. + using Callback = std::function<void(Device*, std::shared_ptr<PjRtClient>, + uint32_t, std::shared_ptr<Literal>)>; + + // Constructs the receiver for the given clients and callback function. + // + // Args: + // callback: a function to be called when an outfeed is ready for + // processing. + // clients: the clients for whose devices to listen. + // max_callback_queue_size_bytes: the maximum number of bytes for all + // received outfeeds queued to be processed. When this limit is reached + // we pause receiving outfeeds from devices. + OutfeedReceiver(Callback callback, + std::vector<std::shared_ptr<PjRtClient>> clients, + ssize_t max_callback_queue_size_bytes); + + OutfeedReceiver(const OutfeedReceiver&) = delete; + OutfeedReceiver& operator=(const OutfeedReceiver&) = delete; + + // Blocks until all data has been received from devices and all data + // in the queue has been passed to Python. + ~OutfeedReceiver(); + + // Starts the listener threads and the callback thread. + void Start(); + + // Adds to the computation builder the outfeed of the arrays. + // Has the side-effect of registering the sent shape for the consumer_id. + // Returns error status if the outfeed shape is different than the + // previously used shape for the same consumer_id or the consumer id is + // invalid. + StatusOr<XlaOp> AddOutfeedToBuilder(XlaBuilder* builder, XlaOp token, + uint32_t consumer_id, + std::vector<XlaOp> arrays); + + private: + std::unique_ptr<OutfeedReceiverImpl> p_impl_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_OUTFEED_RECEIVER_H_ diff --git a/tensorflow/compiler/xla/python/outfeed_receiver_py.cc b/tensorflow/compiler/xla/python/outfeed_receiver_py.cc new file mode 100644 index 00000000000..a6256cfe86c --- /dev/null +++ b/tensorflow/compiler/xla/python/outfeed_receiver_py.cc @@ -0,0 +1,156 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/python/outfeed_receiver_py.h" + +#include <memory> + +#include "absl/memory/memory.h" +#include "absl/synchronization/mutex.h" +#include "pybind11/functional.h" +#include "pybind11/pybind11.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" +#include "tensorflow/compiler/xla/python/outfeed_receiver.h" +#include "tensorflow/compiler/xla/python/types.h" + +namespace xla { + +namespace py = pybind11; + +namespace { + +// A wrapper for OutfeedReceiver for use from Python, useful for ensuring +// that the GIL is released before destroying the OutfeedReceiver. +class OutfeedReceiverForPython { + public: + // A callback to Python takes: consumer id, received literal. + using CallbackToPython = + std::function<void(ClientAndPtr<Device>, uint32_t, pybind11::object)>; + + OutfeedReceiverForPython(CallbackToPython callback_python, + std::vector<std::shared_ptr<PjRtClient>> clients, + ssize_t max_callback_queue_size_bytes) { + callback_python_ = callback_python; + outfeed_receiver_shutting_down_ = false; + OutfeedReceiver::Callback callback = + [this](Device* device, std::shared_ptr<PjRtClient> client, + uint32_t consumer_id, std::shared_ptr<Literal> literal) { + this->Callback(device, client, consumer_id, literal); + }; + outfeed_receiver_ = absl::make_unique<OutfeedReceiver>( + callback, std::move(clients), max_callback_queue_size_bytes); + } + OutfeedReceiverForPython(const OutfeedReceiverForPython&) = delete; + OutfeedReceiverForPython& operator=(const OutfeedReceiverForPython&) = delete; + + ~OutfeedReceiverForPython() { + // This destructor is called from the Python GC. Release it for the duration + // of the destruction, including the destruction of the OutfeedReceiver, + // when we may actually have to wait for threads to end. During this time + // we do not callback to Python (sometimes we get an exception + // "std::runtime_error: scoped_acquire::dec_ref(): thread state must + // be current!""). + { + absl::MutexLock lock(&mu_); + outfeed_receiver_shutting_down_ = true; + } + py::gil_scoped_release gil_release; + outfeed_receiver_ = nullptr; // Shutdown the outfeed receiver. + } + + void Start() { outfeed_receiver_->Start(); } + + StatusOr<XlaOp> AddOutfeed(XlaBuilder* builder, XlaOp token, + uint32_t consumer_id, std::vector<XlaOp> arrays) { + return outfeed_receiver_->AddOutfeedToBuilder(builder, token, consumer_id, + arrays); + } + + void Callback(Device* device, std::shared_ptr<PjRtClient> client, + uint32_t consumer_id, std::shared_ptr<Literal> literal) { + { + absl::MutexLock lock(&mu_); + if (outfeed_receiver_shutting_down_) { + VLOG(2) << "Ignoring unsafe callback to Python during shutdown"; + return; + } + } + py::gil_scoped_acquire gil_acquire; // Need GIL also for LiteralToPython + py::object literal_python = + LiteralToPython(std::move(literal)).ValueOrDie(); + // The callback_ should handle all exceptions in user-code. If we get + // an exception here, it is a bug in the callback and we should stop. + callback_python_(WrapWithClient<Device>(std::move(client), device), + consumer_id, std::move(literal_python)); + } + + private: + CallbackToPython callback_python_; + absl::Mutex mu_; + bool outfeed_receiver_shutting_down_ TF_GUARDED_BY(mu_); + std::unique_ptr<OutfeedReceiver> outfeed_receiver_; +}; + +} // namespace + +void BuildOutfeedReceiverSubmodule(py::module* m) { + py::module outfeed_receiver = + m->def_submodule("outfeed_receiver", "Outfeed receiver"); + outfeed_receiver.def( + "start", + [](OutfeedReceiverForPython::CallbackToPython callback_to_python, + std::vector<std::shared_ptr<PjRtClient>> clients, + ssize_t max_callback_queue_size_bytes) + -> std::unique_ptr<OutfeedReceiverForPython> { + auto server = absl::make_unique<OutfeedReceiverForPython>( + callback_to_python, clients, max_callback_queue_size_bytes); + server->Start(); + return server; + }, + py::arg("callback_to_python"), py::arg("backends"), + py::arg("max_queue_size_bytes") = 256 * 1024 * 1024, + R"(Starts a multithreaded outfeed receiver. + + There is one thread for each of the specified devices. When Python + drops the last reference to the returned object, the receiver is shut + down. The destructor will block until all data is received from + devices. + + Args: + * callback_to_python: a Python callback to call, with <consumer_id> + and the data received. + * backends: the list of backends to listen on. + * max_queue_size_bytes: an optional integer to bound the maximum size + of arrays in the callback queue. When this limit is reached the + device listener pauses. + )", + py::call_guard<py::gil_scoped_release>()); + + py::class_<OutfeedReceiverForPython> outfeed_receiver_class( + outfeed_receiver, "OutfeedReceiverForPython"); + + outfeed_receiver_class.def( + "add_outfeed", &OutfeedReceiverForPython::AddOutfeed, py::arg("builder"), + py::arg("token"), py::arg("consumer_id"), py::arg("arrays"), + R"(Adds an outfeed into the given computation builder. + + Has the side-effect of registering the sent shape along with the consumer + ID. Returns error if the outfeed shape is not compatible with previously + used shape for the same consumer ID.)", + py::call_guard<py::gil_scoped_release>()); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/python/outfeed_receiver_py.h b/tensorflow/compiler/xla/python/outfeed_receiver_py.h new file mode 100644 index 00000000000..6b1a712327a --- /dev/null +++ b/tensorflow/compiler/xla/python/outfeed_receiver_py.h @@ -0,0 +1,27 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_OUTFEED_RECEIVER_PY_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_OUTFEED_RECEIVER_PY_H_ + +#include "pybind11/pybind11.h" + +namespace xla { + +void BuildOutfeedReceiverSubmodule(pybind11::module* m); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_OUTFEED_RECEIVER_PY_H_ diff --git a/tensorflow/compiler/xla/python/outfeed_receiver_test.cc b/tensorflow/compiler/xla/python/outfeed_receiver_test.cc new file mode 100644 index 00000000000..ea84b4e18d6 --- /dev/null +++ b/tensorflow/compiler/xla/python/outfeed_receiver_test.cc @@ -0,0 +1,258 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/python/outfeed_receiver.h" + +#include <memory> + +#include "absl/synchronization/mutex.h" +#include "tensorflow/compiler/xla/client/executable_build_options.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/pjrt/cpu_device.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" +#include "tensorflow/compiler/xla/test.h" + +namespace xla { + +namespace { + +Status CompileAndExecute(XlaBuilder* builder, XlaOp root, int device_id, + PjRtClient* client) { + XlaComputation computation = builder->Build(root).ValueOrDie(); + + CompileOptions compile_options; + compile_options.executable_build_options.set_num_replicas(1); + compile_options.executable_build_options.set_num_partitions(1); + DeviceAssignment device_assignment(1, 1); + device_assignment(0, 0) = device_id; + compile_options.executable_build_options.set_device_assignment( + device_assignment); + + TF_ASSIGN_OR_RETURN( + std::unique_ptr<PjRtExecutable> executable, + PjRtExecutable::Compile(computation, client, std::move(compile_options))); + ExecuteOptions execute_options; + TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<PjRtBuffer>> output_buffers, + executable->Execute({}, execute_options)); + return Status::OK(); +} + +// Accumulates the received data. +class Accumulator { + public: + struct Data { + uint32_t consumer_id; + std::shared_ptr<Literal> data; + }; + + void Receive(uint32_t consumer_id, std::shared_ptr<Literal> data) { + absl::MutexLock lock(&mutex_); + received_.push_back(Data{consumer_id, data}); + } + + std::vector<Data> received() { + absl::MutexLock lock(&mutex_); + return received_; + } + + private: + absl::Mutex mutex_; + std::vector<Data> received_ TF_GUARDED_BY(mutex_); +}; + +TEST(OutfeedReceiverTest, ReceiveOutfeedSimple) { + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr<PjRtClient> cpu_client, + GetCpuClient(true)); + std::vector<std::shared_ptr<PjRtClient>> clients{cpu_client}; + + auto receiver = absl::make_unique<Accumulator>(); + OutfeedReceiver::Callback callback = + [&receiver](Device* device, std::shared_ptr<PjRtClient> client, + uint32_t consumer_id, std::shared_ptr<Literal> data) { + receiver->Receive(consumer_id, data); + }; + auto outfeed_receiver = + std::make_shared<OutfeedReceiver>(callback, clients, 128); + outfeed_receiver->Start(); + + XlaBuilder builder("execute_test_outfeed"); + constexpr int consumer_id0 = 5; + const Shape shape0 = ShapeUtil::MakeShape(U32, {16}); + XlaOp data = Iota(&builder, shape0, 0); + XlaOp send = outfeed_receiver + ->AddOutfeedToBuilder(&builder, CreateToken(&builder), + consumer_id0, {data}) + .ValueOrDie(); + EXPECT_TRUE(CompileAndExecute(&builder, send, 0, cpu_client.get()).ok()); + + // Shutdown the receiver, to force it to wait to deliver the callbacks. + outfeed_receiver = nullptr; + std::vector<Accumulator::Data> received = receiver->received(); + EXPECT_EQ(1, received.size()); + EXPECT_EQ(consumer_id0, received[0].consumer_id); + EXPECT_EQ(ShapeUtil::MakeTupleShape({shape0}), received[0].data->shape()); +} + +TEST(OutfeedReceiverTest, ReceiveOutfeedTwoComputations) { + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr<PjRtClient> cpu_client, + GetCpuClient(true)); + std::vector<std::shared_ptr<PjRtClient>> clients{cpu_client}; + + auto receiver = absl::make_unique<Accumulator>(); + OutfeedReceiver::Callback callback = + [&receiver](Device* device, std::shared_ptr<PjRtClient> client, + uint32_t consumer_id, std::shared_ptr<Literal> data) { + receiver->Receive(consumer_id, data); + }; + auto outfeed_receiver = + std::make_shared<OutfeedReceiver>(callback, clients, 128); + outfeed_receiver->Start(); + + XlaBuilder builder0("execute_test_outfeed_0"); + constexpr int consumer_id0 = 5; + const Shape shape0 = ShapeUtil::MakeShape(U32, {16}); + XlaOp data0 = Iota(&builder0, shape0, 0); + XlaOp send0 = outfeed_receiver + ->AddOutfeedToBuilder(&builder0, CreateToken(&builder0), + consumer_id0, {data0}) + .ValueOrDie(); + EXPECT_TRUE(CompileAndExecute(&builder0, send0, 0, cpu_client.get()).ok()); + + XlaBuilder builder1("execute_test_outfeed_1"); + constexpr int consumer_id1 = 6; + const Shape shape1 = ShapeUtil::MakeShape(U32, {128}); + XlaOp data1 = Iota(&builder1, shape1, 0); + XlaOp send1 = outfeed_receiver + ->AddOutfeedToBuilder(&builder1, CreateToken(&builder1), + consumer_id1, {data1}) + .ValueOrDie(); + EXPECT_TRUE(CompileAndExecute(&builder1, send1, 0, cpu_client.get()).ok()); + + // Shutdown the receiver, to force it to wait to deliver the callbacks. + outfeed_receiver = nullptr; + std::vector<Accumulator::Data> received = receiver->received(); + EXPECT_EQ(2, received.size()); + EXPECT_EQ(consumer_id0, received[0].consumer_id); + EXPECT_EQ(ShapeUtil::MakeTupleShape({shape0}), received[0].data->shape()); + EXPECT_EQ(consumer_id1, received[1].consumer_id); + EXPECT_EQ(ShapeUtil::MakeTupleShape({shape1}), received[1].data->shape()); +} + +TEST(OutfeedReceiverTest, ReceiveOutfeedTwoOutfeed) { + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr<PjRtClient> cpu_client, + GetCpuClient(true)); + std::vector<std::shared_ptr<PjRtClient>> clients{cpu_client}; + + auto receiver = absl::make_unique<Accumulator>(); + OutfeedReceiver::Callback callback = + [&receiver](Device* device, std::shared_ptr<PjRtClient> client, + uint32_t consumer_id, std::shared_ptr<Literal> data) { + receiver->Receive(consumer_id, data); + }; + auto outfeed_receiver = + std::make_shared<OutfeedReceiver>(callback, clients, 128); + outfeed_receiver->Start(); + + XlaBuilder builder("execute_test_outfeed"); + constexpr int consumer_id0 = 5; + const Shape shape0 = ShapeUtil::MakeShape(U32, {16}); + XlaOp data0 = Iota(&builder, shape0, 0); + XlaOp send0 = outfeed_receiver + ->AddOutfeedToBuilder(&builder, CreateToken(&builder), + consumer_id0, {data0}) + .ValueOrDie(); + + constexpr int consumer_id1 = 6; + const Shape shape1 = ShapeUtil::MakeShape(U32, {128}); + XlaOp data1 = Iota(&builder, shape1, 0); + XlaOp send1 = + outfeed_receiver + ->AddOutfeedToBuilder(&builder, send0, consumer_id1, {data1}) + .ValueOrDie(); + EXPECT_TRUE(CompileAndExecute(&builder, send1, 0, cpu_client.get()).ok()); + + // Shutdown the receiver, to force it to wait to deliver the callbacks. + outfeed_receiver = nullptr; + std::vector<Accumulator::Data> received = receiver->received(); + EXPECT_EQ(2, received.size()); + EXPECT_EQ(consumer_id0, received[0].consumer_id); + EXPECT_EQ(ShapeUtil::MakeTupleShape({shape0}), received[0].data->shape()); + EXPECT_EQ(consumer_id1, received[1].consumer_id); + EXPECT_EQ(ShapeUtil::MakeTupleShape({shape1}), received[1].data->shape()); +} + +TEST(OutfeedReceiverTest, DifferentShapeForConsumerIdError) { + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr<PjRtClient> cpu_client, + GetCpuClient(true)); + std::vector<std::shared_ptr<PjRtClient>> clients{cpu_client}; + + auto receiver = absl::make_unique<Accumulator>(); + OutfeedReceiver::Callback callback = + [&receiver](Device* device, std::shared_ptr<PjRtClient> client, + uint32_t consumer_id, std::shared_ptr<Literal> data) { + receiver->Receive(consumer_id, data); + }; + auto outfeed_receiver = + std::make_shared<OutfeedReceiver>(callback, clients, 128); + outfeed_receiver->Start(); + + XlaBuilder builder("execute_test_outfeed"); + constexpr int consumer_id0 = 5; + const Shape shape0 = ShapeUtil::MakeShape(U32, {16}); + XlaOp data0 = Iota(&builder, shape0, 0); + XlaOp send0 = outfeed_receiver + ->AddOutfeedToBuilder(&builder, CreateToken(&builder), + consumer_id0, {data0}) + .ValueOrDie(); + + const Shape shape1 = ShapeUtil::MakeShape(U32, {128}); + XlaOp data1 = Iota(&builder, shape1, 0); + // A different shape for the same consumer ID. + StatusOr<XlaOp> send1 = outfeed_receiver->AddOutfeedToBuilder( + &builder, send0, consumer_id0, {data1}); + EXPECT_FALSE(send1.ok()); + EXPECT_THAT(send1.status().ToString(), + testing::HasSubstr("does not match previous shape element_type")); +} + +TEST(OutfeedReceiverTest, InvalidConsumerIdError) { + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr<PjRtClient> cpu_client, + GetCpuClient(true)); + std::vector<std::shared_ptr<PjRtClient>> clients{cpu_client}; + + auto receiver = absl::make_unique<Accumulator>(); + OutfeedReceiver::Callback callback = + [&receiver](Device* device, std::shared_ptr<PjRtClient> client, + uint32_t consumer_id, std::shared_ptr<Literal> data) { + receiver->Receive(consumer_id, data); + }; + auto outfeed_receiver = + std::make_shared<OutfeedReceiver>(callback, clients, 128); + outfeed_receiver->Start(); + + XlaBuilder builder("execute_test_outfeed"); + const Shape shape0 = ShapeUtil::MakeShape(U32, {16}); + XlaOp data0 = Iota(&builder, shape0, 0); + StatusOr<XlaOp> send0 = outfeed_receiver->AddOutfeedToBuilder( + &builder, CreateToken(&builder), 0, {data0}); + + EXPECT_FALSE(send0.ok()); + EXPECT_THAT(send0.status().ToString(), + testing::HasSubstr("Consumer ID cannot be a reserved value")); +} + +} // namespace + +} // namespace xla diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index f03595bf677..0b6824e83e9 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -24,17 +24,12 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "absl/types/optional.h" #include "absl/types/span.h" +#include "pybind11/attr.h" #include "pybind11/cast.h" #include "pybind11/numpy.h" #include "pybind11/pybind11.h" #include "pybind11/pytypes.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/lib/comparators.h" -#include "tensorflow/compiler/xla/client/lib/math.h" -#include "tensorflow/compiler/xla/client/lib/qr.h" -#include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h" -#include "tensorflow/compiler/xla/client/lib/sorting.h" -#include "tensorflow/compiler/xla/client/lib/svd.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" @@ -47,6 +42,8 @@ limitations under the License. #include "tensorflow/compiler/xla/pjrt/pjrt_client.h" #include "tensorflow/compiler/xla/python/bfloat16.h" #include "tensorflow/compiler/xla/python/dlpack.h" +#include "tensorflow/compiler/xla/python/ops.h" +#include "tensorflow/compiler/xla/python/outfeed_receiver_py.h" #include "tensorflow/compiler/xla/python/python_ref_manager.h" #include "tensorflow/compiler/xla/python/types.h" #include "tensorflow/compiler/xla/service/custom_call_target_registry.h" @@ -62,15 +59,16 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/core/profiler/rpc/profiler_server.h" +#include "tensorflow/python/profiler/internal/traceme_wrapper.h" #include "tensorflow/stream_executor/platform.h" namespace xla { +namespace { namespace py = pybind11; -namespace { +using ::tensorflow::profiler::TraceMeWrapper; struct Uniquer { absl::Mutex mu; @@ -304,358 +302,6 @@ StatusOr<py::dict> PjRtBufferCudaArrayInterface(const PjRtBuffer& buffer) { return result; } -void BuildOpsSubmodule(py::module* m) { - // ops submodule, containing free functions that add operators to an - // XlaBuilder. - py::module ops = m->def_submodule("ops", "XLA operations"); - - py::enum_<TriangularSolveOptions::Transpose>( - ops, "TriangularSolveOptions_Transpose") - .value("TRANSPOSE_INVALID", TriangularSolveOptions::TRANSPOSE_INVALID) - .value("NO_TRANSPOSE", TriangularSolveOptions::NO_TRANSPOSE) - .value("TRANSPOSE", TriangularSolveOptions::TRANSPOSE) - .value("ADJOINT", TriangularSolveOptions::ADJOINT); - - ops.def("AfterAll", &AfterAll, py::arg("builder"), py::arg("tokens")); - ops.def( - "AllReduce", - static_cast<XlaOp (*)( - XlaOp, const XlaComputation&, absl::Span<const ReplicaGroup>, - const absl::optional<ChannelHandle>&, const absl::optional<Shape>&)>( - &AllReduce), - py::arg("operand"), py::arg("computation"), - py::arg("replica_groups") = py::list(), - py::arg("channel_id") = absl::nullopt, - py::arg("shape_with_layout") = absl::nullopt); - ops.def("AllToAll", &AllToAll, py::arg("operand"), py::arg("split_dimension"), - py::arg("concat_dimension"), py::arg("split_count"), - py::arg("replica_groups") = py::list(), - py::arg("layout") = absl::nullopt); - ops.def("CollectivePermute", &CollectivePermute, py::arg("operand"), - py::arg("source_target_pairs")); - ops.def("CreateToken", &CreateToken, py::arg("builder")); - ops.def("CrossReplicaSum", - static_cast<XlaOp (*)(XlaOp, absl::Span<const ReplicaGroup>)>( - &CrossReplicaSum), - py::arg("operand"), py::arg("replica_groups") = py::list()); - ops.def("BitcastConvertType", &BitcastConvertType, py::arg("operand"), - py::arg("new_element_type")); - ops.def("Broadcast", &Broadcast, py::arg("operand"), py::arg("sizes")); - ops.def("BroadcastInDim", &BroadcastInDim, py::arg("operand"), - py::arg("shape"), py::arg("broadcast_dimensions")); - ops.def("Call", &Call, py::arg("builder"), py::arg("computation"), - py::arg("operands")); - ops.def("Cholesky", &Cholesky, py::arg("a"), py::arg("lower") = true); - ops.def("Clamp", &Clamp, py::arg("min"), py::arg("operand"), py::arg("max")); - ops.def("Collapse", &Collapse, py::arg("operand"), py::arg("dimensions")); - ops.def("ConcatInDim", &ConcatInDim, py::arg("builder"), py::arg("operands"), - py::arg("dimension")); - ops.def("Conditional", - static_cast<XlaOp (*)(XlaOp, absl::Span<const XlaComputation* const>, - absl::Span<const XlaOp>)>(&Conditional), - py::arg("branch_index"), py::arg("branch_computations"), - py::arg("branch_operands")); - ops.def("Conditional", - static_cast<XlaOp (*)(XlaOp, XlaOp, const XlaComputation&, XlaOp, - const XlaComputation&)>(&Conditional), - py::arg("predicate"), py::arg("true_operand"), - py::arg("true_computation"), py::arg("false_operand"), - py::arg("false_computation")); - ops.def("Constant", &ConstantLiteral, py::arg("builder"), py::arg("literal")); - ops.def("ConstantLiteral", &ConstantLiteral, py::arg("builder"), - py::arg("literal")); - ops.def("ConvGeneralDilated", &ConvGeneralDilated, py::arg("lhs"), - py::arg("rhs"), py::arg("window_strides"), py::arg("padding"), - py::arg("lhs_dilation"), py::arg("rhs_dilation"), - py::arg("dimension_numbers"), py::arg("feature_group_count") = 1, - py::arg("batch_group_count") = 1, - py::arg("precision_config") = nullptr); - ops.def("ConvertElementType", &ConvertElementType, py::arg("operand"), - py::arg("new_element_type")); - ops.def( - "CustomCall", - [](XlaBuilder* builder, const py::bytes& call_target_name, - absl::Span<const XlaOp> operands, const Shape& shape, - const py::bytes& opaque) -> XlaOp { - return CustomCall(builder, call_target_name, operands, shape, opaque); - }, - py::arg("builder"), py::arg("call_target_name"), py::arg("operands"), - py::arg("shape"), py::arg("opaque") = py::bytes("")); - ops.def( - "CustomCallWithLayout", - [](XlaBuilder* builder, const py::bytes& call_target_name, - absl::Span<const XlaOp> operands, const Shape& shape_with_layout, - absl::Span<const Shape> operand_shapes_with_layout, - const py::bytes& opaque) -> XlaOp { - return CustomCallWithLayout(builder, call_target_name, operands, - shape_with_layout, - operand_shapes_with_layout, opaque); - }, - py::arg("builder"), py::arg("call_target_name"), py::arg("operands"), - py::arg("shape_with_layout"), py::arg("operand_shapes_with_layout"), - py::arg("opaque") = py::bytes("")); - ops.def("Dot", &Dot, py::arg("lhs"), py::arg("rhs"), - py::arg("precision_config") = nullptr); - ops.def("DotGeneral", &DotGeneral, py::arg("lhs"), py::arg("rhs"), - py::arg("dimension_numbers"), py::arg("precision_config") = nullptr); - ops.def("DynamicSlice", - static_cast<XlaOp (*)(XlaOp, absl::Span<const XlaOp>, - absl::Span<const int64>)>(&DynamicSlice), - py::arg("operand"), py::arg("start_indices"), py::arg("slice_sizes")); - ops.def("DynamicUpdateSlice", - static_cast<XlaOp (*)(XlaOp, XlaOp, absl::Span<const XlaOp>)>( - &DynamicUpdateSlice), - py::arg("operand"), py::arg("update"), py::arg("start_indices")); - - ops.def("Fft", &Fft, py::arg("operand"), py::arg("fft_type"), - py::arg("fft_length")); - - ops.def("Gather", &Gather, py::arg("a"), py::arg("start_indices"), - py::arg("dimension_numbers"), py::arg("slice_sizes"), - py::arg("indices_are_sorted") = false); - ops.def("GetTupleElement", &GetTupleElement, py::arg("tuple_data"), - py::arg("index")); - ops.def("InfeedWithToken", &InfeedWithToken, py::arg("token"), - py::arg("shape"), py::arg("config") = ""); - ops.def("Iota", - static_cast<XlaOp (*)(XlaBuilder*, const Shape&, int64)>(&Iota), - py::arg("builder"), py::arg("shape"), py::arg("iota_dimension")); - ops.def("Iota", - static_cast<XlaOp (*)(XlaBuilder*, PrimitiveType, int64)>(&Iota), - py::arg("builder"), py::arg("type"), py::arg("size")); - ops.def("Map", &Map, py::arg("builder"), py::arg("operands"), - py::arg("computation"), py::arg("dimensions"), - py::arg("static_operands") = py::list()); - ops.def("NextAfter", &NextAfter, py::arg("from"), py::arg("to")); - ops.def("OutfeedWithToken", &OutfeedWithToken, py::arg("operand"), - py::arg("token"), py::arg("shape_with_layout"), - py::arg("outfeed_config") = ""); - ops.def("Pad", &Pad, py::arg("operand"), py::arg("padding_value"), - py::arg("padding_config")); - ops.def("Parameter", - static_cast<XlaOp (*)(XlaBuilder*, int64, const Shape&, - const std::string&, const std::vector<bool>&)>( - &Parameter), - py::arg("builder"), py::arg("parameter_number"), py::arg("shape"), - py::arg("name") = "", - py::arg("replicated_at_leaf_buffers") = std::vector<bool>()); - ops.def( - "QR", - [](XlaOp a, bool full_matrices) -> StatusOr<std::pair<XlaOp, XlaOp>> { - TF_ASSIGN_OR_RETURN(auto qr, QRDecomposition(a, full_matrices)); - return std::make_pair(qr.q, qr.r); - }, - py::arg("operand"), py::arg("full_matrices")); - ops.def( - "Eigh", - [](XlaOp a, bool lower, int64 max_iter, - float epsilon) -> std::pair<XlaOp, XlaOp> { - auto eigh = SelfAdjointEig(a, lower, max_iter, epsilon); - return std::make_pair(eigh.v, eigh.w); - }, - py::arg("a"), py::arg("lower") = true, py::arg("max_iter") = 100, - py::arg("epsilon") = 1e-6); - ops.def( - "SVD", - [](XlaOp a, int64 max_iter, - float epsilon) -> std::tuple<XlaOp, XlaOp, XlaOp> { - auto svd = SVD(a, max_iter, epsilon); - return std::make_tuple(svd.u, svd.d, svd.v); - }, - py::arg("a"), py::arg("max_iter") = 100, py::arg("epsilon") = 1e-6); - ops.def("Reduce", - static_cast<XlaOp (*)(XlaBuilder*, absl::Span<const XlaOp>, - absl::Span<const XlaOp>, const XlaComputation&, - absl::Span<const int64>)>(&Reduce), - py::arg("builder"), py::arg("operands"), py::arg("init_values"), - py::arg("computation"), py::arg("dimensions_to_reduce")); - ops.def("ReducePrecision", &ReducePrecision, py::arg("operand"), - py::arg("exponent_bits"), py::arg("mantissa_bits")); - ops.def("ReduceWindowWithGeneralPadding", &ReduceWindowWithGeneralPadding, - py::arg("operand"), py::arg("init_value"), py::arg("computation"), - py::arg("window_dimensions"), py::arg("window_strides"), - py::arg("base_dilations"), py::arg("window_dilations"), - py::arg("padding")); - ops.def("ReplicaId", &ReplicaId, py::arg("builder")); - ops.def("Reshape", - static_cast<XlaOp (*)(XlaOp, absl::Span<const int64>, - absl::Span<const int64>)>(&Reshape), - py::arg("operand"), py::arg("dimensions"), py::arg("new_sizes")); - ops.def("Reshape", - static_cast<XlaOp (*)(XlaOp, absl::Span<const int64>)>(&Reshape), - py::arg("operand"), py::arg("new_sizes")); - ops.def("Rev", &Rev, py::arg("operand"), py::arg("dimensions")); - ops.def("RngNormal", &RngNormal, py::arg("mu"), py::arg("sigma"), - py::arg("shape")); - ops.def("RngUniform", &RngUniform, py::arg("a"), py::arg("b"), - py::arg("shape")); - ops.def("Scatter", &Scatter, py::arg("input"), py::arg("scatter_indices"), - py::arg("updates"), py::arg("update_computation"), - py::arg("dimension_numbers"), py::arg("indices_are_sorted") = false, - py::arg("unique_indices") = false); - ops.def("Select", &Select, py::arg("pred"), py::arg("on_true"), - py::arg("on_false")); - ops.def("SelectAndScatterWithGeneralPadding", - &SelectAndScatterWithGeneralPadding, py::arg("operand"), - py::arg("select"), py::arg("window_dimensions"), - py::arg("window_strides"), py::arg("padding"), py::arg("source"), - py::arg("init_value"), py::arg("scatter")); - ops.def("Slice", &Slice, py::arg("operand"), py::arg("start_indices"), - py::arg("limit_indices"), py::arg("strides")); - ops.def("SliceInDim", &SliceInDim, py::arg("operand"), py::arg("start_index"), - py::arg("limit_index"), py::arg("stride"), py::arg("dimno")); - ops.def( - "Sort", - [](XlaBuilder* builder, absl::Span<const XlaOp> operands, - absl::optional<const XlaComputation*> comparator, int64 dimension, - bool is_stable) -> XlaOp { - return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { - std::vector<PrimitiveType> operand_types; - for (const auto& operand : operands) { - TF_ASSIGN_OR_RETURN(auto operand_shape, builder->GetShape(operand)); - operand_types.push_back(operand_shape.element_type()); - } - - if (comparator) { - return Sort(operands, **comparator, dimension, is_stable); - } else { - return Sort(operands, - CreateScalarLtComputation(operand_types, builder), - dimension, is_stable); - } - }); - }, - py::arg("builder"), py::arg("operands"), - py::arg("comparator") = absl::nullopt, py::arg("dimension") = -1, - py::arg("is_stable") = false); - ops.def("TopK", &TopK, py::arg("input"), py::arg("k")); - ops.def("Transpose", &Transpose, py::arg("operand"), py::arg("permutation")); - ops.def("TriangularSolve", &TriangularSolve, py::arg("a"), py::arg("b"), - py::arg("left_side"), py::arg("lower"), py::arg("unit_diagonal"), - py::arg("transpose_a")); - ops.def("Tuple", &Tuple, py::arg("builder"), py::arg("elements")); - ops.def("While", &While, py::arg("condition"), py::arg("body"), - py::arg("init")); - - ops.def("Igamma", &Igamma, py::arg("a"), py::arg("x")); - ops.def("Igammac", &Igammac, py::arg("a"), py::arg("x")); - ops.def("IgammaGradA", &IgammaGradA, py::arg("a"), py::arg("x")); - ops.def("RandomGammaGrad", &RandomGammaGrad, py::arg("a"), py::arg("x")); - ops.def("RegularizedIncompleteBeta", &RegularizedIncompleteBeta, py::arg("a"), - py::arg("b"), py::arg("x")); - -#define BINARY_OP(op) \ - ops.def( \ - #op, \ - [](XlaOp a, XlaOp b, absl::optional<std::vector<int64>> dims) { \ - return dims ? op(a, b, *dims) : op(a, b); \ - }, \ - py::arg("lhs"), py::arg("rhs"), \ - py::arg("broadcast_dimensions") = absl::nullopt) - BINARY_OP(Eq); - BINARY_OP(Ne); - BINARY_OP(Ge); - BINARY_OP(Gt); - BINARY_OP(Lt); - BINARY_OP(Le); - BINARY_OP(Add); - BINARY_OP(Sub); - BINARY_OP(Mul); - BINARY_OP(Div); - BINARY_OP(Rem); - BINARY_OP(Max); - BINARY_OP(Min); - BINARY_OP(And); - BINARY_OP(Or); - BINARY_OP(Xor); - BINARY_OP(ShiftLeft); - BINARY_OP(ShiftRightArithmetic); - BINARY_OP(ShiftRightLogical); - BINARY_OP(Atan2); - BINARY_OP(Pow); - BINARY_OP(Complex); -#undef BINARY_OP - -#define UNARY_OP(op) ops.def(#op, &op) - UNARY_OP(Not); - UNARY_OP(PopulationCount); - UNARY_OP(Clz); - UNARY_OP(Abs); - UNARY_OP(Exp); - UNARY_OP(Expm1); - UNARY_OP(Floor); - UNARY_OP(Ceil); - UNARY_OP(Round); - UNARY_OP(Log); - UNARY_OP(Log1p); - UNARY_OP(Sign); - UNARY_OP(Cos); - UNARY_OP(Sin); - UNARY_OP(Tanh); - UNARY_OP(IsFinite); - UNARY_OP(Neg); - UNARY_OP(Sqrt); - UNARY_OP(Rsqrt); - UNARY_OP(Square); - UNARY_OP(Reciprocal); - UNARY_OP(Erfc); - UNARY_OP(Erf); - UNARY_OP(ErfInv); - UNARY_OP(Lgamma); - UNARY_OP(Digamma); - UNARY_OP(BesselI0e); - UNARY_OP(BesselI1e); - UNARY_OP(Acos); - UNARY_OP(Asin); - UNARY_OP(Atan); - UNARY_OP(Tan); - UNARY_OP(Acosh); - UNARY_OP(Asinh); - UNARY_OP(Atanh); - UNARY_OP(Cosh); - UNARY_OP(Sinh); - UNARY_OP(Real); - UNARY_OP(Imag); - UNARY_OP(Conj); -#undef UNARY_OP -} - -// Helper to implement TraceMe as a context manager in Python. -class TraceMeContextManager { - public: - explicit TraceMeContextManager(py::str name, py::kwargs kwargs) - : name_(std::move(name)), kwargs_(std::move(kwargs)) {} - - void Enter() { - if (IsEnabled()) { - std::string name(name_); - if (!kwargs_.empty()) { - absl::StrAppend(&name, "#"); - bool first = true; - for (const auto entry : kwargs_) { - absl::StrAppend(&name, first ? "" : ",", - std::string(py::str(entry.first)), "=", - std::string(py::str(entry.second))); - first = false; - } - absl::StrAppend(&name, "#"); - } - traceme_.emplace(std::move(name)); - } - } - py::object Exit(const py::object& ex_type, const py::object& ex_value, - const py::object& traceback) { - traceme_.reset(); - return py::none(); - } - - static bool IsEnabled() { return tensorflow::profiler::TraceMe::Active(); } - - private: - py::str name_; - py::kwargs kwargs_; - absl::optional<tensorflow::profiler::TraceMe> traceme_; -}; void BuildProfilerSubmodule(py::module* m) { py::module profiler = @@ -672,11 +318,19 @@ void BuildProfilerSubmodule(py::module* m) { }, py::arg("port")); - py::class_<TraceMeContextManager> traceme_class(profiler, "TraceMe"); + py::class_<TraceMeWrapper> traceme_class(profiler, "TraceMe", + py::module_local()); traceme_class.def(py::init<py::str, py::kwargs>()) - .def("__enter__", &TraceMeContextManager::Enter) - .def("__exit__", &TraceMeContextManager::Exit) - .def_static("is_enabled", &TraceMeContextManager::IsEnabled); + .def("__enter__", [](py::object self) -> py::object { return self; }) + .def("__exit__", + [](py::object self, const py::object& ex_type, + const py::object& ex_value, + const py::object& traceback) -> py::object { + py::cast<TraceMeWrapper*>(self)->Stop(); + return py::none(); + }) + .def("set_metadata", &TraceMeWrapper::SetMetadata) + .def_static("is_enabled", &TraceMeWrapper::IsEnabled); } } // namespace @@ -872,11 +526,7 @@ PYBIND11_MODULE(xla_extension, m) { DebugOptions* debug_options = options.executable_build_options.mutable_debug_options(); // Sets fast-math-disabling default options expected by JAX. - // TODO(phawkins): make these XLA-wide defaults. - debug_options->set_xla_cpu_fast_math_honor_infs(true); - debug_options->set_xla_cpu_fast_math_honor_nans(true); - debug_options->set_xla_cpu_fast_math_honor_division(true); - debug_options->set_xla_cpu_fast_math_honor_functions(true); + debug_options->set_xla_cpu_enable_fast_min_max(false); debug_options->set_xla_gpu_enable_fast_min_max(false); return options; })) @@ -934,34 +584,6 @@ PYBIND11_MODULE(xla_extension, m) { "client", [](const ClientAndPtr<Device>& device) { return device.client; }) .def("__str__", &Device::DebugString) - // TODO(phawkins): remove capitalized names after updating callers. - .def("TransferToInfeed", - [](const Device& device, const LiteralSlice& literal) { - GlobalPyRefManager()->CollectGarbage(); - py::gil_scoped_release gil_release; - TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, - device.GetLocalDeviceState()); - return local_device->client()->TransferToInfeedLocal( - literal, local_device->device_ordinal()); - }) - .def( - "TransferFromOutfeed", - [](const Device& device, const Shape& shape) -> StatusOr<py::object> { - GlobalPyRefManager()->CollectGarbage(); - std::shared_ptr<Literal> literal_shared; - { - py::gil_scoped_release gil_release; - TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, - device.GetLocalDeviceState()); - TF_ASSIGN_OR_RETURN( - Literal literal, - local_device->client()->TransferFromOutfeedLocal( - shape, local_device->device_ordinal())); - - literal_shared = std::make_shared<Literal>(std::move(literal)); - } - return LiteralToPython(std::move(literal_shared)); - }) .def("transfer_to_infeed", [](const Device& device, const LiteralSlice& literal) { GlobalPyRefManager()->CollectGarbage(); @@ -1248,28 +870,6 @@ PYBIND11_MODULE(xla_extension, m) { .def("size_of_generated_code_in_bytes", &PjRtExecutable::SizeOfGeneratedCodeInBytes) .def("delete", &PjRtExecutable::Delete) - // TODO(phawkins): delete capitalized methods after updating callers. - .def("Delete", &PjRtExecutable::Delete) - .def( - "Execute", - [](const PjRtExecutable& executable, - absl::Span<PjRtBuffer* const> args) - -> StatusOr<std::vector<ClientAndUniquePtr<PjRtBuffer>>> { - py::gil_scoped_release gil_release; - ExecuteOptions options; - options.untuple_result = true; - TF_ASSIGN_OR_RETURN( - std::vector<std::unique_ptr<PjRtBuffer>> output_buffers, - executable.Execute(args, options)); - std::vector<ClientAndUniquePtr<PjRtBuffer>> outputs; - outputs.reserve(output_buffers.size()); - for (auto& buffer : output_buffers) { - outputs.push_back(WrapWithClient( - executable.client()->shared_from_this(), std::move(buffer))); - } - return outputs; - }, - py::arg("arguments")) .def( "execute", [](const PjRtExecutable& executable, @@ -1290,33 +890,6 @@ PYBIND11_MODULE(xla_extension, m) { return outputs; }, py::arg("arguments")) - // TODO(phawkins): delete capitalized methods after updating callers. - .def( - "ExecuteOnLocalDevices", - [](const PjRtExecutable& executable, - absl::Span<const std::vector<PjRtBuffer*>> args) - -> StatusOr< - std::vector<std::vector<ClientAndUniquePtr<PjRtBuffer>>>> { - py::gil_scoped_release gil_release; - ExecuteOptions options; - options.untuple_result = true; - TF_ASSIGN_OR_RETURN( - std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> - output_buffers, - executable.ExecuteOnLocalDevices(args, options)); - std::vector<std::vector<ClientAndUniquePtr<PjRtBuffer>>> outputs; - outputs.resize(output_buffers.size()); - for (int computation = 0; computation < output_buffers.size(); - ++computation) { - for (auto& buffer : output_buffers[computation]) { - outputs[computation].push_back( - WrapWithClient(executable.client()->shared_from_this(), - std::move(buffer))); - } - } - return outputs; - }, - py::arg("arguments")) .def( "execute_on_local_devices", [](const PjRtExecutable& executable, @@ -1377,7 +950,19 @@ PYBIND11_MODULE(xla_extension, m) { &DebugOptions::set_xla_cpu_fast_math_honor_functions) .def_property("xla_gpu_enable_fast_min_max", &DebugOptions::xla_gpu_enable_fast_min_max, - &DebugOptions::set_xla_gpu_enable_fast_min_max); + &DebugOptions::set_xla_gpu_enable_fast_min_max) + .def_property("xla_backend_optimization_level", + &DebugOptions::xla_backend_optimization_level, + &DebugOptions::set_xla_backend_optimization_level) + .def_property("xla_cpu_enable_xprof_traceme", + &DebugOptions::xla_cpu_enable_xprof_traceme, + &DebugOptions::set_xla_cpu_enable_xprof_traceme) + .def_property("xla_llvm_disable_expensive_passes", + &DebugOptions::xla_llvm_disable_expensive_passes, + &DebugOptions::set_xla_llvm_disable_expensive_passes) + .def_property("xla_test_all_input_layouts", + &DebugOptions::xla_test_all_input_layouts, + &DebugOptions::set_xla_test_all_input_layouts); py::class_<ExecutableBuildOptions>(m, "ExecutableBuildOptions") .def(py::init<>()) @@ -1406,7 +991,10 @@ PYBIND11_MODULE(xla_extension, m) { options.device_assignment()) : absl::nullopt; }, - &ExecutableBuildOptions::set_device_assignment); + &ExecutableBuildOptions::set_device_assignment) + .def_property("use_spmd_partitioning", + &ExecutableBuildOptions::use_spmd_partitioning, + &ExecutableBuildOptions::set_use_spmd_partitioning); py::class_<XlaComputation>(m, "XlaComputation") .def(py::init([](const py::bytes& serialized_hlo_module_proto) @@ -1415,12 +1003,6 @@ PYBIND11_MODULE(xla_extension, m) { proto.ParseFromString(serialized_hlo_module_proto); return absl::make_unique<XlaComputation>(proto); })) - // TODO(phawkins): delete capitalized names after updating callers. - .def("GetProgramShape", &XlaComputation::GetProgramShape) - .def("GetSerializedProto", &GetComputationSerializedProto) - .def("GetHloText", &GetComputationHloText) - .def("GetHloDotGraph", &GetComputationHloDotGraph) - .def("Hash", &HashComputation) .def("get_hlo_module", &GetHloModule) .def("program_shape", &XlaComputation::GetProgramShape) .def("as_serialized_hlo_module_proto", &GetComputationSerializedProto) @@ -1513,28 +1095,7 @@ PYBIND11_MODULE(xla_extension, m) { }, "Builds a computation from the contents of the builder.", py::arg("root") = absl::nullopt) - .def("ClearOpMetadata", &XlaBuilder::ClearOpMetadata) .def("GetShape", &XlaBuilder::GetShape) - .def( - "GetProgramShape", - [](const XlaBuilder& builder, - absl::optional<XlaOp> root) -> StatusOr<ProgramShape> { - return root ? builder.GetProgramShape(*root) - : builder.GetProgramShape(); - }, - py::arg("root") = absl::nullopt) - .def("IsConstant", &XlaBuilder::IsConstant) - .def("SetOpMetadata", &XlaBuilder::SetOpMetadata) - .def("SetSharding", &XlaBuilder::SetSharding) - .def("ClearSharding", &XlaBuilder::ClearSharding) - .def("SetUpAlias", - [](XlaBuilder& builder, const std::vector<int64>& output_index, - int64 param_number, const std::vector<int64>& param_index) { - builder.SetUpAlias( - ShapeIndex(output_index.begin(), output_index.end()), - param_number, - ShapeIndex(param_index.begin(), param_index.end())); - }) .def( "build", [](XlaBuilder& builder, absl::optional<XlaOp> root) { @@ -1565,17 +1126,7 @@ PYBIND11_MODULE(xla_extension, m) { ShapeIndex(param_index.begin(), param_index.end())); }); - // TODO(phawkins): delete capitalized names after updating callers - m.def("BufferToDLPackManagedTensor", BufferToDLPackManagedTensor); m.def("buffer_to_dlpack_managed_tensor", BufferToDLPackManagedTensor); - m.def("DLPackManagedTensorToBuffer", - [](const py::capsule& tensor, std::shared_ptr<PjRtClient> client) - -> StatusOr<ClientAndUniquePtr<PjRtBuffer>> { - TF_ASSIGN_OR_RETURN( - std::unique_ptr<PjRtBuffer> buffer, - DLPackManagedTensorToBuffer(tensor, client.get())); - return WrapWithClient(std::move(client), std::move(buffer)); - }); m.def("dlpack_managed_tensor_to_buffer", [](const py::capsule& tensor, std::shared_ptr<PjRtClient> client) -> StatusOr<ClientAndUniquePtr<PjRtBuffer>> { @@ -1615,6 +1166,7 @@ PYBIND11_MODULE(xla_extension, m) { BuildOpsSubmodule(&m); BuildProfilerSubmodule(&m); + BuildOutfeedReceiverSubmodule(&m); py::class_<DistributedRuntimeService, std::unique_ptr<DistributedRuntimeService>> diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index d9cd906939d..76c3bc33a91 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -300,13 +300,13 @@ CompileOptions = _xla.CompileOptions # An Executable is a C++ class that duck types with the following API: # class Executable(object): # def local_devices(self) -> [Device]: -# def Execute(self, arguments : [Buffer]) -> Buffer: +# def execute(self, arguments : [Buffer]) -> Buffer: # """Execute on one replica with Buffer arguments and return value.""" # -# def SizeOfGeneratedCodeInBytes(self) -> int: +# def size_of_generated_code_in_bytes(self) -> int: # """Return generated binary size, or -1 if not known.""" # -# def ExecuteOnLocalDevices(self, arguments: [[Buffer]]) -> [Buffer]: +# def execute_on_local_devices(self, arguments: [[Buffer]]) -> [Buffer]: # """Execute on many replicas with Buffer arguments and return value. # # Args: @@ -329,7 +329,7 @@ def execute_with_python_values(executable, arguments, backend): return backend.buffer_from_pyval(arg, device=executable.local_devices()[0]) arguments = [put(arg) for arg in arguments] - outputs = executable.Execute(arguments) + outputs = executable.execute(arguments) return [x.to_py() for x in outputs] @@ -359,7 +359,7 @@ def execute_with_python_values_replicated(executable, arguments, backend): flat_arg_buffers = flat_arg_buffers[len(replica_args):] return [[x.to_py() for x in xs] - for xs in executable.ExecuteOnLocalDevices(arg_buffers)] + for xs in executable.execute_on_local_devices(arg_buffers)] class PaddingType(enum.Enum): diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index fbdd9921a40..000db2cb16b 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -115,6 +115,10 @@ def TestFactory(xla_backend, cloud_tpu=False): """Convenience wrapper to create Numpy arrays with a np.float32 dtype.""" return np.array(*args, dtype=np.float32, **kwargs) + def NumpyArrayF64(*args, **kwargs): + """Convenience wrapper to create Numpy arrays with a np.float64 dtype.""" + return np.array(*args, dtype=np.float64, **kwargs) + def NumpyArrayS32(*args, **kwargs): """Convenience wrapper to create Numpy arrays with a np.int32 dtype.""" return np.array(*args, dtype=np.int32, **kwargs) @@ -882,12 +886,20 @@ def TestFactory(xla_backend, cloud_tpu=False): ops.Abs(ops.Constant(c, arr)) self._ExecuteAndCompareClose(c, expected=[np.abs(arr)]) - def testTanh(self): + def testTanhF32(self): c = self._NewComputation() - arr = NumpyArrayF32([3.3, 12.1]) + arr = NumpyArrayF32([-0.2, 3.3, 12.1, 0.1, 0.0001]) ops.Tanh(ops.Constant(c, arr)) self._ExecuteAndCompareClose(c, expected=[np.tanh(arr)]) + def testTanhF64(self): + if self.backend.platform == "tpu": + self.skipTest("TPU doesn't support 64bit tanh") + c = self._NewComputation() + arr = NumpyArrayF64([-0.2, 3.3, 12.1, 0.1, 0.0001]) + ops.Tanh(ops.Constant(c, arr)) + self._ExecuteAndCompareClose(c, expected=[np.tanh(arr)], rtol=1e-12) + def testTranspose(self): def _TransposeAndTest(array, permutation): diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 126b62a8eb2..125a42bb2f9 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -491,6 +491,66 @@ tf_cc_test( ], ) +cc_library( + name = "sharding_propagation", + srcs = [ + "sharding_propagation.cc", + ], + hdrs = [ + "sharding_propagation.h", + ], + deps = [ + ":dot_as_convolution_util", + ":hlo", + ":hlo_graph_dumper", + ":hlo_pass", + ":hlo_sharding_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + ], +) + +tf_cc_test( + name = "sharding_propagation_test", + srcs = [ + "sharding_propagation_test.cc", + ], + deps = [ + "hlo_matchers", + ":hlo_parser", + ":sharding_propagation", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + +cc_library( + name = "dot_as_convolution_util", + srcs = [ + "dot_as_convolution_util.cc", + ], + hdrs = [ + "dot_as_convolution_util.h", + ], + deps = [ + ":hlo", + ":shape_inference", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "@com_google_absl//absl/types:optional", + ], +) + tf_cc_test( name = "dynamic_parameter_binding_test", srcs = ["dynamic_parameter_binding_test.cc"], @@ -3284,6 +3344,7 @@ cc_library( ":heap_simulator", ":hlo_cost_analysis", "//tensorflow/compiler/xla:debug_options_flags", + "//tensorflow/core/lib/math:math_util", ], ) @@ -4497,6 +4558,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/core:lib", "//tensorflow/core:regexp_internal", + "@com_google_absl//absl/base", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 55af8726dc8..4025cb46f18 100755 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -472,8 +472,9 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor { HloInstruction* dot); HloComputation* GetOrCreateScalarAddComputation(PrimitiveType type) { - if (scalar_add_computation_) { - return scalar_add_computation_; + HloComputation*& scalar_add_computation = scalar_add_computations_[type]; + if (scalar_add_computation) { + return scalar_add_computation; } HloComputation::Builder b("scalar_add_computation"); @@ -485,9 +486,9 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor { HloInstruction::CreateParameter(1, shape, "scalar_rhs")); auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary( shape, HloOpcode::kAdd, scalar_lhs, scalar_rhs)); - scalar_add_computation_ = + scalar_add_computation = computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op)); - return scalar_add_computation_; + return scalar_add_computation; } // Tries to fold a kPad in the input or filter into the convolution @@ -508,6 +509,13 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor { // Tries to convert slice(reshape(X)) into reshape(slice(X)) StatusOr<bool> TryToReorderSliceAndReshape(HloInstruction* slice); + // Tries to simplify `(and (< a N) (< a K))` in cases where `N <= K` into + // `(< a N)`. This is crucial for being able to figure out the loop trip + // count. + // + // Assumes that the input is conjunction. + StatusOr<bool> TrySimplifyTautologicalCompare(HloInstruction* conjunction); + // Useful when we want to use the same visitor over multiple computations. void ResetState(HloComputation* computation); @@ -521,8 +529,8 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor { // Whether algebraic simplification has occurred. bool changed_ = false; - // Cached computation for adding two scalar F32. - HloComputation* scalar_add_computation_ = nullptr; + // Cached computation for adding two scalars of a given type. + absl::flat_hash_map<PrimitiveType, HloComputation*> scalar_add_computations_; AlgebraicSimplifier* simplifier_ = nullptr; }; @@ -856,6 +864,50 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { return Status::OK(); } +StatusOr<bool> AlgebraicSimplifierVisitor::TrySimplifyTautologicalCompare( + HloInstruction* conjunction) { + HloInstruction *lhs, *rhs; + if (!Match(conjunction, m::And(m::Op(&lhs), m::Op(&rhs)))) { + return false; + } + struct LessThanCompareInfo { // (LT var constant) + HloInstruction* var; + int64 constant; + }; + + auto get_compare_info = + [&](HloInstruction* cmp) -> absl::optional<LessThanCompareInfo> { + HloInstruction *lhs, *rhs; + auto scalar_shape_matcher = + m::Shape().IsEffectiveScalar().WithElementType(PrimitiveType::S32); + if (Match(cmp, m::Compare(m::Op(&lhs), + m::Constant(&rhs).WithShape(scalar_shape_matcher)) + .WithComparisonDirection(ComparisonDirection::kLt))) { + return {LessThanCompareInfo{lhs, *rhs->literal().GetFirstInteger()}}; + } else if (Match( + cmp, + m::Compare(m::Constant(&lhs).WithShape(scalar_shape_matcher), + m::Op(&rhs)) + .WithComparisonDirection(ComparisonDirection::kGt))) { + return {LessThanCompareInfo{rhs, *lhs->literal().GetFirstInteger()}}; + } + return absl::nullopt; + }; + + absl::optional<LessThanCompareInfo> lhs_info = get_compare_info(lhs); + absl::optional<LessThanCompareInfo> rhs_info = get_compare_info(rhs); + if (lhs_info && rhs_info && lhs_info->var == rhs_info->var) { + int64 new_bound = std::min(lhs_info->constant, rhs_info->constant); + TF_RETURN_IF_ERROR(ReplaceWithNewInstruction( + conjunction, + HloInstruction::CreateCompare(lhs->shape(), lhs_info->var, + MakeScalarLike(lhs_info->var, new_bound), + ComparisonDirection::kLt))); + return true; + } + return false; +} + Status AlgebraicSimplifierVisitor::HandleAnd(HloInstruction* logical_and) { HloInstruction *lhs, *rhs; CHECK(Match(logical_and, m::And(m::Op(&lhs), m::Op(&rhs)))); @@ -890,6 +942,13 @@ Status AlgebraicSimplifierVisitor::HandleAnd(HloInstruction* logical_and) { return Status::OK(); } + // Simplify tautological conjunctions. + TF_ASSIGN_OR_RETURN(bool found_tautological_compare, + TrySimplifyTautologicalCompare(logical_and)); + if (found_tautological_compare) { + return Status::OK(); + } + return Status::OK(); } @@ -1423,6 +1482,22 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { return ReplaceInstruction(divide, new_divide); } + // If X is a convert from pred, then + // X / broadcast(Y) => broadcast(1/Y) * X + if (Match(divide, + m::Divide( + m::Convert(&a, + m::Op().WithShape(m::Shape().WithElementType(PRED))), + m::Broadcast(m::Op(&b).WithShape(m::Shape().IsScalar()))))) { + TF_ASSIGN_OR_RETURN( + auto recip, MakeBinaryHlo(HloOpcode::kDivide, MakeScalarLike(b, 1), b)); + auto recip_bcast = computation_->AddInstruction( + HloInstruction::CreateBroadcast(divide->shape(), recip, {})); + TF_ASSIGN_OR_RETURN(auto mul, + MakeBinaryHlo(HloOpcode::kMultiply, recip_bcast, a)); + return ReplaceInstruction(divide, mul); + } + return Status::OK(); } @@ -2983,6 +3058,20 @@ AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand( return false; } HloInstruction* operand = broadcast->mutable_operand(0); + auto is_scalar_broadcast = [](const HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kBroadcast && + ShapeUtil::IsScalar(instruction->operand(0)->shape()); + }; + auto is_equal_broadcast = [operand, + broadcast](const HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kBroadcast && + ShapeUtil::Equal(operand->shape(), + instruction->operand(0)->shape()) && + broadcast->dimensions() == instruction->dimensions(); + }; + auto is_compatible_broadcast = [&](const HloInstruction* instruction) { + return is_scalar_broadcast(instruction) || is_equal_broadcast(instruction); + }; for (HloInstruction* user : broadcast->users()) { if (user->user_count() == 0 && user != computation_->root_instruction()) { continue; @@ -3001,18 +3090,20 @@ AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand( continue; } - // Find the unique non-scalar operand or continue if there isn't one. - int64 scalar_broadcast_count = 0; + // Check if all the operands of the user are compatible broadcasts for + // sinking. (They are either scalar broadcasts or broadcasts casting + // from/to the same shape/dimensions) + int64 compatible_broadcast_count = 0; int64 broadcast_use_count = 0; for (HloInstruction* user_operand : user->operands()) { - if (user_operand->opcode() == HloOpcode::kBroadcast && - ShapeUtil::IsScalar(user_operand->operand(0)->shape())) { - ++scalar_broadcast_count; + if (is_compatible_broadcast(user_operand)) { + ++compatible_broadcast_count; } else if (broadcast == user_operand) { ++broadcast_use_count; } } - if (scalar_broadcast_count + broadcast_use_count != user->operand_count()) { + if (compatible_broadcast_count + broadcast_use_count != + user->operand_count()) { continue; } std::vector<HloInstruction*> new_operands; @@ -3020,14 +3111,24 @@ AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand( Shape changed_shape; for (HloInstruction* user_operand : user->operands()) { - if (user_operand->opcode() == HloOpcode::kBroadcast && - ShapeUtil::IsScalar(user_operand->operand(0)->shape())) { - changed_shape = ShapeUtil::ChangeElementType( - operand->shape(), user_operand->shape().element_type()); - simplifier_->UpdateLayout(&changed_shape); - new_operands.push_back( - computation_->AddInstruction(HloInstruction::CreateBroadcast( - changed_shape, user_operand->mutable_operand(0), {}))); + // If this is a broadcast operand that is not our original broadcast input + // to this function then we might need to change the input. + if (is_compatible_broadcast(user_operand)) { + // If this is a broadcast from a scalar value rewrite a broadcast from + // the scalar to the new shape enforced from the other broadcast + // operands. + if (is_scalar_broadcast(user_operand)) { + changed_shape = ShapeUtil::ChangeElementType( + operand->shape(), user_operand->shape().element_type()); + simplifier_->UpdateLayout(&changed_shape); + new_operands.push_back( + computation_->AddInstruction(HloInstruction::CreateBroadcast( + changed_shape, user_operand->mutable_operand(0), {}))); + } else { + // For the non-scalar broadcasts we guarantee that the shape of the + // operand of the broadcast needs to be already a compatible shape. + new_operands.push_back(user_operand->mutable_operand(0)); + } } else { CHECK_EQ(broadcast, user_operand); new_operands.push_back(operand); diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 6c8e80aa963..bcfc2fdc740 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -338,6 +338,79 @@ TEST_F(AlgebraicSimplifierTest, MultiplyReassociateMergeBroadcastedConstants) { m::ConstantScalar(3.0)))))); } +TEST_F(AlgebraicSimplifierTest, ElementwiseSinkMultipleBroadcastsScalar) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + b0 = f32[4] broadcast(p0), dimensions={} + b1 = f32[4] broadcast(p1), dimensions={} + ROOT multiply = f32[4] multiply(b1, b0) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::Broadcast(m::Multiply(m::Broadcast(m::Parameter(1)), + m::Broadcast(m::Parameter(0)))))); +} + +TEST_F(AlgebraicSimplifierTest, ElementwiseSinkMultipleBroadcastsConstantMix) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[4] parameter(0) + c0 = f32[] constant(2.0) + b0 = f32[4,2] broadcast(c0), dimensions={} + b1 = f32[4,2] broadcast(p0), dimensions={0} + ROOT multiply = f32[4,2] multiply(b1, b0) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Broadcast(m::Multiply( + m::Parameter(0), m::Broadcast(m::ConstantScalar(2.0)))))); +} + +TEST_F(AlgebraicSimplifierTest, ElementwiseSinkMultipleBroadcastsNonScalar) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[4] parameter(0) + p1 = f32[4] parameter(1) + b0 = f32[4,2] broadcast(p0), dimensions={0} + b1 = f32[4,2] broadcast(p1), dimensions={0} + ROOT multiply = f32[4,2] multiply(b1, b0) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::Broadcast(m::Multiply(m::Parameter(1), m::Parameter(0))))); +} + +TEST_F(AlgebraicSimplifierTest, ElementwiseNoSinkBroadcastsDifferentDims) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[4] parameter(0) + p1 = f32[8] parameter(1) + b0 = f32[4,8] broadcast(p0), dimensions={0} + b1 = f32[4,8] broadcast(p1), dimensions={1} + ROOT multiply = f32[4,8] multiply(b1, b0) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Multiply(m::Broadcast(m::Parameter(1)), + m::Broadcast(m::Parameter(0))))); +} + TEST_F(AlgebraicSimplifierTest, MultiplyReassociateMultiplyOfConstantAndBroadcast) { const char* kModuleStr = R"( @@ -5761,6 +5834,44 @@ TEST_F(AlgebraicSimplifierTest, CompareSame) { GmockMatch(m::Broadcast(m::ConstantScalar(true)))); } +TEST_F(AlgebraicSimplifierTest, CompareSimplified) { + const char* kModuleStr = R"( + HloModule m + test { + param = s32[] parameter(0) + c1 = s32[] constant(10) + c2 = s32[] constant(100) + cmp1 = pred[] compare(param, c1), direction=LT + cmp2 = pred[] compare(param, c2), direction=LT + ROOT out = pred[] and(cmp1, cmp2) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::Compare(m::Op(), m::Op().IsConstantScalar(10)) + .WithComparisonDirection(ComparisonDirection::kLt))); +} + +TEST_F(AlgebraicSimplifierTest, CompareSimplifiedReversed) { + const char* kModuleStr = R"( + HloModule m + test { + param = s32[] parameter(0) + c1 = s32[] constant(10) + c2 = s32[] constant(100) + cmp1 = pred[] compare(param, c1), direction=LT + cmp2 = pred[] compare(c2, param), direction=GT + ROOT out = pred[] and(cmp1, cmp2) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::Compare(m::Op(), m::Op().IsConstantScalar(10)) + .WithComparisonDirection(ComparisonDirection::kLt))); +} + TEST_F(AlgebraicSimplifierTest, CanDisableDotToMultiplyRewrite) { // Some backends may have better performance by treating an outer product as a // Dot, rather than a broadcast Multiply @@ -6462,5 +6573,43 @@ TEST_F(AlgebraicSimplifierTest, SwapConvOperands) { EXPECT_EQ(conv->window().dimensions(1).padding_high(), 1); } +TEST_F(AlgebraicSimplifierTest, ScalarDividePredicate) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = pred[2] parameter(0) + cvt = f32[2] convert(p0) + p1 = f32[] parameter(1) + bcast = f32[2] broadcast(p1), dimensions={} + ROOT div = f32[2] divide(cvt, bcast) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::MultiplyAnyOrder( + m::Convert(m::Parameter(0)), + m::Broadcast(m::Divide(m::ConstantScalar(1), m::Parameter(1)))))); +} + +TEST_F(AlgebraicSimplifierTest, MultipleDotStrengthReductions) { + constexpr char kModuleStr[] = R"( + HloModule test + ENTRY test { + a = c64[2,2] parameter(0) + b = c64[2] parameter(1) + cd = c64[2] dot(a, b), lhs_contracting_dims={1}, rhs_contracting_dims={0} + c = f64[2,2] parameter(2) + d = f64[2] parameter(3) + dd = f64[2] dot(c, d), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT tuple = (c64[2], f64[2]) tuple(cd, dd) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_EQ(3, m->computation_count()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/all_gather_decomposer.cc b/tensorflow/compiler/xla/service/all_gather_decomposer.cc index ad63218eca8..00b9adaea43 100644 --- a/tensorflow/compiler/xla/service/all_gather_decomposer.cc +++ b/tensorflow/compiler/xla/service/all_gather_decomposer.cc @@ -50,14 +50,18 @@ HloComputation* MakeBinaryAdd(PrimitiveType type, HloModule* module) { return reduction; } -Status DecomposeAllGather(HloAllGatherInstruction* ag, int64 partition_count, - HloComputation* comp) { +Status DecomposeAllGather(HloAllGatherInstruction* ag, HloComputation* comp) { + const int64 shard_size = + ag->operand(0)->shape().dimensions(ag->all_gather_dimension()); + const int64 ag_size = ag->shape().dimensions(ag->all_gather_dimension()); + TF_RET_CHECK(ag_size % shard_size == 0); + int64 partition_count = ag_size / shard_size; auto zero = comp->AddInstruction(HloInstruction::CreateConstant( LiteralUtil::Zero(ag->shape().element_type()))); zero = comp->AddInstruction( HloInstruction::CreateBroadcast(ag->shape(), zero, {})); auto zero_index = comp->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::Zero(S32))); + HloInstruction::CreateConstant(LiteralUtil::Zero(U32))); std::vector<HloInstruction*> start_indices(ag->shape().rank(), zero_index); auto shard_id_from_subgroup = [&](HloInstruction* replica_or_global_id) { if (ag->replica_groups().empty()) { @@ -79,19 +83,19 @@ Status DecomposeAllGather(HloAllGatherInstruction* ag, int64 partition_count, } // Create a table of shard IDs for each replica_or_global_id, then slice it // using replica_or_global_id. - std::vector<int32> shard_ids(ag->replica_groups().size() * - ag->replica_groups()[0].replica_ids_size()); + std::vector<uint32> shard_ids(ag->replica_groups().size() * + ag->replica_groups()[0].replica_ids_size()); for (const auto& group : ag->replica_groups()) { for (int64 i = 0; i < group.replica_ids_size(); ++i) { shard_ids[group.replica_ids(i)] = i; } } auto id_table = comp->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1<int32>(shard_ids))); + LiteralUtil::CreateR1<uint32>(shard_ids))); auto shard_id = comp->AddInstruction(HloInstruction::CreateDynamicSlice( - ShapeUtil::MakeShape(S32, {1}), id_table, {replica_or_global_id}, {1})); + ShapeUtil::MakeShape(U32, {1}), id_table, {replica_or_global_id}, {1})); shard_id = comp->AddInstruction( - HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), shard_id)); + HloInstruction::CreateReshape(ShapeUtil::MakeShape(U32, {}), shard_id)); return shard_id; }; HloInstruction* shard_id; @@ -100,7 +104,7 @@ Status DecomposeAllGather(HloAllGatherInstruction* ag, int64 partition_count, auto pid = comp->AddInstruction(HloInstruction::CreatePartitionId()); auto rid = comp->AddInstruction(HloInstruction::CreateReplicaId()); auto pcount = comp->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR0<int32>(partition_count))); + LiteralUtil::CreateR0<uint32>(partition_count))); auto global_id = comp->AddInstruction(HloInstruction::CreateBinary( pid->shape(), HloOpcode::kAdd, pid, comp->AddInstruction(HloInstruction::CreateBinary( @@ -119,8 +123,7 @@ Status DecomposeAllGather(HloAllGatherInstruction* ag, int64 partition_count, comp->AddInstruction(HloInstruction::CreateBinary( shard_id->shape(), HloOpcode::kMultiply, shard_id, comp->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR0<int32>(ag->operand(0)->shape().dimensions( - ag->all_gather_dimension())))))); + LiteralUtil::CreateR0<uint32>(shard_size))))); auto dus = comp->AddInstruction(HloInstruction::CreateDynamicUpdateSlice( zero->shape(), zero, ag->mutable_operand(0), start_indices)); auto ar = comp->AddInstruction(HloInstruction::CreateAllReduce( @@ -143,7 +146,7 @@ StatusOr<bool> AllGatherDecomposer::Run(HloModule* module) { } auto ag = Cast<HloAllGatherInstruction>(hlo); if (should_decompose_(*ag)) { - TF_RETURN_IF_ERROR(DecomposeAllGather(ag, partition_count_, comp)); + TF_RETURN_IF_ERROR(DecomposeAllGather(ag, comp)); changed = true; } } diff --git a/tensorflow/compiler/xla/service/all_gather_decomposer.h b/tensorflow/compiler/xla/service/all_gather_decomposer.h index d1983e37383..6b20765c709 100644 --- a/tensorflow/compiler/xla/service/all_gather_decomposer.h +++ b/tensorflow/compiler/xla/service/all_gather_decomposer.h @@ -26,15 +26,12 @@ namespace xla { // dynamic-update-slices and all-reduces. class AllGatherDecomposer : public HloModulePass { public: - AllGatherDecomposer( - std::function<bool(const HloAllGatherInstruction&)> should_decompose, - int64 partition_count) - : should_decompose_(std::move(should_decompose)), - partition_count_(partition_count) {} - explicit AllGatherDecomposer(int64 partition_count) + explicit AllGatherDecomposer( + std::function<bool(const HloAllGatherInstruction&)> should_decompose) + : should_decompose_(std::move(should_decompose)) {} + AllGatherDecomposer() : should_decompose_( - [](const HloAllGatherInstruction& ag) { return true; }), - partition_count_(partition_count) {} + [](const HloAllGatherInstruction& ag) { return true; }) {} absl::string_view name() const override { return "all_gather_decomposer"; } // Run AllGatherDecomposer pass on computations in 'module'. diff --git a/tensorflow/compiler/xla/service/all_gather_decomposer_test.cc b/tensorflow/compiler/xla/service/all_gather_decomposer_test.cc index ebcd66ffa07..3df5e51a7c2 100644 --- a/tensorflow/compiler/xla/service/all_gather_decomposer_test.cc +++ b/tensorflow/compiler/xla/service/all_gather_decomposer_test.cc @@ -48,7 +48,7 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, ParseAndReturnUnverifiedModule((module_str))); - AllGatherDecomposer decomposer(/*partition_count=*/4); + AllGatherDecomposer decomposer; TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); EXPECT_TRUE(changed); EXPECT_THAT( @@ -71,7 +71,7 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, ParseAndReturnUnverifiedModule((module_str))); - AllGatherDecomposer decomposer(/*partition_count=*/4); + AllGatherDecomposer decomposer; TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); EXPECT_TRUE(changed); EXPECT_THAT( @@ -94,7 +94,7 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, ParseAndReturnUnverifiedModule((module_str))); - AllGatherDecomposer decomposer(/*partition_count=*/4); + AllGatherDecomposer decomposer; TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); EXPECT_TRUE(changed); EXPECT_THAT( @@ -117,11 +117,11 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, ParseAndReturnUnverifiedModule((module_str))); - AllGatherDecomposer decomposer(/*partition_count=*/4); + AllGatherDecomposer decomposer; TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); EXPECT_TRUE(changed); auto id = - AllOf(op::Shape("s32[]"), + AllOf(op::Shape("u32[]"), op::Reshape(op::DynamicSlice(op::Constant(), op::ReplicaId()))); EXPECT_THAT(module->entry_computation()->root_instruction(), op::AllReduce(op::DynamicUpdateSlice( @@ -143,13 +143,12 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, ParseAndReturnUnverifiedModule((module_str))); - AllGatherDecomposer decomposer(/*partition_count=*/4); + AllGatherDecomposer decomposer; TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); EXPECT_TRUE(changed); - LOG(ERROR) << module->ToString(); auto global_id = op::Add(op::PartitionId(), op::Multiply(op::ReplicaId(), op::Constant())); - auto id = AllOf(op::Shape("s32[]"), + auto id = AllOf(op::Shape("u32[]"), op::Reshape(op::DynamicSlice(op::Constant(), global_id))); EXPECT_THAT(module->entry_computation()->root_instruction(), op::AllReduce(op::DynamicUpdateSlice( diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 2f432cd9356..3460e65b0a2 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -118,6 +118,9 @@ cc_library( ":target_machine_features", "@com_google_absl//absl/base", "@com_google_absl//absl/types:span", + "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", + "@llvm-project//mlir:ExecutionEngineUtils", + "@llvm-project//mlir:LLVMDialect", "//tensorflow/compiler/xla/service:copy_insertion", "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:dump", @@ -366,6 +369,7 @@ cc_library( "@llvm-project//llvm:core", "@llvm-project//llvm:support", "@llvm-project//llvm:target", + "@llvm-project//mlir:IR", ], ) @@ -456,6 +460,7 @@ cc_library( ":cpu_options", ":cpu_runtime", ":ir_emission_utils", + ":mlir_emitter", ":target_machine_features", ":tiled_dot_emitter", ":vector_support_library", @@ -474,6 +479,10 @@ cc_library( "//tensorflow/core:lib", "@com_google_absl//absl/strings", "@llvm-project//llvm:core", + "@llvm-project//mlir:EDSC", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LinalgOps", + "@llvm-project//mlir:StandardOps", ], ) @@ -1070,3 +1079,24 @@ tf_cc_test( "@llvm-project//llvm:target", ], ) + +cc_library( + name = "mlir_emitter", + srcs = ["mlir_emitter.cc"], + hdrs = ["mlir_emitter.h"], + deps = [ + "//tensorflow/compiler/mlir/xla:hlo_utils", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status", + "@llvm-project//llvm:core", + "@llvm-project//llvm:ipo", + "@llvm-project//llvm:linker", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMTransforms", + "@llvm-project//mlir:LinalgToLLVM", + "@llvm-project//mlir:LinalgTransforms", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:TargetLLVMIR", + "@llvm-project//mlir:VectorToLLVM", + ], +) diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index fe769bbdd2a..b2416ac2799 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -42,6 +42,8 @@ limitations under the License. #include "llvm/Support/TargetSelect.h" #include "llvm/Target/TargetMachine.h" #include "llvm/Target/TargetOptions.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project +#include "mlir/InitAllDialects.h" // from @llvm-project #include "tensorflow/compiler/xla/cpu_function_runtime.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/map_util.h" @@ -158,6 +160,8 @@ CpuCompiler::CpuCompiler() { // Initialize LLVM's MC layer for the native target. llvm::InitializeNativeTarget(); llvm::InitializeNativeTargetAsmPrinter(); + + mlir::registerAllDialects(); } namespace { @@ -606,9 +610,11 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend( user_post_optimization_hook_); // Compile must be thread-safe so create a new LLVM context for the module. - auto llvm_context = absl::make_unique<llvm::LLVMContext>(); - auto llvm_module = - absl::make_unique<llvm::Module>("__compute_module", *llvm_context); + mlir::MLIRContext mlir_context; + auto llvm_module = absl::make_unique<llvm::Module>( + "__compute_module", + mlir_context.getRegisteredDialect<mlir::LLVM::LLVMDialect>() + ->getLLVMContext()); auto jit = absl::make_unique<SimpleOrcJIT>( CompilerTargetOptions(module->config()), @@ -662,7 +668,7 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend( // before a caller computation. LLVMTargetMachineFeatures target_machine_features(jit->target_machine()); - IrEmitter ir_emitter(*module, *assignment, llvm_module.get(), + IrEmitter ir_emitter(&mlir_context, *module, *assignment, llvm_module.get(), std::move(instruction_to_profile_idx), std::move(computation_to_profile_idx), &target_machine_features, @@ -816,8 +822,11 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group, opt_level)); // Compile must be thread-safe so create a new LLVM context for the module. - llvm::LLVMContext llvm_context; - llvm::Module llvm_module("__compute_module", llvm_context); + mlir::MLIRContext mlir_context; + llvm::Module llvm_module( + "__compute_module", + mlir_context.getRegisteredDialect<mlir::LLVM::LLVMDialect>() + ->getLLVMContext()); llvm_module.setDataLayout(target_machine->createDataLayout()); llvm_module.setTargetTriple(triple.getTriple()); if (pic_level != llvm::PICLevel::NotPIC) { @@ -866,7 +875,7 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group, } LLVMTargetMachineFeatures target_machine_features(target_machine.get()); - IrEmitter ir_emitter(*module, *assignment, &llvm_module, + IrEmitter ir_emitter(&mlir_context, *module, *assignment, &llvm_module, std::move(instruction_to_profile_idx), std::move(computation_to_profile_idx), &target_machine_features, diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.cc b/tensorflow/compiler/xla/service/cpu/cpu_options.cc index ff654c83d61..c0222010fd9 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_options.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_options.cc @@ -25,6 +25,7 @@ const char* const kXlaOptimizeForSizeCpuOption = "xla_cpu_optimize_for_size"; const char* const kLlvmIrDotTilingFactor = "xla_llvm_dot_tiling_factor"; const char* const kXlaForceEnableExperimentalLlvmIrGemm = "xla_force_enable_experimental_llvm_ir_gemm"; +const char* const kXlaUseLinalgForDot = "xla_use_linalg_for_dot"; const char* const kLlvmIrGemmTileSize = "xla_llvm_ir_gemm_tile_size"; } // namespace @@ -63,6 +64,12 @@ bool ForceEnableExperimentalLlvmIrGemm(const HloModuleConfig& config) { return extra_options_map.count(kXlaForceEnableExperimentalLlvmIrGemm) > 0; } +bool UseLinalgForDot(const HloModuleConfig& config) { + const auto& extra_options_map = + config.debug_options().xla_backend_extra_options(); + return extra_options_map.count(kXlaUseLinalgForDot) > 0; +} + static absl::string_view RemoveSuffix(absl::string_view str, absl::string_view suffix) { CHECK_GE(str.size(), suffix.size()); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.h b/tensorflow/compiler/xla/service/cpu/cpu_options.h index 99e6702d14a..5d25aef6912 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_options.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_options.h @@ -27,6 +27,7 @@ namespace options { bool OptimizeForSizeRequested(const HloModuleConfig& config); bool VectorizedReduceDisabled(const HloModuleConfig& config); bool ForceEnableExperimentalLlvmIrGemm(const HloModuleConfig& config); +bool UseLinalgForDot(const HloModuleConfig& config); absl::optional<int64> LlvmIrGemvTilingFactor(const HloModuleConfig& config); absl::optional<std::tuple<int64, int64, int64>> LlvmIrGemmTileSize( const HloModuleConfig& config); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc index fae9670051a..e21ed7ad60e 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc @@ -154,7 +154,8 @@ CpuTransferManager::TransferBufferToInfeedInternal(se::StreamExecutor* executor, int64 size, const void* source) { if (size > std::numeric_limits<int32>::max()) { - return InvalidArgument("Infeed shape is too large: needs %d bytes", size); + return InvalidArgument("CPU infeed of %d bytes exceeds maximum of %d bytes", + size, std::numeric_limits<int32>::max()); } if (size <= 0) { diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index 7dba826b65c..e1ad14600d7 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -23,8 +23,17 @@ limitations under the License. #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" #include "llvm/IR/Value.h" +#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" // from @llvm-project +#include "mlir/EDSC/Builders.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "tensorflow/compiler/xla/service/cpu/cpu_options.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/cpu/mlir_emitter.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h" #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h" @@ -89,6 +98,9 @@ enum class DotImplementationStrategy { // and the output have to be row major. kTiledLlvmIrGemm, + // The dot operation is lowered into linalg.matmul op and lowered to LLVM IR. + kLinalgMatmul, + // The dot operation is lowered into a call into an Eigen routine. No fusions // are supported today. The two inputs and the output have to be row major. // However, we do allow transposing either the LHS or the RHS as part of the @@ -112,7 +124,7 @@ class DotOpEmitter { const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, llvm::Value* executable_run_options_value, - llvm::IRBuilder<>* b, + llvm::IRBuilder<>* b, mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config, const TargetMachineFeatures& target_machine_features); @@ -163,6 +175,9 @@ class DotOpEmitter { // Lowers the dot operation as a tiled Matrix*Matrix loop. void EmitTiledLlvmIrGemm(); + // Lowers the dot operation through MLIR's linalg.matmul. + Status EmitLinalgMatmul(); + // Lowers the dot operation as a naive nested loop that computes the result // one element at a time. void EmitNaiveLlvmIrGemm(); @@ -194,20 +209,19 @@ class DotOpEmitter { const llvm_ir::IrArray* addend_array_; llvm::Value* executable_run_options_value_; llvm::IRBuilder<>* b_; + mlir::MLIRContext* mlir_context_; const HloModuleConfig& hlo_module_config_; const TargetMachineFeatures& target_machine_features_; }; } // namespace -DotOpEmitter::DotOpEmitter(DotInfo dot_info, string dot_hlo_name, - const llvm_ir::IrArray& target_array, - const llvm_ir::IrArray& lhs_array, - const llvm_ir::IrArray& rhs_array, - const llvm_ir::IrArray* addend_array, - llvm::Value* executable_run_options_value, - llvm::IRBuilder<>* b, - const HloModuleConfig& hlo_module_config, - const TargetMachineFeatures& target_machine_features) +DotOpEmitter::DotOpEmitter( + DotInfo dot_info, string dot_hlo_name, const llvm_ir::IrArray& target_array, + const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, + const llvm_ir::IrArray* addend_array, + llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b, + mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config, + const TargetMachineFeatures& target_machine_features) : dot_info_(std::move(dot_info)), dot_hlo_name_(std::move(dot_hlo_name)), target_array_(target_array), @@ -216,9 +230,36 @@ DotOpEmitter::DotOpEmitter(DotInfo dot_info, string dot_hlo_name, addend_array_(addend_array), executable_run_options_value_(executable_run_options_value), b_(b), + mlir_context_(mlir_context), hlo_module_config_(hlo_module_config), target_machine_features_(target_machine_features) {} +Status DotOpEmitter::EmitLinalgMatmul() { + Shape operand_shapes[] = {dot_info_.lhs_shape, dot_info_.rhs_shape}; + llvm::Value* operand_ptrs[] = {lhs_array_.GetBasePointer(), + rhs_array_.GetBasePointer()}; + llvm::Value* target_ptr = target_array_.GetBasePointer(); + + // Zero out the output buffer. + int64 size_bytes = ShapeUtil::ByteSizeOf(dot_info_.result_shape); + b_->CreateMemSet(target_ptr, b_->getInt8(0), /*Size=*/size_bytes, + /*Align=*/llvm::MaybeAlign(1)); + + std::string name = + absl::StrCat("linalgMatMul_", dot_info_.result_shape.ToString(true), "_", + dot_info_.lhs_shape.ToString(true), "_", + dot_info_.rhs_shape.ToString(true)); + return EmitMlirFuncAndCall( + mlir_context_, b_, dot_info_.result_shape, operand_shapes, target_ptr, + operand_ptrs, name, [&](mlir::OpBuilder* builder, mlir::FuncOp function) { + mlir::edsc::ScopedContext scope(*builder, function.getLoc()); + mlir::Value a = function.getArgument(0), b = function.getArgument(1), + c = function.getArgument(2); + mlir::edsc::intrinsics::linalg_matmul(b, c, a); + mlir::edsc::intrinsics::std_ret(); + }); +} + void DotOpEmitter::EmitTiledLlvmIrGemm() { PrimitiveType primitive_type = dot_info_.result_shape.element_type(); MatMultDims mat_mult_dims = GetMatMultDims(); @@ -418,6 +459,9 @@ Status DotOpEmitter::Emit() { EmitTiledLlvmIrGemm(); return Status::OK(); + case DotImplementationStrategy::kLinalgMatmul: + return EmitLinalgMatmul(); + case DotImplementationStrategy::kEigen: return EmitCallToRuntime(); } @@ -886,9 +930,12 @@ DotImplementationStrategy GetDotImplementationStrategy( } if (IsAlignedGemm(dot_info, target_machine_features)) { - return CanEmitTiledLlvmIrGemm(config, dot_info, target_machine_features) - ? DotImplementationStrategy::kTiledLlvmIrGemm - : DotImplementationStrategy::kEigen; + if (CanEmitTiledLlvmIrGemm(config, dot_info, target_machine_features)) { + return options::UseLinalgForDot(config) + ? DotImplementationStrategy::kLinalgMatmul + : DotImplementationStrategy::kTiledLlvmIrGemm; + } + return DotImplementationStrategy::kEigen; } return DotImplementationStrategy::kNaiveLlvmIr; @@ -899,15 +946,15 @@ Status EmitNonBatchDotOperation( const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b, - const HloModuleConfig& hlo_module_config, + mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config, const TargetMachineFeatures& target_machine_features) { PrimitiveType type = target_array.GetShape().element_type(); TF_RET_CHECK(S32 == type || F16 == type || F32 == type || F64 == type || C64 == type || C128 == type); DotOpEmitter dot_emitter(std::move(dot_info), std::move(hlo_name), target_array, lhs_array, rhs_array, addend_array, - executable_run_options_value, b, hlo_module_config, - target_machine_features); + executable_run_options_value, b, mlir_context, + hlo_module_config, target_machine_features); return dot_emitter.Emit(); } @@ -981,7 +1028,7 @@ Status EmitBatchDotOperation( const HloInstruction& dot, const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b, - const HloModuleConfig& hlo_module_config, + mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config, const TargetMachineFeatures& target_machine_features) { TF_RETURN_IF_ERROR(ValidateDotDimensionNumbers(dot.dot_dimension_numbers())); @@ -1039,7 +1086,7 @@ Status EmitBatchDotOperation( // Emit the inner non-batch dot operation. return EmitNonBatchDotOperation( dot_info, dot.name(), target_slice, lhs_slice, rhs_slice, nullptr, - executable_run_options_value, b, hlo_module_config, + executable_run_options_value, b, mlir_context, hlo_module_config, target_machine_features); }); } @@ -1089,7 +1136,7 @@ Status EmitDotOperation(const HloInstruction& dot, const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, llvm::Value* executable_run_options_value, - llvm::IRBuilder<>* b, + llvm::IRBuilder<>* b, mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config, const TargetMachineFeatures& target_machine_features) { // This routine assumes that the dot operation is not in a parallelized @@ -1099,13 +1146,13 @@ Status EmitDotOperation(const HloInstruction& dot, if (IsBatchDot(dot)) { TF_RET_CHECK(addend_array == nullptr); return EmitBatchDotOperation(dot, target_array, lhs_array, rhs_array, - executable_run_options_value, b, + executable_run_options_value, b, mlir_context, hlo_module_config, target_machine_features); } return EmitNonBatchDotOperation(DotInfo(dot), dot.name(), target_array, lhs_array, rhs_array, addend_array, - executable_run_options_value, b, + executable_run_options_value, b, mlir_context, hlo_module_config, target_machine_features); } } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h index 105bd3005c8..d9cf8a2036b 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h @@ -18,6 +18,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "llvm/IR/IRBuilder.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "tensorflow/compiler/xla/service/cpu/cpu_options.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -63,7 +64,7 @@ Status EmitDotOperation(const HloInstruction& dot, const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, llvm::Value* executable_run_options_value, - llvm::IRBuilder<>* b, + llvm::IRBuilder<>* b, mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config, const TargetMachineFeatures& target_machine_features); } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 5a4c6250293..1e204afb001 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -41,6 +41,7 @@ limitations under the License. #include "llvm/IR/Intrinsics.h" #include "llvm/IR/IntrinsicsX86.h" #include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/primitive_util.h" @@ -89,8 +90,8 @@ using llvm_ir::SetToFirstInsertPoint; namespace cpu { IrEmitter::IrEmitter( - const HloModule& hlo_module, const BufferAssignment& assignment, - llvm::Module* llvm_module, + mlir::MLIRContext* mlir_context, const HloModule& hlo_module, + const BufferAssignment& assignment, llvm::Module* llvm_module, std::unordered_map<const HloInstruction*, int64> instruction_to_profile_idx, std::unordered_map<const HloComputation*, int64> computation_to_profile_idx, const TargetMachineFeatures* target_machine_features, @@ -99,6 +100,7 @@ IrEmitter::IrEmitter( module_(llvm_module), arch_type_(llvm::Triple(llvm_module->getTargetTriple()).getArch()), b_(llvm_module->getContext()), + mlir_context_(mlir_context), instruction_to_profile_idx_(std::move(instruction_to_profile_idx)), computation_to_profile_idx_(std::move(computation_to_profile_idx)), alias_analysis_(hlo_module, assignment, &llvm_module->getContext()), @@ -898,7 +900,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { // Dot operation is complicated so we delegate to a helper class. return EmitDotOperation(*dot, target_array, lhs_array, rhs_array, /*addend_array=*/nullptr, - GetExecutableRunOptionsArgument(), &b_, + GetExecutableRunOptionsArgument(), &b_, mlir_context_, hlo_module_config_, target_machine_features_); } @@ -2305,10 +2307,10 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { llvm_ir::IrArray addend_array( GetIrArrayFor(fusion->operand(addend_param_number))); - TF_RETURN_IF_ERROR( - EmitDotOperation(*dot, target_array, lhs_array, rhs_array, - &addend_array, GetExecutableRunOptionsArgument(), &b_, - hlo_module_config_, target_machine_features_)); + TF_RETURN_IF_ERROR(EmitDotOperation( + *dot, target_array, lhs_array, rhs_array, &addend_array, + GetExecutableRunOptionsArgument(), &b_, mlir_context_, + hlo_module_config_, target_machine_features_)); return Status::OK(); } else { return Unimplemented("Fusion kind not implemented on CPU"); @@ -2343,56 +2345,68 @@ Status IrEmitter::HandleCall(HloInstruction* call) { } Status IrEmitter::HandleSliceToDynamic(HloInstruction* hlo) { - // TODO(jackcao): Generalize this to generic llvm emitter. - TF_RET_CHECK(hlo->shape().rank() == 1); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo)); + std::vector<llvm::Value*> dynamic_dims; + int32 raw_data_size = + ShapeUtil::ByteSizeOf(ShapeUtil::MakeStaticShape(hlo->shape())); + llvm::Value* dest_buffer = GetEmittedValueFor(hlo); + llvm::Value* raw_buffer = + b_.CreateBitCast(dest_buffer, b_.getInt8Ty()->getPointerTo()); for (int64 i = 1; i < hlo->operand_count(); ++i) { const int64 dim_index = i - 1; llvm::Value* source_buffer = GetEmittedValueFor(hlo->operand(i)); - llvm::LoadInst* dim_size = b_.CreateLoad(source_buffer, "dim_size"); - llvm::Value* dest_buffer = GetEmittedValueFor(hlo); - llvm::Value* raw_buffer = - b_.CreateBitCast(dest_buffer, b_.getInt8Ty()->getPointerTo()); + llvm::LoadInst* dyn_dim_size = b_.CreateLoad(source_buffer, "dyn_dim_size"); - int32 raw_data_size = - ShapeUtil::ByteSizeOf(ShapeUtil::MakeStaticShape(hlo->shape())); llvm::Value* metadata = b_.CreateConstInBoundsGEP1_32( b_.getInt8Ty(), raw_buffer, raw_data_size + dim_index * sizeof(int32)); - b_.CreateStore(dim_size, + b_.CreateStore(dyn_dim_size, b_.CreateBitCast(metadata, b_.getInt32Ty()->getPointerTo())); + dynamic_dims.push_back(b_.CreateIntCast(dyn_dim_size, b_.getInt64Ty(), + /*isSigned=*/true, + "i64_dyn_dim_size")); } - return EmitTargetElementLoop(hlo, - [=](const llvm_ir::IrArray::Index& dest_index) { - // TODO(jackcao): Properly linearize dest_index - // and delinearize to source index. - return GetIrArrayFor(hlo->operand(0)) - .EmitReadArrayElement(dest_index, &b_); - }); + llvm_ir::IrArray data_array = GetIrArrayFor(hlo); + // Pseudo code for sliceToDynamic: + // + // for (index i in dynamic_dim) + // dest_index = delinearize(linearize(i, dynamic_dim), static_dim) + // dest[dest_index] = source[i] + auto loop_body_emitter = + [&](const llvm_ir::IrArray::Index& array_index) -> Status { + llvm::Value* source_element = + GetIrArrayFor(hlo->operand(0)).EmitReadArrayElement(array_index, &b_); + llvm::Value* linear_index = array_index.Linearize(dynamic_dims, &b_); + // Delinearize the index based on the static shape. + llvm_ir::IrArray::Index dest_index(linear_index, data_array.GetShape(), + &b_); + data_array.EmitWriteArrayElement(dest_index, source_element, &b_); + return Status::OK(); + }; + return llvm_ir::LoopEmitter(loop_body_emitter, data_array.GetShape(), + dynamic_dims, &b_) + .EmitLoop(IrName(hlo)); } Status IrEmitter::HandlePadToStatic(HloInstruction* hlo) { - // TODO(jackcao): Generalize this to generic llvm emitter. - TF_RET_CHECK(hlo->operand(0)->shape().rank() == 1); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo)); TF_ASSIGN_OR_RETURN(BufferAllocation::Slice data_slice, assignment_.GetUniqueSlice(hlo, {0})); + std::vector<llvm::Value*> dynamic_dims; + std::vector<llvm::Value*> tuple_operand_ptrs; const Shape& data_shape = ShapeUtil::GetSubshape(hlo->shape(), {0}); + const Shape& input_shape = hlo->operand(0)->shape(); llvm::Value* data_address = EmitBufferPointer(data_slice, data_shape); llvm_ir::IrArray data_array(data_address, data_shape); - TF_RETURN_IF_ERROR(llvm_ir::LoopEmitter( - [=](const llvm_ir::IrArray::Index& dest_index) { - // TODO(jackcao): Properly linearize dest_index and - // delinearize to source index. - return GetIrArrayFor(hlo->operand(0)) - .EmitReadArrayElement(dest_index, &b_); - }, - llvm_ir::IrArray(data_address, data_shape), &b_) - .EmitLoop(IrName(hlo))); - std::vector<llvm::Value*> tuple_operand_ptrs; - tuple_operand_ptrs.push_back(data_array.GetBasePointer()); + llvm::Value* source_buffer = GetEmittedValueFor(hlo->operand(0)); + llvm::Value* raw_buffer = + b_.CreateBitCast(source_buffer, b_.getInt8Ty()->getPointerTo()); + int64 raw_data_size = + ShapeUtil::ByteSizeOf(ShapeUtil::MakeStaticShape(input_shape)); + // Put a placeholder for the data array's pointer + tuple_operand_ptrs.push_back(data_array.GetBasePointer()); // PadToStatic has a dynamic tensor as input and variadic size of outputs: // (static_tensor, dynamic_dim_0, dynamic_dim_1, ... ) // Dynamic dimension sizes starts from output index 1. @@ -2405,20 +2419,38 @@ Status IrEmitter::HandlePadToStatic(HloInstruction* hlo) { llvm::Value* dest_dim_size_address = EmitBufferPointer(dim_size_slice, data_shape); const int64 dim_index = i - 1; - llvm::Value* source_buffer = GetEmittedValueFor(hlo->operand(0)); - llvm::Value* raw_buffer = - b_.CreateBitCast(source_buffer, b_.getInt8Ty()->getPointerTo()); - int32 raw_data_size = ShapeUtil::ByteSizeOf( - ShapeUtil::MakeStaticShape(hlo->operand(0)->shape())); llvm::Value* metadata = b_.CreateConstInBoundsGEP1_32( b_.getInt8Ty(), raw_buffer, raw_data_size + dim_index * sizeof(int32)); - llvm::Value* dim_size = b_.CreateLoad( - b_.CreateBitCast(metadata, b_.getInt32Ty()->getPointerTo())); - b_.CreateStore(dim_size, b_.CreateBitCast(dest_dim_size_address, - b_.getInt32Ty()->getPointerTo())); + llvm::Value* dyn_dim_size = b_.CreateLoad( + b_.CreateBitCast(metadata, b_.getInt32Ty()->getPointerTo()), + "dyn_dim_size"); + b_.CreateStore(dyn_dim_size, + b_.CreateBitCast(dest_dim_size_address, + b_.getInt32Ty()->getPointerTo())); + dynamic_dims.push_back(b_.CreateIntCast(dyn_dim_size, b_.getInt64Ty(), + /*isSigned=*/true, + "i64_dyn_dim_size")); tuple_operand_ptrs.push_back(dest_dim_size_address); } + // Pseudo code for padToStatic: + // + // for (index i in dynamic_dim) + // source_index = delinearize(inearize(i, dynamic_dim), static_dim) + // dest[i] = source[source_index] + auto loop_body_emitter = + [&](const llvm_ir::IrArray::Index& array_index) -> Status { + llvm::Value* linear_index = array_index.Linearize(dynamic_dims, &b_); + llvm_ir::IrArray::Index source_index(linear_index, input_shape, &b_); + llvm::Value* source_element = + GetIrArrayFor(hlo->operand(0)).EmitReadArrayElement(source_index, &b_); + data_array.EmitWriteArrayElement(array_index, source_element, &b_); + return Status::OK(); + }; + TF_RETURN_IF_ERROR( + llvm_ir::LoopEmitter(loop_body_emitter, input_shape, dynamic_dims, &b_) + .EmitLoop(IrName(hlo))); + // Emit static tensor and dynamic sizes as one tuple. llvm_ir::EmitTuple(GetIrArrayFor(hlo), tuple_operand_ptrs, &b_); return Status::OK(); @@ -2875,9 +2907,8 @@ Status IrEmitter::HandleRngGetAndUpdateState(HloInstruction* rng_state) { old_state->getType()->getScalarType(), address->getType()->getPointerAddressSpace())); llvm::StoreInst* store = Store(old_state, address); - store->setAlignment( - llvm::MaybeAlign(IrEmitter::MinimumAlignmentForPrimitiveType( - rng_state->shape().element_type()))); + store->setAlignment(llvm::Align(IrEmitter::MinimumAlignmentForPrimitiveType( + rng_state->shape().element_type()))); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 9b0d11e9f3f..661785153d0 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMITTER_H_ #include <stddef.h> + #include <map> #include <memory> #include <string> @@ -32,6 +33,7 @@ limitations under the License. #include "llvm/IR/Module.h" #include "llvm/IR/Value.h" #include "llvm/Target/TargetMachine.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/cpu/ir_function.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" @@ -69,14 +71,16 @@ class IrEmitter : public DfsHloVisitorWithDefault, // hlo_module: the HLO module we are emitting IR for. // assignment: a BufferAssignment from which we know which buffers are used by // the HLO nodes. - // llvm_module: the LLVM module to emit IR into. + // mlir_context: the MLIR context used for IR emission. + // llvm_module: the LLVM module to emit IR into. It's built using the LLVM + // context inside of mlir_context. // instruction_to_profile_idx: the mapping from HLO instructions to their // index in the profiling array. // computation_to_profile_idx: the mapping from HLO computations to their // index in the profiling array. // emit_code_for_msan: whether emitted code should be compatible with msan. - IrEmitter(const HloModule& hlo_module, const BufferAssignment& assignment, - llvm::Module* llvm_module, + IrEmitter(mlir::MLIRContext* mlir_context, const HloModule& hlo_module, + const BufferAssignment& assignment, llvm::Module* llvm_module, std::unordered_map<const HloInstruction*, int64> instruction_to_profile_idx, std::unordered_map<const HloComputation*, int64> @@ -442,6 +446,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, // module's function list). std::unique_ptr<IrFunction> compute_function_; llvm::IRBuilder<> b_; + mlir::MLIRContext* mlir_context_; // The buffer allocation slice for the root of the computation being compiled. // Only relevant for thread local computations. diff --git a/tensorflow/compiler/xla/service/cpu/mlir_emitter.cc b/tensorflow/compiler/xla/service/cpu/mlir_emitter.cc new file mode 100644 index 00000000000..e7d52c288d5 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/mlir_emitter.cc @@ -0,0 +1,132 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/mlir_emitter.h" + +#include "llvm/Linker/Linker.h" +#include "llvm/Transforms/IPO/Internalize.h" +#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" // from @llvm-project +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" // from @llvm-project +#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" // from @llvm-project +#include "mlir/Dialect/Linalg/Passes.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Target/LLVMIR.h" // from @llvm-project +#include "tensorflow/compiler/mlir/xla/hlo_utils.h" + +namespace xla { +namespace cpu { +namespace { + +// Lower an MLIR module to an LLVM module. +std::unique_ptr<llvm::Module> MakeLLVMModule(mlir::OwningModuleRef module) { + mlir::PassManager manager(module->getContext()); + manager.addPass(mlir::createConvertLinalgToLoopsPass()); + manager.addPass(mlir::createConvertLinalgToLLVMPass()); + manager.addPass(mlir::createConvertVectorToLLVMPass()); + manager.addPass(mlir::createLowerToLLVMPass()); + CHECK(succeeded(manager.run(*module))); + return mlir::translateModuleToLLVMIR(*module); +} + +// Get arguments to pass a memref to an mlir function. +void BuildViewForBuffer(llvm::SmallVectorImpl<llvm::Value *> *args, + llvm::IRBuilder<> *b, const Shape &opShape, + llvm::Value *op_val) { + llvm::Type *ty = op_val->getType(); + while (auto aty = llvm::dyn_cast<llvm::ArrayType>( + llvm::cast<llvm::PointerType>(ty)->getElementType())) { + ty = aty->getElementType()->getPointerTo(); + } + op_val = b->CreateBitCast(op_val, ty); + + args->push_back(op_val); // Allocated pointer. + args->push_back(op_val); // Aligned pointer. + args->push_back(b->getInt64(0)); // Offset. + + // Sizes. + for (int64 dim : opShape.dimensions()) { + args->push_back(b->getInt64(dim)); + } + + int64_t accumulated_stride = 1; + llvm::SmallVector<int64_t, 4> strides(opShape.rank(), 1); + for (int64 dim : LayoutUtil::MinorToMajor(opShape)) { + strides[dim] = accumulated_stride; + accumulated_stride *= opShape.dimensions(dim); + } + + // Strides. + for (int64 stride : strides) { + args->push_back(b->getInt64(stride)); + } +} +} // namespace + +Status EmitMlirFuncAndCall( + mlir::MLIRContext *context, llvm::IRBuilder<> *b, const Shape &result_shape, + llvm::ArrayRef<Shape> operand_shapes, llvm::Value *result_ptr, + llvm::ArrayRef<llvm::Value *> operand_ptrs, llvm::StringRef func_name, + llvm::function_ref<void(mlir::OpBuilder *, mlir::FuncOp)> emitter) { + llvm::Module *llvm_module = b->GetInsertBlock()->getParent()->getParent(); + mlir::Builder mlir_builder(context); + + // Get memref types for the inputs and output. + TF_ASSIGN_OR_RETURN(mlir::Type ret_memref, ConvertTensorShapeToMemRefType( + result_shape, mlir_builder)); + std::vector<mlir::Type> operand_types = {ret_memref}; + for (int i = 0; i != operand_shapes.size(); ++i) { + TF_ASSIGN_OR_RETURN( + mlir::Type op_memref, + ConvertTensorShapeToMemRefType(operand_shapes[i], mlir_builder)); + operand_types.push_back(op_memref); + } + + // Create the function an call the emission callback. + mlir::Location loc = mlir::UnknownLoc::get(context); + auto function = mlir::FuncOp::create( + loc, func_name, mlir::FunctionType::get(operand_types, {}, context)); + function.addEntryBlock(); + mlir::OwningModuleRef mlir_module = mlir::ModuleOp::create(loc); + mlir_module->push_back(function); + mlir::OpBuilder op_builder(&function.getBody()); + emitter(&op_builder, function); + + // Now link it all into the main LLVM module. + auto mlir_llvm_module = MakeLLVMModule(std::move(mlir_module)); + mlir_llvm_module->setDataLayout(llvm_module->getDataLayout()); + llvm::Linker::linkModules( + *llvm_module, std::move(mlir_llvm_module), llvm::Linker::None, + [](llvm::Module &M, const llvm::StringSet<> &GVS) { + llvm::internalizeModule(M, [&GVS](const llvm::GlobalValue &GV) { + return !GV.hasName() || (GVS.count(GV.getName()) == 0); + }); + }); + + // And leave behind a call to the function generated by MLIR. + llvm::Function *func = llvm_module->getFunction(func_name); + llvm::SmallVector<llvm::Value *, 4> op_vals; + BuildViewForBuffer(&op_vals, b, result_shape, result_ptr); + for (int i = 0; i != operand_shapes.size(); ++i) { + BuildViewForBuffer(&op_vals, b, operand_shapes[i], operand_ptrs[i]); + } + b->CreateCall(func, op_vals); + + return Status::OK(); +} + +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/mlir_emitter.h b/tensorflow/compiler/xla/service/cpu/mlir_emitter.h new file mode 100644 index 00000000000..bc0741e851a --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/mlir_emitter.h @@ -0,0 +1,43 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_MLIR_EMITTER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_MLIR_EMITTER_H_ + +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Value.h" +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/status.h" + +namespace xla { +namespace cpu { + +// Create a new MLIR function with the name `func_name`, populate it with +// `emitter` and create a call, passing it the buffers defined by +// resultShape/resultPtr and operandShapes/operandPtrs. The function is added to +// the LLVM module at `b`s insertion point. +Status EmitMlirFuncAndCall( + mlir::MLIRContext *context, llvm::IRBuilder<> *b, const Shape &result_shape, + llvm::ArrayRef<Shape> operand_shapes, llvm::Value *result_ptr, + llvm::ArrayRef<llvm::Value *> operand_ptrs, llvm::StringRef func_name, + llvm::function_ref<void(mlir::OpBuilder *, mlir::FuncOp)> emitter); + +} // namespace cpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_MLIR_EMITTER_H_ diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc index 14afe770ede..225102e6ae6 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc @@ -142,24 +142,29 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount( // in-place will only touch the updated elements). // TODO(b/27458679) Parallelize instructions which are skipped here. auto opcode = instruction->opcode(); - if (opcode == HloOpcode::kParameter || opcode == HloOpcode::kConstant || - opcode == HloOpcode::kCall || opcode == HloOpcode::kCustomCall || - opcode == HloOpcode::kDot || opcode == HloOpcode::kSelectAndScatter || - opcode == HloOpcode::kGetTupleElement || opcode == HloOpcode::kBitcast || - opcode == HloOpcode::kFft || opcode == HloOpcode::kInfeed || - opcode == HloOpcode::kOutfeed || opcode == HloOpcode::kRng || - opcode == HloOpcode::kSort || - (opcode == HloOpcode::kConvolution && - PotentiallyImplementedAsEigenConvolution(*instruction, - target_machine_features_)) || - (opcode == HloOpcode::kFusion && !instruction->IsLoopFusion()) || - llvm_ir::MayBeImplementedAsInPlaceDynamicUpdateSlice(instruction) || - instruction->shape().IsTuple()) { + if (llvm_ir::MayBeImplementedAsInPlaceDynamicUpdateSlice(instruction) || + instruction->shape().IsTuple() || opcode == HloOpcode::kRng) { return 1; } - // Consult 'cost_model_' to compute target parallel task count. - return cost_model_->GetParallelTaskCount(instruction); + // Only allow known good instructions. + if (instruction->IsElementwise() || instruction->IsLoopFusion() || + opcode == HloOpcode::kBroadcast || opcode == HloOpcode::kConcatenate || + opcode == HloOpcode::kDynamicSlice || + opcode == HloOpcode::kDynamicUpdateSlice || + opcode == HloOpcode::kGather || opcode == HloOpcode::kIota || + opcode == HloOpcode::kPad || opcode == HloOpcode::kReduce || + opcode == HloOpcode::kReduceWindow || opcode == HloOpcode::kReshape || + opcode == HloOpcode::kReverse || opcode == HloOpcode::kSlice || + opcode == HloOpcode::kTranspose || + (opcode == HloOpcode::kConvolution && + !PotentiallyImplementedAsEigenConvolution(*instruction, + target_machine_features_))) { + // Consult 'cost_model_' to compute target parallel task count. + return cost_model_->GetParallelTaskCount(instruction); + } + + return 1; } StatusOr<bool> ParallelTaskAssigner::Run(HloModule* module) { diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc index e2c93568b74..e22210a61f2 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc @@ -170,5 +170,26 @@ TEST_F(ParallelTaskAssignmentTest, InPlaceDynamicUpdateSliceNotParallelized) { EXPECT_FALSE(changed); } +TEST_F(ParallelTaskAssignmentTest, AllReduceNotParallelized) { + constexpr char hlo_string[] = R"( + HloModule TestTaskParallel_allreduce + add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + } + + ENTRY CRS { + input = f32[1234567] parameter(0) + ROOT crs = f32[1234567] all-reduce(input), replica_groups={}, to_apply=add + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(m.get())); + EXPECT_FALSE(changed); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc index 7831c1b1b5b..0d4e7055ddb 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc @@ -60,6 +60,11 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSort( std::unique_ptr<std::string[]> reordered_values( new std::string[sort_dimension_elements]); for (int64 index = 0; index < num_iteration_elements; ++index) { + // If the sort should be stable, we have to reinitialize indices to iota to + // guarantee that we still keep the relative order in case of ties. + if (is_stable && index > 0) { + std::iota(indices.get(), indices.get() + sort_dimension_elements, 0); + } // 'index' can be split into two values which index into the 'c' dimension // and the 'a' dimension, respectively. 'index' % 'c' is the index into the // 'c' dimension, 'index' / 'c' is the index into the 'a' dimension. When diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD index f52de3394fe..1ac8509cdb1 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD @@ -35,6 +35,19 @@ cc_library( ], ) +tf_cc_test( + name = "cpu_dyn_shape_test", + srcs = ["cpu_dyn_shape_test.cc"], + deps = [ + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service/cpu:cpu_compiler", + "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + tf_cc_test( name = "cpu_fusion_test", srcs = ["cpu_fusion_test.cc"], diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_dyn_shape_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_dyn_shape_test.cc new file mode 100644 index 00000000000..46249caa0c7 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_dyn_shape_test.cc @@ -0,0 +1,60 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <memory> + +#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" +#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" + +namespace xla { +namespace cpu { +namespace { + +using CpuDynamicShapeTest = CpuCodegenTest; + +TEST_F(CpuDynamicShapeTest, DynamicShapeR2) { + HloComputation::Builder builder(TestName()); + + xla::Shape dyn_input_shape = xla::ShapeUtil::MakeShape(xla::F32, {2, 4}); + dyn_input_shape.set_dynamic_dimension(0, true); + HloInstruction* param_x = builder.AddInstruction( + HloInstruction::CreateParameter(0, dyn_input_shape, "x")); + + builder.AddInstruction(HloInstruction::CreateUnary( + dyn_input_shape, HloOpcode::kNegate, param_x)); + auto hlo_module = CreateNewVerifiedModule(); + hlo_module->AddEntryComputation(builder.Build()); + + string filecheck_pattern = R"( +; CHECK: %[[dyn_dim_size:.*]] = load i32, i32* +; CHECK: %[[i64_dyn_dim_size:.*]] = sext i32 %[[dyn_dim_size:.*]] to i64 +; CHECK: icmp uge i64 %[[custom:.*]], %[[i64_dyn_dim_size:.*]] +; CHECK: %[[multiplier:.*]] = mul i64 1, %[[i64_dyn_dim_size:.*]] +; CHECK: mul nuw nsw i64 %[[custom:.*]], %[[multiplier:.*]] +)"; + + CpuAotCompilationOptions options{ + /*triple=*/"x86_64", /*cpu_name=*/"", /*features=*/"", + /*entry_point_name=*/"entry", + /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static}; + + CompileAheadOfTimeAndVerifyIr(std::move(hlo_module), options, + filecheck_pattern, + /*match_optimized_ir=*/false); +} + +} // namespace +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc index b6d6de28bc5..efeab3bd31a 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc @@ -70,6 +70,13 @@ class CpuUnaryIntrinsicTest return absl::StrCat(opcode, "_On_", triple, (features.empty() ? "" : "_With"), features); } + + private: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); + HloTestBase::SetAotFastMathDebugOptions(&debug_options); + return debug_options; + } }; // Creates a module with a call to the unary op, and tests if the diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_vectorization_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_vectorization_test.cc index 8a72eb15487..757d878e224 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_vectorization_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_vectorization_test.cc @@ -69,6 +69,13 @@ class CpuVectorizationTest return absl::StrCat(opcode, "_On_", triple, (features.empty() ? "" : "_With"), features); } + + private: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); + HloTestBase::SetAotFastMathDebugOptions(&debug_options); + return debug_options; + } }; TEST_P(CpuVectorizationTest, DoIt) { diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index caea9d9095a..bdaac32a0e5 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -120,6 +120,8 @@ class DfsHloVisitorBase { virtual Status HandleAllReduce(HloInstructionPtr hlo) = 0; virtual Status HandleAllToAll(HloInstructionPtr hlo) = 0; virtual Status HandleCollectivePermute(HloInstructionPtr hlo) = 0; + virtual Status HandleCollectivePermuteStart(HloInstructionPtr hlo) = 0; + virtual Status HandleCollectivePermuteDone(HloInstructionPtr hlo) = 0; virtual Status HandleReplicaId(HloInstructionPtr hlo) = 0; virtual Status HandlePartitionId(HloInstructionPtr hlo) = 0; virtual Status HandleGetDimensionSize(HloInstructionPtr hlo) = 0; diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index 9cd220245ba..b1d674fe467 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -110,6 +110,12 @@ class DfsHloVisitorWithDefaultBase Status HandleCollectivePermute(HloInstructionPtr hlo) override { return DefaultAction(hlo); } + Status HandleCollectivePermuteStart(HloInstructionPtr hlo) override { + return DefaultAction(hlo); + } + Status HandleCollectivePermuteDone(HloInstructionPtr hlo) override { + return DefaultAction(hlo); + } Status HandleReplicaId(HloInstructionPtr hlo) override { return DefaultAction(hlo); } diff --git a/tensorflow/compiler/xla/service/dot_as_convolution_util.cc b/tensorflow/compiler/xla/service/dot_as_convolution_util.cc new file mode 100644 index 00000000000..fcdf85d5ecb --- /dev/null +++ b/tensorflow/compiler/xla/service/dot_as_convolution_util.cc @@ -0,0 +1,139 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/dot_as_convolution_util.h" + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" +#include "tensorflow/compiler/xla/status_macros.h" + +namespace xla { +namespace dot_as_convolution_util { + +/* static */ absl::optional<DotGeneralAsConvolutionDimsInfo> +ParseDotGeneralFromConvolution(const HloInstruction* conv) { + CHECK_EQ(conv->opcode(), HloOpcode::kConvolution); + if (conv->feature_group_count() != 1 || conv->batch_group_count() != 1) { + return absl::nullopt; + } + const auto& conv_dims = conv->convolution_dimension_numbers(); + DotGeneralAsConvolutionDimsInfo dims; + dims.lhs_non_contracting_dims.push_back( + {conv_dims.input_batch_dimension(), -1, + conv_dims.output_batch_dimension(), -1}); + dims.rhs_non_contracting_dims.push_back( + {-1, conv_dims.kernel_output_feature_dimension(), + conv_dims.output_feature_dimension(), -1}); + dims.contracting_dims.push_back({conv_dims.input_feature_dimension(), + conv_dims.kernel_input_feature_dimension(), + -1, -1}); + + for (int64 i = 0; i < conv_dims.input_spatial_dimensions_size(); ++i) { + int64 lhs = conv_dims.input_spatial_dimensions(i); + int64 lhs_size = conv->operand(0)->shape().dimensions(lhs); + int64 rhs = conv_dims.kernel_spatial_dimensions(i); + int64 rhs_size = conv->operand(1)->shape().dimensions(rhs); + int64 output = conv_dims.output_spatial_dimensions(i); + const auto& wd = conv->window().dimensions(i); + if (lhs_size == wd.size() && + std::max<int64>(1, lhs_size - 1) == wd.stride() && + lhs_size == wd.base_dilation() && wd.window_dilation() == 1 && + wd.padding_high() == 0 && wd.padding_low() == 0 && + !wd.window_reversal()) { + // A batch dimension in DotGeneral is represented as a spatial dimension + // with window size B (batch dimension size), stride B - 1, and base + // dilation B. + dims.batch_dims.push_back({lhs, rhs, output, i}); + } else if (lhs_size == wd.size() && wd.base_dilation() == 1 && + wd.window_dilation() == 1 && wd.padding_high() == 0 && + wd.padding_low() == 0 && !wd.window_reversal()) { + // A contracting dimension be represented as a spatial dimension with + // window size C (contracting dimension size). Stride can be any size + // since there is only one window. + dims.contracting_dims.push_back({lhs, rhs, output, i}); + } else if (wd.stride() == 1 && wd.window_dilation() == 1 && + wd.base_dilation() == 1) { + if (rhs_size == 1 && wd.size() == 1 && wd.padding_high() == 0 && + wd.padding_low() == 0 && !wd.window_reversal()) { + // A LHS non-contracting dimension can be represented as a spatial + // dimension with window size 1. + dims.lhs_non_contracting_dims.push_back({lhs, rhs, output, i}); + } else if (lhs_size == 1 && wd.size() == rhs_size && + wd.padding_high() == rhs_size - 1 && + wd.padding_low() == rhs_size - 1 && wd.window_reversal()) { + // A RHS non-contracting dimension can be represented as a spatial + // dimension with window size N (non-contracting dimension size), low + // padding N - 1, high padding N - 1 and window reversal. + dims.rhs_non_contracting_dims.push_back({lhs, rhs, output, i}); + } else { + return absl::nullopt; + } + } else { + return absl::nullopt; + } + } + + return dims; +} + +StatusOr<std::unique_ptr<HloInstruction>> +CreateShardedConvForDotGeneralConvolution( + const HloInstruction& conv, + const DotGeneralAsConvolutionDimsInfo& dot_dnums, + HloInstruction* sharded_lhs_hlo, HloInstruction* sharded_rhs_hlo) { + CHECK_EQ(conv.opcode(), HloOpcode::kConvolution); + const auto& conv_dnums = conv.convolution_dimension_numbers(); + auto window = conv.window(); + for (const auto& dim : dot_dnums.batch_dims) { + auto wd = window.mutable_dimensions(dim.spatial_dim); + wd->set_size(sharded_lhs_hlo->shape().dimensions( + conv_dnums.input_spatial_dimensions(dim.spatial_dim))); + wd->set_stride(std::max<int64>(1, wd->size() - 1)); + wd->set_base_dilation(wd->size()); + } + for (const auto& dim : dot_dnums.contracting_dims) { + if (dim.spatial_dim < 0) { + continue; + } + auto wd = window.mutable_dimensions(dim.spatial_dim); + wd->set_size(sharded_lhs_hlo->shape().dimensions( + conv_dnums.input_spatial_dimensions(dim.spatial_dim))); + } + for (const auto& dim : dot_dnums.rhs_non_contracting_dims) { + if (dim.spatial_dim < 0) { + continue; + } + auto wd = window.mutable_dimensions(dim.spatial_dim); + wd->set_size(sharded_rhs_hlo->shape().dimensions( + conv_dnums.kernel_spatial_dimensions(dim.spatial_dim))); + wd->set_padding_high(wd->size() - 1); + wd->set_padding_low(wd->size() - 1); + } + TF_ASSIGN_OR_RETURN(Shape sharded_conv_shape, + ShapeInference::InferConvolveShape( + sharded_lhs_hlo->shape(), sharded_rhs_hlo->shape(), + /*feature_group_count=*/1, + /*batch_group_count=*/1, window, conv_dnums)); + *sharded_conv_shape.mutable_layout() = conv.shape().layout(); + return HloInstruction::CreateConvolve( + sharded_conv_shape, sharded_lhs_hlo, sharded_rhs_hlo, + /*feature_group_count=*/1, + /*batch_group_count=*/1, window, conv_dnums, conv.precision_config()); +} + +} // namespace dot_as_convolution_util +} // namespace xla diff --git a/tensorflow/compiler/xla/service/dot_as_convolution_util.h b/tensorflow/compiler/xla/service/dot_as_convolution_util.h new file mode 100644 index 00000000000..a3e829a3d31 --- /dev/null +++ b/tensorflow/compiler/xla/service/dot_as_convolution_util.h @@ -0,0 +1,68 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DOT_AS_CONVOLUTION_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_DOT_AS_CONVOLUTION_UTIL_H_ + +#include <memory> +#include <vector> + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" + +namespace xla { +namespace dot_as_convolution_util { + +// Describes the dimensions of a convolution that can be interpreted as a dot. +struct DotGeneralAsConvolutionDimsInfo { + // The dimension numbers for the operands and output corresponding to a + // logical dimension (e.g., batch, contracting, non-contracting). If an + // operand or the output doesn't have the logical dimension, it is set to + // -1. + struct DimNums { + int64 lhs; + int64 rhs; + int64 output; + // The corresponding spatial dimension in the convolution's config. Set to + // -1 if it's not mapped to a spatial dimension. + int64 spatial_dim; + }; + std::vector<DimNums> batch_dims; + std::vector<DimNums> contracting_dims; + std::vector<DimNums> lhs_non_contracting_dims; + std::vector<DimNums> rhs_non_contracting_dims; +}; + +// Parses a convolution and returns a DotGeneralAsConvolutionDimsInfo if it can +// be interpreted as a dot, or absl::nullopt otherwise. +absl::optional<DotGeneralAsConvolutionDimsInfo> ParseDotGeneralFromConvolution( + const HloInstruction* conv); + +// Creates sharded convolution instruction that can be interpreted as a dot. +// This is a utility for per-op partitioners. +// - 'conv' is the original convolution instruction. +// - 'dot_dnums' is the result of ParseDotGeneralFromConvolution() for 'conv'. +// - 'sharded_lhs_hlo' and 'sharded_rhs_hlo' are sharded inputs for the result +// convolution instruction. +StatusOr<std::unique_ptr<HloInstruction>> +CreateShardedConvForDotGeneralConvolution( + const HloInstruction& conv, + const DotGeneralAsConvolutionDimsInfo& dot_dnums, + HloInstruction* sharded_lhs_hlo, HloInstruction* sharded_rhs_hlo); + +} // namespace dot_as_convolution_util +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_DOT_AS_CONVOLUTION_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/dot_decomposer.cc b/tensorflow/compiler/xla/service/dot_decomposer.cc index 353a7f5cebc..40354dec3c6 100644 --- a/tensorflow/compiler/xla/service/dot_decomposer.cc +++ b/tensorflow/compiler/xla/service/dot_decomposer.cc @@ -31,7 +31,7 @@ namespace { // Convert a dot into a canonical form where non-contracting and contracting // dimensions are reshaped together and batch dimensions are the most major -// dimensions. The requires transposing and reshapes the lhs and rhs and +// dimensions. This requires transposing and reshapes of the lhs and rhs and // reshaping the output batch to the original shape. Status CanonicalizeDot(HloInstruction* original_dot) { auto computation = original_dot->parent(); @@ -80,7 +80,9 @@ Status CanonicalizeDot(HloInstruction* original_dot) { lhs_shape), original_dot->mutable_operand(0), lhs_transpose)); std::vector<int64> lhs_reshape_dims = batch_dim_sizes; - lhs_reshape_dims.push_back(lhs_non_contracting_size); + if (lhs_non_contracting_size > 1) { + lhs_reshape_dims.push_back(lhs_non_contracting_size); + } lhs_reshape_dims.push_back(lhs_contracting_size); // Reshape the contracting and non-contracting dimensions together. HloInstruction* reshaped_lhs = @@ -126,7 +128,9 @@ Status CanonicalizeDot(HloInstruction* original_dot) { std::vector<int64> rhs_reshape_dims = batch_dim_sizes; rhs_reshape_dims.push_back(rhs_contracting_size); - rhs_reshape_dims.push_back(rhs_non_contracting_size); + if (rhs_non_contracting_size > 1) { + rhs_reshape_dims.push_back(rhs_non_contracting_size); + } // Reshape the contracting and non-contracting dimensions together. HloInstruction* reshaped_rhs = computation->AddInstruction(HloInstruction::CreateReshape( @@ -134,15 +138,20 @@ Status CanonicalizeDot(HloInstruction* original_dot) { transposed_rhs)); std::vector<int64> dot_dims = batch_dim_sizes; - dot_dims.push_back(lhs_non_contracting_size); - dot_dims.push_back(rhs_non_contracting_size); + if (lhs_non_contracting_size > 1) { + dot_dims.push_back(lhs_non_contracting_size); + } + if (rhs_non_contracting_size > 1) { + dot_dims.push_back(rhs_non_contracting_size); + } DotDimensionNumbers dot_dnums; for (int64 i = 0; i < num_batch_dims; ++i) { dot_dnums.add_lhs_batch_dimensions(i); dot_dnums.add_rhs_batch_dimensions(i); } - dot_dnums.add_lhs_contracting_dimensions(num_batch_dims + 1); + dot_dnums.add_lhs_contracting_dimensions( + num_batch_dims + (lhs_non_contracting_size > 1 ? 1 : 0)); dot_dnums.add_rhs_contracting_dimensions(num_batch_dims); HloInstruction* dot = computation->AddInstruction(HloInstruction::CreateDot( @@ -174,9 +183,9 @@ StatusOr<bool> DotDecomposer::Run(HloModule* module) { } // A dot is not canonical if it has more than one non-contracting // dimension. - if (dnums.lhs_batch_dimensions_size() + 2 != + if (dnums.lhs_batch_dimensions_size() + 2 < instruction->operand(0)->shape().rank() || - dnums.rhs_batch_dimensions_size() + 2 != + dnums.rhs_batch_dimensions_size() + 2 < instruction->operand(1)->shape().rank()) { non_canonical_dots.push_back(instruction); continue; diff --git a/tensorflow/compiler/xla/service/dot_decomposer_test.cc b/tensorflow/compiler/xla/service/dot_decomposer_test.cc index 67fff50eaf6..c4152393933 100644 --- a/tensorflow/compiler/xla/service/dot_decomposer_test.cc +++ b/tensorflow/compiler/xla/service/dot_decomposer_test.cc @@ -50,5 +50,75 @@ TEST_F(DotDecomposerTest, CanonicalizeMultipleNonContractingDims) { op::Shape("f32[4032,512]")))); } +TEST_F(DotDecomposerTest, DontCanonicalizeIfNoNoncontractingDims) { + absl::string_view module_string = R"( + HloModule module + + ENTRY main { + p0 = f32[64,4]{1,0} parameter(0) + p1 = f32[64,4]{1,0} parameter(1) + ROOT dot = f32[64]{0} dot(p0, p1), lhs_batch_dims={0}, + lhs_contracting_dims={1}, + rhs_batch_dims={0}, + rhs_contracting_dims={1} + })"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + ParseAndReturnVerifiedModule(module_string)); + TF_ASSERT_OK_AND_ASSIGN(bool canonicalized, + DotDecomposer().Run(module.get())); + EXPECT_FALSE(canonicalized); +} + +TEST_F(DotDecomposerTest, DontAddLhsNonContractingDimIfOne) { + absl::string_view module_string = R"( + HloModule module + + ENTRY main { + p0 = f32[64,4]{1,0} parameter(0) + p1 = f32[64,4,2,1]{3,2,1,0} parameter(1) + ROOT dot = f32[64,2,1]{2,1,0} dot(p0, p1), lhs_batch_dims={0}, + lhs_contracting_dims={1}, + rhs_batch_dims={0}, + rhs_contracting_dims={1} + })"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + ParseAndReturnVerifiedModule(module_string)); + TF_ASSERT_OK_AND_ASSIGN(bool canonicalized, + DotDecomposer().Run(module.get())); + EXPECT_TRUE(canonicalized); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Reshape(AllOf(op::Dot(op::Reshape(), op::Reshape(), + /*lhs_contracting_dim=*/1, + /*rhs_contracting_dim=*/1), + op::Shape("f32[64,2]")))); +} + +TEST_F(DotDecomposerTest, DontAddRhsNonContractingDimIfOne) { + absl::string_view module_string = R"( + HloModule module + + ENTRY main { + p0 = f32[64,4,2,1]{3,2,1,0} parameter(0) + p1 = f32[64,4]{1,0} parameter(1) + ROOT dot = f32[64,2,1]{2,1,0} dot(p0, p1), lhs_batch_dims={0}, + lhs_contracting_dims={1}, + rhs_batch_dims={0}, + rhs_contracting_dims={1} + })"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + ParseAndReturnVerifiedModule(module_string)); + TF_ASSERT_OK_AND_ASSIGN(bool canonicalized, + DotDecomposer().Run(module.get())); + EXPECT_TRUE(canonicalized); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Reshape(AllOf(op::Dot(op::Reshape(), op::Reshape(), + /*lhs_contracting_dim=*/2, + /*rhs_contracting_dim=*/1), + op::Shape("f32[64,2]")))); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 0f6b2cb72e6..958100ecc03 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -17,15 +17,15 @@ load( "tf_cuda_library", ) load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") -load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") +load( + "@local_config_rocm//rocm:build_defs.bzl", + "if_rocm", + "if_rocm_is_configured", +) load( "//tensorflow/core/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured", ) -load( - "@local_config_rocm//rocm:build_defs.bzl", - "if_rocm_is_configured", -) load("//tensorflow:tensorflow.bzl", "if_nccl") package( @@ -901,12 +901,15 @@ cc_library( ":ir_emission_utils", "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/compiler/xla/service:hlo_reachability", - "//tensorflow/compiler/xla/service:multi_output_fusion", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index 1be0b1b4e7b..eee0fc83481 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -260,6 +260,13 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitAtan2(PrimitiveType prim_type, StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type, llvm::Value* value) { + // When F64 is being requested, assume performance is less important and use + // the more numerically precise tanh function. + if (prim_type == F64) { + return EmitDeviceMathCall(TargetDeviceFunctionID::kTanh, {value}, + {prim_type}, prim_type); + } + // Emit a fast approximation of tanh instead of calling __nv_tanh. // __nv_tanh is particularly bad because it contains branches, thus // preventing LLVM's load-store vectorizer from working its magic across a diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc index 1316e8ad1aa..bb4184ff76f 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc @@ -351,6 +351,9 @@ bool FusionWouldBeTooLarge(const HloInstruction& instr1, const HloInstruction& instr2) { if (SharedMemoryUsage(instr1) + SharedMemoryUsage(instr2) > kSharedMemoryBudgetInBytes) { + VLOG(5) << "Shared memory usage of fusion of " << instr1.ToString() + << " and " << instr2.ToString() << " would be over the budget of " + << kSharedMemoryBudgetInBytes << "B"; return true; } @@ -383,6 +386,14 @@ bool FusionWouldBeTooLarge(const HloInstruction& instr1, num_output_buffers <= kMaxOperandsAndOutputsPerFusion) { return false; + } else { + VLOG(5) << "Operand count of " + << "(" << instr1.ToString() << " ) = " << instr1.operand_count() + << " and ( " << instr2.ToString() + << " ) = " << instr2.operand_count() + << " and num_output_buffers = " << num_output_buffers + << " is bigger than the bound of " + << kMaxOperandsAndOutputsPerFusion; } // Compute the precise number of operands to the new fusion. diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc index 05fa798dc39..cb22b4d9042 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc @@ -96,7 +96,8 @@ Status GpuTransferManager::EnqueueBuffersToInfeed( StatusOr<InfeedBuffer> GpuTransferManager::TransferBufferToInfeedInternal( se::StreamExecutor* executor, int64 size, const void* source) { if (size > std::numeric_limits<int32>::max()) { - return InvalidArgument("Infeed shape is too large: needs %d bytes", size); + return InvalidArgument("GPU infeed of %d bytes exceeds maximum of %d bytes", + size, std::numeric_limits<int32>::max()); } if (size == 0) { diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index fc1c1bb4ab1..a0580e2ab04 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -65,12 +65,16 @@ bool GpuInstructionFusion::ShouldFuseInexpensiveChecks(HloInstruction* consumer, bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, int64 operand_index) { if (!ShouldFuseInexpensiveChecks(consumer, operand_index)) { + VLOG(5) << "Not fusing inexpensive checks of operand " << operand_index + << " of " << consumer->ToString(); return false; } auto producer = consumer->operand(operand_index); // The following checks are potentially expensive. if (FusionWouldBeTooLarge(*consumer, *producer)) { + VLOG(5) << "Fusion of (" << producer->ToString() << ") into (" + << consumer->ToString() << ") would be too large"; return false; } if (consumer->opcode() != HloOpcode::kFusion) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 011eb07d3bd..744cd7b56bf 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -222,7 +222,7 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation( // Derive a minimum alignment from the type. The optimizer can increase it // later. store->setAlignment( - llvm::MaybeAlign(ShapeUtil::ByteSizeOfPrimitiveType(element_type))); + llvm::Align(ShapeUtil::ByteSizeOfPrimitiveType(element_type))); return true; } diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc index 060a0375271..497dcda4361 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -689,7 +689,7 @@ std::unique_ptr<llvm::TargetMachine> AMDGPUGetTargetMachine( llvm::Triple target_triple, int amdgpu_version, const HloModuleConfig& hlo_module_config) { return GetTargetMachine(target_triple, absl::StrCat("gfx", amdgpu_version), - hlo_module_config, "-code-object-v3"); + hlo_module_config, "+code-object-v3"); } void AMDGPUBackendInit(const HloModuleConfig& hlo_module_config) { diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h index 8d2ef53bfa9..e60f3bc3c14 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h @@ -16,7 +16,15 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MULTI_OUTPUT_FUSION_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MULTI_OUTPUT_FUSION_H_ -#include "tensorflow/compiler/xla/service/multi_output_fusion.h" +#include <queue> +#include <vector> + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/service/hlo_reachability.h" +#include "tensorflow/compiler/xla/statusor.h" namespace xla { namespace gpu { diff --git a/tensorflow/compiler/xla/service/gpu/target_util.cc b/tensorflow/compiler/xla/service/gpu/target_util.cc index 49eadd8c6be..31b590a19ff 100644 --- a/tensorflow/compiler/xla/service/gpu/target_util.cc +++ b/tensorflow/compiler/xla/service/gpu/target_util.cc @@ -111,47 +111,50 @@ struct TargetDeviceFunction { struct TargetDeviceFunction GetDeviceFunctionRoot( TargetDeviceFunctionID func_id) { switch (func_id) { - case TargetDeviceFunctionID::kPow: { - return {"__nv_pow", "__ocml_pow"}; - } - case TargetDeviceFunctionID::kErfcinv: { - return {"__nv_erfcinv", "__ocml_erfcinv"}; - } - case TargetDeviceFunctionID::kLog: { - return {"__nv_log", "__ocml_log"}; - } - case TargetDeviceFunctionID::kLog1p: { - return {"__nv_log1p", "__ocml_log1p"}; - } - case TargetDeviceFunctionID::kSin: { - return {"__nv_sin", "__ocml_sin"}; + case TargetDeviceFunctionID::kAtan2: { + return {"__nv_atan2", "__ocml_atan2"}; } case TargetDeviceFunctionID::kCos: { return {"__nv_cos", "__ocml_cos"}; } + case TargetDeviceFunctionID::kErfcinv: { + return {"__nv_erfcinv", "__ocml_erfcinv"}; + } case TargetDeviceFunctionID::kExp: { return {"__nv_exp", "__ocml_exp"}; } case TargetDeviceFunctionID::kExpm1: { return {"__nv_expm1", "__ocml_expm1"}; } - case TargetDeviceFunctionID::kSqrt: { - return {"__nv_sqrt", "__ocml_sqrt"}; - } - case TargetDeviceFunctionID::kRsqrt: { - return {"__nv_rsqrt", "__ocml_rsqrt"}; - } - case TargetDeviceFunctionID::kAtan2: { - return {"__nv_atan2", "__ocml_atan2"}; - } case TargetDeviceFunctionID::kFmod: { return {"__nv_fmod", "__ocml_fmod"}; } + case TargetDeviceFunctionID::kHypot: { + return {"__nv_hypot", "__ocml_hypot"}; + } + case TargetDeviceFunctionID::kLog: { + return {"__nv_log", "__ocml_log"}; + } + case TargetDeviceFunctionID::kLog1p: { + return {"__nv_log1p", "__ocml_log1p"}; + } + case TargetDeviceFunctionID::kPow: { + return {"__nv_pow", "__ocml_pow"}; + } case TargetDeviceFunctionID::kRound: { return {"__nv_round", "__ocml_round"}; } - case TargetDeviceFunctionID::kHypot: { - return {"__nv_hypot", "__ocml_hypot"}; + case TargetDeviceFunctionID::kRsqrt: { + return {"__nv_rsqrt", "__ocml_rsqrt"}; + } + case TargetDeviceFunctionID::kSin: { + return {"__nv_sin", "__ocml_sin"}; + } + case TargetDeviceFunctionID::kSqrt: { + return {"__nv_sqrt", "__ocml_sqrt"}; + } + case TargetDeviceFunctionID::kTanh: { + return {"__nv_tanh", "__ocml_tanh"}; } } } diff --git a/tensorflow/compiler/xla/service/gpu/target_util.h b/tensorflow/compiler/xla/service/gpu/target_util.h index 4355ed21136..2bdaea7734a 100644 --- a/tensorflow/compiler/xla/service/gpu/target_util.h +++ b/tensorflow/compiler/xla/service/gpu/target_util.h @@ -46,20 +46,21 @@ enum class TargetIntrinsicID { // Enumeration to get target specific device math function. enum class TargetDeviceFunctionID { - kPow = 0, - kErfcinv, - kLog, - kLog1p, - kSin, + kAtan2 = 0, kCos, + kErfcinv, kExp, kExpm1, - kSqrt, - kRsqrt, - kAtan2, kFmod, + kHypot, + kLog, + kLog1p, + kPow, kRound, - kHypot + kRsqrt, + kSin, + kSqrt, + kTanh, }; // Emits IR to call a device function named "callee_name" on the given diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 32a9038b15a..8a31bc5fef4 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -736,6 +736,16 @@ Status HloCostAnalysis::HandleCollectivePermute(const HloInstruction* /*hlo*/) { return Status::OK(); } +Status HloCostAnalysis::HandleCollectivePermuteStart( + const HloInstruction* /*hlo*/) { + return Status::OK(); +} + +Status HloCostAnalysis::HandleCollectivePermuteDone( + const HloInstruction* /*hlo*/) { + return Status::OK(); +} + Status HloCostAnalysis::HandlePartitionId(const HloInstruction* /*hlo*/) { return Status::OK(); } @@ -1031,6 +1041,42 @@ float HloCostAnalysis::optimal_seconds(const HloInstruction& hlo) const { return GetPropertyForHlo(hlo, kOptimalSecondsKey, hlo_properties_); } +int64 HloCostAnalysis::GetBytesRead(const HloInstruction& hlo, + absl::optional<int64> memory_space) const { + int64 bytes_read = 0; + for (int operand_number = 0; operand_number < hlo.operand_count(); + ++operand_number) { + for (const ShapeUtil::IndexedShape& indexed_shape : + ShapeUtil::GetLeafShapes(hlo.operand(operand_number)->shape())) { + absl::optional<int64> index_memory_space; + if (indexed_shape.shape.has_layout()) { + index_memory_space = indexed_shape.shape.layout().memory_space(); + } + if (!memory_space || memory_space == index_memory_space) { + bytes_read += + operand_bytes_accessed(hlo, operand_number, indexed_shape.index); + } + } + } + return bytes_read; +} + +int64 HloCostAnalysis::GetBytesWritten( + const HloInstruction& hlo, absl::optional<int64> memory_space) const { + int64 bytes_written = 0; + for (const ShapeUtil::IndexedShape& indexed_shape : + ShapeUtil::GetLeafShapes(hlo.shape())) { + absl::optional<int64> index_memory_space; + if (indexed_shape.shape.has_layout()) { + index_memory_space = indexed_shape.shape.layout().memory_space(); + } + if (!memory_space || memory_space == index_memory_space) { + bytes_written += output_bytes_accessed(hlo, indexed_shape.index); + } + } + return bytes_written; +} + StatusOr<HloCostAnalysis::Properties> HloCostAnalysis::ProcessSubcomputation( HloComputation* computation) { auto visitor = CreateNestedCostAnalysis(shape_size_, per_second_rates_); diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index 9fdb42185fb..d9085dd7785 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -80,6 +80,8 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleAllReduce(const HloInstruction* crs) override; Status HandleAllToAll(const HloInstruction* hlo) override; Status HandleCollectivePermute(const HloInstruction* hlo) override; + Status HandleCollectivePermuteStart(const HloInstruction* hlo) override; + Status HandleCollectivePermuteDone(const HloInstruction* hlo) override; Status HandleReplicaId(const HloInstruction* hlo) override; Status HandlePartitionId(const HloInstruction* hlo) override; Status HandleInfeed(const HloInstruction* infeed) override; @@ -162,6 +164,14 @@ class HloCostAnalysis : public ConstDfsHloVisitor { ShapeIndex index = {}) const; float optimal_seconds(const HloInstruction& hlo) const; + // Get bytes read/written by this HLO. If memory_space is provided, it returns + // the bytes read/written from/to the given memory space only. + int64 GetBytesRead(const HloInstruction& hlo, + absl::optional<int64> memory_space = absl::nullopt) const; + int64 GetBytesWritten( + const HloInstruction& hlo, + absl::optional<int64> memory_space = absl::nullopt) const; + const Properties& properties() const { return properties_sum_; } const float property(const string& key) const { return GetProperty(key, properties()); diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index cd2a61d7eff..3930898d665 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -1061,6 +1061,8 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kAllReduce: case HloOpcode::kAllToAll: case HloOpcode::kCollectivePermute: + case HloOpcode::kCollectivePermuteStart: + case HloOpcode::kCollectivePermuteDone: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kPartitionId: diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 9e9c8b0913b..c02100debc3 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -452,7 +452,8 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( /*channel_id=*/channel_id, split_dimension); break; } - case HloOpcode::kCollectivePermute: { + case HloOpcode::kCollectivePermute: + case HloOpcode::kCollectivePermuteStart: { std::vector<std::pair<int64, int64>> source_target_pairs( proto.source_target_pairs_size()); absl::optional<int64> channel_id; @@ -463,8 +464,17 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( source_target_pairs[i].first = proto.source_target_pairs(i).source(); source_target_pairs[i].second = proto.source_target_pairs(i).target(); } - instruction = CreateCollectivePermute(shape, operands(0), - source_target_pairs, channel_id); + + if (opcode == HloOpcode::kCollectivePermute) { + instruction = CreateCollectivePermute(shape, operands(0), + source_target_pairs, channel_id); + } else if (opcode == HloOpcode::kCollectivePermuteStart) { + instruction = CreateCollectivePermuteStart( + shape, operands(0), source_target_pairs, channel_id); + } else { + LOG(FATAL) << "Expect CollectivePermute or CollectivePermuteStart, " + << "but got " << HloOpcodeString(opcode); + } break; } case HloOpcode::kReplicaId: { @@ -805,6 +815,7 @@ HloInstruction::CreateRngBitGenerator(const Shape& shape, HloInstruction* state, case HloOpcode::kRoundNearestAfz: case HloOpcode::kBitcast: case HloOpcode::kCeil: + case HloOpcode::kCollectivePermuteDone: case HloOpcode::kCopy: case HloOpcode::kCopyStart: case HloOpcode::kCopyDone: @@ -982,7 +993,18 @@ HloInstruction::CreateCollectivePermute( const std::vector<std::pair<int64, int64>>& source_target_pairs, const absl::optional<int64>& channel_id) { return absl::make_unique<HloCollectivePermuteInstruction>( - shape, operand, source_target_pairs, channel_id); + HloOpcode::kCollectivePermute, shape, operand, source_target_pairs, + channel_id); +} + +/* static */ std::unique_ptr<HloInstruction> +HloInstruction::CreateCollectivePermuteStart( + const Shape& shape, HloInstruction* operand, + const std::vector<std::pair<int64, int64>>& source_target_pairs, + const absl::optional<int64>& channel_id) { + return absl::make_unique<HloCollectivePermuteInstruction>( + HloOpcode::kCollectivePermuteStart, shape, operand, source_target_pairs, + channel_id); } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReplicaId() { @@ -1549,6 +1571,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( case HloOpcode::kAllReduce: case HloOpcode::kAllToAll: case HloOpcode::kCollectivePermute: + case HloOpcode::kCollectivePermuteStart: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kConvolution: @@ -1575,6 +1598,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( case HloOpcode::kBitcast: case HloOpcode::kCeil: case HloOpcode::kClz: + case HloOpcode::kCollectivePermuteDone: case HloOpcode::kCopy: case HloOpcode::kCopyStart: case HloOpcode::kCopyDone: @@ -1928,6 +1952,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kCeil: case HloOpcode::kClamp: case HloOpcode::kClz: + case HloOpcode::kCollectivePermuteDone: case HloOpcode::kComplex: case HloOpcode::kConvert: case HloOpcode::kCopy: @@ -2029,6 +2054,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kAllReduce: case HloOpcode::kAllToAll: case HloOpcode::kCollectivePermute: + case HloOpcode::kCollectivePermuteStart: case HloOpcode::kConvolution: case HloOpcode::kCustomCall: case HloOpcode::kReduceWindow: @@ -2888,6 +2914,10 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) { return visitor->HandleAllToAll(this); case HloOpcode::kCollectivePermute: return visitor->HandleCollectivePermute(this); + case HloOpcode::kCollectivePermuteStart: + return visitor->HandleCollectivePermuteStart(this); + case HloOpcode::kCollectivePermuteDone: + return visitor->HandleCollectivePermuteDone(this); case HloOpcode::kReplicaId: return visitor->HandleReplicaId(this); case HloOpcode::kPartitionId: @@ -3965,6 +3995,10 @@ const PaddingConfig& HloInstruction::padding_config() const { return Cast<HloPadInstruction>(this)->padding_config(); } +PaddingConfig* HloInstruction::mutable_padding_config() { + return Cast<HloPadInstruction>(this)->mutable_padding_config(); +} + int64 HloInstruction::slice_sizes(int64 dimension) const { return Cast<HloDynamicSliceInstruction>(this)->slice_sizes(dimension); } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 8be7a034877..7a5d506b681 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -681,7 +681,7 @@ class HloInstruction { const absl::optional<int64>& channel_id, const absl::optional<int64>& split_dimension = absl::nullopt); - // Creates a communication instructions that permutes data cross replicas. + // Creates a communication instruction that permutes data cross replicas. // Data is sent/received according to the (source_replica_id, // target_replica_id) pairs in `source_target_pairs`. If a replica id is not a // target_replica_id in any pair, the output on that replica is a tensor @@ -691,6 +691,13 @@ class HloInstruction { const std::vector<std::pair<int64, int64>>& source_target_pairs, const absl::optional<int64>& channel_id); + // Creates a communication instruction that initiates the start of + // CollectivePermute. + static std::unique_ptr<HloInstruction> CreateCollectivePermuteStart( + const Shape& shape, HloInstruction* operand, + const std::vector<std::pair<int64, int64>>& source_target_pairs, + const absl::optional<int64>& channel_id); + // Creates an instruction that returns a U32 replica ID. static std::unique_ptr<HloInstruction> CreateReplicaId(); @@ -1810,6 +1817,7 @@ class HloInstruction { // Delegates to HloPadInstruction::padding_config. const PaddingConfig& padding_config() const; + PaddingConfig* mutable_padding_config(); // Delegates to HloDynamicSliceInstruction::slice_sizes. int64 slice_sizes(int64 dimension) const; diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index d5bdd674563..9c5a66f0040 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -703,10 +703,10 @@ bool HloAllToAllInstruction::IdenticalSlowPath( } HloCollectivePermuteInstruction::HloCollectivePermuteInstruction( - const Shape& shape, HloInstruction* operand, + HloOpcode opcode, const Shape& shape, HloInstruction* operand, const std::vector<std::pair<int64, int64>>& source_target_pairs, const absl::optional<int64>& channel_id) - : HloChannelInstruction(HloOpcode::kCollectivePermute, shape, channel_id), + : HloChannelInstruction(opcode, shape, channel_id), source_target_pairs_(source_target_pairs) { AppendOperand(operand); } @@ -738,6 +738,9 @@ bool HloCollectivePermuteInstruction::IdenticalSlowPath( const HloInstruction& other, const std::function<bool(const HloComputation*, const HloComputation*)>& eq_computations) const { + if (opcode() != other.opcode()) { + return false; + } const auto& casted_other = static_cast<const HloCollectivePermuteInstruction&>(other); return HloChannelInstruction::IdenticalSlowPath(other, eq_computations) && @@ -752,7 +755,7 @@ HloCollectivePermuteInstruction::CloneWithNewOperandsImpl( const Shape& shape, absl::Span<HloInstruction* const> new_operands, HloCloneContext* /*context*/) const { return absl::make_unique<HloCollectivePermuteInstruction>( - shape, new_operands[0], source_target_pairs(), channel_id()); + opcode(), shape, new_operands[0], source_target_pairs(), channel_id()); } HloReverseInstruction::HloReverseInstruction(const Shape& shape, @@ -1864,8 +1867,14 @@ std::unique_ptr<HloInstruction> HloParameterInstruction::CloneWithNewOperandsImpl( const Shape& shape, absl::Span<HloInstruction* const> new_operands, HloCloneContext* context) const { - return absl::make_unique<HloParameterInstruction>(parameter_number_, shape, - name()); + auto clone = absl::make_unique<HloParameterInstruction>(parameter_number_, + shape, name()); + if (parameter_replicated_at_leaf_buffers_ && + ShapeUtil::Equal(shape, this->shape())) { + clone->set_parameter_replicated_at_leaf_buffers( + *parameter_replicated_at_leaf_buffers_); + } + return clone; } HloGetTupleElementInstruction::HloGetTupleElementInstruction( diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index ae78d365cfa..6da01dc088e 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -463,7 +463,7 @@ class HloAllToAllInstruction : public HloCollectiveInstruction { class HloCollectivePermuteInstruction : public HloChannelInstruction { public: explicit HloCollectivePermuteInstruction( - const Shape& shape, HloInstruction* operand, + HloOpcode opcode, const Shape& shape, HloInstruction* operand, const std::vector<std::pair<int64, int64>>& source_target_pairs, const absl::optional<int64>& channel_id); @@ -706,7 +706,7 @@ class HloMapInstruction : public HloInstruction { // Returns the dimension sizes or numbers associated with this instruction. const std::vector<int64>& dimensions() const override { return dimensions_; } int64 dimensions(int64 index) const override { return dimensions()[index]; } - std::vector<int64>* mutable_dimensions() { return &dimensions_; } + std::vector<int64>* mutable_dimensions() override { return &dimensions_; } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -1409,6 +1409,7 @@ class HloPadInstruction : public HloInstruction { const PaddingConfig& padding_config); // Returns the padding configuration for a pad node. const PaddingConfig& padding_config() const { return padding_config_; } + PaddingConfig* mutable_padding_config() { return &padding_config_; } // Returns the padding value. const HloInstruction* padding_value() const { return operand(1); } HloInstruction* mutable_padding_value() { return mutable_operand(1); } diff --git a/tensorflow/compiler/xla/service/hlo_lexer.cc b/tensorflow/compiler/xla/service/hlo_lexer.cc index bc1745a0791..5502665e886 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.cc +++ b/tensorflow/compiler/xla/service/hlo_lexer.cc @@ -17,6 +17,7 @@ limitations under the License. #include <unordered_map> +#include "absl/base/casts.h" #include "absl/strings/ascii.h" #include "absl/strings/escaping.h" #include "absl/strings/numbers.h" @@ -370,6 +371,11 @@ TokKind HloLexer::LexNumberOrPattern() { if (absl::SimpleAtoi(slice, &token_state_.int64_val)) { return TokKind::kInt; } + uint64 uint64_val; + if (absl::SimpleAtoi(slice, &uint64_val)) { + token_state_.int64_val = absl::bit_cast<int64>(uint64_val); + return TokKind::kInt; + } LOG(ERROR) << "Failed to parse int literal: " << slice; return TokKind::kError; } diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index ec048bef9e8..cb1b1d0dae4 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -203,6 +203,7 @@ HLO_MATCHER(Abs); HLO_MATCHER(Add); HLO_MATCHER(AddDependency); HLO_MATCHER(AfterAll); +HLO_MATCHER(AllGather); HLO_MATCHER(AllReduce); HLO_MATCHER(AllToAll); HLO_MATCHER(And); diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h index 833d0fe59d0..964f83322a4 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.h +++ b/tensorflow/compiler/xla/service/hlo_module_config.h @@ -204,7 +204,7 @@ class HloModuleConfig { std::vector<std::vector<int64>>* mutable_dot_config() { return &dot_config_; } - absl::Span<const std::vector<std::vector<int64>>> layout_config() const { + const std::vector<std::vector<std::vector<int64>>>& layout_config() const { return layout_config_; } diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index 664fa10a990..92359bcbdac 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -63,6 +63,8 @@ namespace xla { V(kCholesky, "cholesky", 1) \ V(kClamp, "clamp", 3) \ V(kCollectivePermute, "collective-permute", 1) \ + V(kCollectivePermuteStart, "collective-permute-start", 1) \ + V(kCollectivePermuteDone, "collective-permute-done", 1) \ V(kClz, "count-leading-zeros", 1) \ V(kCompare, "compare", 2) \ V(kComplex, "complex", 2) \ diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 2a90c95850c..d52a60d2555 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -765,6 +765,7 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, case HloOpcode::kBitcast: case HloOpcode::kCeil: case HloOpcode::kClz: + case HloOpcode::kCollectivePermuteDone: case HloOpcode::kCopy: case HloOpcode::kCopyStart: case HloOpcode::kCopyDone: @@ -938,7 +939,8 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, split_dimension)); break; } - case HloOpcode::kCollectivePermute: { + case HloOpcode::kCollectivePermute: + case HloOpcode::kCollectivePermuteStart: { optional<std::vector<std::vector<int64>>> source_targets; attrs["source_target_pairs"] = { /*required=*/true, AttrTy::kBracedInt64ListList, &source_targets}; @@ -957,9 +959,19 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, pairs[i].first = (*source_targets)[i][0]; pairs[i].second = (*source_targets)[i][1]; } - instruction = - builder->AddInstruction(HloInstruction::CreateCollectivePermute( - shape, operands[0], pairs, channel_id)); + if (opcode == HloOpcode::kCollectivePermute) { + instruction = + builder->AddInstruction(HloInstruction::CreateCollectivePermute( + shape, operands[0], pairs, channel_id)); + } else if (opcode == HloOpcode::kCollectivePermuteStart) { + instruction = builder->AddInstruction( + HloInstruction::CreateCollectivePermuteStart(shape, operands[0], + pairs, channel_id)); + } else { + LOG(FATAL) << "Expect opcode to be CollectivePermute or " + "CollectivePermuteStart, but got " + << HloOpcodeString(opcode); + } break; } case HloOpcode::kReplicaId: { @@ -2598,14 +2610,10 @@ bool HloParserImpl::CheckParsedValueIsInRange(LocTy loc, ParsedElemT value) { std::is_same<ParsedElemT, bool>::value)) << "Unimplemented checking for ParsedElemT"; - ParsedElemT upper_bound; - if (sizeof(LiteralNativeT) >= sizeof(ParsedElemT)) { - upper_bound = std::numeric_limits<ParsedElemT>::max(); - } else { - upper_bound = - static_cast<ParsedElemT>(std::numeric_limits<LiteralNativeT>::max()); - } - if (value > upper_bound || value < 0) { + const uint64 unsigned_value = value; + const uint64 upper_bound = + static_cast<uint64>(std::numeric_limits<LiteralNativeT>::max()); + if (unsigned_value > upper_bound) { // Value is out of range for LiteralNativeT. return Error(loc, StrCat("value ", value, " is out of range for literal's primitive type ", diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index e18014a3071..a687d0e1921 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -1553,6 +1553,20 @@ ENTRY CollectivePermute { ROOT root = f32[128,32]{0,1} collective-permute(input), source_target_pairs={{0,1},{1,2},{2,3}} } +)", +/*replica_count=*/4 +}, +// collective-permute-start and -done +{ +"CollectivePermuteStartAndDone", +R"(HloModule CollectivePermuteStartAndDone + +ENTRY CollectivePermuteStartAndDone { + input = f32[128,32]{0,1} parameter(0) + collective-permute-start.1 = (f32[128,32]{0,1}, f32[128,32]{0,1}, u32[], u32[]) collective-permute-start(input), source_target_pairs={{0,1},{1,2},{2,3}} + ROOT collective-permute-done.1 = f32[128,32]{0,1} collective-permute-done(collective-permute-start.1) +} + )", /*replica_count=*/4 }, @@ -2000,9 +2014,7 @@ TEST_F(HloParserTest, ConstantUnsignedUnderflow) { ROOT %constant = u64[] constant(-1) })"; auto result = ParseAndReturnUnverifiedModule(original); - EXPECT_NE(Status::OK(), result.status()); - ExpectHasSubstr(result.status().error_message(), - "is out of range for literal's primitive type U64"); + EXPECT_EQ(Status::OK(), result.status()); } TEST_F(HloParserTest, ConstantUnsignedOverflow) { @@ -2024,7 +2036,7 @@ TEST_F(HloParserTest, ConstantUnsignedInt64Overflow) { ROOT %constant = u64[] constant(9223372036854775808) })"; auto result = ParseAndReturnUnverifiedModule(original); - EXPECT_NE(Status::OK(), result.status()); + EXPECT_EQ(Status::OK(), result.status()); } TEST_F(HloParserTest, ConstantC64Overflow) { diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index d15a36532eb..4661b8fd9e3 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -74,7 +74,6 @@ Status CheckParameterCount(const HloInstruction* calling_instruction, } return Status::OK(); } - } // namespace Status ShapeVerifier::Preprocess(HloInstruction* hlo) { @@ -332,7 +331,9 @@ Status ShapeVerifier::HandleReplicaId(HloInstruction* hlo) { return CheckShape(hlo, ShapeUtil::MakeShape(U32, {})); } -Status ShapeVerifier::HandleCollectivePermute(HloInstruction* hlo) { +namespace { + +Status CheckDuplicatedSourceOrTarget(HloInstruction* hlo) { // A source or target cannot appear twice in the collective-permute's // source-target pairs. absl::flat_hash_set<int64> seen_sources; @@ -351,10 +352,30 @@ Status ShapeVerifier::HandleCollectivePermute(HloInstruction* hlo) { p.second, hlo->ToString()); } } + return Status::OK(); +} + +} // namespace + +Status ShapeVerifier::HandleCollectivePermute(HloInstruction* hlo) { + TF_RETURN_IF_ERROR(CheckDuplicatedSourceOrTarget(hlo)); return CheckShape(hlo, ShapeInference::InferCollectivePermuteShape( hlo->operand(0)->shape())); } +Status ShapeVerifier::HandleCollectivePermuteStart(HloInstruction* hlo) { + TF_RETURN_IF_ERROR(CheckDuplicatedSourceOrTarget(hlo)); + return CheckShape( + hlo, ShapeUtil::MakeTupleShape( + {hlo->operand(0)->shape(), hlo->operand(0)->shape(), + ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeShape(U32, {})})); +} + +Status ShapeVerifier::HandleCollectivePermuteDone(HloInstruction* hlo) { + return CheckShape( + hlo, ShapeUtil::GetTupleElementShape(hlo->operand(0)->shape(), 0)); +} + Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) { return CheckShape(reduce_precision, ShapeInference::InferReducePrecisionShape( reduce_precision->operand(0)->shape(), @@ -1375,32 +1396,60 @@ Status CheckSameIsHostTransfer(const HloInstruction* instr1, return Status::OK(); } -// Checks CopyStart and CopyDone nodes. -Status VerifyAsynchronousCopies(const HloModule& module) { +Status VerifySingleUser(const HloInstruction* instruction, + HloOpcode expected_user) { + TF_RET_CHECK(instruction->users().size() == 1) + << "The " << HloOpcodeString(instruction->opcode()) + << " instruction requires one consumer, found " + << instruction->users().size(); + + const HloInstruction* user = instruction->users().front(); + TF_RET_CHECK(user->opcode() == expected_user) + << "The consumer of a " << HloOpcodeString(instruction->opcode()) + << " instruction needs to be " << HloOpcodeString(expected_user) + << ", found " << HloOpcodeString(user->opcode()); + return Status::OK(); +} + +Status VerifySingleOperand(const HloInstruction* instruction, + HloOpcode expected_operand) { + TF_RET_CHECK(instruction->operands().size() == 1) + << "The " << HloOpcodeString(instruction->opcode()) + << " instruction requires one consumer, found " + << instruction->users().size(); + + const HloInstruction* operand = instruction->operand(0); + TF_RET_CHECK(operand->opcode() == expected_operand) + << "The operand of a " << HloOpcodeString(instruction->opcode()) + << " instruction needs to be " << HloOpcodeString(expected_operand) + << ", found " << HloOpcodeString(operand->opcode()); + return Status::OK(); +} + +// Checks asynchronous instruction pairs. +Status VerifyAsynchronousInstructionPairs(const HloModule& module) { // CopyStart must have a single CopyDone user. for (const HloComputation* computation : module.computations()) { for (const HloInstruction* instruction : computation->instructions()) { switch (instruction->opcode()) { case HloOpcode::kCopyStart: { - TF_RET_CHECK(instruction->users().size() == 1) - << "CopyStart instruction requires one consumer, found " - << instruction->users().size(); - const HloInstruction* copy_done = instruction->users().front(); - TF_RET_CHECK(copy_done->opcode() == HloOpcode::kCopyDone) - << "The consumer of a CopyStart instruction needs to be " - "CopyDone, found " - << HloOpcodeString(copy_done->opcode()); + TF_RETURN_IF_ERROR( + VerifySingleUser(instruction, HloOpcode::kCopyDone)); break; } case HloOpcode::kCopyDone: { - TF_RET_CHECK(instruction->operands().size() == 1) - << "CopyDone instruction requires one operand, found " - << instruction->operands().size(); - const HloInstruction* copy_start = instruction->operand(0); - TF_RET_CHECK(copy_start->opcode() == HloOpcode::kCopyStart) - << "The operand of a CopyDone instruction needs to be CopyStart, " - "found " - << HloOpcodeString(copy_start->opcode()); + TF_RETURN_IF_ERROR( + VerifySingleOperand(instruction, HloOpcode::kCopyStart)); + break; + } + case HloOpcode::kCollectivePermuteStart: { + TF_RETURN_IF_ERROR( + VerifySingleUser(instruction, HloOpcode::kCollectivePermuteDone)); + break; + } + case HloOpcode::kCollectivePermuteDone: { + TF_RETURN_IF_ERROR(VerifySingleOperand( + instruction, HloOpcode::kCollectivePermuteStart)); break; } default: @@ -1815,7 +1864,7 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) { } TF_RETURN_IF_ERROR(VerifyHloStructure(module)); - TF_RETURN_IF_ERROR(VerifyAsynchronousCopies(*module)); + TF_RETURN_IF_ERROR(VerifyAsynchronousInstructionPairs(*module)); TF_RETURN_IF_ERROR(VerifyChannels(*module)); std::unique_ptr<ShapeVerifier> shape_verifier = diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index 7a2d3dc2e6c..85b02e0518c 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -60,6 +60,8 @@ class ShapeVerifier : public DfsHloVisitor { Status HandleAllReduce(HloInstruction* crs) override; Status HandleAllToAll(HloInstruction* hlo) override; Status HandleCollectivePermute(HloInstruction* hlo) override; + Status HandleCollectivePermuteStart(HloInstruction* hlo) override; + Status HandleCollectivePermuteDone(HloInstruction* hlo) override; Status HandlePartitionId(HloInstruction* hlo) override; Status HandleReplicaId(HloInstruction* hlo) override; Status HandleReducePrecision(HloInstruction* reduce_precision) override; diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc index e2c363e40c5..294dfbf66fa 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -710,7 +710,7 @@ TEST_F(HloVerifierTest, CopyStartMultipleCopyDone) { ASSERT_FALSE(status.ok()); EXPECT_THAT( status.error_message(), - HasSubstr("CopyStart instruction requires one consumer, found 2")); + HasSubstr("copy-start instruction requires one consumer, found 2")); } TEST_F(HloVerifierTest, CopyDoneNoCopyStart) { @@ -730,8 +730,8 @@ TEST_F(HloVerifierTest, CopyDoneNoCopyStart) { auto status = verifier().Run(module.get()).status(); ASSERT_FALSE(status.ok()); EXPECT_THAT(status.error_message(), - HasSubstr("The operand of a CopyDone instruction needs to be " - "CopyStart, found tuple")); + HasSubstr("The operand of a copy-done instruction needs to be " + "copy-start, found tuple")); } TEST_F(HloVerifierTest, IotaNonArrayResult) { @@ -1134,5 +1134,86 @@ TEST_F(HloVerifierTest, CollectiveChannelVerifier) { HasSubstr("used for different types of channel instructions")); } +TEST_F(HloVerifierTestLayoutSensitive, CollectivePermuteStartAndDone) { + const char* const kModuleStr = R"( + HloModule Module + + ENTRY CollectivePermuteStartAndDone { + p0 = f32[2,3]{1,0:S(1)} parameter(0) + collective-permute-start.1 = (f32[2,3]{1,0:S(1)}, f32[2,3]{1,0:S(1)}, u32[], u32[]) collective-permute-start(p0), source_target_pairs={{0,1},{1,0}}, channel_id=1 + ROOT collective-permute-done.1 = f32[2,3]{1,0:S(1)} collective-permute-done(collective-permute-start.1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(kModuleStr)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_TRUE(status.ok()); +} + +TEST_F(HloVerifierTest, CollectivePermuteStartAndDoneWrongType) { + const char* const kModuleStr = R"( + HloModule Module + + ENTRY CollectivePermuteStartAndDoneWrongType { + p0 = f32[2,3]{1,0:S(1)} parameter(0) + collective-permute-start.1 = f32[2,3]{1,0:S(1)} collective-permute-start(p0), source_target_pairs={{0,1},{1,0}}, channel_id=1 + ROOT collective-permute-done.1 = f32[2,3]{1,0:S(1)} collective-permute-done(collective-permute-start.1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(kModuleStr)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Expected instruction to have shape equal to " + "(f32[2,3], f32[2,3], u32[], u32[])")); +} + +TEST_F(HloVerifierTest, CollectivePermuteStartAndMultipleDone) { + const char* const kModuleStr = R"( + HloModule Module + + ENTRY CollectivePermuteStartAndMultipleDone { + p0 = f32[2,3]{1,0:S(1)} parameter(0) + collective-permute-start.1 = (f32[2,3]{1,0:S(1)}, f32[2,3]{1,0:S(1)}, u32[], u32[]) collective-permute-start(p0), source_target_pairs={{0,1},{1,0}}, channel_id=1 + collective-permute-done.1 = f32[2,3]{1,0:S(1)} collective-permute-done(collective-permute-start.1) + ROOT collective-permute-done.2 = f32[2,3]{1,0:S(1)} collective-permute-done(collective-permute-start.1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(kModuleStr)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT( + status.error_message(), + HasSubstr("collective-permute-start instruction requires one consumer, " + "found 2")); +} + +TEST_F(HloVerifierTest, CollectivePermuteDoneNoCollectivePermuteStart) { + const char* const kModuleStr = R"( + HloModule Module + + ENTRY CollectivePermuteDoneNoCollectivePermuteStart { + p0 = f32[2,3]{1,0:S(1)} parameter(0) + p1 = f32[2,3]{1,0:S(1)} parameter(1) + p2 = u32[] parameter(2) + tuple.1 = (f32[2,3], f32[2,3], u32[], u32[]) tuple(p0, p1, p2) + ROOT collective-permute-done.1 = f32[2,3]{1,0:S(1)} collective-permute-done(tuple.1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(kModuleStr)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("The operand of a collective-permute-done instruction " + "needs to be collective-permute-start, found tuple")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 1bc3d24274c..02966cc2bf2 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -149,6 +149,8 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kAllReduce: case HloOpcode::kAllToAll: case HloOpcode::kCollectivePermute: + case HloOpcode::kCollectivePermuteDone: + case HloOpcode::kCollectivePermuteStart: case HloOpcode::kCustomCall: case HloOpcode::kDomain: case HloOpcode::kDot: @@ -502,7 +504,7 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) { while (true) { auto next_entry = fusion_queue->DequeueNextInstructionAndOperandsToFuseInOrder(); - auto instruction = next_entry.first; + HloInstruction* instruction = next_entry.first; if (instruction == nullptr) { break; } @@ -512,12 +514,14 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) { continue; } + VLOG(5) << "Considering fusion of: " << instruction->ToString(); std::vector<int64>& sorted_operand_numbers = next_entry.second; for (int64 i : sorted_operand_numbers) { HloInstruction* operand = instruction->mutable_operand(i); if (!operand->IsFusible()) { + VLOG(3) << "Operand (" << operand->ToString() << ") is not fusible"; continue; } @@ -691,6 +695,8 @@ bool InstructionFusion::ShouldFuse(HloInstruction* consumer, if (FusionWouldDuplicate(*producer, *consumer) && (!may_duplicate_ || is_expensive_(*producer)) && !IsAlwaysDuplicable(*producer)) { + VLOG(4) << "Stopping: fusion may duplicate operand (" + << producer->ToString() << ") , and this is expensive"; return false; } diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 13699f3adf9..82c30f1a710 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -2234,6 +2234,8 @@ bool LayoutAssignment::InstructionCanChangeLayout( case HloOpcode::kBitcast: case HloOpcode::kBroadcast: case HloOpcode::kCall: + case HloOpcode::kCollectivePermuteStart: + case HloOpcode::kCollectivePermuteDone: case HloOpcode::kConstant: case HloOpcode::kConvolution: case HloOpcode::kCopy: diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD index 39399df7ad8..cabcc8e06ee 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/BUILD +++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD @@ -64,6 +64,7 @@ cc_library( srcs = ["llvm_util.cc"], hdrs = ["llvm_util.h"], deps = [ + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc index 453a5cd84b2..f7808773592 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc @@ -58,7 +58,7 @@ ENTRY while3 { CompileAndVerifyIr(hlo_string, R"( ; CHECK-LABEL: @body(i8* %retval -; CHECK: %[[add_result:.*]] = fadd fast float %[[fadd_lhs:.*]], %[[fadd_rhs:.*]] +; CHECK: %[[add_result:.*]] = fadd reassoc nsz contract float %[[fadd_lhs:.*]], %[[fadd_rhs:.*]] ; CHECK: store float %[[add_result]], float* %[[store_dest:.*]], align 4, !alias.scope ![[alias_scope_md_for_store:[0-9]+]] ; ; CHECK-LABEL: @condition(i8* %retval, i8* noalias %run_options, i8** noalias %params diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc index da0dbf94ddd..278aa3e1696 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc @@ -20,6 +20,7 @@ limitations under the License. #include "llvm/IR/Constants.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -373,6 +374,28 @@ llvm::Value* IrArray::Index::Linearize(absl::Span<const int64> dimensions, return logical_linear_index; } +llvm::Value* IrArray::Index::Linearize( + const std::vector<llvm::Value*>& dynamic_dims, + llvm::IRBuilder<>* builder) const { + // Each dimension is multiplied by the product of the sizes of all + // earlier dimensions and added to the accumulator logical_linear_index. + CHECK_EQ(size(), dynamic_dims.size()); + llvm::Value* logical_linear_index = GetConstantWithIndexType(0); + llvm::Value* multiplier = GetConstantWithIndexType(1); + for (ssize_t i = size() - 1; i >= 0; --i) { + llvm::Value* addend = builder->CreateMul((*this)[i], multiplier, "", + /*HasNUW=*/true, /*HasNSW=*/true); + addend = builder->CreateZExtOrTrunc(addend, index_type_); + logical_linear_index = builder->CreateAdd(logical_linear_index, addend, "", + /*HasNUW=*/true, /*HasNSW=*/true); + if (i) { + multiplier = builder->CreateMul(multiplier, dynamic_dims[i], + /*Name=*/"multiplier"); + } + } + return logical_linear_index; +} + llvm::Value* IrArray::EmitArrayElementAddress(const IrArray::Index& index, llvm::IRBuilder<>* b, absl::string_view name, diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h index e838c4a0534..c71654f5294 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h @@ -155,6 +155,10 @@ class IrArray { llvm::Value* Linearize(absl::Span<const int64> dimensions, llvm::IRBuilder<>* builder) const; + // Linearizes the index into the given dynamic dimensions. + llvm::Value* Linearize(const std::vector<llvm::Value*>& dynamic_dims, + llvm::IRBuilder<>* builder) const; + llvm::Type* GetType() const { return index_type_; } llvm::Constant* GetConstantWithIndexType(int64 c) const { diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index 4c9a8d3e004..6375bf7341f 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -30,6 +30,7 @@ limitations under the License. #include "llvm/Support/CommandLine.h" #include "llvm/Target/TargetOptions.h" #include "llvm/Transforms/Utils/Cloning.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/cpu/cpu_options.h" @@ -90,7 +91,9 @@ llvm::CallInst* EmitCallToIntrinsic( llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value, llvm::IRBuilder<>* b) { - if (b->getFastMathFlags().noNaNs()) { + // TODO(tpopp): Pass this information down from the HLO's ModuleConfig. + if (b->getFastMathFlags().noNaNs() || + GetDebugOptionsFromFlags().xla_cpu_enable_fast_min_max()) { auto cmp = b->CreateFCmpUGE(lhs_value, rhs_value); return b->CreateSelect(cmp, lhs_value, rhs_value); } else { @@ -103,7 +106,9 @@ llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value, llvm::Value* EmitFloatMin(llvm::Value* lhs_value, llvm::Value* rhs_value, llvm::IRBuilder<>* b) { - if (b->getFastMathFlags().noNaNs()) { + // TODO(tpopp): Pass this information down from the HLO's ModuleConfig. + if (b->getFastMathFlags().noNaNs() || + GetDebugOptionsFromFlags().xla_cpu_enable_fast_min_max()) { auto cmp = b->CreateFCmpULE(lhs_value, rhs_value); return b->CreateSelect(cmp, lhs_value, rhs_value); } else { @@ -287,7 +292,7 @@ llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount(llvm::Type* type, llvm::AllocaInst* alloca = b->CreateAlloca(type, element_count, AsStringRef(name)); if (alignment != 0) { - alloca->setAlignment(llvm::MaybeAlign(alignment)); + alloca->setAlignment(llvm::Align(alignment)); } return alloca; } diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc index 83be4334269..b6b3b2dd8b3 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc @@ -35,6 +35,14 @@ LoopEmitter::LoopEmitter(const BodyEmitter& body_emitter, const Shape& shape, llvm::IRBuilder<>* b) : body_emitter_(body_emitter), shape_(shape), b_(b) {} +LoopEmitter::LoopEmitter(const BodyEmitter& body_emitter, const Shape& shape, + std::vector<llvm::Value*> dynamic_dims, + llvm::IRBuilder<>* b) + : LoopEmitter::LoopEmitter(body_emitter, shape, b) { + CHECK_EQ(dynamic_dims.size(), shape_.dimensions_size()); + dynamic_dims_ = std::move(dynamic_dims); +} + LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator, const IrArray& target_array, llvm::IRBuilder<>* b) : body_emitter_([=](const llvm_ir::IrArray::Index array_index) -> Status { @@ -84,6 +92,43 @@ LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator, } } +IrArray::Index LoopEmitter::EmitStaticIndex(ForLoopNest* loop_nest, + llvm::Type* index_type) { + // Create loop nest with one for-loop for each dimension of the target shape. + // Loops are added from outermost to innermost order with the ForLoopNest + // class so emit loops in order from most-major dimension down to most-minor + // dimension (of the target shape). + std::vector<llvm::Value*> array_multi_index(shape_.dimensions_size()); + for (int i = 0; i < LayoutUtil::MinorToMajor(shape_).size(); ++i) { + int64 dimension = LayoutUtil::Major(shape_.layout(), i); + std::unique_ptr<ForLoop> loop = loop_nest->AddLoop( + /*start_index=*/0, + /*end_index=*/shape_.dimensions(dimension), + /*suffix=*/absl::StrFormat("dim.%d", dimension)); + array_multi_index[dimension] = loop->GetIndVarValue(); + } + return IrArray::Index(array_multi_index, shape_, index_type); +} + +IrArray::Index LoopEmitter::EmitDynamicIndex(ForLoopNest* loop_nest, + llvm::Type* index_type) { + CHECK_EQ(shape_.is_dynamic(), true); + // Create loop nest with one for-loop for each dynamic dimensions. + // Loops are added from outermost to innermost order with the ForLoopNest + // class so emit loops in order from most-major dimension down to most-minor + // dimension (of the target shape). + std::vector<llvm::Value*> array_multi_index(shape_.dimensions_size()); + for (int i = 0; i < LayoutUtil::MinorToMajor(shape_).size(); ++i) { + int64 dimension = LayoutUtil::Major(shape_.layout(), i); + std::unique_ptr<ForLoop> loop = loop_nest->AddLoop( + /*suffix=*/absl::StrFormat("dim.%d", dimension), + /*start_index=*/llvm::ConstantInt::get(index_type, 0), + /*end_index=*/dynamic_dims_[dimension]); + array_multi_index[dimension] = loop->GetIndVarValue(); + } + return IrArray::Index(array_multi_index, shape_, index_type); +} + std::vector<IrArray::Index> LoopEmitter::EmitIndexAndSetExitBasicBlock( absl::string_view loop_name, llvm::Type* index_type) { CHECK_NE(index_type, nullptr); @@ -93,21 +138,11 @@ std::vector<IrArray::Index> LoopEmitter::EmitIndexAndSetExitBasicBlock( return {IrArray::Index(index_type)}; } - // Create loop nest with one for-loop for each dimension of the target shape. - // Loops are added from outermost to innermost order with the ForLoopNest - // class so emit loops in order from most-major dimension down to most-minor - // dimension (of the target shape). ForLoopNest loop_nest(loop_name, b_); - std::vector<llvm::Value*> array_multi_index(shape_.dimensions_size()); - for (int i = 0; i < LayoutUtil::MinorToMajor(shape_).size(); ++i) { - int64 dimension = LayoutUtil::Major(shape_.layout(), i); - std::unique_ptr<ForLoop> loop = loop_nest.AddLoop( - /*start_index=*/0, - /*end_index=*/shape_.dimensions(dimension), - /*suffix=*/absl::StrFormat("dim.%d", dimension)); - array_multi_index[dimension] = loop->GetIndVarValue(); - } - IrArray::Index array_index(array_multi_index, shape_, index_type); + + IrArray::Index array_index = dynamic_dims_.empty() + ? EmitStaticIndex(&loop_nest, index_type) + : EmitDynamicIndex(&loop_nest, index_type); // Set IR builder insertion point to the loop body basic block of the // innermost loop. diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h index a537c00066b..008205a642a 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h @@ -22,6 +22,7 @@ limitations under the License. #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/statusor.h" namespace xla { @@ -42,6 +43,12 @@ class LoopEmitter { LoopEmitter(const BodyEmitter& body_emitter, const Shape& shape, llvm::IRBuilder<>* b); + + // Constructs a LoopEmitter from an body_emitter that generates + // element of the given target array in the dynamic dimension. + LoopEmitter(const BodyEmitter& body_emitter, const Shape& shape, + std::vector<llvm::Value*> dynamic_dims, llvm::IRBuilder<>* b); + // Constructs a LoopEmitter from an element generator that generates each // element of the given target array. LoopEmitter(const ElementGenerator& target_element_generator, @@ -81,11 +88,21 @@ class LoopEmitter { // The shape that the emitted loop iterates through. Shape shape_; + // Dynamic dimensions that emitted loop iterates through. Generate the + // loop based on the dynamic dimensions if this vector is not empty. + std::vector<llvm::Value*> dynamic_dims_; + // Points to the exit block of the emitted loop. If the given shape is // scalar, no loops are emitted and exit_bb_ is nullptr in that case. llvm::BasicBlock* exit_bb_; llvm::IRBuilder<>* b_; + + private: + IrArray::Index EmitStaticIndex(ForLoopNest* loop_nest, + llvm::Type* index_type); + IrArray::Index EmitDynamicIndex(ForLoopNest* loop_nest, + llvm::Type* index_type); }; } // namespace llvm_ir diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc index 742de71e74c..e07431bf46f 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc @@ -16,53 +16,52 @@ limitations under the License. #include "tensorflow/compiler/xla/service/memory_space_assignment.h" #include "tensorflow/compiler/xla/debug_options_flags.h" +#include "tensorflow/core/lib/math/math_util.h" namespace xla { namespace { // Define a dummy chunk for chunks that will be allocated in the default memory // space and for keeping track of number of asynchronous copies. const HeapSimulator::Chunk kDummyChunk{-1, -1}; +// This variable is used by the cost analysis in estimating how many times each +// while loop will execute. Nested loops will be assumed to have executed +// pow(kWhileExecutionCount, nesting_level) times. +const int kWhileExecutionCount = 5; -// Returns a heuristic value that captures how much putting this tensor to -// the alternate memory would help if the op is memory bound, or otherwise -// how far off is the op to memory boundedness. The larger this number, the -// higher priority it will be placed in the alternate memory. -float GetAlternateMemoryBenefit( - const MemorySpaceAssignmentCostAnalysis& cost_analysis, +} // namespace + +float MemorySpaceAssignmentCostAnalysis::GetAlternateMemoryBenefit( const HloInstruction& instruction, - float elapsed_time_due_to_alternate_mem) { + float elapsed_time_due_to_alternate_mem) const { float elapsed_time_due_to_compute = - cost_analysis.GetInstructionElapsedDueToCompute(instruction); + GetInstructionElapsedDueToCompute(instruction); float elapsed_time_due_to_memory = - cost_analysis.GetInstructionElapsedDueToMemory(instruction); + GetInstructionElapsedDueToMemory(instruction); if (elapsed_time_due_to_memory > elapsed_time_due_to_compute) { // Memory bound, return how much alternate memory is better. - return elapsed_time_due_to_memory - elapsed_time_due_to_alternate_mem; + int while_nest_level = CalculateWhileLoopNestLevel(&instruction); + return (elapsed_time_due_to_memory - elapsed_time_due_to_alternate_mem) * + tensorflow::MathUtil::IPow<float>(kWhileExecutionCount, + while_nest_level); } else { // Compute bound, return how far off are we to memory boundedness. return elapsed_time_due_to_memory - elapsed_time_due_to_compute; } } -// Returns a heuristic value of memory boundedness for the given BufferInterval. -// The larger this number, the higher priority it will be placed in the -// alternate memory. -float GetMemoryBoundedness( - const MemorySpaceAssignmentCostAnalysis& cost_analysis, - const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) { +float MemorySpaceAssignmentCostAnalysis::GetMemoryBoundedness( + const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const { const HloInstruction& defining_instruction = *interval.buffer->defining_instruction(); - float alternate_mem_benefit = - GetAlternateMemoryBenefit(cost_analysis, defining_instruction, - cost_analysis.GetInstructionElapsedDueToMemory( - defining_instruction, - /*operand_in_alternate_mem=*/{}, - /*output_in_alternate_mem=*/true)); + float alternate_mem_benefit = GetAlternateMemoryBenefit( + defining_instruction, + GetInstructionElapsedDueToMemory(defining_instruction, + /*operand_in_alternate_mem=*/{}, + /*output_in_alternate_mem=*/true)); for (const HloUse& use : interval.buffer->uses()) { float use_alternate_mem_benefit = GetAlternateMemoryBenefit( - cost_analysis, *use.instruction, - cost_analysis.GetInstructionElapsedDueToMemory(*use.instruction, - use.operand_number)); + *use.instruction, + GetInstructionElapsedDueToMemory(*use.instruction, use.operand_number)); // If the benefit is positive (memory bound), add it to this buffer's // benefit. If the benefit is negative (compute bound), calculate the // maximum. @@ -77,7 +76,7 @@ float GetMemoryBoundedness( // Get performance slowdown in seconds of prefetching current BufferInterval // causing to other BufferIntervals. float alternate_mem_slowdown = - cost_analysis.GetInstructionElapsedDueToMemorySlowdown(interval.size); + GetInstructionElapsedDueToMemorySlowdown(interval.size); // Scale the slowdown based on the time of this buffer. We would want earlier // buffers have lower slowdown values, because they are less likely to overlap @@ -86,13 +85,28 @@ float GetMemoryBoundedness( // for early HLOs, and full slowdown for mid-to-late HLOs. // TODO(yuemmawang): Further in a smarter way, we want buffers overlapped with // more HLOs have higher slowdown, and vice versa. - float scale = interval.start * 1.0 / cost_analysis.GetScheduleEndTime(); + float scale = interval.start * 1.0 / GetScheduleEndTime(); alternate_mem_slowdown *= scale; return alternate_mem_benefit - alternate_mem_slowdown; } -} // namespace +int MemorySpaceAssignmentCostAnalysis::CalculateWhileLoopNestLevel( + const HloInstruction* instruction) const { + int nest_level = 0; + const HloComputation* computation = instruction->parent(); + while (!computation->IsEntryComputation()) { + auto node = call_graph_.GetNode(computation); + auto callsites = node.caller_callsites(); + CHECK_EQ(callsites.size(), 1) << "The module is not flattened!"; + auto callsite = callsites[0]; + if (callsite.instruction()->opcode() == HloOpcode::kWhile) { + ++nest_level; + } + computation = callsite.instruction()->parent(); + } + return nest_level; +} float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedDueToCompute( const HloInstruction& instruction) const { @@ -207,29 +221,30 @@ CostAnalysisPrefetchIntervalPicker::CostAnalysisPrefetchIntervalPicker( const MemorySpaceAssignmentCostAnalysis& cost_analysis, float min_async_copy_to_overlap_ratio, float max_async_copy_to_overlap_ratio) - : cost_analysis_(cost_analysis), + : elapsed_time_( + cost_analysis.hlo_live_range().instruction_schedule().size(), 0.0), + while_nest_level_( + cost_analysis.hlo_live_range().instruction_schedule().size(), 0), + cost_analysis_(cost_analysis), min_async_copy_to_overlap_ratio_(min_async_copy_to_overlap_ratio), max_async_copy_to_overlap_ratio_(max_async_copy_to_overlap_ratio) { instruction_schedule_ = &cost_analysis_.hlo_live_range().instruction_schedule(); - // First create a vector of elapsed times of HLO instructions. - std::vector<float> instructions_elapsed_time(instruction_schedule_->size(), - 0.0); + // Create a vector of elapsed times and while nesting levels of HLO + // instructions. for (const auto& instruction_and_logical_time : *instruction_schedule_) { float elapsed_time = cost_analysis_.cost_analysis().optimal_seconds( *instruction_and_logical_time.first); int64 logical_time = instruction_and_logical_time.second; - if (logical_time >= instructions_elapsed_time.size()) { - instructions_elapsed_time.resize(logical_time + 1, 0.0); + if (logical_time >= elapsed_time_.size()) { + elapsed_time_.resize(logical_time + 1, 0.0); + while_nest_level_.resize(logical_time + 1, 0); } - instructions_elapsed_time[logical_time] = elapsed_time; - } - // As an optimization, create a cumulative sum vector of elapsed time. - float cumsum = 0.0; - for (float elapsed_time : instructions_elapsed_time) { - cumsum += elapsed_time; - elapsed_time_cumsum_.push_back(cumsum); + elapsed_time_[logical_time] = elapsed_time; + while_nest_level_[logical_time] = + cost_analysis_.CalculateWhileLoopNestLevel( + instruction_and_logical_time.first); } } @@ -275,7 +290,7 @@ void CostAnalysisPrefetchIntervalPicker::Begin(const HloUse& use, end_logical_time_ = end_time; // Find the earliest time we're allowed to start prefetching. for (current_logical_prefetch_time_ = start_time; - current_logical_prefetch_time_ <= end_logical_time_ && + current_logical_prefetch_time_ < end_logical_time_ && max_async_copy_to_overlap_ratio_ * async_copy_elapsed_ < GetLogicalIntervalElapsed(current_logical_prefetch_time_, end_logical_time_); @@ -290,9 +305,9 @@ int64 CostAnalysisPrefetchIntervalPicker::Next() { } bool CostAnalysisPrefetchIntervalPicker::Done() const { - // The end time is inclusive, so we're done if the prefetch time is greater - // than that. - if (current_logical_prefetch_time_ > end_logical_time_) { + // The end time is exclusive, so we're done if the prefetch time is greater + // than or equal to the end time. + if (current_logical_prefetch_time_ >= end_logical_time_) { return true; } float logical_interval_elapsed = GetLogicalIntervalElapsed( @@ -303,7 +318,17 @@ bool CostAnalysisPrefetchIntervalPicker::Done() const { float CostAnalysisPrefetchIntervalPicker::GetLogicalIntervalElapsed( int64 start_time, int64 end_time) const { - return elapsed_time_cumsum_[end_time - 1] - elapsed_time_cumsum_[start_time]; + int interval_nest_level = + std::min(while_nest_level_[start_time], while_nest_level_[end_time]); + float total_elapsed = 0; + for (int i = start_time + 1; i < end_time; ++i) { + total_elapsed += + elapsed_time_[i] * + tensorflow::MathUtil::IPow<float>( + kWhileExecutionCount, + std::max(0, while_nest_level_[i] - interval_nest_level)); + } + return total_elapsed; } std::string CostAnalysisPrefetchIntervalPicker::ToDebugString() const { @@ -328,7 +353,7 @@ std::string CostAnalysisPrefetchIntervalPicker::ToNoCopyDebugString( absl::optional<float> CostAnalysisPrefetchIntervalPicker::BufferIntervalAlternateMemoryBenefit( const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const { - return GetMemoryBoundedness(cost_analysis_, interval); + return cost_analysis_.GetMemoryBoundedness(interval); } std::string MemorySpaceAssignment::AllocationValue::ToString() const { @@ -502,7 +527,8 @@ bool AlternateMemoryBestFitHeap::IsIntervalAllowedInAlternateMemory( } bool AlternateMemoryBestFitHeap::IsUseAllowedInAlternateMemory( - const HloUse& use) const { + const AllocationValue& value, const HloUse& use) const { + const auto& instruction_schedule = hlo_live_range_.instruction_schedule(); if (use.instruction->opcode() == HloOpcode::kWhile) { HloComputation* while_body = use.instruction->while_body(); @@ -512,7 +538,6 @@ bool AlternateMemoryBestFitHeap::IsUseAllowedInAlternateMemory( HloValue* parameter_value = &alias_analysis_.dataflow_analysis().GetUniqueValueAt( while_body->parameter_instruction(0), use.operand_index); - const auto& instruction_schedule = hlo_live_range_.instruction_schedule(); int64 parameter_time = instruction_schedule.at(while_body->parameter_instruction(0)); int64 root_time = instruction_schedule.at(while_body->root_instruction()); @@ -567,7 +592,54 @@ bool AlternateMemoryBestFitHeap::IsUseAllowedInAlternateMemory( "there is a required default memory assignment."; return false; } + } else if (use.instruction->opcode() == HloOpcode::kConditional) { + // For any use of this conditional (the same value might be passed into + // multiple called computations), determine if the parameter->first use + // dependency is short. + int64 conditional_time = instruction_schedule.at(use.instruction); + for (const HloUse& other_use : value.uses()) { + if (other_use.instruction != use.instruction) { + continue; + } + HloComputation* called_computation = + use.instruction->called_computations().at(other_use.operand_number - + 1); + const HloInstruction* parameter_instruction = + called_computation->parameter_instruction(0); + HloValue* parameter_value = + &alias_analysis_.dataflow_analysis().GetUniqueValueAt( + parameter_instruction, other_use.operand_index); + int64 parameter_time = instruction_schedule.at(parameter_instruction); + int64 min_use_time = conditional_time; + for (const HloUse& parameter_use : parameter_value->uses()) { + if (parameter_use.instruction->parent() == called_computation && + parameter_use.instruction->opcode() != + HloOpcode::kGetTupleElement && + parameter_use.instruction->opcode() != HloOpcode::kTuple && + parameter_use.instruction->opcode() != HloOpcode::kBitcast) { + min_use_time = std::min( + min_use_time, instruction_schedule.at(parameter_use.instruction)); + } + } + if (options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy( + parameter_value->shape(), parameter_time, min_use_time)) { + VLOG(4) << "Conditional allocation allowed in alternate memory for " + "computation = " + << called_computation->name() + << ", parameter time = " << parameter_time + << ", min use time = " << min_use_time; + return true; + } else { + VLOG(4) << "Conditional allocation not allowed in alternate memory for " + "computation = " + << called_computation->name() + << ", parameter time = " << parameter_time + << ", min use time = " << min_use_time; + } + } + return false; } + return true; } @@ -758,8 +830,6 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { } const auto& instruction_schedule = hlo_live_range_.instruction_schedule(); - global_max_time_ = instruction_schedule.at( - module->entry_computation()->root_instruction()); // TODO(berkin): For now, place the phi values due to conditionals in // default memory. @@ -769,20 +839,12 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { if (position.instruction->opcode() == HloOpcode::kConditional) { VLOG(3) << "Adding required assignment for condition output: " << value->ToShortString(); - required_assignments_[value].push_back( - {MemorySpace::kDefault, - instruction_schedule.at(position.instruction), - /*chunk=*/absl::nullopt}); + AddRequiredAssignment(position.instruction, position.index, + MemorySpace::kDefault); for (const HloComputation* called_computation : position.instruction->called_computations()) { - HloValue* root_value = - &alias_analysis_.dataflow_analysis().GetUniqueValueAt( - called_computation->root_instruction(), position.index); - required_assignments_[root_value].push_back( - {MemorySpace::kDefault, - instruction_schedule.at( - called_computation->root_instruction()), - /*chunk=*/absl::nullopt}); + AddRequiredAssignment(called_computation->root_instruction(), + position.index, MemorySpace::kDefault); } } } @@ -808,9 +870,13 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { } // Iterate over the uses. - for (HloUse use : allocation_value.uses()) { + for (int use_idx = 0; use_idx < allocation_value.uses().size(); + ++use_idx) { + const HloUse& use = allocation_value.uses().at(use_idx); int64 use_time = instruction_schedule.at(use.instruction); int64 latest_prefetch_time = use_time; + bool allow_no_copy_alternate_mem_allocation = true; + absl::optional<int64> earliest_prefetch_time = absl::nullopt; // Sequential calls include kWhile, kCall, and kConditional opcodes. bool is_sequential_call = @@ -857,14 +923,41 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { // when we look at uses within the while loop body. use_time = instruction_schedule.at(while_body->parameter_instruction(0)); + } else if (use.instruction->opcode() == HloOpcode::kConditional) { + // Replace the use time with the earliest parameter of called + // computations. + for (const HloComputation* called_computation : + use.instruction->called_computations()) { + use_time = std::min( + use_time, instruction_schedule.at( + called_computation->parameter_instruction(0))); + } } } // Add a required assignment in default memory if the use not allowed in // alternate memory. - if (!IsUseAllowedInAlternateMemory(use)) { - required_assignments_[allocation_value.value()].push_back( - {MemorySpace::kDefault, use_time, /*chunk=*/absl::nullopt}); + if (!IsUseAllowedInAlternateMemory(allocation_value, use)) { + AddRequiredAssignment(allocation_value.value(), use.instruction, + MemorySpace::kDefault, use_time); + } else if (use_idx > 0) { + // We allow buffers in alternate memory that are passed into + // conditionals to give up their alternate memory allocation inside + // the called computation. This means that if a conditional operator + // has an alternate memory allocation, subsequent uses cannot use the + // same alternate memory allocation in order not to clobber data. So + // we force default memory allocation for these subsequent uses. + const HloUse& previous_use = allocation_value.uses().at(use_idx - 1); + if (previous_use.instruction->opcode() == HloOpcode::kConditional && + previous_use.instruction != use.instruction) { + allow_no_copy_alternate_mem_allocation = false; + earliest_prefetch_time = + instruction_schedule.at(previous_use.instruction); + VLOG(3) << "Previous use (" << previous_use.ToString() + << ") of use (" << use.ToString() + << ") is a conditional, so this use will need to evict. " + << "Earliest prefetch time = " << *earliest_prefetch_time; + } } // Bitcasts don't define buffers and don't directly consume buffers. @@ -872,10 +965,16 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { // bitcasts will be handled specially. if (use.instruction->opcode() != HloOpcode::kBitcast) { AllocationRequest request; - request.start_time = definition_time; + // Rarely, (e.g., when conditional true and false parameters are the + // same), definition time can be the time of the conditional and use + // time is the parameter use, which is less. + request.start_time = std::min(definition_time, use_time); request.end_time = use_time; request.latest_prefetch_time = latest_prefetch_time; request.size = interval.size; + request.allow_no_copy_alternate_mem_allocation = + allow_no_copy_alternate_mem_allocation; + request.earliest_prefetch_time = earliest_prefetch_time; request.preferred_offset = preferred_offset; request.use = use; request.allocation_value = &allocation_value; @@ -1061,35 +1160,42 @@ void AlternateMemoryBestFitHeap::AddAliasedRequiredAssignment( if (aliased_allocation->memory_space() == MemorySpace::kAlternate) { chunk = aliased_allocation->chunk(); } - const auto& instruction_schedule = hlo_live_range_.instruction_schedule(); - HloValue* value = - &alias_analysis_.dataflow_analysis().GetUniqueValueAt(instruction, index); - int64 instruction_time = instruction_schedule.at(instruction); + AddRequiredAssignment(instruction, index, aliased_allocation->memory_space(), + chunk); +} + +void AlternateMemoryBestFitHeap::AddRequiredAssignment( + const HloValue* value, const HloInstruction* instruction, + MemorySpaceAssignment::MemorySpace memory_space, int64 time, + absl::optional<HeapSimulator::Chunk> chunk) { // Check for existing required assignment at this time and make sure it is the // same as this if there is one. - auto existing_required_assignment = - RequiredMemoryAssignmentAt(value, instruction_time); + auto existing_required_assignment = RequiredMemoryAssignmentAt(value, time); if (existing_required_assignment) { - CHECK(aliased_allocation->memory_space() == - existing_required_assignment->memory_space); + CHECK(memory_space == existing_required_assignment->memory_space) + << "inst = " << instruction->ToString() << " at " << time; CHECK((!chunk && !existing_required_assignment->chunk) || chunk->offset == existing_required_assignment->chunk->offset); - VLOG(3) << "Not adding aliased required assignment because there is one " - "already: " - << value->ToShortString() << " at " << instruction_time << " at " - << (aliased_allocation->memory_space() == MemorySpace::kDefault - ? "def" - : "alt"); - return; + VLOG(3) << "Not adding required assignment because there is one already: " + << value->ToShortString() << " at " << time << " at " + << (memory_space == MemorySpace::kDefault ? "def" : "alt"); + } else { + VLOG(3) << "Adding required assignment: " << value->ToShortString() + << " at " << time << " at " + << (memory_space == MemorySpace::kDefault ? "def" : "alt"); + required_assignments_[value].push_back({memory_space, time, chunk}); } +} - required_assignments_[value].push_back( - {aliased_allocation->memory_space(), instruction_time, chunk}); - VLOG(3) << "Adding aliased required assignment: " << value->ToShortString() - << " at " << instruction_time << " at " - << (aliased_allocation->memory_space() == MemorySpace::kDefault - ? "def" - : "alt"); +void AlternateMemoryBestFitHeap::AddRequiredAssignment( + const HloInstruction* instruction, ShapeIndex index, + MemorySpace memory_space, absl::optional<Chunk> chunk) { + const HloValue* value = + &alias_analysis_.dataflow_analysis().GetUniqueValueAt(instruction, index); + int64 instruction_time = + hlo_live_range_.instruction_schedule().at(instruction); + AddRequiredAssignment(value, instruction, memory_space, instruction_time, + chunk); } void AlternateMemoryBestFitHeap::AddInputAndOutputRequiredAssignments() { @@ -1187,10 +1293,13 @@ void AlternateMemoryBestFitHeap::UncommitPendingChunks() { interval_tree_.Remove(interval.start, interval.end, chunk); } for (const auto& interval : pending_async_copies_) { - async_copy_interval_tree_.Remove(interval.start_time, interval.end_time, - kDummyChunk); if (interval.destination == MemorySpace::kAlternate) { + prefetch_interval_tree_.Remove(interval.start_time, interval.end_time, + kDummyChunk); async_copy_ordering_.RemoveCopy(interval); + } else { + eviction_interval_tree_.Remove(interval.start_time, interval.end_time, + kDummyChunk); } } pending_chunks_.clear(); @@ -1289,6 +1398,7 @@ bool AlternateMemoryBestFitHeap::FindAllocation( // First try keeping the allocation entirely in the alternate memory. if (required_memory_space_at_start != MemorySpace::kDefault && required_memory_space_at_end != MemorySpace::kDefault && + request.allow_no_copy_alternate_mem_allocation && AllocateInAlternateMemoryNoCopy(request)) { return true; } @@ -1363,6 +1473,7 @@ void AlternateMemoryBestFitHeap::AddAsyncCopy( : "alternate") << " memory between " << start_time << " and " << copy_done_schedule_before_time << " keeping until " << end_time; + CHECK_LT(start_time, copy_done_schedule_before_time); allocations->push_back( absl::make_unique<MemorySpaceAssignment::CopyAllocation>( @@ -1373,27 +1484,37 @@ void AlternateMemoryBestFitHeap::AddAsyncCopy( // the limit at any given time. pending_async_copies_.push_back( {start_time, copy_done_schedule_before_time, memory_space}); - async_copy_interval_tree_.Add(start_time, copy_done_schedule_before_time, - kDummyChunk); if (memory_space == MemorySpaceAssignment::MemorySpace::kAlternate) { + prefetch_interval_tree_.Add(start_time, copy_done_schedule_before_time, + kDummyChunk); async_copy_ordering_.AddCopy(pending_async_copies_.back()); + } else { + eviction_interval_tree_.Add(start_time, copy_done_schedule_before_time, + kDummyChunk); } } bool AlternateMemoryBestFitHeap::ViolatesMaximumOutstandingAsyncCopies( - int64 start_time, int64 end_time) const { - if (options_.max_outstanding_async_copies < 0) { + int64 start_time, int64 end_time, bool is_prefetch) const { + if (options_.max_outstanding_prefetches < 0 && is_prefetch) { + return false; + } + if (options_.max_outstanding_evictions < 0 && !is_prefetch) { return false; } - // Count the asynchronous copies in the interval tree for the given interval. - int64 num_async_copies = - async_copy_interval_tree_.ChunksOverlappingInTime(start_time, end_time) - .size(); - - // Add one because we are checking if adding an additional asynchronous copy - // would violate the limit. - return num_async_copies + 1 > options_.max_outstanding_async_copies; + // Count the prefetches/evictions in the interval tree for the given interval. + if (is_prefetch) { + int64 num_prefetches = + prefetch_interval_tree_.ChunksOverlappingInTime(start_time, end_time) + .size(); + return num_prefetches >= options_.max_outstanding_prefetches; + } else { + int64 num_evictions = + eviction_interval_tree_.ChunksOverlappingInTime(start_time, end_time) + .size(); + return num_evictions >= options_.max_outstanding_evictions; + } } bool AlternateMemoryBestFitHeap::ViolatesAsyncCopyOrdering( @@ -1525,6 +1646,9 @@ bool AlternateMemoryBestFitHeap::Evict(const AllocationRequest& request) { request.allocation_value->defining_position().shape(), eviction_start_time, request.end_time), eviction_end_time); + // Evictions must complete by the time of this use. + preferred_eviction_end_time = + std::min(preferred_eviction_end_time, request.latest_prefetch_time); BufferInterval eviction_mem_interval; eviction_mem_interval.buffer = request.allocation_value->value(); @@ -1532,8 +1656,7 @@ bool AlternateMemoryBestFitHeap::Evict(const AllocationRequest& request) { // Try to reserve a buffer from the end of the previous allocation to the // preferred eviction end time. eviction_mem_interval.start = eviction_end_time + 1; - eviction_mem_interval.end = - std::min(preferred_eviction_end_time, global_max_time_); + eviction_mem_interval.end = preferred_eviction_end_time; int64 preferred_offset = prev_allocation->chunk().offset; VLOG(3) << "Eviction (" << eviction_start_time << ", " << eviction_end_time << ") preferred end time = " << eviction_mem_interval.end; @@ -1555,7 +1678,8 @@ bool AlternateMemoryBestFitHeap::Evict(const AllocationRequest& request) { bool eviction_interval_too_short = (eviction_start_time == eviction_end_time); bool eviction_violates_outstanding_copies = ViolatesMaximumOutstandingAsyncCopies(eviction_start_time, - eviction_end_time); + eviction_end_time, + /*is_prefetch=*/false); // See if this interval would violate the asynchronous copy limit. if (!eviction_interval_too_short && !eviction_violates_outstanding_copies) { @@ -1576,7 +1700,8 @@ bool AlternateMemoryBestFitHeap::Evict(const AllocationRequest& request) { bool eviction_scheduled = false; for (int64 time = eviction_start_time; time < eviction_end_time; ++time) { VLOG(4) << "Try evicting (" << time << ", " << time + 1 << ")"; - if (!ViolatesMaximumOutstandingAsyncCopies(time, time + 1)) { + if (!ViolatesMaximumOutstandingAsyncCopies(time, time + 1, + /*is_prefetch=*/false)) { VLOG(3) << "Eviction successful."; AddAsyncCopy(*prev_allocation, MemorySpace::kDefault, /*chunk=*/absl::nullopt, time, time + 1, time + 1, @@ -1618,9 +1743,14 @@ bool AlternateMemoryBestFitHeap::Prefetch( // ^ ^ // Copy Copy // Start Done - options_.prefetch_interval_picker->Begin( - request.use, prev_allocation_in_default_mem.earliest_available_time(), - request.latest_prefetch_time); + int64 earliest_prefetch_time = + prev_allocation_in_default_mem.earliest_available_time(); + if (request.earliest_prefetch_time) { + earliest_prefetch_time = + std::max(earliest_prefetch_time, *request.earliest_prefetch_time); + } + options_.prefetch_interval_picker->Begin(request.use, earliest_prefetch_time, + request.latest_prefetch_time); VLOG(3) << "Trying prefetch picker = " << options_.prefetch_interval_picker->ToDebugString(); @@ -1631,12 +1761,14 @@ bool AlternateMemoryBestFitHeap::Prefetch( alternate_mem_interval.size = request.size; while (!options_.prefetch_interval_picker->Done()) { alternate_mem_interval.start = options_.prefetch_interval_picker->Next(); + CHECK_LT(alternate_mem_interval.start, request.latest_prefetch_time); VLOG(4) << "Trying alternate memory allocation (" << alternate_mem_interval.start << ", " << request.end_time << ")"; // If this additional asynchronous copy would violate the limit, try a // different interval. if (ViolatesMaximumOutstandingAsyncCopies(alternate_mem_interval.start, - request.latest_prefetch_time)) { + request.latest_prefetch_time, + /*is_prefetch=*/true)) { VLOG(4) << "This would violate the outstanding async copy limit."; continue; } @@ -1706,28 +1838,48 @@ AlternateMemoryBestFitHeap::FindBestChunkCandidate( return absl::nullopt; } -/*static*/ int64 MemorySpaceAssignment::CountMaximumOutstandingAsyncCopies( - const HloModule& module) { - int64 max_copies = 0; +StatusOr<MemorySpaceAssignment::AsyncCopyStats> +MemorySpaceAssignment::CalculateAsyncCopyStats() const { + AsyncCopyStats stats; + stats.max_outstanding_async_copies = 0; + stats.num_prefetches = 0; + stats.prefetch_bytes = 0; + stats.num_evictions = 0; + stats.eviction_bytes = 0; int64 current_copies = 0; - for (HloInstruction* instruction : - module.schedule().sequence(module.entry_computation()).instructions()) { - if (instruction->opcode() == HloOpcode::kCopyStart) { - current_copies++; - } else if (instruction->opcode() == HloOpcode::kCopyDone) { - current_copies--; + TF_ASSIGN_OR_RETURN(std::unique_ptr<HloDataflowAnalysis> dataflow_analysis, + HloDataflowAnalysis::Run(*module_)); + for (const HloComputation* computation : + module_->MakeNonfusionComputations()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kCopyStart) { + current_copies++; + } else if (instruction->opcode() == HloOpcode::kCopyDone) { + current_copies--; + int64 size = + options_.size_fn(dataflow_analysis->GetUniqueValueAt(instruction)); + if (instruction->shape().layout().memory_space() == + options_.alternate_memory_space) { + ++stats.num_prefetches; + stats.prefetch_bytes += size; + } else { + ++stats.num_evictions; + stats.eviction_bytes += size; + } + } + stats.max_outstanding_async_copies = + std::max(stats.max_outstanding_async_copies, current_copies); } - max_copies = std::max(max_copies, current_copies); } - return max_copies; + return stats; } /*static*/ MemorySpaceAssignment::BufferIntervalCompare MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare( const MemorySpaceAssignmentCostAnalysis& cost_analysis) { return [&](const BufferInterval& x, const BufferInterval& y) { - float x_memory_boundedness = GetMemoryBoundedness(cost_analysis, x); - float y_memory_boundedness = GetMemoryBoundedness(cost_analysis, y); + float x_memory_boundedness = cost_analysis.GetMemoryBoundedness(x); + float y_memory_boundedness = cost_analysis.GetMemoryBoundedness(y); if (x_memory_boundedness != y_memory_boundedness) { return x_memory_boundedness > y_memory_boundedness; } @@ -1851,8 +2003,13 @@ MemorySpaceAssignment::RunMemorySpaceAssignment( VLOG(3) << "Module after memory space assignment: "; XLA_VLOG_LINES(3, module_->ToString()); TF_CHECK_OK(module_->schedule().Verify()); + TF_ASSIGN_OR_RETURN(AsyncCopyStats stats, CalculateAsyncCopyStats()); VLOG(1) << "Maximum number of outstanding async copies: " - << CountMaximumOutstandingAsyncCopies(*module_); + << stats.max_outstanding_async_copies; + VLOG(1) << "Number of prefetches: " << stats.num_prefetches + << ", in bytes: " << stats.prefetch_bytes; + VLOG(1) << "Number of evictions: " << stats.num_evictions + << ", in bytes: " << stats.eviction_bytes; TF_RETURN_IF_ERROR(VerifyAndExportHeapSimulatorTrace()); @@ -2411,6 +2568,34 @@ Status MemorySpaceAssignment::VerifyAndExportHeapSimulatorTrace() { std::tuple<const HloValue*, Chunk, HeapSimulatorTrace::Event::Kind>> events; + auto add_allocation_and_verify = [&](int64 start_time, int64 end_time, + const Chunk& chunk, + const HloValue* value) { + events[std::make_tuple(start_time, /*is_free=*/false, value->id())] = + std::make_tuple(value, chunk, HeapSimulatorTrace::Event::ALLOC); + events[std::make_tuple(end_time, /*is_free=*/true, value->id())] = + std::make_tuple(value, chunk, HeapSimulatorTrace::Event::FREE); + + // Get the chunks overlapping in time and search if they overlap in space + // as well. + // TODO(berkin): For now checking against end_time - 1 (exclusive), but we + // really should check against end_time (inclusive) for cases where the + // operand can't share buffer with user (see + // HloDataflowAnalysis::CanShareOperandBufferWithUser). + for (const Chunk& overlapping_chunk : + interval_tree.ChunksOverlappingInTime(start_time, end_time - 1)) { + if (chunk.OverlapsWith(overlapping_chunk)) { + return InternalError( + ("Value %s (%d, %d) off: %d size: %d overlaps with another chunk" + " off: %d size: %d"), + value->ToShortString(), start_time, end_time, chunk.offset, + chunk.size, overlapping_chunk.offset, overlapping_chunk.size); + } + } + interval_tree.Add(start_time, end_time - 1, chunk); + return Status::OK(); + }; + // Go through all instructions in the module to ensure CopyStart/CopyDone // instructions copy between alternate memory and default memory. for (const HloComputation* computation : @@ -2446,34 +2631,73 @@ Status MemorySpaceAssignment::VerifyAndExportHeapSimulatorTrace() { for (const HloValue* value : buffer.values()) { const HloLiveRange::TimeBound& time_bound = hlo_live_range->buffer_live_ranges().at(value); - events[std::make_tuple(time_bound.start, /*is_free=*/false, - value->id())] = - std::make_tuple(value, chunk, HeapSimulatorTrace::Event::ALLOC); - events[std::make_tuple(time_bound.end, /*is_free=*/true, value->id())] = - std::make_tuple(value, chunk, HeapSimulatorTrace::Event::FREE); - - VLOG(3) << " buffer: " << buffer.ToString() - << " value: " << value->ToShortString() << ": (" - << time_bound.start << ", " << time_bound.end - << ") off: " << chunk.offset << ", size: " << chunk.size; - // Get the chunks overlapping in time and search if they overlap in space - // as well. - // TODO(berkin): For now checking against end_time - 1 (exclusive), but we - // really should check against end_time (inclusive) for cases where the - // operand can't share buffer with user (see - // HloDataflowAnalysis::CanShareOperandBufferWithUser). - for (const Chunk& overlapping_chunk : - interval_tree.ChunksOverlappingInTime(time_bound.start, - time_bound.end - 1)) { - if (chunk.OverlapsWith(overlapping_chunk)) { - return InternalError( - ("Buffer %s (%d, %d) off: %d size: %d overlaps with another chunk" - " off: %d size: %d"), - buffer.ToString(), time_bound.start, time_bound.end, chunk.offset, - chunk.size, overlapping_chunk.offset, overlapping_chunk.size); + const HloInstruction* last_use_instruction = nullptr; + int64 last_use_time = time_bound.start; + for (const HloUse& use : value->uses()) { + int64 use_time = + hlo_live_range->instruction_schedule().at(use.instruction); + if (use_time > last_use_time) { + last_use_time = use_time; + last_use_instruction = use.instruction; } } - interval_tree.Add(time_bound.start, time_bound.end - 1, chunk); + + if (last_use_instruction && + last_use_instruction->opcode() == HloOpcode::kConditional) { + // Special case when verifying conditional: we internally split the use + // of alternate memory in conditionals, so fish them out from the + // conditionals. + VLOG(3) << " Splitting conditional buffer: " << buffer.ToString() + << " value: " << value->ToShortString() << ": (" + << time_bound.start << ", " << time_bound.end + << ") off: " << chunk.offset << ", size: " << chunk.size; + int64 earliest_computation_start_time = time_bound.end; + for (const HloComputation* called_computation : + last_use_instruction->called_computations()) { + earliest_computation_start_time = + std::min(earliest_computation_start_time, + hlo_live_range->computation_span_times() + .at(called_computation) + .start); + int64 parameter_time = -1; + int64 last_use_time = -1; + for (const HloPosition& position : value->positions()) { + if (position.instruction->opcode() == HloOpcode::kParameter && + position.instruction->parent() == called_computation) { + parameter_time = hlo_live_range->instruction_schedule().at( + position.instruction); + break; + } + } + for (const HloUse& use : value->uses()) { + if (use.instruction->parent() == called_computation) { + last_use_time = std::max( + last_use_time, + hlo_live_range->instruction_schedule().at(use.instruction)); + } + } + if (last_use_time != -1) { + CHECK_NE(parameter_time, -1); + VLOG(3) << " computation: " << called_computation->name() << ": (" + << parameter_time << ", " << last_use_time << ")"; + TF_RETURN_IF_ERROR(add_allocation_and_verify( + parameter_time, last_use_time, chunk, value)); + } + } + VLOG(3) << " from beginning until first computation: (" + << time_bound.start << ", " + << (earliest_computation_start_time - 1) << ")"; + TF_RETURN_IF_ERROR(add_allocation_and_verify( + time_bound.start, earliest_computation_start_time - 1, chunk, + value)); + } else { + VLOG(3) << " buffer: " << buffer.ToString() + << " value: " << value->ToShortString() << ": (" + << time_bound.start << ", " << time_bound.end + << ") off: " << chunk.offset << ", size: " << chunk.size; + TF_RETURN_IF_ERROR(add_allocation_and_verify( + time_bound.start, time_bound.end, chunk, value)); + } } } diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.h b/tensorflow/compiler/xla/service/memory_space_assignment.h index eb16db90600..3f59abfd28e 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.h +++ b/tensorflow/compiler/xla/service/memory_space_assignment.h @@ -82,16 +82,31 @@ class MemorySpaceAssignmentCostAnalysis { const HloCostAnalysis& cost_analysis, float async_copy_bandwidth_bytes_per_second, float alternate_mem_bandwidth_bytes_per_second, - const HloLiveRange& hlo_live_range) + const HloLiveRange& hlo_live_range, const CallGraph& call_graph) : cost_analysis_(cost_analysis), async_copy_bandwidth_bytes_per_second_( async_copy_bandwidth_bytes_per_second), alternate_mem_bandwidth_bytes_per_second_( alternate_mem_bandwidth_bytes_per_second), - hlo_live_range_(hlo_live_range) {} + hlo_live_range_(hlo_live_range), + call_graph_(call_graph) {} const HloCostAnalysis& cost_analysis() const { return cost_analysis_; } + // Returns a heuristic value that captures how much putting this tensor to the + // alternate memory would help if the op is memory bound, or otherwise how far + // off is the op to memory boundedness. The larger this number, the higher + // priority it will be placed in the alternate memory. + float GetAlternateMemoryBenefit( + const HloInstruction& instruction, + float elapsed_time_due_to_alternate_mem) const; + + // Returns a heuristic value of memory boundedness for the given + // BufferInterval. The larger this number, the higher priority it will be + // placed in the alternate memory. + float GetMemoryBoundedness( + const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const; + // Returns the elapsed time in seconds due to compute only. float GetInstructionElapsedDueToCompute( const HloInstruction& instruction) const; @@ -127,6 +142,10 @@ class MemorySpaceAssignmentCostAnalysis { int64 GetScheduleEndTime() const; + // Returns the number of nested while loop levels this instruction resides in. + // 0 means it is not in a while loop. + int CalculateWhileLoopNestLevel(const HloInstruction* instruction) const; + const HloLiveRange& hlo_live_range() const { return hlo_live_range_; } private: @@ -134,6 +153,7 @@ class MemorySpaceAssignmentCostAnalysis { float async_copy_bandwidth_bytes_per_second_; float alternate_mem_bandwidth_bytes_per_second_; const HloLiveRange& hlo_live_range_; + const CallGraph& call_graph_; }; // Abstract base class that memory space assignment uses to pick prefetch @@ -262,10 +282,10 @@ class CostAnalysisPrefetchIntervalPicker : public PrefetchIntervalPicker { // corresponds to the instruction schedule. float GetLogicalIntervalElapsed(int64 start_time, int64 end_time) const; - // For performance reasons, we calculate the prefix sum of the elapsed time so - // that it's efficient to find the elapsed time in seconds in any logical - // interval. - std::vector<float> elapsed_time_cumsum_; + // For each instruction in the flattened schedule, maintain their elapsed time + // and while nesting level. + std::vector<float> elapsed_time_; + std::vector<int> while_nest_level_; const MemorySpaceAssignmentCostAnalysis& cost_analysis_; float min_async_copy_to_overlap_ratio_; @@ -323,9 +343,10 @@ class MemorySpaceAssignment { // the opcode) to be placed on the alternate memory. IsAllowedInAlternateMemoryFunction is_allowed_in_alternate_mem_fn; - // Specifies the upper bound for number of outstanding asynchronous copies, - // -1 for unlimited. - int64 max_outstanding_async_copies = -1; + // Specifies the upper bound for number of outstanding prefetches and + // evictions, -1 for unlimited. + int64 max_outstanding_prefetches = -1; + int64 max_outstanding_evictions = -1; // If true, tries allocating buffers across (e.g., before and inside a while // loop body) sequential calls (kWhile, kCall, and kConditional). @@ -604,6 +625,15 @@ class MemorySpaceAssignment { AllocationSequence allocation_sequence_; }; + // Statistics of asynchronous copies. + struct AsyncCopyStats { + int64 max_outstanding_async_copies; + int64 num_prefetches; + int64 prefetch_bytes; + int64 num_evictions; + int64 eviction_bytes; + }; + virtual ~MemorySpaceAssignment() = default; // Runs the MemorySpaceAssignment pass. @@ -611,9 +641,8 @@ class MemorySpaceAssignment { HloModule* module, const HloLiveRange& hlo_live_range, const HloAliasAnalysis& alias_analysis, const Options& options); - // Returns the maximum number of outstanding asynchronous copies in the - // module. - static int64 CountMaximumOutstandingAsyncCopies(const HloModule& module); + // Calculates asynchronous copy statistics. + StatusOr<AsyncCopyStats> CalculateAsyncCopyStats() const; static BufferIntervalCompare GetMemoryBoundednessBufferIntervalCompare( const MemorySpaceAssignmentCostAnalysis& cost_analysis); @@ -808,11 +837,16 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { // use_times is a sorted sequence of the times of all uses. // latest_prefetch_time is the latest time we can schedule the CopyDone for a // prefetch. + // If allow_no_copy_alternate_mem_allocation is false, an eviction is forced. + // If earliest_prefetch_time is set, prefetches cannot start before this + // value. struct AllocationRequest { int64 start_time; int64 end_time; int64 latest_prefetch_time; int64 size; + bool allow_no_copy_alternate_mem_allocation; + absl::optional<int64> earliest_prefetch_time; absl::optional<int64> preferred_offset; HloUse use; MemorySpaceAssignment::AllocationValue* allocation_value; @@ -833,7 +867,8 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { bool IsIntervalAllowedInAlternateMemory(const BufferInterval& interval) const; // Returns true if the use is allowed in the alternate memory. - bool IsUseAllowedInAlternateMemory(const HloUse& use) const; + bool IsUseAllowedInAlternateMemory(const AllocationValue& value, + const HloUse& use) const; // Given an HloValue, creates AllocationValue objects and corresponding // AllocationSequences and appends them into allocation_sequence_list_. @@ -887,6 +922,16 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { const HloInstruction* instruction, ShapeIndex index, const MemorySpaceAssignment::Allocation* aliased_allocation); + // This sets a required assignment. CHECK fails if there is a conflicting + // required assignment at the same time. + void AddRequiredAssignment(const HloValue* value, + const HloInstruction* instruction, + MemorySpace memory_space, int64 time, + absl::optional<Chunk> chunk = absl::nullopt); + void AddRequiredAssignment(const HloInstruction* instruction, + ShapeIndex index, MemorySpace memory_space, + absl::optional<Chunk> chunk = absl::nullopt); + // Adds input and outputs as required assignments. void AddInputAndOutputRequiredAssignments(); @@ -909,8 +954,8 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { // Returns true if the addition of an asynchronous copy in the given time // interval would violate the maximum number of asynchronous copies. - bool ViolatesMaximumOutstandingAsyncCopies(int64 start_time, - int64 end_time) const; + bool ViolatesMaximumOutstandingAsyncCopies(int64 start_time, int64 end_time, + bool is_prefetch) const; // Return true if the asynchronous copy would violate the pipelining order. bool ViolatesAsyncCopyOrdering(int64 start_time, int64 end_time) const; @@ -953,8 +998,9 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { const HloAliasAnalysis& alias_analysis_; const HloLiveRange& hlo_live_range_; // We use a interval tree to keep track of the number of outstanding - // asynchronous copies. - BufferIntervalTree async_copy_interval_tree_; + // prefetches and evictions. + BufferIntervalTree prefetch_interval_tree_; + BufferIntervalTree eviction_interval_tree_; AsynchronousCopyOrdering async_copy_ordering_; std::vector<std::pair<BufferInterval, ChunkCandidate>> pending_chunks_; std::vector<AsynchronousCopy> pending_async_copies_; @@ -964,7 +1010,6 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { required_assignments_; // Number of bytes reserved in alternate memory space. int64 reserved_in_bytes_ = 0; - int64 global_max_time_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc index b2125d318d0..0a76dd5f31c 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc @@ -57,9 +57,10 @@ class MemorySpaceAssignmentTest : public HloTestBase, HloLiveRange::Run(module->schedule(), *alias_analysis, module->entry_computation()) .ValueOrDie(); + std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module); MemorySpaceAssignmentCostAnalysis cost_analysis( hlo_cost_analysis, kAsyncCopyBandwidth, kAlternateMemBandwidth, - *hlo_live_range); + *hlo_live_range, *call_graph); CostAnalysisPrefetchIntervalPicker prefetch_interval_picker( CostAnalysisPrefetchIntervalPicker( cost_analysis, /*min_async_copy_to_overlap_ratio=*/0.8, @@ -126,7 +127,8 @@ class MemorySpaceAssignmentTest : public HloTestBase, options.prefetch_interval_picker = prefetch_interval_picker; options.size_fn = size_fn; options.is_allowed_in_alternate_mem_fn = is_allowed_in_alternate_mem; - options.max_outstanding_async_copies = max_outstanding_async_copies; + options.max_outstanding_prefetches = max_outstanding_async_copies; + options.max_outstanding_evictions = max_outstanding_async_copies; options.allocate_across_sequential_calls = GetParam(); options.verify = true; @@ -184,6 +186,47 @@ class MemorySpaceAssignmentTest : public HloTestBase, } } + struct OutstandingAsyncCopies { + int64 max_copies; + int64 max_prefetches; + int64 max_evictions; + }; + + /*static*/ OutstandingAsyncCopies CountMaximumOutstandingAsyncCopies( + const HloModule& module) { + OutstandingAsyncCopies copies{0, 0, 0}; + int64 current_copies = 0; + int64 current_prefetches = 0; + int64 current_evictions = 0; + for (HloInstruction* instruction : module.schedule() + .sequence(module.entry_computation()) + .instructions()) { + if (instruction->opcode() == HloOpcode::kCopyStart) { + current_copies++; + if (ShapeUtil::GetSubshape(instruction->shape(), {0}) + .layout() + .memory_space() == kAlternateMemorySpace) { + current_prefetches++; + } else { + current_evictions++; + } + } else if (instruction->opcode() == HloOpcode::kCopyDone) { + current_copies--; + if (instruction->shape().layout().memory_space() == + kAlternateMemorySpace) { + current_prefetches--; + } else { + current_evictions--; + } + } + copies.max_copies = std::max(copies.max_copies, current_copies); + copies.max_prefetches = + std::max(copies.max_prefetches, current_prefetches); + copies.max_prefetches = std::max(copies.max_evictions, current_evictions); + } + return copies; + } + std::unique_ptr<HloModule> CreateEvictAndPrefetchModule() { HloComputation::Builder builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); @@ -391,8 +434,8 @@ TEST_P(MemorySpaceAssignmentTest, EvictAndPrefetchLimitAsyncCopies0) { AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/0); - EXPECT_EQ(MemorySpaceAssignment::CountMaximumOutstandingAsyncCopies(*module), - 0); + EXPECT_LE(CountMaximumOutstandingAsyncCopies(*module).max_prefetches, 0); + EXPECT_LE(CountMaximumOutstandingAsyncCopies(*module).max_evictions, 0); } TEST_P(MemorySpaceAssignmentTest, EvictAndPrefetchLimitAsyncCopies1) { @@ -400,8 +443,8 @@ TEST_P(MemorySpaceAssignmentTest, EvictAndPrefetchLimitAsyncCopies1) { AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/1); - EXPECT_EQ(MemorySpaceAssignment::CountMaximumOutstandingAsyncCopies(*module), - 1); + EXPECT_LE(CountMaximumOutstandingAsyncCopies(*module).max_prefetches, 1); + EXPECT_LE(CountMaximumOutstandingAsyncCopies(*module).max_evictions, 1); } TEST_P(MemorySpaceAssignmentTest, EvictAndPrefetchLimitAsyncCopies2) { @@ -409,8 +452,8 @@ TEST_P(MemorySpaceAssignmentTest, EvictAndPrefetchLimitAsyncCopies2) { AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/2); - EXPECT_EQ(MemorySpaceAssignment::CountMaximumOutstandingAsyncCopies(*module), - 2); + EXPECT_LE(CountMaximumOutstandingAsyncCopies(*module).max_prefetches, 2); + EXPECT_LE(CountMaximumOutstandingAsyncCopies(*module).max_evictions, 2); } // TODO(berkin): This test is broken with some prefetch timing improvements. @@ -1650,6 +1693,324 @@ TEST_P(MemorySpaceAssignmentTest, ControlPredecessorsBug) { AssignMemorySpace(module.get()); } +TEST_P(MemorySpaceAssignmentTest, ConditionalShouldBeAllocatedInAlternateMem) { + // Checks if simple conditionals get alternate memory allocations. + absl::string_view hlo_string = R"( + HloModule CondAllocation, is_scheduled=true + + true_computation { + p0 = (f32[3]{0}) parameter(0) + gte = f32[3]{0} get-tuple-element(p0), index=0 + ROOT neg1 = f32[3]{0} negate(gte) + } + + false_computation { + p0 = (f32[3]{0}) parameter(0) + gte = f32[3]{0} get-tuple-element(p0), index=0 + ROOT neg2 = f32[3]{0} negate(gte) + } + + ENTRY entry { + p0 = f32[3]{0} parameter(0) + p1 = pred[] parameter(1) + copy = f32[3]{0} copy(p0) + tuple = (f32[3]{0}) tuple(copy) + ROOT conditional = f32[3]{0} conditional(p1, tuple, tuple), true_computation=true_computation, false_computation=false_computation + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + AssignMemorySpace(module.get()); + + if (GetParam()) { + // Check that copy and gtes got alternate memory allocations. + auto copy = + module->GetComputationWithName("entry")->GetInstructionWithName("copy"); + EXPECT_EQ(copy->shape().layout().memory_space(), kAlternateMemorySpace); + auto neg1 = module->GetComputationWithName("true_computation") + ->GetInstructionWithName("neg1"); + auto neg1_operand = neg1->operand(0); + EXPECT_EQ(neg1_operand->shape().layout().memory_space(), + kAlternateMemorySpace); + auto neg2 = module->GetComputationWithName("false_computation") + ->GetInstructionWithName("neg2"); + auto neg2_operand = neg2->operand(0); + EXPECT_EQ(neg2_operand->shape().layout().memory_space(), + kAlternateMemorySpace); + } +} + +TEST_P(MemorySpaceAssignmentTest, ConditionalAvoidsUnnecessaryPrefetch) { + // Checks if we avoid unnecessary allocation in alternate memory if the input + // won't be used in the computation for a long time. + absl::string_view hlo_string = R"( + HloModule CondAllocation, is_scheduled=true + + true_computation { + p0 = (f32[3]{0}, f32[3]{0}) parameter(0) + gte0 = f32[3]{0} get-tuple-element(p0), index=0 + neg0 = f32[3]{0} negate(gte0) + neg1 = f32[3]{0} negate(neg0) + neg2 = f32[3]{0} negate(neg1) + neg3 = f32[3]{0} negate(neg2) + neg4 = f32[3]{0} negate(neg3) + neg5 = f32[3]{0} negate(neg4) + neg6 = f32[3]{0} negate(neg5) + neg7 = f32[3]{0} negate(neg6) + neg8 = f32[3]{0} negate(neg7) + neg9 = f32[3]{0} negate(neg8) + gte1 = f32[3]{0} get-tuple-element(p0), index=1 + ROOT add = f32[3]{0} add(neg9, gte1) + } + + false_computation { + p0 = (f32[3]{0}) parameter(0) + gte = f32[3]{0} get-tuple-element(p0), index=0 + ROOT neg = f32[3]{0} negate(gte) + } + + ENTRY entry { + p0 = f32[3]{0} parameter(0) + p1 = pred[] parameter(1) + copy0 = f32[3]{0} copy(p0) + copy1 = f32[3]{0} copy(p0) + tuple0 = (f32[3]{0}, f32[3]{0}) tuple(copy0, copy1) + tuple1 = (f32[3]{0}) tuple(copy0) + ROOT conditional = f32[3]{0} conditional(p1, tuple0, tuple1), true_computation=true_computation, false_computation=false_computation + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + AssignMemorySpace(module.get()); + + if (GetParam()) { + // Check that copy1 doesn't get unnecessarily allocated in alternate mem + // (due to long negate chain in true_computation) but is prefetched before + // add. + auto copy0 = + module->GetComputationWithName("entry")->GetInstructionWithName( + "copy0"); + EXPECT_EQ(copy0->shape().layout().memory_space(), kAlternateMemorySpace); + auto copy1 = + module->GetComputationWithName("entry")->GetInstructionWithName( + "copy1"); + EXPECT_EQ(copy1->shape().layout().memory_space(), kDefaultMemorySpace); + auto add = module->GetComputationWithName("true_computation") + ->GetInstructionWithName("add"); + auto add_operand = add->operand(1); + EXPECT_EQ(add_operand->shape().layout().memory_space(), + kAlternateMemorySpace); + } +} + +TEST_P(MemorySpaceAssignmentTest, ConditionalMultiUse) { + // Make sure there is an evict when there is a conditional use followed by + // another use. + absl::string_view hlo_string = R"( + HloModule CondAllocation, is_scheduled=true + + true_computation { + p0 = (f32[3]{0}, f32[3]{0}) parameter(0) + gte0 = f32[3]{0} get-tuple-element(p0), index=0 + gte1 = f32[3]{0} get-tuple-element(p0), index=1 + add0 = f32[3]{0} add(gte0, gte1) + neg0 = f32[3]{0} negate(add0) + neg1 = f32[3]{0} negate(neg0) + neg2 = f32[3]{0} negate(neg1) + neg3 = f32[3]{0} negate(neg2) + neg4 = f32[3]{0} negate(neg3) + neg5 = f32[3]{0} negate(neg4) + neg6 = f32[3]{0} negate(neg5) + neg7 = f32[3]{0} negate(neg6) + neg8 = f32[3]{0} negate(neg7) + ROOT neg9 = f32[3]{0} negate(neg8) + } + + false_computation { + p0 = (f32[3]{0}) parameter(0) + gte = f32[3]{0} get-tuple-element(p0), index=0 + ROOT neg = f32[3]{0} negate(gte) + } + + ENTRY entry { + p0 = f32[3]{0} parameter(0) + p1 = pred[] parameter(1) + copy0 = f32[3]{0} copy(p0) + copy1 = f32[3]{0} copy(p0) + tuple0 = (f32[3]{0}, f32[3]{0}) tuple(copy0, copy1) + tuple1 = (f32[3]{0}) tuple(copy0) + conditional = f32[3]{0} conditional(p1, tuple0, tuple1), true_computation=true_computation, false_computation=false_computation + ROOT add1 = f32[3]{0} add(copy1, conditional) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + AssignMemorySpace(module.get()); + + if (GetParam()) { + // Make sure the copy1->add edge is in alternate memory. Before conditional, + // this should be evicted to default memory and neg uses the input from + // default memory. + auto copy1 = + module->GetComputationWithName("entry")->GetInstructionWithName( + "copy1"); + EXPECT_EQ(copy1->shape().layout().memory_space(), kAlternateMemorySpace); + auto add0 = module->GetComputationWithName("true_computation") + ->GetInstructionWithName("add0"); + auto add0_operand = add0->operand(1); + EXPECT_EQ(add0_operand->shape().layout().memory_space(), + kAlternateMemorySpace); + auto add1 = + module->GetComputationWithName("entry")->GetInstructionWithName("add1"); + auto add1_operand = add1->operand(0); + EXPECT_EQ(add1_operand->shape().layout().memory_space(), + kDefaultMemorySpace); + EXPECT_EQ(add1_operand->opcode(), HloOpcode::kCopyDone); + } +} + +TEST_P(MemorySpaceAssignmentTest, ConditionalMultiUseInWhile) { + absl::string_view hlo_string = R"( + HloModule CondAllocation, is_scheduled=true + + true_computation { + p0 = (f32[3]{0}) parameter(0) + gte = f32[3]{0} get-tuple-element(p0), index=0 + ROOT neg1 = f32[3]{0} negate(gte) + } + + false_computation { + p0 = (f32[3]{0}) parameter(0) + gte = f32[3]{0} get-tuple-element(p0), index=0 + ROOT neg2 = f32[3]{0} negate(gte) + } + + while_cond { + p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0) + ROOT gte = pred[] get-tuple-element(p0), index=2 + } + + while_body { + p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0) + gte0 = f32[3]{0} get-tuple-element(p0), index=0 + gte1 = f32[3]{0} get-tuple-element(p0), index=1 + gte2 = pred[] get-tuple-element(p0), index=2 + cond_tuple = (f32[3]{0}) tuple(gte0) + conditional = f32[3]{0} conditional(gte2, cond_tuple, cond_tuple), true_computation=true_computation, false_computation=false_computation + add = f32[3]{0} add(conditional, gte1) + neg0 = f32[3]{0} negate(add) + neg1 = f32[3]{0} negate(neg0) + ROOT tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(gte0, neg1, gte2) + } + + ENTRY entry { + p0 = f32[3]{0} parameter(0) + p1 = pred[] parameter(1) + copy0 = f32[3]{0} copy(p0) + copy1 = f32[3]{0} copy(p0) + tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(copy0, copy1, p1) + while = (f32[3]{0}, f32[3]{0}, pred[]) while(tuple), condition=while_cond, body=while_body + ROOT gte = f32[3]{0} get-tuple-element(while), index=1 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + AssignMemorySpace(module.get()); + + if (GetParam()) { + // Make sure copy1/while{0}/cond_tuple{0} gets alternate memory allocation. + // This will force an eviction and a prefetch for while body root. + auto copy0 = + module->GetComputationWithName("entry")->GetInstructionWithName( + "copy0"); + EXPECT_EQ(copy0->shape().layout().memory_space(), kAlternateMemorySpace); + auto conditional = module->GetComputationWithName("while_body") + ->GetInstructionWithName("conditional"); + auto conditional_operand = conditional->operand(1); + EXPECT_EQ(ShapeUtil::GetSubshape(conditional_operand->shape(), {0}) + .layout() + .memory_space(), + kAlternateMemorySpace); + auto while_root = + module->GetComputationWithName("while_body")->root_instruction(); + auto while_root_operand = while_root->operand(0); + EXPECT_THAT( + while_root_operand, + op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace, + op::AsyncCopy(kDefaultMemorySpace, kAlternateMemorySpace, + op::GetTupleElement(op::Parameter(0))))); + } +} + +TEST_P(MemorySpaceAssignmentTest, NestedConditional) { + absl::string_view hlo_string = R"( + HloModule CondAllocation, is_scheduled=true + + true_computation2 { + p0 = (f32[3]{0}) parameter(0) + gte = f32[3]{0} get-tuple-element(p0), index=0 + ROOT neg1 = f32[3]{0} negate(gte) + } + + false_computation2 { + p0 = (f32[3]{0}) parameter(0) + gte = f32[3]{0} get-tuple-element(p0), index=0 + ROOT neg2 = f32[3]{0} negate(gte) + } + + true_computation1 { + p0 = (f32[3]{0}) parameter(0) + gte = f32[3]{0} get-tuple-element(p0), index=0 + slice = f32[1]{0} slice(gte), slice={[0:1]} + bitcast = f32[] bitcast(slice) + constant = f32[] constant(0.0) + compare = pred[] compare(bitcast, constant), direction=GT + ROOT conditional = f32[3]{0} conditional(compare, p0, p0), true_computation=true_computation2, false_computation=false_computation2 + } + + false_computation1 { + p0 = (f32[3]{0}) parameter(0) + gte = f32[3]{0} get-tuple-element(p0), index=0 + ROOT neg3 = f32[3]{0} negate(gte) + } + + + ENTRY entry { + p0 = f32[3]{0} parameter(0) + p1 = pred[] parameter(1) + copy = f32[3]{0} copy(p0) + tuple = (f32[3]{0}) tuple(copy) + ROOT conditional = f32[3]{0} conditional(p1, tuple, tuple), true_computation=true_computation1, false_computation=false_computation1 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + AssignMemorySpace(module.get()); + + if (GetParam()) { + // Make sure alternate memory allocation gets propagated into both levels of + // conditional. + auto copy = + module->GetComputationWithName("entry")->GetInstructionWithName("copy"); + EXPECT_EQ(copy->shape().layout().memory_space(), kAlternateMemorySpace); + auto neg1_operand = module->GetComputationWithName("true_computation2") + ->GetInstructionWithName("neg1") + ->operand(0); + auto neg2_operand = module->GetComputationWithName("false_computation2") + ->GetInstructionWithName("neg2") + ->operand(0); + auto neg3_operand = module->GetComputationWithName("false_computation1") + ->GetInstructionWithName("neg3") + ->operand(0); + EXPECT_EQ(neg1_operand->shape().layout().memory_space(), + kAlternateMemorySpace); + EXPECT_EQ(neg2_operand->shape().layout().memory_space(), + kAlternateMemorySpace); + EXPECT_EQ(neg3_operand->shape().layout().memory_space(), + kAlternateMemorySpace); + } +} + TEST_P(MemorySpaceAssignmentTest, RequestIdentifierShouldNotBeAllocatedInAlternateMem) { // Ensure that request identifier returned by Send/Recv HLOs are not allocated @@ -2136,7 +2497,8 @@ TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule3) { AssignMemorySpace(module.get(), -1, 5); } -TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule4) { +// TODO(berkin): This might be an incorrect input graph, investigate. +TEST_P(MemorySpaceAssignmentTest, DISABLED_NonEntryComputationSchedule4) { auto module = CreateNewVerifiedModule(); Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3}); Shape shape2 = ShapeUtil::MakeShape(xla::F32, {3, 3}); diff --git a/tensorflow/compiler/xla/service/mlir_gpu/BUILD b/tensorflow/compiler/xla/service/mlir_gpu/BUILD index a57e4300d6e..07655a61074 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/BUILD +++ b/tensorflow/compiler/xla/service/mlir_gpu/BUILD @@ -185,10 +185,10 @@ cc_library( "@llvm-project//mlir:LinalgOps", "@llvm-project//mlir:LinalgToLLVM", "@llvm-project//mlir:LinalgTransforms", - "@llvm-project//mlir:LoopsToGPUPass", "@llvm-project//mlir:NVVMDialect", "@llvm-project//mlir:Pass", "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:SCFToGPUPass", "@llvm-project//mlir:SCFTransforms", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Support", diff --git a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter_test.cc b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter_test.cc index 56684b1f726..d5cad385324 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter_test.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include <vector> #include "llvm/Support/raw_ostream.h" -#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h" // from @llvm-project +#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" // from @llvm-project #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project diff --git a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc index 847ad918308..4645b084eb6 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc @@ -19,8 +19,8 @@ limitations under the License. #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" // from @llvm-project #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" // from @llvm-project -#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h" // from @llvm-project -#include "mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h" // from @llvm-project +#include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h" // from @llvm-project +#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" // from @llvm-project #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" // from @llvm-project #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" // from @llvm-project #include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project @@ -351,7 +351,7 @@ struct FixKernelFunctionSignatures struct MapParallelLoops : public mlir::PassWrapper<MapParallelLoops, mlir::FunctionPass> { void runOnFunction() override { - mlir::greedilyMapParallelLoopsToGPU(getFunction().getBody()); + mlir::greedilyMapParallelSCFToGPU(getFunction().getBody()); } }; diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 8d6ef9faba9..0ea7912c95c 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -2001,7 +2001,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /* static */ StatusOr<Shape> ShapeInference::InferAllGatherShape( const Shape& operand_shape, int64 all_gather_dimension, int64 shard_count) { - TF_RET_CHECK(all_gather_dimension > 0); + TF_RET_CHECK(all_gather_dimension >= 0); TF_RET_CHECK(all_gather_dimension < operand_shape.rank()); TF_RET_CHECK(shard_count > 0); auto shape = operand_shape; diff --git a/tensorflow/compiler/xla/service/sharding_propagation.cc b/tensorflow/compiler/xla/service/sharding_propagation.cc new file mode 100644 index 00000000000..bee2e04fabf --- /dev/null +++ b/tensorflow/compiler/xla/service/sharding_propagation.cc @@ -0,0 +1,1478 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/sharding_propagation.h" + +#include <algorithm> +#include <list> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/str_split.h" +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/service/dot_as_convolution_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_sharding.h" +#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" +#include "tensorflow/compiler/xla/service/hlo_sharding_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +using ComputationMap = + absl::flat_hash_map<const HloComputation*, HloInstruction*>; + +// Returns true iff the specified hlo or sharding has a spatially partitioned +// sharding (tiled or replicated) what can be propagated by sharding +// propagation. +bool IsSpatiallyPartitioned(const HloSharding& sharding) { + if (sharding.IsTuple()) { + return absl::c_any_of(sharding.tuple_elements(), IsSpatiallyPartitioned); + } else { + return !sharding.IsTileMaximal() || sharding.IsReplicated(); + } +} +bool IsSpatiallyPartitioned(const HloInstruction* hlo) { + return hlo->has_sharding() && IsSpatiallyPartitioned(hlo->sharding()); +} + +// Returns true if the lhs sharding is preferable over the rhs sharding. +// The most specific sharding is tile maximal followed by single device tile +// maximal and finally replicated. This order aims to primarily reduce memory +// usage and secondly reduce total compute. +// Note: This does NOT provide a total ordering as we can have 2 different +// sharding with same preference level. +bool IsShardingMoreSpecific(const HloSharding& lhs, const HloSharding& rhs) { + CHECK_EQ(lhs.IsTuple(), rhs.IsTuple()); + if (lhs.IsTuple()) { + // For tuples we consider lhs to have a better sharding if none of the + // elements are worse and at least one element is better then in rhs + // sharding. + const auto& lhs_shardings = lhs.tuple_elements(); + const auto& rhs_shardings = rhs.tuple_elements(); + CHECK_EQ(lhs_shardings.size(), rhs_shardings.size()); + bool is_better = false; + for (int64 i = 0; i < lhs_shardings.size(); ++i) { + if (IsShardingMoreSpecific(rhs_shardings[i], lhs_shardings[i])) { + return false; + } + if (IsShardingMoreSpecific(lhs_shardings[i], rhs_shardings[i])) { + is_better = true; + } + } + return is_better; + } + if (!rhs.IsTileMaximal()) { + // If we already have a non-tile-maximal sharding then we can't improve + // that. + return false; + } else if (!rhs.IsReplicated()) { + // If we are not replicated then only tiled (not tile maximal) shardings + // can improve us. + return !lhs.IsTileMaximal(); + } else { + // If we are replicated then any non-replicated sharding can improve us. + return !lhs.IsReplicated(); + } +} + +// Returns a sharding where each tuple element is chosen as the more specific +// one of the corresponding elements in a and b. Requires a an b to have the +// same tuple nesting. +HloSharding MergeForMoreSpecificSharding(const HloSharding& a, + const HloSharding& b) { + if (a.IsTuple()) { + HloSharding result = a; + CHECK(b.IsTuple()); + CHECK_EQ(a.tuple_elements().size(), b.tuple_elements().size()); + for (int64 i = 0; i < result.tuple_elements().size(); ++i) { + result.tuple_elements()[i] = MergeForMoreSpecificSharding( + a.tuple_elements()[i], b.tuple_elements()[i]); + } + return result; + } + return IsShardingMoreSpecific(a, b) ? a : b; +} + +// Updates the sharding of the specified instruction with the specified sharding +// if it is better than the current one and returns true if a new sharding have +// been applied. +bool MaybeImproveInstructionSharding(const HloSharding& sharding, + HloInstruction* instruction) { + // We don't want to propagate tile maximal shardings. + if (!IsSpatiallyPartitioned(sharding)) { + return false; + } + // Any sharding is better then no sharding. + if (!instruction->has_sharding()) { + instruction->set_sharding(sharding); + return true; + } + if (IsShardingMoreSpecific(sharding, instruction->sharding())) { + instruction->set_sharding(sharding); + return true; + } + return false; +} + +// Sets the sharding for every element within a tuple to replicated (default +// sharding). This is necessary because there is no way to represent a tuple +// sharding when only some of the elements are sharded. +void SetDefaultTupleSharding(HloInstruction* instruction) { + instruction->set_sharding( + HloSharding::SingleTuple(instruction->shape(), HloSharding::Replicate())); +} + +// We consider a convolution kernel to be small iff it is smaller along all +// spatial dimensions then the output of the convolution. The rational is that +// we can either shard the kernel or the output and we want to shard the larger +// one for better efficiency. +bool IsConvolutionKernelSmall(const HloInstruction* instruction) { + CHECK_EQ(instruction->opcode(), HloOpcode::kConvolution); + const HloInstruction* rhs = instruction->operand(1); + const auto& dnums = instruction->convolution_dimension_numbers(); + for (int64 i = 0; i < dnums.input_spatial_dimensions().size(); ++i) { + int64 kernel_dim = + rhs->shape().dimensions(dnums.kernel_spatial_dimensions(i)); + int64 output_dim = + instruction->shape().dimensions(dnums.output_spatial_dimensions(i)); + if (kernel_dim >= output_dim) { + return false; + } + } + return true; +} + +// Return the operand which is the most suitable for determining the sharding +// for the specified instruction or nullptr if there isn't any suitable operand. +const HloInstruction* PickRepresentativeOperand( + const HloInstruction* instruction) { + switch (instruction->opcode()) { + case HloOpcode::kMap: + case HloOpcode::kPad: + case HloOpcode::kPower: + case HloOpcode::kReverse: + case HloOpcode::kSlice: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: + // For these opcodes the output sharding has to be determined by the + // sharding of the first operand but we can only determine sharding based + // on it if it already has a sharding. + if (instruction->operand(0)->has_sharding()) { + return instruction->operand(0); + } + return nullptr; + case HloOpcode::kAbs: + case HloOpcode::kAdd: + case HloOpcode::kAnd: + case HloOpcode::kAtan2: + case HloOpcode::kBitcastConvert: + case HloOpcode::kCeil: + case HloOpcode::kClamp: + case HloOpcode::kClz: + case HloOpcode::kCompare: + case HloOpcode::kComplex: + case HloOpcode::kConcatenate: + case HloOpcode::kConvert: + case HloOpcode::kCopy: + case HloOpcode::kCos: + case HloOpcode::kAllGather: + case HloOpcode::kAllReduce: + case HloOpcode::kAllToAll: + case HloOpcode::kCollectivePermute: + case HloOpcode::kDivide: + case HloOpcode::kExp: + case HloOpcode::kExpm1: + case HloOpcode::kFloor: + case HloOpcode::kImag: + case HloOpcode::kIsFinite: + case HloOpcode::kLog: + case HloOpcode::kLog1p: + case HloOpcode::kMaximum: + case HloOpcode::kMinimum: + case HloOpcode::kMultiply: + case HloOpcode::kNegate: + case HloOpcode::kNot: + case HloOpcode::kOr: + case HloOpcode::kPopulationCount: + case HloOpcode::kReal: + case HloOpcode::kReducePrecision: + case HloOpcode::kRemainder: + case HloOpcode::kRoundNearestAfz: + case HloOpcode::kRsqrt: + case HloOpcode::kSelect: + case HloOpcode::kSign: + case HloOpcode::kSin: + case HloOpcode::kSort: + case HloOpcode::kSqrt: + case HloOpcode::kCbrt: + case HloOpcode::kSubtract: + case HloOpcode::kTanh: + case HloOpcode::kTupleSelect: + case HloOpcode::kWhile: + case HloOpcode::kXor: { + // For these opcodes the output sharding can be determined by any operand + // so we find the operand with the most specific sharding. + const HloInstruction* best_operand = nullptr; + for (const HloInstruction* operand : instruction->operands()) { + if (operand->has_sharding() && + (best_operand == nullptr || + IsShardingMoreSpecific(operand->sharding(), + best_operand->sharding()))) { + best_operand = operand; + } + } + return best_operand; + } + + // There is no suitable operand for the rest of the opcodes. + case HloOpcode::kAddDependency: + case HloOpcode::kAfterAll: + case HloOpcode::kBatchNormGrad: + case HloOpcode::kBatchNormInference: + case HloOpcode::kBatchNormTraining: + case HloOpcode::kBitcast: + case HloOpcode::kBroadcast: + case HloOpcode::kCall: + case HloOpcode::kCholesky: + case HloOpcode::kCollectivePermuteDone: + case HloOpcode::kCollectivePermuteStart: + case HloOpcode::kConditional: + case HloOpcode::kConstant: + case HloOpcode::kConvolution: + case HloOpcode::kCopyDone: + case HloOpcode::kCopyStart: + case HloOpcode::kCustomCall: + case HloOpcode::kDomain: + case HloOpcode::kDot: + case HloOpcode::kDynamicSlice: + case HloOpcode::kDynamicUpdateSlice: + case HloOpcode::kFft: + case HloOpcode::kFusion: + case HloOpcode::kGather: + case HloOpcode::kGetTupleElement: + case HloOpcode::kInfeed: + case HloOpcode::kIota: + case HloOpcode::kOutfeed: + case HloOpcode::kParameter: + case HloOpcode::kPartitionId: + case HloOpcode::kRecv: + case HloOpcode::kRecvDone: + case HloOpcode::kReduce: + case HloOpcode::kReduceWindow: + case HloOpcode::kReplicaId: + case HloOpcode::kReshape: + case HloOpcode::kRng: + case HloOpcode::kRngGetAndUpdateState: + case HloOpcode::kRngBitGenerator: + case HloOpcode::kScatter: + case HloOpcode::kSelectAndScatter: + case HloOpcode::kSend: + case HloOpcode::kSendDone: + case HloOpcode::kTrace: + case HloOpcode::kTranspose: + case HloOpcode::kTriangularSolve: + case HloOpcode::kTuple: + case HloOpcode::kGetDimensionSize: + case HloOpcode::kSetDimensionSize: + return nullptr; + } +} + +bool SupportSpatialPartitioning(const HloInstruction* instruction, + const ComputationMap& computation_map, + bool is_spmd) { + if (instruction->parent()->root_instruction() == instruction && + computation_map.find(instruction->parent()) == computation_map.end()) { + // We don't support sharding the root instruction of a computation yet, + // unless the computation is a while body. + return false; + } + + if (instruction->IsElementwise() && + (instruction->opcode() != HloOpcode::kRng || is_spmd)) { + return true; + } + switch (instruction->opcode()) { + case HloOpcode::kBroadcast: + case HloOpcode::kConcatenate: + case HloOpcode::kConditional: + case HloOpcode::kConstant: + case HloOpcode::kConvolution: + case HloOpcode::kDot: + case HloOpcode::kDynamicSlice: + case HloOpcode::kDynamicUpdateSlice: + case HloOpcode::kGather: + case HloOpcode::kGetTupleElement: + case HloOpcode::kInfeed: + case HloOpcode::kIota: + case HloOpcode::kPad: + case HloOpcode::kReduceWindow: + case HloOpcode::kReshape: + case HloOpcode::kScatter: + case HloOpcode::kSelectAndScatter: + case HloOpcode::kSlice: + case HloOpcode::kSort: + case HloOpcode::kTranspose: + case HloOpcode::kTuple: + case HloOpcode::kWhile: + case HloOpcode::kReduce: + return true; + case HloOpcode::kAllReduce: + // Only if channel_id is not specified. + return instruction->channel_id() == absl::nullopt; + case HloOpcode::kParameter: + return computation_map.find(instruction->parent()) != + computation_map.end(); + case HloOpcode::kReverse: + return is_spmd; + default: + return false; + } +} + +// Tries to update the sharding of the specified instruction based on its +// operands and returns true if the sharding of the instruction have been +// changed and false otherwise. +bool InferShardingFromOperands(HloInstruction* instruction, + const ComputationMap& computation_map, + bool is_spmd, bool aggressive_prop) { + if (!SupportSpatialPartitioning(instruction, computation_map, is_spmd)) { + // If an array shaped HLO doesn't support spatial partitioning but at least + // one of its operand is replicated then we make the HLO replicated as well. + if (instruction->shape().IsTuple() || instruction->operand_count() == 0 || + instruction == instruction->parent()->root_instruction() || + instruction->HasSideEffect()) { + return false; + } + if (absl::c_any_of(instruction->operands(), [](const HloInstruction* op) { + return op->has_sharding() && op->sharding().IsReplicated(); + })) { + return MaybeImproveInstructionSharding(HloSharding::Replicate(), + instruction); + } + return false; + } + + switch (instruction->opcode()) { + case HloOpcode::kGetTupleElement: { + const HloInstruction* operand = instruction->operand(0); + if (!IsSpatiallyPartitioned(operand)) { + return false; + } + HloSharding new_sharding = operand->sharding().GetSubSharding( + operand->shape(), {instruction->tuple_index()}); + return MaybeImproveInstructionSharding(new_sharding, instruction); + } + case HloOpcode::kTuple: { + if (absl::c_none_of(instruction->operands(), + [](const HloInstruction* hlo) { + return IsSpatiallyPartitioned(hlo); + })) { + // None of the operands have a spatially partitioned sharding. + return false; + } + bool changed = false; + if (!instruction->has_sharding()) { + // Set the sharding for all elements in the tuple because it isn't + // possible to set a partial sharding. + SetDefaultTupleSharding(instruction); + changed = true; + } + // Go through each operand and if the operand has a sharding that is + // better than the current sharding for that tuple element then update + // it. + const Shape& shape = instruction->shape(); + std::vector<HloSharding> sub_shardings = + instruction->sharding().tuple_elements(); + int64 sub_sharding_index = 0; + for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + const HloInstruction* operand = instruction->operand(i); + if (operand->has_sharding()) { + if (operand->shape().IsTuple()) { + for (int64 i = 0, e = ShapeUtil::GetLeafCount(operand->shape()); + i < e; ++i) { + if (IsShardingMoreSpecific( + operand->sharding().tuple_elements()[i], + sub_shardings[sub_sharding_index + i])) { + sub_shardings[sub_sharding_index + i] = + operand->sharding().tuple_elements()[i]; + } + } + } else { + if (IsShardingMoreSpecific(operand->sharding(), + sub_shardings[sub_sharding_index])) { + sub_shardings[sub_sharding_index] = operand->sharding(); + } + } + } + sub_sharding_index += ShapeUtil::GetLeafCount(operand->shape()); + } + + HloSharding new_sharding = HloSharding::Tuple(shape, sub_shardings); + if (new_sharding != instruction->sharding()) { + instruction->set_sharding(new_sharding); + return true; + } + return changed; + } + case HloOpcode::kReduce: { + // Reduce could have a tuple shape, where the first half of operands are + // the arrays to reduce, and the second half of operands are the init + // values. + bool changed = false; + for (int64 operand_id = 0; operand_id < instruction->operand_count() / 2; + ++operand_id) { + const HloInstruction* operand = instruction->operand(operand_id); + if (!IsSpatiallyPartitioned(operand)) { + continue; + } + auto get_maybe_tuple_sharding = [&](const HloSharding& sharding) { + if (instruction->operand_count() == 2) { + return sharding; + } + std::vector<HloSharding> tuple(instruction->operand_count() / 2, + sharding); + return HloSharding::Tuple(instruction->shape(), tuple); + }; + if (operand->sharding().IsReplicated()) { + changed |= MaybeImproveInstructionSharding( + get_maybe_tuple_sharding(HloSharding::Replicate()), instruction); + continue; + } + if (absl::c_any_of(instruction->dimensions(), [operand](int64 dim) { + return operand->sharding().tile_assignment().dim(dim) > 1; + })) { + // We are reducing along one of the sharded dimensions. We don't + // support tiled sharding in this case. + changed |= MaybeImproveInstructionSharding( + get_maybe_tuple_sharding(HloSharding::Replicate()), instruction); + } else { + // We are reducing along some of the non-sharded dimensions. The + // result sharding should be the same as the operand sharding with the + // reduction dimensions removed as they are removed from the result + // shape. + std::vector<int64> target_tile_assignment_dimensions; + const auto& dimensions = instruction->dimensions(); + for (int64 i = 0; i < operand->shape().rank(); ++i) { + if (absl::c_find(dimensions, i) == dimensions.end()) { + target_tile_assignment_dimensions.push_back( + operand->sharding().tile_assignment().dim(i)); + } + } + Array<int64> new_tile_assignment = + operand->sharding().tile_assignment(); + new_tile_assignment.Reshape(target_tile_assignment_dimensions); + // Use the same sharding for all tuple elements, because they are part + // of the same reduce instruction. + HloSharding new_sharding = + get_maybe_tuple_sharding(HloSharding::Tile(new_tile_assignment)); + changed |= MaybeImproveInstructionSharding(new_sharding, instruction); + } + } + return changed; + } + case HloOpcode::kBroadcast: { + const HloInstruction* op = instruction->operand(0); + if (!IsSpatiallyPartitioned(op) || op->sharding().IsReplicated()) { + return false; + } + // Heuristic: If an operand is more than 8 times fewer elements than its + // output, do not propagate sharding. + if (ShapeUtil::ElementsIn(instruction->shape()) > + 8 * ShapeUtil::ElementsIn(op->shape())) { + return false; + } + // The output will be tiled along the broadcasted dimension the same way + // as the input for the broadcast while the other dimensions are kept + // non-tiled. + std::vector<int64> target_tile_assignment_dimensions; + const auto& dimensions = instruction->dimensions(); + for (int64 i = 0; i < instruction->shape().rank(); ++i) { + auto it = absl::c_find(dimensions, i); + if (it == dimensions.end()) { + target_tile_assignment_dimensions.push_back(1); + } else { + const int64 source_dim = std::distance(dimensions.begin(), it); + target_tile_assignment_dimensions.push_back( + op->sharding().tile_assignment().dim(source_dim)); + } + } + Array<int64> new_tile_assignment = op->sharding().tile_assignment(); + new_tile_assignment.Reshape(target_tile_assignment_dimensions); + HloSharding new_sharding = HloSharding::Tile(new_tile_assignment); + return MaybeImproveInstructionSharding(new_sharding, instruction); + } + case HloOpcode::kConvolution: { + const auto& dnums = instruction->convolution_dimension_numbers(); + const HloInstruction* lhs = instruction->operand(0); + const HloInstruction* rhs = instruction->operand(1); + auto get_tiled_sharding_based_on_lhs = [&] { + CHECK(!lhs->sharding().IsTileMaximal()); + std::vector<int64> output_to_lhs_indices(instruction->shape().rank()); + output_to_lhs_indices[dnums.output_batch_dimension()] = + dnums.input_batch_dimension(); + output_to_lhs_indices[dnums.output_feature_dimension()] = + dnums.input_feature_dimension(); + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + output_to_lhs_indices[dnums.output_spatial_dimensions(i)] = + dnums.input_spatial_dimensions(i); + } + return hlo_sharding_util::TransposeSharding(lhs->sharding(), + output_to_lhs_indices); + }; + auto get_tiled_sharding_based_on_rhs = [&] { + CHECK(!rhs->sharding().IsTileMaximal()); + std::vector<int64> output_to_rhs_indices(instruction->shape().rank()); + output_to_rhs_indices[dnums.output_batch_dimension()] = + dnums.kernel_input_feature_dimension(); + output_to_rhs_indices[dnums.output_feature_dimension()] = + dnums.kernel_output_feature_dimension(); + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + output_to_rhs_indices[dnums.output_spatial_dimensions(i)] = + dnums.kernel_spatial_dimensions(i); + } + return hlo_sharding_util::TransposeSharding(rhs->sharding(), + output_to_rhs_indices); + }; + if (auto dot_dims = + dot_as_convolution_util::ParseDotGeneralFromConvolution( + instruction)) { + // lhs_or_rhs: lhs is 0 and rhs is 1. + auto partitioned_only_along = + [&](const HloSharding& sharding, + std::vector<dot_as_convolution_util:: + DotGeneralAsConvolutionDimsInfo::DimNums>& dims, + int64 lhs_or_rhs) { + if (sharding.IsTileMaximal()) { + return false; + } + int64 partition_count = 1; + for (const auto& dim : dims) { + if (lhs_or_rhs == 0) { + partition_count *= sharding.tile_assignment().dim(dim.lhs); + } else { + CHECK_EQ(lhs_or_rhs, 1); + partition_count *= sharding.tile_assignment().dim(dim.rhs); + } + } + return partition_count == + sharding.tile_assignment().num_elements(); + }; + // If LHS/RHS is partitioned only along the batch dimensions, propagate + // the sharding to the output, since batch dimensions are the easiest to + // partition. + if (IsSpatiallyPartitioned(lhs) && + partitioned_only_along(lhs->sharding(), dot_dims->batch_dims, 0)) { + return MaybeImproveInstructionSharding( + get_tiled_sharding_based_on_lhs(), instruction); + } + if (IsSpatiallyPartitioned(rhs) && + partitioned_only_along(rhs->sharding(), dot_dims->batch_dims, 1)) { + return MaybeImproveInstructionSharding( + get_tiled_sharding_based_on_rhs(), instruction); + } + if (aggressive_prop) { + // If LHS/RHS is partitioned only along the non-contracting + // dimensions, propagate the sharding to the output. + const bool can_propagate_from_lhs = + IsSpatiallyPartitioned(lhs) && + partitioned_only_along(lhs->sharding(), + dot_dims->lhs_non_contracting_dims, 0); + const bool can_propagate_from_rhs = + IsSpatiallyPartitioned(rhs) && + partitioned_only_along(rhs->sharding(), + dot_dims->rhs_non_contracting_dims, 1); + // If we can propagate from both operands, choose the larger one which + // should help us reduce communications. + if (can_propagate_from_lhs && can_propagate_from_rhs) { + if (Product(lhs->shape().dimensions()) >= + Product(rhs->shape().dimensions())) { + return MaybeImproveInstructionSharding( + get_tiled_sharding_based_on_lhs(), instruction); + } else { + return MaybeImproveInstructionSharding( + get_tiled_sharding_based_on_rhs(), instruction); + } + } + if (can_propagate_from_lhs) { + return MaybeImproveInstructionSharding( + get_tiled_sharding_based_on_lhs(), instruction); + } + if (can_propagate_from_rhs) { + return MaybeImproveInstructionSharding( + get_tiled_sharding_based_on_rhs(), instruction); + } + } + } + + if (!IsSpatiallyPartitioned(lhs)) { + return false; + } + if (lhs->sharding().IsReplicated()) { + return MaybeImproveInstructionSharding(HloSharding::Replicate(), + instruction); + } + + if (IsConvolutionKernelSmall(instruction)) { + // If the kernel is small compared to the input then we can generate an + // output what is sharded the same way as the input. + const auto& tile_assignment = lhs->sharding().tile_assignment(); + if (tile_assignment.dim(dnums.input_feature_dimension()) > 1) { + return false; + } + return MaybeImproveInstructionSharding( + get_tiled_sharding_based_on_lhs(), instruction); + } + // If the kernel is large (e.g backward convolution) then we only support + // replicated output. + return MaybeImproveInstructionSharding(HloSharding::Replicate(), + instruction); + } + case HloOpcode::kTranspose: { + const HloInstruction* input = instruction->operand(0); + if (!IsSpatiallyPartitioned(input)) { + return false; + } + HloSharding sharding = hlo_sharding_util::TransposeSharding( + input->sharding(), instruction->dimensions()); + return MaybeImproveInstructionSharding(sharding, instruction); + } + case HloOpcode::kReduceWindow: { + const HloInstruction* lhs = instruction->operand(0); + if (!IsSpatiallyPartitioned(lhs)) { + return false; + } + + auto has_dilation = [](const WindowDimension& dimensions) { + return dimensions.base_dilation() > 1 || + dimensions.window_dilation() > 1; + }; + if (absl::c_any_of(instruction->window().dimensions(), has_dilation)) { + VLOG(2) << "Not applying sharding to reduce window because dilatation " + "isn't supported yet: " + << instruction->ToString(); + return false; + } + return MaybeImproveInstructionSharding(lhs->sharding(), instruction); + } + case HloOpcode::kSelectAndScatter: { + // Shard according to first operand, as output keeps the same shape. + const HloInstruction* lhs = instruction->operand(0); + if (!IsSpatiallyPartitioned(lhs)) { + return false; + } + + auto has_base_dilation = [](const WindowDimension& dimensions) { + return dimensions.base_dilation() > 1; + }; + if (absl::c_any_of(instruction->window().dimensions(), + has_base_dilation)) { + VLOG(2) << "Not applying sharding to select-and-scatter because " + "base dilation isn't supported yet: " + << instruction->ToString(); + return false; + } + return MaybeImproveInstructionSharding(lhs->sharding(), instruction); + } + case HloOpcode::kReshape: { + if (!IsSpatiallyPartitioned(instruction->operand(0))) { + return false; + } + absl::optional<HloSharding> new_sharding = + hlo_sharding_util::ReshapeSharding( + instruction->operand(0)->shape(), instruction->shape(), + instruction->operand(0)->sharding()); + if (new_sharding.has_value()) { + return MaybeImproveInstructionSharding(new_sharding.value(), + instruction); + } + return false; + } + case HloOpcode::kDot: { + auto& dot_dim_numbs = instruction->dot_dimension_numbers(); + // Batch dimensions are the same for lhs and rhs on dot operations. + int64 num_batch_dims = dot_dim_numbs.lhs_batch_dimensions_size(); + std::vector<int64> contracting_dims(2); + contracting_dims[0] = dot_dim_numbs.lhs_contracting_dimensions(0); + contracting_dims[1] = dot_dim_numbs.rhs_contracting_dimensions(0); + std::vector<const HloSharding*> ops_sharding(2, nullptr); + for (int64 op_num = 0; op_num < 2; ++op_num) { + const HloInstruction* op = instruction->operand(op_num); + if (IsSpatiallyPartitioned(op)) { + ops_sharding[op_num] = &op->sharding(); + } + } + if (ops_sharding[0] == nullptr && ops_sharding[1] == nullptr) { + return false; + } + + // Select representative operand. + int64 representative_op = -1; + if (ops_sharding[0] == nullptr) { + representative_op = 1; + } else if (ops_sharding[1] == nullptr) { + representative_op = 0; + } else if (ops_sharding[0]->IsReplicated() && + ops_sharding[1]->IsReplicated()) { + // Both replicated -> replicate + return MaybeImproveInstructionSharding(HloSharding::Replicate(), + instruction); + } else if (!ops_sharding[0]->IsReplicated() && + !ops_sharding[1]->IsReplicated()) { + // Both tile sharded. The dot spatial partitioning implementation + // replicates the operand corresponding to the non-tiled dimension: + // dot(lhs, rhs), sharding={devices=[1, ..., n, 1]} replicates rhs + // dot(lhs, rhs), sharding={devices=[1, ..., 1, n]} replicates lhs + // so set sharding in order to replicate the smaller of lhs and rhs + representative_op = + ShapeUtil::ByteSizeOf(instruction->operand(0)->shape()) < + ShapeUtil::ByteSizeOf(instruction->operand(1)->shape()) + ? 1 + : 0; + } else { + // One is replicated and the other is tiled - pick the tiled one. + representative_op = ops_sharding[0]->IsReplicated() ? 1 : 0; + } + + if (ops_sharding[representative_op]->IsReplicated()) { + return MaybeImproveInstructionSharding(HloSharding::Replicate(), + instruction); + } else { + // Tile-shard instruction according to representative op. + auto sharding = *ops_sharding[representative_op]; + if (instruction->shape().dimensions_size() != + sharding.tile_assignment().num_dimensions()) { + // It is necessarily the case of a matrix x vector, with + // representative_op being the matrix, because the vector op has the + // same shape as instruction. + CHECK_EQ(sharding.tile_assignment().num_dimensions(), + instruction->shape().dimensions_size() + 1); + // Reshape sharding so that last dimension is 1, and then remove + // last dimension. + std::vector<int64> non_batch_dims( + sharding.tile_assignment().num_dimensions() - num_batch_dims); + absl::c_iota(non_batch_dims, num_batch_dims); + sharding = hlo_sharding_util::ReshapeToTileDimension( + sharding, num_batch_dims, non_batch_dims); + auto tile_assignment = sharding.tile_assignment(); + auto dimensions = tile_assignment.dimensions(); + CHECK_EQ(dimensions.back(), 1); + dimensions.pop_back(); + tile_assignment.Reshape(dimensions); + sharding = HloSharding::Tile(tile_assignment); + } + return MaybeImproveInstructionSharding(sharding, instruction); + } + } + case HloOpcode::kParameter: { + auto parent_it = computation_map.find(instruction->parent()); + if (parent_it == computation_map.end()) { + return false; + } + const HloInstruction* parent = parent_it->second; + switch (parent->opcode()) { + case HloOpcode::kConditional: { + for (int64 i = 1; i < parent->operand_count(); ++i) { + if (parent->called_computations()[i - 1] == instruction->parent()) { + if (parent->operand(i)->has_sharding()) { + return MaybeImproveInstructionSharding( + parent->operand(i)->sharding(), instruction); + } + return false; + } + } + return false; + } + default: + return false; + } + } + case HloOpcode::kSort: { + const HloInstruction* operand = PickRepresentativeOperand(instruction); + if (!operand || !IsSpatiallyPartitioned(operand)) { + return false; + } + + if (!operand->sharding().IsTileMaximal() && + operand->sharding().tile_assignment().dim( + instruction->dimensions(0)) != 1) { + // Doesn't support sharding the sorting dimension. + return false; + } + + if (instruction->shape().IsTuple()) { + return MaybeImproveInstructionSharding( + HloSharding::SingleTuple(instruction->shape(), operand->sharding()), + instruction); + } else { + return MaybeImproveInstructionSharding(operand->sharding(), + instruction); + } + } + case HloOpcode::kDynamicSlice: + case HloOpcode::kDynamicUpdateSlice: { + auto propagate_slicing = [instruction]() { + const HloInstruction* operand = + instruction->opcode() == HloOpcode::kDynamicSlice + ? instruction->operand(0) + : instruction->operand(1); + if (!IsSpatiallyPartitioned(operand)) { + return false; + } + + if (operand->sharding().IsReplicated()) { + return MaybeImproveInstructionSharding(HloSharding::Replicate(), + instruction); + } + + const auto& tile_assignment = operand->sharding().tile_assignment(); + for (int64 i = 0; i < instruction->shape().rank(); ++i) { + if (tile_assignment.dim(i) > 1 && + instruction->shape().dimensions(i) != + operand->shape().dimensions(i)) { + return false; + } + } + return MaybeImproveInstructionSharding(operand->sharding(), + instruction); + }; + auto propagate_base = [instruction]() { + if (instruction->opcode() != HloOpcode::kDynamicUpdateSlice) { + return false; + } + if (!IsSpatiallyPartitioned(instruction->operand(0))) { + return false; + } + return MaybeImproveInstructionSharding( + instruction->operand(0)->sharding(), instruction); + }; + return propagate_slicing() || propagate_base(); + } + case HloOpcode::kGather: { + if (!IsSpatiallyPartitioned(instruction->operand(1))) { + return false; + } + HloSharding new_sharding = hlo_sharding_util::GatherOutputSharding( + instruction->operand(1)->sharding(), instruction); + return MaybeImproveInstructionSharding(new_sharding, instruction); + } + case HloOpcode::kScatter: { + if (!IsSpatiallyPartitioned(instruction->operand(1)) && + !IsSpatiallyPartitioned(instruction->operand(2))) { + return false; + } + return MaybeImproveInstructionSharding(HloSharding::Replicate(), + instruction); + } + case HloOpcode::kWhile: { + if (!instruction->operand(0)->has_sharding()) { + return false; + } + auto sharding = instruction->operand(0)->sharding(); + if (instruction->has_sharding()) { + sharding = + MergeForMoreSpecificSharding(sharding, instruction->sharding()); + } + return MaybeImproveInstructionSharding(sharding, instruction); + } + default: { + const HloInstruction* operand = PickRepresentativeOperand(instruction); + if (!operand || !IsSpatiallyPartitioned(operand)) { + return false; + } + return MaybeImproveInstructionSharding(operand->sharding(), instruction); + } + } + return false; +} + +// Return the sharding that should be propagated from user to instruction. +absl::optional<HloSharding> GetShardingFromUser( + const HloInstruction& instruction, const HloInstruction& user, + bool aggressive_prop, bool is_spmd) { + if (!IsSpatiallyPartitioned(&user)) { + return absl::nullopt; + } + switch (user.opcode()) { + case HloOpcode::kBroadcast: { + if (user.sharding().IsReplicated()) { + return user.sharding(); + } + // Only support when none of the partitioned dimensions in the broadcast + // output belong to new dimensions. + for (int64 i = 0; i < user.shape().rank(); ++i) { + if (user.sharding().tile_assignment().dim(i) > 1 && + absl::c_count(user.dimensions(), i) == 0) { + return absl::nullopt; + } + } + + // The instruction (operand of broadcast) will be tiled the same way + // as the output. + std::vector<int64> target_tile_assignment_dimensions; + for (int64 output_dim : user.dimensions()) { + target_tile_assignment_dimensions.push_back( + user.sharding().tile_assignment().dim(output_dim)); + } + Array<int64> new_tile_assignment = user.sharding().tile_assignment(); + new_tile_assignment.Reshape(target_tile_assignment_dimensions); + return HloSharding::Tile(new_tile_assignment); + } + case HloOpcode::kConcatenate: { + if (user.sharding().IsReplicated()) { + return user.sharding(); + } + + const int64 cdim = user.concatenate_dimension(); + const Array<int64>& tile_assignment = user.sharding().tile_assignment(); + if (tile_assignment.dim(cdim) == 1) { + // If we are concatenating along a non-sharded dimension then the + // operands should have the same sharding as the result. + return user.sharding(); + } + + if (is_spmd) { + // SPMD doesn't support tiling with part of the devices. Return the same + // sharding. + return user.sharding(); + } + + // If we are concatenating along a sharded dimension then we want the + // operands to be distributed among the devices their data is used. + int64 start_offset = 0; + for (HloInstruction* op : user.operands()) { + if (op == &instruction) { + break; + } + start_offset += op->shape().dimensions(cdim); + } + const int64 tile_shape = CeilOfRatio(user.shape().dimensions(cdim), + tile_assignment.dimensions()[cdim]); + std::vector<int64> start_indices(tile_assignment.num_dimensions()); + std::vector<int64> end_indices = tile_assignment.dimensions(); + start_indices[cdim] = start_offset / tile_shape; + end_indices[cdim] = CeilOfRatio( + start_offset + instruction.shape().dimensions(cdim), tile_shape); + auto new_tile_assignment = + tile_assignment.Slice(start_indices, end_indices); + if (new_tile_assignment.num_elements() == 1) { + return HloSharding::AssignDevice(*new_tile_assignment.begin()); + } + return HloSharding::Tile(new_tile_assignment); + } + case HloOpcode::kConvolution: { + if (auto dot_dims = + dot_as_convolution_util::ParseDotGeneralFromConvolution(&user)) { + const auto& dnums = user.convolution_dimension_numbers(); + auto partitioned_only_along = + [&](const HloSharding& sharding, + std::vector<dot_as_convolution_util:: + DotGeneralAsConvolutionDimsInfo::DimNums>& + dims) { + if (sharding.IsTileMaximal()) { + return false; + } + int64 partition_count = 1; + for (const auto& dim : dims) { + partition_count *= sharding.tile_assignment().dim(dim.output); + } + return partition_count == + sharding.tile_assignment().num_elements(); + }; + // If output is partitioned only along the batch dimensions, or only + // along the non-contracting dimensions, propagate the sharding to the + // operand. + if (&instruction == user.operand(0) && + (partitioned_only_along(user.sharding(), dot_dims->batch_dims) || + partitioned_only_along(user.sharding(), + dot_dims->lhs_non_contracting_dims))) { + std::vector<int64> lhs_to_output_indices(user.shape().rank()); + lhs_to_output_indices[dnums.input_batch_dimension()] = + dnums.output_batch_dimension(); + lhs_to_output_indices[dnums.input_feature_dimension()] = + dnums.output_feature_dimension(); + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + lhs_to_output_indices[dnums.input_spatial_dimensions(i)] = + dnums.output_spatial_dimensions(i); + } + return hlo_sharding_util::TransposeSharding(user.sharding(), + lhs_to_output_indices); + } + if (&instruction == user.operand(1) && + (partitioned_only_along(user.sharding(), dot_dims->batch_dims) || + partitioned_only_along(user.sharding(), + dot_dims->rhs_non_contracting_dims))) { + std::vector<int64> rhs_to_output_indices(user.shape().rank()); + rhs_to_output_indices[dnums.kernel_input_feature_dimension()] = + dnums.output_batch_dimension(); + rhs_to_output_indices[dnums.kernel_output_feature_dimension()] = + dnums.output_feature_dimension(); + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + rhs_to_output_indices[dnums.kernel_spatial_dimensions(i)] = + dnums.output_spatial_dimensions(i); + } + return hlo_sharding_util::TransposeSharding(user.sharding(), + rhs_to_output_indices); + } + } + return absl::nullopt; + } + case HloOpcode::kDynamicSlice: + case HloOpcode::kDynamicUpdateSlice: { + if (user.sharding().IsReplicated()) { + return user.sharding(); + } + if (user.opcode() == HloOpcode::kDynamicUpdateSlice && + &instruction == user.operand(0)) { + return user.sharding(); + } + const HloInstruction* operand = user.opcode() == HloOpcode::kDynamicSlice + ? user.operand(0) + : user.operand(1); + if (&instruction != operand) { + return absl::nullopt; + } + + const auto& tile_assignment = user.sharding().tile_assignment(); + for (int64 i = 0; i < user.shape().rank(); ++i) { + if (tile_assignment.dim(i) > 1 && + user.shape().dimensions(i) != operand->shape().dimensions(i)) { + return absl::nullopt; + } + } + return user.sharding(); + } + case HloOpcode::kReduceWindow: { + if (&instruction != user.operand(0)) { + return absl::nullopt; + } + return user.sharding(); + } + case HloOpcode::kReshape: { + return hlo_sharding_util::ReshapeSharding( + user.shape(), instruction.shape(), user.sharding()); + } + case HloOpcode::kTranspose: { + // Calculate the dimension numbers for reversing the current transpose + // and then use TransposeSharding to convert the output sharding to an + // input sharding. + std::vector<int64> reverse_dimensions(user.dimensions().size()); + for (int64 i = 0; i < user.dimensions().size(); ++i) { + reverse_dimensions[user.dimensions(i)] = i; + } + return hlo_sharding_util::TransposeSharding(user.sharding(), + reverse_dimensions); + } + case HloOpcode::kTuple: { + return user.sharding().GetSubSharding(user.shape(), + {user.operand_index(&instruction)}); + } + case HloOpcode::kGetTupleElement: { + HloSharding new_sharding = + instruction.has_sharding() + ? instruction.sharding() + : HloSharding::SingleTuple(instruction.shape(), + HloSharding::Replicate()); + int64 sharding_index = 0; + for (int64 i = 0; i < instruction.shape().tuple_shapes_size(); ++i) { + if (i == user.tuple_index()) { + break; + } + if (instruction.shape().tuple_shapes(i).IsArray()) { + sharding_index += 1; + } else { + sharding_index += + instruction.shape().tuple_shapes(i).tuple_shapes_size(); + } + } + if (user.shape().IsArray()) { + new_sharding.tuple_elements()[sharding_index] = user.sharding(); + } + for (int64 i = 0; i < user.sharding().tuple_elements().size(); ++i) { + new_sharding.tuple_elements()[sharding_index + i] = + user.sharding().tuple_elements()[i]; + } + return new_sharding; + } + case HloOpcode::kDot: { + if (user.sharding().IsReplicated()) { + return user.sharding(); + } + auto& dim_numbers = user.dot_dimension_numbers(); + int64 op_idx = user.operand_index(&instruction); + // Batch dimensions are the same on lhs and rhs for dot operations. + int64 num_batch_dims = dim_numbers.lhs_batch_dimensions_size(); + int64 num_spatial_dims = + instruction.shape().dimensions_size() - num_batch_dims; + if (num_spatial_dims == 1) { + // This is the vector of a matrix x vector operation -> replicate, + // since tiling on the vector would necessarily be on the contracting + // dimension, which we don't support. + CHECK_EQ(op_idx, 1); + return HloSharding::Replicate(); + } + // Instruction is necessarily a matrix because it is one of the operands + // of a matrix x matrix operation. + CHECK_EQ(num_spatial_dims, 2); + // Propagate tile sharding to the bigger operand, and replicate the other. + auto other_op = user.operand(op_idx ^ 1); + if (ShapeUtil::ByteSizeOf(instruction.shape()) > + ShapeUtil::ByteSizeOf(other_op->shape())) { + return user.sharding(); + } else { + return HloSharding::Replicate(); + } + } + case HloOpcode::kReduce: { + if (instruction.shape().rank() == 0) { + return absl::nullopt; + } + auto user_sharding = + user.shape().IsTuple() + ? user.sharding().GetSubSharding( + user.shape(), {user.operand_index(&instruction)}) + : user.sharding(); + if (user_sharding.IsTileMaximal()) { + return user_sharding; + } + std::vector<int64> target_tile_assignment_dimensions( + instruction.shape().rank()); + const auto& dimensions = user.dimensions(); + int64 next_output_dim = 0; + for (int64 i = 0; i < instruction.shape().rank(); ++i) { + if (absl::c_find(dimensions, i) == dimensions.end()) { + target_tile_assignment_dimensions[i] = + user_sharding.tile_assignment().dim(next_output_dim++); + } else { + target_tile_assignment_dimensions[i] = 1; + } + } + auto tile_assignment = user_sharding.tile_assignment(); + tile_assignment.Reshape(target_tile_assignment_dimensions); + return HloSharding::Tile(tile_assignment); + } + case HloOpcode::kSort: { + if (user.sharding().IsTuple()) { + return user.sharding().GetSubSharding( + user.shape(), {user.operand_index(&instruction)}); + } else { + return user.sharding(); + } + } + default: { + // If the user output shape is compatible with the current instruction + // shape excluding element type and the current instruction is supported + // by spatial partitioning, then the user sharding can be used for + // propagation to the current instruction. + if (ShapeUtil::CompatibleIgnoringElementType(instruction.shape(), + user.shape())) { + return user.sharding(); + } + return absl::nullopt; + } + } +} + +// Tries to update the sharding of the specified instruction based on its users +// and returns true if the sharding of the instruction have been changed and +// false otherwise. +bool InferShardingFromUsers(HloInstruction* instruction, + const ComputationMap& computation_map, + bool aggressive_prop, bool is_spmd) { + if (!SupportSpatialPartitioning(instruction, computation_map, is_spmd)) { + return false; + } + bool improved_sharding = false; + for (const HloInstruction* user : instruction->users()) { + absl::optional<HloSharding> user_sharding = + GetShardingFromUser(*instruction, *user, aggressive_prop, is_spmd); + if (user_sharding) { + improved_sharding |= + MaybeImproveInstructionSharding(*user_sharding, instruction); + } + } + return improved_sharding; +} + +// Remove Sharding custom-call instruction by folding the sharding attribute +// to its operand. If the operand alreayd has a different sharding, insert a +// copy node for reshard. +StatusOr<bool> ProcessShardingInstruction(HloModule* module) { + bool changed = false; + + for (HloComputation* computation : module->computations()) { + auto instructions = computation->MakeInstructionPostOrder(); + std::reverse(instructions.begin(), instructions.end()); + for (HloInstruction* instruction : instructions) { + if (instruction->opcode() != HloOpcode::kCustomCall) { + continue; + } + if (instruction->custom_call_target() != "Sharding") { + continue; + } + TF_RET_CHECK(instruction->has_sharding()) + << "Sharding instruction must have a sharding attribute"; + const HloSharding& sharding = instruction->sharding(); + + // If the operand has a different sharding from the current sharding + // instruction, create a copy node. Otherwise, just remove the sharding + // instruction and set the operand sharding. + if (instruction->operand(0)->has_sharding() && + instruction->operand(0)->sharding() != sharding) { + auto copy = computation->AddInstruction( + HloInstruction::CreateUnary(instruction->shape(), HloOpcode::kCopy, + instruction->mutable_operand(0))); + TF_RETURN_IF_ERROR(computation->ReplaceInstruction(instruction, copy)); + copy->set_sharding(sharding); + } else { + instruction->mutable_operand(0)->set_sharding(sharding); + TF_RETURN_IF_ERROR( + instruction->ReplaceAllUsesWith(instruction->mutable_operand(0))); + TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction)); + } + changed = true; + } + } + return changed; +} + +} // namespace + +/*static*/ Status ShardingPropagation::NormalizeDomain( + const DomainMetadata::Domain& domain, const DomainMetadata* metadata) { + if (metadata != nullptr) { + TF_ASSIGN_OR_RETURN(const auto& sharding_metadata, + ShardingMetadata::ToShardingMetadata(metadata)); + const auto& sharding = sharding_metadata->sharding(); + if (sharding != nullptr) { + bool is_spatially_partitioned = !sharding->HasUniqueDevice(); + if (sharding->IsTuple()) { + is_spatially_partitioned = absl::c_any_of( + sharding->tuple_elements(), + [](const HloSharding& s) { return !s.HasUniqueDevice(); }); + } + if (is_spatially_partitioned) { + for (HloInstruction* domain : domain.exit_domains) { + domain->mutable_operand(0)->set_sharding(*sharding); + } + return Status::OK(); + } + } + } + return ShardingMetadata::NormalizeShardingDomain(domain, metadata); +} + +StatusOr<bool> ShardingPropagation::Run(HloModule* module) { + TF_ASSIGN_OR_RETURN(bool any_changed, ProcessShardingInstruction(module)); + + // Association of partitionable embedded computations with their parent + // instruction. + ComputationMap computation_map; + + // Instructions that are related through a computation and need to share the + // same sharding. + auto get_related_instructions = [](HloInstruction* inst) { + if (inst->opcode() == HloOpcode::kWhile) { + return std::vector<HloInstruction*>{ + inst, inst->while_body()->root_instruction(), + inst->while_body()->parameter_instruction(0), + inst->while_condition()->parameter_instruction(0)}; + } else if (inst->opcode() == HloOpcode::kConditional) { + std::vector<HloInstruction*> comps{inst}; + for (HloComputation* c : inst->called_computations()) { + comps.push_back(c->root_instruction()); + } + return comps; + } else { + CHECK(false); + } + }; + + // If instruction is a while, or the root or a parameter of a while body, + // then propagate its sharding to the while instruction, to its body root, + // and to its condition parameter. + std::function<void(HloInstruction*)> maybe_computation_propagation = + [&](HloInstruction* instruction) { + auto propagate_to_instruction = [&](HloInstruction* search_inst) { + auto related_instructions = get_related_instructions(search_inst); + if (absl::c_count(related_instructions, instruction)) { + for (HloInstruction* inst : related_instructions) { + if (!inst->has_sharding() || + inst->sharding() != instruction->sharding()) { + VLOG(2) << "Add computation sharding: " << inst->name(); + inst->set_sharding(instruction->sharding()); + maybe_computation_propagation(inst); + } + } + } + }; + + if (instruction->opcode() == HloOpcode::kConditional || + instruction->opcode() == HloOpcode::kWhile) { + propagate_to_instruction(instruction); + } + + if (instruction->opcode() == HloOpcode::kParameter || + instruction->parent()->root_instruction() == instruction) { + auto it = computation_map.find(instruction->parent()); + if (it != computation_map.end()) { + propagate_to_instruction(it->second); + } + } + }; + + // Populate computation_map in order to associate while bodies to their + // while instructions. + for (auto computation : module->computations()) { + for (auto instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kWhile || + instruction->opcode() == HloOpcode::kConditional) { + // Check if any of the related instructions has sharding, in which case + // propagate it to the other instructions, so they all share the same + // sharding, in case the user didn't shard all of them. We don't check + // that user shardings are consistent, because such check is already + // done by HloShardingVerifier. + const HloInstruction* sharded_inst = nullptr; + auto related_instructions = get_related_instructions(instruction); + for (auto inst : related_instructions) { + if (inst->has_sharding()) { + sharded_inst = inst; + break; + } + } + if (sharded_inst != nullptr) { + // Set the same sharding to all the other related instructions. + for (auto inst : related_instructions) { + inst->set_sharding(sharded_inst->sharding()); + } + } + } + if (instruction->opcode() == HloOpcode::kWhile) { + computation_map[instruction->while_body()] = instruction; + } else if (instruction->opcode() == HloOpcode::kConditional) { + for (HloComputation* c : instruction->called_computations()) { + computation_map[c] = instruction; + } + } + } + } + + // Collect all pre-sharded instructions as we aren't allowed to modify their + // sharding. + absl::flat_hash_set<const HloInstruction*> provided_shardings; + for (const HloComputation* computation : module->computations()) { + for (const HloInstruction* inst : computation->instructions()) { + if (inst->has_sharding()) { + provided_shardings.insert(inst); + } + } + } + + // Consider the root instruction of the entry module as one with provided + // sharding as its sharding have to match with the one expected by the host. + provided_shardings.insert(module->entry_computation()->root_instruction()); + + // Iterate to a fixpoint that is guaranteed to be reached because we only + // strictly improve the sharding of the graph and it can't be improved + // indefinitely. + int64 iterations = 0; + auto run_to_fix_point = [&](bool aggressive_prop) { + bool changed = true; + while (changed) { + changed = false; + int64 inferred_from_operand_counter = 0; + int64 inferred_from_user_counter = 0; + int64 instruction_counter = 0; + int64 already_sharded_counter = 0; + for (const HloComputation* computation : module->computations()) { + std::vector<HloInstruction*> instructions = + computation->MakeInstructionPostOrder(); + + instruction_counter += instructions.size(); + for (const HloInstruction* instruction : instructions) { + already_sharded_counter += (instruction->has_sharding() ? 1 : 0); + } + + // Remove the instructions where the sharding was provided from the + // outside so we don't modify them. + instructions.erase( + std::remove_if(instructions.begin(), instructions.end(), + [&](HloInstruction* instruction) { + return provided_shardings.contains(instruction); + }), + instructions.end()); + + // First iterate the HLO graph in post order taking shardings from + // operands. + for (HloInstruction* instruction : instructions) { + if (InferShardingFromOperands(instruction, computation_map, is_spmd_, + aggressive_prop)) { + ++inferred_from_operand_counter; + changed = true; + VLOG(2) << "Add sharding (forward-pass): " + << instruction->ToString(); + maybe_computation_propagation(instruction); + } + } + + // Then iterate the HLO graph in reverse post order taking shardings + // from users. + for (auto it = instructions.rbegin(); it != instructions.rend(); ++it) { + if (InferShardingFromUsers(*it, computation_map, aggressive_prop, + is_spmd_)) { + ++inferred_from_user_counter; + changed = true; + VLOG(2) << "Add sharding (backward-pass): " << (*it)->ToString(); + maybe_computation_propagation(*it); + } + } + } + any_changed |= changed; + VLOG(1) << "Sharding propagation iteration " << iterations << ";"; + VLOG(1) << " total instructions: " << instruction_counter; + VLOG(1) << " instructions already sharded: " << already_sharded_counter; + VLOG(1) << " shardings inferred from operands: " + << inferred_from_operand_counter; + VLOG(1) << " shardings inferred from users: " + << inferred_from_user_counter; + ++iterations; + } + }; + run_to_fix_point(false); + run_to_fix_point(true); + + VLOG(1) << "Sharding propagation completed after " << iterations + << " iterations"; + return any_changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/sharding_propagation.h b/tensorflow/compiler/xla/service/sharding_propagation.h new file mode 100644 index 00000000000..2c07a4a6a31 --- /dev/null +++ b/tensorflow/compiler/xla/service/sharding_propagation.h @@ -0,0 +1,50 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SHARDING_PROPAGATION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_SHARDING_PROPAGATION_H_ + +#include <memory> +#include <vector> + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// Propagates sharding information around the graph. HLOs that have shardings +// are kept as-is, those that do not have shardings are given shardings based on +// a simple local greedy heuristic. +class ShardingPropagation : public HloModulePass { + public: + explicit ShardingPropagation(bool is_spmd = false) : is_spmd_(is_spmd) {} + absl::string_view name() const override { return "sharding-propagation"; } + StatusOr<bool> Run(HloModule* module) override; + + // Function which can be used to apply a spatially partitioned sharding onto a + // given domain. It will apply the sharding into the exit edges of the domain + // and then rely on the rest of sharding propagation to ensure that the + // intermediate nodes get the correct sharding. + static Status NormalizeDomain(const DomainMetadata::Domain& domain, + const DomainMetadata* metadata); + + private: + bool is_spmd_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_SHARDING_PROPAGATION_H_ diff --git a/tensorflow/compiler/xla/service/sharding_propagation_test.cc b/tensorflow/compiler/xla/service/sharding_propagation_test.cc new file mode 100644 index 00000000000..a9d685a7a93 --- /dev/null +++ b/tensorflow/compiler/xla/service/sharding_propagation_test.cc @@ -0,0 +1,1329 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/sharding_propagation.h" + +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" + +namespace op = xla::testing::opcode_matchers; + +namespace xla { +namespace { + +using ShardingPropagationTest = HloTestBase; + +TEST_F(ShardingPropagationTest, ElementwiseOperationForwardPass) { + const char* const hlo_string = R"( +HloModule module +ENTRY %elementwise { + %param0 = f32[5,7,11,13]{3,2,1,0} parameter(0), + sharding={devices=[1,2,2,1]0,1,2,3} + %param1 = f32[5,7,11,13]{3,2,1,0} parameter(1) + %add = f32[5,7,11,13]{3,2,1,0} add(%param0, %param1) + ROOT %copy = f32[5,7,11,13]{3,2,1,0} copy(%add) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "add"), + op::Sharding("{devices=[1,2,2,1]0,1,2,3}")); +} + +TEST_F(ShardingPropagationTest, ElementwiseOperationBackwardPass) { + const char* const hlo_string = R"( +HloModule module +ENTRY %elementwise { + %param0 = f32[5,7,11,13]{3,2,1,0} parameter(0) + %param1 = f32[5,7,11,13]{3,2,1,0} parameter(1) + %add = f32[5,7,11,13]{3,2,1,0} add(%param0, %param1) + ROOT %copy = f32[5,7,11,13]{3,2,1,0} copy(%add), + sharding={devices=[1,2,2,1]0,1,2,3} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "add"), + op::Sharding("{devices=[1,2,2,1]0,1,2,3}")); +} + +TEST_F(ShardingPropagationTest, BroadcastForwardPassNoSharding) { + const char* const hlo_string = R"( +HloModule module +ENTRY %broadcast { + %param0 = f32[7,11]{1,0} parameter(0), + sharding={devices=[2,2]0,1,2,3} + %broadcast = f32[5,7,11,13]{3,2,1,0} broadcast(%param0), dimensions={1,2} + ROOT %copy = f32[5,7,11,13]{3,2,1,0} copy(%broadcast) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_FALSE(changed); +} + +// Regression Test for b/129569657. +TEST_F(ShardingPropagationTest, BroadcastForwardPass) { + const char* const hlo_string = R"( +HloModule module +ENTRY %broadcast { + %param0 = f32[3,2048,2048]{2,1,0} parameter(0), + sharding={devices=[1,2,2]0,1,2,3} + %broadcast = f32[3,2048,2048,3]{3,2,1,0} broadcast(%param0), dimensions={0,1,2} + ROOT %copy = f32[3,2048,2048,3]{3,2,1,0} copy(%broadcast) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "broadcast"), + op::Sharding("{devices=[1,2,2,1]0,1,2,3}")); +} + +TEST_F(ShardingPropagationTest, BroadcastBackwardPass) { + const char* const hlo_string = R"( +HloModule module +ENTRY %broadcast { + %param0 = f32[13]{0} parameter(0) + %broadcast = f32[5,7,11,13]{3,2,1,0} broadcast(%param0), dimensions={3} + ROOT %copy = f32[5,7,11,13]{3,2,1,0} copy(%broadcast), + sharding={devices=[1,2,2,1]0,1,2,3} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "broadcast"), + op::Sharding("{devices=[1,2,2,1]0,1,2,3}")); +} + +TEST_F(ShardingPropagationTest, BroadcastUser) { + const char* const hlo_string = R"( +HloModule module +ENTRY %broadcast { + %param0 = f32[24,8]{0,1} parameter(0) + %copy = f32[24,8]{0,1} copy(%param0) + ROOT %broadcast = f32[4,24,6,8]{3,2,1,0} broadcast(%copy), dimensions={1,3}, + sharding={devices=[1,2,1,4]0,1,2,3,4,5,6,7} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "copy"), + op::Sharding("{devices=[2,4]0,1,2,3,4,5,6,7}")); +} + +TEST_F(ShardingPropagationTest, MaximalReduceForwardPass) { + const char* const hlo_string = R"( +HloModule module +%add { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %add = f32[] add(%lhs, %rhs) +} +ENTRY %reduce { + %param0 = f32[5,7,11,13]{3,2,1,0} parameter(0), + sharding={devices=[1,2,2,1]0,1,2,3} + %init = f32[] parameter(1) + %reduce = f32[5,7]{1,0} reduce(%param0, %init), dimensions={2,3}, to_apply=%add + ROOT %copy = f32[5,7]{0,1} copy(%reduce) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "reduce"), + op::Sharding("{replicated}")); +} + +TEST_F(ShardingPropagationTest, ShardedReduceForwardPass) { + const char* const hlo_string = R"( +HloModule module +%add { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %add = f32[] add(%lhs, %rhs) +} +ENTRY %reduce { + %param0 = f32[5,7,11,13]{3,2,1,0} parameter(0), + sharding={devices=[1,2,2,1]0,1,2,3} + %init = f32[] parameter(1) + %reduce = f32[7,11]{1,0} reduce(%param0, %init), dimensions={0,3}, to_apply=%add + ROOT %copy = f32[7,11]{0,1} copy(f32[7,11]{1,0} %reduce) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "reduce"), + op::Sharding("{devices=[2,2]0,1,2,3}")); +} + +TEST_F(ShardingPropagationTest, ShardedTupleReduceForwardAndBackwardPass) { + const char* const hlo_string = R"( +HloModule module + +%minmax_func { + %lhs_value = f32[] parameter(0) + %rhs_value = f32[] parameter(2) + %compare.2 = pred[] compare(%lhs_value, %rhs_value), direction=GT + %select.4 = f32[] select(%compare.2, %lhs_value, %rhs_value) + %lhs_index = s32[] parameter(1) + %rhs_index = s32[] parameter(3) + %select.5 = s32[] select(%compare.2, %lhs_index, %rhs_index) + ROOT %tuple.2 = (f32[], s32[]) tuple(%select.4, %select.5) +} + +ENTRY %main { + %param0 = f32[28,10] parameter(0) + %param1 = s32[28,10] parameter(1), sharding={devices=[2,1]0,1} + %copy_param0 = f32[28,10] copy(%param0) + %init0 = f32[] parameter(2) + %init1 = s32[] parameter(3) + %reduce = (f32[28], s32[28]) reduce(%copy_param0, %param1, %init0, %init1), + dimensions={1}, to_apply=%minmax_func + %gte0 = f32[28] get-tuple-element(%reduce), index=0 + %gte1 = s32[28] get-tuple-element(%reduce), index=1 + %copy0 = f32[28] copy(%gte0) + %copy1 = s32[28] copy(%gte1) + ROOT %tuple = (f32[28], s32[28]) tuple(%copy0, %copy1) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "reduce"), + op::Sharding("{{devices=[2]0,1},{devices=[2]0,1}}")); + EXPECT_THAT(FindInstruction(module.get(), "copy_param0"), + op::Sharding("{devices=[2,1]0,1}")); +} + +TEST_F(ShardingPropagationTest, GetTupleElementForwardPass) { + const char* const hlo_string = R"( +HloModule module +ENTRY %gte { + %param0 = f32[5,7,11,13]{3,2,1,0} parameter(0) + %tuple = (f32[5,7,11,13]{3,2,1,0}, f32[5,7,11,13]{3,2,1,0}) tuple( + %param0, %param0) + %tuple.1 = (f32[5,7,11,13]{3,2,1,0}, + (f32[5,7,11,13]{3,2,1,0}, f32[5,7,11,13]{3,2,1,0})) tuple( + %param0, %tuple), + sharding={{devices=[1,2,2,1]0,1,2,3}, + {replicated}, + {devices=[1,2,2,1]0,1,2,3}} + %gte = f32[5,7,11,13]{3,2,1,0} get-tuple-element(%tuple.1), index=0 + %gte.1 = (f32[5,7,11,13]{3,2,1,0}, f32[5,7,11,13]{3,2,1,0}) get-tuple-element( + %tuple.1), index=1 + %gte.2 = f32[5,7,11,13]{3,2,1,0} get-tuple-element(%gte.1), index=0 + ROOT %copy = f32[5,7,11,13]{3,2,1,0} copy(%gte.2) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "gte"), + op::Sharding("{devices=[1,2,2,1]0,1,2,3}")); + EXPECT_THAT(FindInstruction(module.get(), "gte.1"), + op::Sharding("{{replicated}," + " {devices=[1,2,2,1]0,1,2,3}}")); + EXPECT_THAT(FindInstruction(module.get(), "gte.2"), + op::Sharding("{replicated}")); +} + +TEST_F(ShardingPropagationTest, TupleForwardPass) { + const char* const hlo_string = R"( +HloModule module +ENTRY %tuple { + %param0 = f32[5,7,11,13]{3,2,1,0} parameter(0), + sharding={replicated} + %param1 = f32[5,7,11,13]{3,2,1,0} parameter(1), + sharding={devices=[1,2,2,1]0,1,2,3} + %param2 = f32[5,7,11,13]{3,2,1,0} parameter(2) + %tuple = (f32[5,7,11,13]{3,2,1,0}, f32[5,7,11,13]{3,2,1,0}) tuple( + %param1, %param2) + %tuple.1 = (f32[5,7,11,13]{3,2,1,0}, + (f32[5,7,11,13]{3,2,1,0}, f32[5,7,11,13]{3,2,1,0})) tuple( + %param0, %tuple) + ROOT %copy = (f32[5,7,11,13]{3,2,1,0}, + (f32[5,7,11,13]{3,2,1,0}, f32[5,7,11,13]{3,2,1,0})) copy( + %tuple.1) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "tuple"), + op::Sharding("{{devices=[1,2,2,1]0,1,2,3}," + " {replicated}}")); + EXPECT_THAT(FindInstruction(module.get(), "tuple.1"), + op::Sharding("{{replicated}," + " {devices=[1,2,2,1]0,1,2,3}," + " {replicated}}")); +} + +TEST_F(ShardingPropagationTest, ForwardConvolutionForwardPass) { + const char* const hlo_string = R"( +HloModule module +ENTRY %conv { + %lhs = f32[5,7,11,13]{3,2,1,0} parameter(0), + sharding={devices=[2,2,2,1]0,1,2,3,4,5,6,7} + %rhs = f32[3,3,13,17]{3,2,1,0} parameter(1) + %convolution = f32[5,7,11,17]{3,2,1,0} convolution(%lhs, %rhs), + window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f + ROOT %copy = f32[5,7,11,17]{3,2,1,0} copy(%convolution) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "convolution"), + op::Sharding("{devices=[2,2,2,1]0,1,2,3,4,5,6,7}")); +} + +TEST_F(ShardingPropagationTest, ForwardConvolutionLargeDilationForwardPass) { + const char* const hlo_string = R"( +HloModule module +ENTRY %conv { + %lhs = f32[8,64,2]{2,1,0} parameter(0), + sharding={devices=[1,4,1]0,1,2,3} + %rhs = f32[3,2,2]{2,1,0} parameter(1) + %convolution = f32[8,32,2]{2,1,0} convolution(%lhs, %rhs), + window={size=3 rhs_dilate=16}, dim_labels=b0f_0io->b0f + ROOT %copy = f32[8,32,2]{2,1,0} copy(%convolution) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "convolution"), + op::Sharding("{devices=[1,4,1]0,1,2,3}")); +} + +TEST_F(ShardingPropagationTest, TransposeForwardPass) { + const char* const hlo_string = R"( +HloModule module +ENTRY %transpose { + %param = f32[7,11,13]{2,1,0} parameter(0), + sharding={devices=[2,1,2]0,1,2,3} + %transpose = f32[11,13,7]{2,1,0} transpose(%param), dimensions={1,2,0} + ROOT %copy = f32[11,13,7]{2,1,0} copy(%transpose) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "transpose"), + op::Sharding("{devices=[1,2,2]0,2,1,3}")); +} + +TEST_F(ShardingPropagationTest, TransposeBackwardPass) { + const char* const hlo_string = R"( +HloModule module +ENTRY %transpose { + %param = f32[7,11,13]{2,1,0} parameter(0) + %copy = f32[7,11,13]{2,1,0} copy(%param) + ROOT %transpose = f32[11,13,7]{2,1,0} transpose(%copy), dimensions={1,2,0}, + sharding={devices=[1,2,2]0,1,2,3} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "copy"), + op::Sharding("{devices=[2,1,2]0,2,1,3}")); +} + +TEST_F(ShardingPropagationTest, ReshapeForwardPass) { + const char* const hlo_string = R"( +HloModule module +ENTRY %reshape { + %param0 = f32[1430,1]{1,0} parameter(0), + sharding={devices=[2,1]0,1} + %reshape = f32[10,11,13]{2,1,0} reshape(%param0) + ROOT %copy = f32[10,11,13]{2,1,0} copy(%reshape) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "reshape"), + op::Sharding("{devices=[2,1,1]0,1}")); +} + +TEST_F(ShardingPropagationTest, ReshapeBackwardPass) { + const char* const hlo_string = R"( +HloModule module +ENTRY %reshape { + %param0 = f32[2002,1]{1,0} parameter(0) + %copy = f32[2002,1]{1,0} copy(f32[2002,1]{1,0} %param0) + ROOT %reshape = f32[14,11,13]{2,1,0} reshape(%copy), + sharding={devices=[2,1,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "copy"), + op::Sharding("{devices=[2,1]0,1}")); +} + +TEST_F(ShardingPropagationTest, PadForwardPass) { + const char* const hlo_string = R"( +HloModule module +ENTRY %pad { + %input = f32[11,17]{1,0} parameter(0), + sharding={devices=[2,2]0,1,2,3} + %pad_value = f32[] parameter(1) + %pad = f32[27,51]{1,0} pad(%input, %pad_value), padding=2_4_1x1_1_2 + ROOT %copy = f32[27,51]{1,0} copy(%pad) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "pad"), + op::Sharding("{devices=[2,2]0,1,2,3}")); +} + +TEST_F(ShardingPropagationTest, ShardedPreferredOverReplicated) { + const char* const hlo_string = R"( +HloModule module +ENTRY %replicated { + %param0 = f32[5,7,11,13]{3,2,1,0} parameter(0), + sharding={replicated} + %copy = f32[5,7,11,13]{3,2,1,0} copy(%param0) + %param1 = f32[5,7,11,13]{3,2,1,0} parameter(1), + sharding={devices=[1,2,2,1]0,1,2,3} + %copy.1 = f32[5,7,11,13]{3,2,1,0} copy(%param1) + %add = f32[5,7,11,13]{3,2,1,0} add(%copy, %copy.1) + ROOT %copy.2 = f32[5,7,11,13]{3,2,1,0} copy(%add) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "copy"), + op::Sharding("{devices=[1,2,2,1]0,1,2,3}")); + EXPECT_THAT(FindInstruction(module.get(), "copy.1"), + op::Sharding("{devices=[1,2,2,1]0,1,2,3}")); + EXPECT_THAT(FindInstruction(module.get(), "add"), + op::Sharding("{devices=[1,2,2,1]0,1,2,3}")); +} + +TEST_F(ShardingPropagationTest, DontShardTuplesIfAllInputIsMaximal) { + const char* const hlo_string = R"( +HloModule module +ENTRY %tuple { + %param0 = f32[5,7,11,13]{3,2,1,0} parameter(0), + sharding={maximal device=0} + %param1 = f32[5,7,11,13]{3,2,1,0} parameter(1), + sharding={maximal device=1} + %tuple = (f32[5,7,11,13]{3,2,1,0}, f32[5,7,11,13]{3,2,1,0}) tuple( + %param0, %param1) + ROOT %copy = (f32[5,7,11,13]{3,2,1,0}, f32[5,7,11,13]{3,2,1,0}) copy(%tuple) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_FALSE(changed); + EXPECT_THAT(FindInstruction(module.get(), "tuple"), op::NoSharding()); +} + +TEST_F(ShardingPropagationTest, ValidConvolution) { + const char* const hlo_string = R"( +HloModule module + +ENTRY conv { + %lhs = f32[13,17,19]{2,1,0} parameter(0), + sharding={devices=[1,2,1]0,1} + %rhs = f32[19,5,19]{2,1,0} parameter(1) + %conv = f32[13,13,19]{2,1,0} convolution(%lhs, %rhs), + window={size=5}, dim_labels=b0f_i0o->b0f + ROOT %tuple = (f32[13,13,19]{2,1,0}) tuple(%conv) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "conv"), + op::Sharding("{devices=[1,2,1]0,1}")); +} + +TEST_F(ShardingPropagationTest, StridedSlice) { + const char* const hlo_string = R"( +HloModule module + +ENTRY %slice { + %param = f32[17,13]{1,0} parameter(0), + sharding={devices=[2,1]0,1} + %slice = f32[7,5]{1,0} slice(%param), slice={[1:15:2], [5:10:1]} + ROOT %tuple = (f32[7,5]{1,0}) tuple(%slice) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "slice"), + op::Sharding("{devices=[2,1]0,1}")); +} + +TEST_F(ShardingPropagationTest, ReduceWindowBackwardPass) { + const char* const hlo_string = R"( +HloModule module +%add (lhs: f32[], rhs: f32[]) -> f32[] { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %add = f32[] add(%lhs, %rhs) +} +ENTRY %reduce_window { + %param = f32[13,17]{1,0} parameter(0) + %param.copy = f32[13,17]{1,0} copy(%param) + %init = f32[] parameter(1) + ROOT %reduce-window = f32[7,17]{1,0} reduce-window(%param.copy, %init), + window={size=3x2 stride=2x1 pad=1_1x0_1}, to_apply=%add, + sharding={devices=[2,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "param.copy"), + op::Sharding("{devices=[2,1]0,1}")); + EXPECT_THAT(FindInstruction(module.get(), "reduce-window"), + op::Sharding("{devices=[2,1]0,1}")); +} + +TEST_F(ShardingPropagationTest, ReplicatedConvolutionLhs) { + const char* const hlo_string = R"( +HloModule module + +ENTRY conv { + %lhs = f32[3,2,3]{2,1,0} parameter(0), sharding={replicated} + %rhs = f32[2,2,1]{2,1,0} parameter(1) + %conv = f32[3,2,3]{2,1,0} convolution(%lhs, %rhs), + window={size=1}, dim_labels=bf0_oi0->bf0 + ROOT %tuple = f32[3,2,3]{2,1,0} tuple(%conv) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "lhs"), + op::Sharding("{replicated}")); + EXPECT_THAT(FindInstruction(module.get(), "conv"), + op::Sharding("{replicated}")); +} + +TEST_F(ShardingPropagationTest, ConvolutionShardedFeature) { + const char* const hlo_string = R"( +HloModule module + +ENTRY conv { + %lhs = f32[3,2,3]{2,1,0} parameter(0), + sharding={devices=[1,2,1]0,1} + %rhs = f32[2,2,1]{2,1,0} parameter(1) + %conv = f32[3,2,3]{2,1,0} convolution(%lhs, %rhs), + window={size=1}, dim_labels=bf0_oi0->bf0 + ROOT %tuple = f32[3,2,3]{2,1,0} tuple(%conv) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_FALSE(changed); +} + +TEST_F(ShardingPropagationTest, ConvolutionDifferentDimensionNumbers) { + const char* const hlo_string = R"( +HloModule module + +ENTRY conv { + %lhs = f32[8,16,512] parameter(0), + sharding={devices=[1,2,1]0,1} + %rhs = f32[8,2,512] parameter(1) + %conv = f32[3,512,512] convolution(%lhs, %rhs), + window={size=2 stride=5}, + dim_labels=f0b_i0o->0bf + ROOT %tuple = f32[3,512,512] tuple(%conv) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "conv"), + op::Sharding("{devices=[2,1,1]0,1}")); +} + +TEST_F(ShardingPropagationTest, Concatenate) { + const char* const hlo_string = R"( +HloModule module + +ENTRY %concat { + %param.0 = f32[5,7] parameter(0), + sharding={devices=[2,1]0,1} + %param.1 = f32[5,9] parameter(1), + sharding={devices=[2,1]0,1} + %concat = f32[5,16] concatenate(%param.0, %param.1), + dimensions={1} + ROOT %tuple = (f32[5,16]) tuple(%concat) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "concat"), + op::Sharding("{devices=[2,1]0,1}")); +} + +TEST_F(ShardingPropagationTest, TupleBackwardPass) { + const char* const hlo_string = R"( +HloModule module + +ENTRY %tuple { + %param.0 = f32[1] parameter(0) + %param.1 = f32[3] parameter(1) + %copy.0 = f32[1] copy(%param.0) + %copy.1 = f32[3] copy(param.1) + ROOT %tuple = (f32[1], f32[3]) tuple(%copy.0, %copy.1), + sharding={{replicated}, {devices=[2]0,1}} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "copy.0"), + op::Sharding("{replicated}")); + EXPECT_THAT(FindInstruction(module.get(), "copy.1"), + op::Sharding("{devices=[2]0,1}")); +} + +TEST_F(ShardingPropagationTest, AllReduce) { + const char* const hlo_string = R"( +HloModule module + +%add (lhs: f32[], rhs: f32[]) -> f32[] { + %add_lhs = f32[] parameter(0) + %add_rhs = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %add_lhs, f32[] %add_rhs) +} + +ENTRY %entry { + %param.0 = f32[3] parameter(0) + %param.1 = f32[3] parameter(1) + + %copy_f_t = f32[3] copy(%param.1), sharding={devices=[2]0,1} + %crs_f.tiled = f32[3] all-reduce(%copy_f_t), to_apply=%add + %crs_f.none = f32[3] all-reduce(%copy_f_t), to_apply=%add, + channel_id=1 + + %crs_b.replicated = f32[3] all-reduce(%param.0), to_apply=%add + %copy_b_r = f32[3] copy(%crs_b.replicated), sharding={replicated} + + ROOT %tuple = (f32[3], f32[3], f32[3], f32[3]) tuple( + %crs_f.tiled, crs_f.none, %copy_b_r) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "crs_f.tiled"), + op::Sharding("{devices=[2]0,1}")); + EXPECT_THAT(FindInstruction(module.get(), "crs_f.none"), op::NoSharding()); + + EXPECT_THAT(FindInstruction(module.get(), "crs_b.replicated"), + op::Sharding("{replicated}")); +} + +TEST_F(ShardingPropagationTest, While) { + const char* const hlo_string = R"( +HloModule module + +%cond { + %vars.cond = (u32[], f32[10]{0}) parameter(0) + %count.cond = u32[] get-tuple-element((u32[], f32[10]{0}) %vars.cond), index=0 + %limit = u32[] constant(10) + ROOT %lt = pred[] compare(u32[] %count.cond, u32[] %limit), direction=LT +} + +%body { + %vars = (u32[], f32[10]{0}) parameter(0) + %count = u32[] get-tuple-element(%vars), index=0 + %acc = f32[10]{0} get-tuple-element((u32[], f32[10]{0}) %vars), index=1 + + %one = u32[] constant(1) + %count.1 = u32[] add(u32[] %count, u32[] %one), sharding={replicated} + %acc.1 = f32[10]{0} add(f32[10]{0} %acc, f32[10]{0} %acc) + ROOT %tuple = (u32[], f32[10]{0}) tuple(u32[] %count.1, f32[10]{0} %acc.1) +} + +ENTRY %entry { + %p0 = f32[10]{0} parameter(0) + %p0.copy = f32[10]{0} copy(f32[10]{0} %p0) + %p1 = f32[10]{0} parameter(1) + %zero = u32[] constant(0) + %init = (u32[], f32[10]{0}) tuple(u32[] %zero, f32[10]{0} %p0.copy) + %while = (u32[], f32[10]{0}) while((u32[], f32[10]{0}) %init), + body=%body, condition=%cond + %res = f32[10]{0} get-tuple-element((u32[], f32[10]{0}) %while), index=1 + %prev = f32[10]{0} get-tuple-element((u32[], f32[10]{0}) %init), index=1 + %res.1 = f32[10]{0} multiply(f32[10]{0} %res, %prev) + ROOT %res_tuple = (f32[10]{0}) tuple(f32[10]{0} %res.1) +})"; + + auto while_is_sharded = [this](HloModule* module, + const HloSharding& sharding) { + TF_ASSERT_OK_AND_ASSIGN(bool changed, ShardingPropagation().Run(module)); + EXPECT_TRUE(changed); + auto while_instr = FindInstruction(module, "while"); + EXPECT_NE(nullptr, while_instr); + std::vector<const HloInstruction*> instructions{ + while_instr, while_instr->while_body()->root_instruction(), + while_instr->while_body()->parameter_instruction(0), + while_instr->while_condition()->parameter_instruction(0)}; + + for (auto instr : instructions) { + EXPECT_TRUE(instr->has_sharding()); + EXPECT_EQ(sharding, instr->sharding()); + } + }; + { + // Propagation of user-defined partial sharding of while-related instruction + // (body root in this test). + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + auto body_root = FindInstruction(module.get(), "tuple"); + EXPECT_NE(nullptr, body_root); + auto sharding = + ParseSharding("{{replicated}, {devices=[2]0,1}}").ConsumeValueOrDie(); + body_root->set_sharding(sharding); + while_is_sharded(module.get(), sharding); + } + { + // Propagation from acc.1 to the rest of the loop. + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + auto acc_1 = FindInstruction(module.get(), "acc.1"); + EXPECT_NE(nullptr, acc_1); + acc_1->set_sharding(ParseSharding("{devices=[2]0,1}").ConsumeValueOrDie()); + + while_is_sharded( + module.get(), + ParseSharding("{{replicated}, {devices=[2]0,1}}").ConsumeValueOrDie()); + } +} + +TEST_F(ShardingPropagationTest, Dot) { + const char* const hlo_string = R"( +HloModule module +ENTRY %conv { + %param.0 = f32[8,256,128] parameter(0) + %param.1 = f32[8,128,512] parameter(1) + %param.2 = f32[8,128] parameter(2) + + %p0_copy_0 = f32[8,256,128] copy(%param.0), + sharding={devices=[1,4,1]0,1,2,3} + %p1_copy_0 = f32[8,128,512] copy(%param.1), + sharding={devices=[1,2,2]0,1,2,3} + %p2_copy = f32[8,128] copy(%param.2) + %dot_prop_rhs = f32[8,256,512] dot(%p0_copy_0, %p1_copy_0), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={1} + %dot_prop_lhs = f32[8,512,256] dot(%p1_copy_0, %p0_copy_0), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={1}, rhs_contracting_dims={2} + %dot_mat_vec = f32[8,256] dot(%p0_copy_0, %p2_copy), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={1} + + %p0_copy_1 = f32[8,256,128] copy(%param.0) + %p1_copy_1 = f32[8,128,512] copy(%param.1) + %dot_back_prop_rhs = f32[8,256,512] dot(%p0_copy_1, %p1_copy_1), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={1} + %copy_back_prop_rhs = f32[8,256,512] copy(%dot_back_prop_rhs), + sharding={devices=[1,2,2]0,1,2,3} + + ROOT %tuple = (f32[8,256,256], f32[8,256,256], f32[8,256]) + tuple(%dot_prop_lhs, %dot_prop_rhs, %dot_mat_vec, %copy_back_prop_rhs) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "dot_prop_rhs"), + op::Sharding("{devices=[1,2,2]0,1,2,3}")); + EXPECT_THAT(FindInstruction(module.get(), "dot_prop_lhs"), + op::Sharding("{devices=[1,2,2]0,1,2,3}")); + EXPECT_THAT(FindInstruction(module.get(), "dot_mat_vec"), + op::Sharding("{devices=[1,4]0,1,2,3}")); + + EXPECT_THAT(FindInstruction(module.get(), "p0_copy_1"), + op::Sharding("{replicated}")); + EXPECT_THAT(FindInstruction(module.get(), "p1_copy_1"), + op::Sharding("{devices=[1,2,2]0,1,2,3}")); + EXPECT_THAT(FindInstruction(module.get(), "dot_back_prop_rhs"), + op::Sharding("{devices=[1,2,2]0,1,2,3}")); +} + +TEST_F(ShardingPropagationTest, DotTiledBatchDim) { + const char* const hlo_string = R"( +HloModule module +ENTRY %conv { + %p0 = f32[8,256,512] parameter(0) + %p1 = f32[8,512,128] parameter(1) + + %add = f32[8,256,512] add(%p0, %p0) + %dot = f32[8,256,128] dot(%add, %p1), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={1} + %res = f32[8,32768] reshape(%dot), sharding={devices=[2,2]0,1,2,3} + + ROOT %tuple = (f32[8,32768]) tuple(%res) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "add"), + op::Sharding("{devices=[2,2,1]0,1,2,3}")); +} + +TEST_F(ShardingPropagationTest, ConcatFromUserUnshardedDim) { + const char* const hlo_string = R"( +HloModule module +ENTRY %conv { + %p0 = f32[8,128] parameter(0) + %p1 = f32[8,128] parameter(1) + %c0 = f32[8,128] copy(%p0) + %c1 = f32[8,128] copy(%p1) + + %concat = f32[16,128] concatenate(%c0, %c1), + dimensions={0}, + sharding={devices=[1,2]0,1} + ROOT %tuple = (f32[16,128]) tuple(%concat) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "c0"), + op::Sharding("{devices=[1,2]0,1}")); + EXPECT_THAT(FindInstruction(module.get(), "c1"), + op::Sharding("{devices=[1,2]0,1}")); +} + +TEST_F(ShardingPropagationTest, ConcatFromUserShardedDim) { + const char* const hlo_string = R"( +HloModule module +ENTRY %conv { + %p0 = f32[8,128] parameter(0) + %p1 = f32[8,128] parameter(1) + %c0 = f32[8,128] copy(%p0) + %c1 = f32[8,128] copy(%p1) + + %concat = f32[16,128] concatenate(%c0, %c1), + dimensions={0}, + sharding={devices=[3,1]0,1,2} + ROOT %tuple = (f32[16,128]) tuple(%concat) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "c0"), + op::Sharding("{devices=[2,1]0,1}")); + EXPECT_THAT(FindInstruction(module.get(), "c1"), + op::Sharding("{devices=[2,1]1,2}")); +} + +TEST_F(ShardingPropagationTest, ConcatFromUserShardedDimMaximalOperand) { + const char* const hlo_string = R"( +HloModule module +ENTRY %conv { + %p0 = f32[8,128] parameter(0) + %p1 = f32[24,128] parameter(1) + %c0 = f32[8,128] copy(%p0) + %c1 = f32[24,128] copy(%p1) + + %concat = f32[32,128] concatenate(%c0, %c1), + dimensions={0}, + sharding={devices=[4,1]0,1,2,3} + ROOT %tuple = (f32[32,128]) tuple(%concat) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "c0"), op::NoSharding()); + EXPECT_THAT(FindInstruction(module.get(), "c1"), + op::Sharding("{devices=[3,1]1,2,3}")); +} + +TEST_F(ShardingPropagationTest, ReplicatedToSideEffecting) { + const char* const hlo_string = R"( +HloModule module +ENTRY entry_computation { + %const.0 = s32[] constant(0), sharding={replicated} + %const.1 = s32[] constant(2147483647), sharding={replicated} + %rng = s32[4]{0} rng(%const.0, %const.1), + distribution=rng_uniform + ROOT %root = (s32[4]{0}) tuple(%rng) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_FALSE(changed); + EXPECT_THAT(FindInstruction(module.get(), "rng"), op::NoSharding()); +} + +TEST_F(ShardingPropagationTest, PartReplicatedTupleUser) { + const char* const hlo_string = R"( +HloModule module +ENTRY entry_computation { + %param.0 = f32[5] parameter(0) + %param.1 = f32[7] parameter(1) + %param.2 = f32[9] parameter(2) + %tuple.0 = (f32[5], f32[7]) tuple(%param.0, %param.1) + ROOT %tuple.1 = ((f32[5], f32[7]), f32[9]) tuple(%tuple.0, %param.2), + sharding={{maximal device=0}, {replicated}, {maximal device=1}} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "tuple.0"), + op::Sharding("{{maximal device=0}, {replicated}}")); +} + +TEST_F(ShardingPropagationTest, Conditional) { + const char* const hlo_string = R"( +HloModule module + +%true_comp { + %tp = (f32[3,5]) parameter(0) + %tgte = f32[3,5] get-tuple-element(%tp), index=0 + %ttr = f32[5,3] transpose(%tgte), dimensions={1,0} + ROOT %tr = (f32[5,3]) tuple(%ttr) +} + +%false_comp { + %fp = (f32[5,3]) parameter(0) + %fgte = f32[5,3] get-tuple-element(%fp), index=0 + ROOT %fr = (f32[5,3]) tuple(%fgte) +} + +ENTRY entry { + %cond = pred[] parameter(0) + %true_param = (f32[3,5]) parameter(1), sharding={{devices=[1,2]0,1}} + %false_param = (f32[5,3]) parameter(2), sharding={{devices=[1,3]0,1,2}} + %conditional = (f32[5,3]) conditional( + %cond, %true_param, %false_param), + true_computation=%true_comp, + false_computation=%false_comp + ROOT %root = f32[5,3] get-tuple-element(%conditional), index=0 +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "tp"), + op::Sharding("{{devices=[1,2]0,1}}")); + EXPECT_THAT(FindInstruction(module.get(), "tgte"), + op::Sharding("{devices=[1,2]0,1}")); + EXPECT_THAT(FindInstruction(module.get(), "ttr"), + op::Sharding("{devices=[2,1]0,1}")); + EXPECT_THAT(FindInstruction(module.get(), "tr"), + op::Sharding("{{devices=[2,1]0,1}}")); + EXPECT_THAT(FindInstruction(module.get(), "fp"), + op::Sharding("{{devices=[1,3]0,1,2}}")); + EXPECT_THAT(FindInstruction(module.get(), "fgte"), + op::Sharding("{devices=[1,3]0,1,2}")); + EXPECT_THAT(FindInstruction(module.get(), "fr"), + op::Sharding("{{devices=[2,1]0,1}}")); + EXPECT_THAT(FindInstruction(module.get(), "conditional"), + op::Sharding("{{devices=[2,1]0,1}}")); +} + +TEST_F(ShardingPropagationTest, TupleFromUser) { + const char* const hlo_string = R"( +HloModule module +ENTRY %entry { + %p0 = f32[13] parameter(0) + %p1 = f32[15] parameter(1) + %p2 = f32[17] parameter(2) + %t0 = (f32[13], f32[15]) tuple(%p0, %p1) + %t1 = ((f32[13], f32[15]), f32[17]) tuple(%t0, %p2) + %gte.0 = (f32[13], f32[15]) get-tuple-element(%t1), index=0 + %gte.1 = f32[13] get-tuple-element(%gte.0), index=0 + %gte.2 = f32[15] get-tuple-element(%gte.0), index=1 + %gte.3 = f32[17] get-tuple-element(%t1), index=1 + ROOT %t2 = (f32[13], f32[15], f32[17]) tuple(%gte.1, %gte.2, %gte.3), + sharding={{replicated}, {devices=[2]0,1}, {devices=[3]1,2,3}} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "t0"), + op::Sharding("{{replicated}, {devices=[2]0,1}}")); + EXPECT_THAT( + FindInstruction(module.get(), "t1"), + op::Sharding("{{replicated}, {devices=[2]0,1}, {devices=[3]1,2,3}}")); +} + +TEST_F(ShardingPropagationTest, DynamicSliceForwardPass) { + const char* hlo_string = R"( +HloModule module +ENTRY %entry { + %p0 = f32[11,13,15] parameter(0) + %c0 = f32[11,13,15] copy(%p0), sharding={devices=[1,1,2]0,1} + %p1 = s32[] parameter(1) + %i0 = s32[] constant(0) + %ds = f32[11,1,15] dynamic-slice(%c0, %i0, %p1, %i0), + dynamic_slice_sizes={11,1,15} + ROOT %root = (f32[11,1,15]) tuple(%ds) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "ds"), + op::Sharding("{devices=[1,1,2]0,1}")); +} + +TEST_F(ShardingPropagationTest, DynamicSliceBackwardPass) { + const char* hlo_string = R"( +HloModule module +ENTRY %entry { + %p0 = f32[11,13,15] parameter(0) + %c0 = f32[11,13,15] copy(%p0) + %p1 = s32[] parameter(1) + %i0 = s32[] constant(0) + %ds = f32[11,1,15] dynamic-slice(%c0, %i0, %p1, %i0), + dynamic_slice_sizes={11,1,15}, + sharding={devices=[1,1,2]0,1} + ROOT %root = (f32[11,1,15]) tuple(%ds) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "ds"), + op::Sharding("{devices=[1,1,2]0,1}")); +} + +TEST_F(ShardingPropagationTest, DynamicUpdateSliceForwardPassBase) { + const char* hlo_string = R"( +HloModule module +ENTRY %entry { + %p0 = f32[11,13,15] parameter(0) + %c0 = f32[11,13,15] copy(%p0), sharding={devices=[1,1,2]0,1} + %p1 = f32[11,1,15] parameter(1) + %c1 = f32[11,1,15] copy(%p1) + %p2 = s32[] parameter(2) + %i0 = s32[] constant(0) + %dus = f32[11,13,15] dynamic-update-slice(%c0, %c1, %i0, %p2, %i0) + ROOT %root = (f32[11,13,15]) tuple(%dus) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "dus"), + op::Sharding("{devices=[1,1,2]0,1}")); + EXPECT_THAT(FindInstruction(module.get(), "c1"), + op::Sharding("{devices=[1,1,2]0,1}")); +} + +TEST_F(ShardingPropagationTest, DynamicUpdateSliceForwardPassUpdate) { + const char* hlo_string = R"( +HloModule module +ENTRY %entry { + %p0 = f32[11,13,15] parameter(0) + %c0 = f32[11,13,15] copy(%p0) + %p1 = f32[11,1,15] parameter(1) + %c1 = f32[11,1,15] copy(%p1), sharding={devices=[1,1,2]0,1} + %p2 = s32[] parameter(2) + %i0 = s32[] constant(0) + %dus = f32[11,13,15] dynamic-update-slice(%c0, %c1, %i0, %p2, %i0) + ROOT %root = (f32[11,13,15]) tuple(%dus) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "dus"), + op::Sharding("{devices=[1,1,2]0,1}")); + EXPECT_THAT(FindInstruction(module.get(), "c0"), + op::Sharding("{devices=[1,1,2]0,1}")); +} + +TEST_F(ShardingPropagationTest, DynamicUpdateSliceBackwardPass) { + const char* hlo_string = R"( +HloModule module +ENTRY %entry { + %p0 = f32[11,13,15] parameter(0) + %c0 = f32[11,13,15] copy(%p0) + %p1 = f32[11,1,15] parameter(1) + %c1 = f32[11,1,15] copy(%p1) + %p2 = s32[] parameter(2) + %i0 = s32[] constant(0) + %dus = f32[11,13,15] dynamic-update-slice(%c0, %c1, %i0, %p2, %i0), + sharding={devices=[1,1,2]0,1} + ROOT %root = (f32[11,13,15]) tuple(%dus) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "c0"), + op::Sharding("{devices=[1,1,2]0,1}")); + EXPECT_THAT(FindInstruction(module.get(), "c1"), + op::Sharding("{devices=[1,1,2]0,1}")); +} + +TEST_F(ShardingPropagationTest, EinsumLHSBatchPartitioned) { + const char* hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,64] parameter(0) + %lhs.copy = f32[32,24,64] copy(%lhs), sharding={devices=[2,1,1]0,1} + %rhs = f32[32,39296,64] parameter(1) + %rhs.copy = f32[32,39296,64] copy(%rhs) + %conv = f32[32,24,39296] convolution(%lhs.copy, %rhs.copy), + dim_labels=0bf_0oi->0bf, window={size=32 stride=31 lhs_dilate=32} + ROOT %copy = f32[32,24,39296] copy(%conv) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "rhs.copy"), + op::Sharding("{devices=[2,1,1]0,1}")); + EXPECT_THAT(FindInstruction(module.get(), "conv"), + op::Sharding("{devices=[2,1,1]0,1}")); +} + +TEST_F(ShardingPropagationTest, EinsumOutputBatchPartitioned) { + const char* hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,64] parameter(0) + %lhs.copy = f32[32,24,64] copy(%lhs) + %rhs = f32[32,39296,64] parameter(1) + %rhs.copy = f32[32,39296,64] copy(%rhs) + %conv = f32[32,24,39296] convolution(%lhs.copy, %rhs.copy), + dim_labels=0bf_0oi->0bf, window={size=32 stride=31 lhs_dilate=32}, + sharding={devices=[2,1,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "lhs.copy"), + op::Sharding("{devices=[2,1,1]0,1}")); + EXPECT_THAT(FindInstruction(module.get(), "rhs.copy"), + op::Sharding("{devices=[2,1,1]0,1}")); +} + +TEST_F(ShardingPropagationTest, EinsumLHSNonContractingPartitioned) { + const char* hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,64,128] parameter(0) + %lhs.copy = f32[32,24,64,128] copy(%lhs), sharding={devices=[1,2,1,2]0,1,2,3} + %rhs = f32[32,39296,64,1] parameter(1) + %rhs.copy = f32[32,39296,64,1] copy(%rhs) + %conv = f32[32,24,39296,128] convolution(%lhs.copy, %rhs.copy), + dim_labels=0bf1_0oi1->0bf1, window={size=32x1 stride=31x1 lhs_dilate=32x1} + ROOT %copy = f32[32,24,39296,128] copy(%conv) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "conv"), + op::Sharding("{devices=[1,2,1,2]0,1,2,3}")); +} + +TEST_F(ShardingPropagationTest, EinsumOutputLHSNonContractingPartitioned) { + const char* hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,64,128] parameter(0) + %lhs.copy = f32[32,24,64,128] copy(%lhs) + %rhs = f32[32,39296,64,1] parameter(1) + %rhs.copy = f32[32,39296,64,1] copy(%rhs) + ROOT %conv = f32[32,24,39296,128] convolution(%lhs.copy, %rhs.copy), + dim_labels=0bf1_0oi1->0bf1, window={size=32x1 stride=31x1 lhs_dilate=32x1}, + sharding={devices=[1,2,1,2]0,1,2,3} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "lhs.copy"), + op::Sharding("{devices=[1,2,1,2]0,1,2,3}")); +} + +TEST_F(ShardingPropagationTest, EinsumRHSNonContractingPartitioned) { + const char* hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,64,1] parameter(0) + %lhs.copy = f32[32,24,64,1] copy(%lhs) + %rhs = f32[32,39296,64,128] parameter(1) + %rhs.copy = f32[32,39296,64,128] copy(%rhs), sharding={devices=[1,2,1,2]0,1,2,3} + %conv = f32[32,24,39296,128] convolution(%lhs.copy, %rhs.copy), + dim_labels=0bf1_0oi1->0bf1, + window={size=32x128 stride=31x1 pad=0_0x127_127 lhs_dilate=32x1 rhs_reversal=0x1} + ROOT %copy = f32[32,24,39296,128] copy(%conv) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "conv"), + op::Sharding("{devices=[1,1,2,2]0,1,2,3}")); +} + +TEST_F(ShardingPropagationTest, EinsumOutputRHSNonContractingPartitioned) { + const char* hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,64,1] parameter(0) + %lhs.copy = f32[32,24,64,1] copy(%lhs) + %rhs = f32[32,39296,64,128] parameter(1) + %rhs.copy = f32[32,39296,64,128] copy(%rhs) + ROOT %conv = f32[32,24,39296,128] convolution(%lhs.copy, %rhs.copy), + dim_labels=0bf1_0oi1->0bf1, + window={size=32x128 stride=31x1 pad=0_0x127_127 lhs_dilate=32x1 rhs_reversal=0x1}, + sharding={devices=[1,1,2,2]0,1,2,3} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "rhs.copy"), + op::Sharding("{devices=[1,2,1,2]0,1,2,3}")); +} + +TEST_F(ShardingPropagationTest, EinsumChooseLargerOperand) { + const char* hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,64,1] parameter(0) + %lhs.copy = f32[32,24,64,1] copy(%lhs), sharding={devices=[1,4,1,1]0,1,2,3} + %rhs = f32[32,39296,64,128] parameter(1) + %rhs.copy = f32[32,39296,64,128] copy(%rhs), sharding={devices=[1,2,1,2]0,1,2,3} + %conv = f32[32,24,39296,128] convolution(%lhs.copy, %rhs.copy), + dim_labels=0bf1_0oi1->0bf1, + window={size=32x128 stride=31x1 pad=0_0x127_127 lhs_dilate=32x1 rhs_reversal=0x1} + ROOT %copy = f32[32,24,39296,128] copy(%conv) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "conv"), + op::Sharding("{devices=[1,1,2,2]0,1,2,3}")); +} + +TEST_F(ShardingPropagationTest, EinsumChooseBatchFirst) { + const char* hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,24,64,1] parameter(0) + %lhs.copy = f32[32,24,64,1] copy(%lhs), sharding={devices=[1,2,1,1]0,1} + %rhs = f32[32,39296,64,128] parameter(1) + %rhs.copy = f32[32,39296,64,128] copy(%rhs), sharding={devices=[2,1,1,1]0,1} + %conv = f32[32,24,39296,128] convolution(%lhs.copy, %rhs.copy), + dim_labels=0bf1_0oi1->0bf1, + window={size=32x128 stride=31x1 pad=0_0x127_127 lhs_dilate=32x1 rhs_reversal=0x1} + ROOT %copy = f32[32,24,39296,128] copy(%conv) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "conv"), + op::Sharding("{devices=[2,1,1,1]0,1}")); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/spmd/BUILD b/tensorflow/compiler/xla/service/spmd/BUILD index 5be6a04f934..280af2246bb 100644 --- a/tensorflow/compiler/xla/service/spmd/BUILD +++ b/tensorflow/compiler/xla/service/spmd/BUILD @@ -33,6 +33,7 @@ cc_library( "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client/lib:comparators", + "//tensorflow/compiler/xla/service:dot_as_convolution_util", "//tensorflow/compiler/xla/service:flatten_call_graph", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_casting_utils", diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc index b857c8bdbe6..8eee452328e 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/dot_as_convolution_util.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -670,26 +671,34 @@ PartitionedHlo PartitionedHlo::Replicate() { } // 'Tiled' to 'Replicated'. + HloInstruction* result = nullptr; + if (state_.collective_ops_creator.create_cross_partition_all_gather) { + result = state_.partitioner->AllGatherShards(state_.b, hlo_, sharding, + NewChannel()); + } Shape padded_base_shape = shape; for (int64 i = 0; i < padded_base_shape.rank(); ++i) { padded_base_shape.set_dimensions( i, shape.dimensions(i) * sharding.tile_assignment().dim(i)); } - auto zero = state_.b->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type()))); - auto zero_bcast = state_.b->AddInstruction( - HloInstruction::CreateBroadcast(padded_base_shape, zero, {})); - auto dus = state_.b->AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - padded_base_shape, zero_bcast, hlo_, - MakePartitionOffsets(padded_base_shape, sharding, state_.partition_id, - state_.b))); - HloComputation* reduction = - MakeBinaryAdd(shape.element_type(), state_.module); + if (result == nullptr) { + auto zero = state_.b->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(shape.element_type()))); + auto zero_bcast = state_.b->AddInstruction( + HloInstruction::CreateBroadcast(padded_base_shape, zero, {})); + auto dus = + state_.b->AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + padded_base_shape, zero_bcast, hlo_, + MakePartitionOffsets(padded_base_shape, sharding, + state_.partition_id, state_.b))); + HloComputation* reduction = + MakeBinaryAdd(shape.element_type(), state_.module); - auto all_reduce = - state_.collective_ops_creator.create_cross_partition_all_reduce( - state_.b, dus, reduction, NewChannel()); - HloInstruction* result = all_reduce; + auto all_reduce = + state_.collective_ops_creator.create_cross_partition_all_reduce( + state_.b, dus, reduction, NewChannel()); + result = all_reduce; + } if (!ShapeUtil::Compatible(base_shape_, padded_base_shape)) { std::vector<int64> start_indices(shape.rank(), 0); std::vector<int64> strides(shape.rank(), 1); @@ -2897,6 +2906,46 @@ Status SpmdPartitioningVisitor::HandleConvolutionTiledLhsAndRhs( } Status SpmdPartitioningVisitor::HandleConvolution(HloInstruction* hlo) { + auto dot_dnums = dot_as_convolution_util::ParseDotGeneralFromConvolution(hlo); + if (dot_dnums) { + // Use HandleDotHelper() for convs that are actually einsums. + spmd::DotGeneralDimsMapping mapping; + for (const auto& dims : dot_dnums->batch_dims) { + mapping.batch_dims.emplace_back(); + mapping.batch_dims.back().lhs = dims.lhs; + mapping.batch_dims.back().rhs = dims.rhs; + mapping.batch_dims.back().output = dims.output; + } + for (const auto& dims : dot_dnums->contracting_dims) { + mapping.contracting_dims.emplace_back(); + mapping.contracting_dims.back().lhs = dims.lhs; + mapping.contracting_dims.back().rhs = dims.rhs; + mapping.contracting_dims.back().output = dims.output; + } + for (const auto& dims : dot_dnums->lhs_non_contracting_dims) { + mapping.lhs_non_contracting_dims.emplace_back(); + mapping.lhs_non_contracting_dims.back().lhs = dims.lhs; + mapping.lhs_non_contracting_dims.back().rhs = dims.rhs; + mapping.lhs_non_contracting_dims.back().output = dims.output; + } + for (const auto& dims : dot_dnums->rhs_non_contracting_dims) { + mapping.rhs_non_contracting_dims.emplace_back(); + mapping.rhs_non_contracting_dims.back().lhs = dims.lhs; + mapping.rhs_non_contracting_dims.back().rhs = dims.rhs; + mapping.rhs_non_contracting_dims.back().output = dims.output; + } + auto create_sharded_conv = + [&](HloInstruction* lhs_hlo, HloInstruction* rhs_hlo, + spmd::SpmdBuilder* b) -> StatusOr<HloInstruction*> { + TF_ASSIGN_OR_RETURN( + auto sharded_conv, + dot_as_convolution_util::CreateShardedConvForDotGeneralConvolution( + *hlo, *dot_dnums, lhs_hlo, rhs_hlo)); + return b->AddInstruction(std::move(sharded_conv)); + }; + return HandleDotHelper(hlo, mapping, create_sharded_conv); + } + auto lhs = GetPartitionedHlo(hlo->operand(0)); auto rhs = GetPartitionedHlo(hlo->operand(1)); const HloSharding& sharding = hlo->sharding(); @@ -4449,42 +4498,133 @@ Status SpmdPartitioningVisitor::HandlePartitionId(HloInstruction* hlo) { "the data is replicated, and if the latter which data is replicated."); } +SPMDCollectiveOpsCreator GetDefaultCollectiveOpsCreator(int64 num_partitions, + int64 num_replicas) { + return { + [](SpmdBuilder* b) { + return b->AddInstruction(HloInstruction::CreatePartitionId()); + }, + [num_replicas](SpmdBuilder* b, HloInstruction* operand, + HloComputation* reduction, int64 channel_id) { + return b->AddInstruction(HloInstruction::CreateAllReduce( + operand->shape(), {operand}, reduction, + CreateReplicaGroups(num_replicas), + /*constrain_layout=*/false, channel_id, + /*use_global_device_ids=*/false)); + }, + [](SpmdBuilder* b, HloInstruction* operand, + std::vector<std::pair<int64, int64>>& src_dst_pairs, + int64 channel_id) { + return b->AddInstruction(HloInstruction::CreateCollectivePermute( + operand->shape(), operand, src_dst_pairs, channel_id)); + }, + [](SpmdBuilder* b, absl::Span<HloInstruction* const> operands, + const std::vector<ReplicaGroup>& replica_groups, int64 channel_id, + absl::optional<int64> split_dimension) { + std::vector<Shape> shapes(operands.size(), operands[0]->shape()); + const Shape output_shape = (shapes.size() == 1) + ? shapes[0] + : ShapeUtil::MakeTupleShape(shapes); + return b->AddInstruction(HloInstruction::CreateAllToAll( + output_shape, operands, replica_groups, + /*constrain_layout=*/false, channel_id, split_dimension)); + }, + [num_replicas, num_partitions]( + SpmdBuilder* b, HloInstruction* operand, const Shape& ag_shape, + const std::vector<std::vector<int64>>& partition_subgroups, + int64 channel_id, int64 all_gather_dimension) { + std::vector<ReplicaGroup> device_groups; + device_groups.reserve(partition_subgroups.size() * num_replicas); + for (int64 i = 0; i < num_replicas; ++i) { + for (const auto& pgroup : partition_subgroups) { + device_groups.emplace_back(); + for (int64 pid : pgroup) { + device_groups.back().add_replica_ids(i * num_partitions + pid); + } + } + } + return b->AddInstruction(HloInstruction::CreateAllGather( + ag_shape, operand, all_gather_dimension, device_groups, + /*constrain_layout=*/false, channel_id, + /*use_global_device_ids=*/true)); + }, + }; +} + SpmdPartitioner::SpmdPartitioner(int64 num_partitions, int64 num_replicas, SpmdPartitionerOptions options) : SpmdPartitioner( num_partitions, num_replicas, std::move(options), - SPMDCollectiveOpsCreator{ - [](SpmdBuilder* b) { - return b->AddInstruction(HloInstruction::CreatePartitionId()); - }, - [num_replicas](SpmdBuilder* b, HloInstruction* operand, - HloComputation* reduction, int64 channel_id) { - return b->AddInstruction(HloInstruction::CreateAllReduce( - operand->shape(), {operand}, reduction, - CreateReplicaGroups(num_replicas), - /*constrain_layout=*/false, channel_id, - /*use_global_device_ids=*/false)); - }, - [](SpmdBuilder* b, HloInstruction* operand, - std::vector<std::pair<int64, int64>>& src_dst_pairs, - int64 channel_id) { - return b->AddInstruction( - HloInstruction::CreateCollectivePermute( - operand->shape(), operand, src_dst_pairs, channel_id)); - }, - [](SpmdBuilder* b, absl::Span<HloInstruction* const> operands, - const std::vector<ReplicaGroup>& replica_groups, - int64 channel_id, absl::optional<int64> split_dimension) { - std::vector<Shape> shapes(operands.size(), - operands[0]->shape()); - const Shape output_shape = - (shapes.size() == 1) ? shapes[0] - : ShapeUtil::MakeTupleShape(shapes); - return b->AddInstruction(HloInstruction::CreateAllToAll( - output_shape, operands, replica_groups, - /*constrain_layout=*/false, channel_id, split_dimension)); - }, - }) {} + GetDefaultCollectiveOpsCreator(num_partitions, num_replicas)) {} + +HloInstruction* SpmdPartitioner::AllGatherShards(SpmdBuilder* b, + HloInstruction* operand, + const HloSharding& sharding, + int64 channel_id) { + CHECK(!sharding.IsTileMaximal()); + // Add one leading dimension to gather all partitions. + std::vector<int64> shape; + shape.push_back(1); + for (int64 dim : operand->shape().dimensions()) { + shape.push_back(dim); + } + auto reshape = b->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(operand->shape().element_type(), shape), operand)); + std::vector<std::vector<int64>> partition_subgroups(1); + for (int64 pid : sharding.tile_assignment()) { + partition_subgroups[0].push_back(pid); + } + shape[0] = sharding.tile_assignment().num_elements(); + auto result = collective_ops_creator_.create_cross_partition_all_gather( + b, reshape, ShapeUtil::MakeShape(operand->shape().element_type(), shape), + partition_subgroups, channel_id, /*all_gather_dimension=*/0); + // If n > 1 dimensions are partitioned, split the leading dimension to n. + std::vector<int64> tiled_dims; + for (int64 i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) { + if (sharding.tile_assignment().dim(i) > 1) { + tiled_dims.push_back(i); + } + } + if (tiled_dims.size() > 1) { + std::vector<int64> split_dim_shape; + split_dim_shape.reserve(tiled_dims.size() + operand->shape().rank()); + for (int64 i : tiled_dims) { + split_dim_shape.push_back(sharding.tile_assignment().dim(i)); + } + for (int64 dim : operand->shape().dimensions()) { + split_dim_shape.push_back(dim); + } + result = b->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(operand->shape().element_type(), split_dim_shape), + result)); + } + // Transpose the gathered dimensions to next to their corresponding + // partitioned dimensions. + std::vector<int64> xpose_permutation(result->shape().rank()); + int64 split_dims_added = 0; + for (int64 i = 0; i < xpose_permutation.size(); ++i) { + if (sharding.tile_assignment().dim(i - split_dims_added) == 1) { + xpose_permutation[i] = i + tiled_dims.size() - split_dims_added; + } else { + xpose_permutation[i] = split_dims_added; + split_dims_added++; + xpose_permutation[i + 1] = i + tiled_dims.size(); + i++; + } + } + result = b->AddInstruction(HloInstruction::CreateTranspose( + ShapeInference::InferTransposeShape(result->shape(), xpose_permutation) + .ValueOrDie(), + result, xpose_permutation)); + // Reshape to the desired shape. + auto ag_shape = operand->shape(); + for (int64 i : tiled_dims) { + ag_shape.set_dimensions( + i, ag_shape.dimensions(i) * sharding.tile_assignment().dim(i)); + } + result = b->AddInstruction(HloInstruction::CreateReshape(ag_shape, result)); + return result; +} StatusOr<bool> SpmdPartitioner::PartitionComputation( HloComputation* computation, const HloSharding& root_sharding, diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h index f22f564be73..2918cd1ef58 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h @@ -99,8 +99,20 @@ struct SPMDCollectiveOpsCreator { const std::vector<ReplicaGroup>& replica_groups, int64 channel_id, absl::optional<int64> split_dimension)> create_cross_partition_all_to_all; + + // Function used to create a cross-partition all-gather HLO. This is optional: + // if it is nullptr, the partitioner will use all-reduce instead. + std::function<HloInstruction*( + SpmdBuilder*, HloInstruction* operand, const Shape& ag_shape, + const std::vector<std::vector<int64>>& partition_subgroups, + int64 channel_id, int64 all_gather_dimension)> + create_cross_partition_all_gather; }; +// Create a default SPMDCollectiveOpsCreator. +SPMDCollectiveOpsCreator GetDefaultCollectiveOpsCreator(int64 num_partitions, + int64 num_replicas); + // Logger to report memory usage during SPMD partitioning. class SpmdLogger { public: @@ -153,6 +165,15 @@ class SpmdPartitioner : public HloModulePass { int64* next_channel_id, SpmdLogger* logger); + // Creates all-gather based on HloSharding. Can be overridden to customize. + // The default uses a single all-gather even if there are multiple sharded + // dimensions, and adds potential reshapes and transposes to achieve that. + // If it returns false, the partitioner will fall back to all-reduce. + virtual HloInstruction* AllGatherShards(SpmdBuilder* b, + HloInstruction* operand, + const HloSharding& sharding, + int64 channel_id); + protected: virtual std::unique_ptr<SpmdPartitioningVisitor> CreateVisitor( HloComputation* computation, int64 num_partitions, int64 num_replicas, @@ -160,7 +181,6 @@ class SpmdPartitioner : public HloModulePass { int64* next_channel_id, SpmdLogger* logger, SpmdPartitionerOptions options); - private: // Verify that the sharding of instructions in the module are valid, and also // fill in missing sharding information. Status PreprocessSharding(HloModule* module); @@ -205,6 +225,7 @@ class PartitionedHlo { SPMDCollectiveOpsCreator collective_ops_creator; int64* next_channel_id; ReshardCache* reshard_cache; + SpmdPartitioner* partitioner; }; PartitionedHlo(HloInstruction* hlo, Shape base_shape, PartitioningState state) : hlo_(hlo), base_shape_(base_shape), state_(std::move(state)) { @@ -378,6 +399,7 @@ class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault { state.collective_ops_creator = collective_ops_creator_; state.next_channel_id = next_channel_id_; state.reshard_cache = &reshard_cache_; + state.partitioner = partitioner_; return state; } diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc index ca1afc816b0..55d7dc43785 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc @@ -41,13 +41,19 @@ class SpmdPartitioningTest : public HloTestBase { SpmdPartitionerOptions options; options.conv_halo_exchange_always_on_lhs = conv_halo_exchange_always_on_lhs; options.allow_module_signature_change = true; + auto collective_ops_creator = + GetDefaultCollectiveOpsCreator(num_devices, /*num_replicas=*/1); + // Do not use all-gather for pattern-matching purpose, as the partitioner + // might create reshape/transposes around it. + collective_ops_creator.create_cross_partition_all_gather = nullptr; TF_ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule( hlo_module, GetModuleConfigForTest())); HloPassPipeline pass("spmd-partitioning"); pass.AddPass<HloVerifier>(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); - pass.AddPass<SpmdPartitioner>(num_devices, /*num_replicas=*/1, options); + pass.AddPass<SpmdPartitioner>(num_devices, /*num_replicas=*/1, options, + collective_ops_creator); pass.AddPass<HloVerifier>(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); TF_RETURN_IF_ERROR(pass.Run(module.get()).status()); diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index 5b83186ffa4..790497f888e 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -76,6 +76,7 @@ class ClientLibraryTestBase : public ::testing::Test { void SetFastMathDisabled(bool disabled) { auto* opts = execution_options_.mutable_debug_options(); opts->set_xla_cpu_enable_fast_math(!disabled); + opts->set_xla_cpu_enable_fast_min_max(!disabled); opts->set_xla_gpu_enable_fast_min_max(!disabled); } diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 8eed609a134..7b64be5597b 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -165,6 +165,16 @@ PrecisionConfig HloTestBase::DefaultPrecisionConfig(int operands) { return precision_config; } +void HloTestBase::SetAotFastMathDebugOptions(DebugOptions* options) { + options->set_xla_cpu_enable_fast_math(true); + options->set_xla_gpu_enable_fast_min_max(true); + options->set_xla_cpu_enable_fast_min_max(true); + options->set_xla_cpu_fast_math_honor_nans(false); + options->set_xla_cpu_fast_math_honor_infs(false); + options->set_xla_cpu_fast_math_honor_functions(false); + options->set_xla_cpu_fast_math_honor_division(false); +} + DebugOptions HloTestBase::GetDebugOptionsForTest() { auto debug_options = GetDebugOptionsFromFlags(); // TODO(b/38354253): Change tests to use Parameters instead of Constants. diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index d05776a0cb9..85b1876dd3c 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -100,6 +100,10 @@ class HloTestBase : public ::testing::Test { static PrecisionConfig DefaultPrecisionConfig(int operands); + // Sets most fath math options to be enabled to model the fast math flags + // generally used for CPU:AOT compilation. + static void SetAotFastMathDebugOptions(DebugOptions* options); + protected: // This uses the interpreter backend as the reference backend and // automatically finds another supported backend as the test backend. If the diff --git a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc index 3407a68f709..40e226f9902 100644 --- a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc @@ -310,8 +310,7 @@ XLA_TEST_F(VecOpsSimpleTest, ClampTenValuesConstantNonzeroLower) { XLA_TEST_F(VecOpsSimpleTest, ClampFloatEdgeCases) { XlaBuilder builder(TestName()); - mutable_debug_options()->set_xla_cpu_enable_fast_math(false); - mutable_debug_options()->set_xla_gpu_enable_fast_min_max(false); + SetFastMathDisabled(true); auto low = ConstantR1<float>(&builder, {NAN, 1, 1}); auto high = ConstantR1<float>(&builder, {3, NAN, 3}); auto x = ConstantR1<float>(&builder, {2, 2, NAN}); diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index f4b08f454b9..9374b1fca6a 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -148,9 +148,20 @@ message DebugOptions { // xla_cpu_enable_fast_math is false. bool xla_cpu_fast_math_honor_functions = 129; + // When false we lower the Minimum and Maximum hlos in the CPU backend such + // that Min(NotNaN, NaN) = Min(NaN, NotNaN) = NaN. In other words, if flag + // this is false we always propagate NaNs through Min and Max. + // + // Note, this does not correspond to the exact same behavior as the gpu flag + // below! + bool xla_cpu_enable_fast_min_max = 140; + // When true we lower the Minimum and Maximum hlos in the GPU backend such // that Min(NotNaN, NaN) = Min(NaN, NotNaN) = NotNaN. In other words, if flag // this is true we don't propagate NaNs through Min and Max. + // + // Note, this does not correspond to the exact same behavior as the cpu flag + // above! bool xla_gpu_enable_fast_min_max = 100; // Allows xla to increase the output precision of floating point operations. @@ -280,7 +291,7 @@ message DebugOptions { // memory, or have bugs. bool xla_gpu_unsafe_fallback_to_driver_on_ptxas_error = 139; - // Next id: 140 + // Next id: 141 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. diff --git a/tensorflow/compiler/xrt/tests/raw_api_test.cc b/tensorflow/compiler/xrt/tests/raw_api_test.cc index fbf9dfd0a17..67647cc4285 100644 --- a/tensorflow/compiler/xrt/tests/raw_api_test.cc +++ b/tensorflow/compiler/xrt/tests/raw_api_test.cc @@ -62,6 +62,20 @@ xla::XlaComputation ReturnDynamicR1() { return builder.Build(pad_sum).ValueOrDie(); } +xla::XlaComputation ReturnDynamicR2() { + xla::XlaBuilder builder("ReturnDynamicR2"); + auto p0 = xla::Parameter(&builder, 0, + xla::ShapeUtil::MakeShape(xla::F32, {2, 4}), "P0"); + auto p1 = xla::Parameter(&builder, 1, + xla::ShapeUtil::MakeShape(xla::F32, {2, 4}), "P1"); + auto p2 = xla::Parameter(&builder, 2, xla::ShapeUtil::MakeShape(xla::S32, {}), + "P2"); + auto sum = xla::Add(p0, p1); + auto pad_sum_dim0 = xla::SetDimensionSize(sum, p2, 0); + auto pad_sum_dim1 = xla::SetDimensionSize(pad_sum_dim0, p2, 1); + return builder.Build(pad_sum_dim1).ValueOrDie(); +} + xla::XlaComputation AcceptDynamicR1() { xla::XlaBuilder builder("AcceptDynamicR1"); xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4}); @@ -72,6 +86,16 @@ xla::XlaComputation AcceptDynamicR1() { return builder.Build(sum).ValueOrDie(); } +xla::XlaComputation AcceptDynamicR2() { + xla::XlaBuilder builder("AcceptDynamicR2"); + xla::Shape dyn_shape; + dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {2, 4}); + dyn_shape.set_dynamic_dimension(1, true); + auto p0 = xla::Parameter(&builder, 0, dyn_shape, "P0"); + auto negate = xla::Neg(p0); + return builder.Build(negate).ValueOrDie(); +} + xla::XlaComputation ReturnDynamicR1Tuple() { xla::XlaBuilder builder("ReturnDynamicR1Tuple"); auto p0 = xla::Parameter(&builder, 0, @@ -1103,7 +1127,8 @@ TEST(RawApiTest, CompileAndExecute) { TEST(RawApiTest, DynamicR1Test) { if (!SupportDynamicShapes()) { - return; + GTEST_SKIP() + << "Skipping the test if backend doesn't support dynamic shapes"; } xrt::XLAAllocation p0; *p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f, -1.0f}); @@ -1156,9 +1181,71 @@ TEST(RawApiTest, DynamicR1Test) { EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); } +TEST(RawApiTest, DynamicR2Test) { + if (!SupportDynamicShapes()) { + GTEST_SKIP() + << "Skipping the test if backend doesn't support dynamic shapes"; + } + xrt::XLAAllocation p0; + *p0.mutable_value() = xla::LiteralUtil::CreateR2({{1.0f, 2.0f, 0.5f, -1.0f}, + {1.5f, 2.5f, 3.0f, -2.0f}}) + .ToProto(); + xrt::XLAAllocation p1; + *p1.mutable_value() = xla::LiteralUtil::CreateR2({{1.0f, -1.0f, 2.5f, 1.17f}, + {1.2f, -1.6f, 2.8f, 1.24f}}) + .ToProto(); + xrt::XLAAllocation p2; + *p2.mutable_value() = CreateR0<xla::int32>(2); + + xrt::XLAComputation c; + auto config = c.mutable_config(); + auto shapes = config->mutable_program_shape(); + *shapes->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {2, 4}).ToProto(); + *shapes->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {2, 4}).ToProto(); + *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S32, {}).ToProto(); + xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {2, 4}); + dyn_shape.set_dynamic_dimension(0, true); + dyn_shape.set_dynamic_dimension(1, true); + *shapes->mutable_result() = dyn_shape.ToProto(); + StoreComputationSnapshot(ReturnDynamicR2(), c.mutable_hlo_snapshot()); + + xrt::XRTExecutionConfig e; + e.set_release_input_handles(true); + e.set_release_compilation_handle(true); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + Scope cpu_root = root.WithDevice("/device:CPU:0"); + auto e_config = ops::Const(cpu_root, e.SerializeAsString()); + auto computation = ops::Const(cpu_root, c.SerializeAsString()); + auto c_handle = ops::XRTCompile(root, computation); + auto p0_value = ops::Const(cpu_root, p0.SerializeAsString()); + auto p0_handle = ops::XRTAllocate(root, p0_value); + auto p1_value = ops::Const(cpu_root, p1.SerializeAsString()); + auto p1_handle = ops::XRTAllocate(root, p1_value); + auto p2_value = ops::Const(cpu_root, p2.SerializeAsString()); + auto p2_handle = ops::XRTAllocate(root, p2_value); + auto result = ops::XRTExecute( + root, c_handle.handle, e_config, + {Output(p0_handle), Output(p1_handle), Output(p2_handle)}); + auto read_back = ops::XRTReadLiteralAndRelease(root, result); + TF_ASSERT_OK(root.status()); + + XrtClientSession session(root); + std::vector<Tensor> outputs; + TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs)); + + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<tstring>()())); + auto expected = xla::LiteralUtil::CreateR2<float>({{2.0f, 1.0f}, {2.7, 0.9}}); + EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); +} + TEST(RawApiTest, DynamicR1TupleTest) { if (!SupportDynamicShapes()) { - return; + GTEST_SKIP() + << "Skipping the test if backend doesn't support dynamic shapes"; } xrt::XLAAllocation p0; *p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f, -1.0f}); @@ -1221,7 +1308,8 @@ TEST(RawApiTest, DynamicR1TupleTest) { TEST(RawApiTest, AcceptDynamicR1TupleTest) { if (!SupportDynamicShapes()) { - return; + GTEST_SKIP() + << "Skipping the test if backend doesn't support dynamic shapes"; } xrt::XLAAllocation p0; *p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f}); @@ -1286,7 +1374,8 @@ TEST(RawApiTest, AcceptDynamicR1TupleTest) { TEST(RawApiTest, AcceptDynamicR1Test) { if (!SupportDynamicShapes()) { - return; + GTEST_SKIP() + << "Skipping the test if backend doesn't support dynamic shapes"; } xrt::XLAAllocation p0; *p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f}); @@ -1334,6 +1423,55 @@ TEST(RawApiTest, AcceptDynamicR1Test) { EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); } +TEST(RawApiTest, AcceptDynamicR2Test) { + if (!SupportDynamicShapes()) { + GTEST_SKIP() + << "Skipping the test if backend doesn't support dynamic shapes"; + } + xrt::XLAAllocation p0; + *p0.mutable_value() = + xla::LiteralUtil::CreateR2({{-1.0f, 3.0f, 1.0f}, {-2.0f, -1.0f, 3.0f}}) + .ToProto(); + + xrt::XLAComputation c; + auto config = c.mutable_config(); + auto shapes = config->mutable_program_shape(); + // Compile time expects ascending layout. + xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {2, 4}); + dyn_shape.set_dynamic_dimension(1, true); + *shapes->add_parameters() = dyn_shape.ToProto(); + + *shapes->mutable_result() = dyn_shape.ToProto(); + StoreComputationSnapshot(AcceptDynamicR2(), c.mutable_hlo_snapshot()); + + xrt::XRTExecutionConfig e; + e.set_release_input_handles(true); + e.set_release_compilation_handle(true); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + Scope cpu_root = root.WithDevice("/device:CPU:0"); + auto e_config = ops::Const(cpu_root, e.SerializeAsString()); + auto computation = ops::Const(cpu_root, c.SerializeAsString()); + auto c_handle = ops::XRTCompile(root, computation); + auto p0_value = ops::Const(cpu_root, p0.SerializeAsString()); + auto p0_handle = ops::XRTAllocate(root, p0_value); + auto result = + ops::XRTExecute(root, c_handle.handle, e_config, {Output(p0_handle)}); + auto read_back = ops::XRTReadLiteralAndRelease(root, result); + TF_ASSERT_OK(root.status()); + + XrtClientSession session(root); + std::vector<Tensor> outputs; + TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs)); + + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<tstring>()())); + + auto expected = xla::LiteralUtil::CreateR2<float>( + {{1.0f, -3.0f, -1.0f}, {2.0f, 1.0f, -3.0f}}); + EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); +} + TEST(RawApiTest, CompileAndExecuteWithArgumentVector) { xrt::XLAAllocation p0; *p0.mutable_value() = FloatVector({1.0f, 2.0f}); diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 6b4874a8393..2b16801f6ed 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -2254,6 +2254,7 @@ tf_cuda_library( "//tensorflow/core/platform/default/build_config:platformlib", "//tensorflow/core/profiler/lib:annotated_traceme", "//tensorflow/core/profiler/lib:traceme", + "//tensorflow/core/tpu:tpu_library_loader", "//tensorflow/core/util:einsum_op_util", "//tensorflow/core/util:padding", "//tensorflow/core/util:port", diff --git a/tensorflow/core/api_def/base_api/api_def_CompressElement.pbtxt b/tensorflow/core/api_def/base_api/api_def_CompressElement.pbtxt new file mode 100644 index 00000000000..17b63e4ab2f --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_CompressElement.pbtxt @@ -0,0 +1,5 @@ +op { + graph_op_name: "CompressElement" + visibility: HIDDEN + summary: "Compresses a dataset element." +} diff --git a/tensorflow/core/api_def/base_api/api_def_ExtractGlimpseV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExtractGlimpseV2.pbtxt new file mode 100644 index 00000000000..aeb87346ab2 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ExtractGlimpseV2.pbtxt @@ -0,0 +1,86 @@ +op { + graph_op_name: "ExtractGlimpseV2" + visibility: HIDDEN + in_arg { + name: "input" + description: <<END +A 4-D float tensor of shape `[batch_size, height, width, channels]`. +END + } + in_arg { + name: "size" + description: <<END +A 1-D tensor of 2 elements containing the size of the glimpses +to extract. The glimpse height must be specified first, following +by the glimpse width. +END + } + in_arg { + name: "offsets" + description: <<END +A 2-D integer tensor of shape `[batch_size, 2]` containing +the y, x locations of the center of each window. +END + } + out_arg { + name: "glimpse" + description: <<END +A tensor representing the glimpses `[batch_size, +glimpse_height, glimpse_width, channels]`. +END + } + attr { + name: "centered" + description: <<END +indicates if the offset coordinates are centered relative to +the image, in which case the (0, 0) offset is relative to the center +of the input images. If false, the (0,0) offset corresponds to the +upper left corner of the input images. +END + } + attr { + name: "normalized" + description: <<END +indicates if the offset coordinates are normalized. +END + } + attr { + name: "uniform_noise" + description: <<END +indicates if the noise should be generated using a +uniform distribution or a Gaussian distribution. +END + } + attr { + name: "noise" + description: <<END +indicates if the noise should `uniform`, `gaussian`, or +`zero`. The default is `uniform` which means the the noise type +will be decided by `uniform_noise`. +END + } + summary: "Extracts a glimpse from the input tensor." + description: <<END +Returns a set of windows called glimpses extracted at location +`offsets` from the input tensor. If the windows only partially +overlaps the inputs, the non overlapping areas will be filled with +random noise. + +The result is a 4-D tensor of shape `[batch_size, glimpse_height, +glimpse_width, channels]`. The channels and batch dimensions are the +same as that of the input tensor. The height and width of the output +windows are specified in the `size` parameter. + +The argument `normalized` and `centered` controls how the windows are built: + +* If the coordinates are normalized but not centered, 0.0 and 1.0 + correspond to the minimum and maximum of each height and width + dimension. +* If the coordinates are both normalized and centered, they range from + -1.0 to 1.0. The coordinates (-1.0, -1.0) correspond to the upper + left corner, the lower right corner is located at (1.0, 1.0) and the + center is at (0, 0). +* If the coordinates are not normalized they are interpreted as + numbers of pixels. +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_SparseCrossHashed.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseCrossHashed.pbtxt new file mode 100644 index 00000000000..2c4340cb9b7 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_SparseCrossHashed.pbtxt @@ -0,0 +1,104 @@ +op { + graph_op_name: "SparseCrossHashed" + in_arg { + name: "indices" + description: <<END +2-D. Indices of each input `SparseTensor`. +END + } + in_arg { + name: "values" + description: <<END +1-D. values of each `SparseTensor`. +END + } + in_arg { + name: "shapes" + description: <<END +1-D. Shapes of each `SparseTensor`. +END + } + in_arg { + name: "dense_inputs" + description: <<END +2-D. Columns represented by dense `Tensor`. +END + } + in_arg { + name: "num_buckets" + description: <<END +It is used if hashed_output is true. +output = hashed_value%num_buckets if num_buckets > 0 else hashed_value. +END + } + in_arg { + name: "strong_hash" + description: <<END +boolean, if true, siphash with salt will be used instead of farmhash. +END + } + in_arg { + name: "salt" + description: <<END +Specify the salt that will be used by the siphash function. +END + } + out_arg { + name: "output_indices" + description: <<END +2-D. Indices of the concatenated `SparseTensor`. +END + } + out_arg { + name: "output_values" + description: <<END +1-D. Non-empty values of the concatenated or hashed +`SparseTensor`. +END + } + out_arg { + name: "output_shape" + description: <<END +1-D. Shape of the concatenated `SparseTensor`. +END + } + summary: "Generates sparse cross from a list of sparse and dense tensors." + description: <<END +The op takes two lists, one of 2D `SparseTensor` and one of 2D `Tensor`, each +representing features of one feature column. It outputs a 2D `SparseTensor` with +the batchwise crosses of these features. + +For example, if the inputs are + + inputs[0]: SparseTensor with shape = [2, 2] + [0, 0]: "a" + [1, 0]: "b" + [1, 1]: "c" + + inputs[1]: SparseTensor with shape = [2, 1] + [0, 0]: "d" + [1, 0]: "e" + + inputs[2]: Tensor [["f"], ["g"]] + +then the output will be + + shape = [2, 2] + [0, 0]: "a_X_d_X_f" + [1, 0]: "b_X_e_X_g" + [1, 1]: "c_X_e_X_g" + +if hashed_output=true then the output will be + + shape = [2, 2] + [0, 0]: FingerprintCat64( + Fingerprint64("f"), FingerprintCat64( + Fingerprint64("d"), Fingerprint64("a"))) + [1, 0]: FingerprintCat64( + Fingerprint64("g"), FingerprintCat64( + Fingerprint64("e"), Fingerprint64("b"))) + [1, 1]: FingerprintCat64( + Fingerprint64("g"), FingerprintCat64( + Fingerprint64("e"), Fingerprint64("c"))) +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_SparseCrossV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseCrossV2.pbtxt new file mode 100644 index 00000000000..0627d9b3909 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_SparseCrossV2.pbtxt @@ -0,0 +1,91 @@ +op { + graph_op_name: "SparseCrossV2" + in_arg { + name: "indices" + description: <<END +2-D. Indices of each input `SparseTensor`. +END + } + in_arg { + name: "values" + description: <<END +1-D. values of each `SparseTensor`. +END + } + in_arg { + name: "shapes" + description: <<END +1-D. Shapes of each `SparseTensor`. +END + } + in_arg { + name: "dense_inputs" + description: <<END +2-D. Columns represented by dense `Tensor`. +END + } + in_arg { + name: "sep" + description: <<END +string used when joining a list of string inputs, can be used as separator later. +END + } + out_arg { + name: "output_indices" + description: <<END +2-D. Indices of the concatenated `SparseTensor`. +END + } + out_arg { + name: "output_values" + description: <<END +1-D. Non-empty values of the concatenated or hashed +`SparseTensor`. +END + } + out_arg { + name: "output_shape" + description: <<END +1-D. Shape of the concatenated `SparseTensor`. +END + } + summary: "Generates sparse cross from a list of sparse and dense tensors." + description: <<END +The op takes two lists, one of 2D `SparseTensor` and one of 2D `Tensor`, each +representing features of one feature column. It outputs a 2D `SparseTensor` with +the batchwise crosses of these features. + +For example, if the inputs are + + inputs[0]: SparseTensor with shape = [2, 2] + [0, 0]: "a" + [1, 0]: "b" + [1, 1]: "c" + + inputs[1]: SparseTensor with shape = [2, 1] + [0, 0]: "d" + [1, 0]: "e" + + inputs[2]: Tensor [["f"], ["g"]] + +then the output will be + + shape = [2, 2] + [0, 0]: "a_X_d_X_f" + [1, 0]: "b_X_e_X_g" + [1, 1]: "c_X_e_X_g" + +if hashed_output=true then the output will be + + shape = [2, 2] + [0, 0]: FingerprintCat64( + Fingerprint64("f"), FingerprintCat64( + Fingerprint64("d"), Fingerprint64("a"))) + [1, 0]: FingerprintCat64( + Fingerprint64("g"), FingerprintCat64( + Fingerprint64("e"), Fingerprint64("b"))) + [1, 1]: FingerprintCat64( + Fingerprint64("g"), FingerprintCat64( + Fingerprint64("e"), Fingerprint64("c"))) +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_UncompressElement.pbtxt b/tensorflow/core/api_def/base_api/api_def_UncompressElement.pbtxt new file mode 100644 index 00000000000..e2039b674f0 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_UncompressElement.pbtxt @@ -0,0 +1,5 @@ +op { + graph_op_name: "UncompressElement" + visibility: HIDDEN + summary: "Uncompresses a compressed dataset element." +} diff --git a/tensorflow/core/api_def/python_api/api_def_SparseCrossHashed.pbtxt b/tensorflow/core/api_def/python_api/api_def_SparseCrossHashed.pbtxt new file mode 100644 index 00000000000..2c830668733 --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_SparseCrossHashed.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "SparseCrossHashed" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_SparseCrossV2.pbtxt b/tensorflow/core/api_def/python_api/api_def_SparseCrossV2.pbtxt new file mode 100644 index 00000000000..dfa0a670c4c --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_SparseCrossV2.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "SparseCrossV2" + visibility: HIDDEN +} diff --git a/tensorflow/core/common_runtime/composite_device.cc b/tensorflow/core/common_runtime/composite_device.cc index 3103fa37941..7fd41e00a04 100644 --- a/tensorflow/core/common_runtime/composite_device.cc +++ b/tensorflow/core/common_runtime/composite_device.cc @@ -24,7 +24,7 @@ const char* const kCompositeDeviceType = "COMPOSITE"; std::unique_ptr<CompositeDevice> CompositeDevice::MakeDevice( const std::vector<string>& underlying_devices, const int unique_device_id, - Status* status) { + const DeviceNameUtils::ParsedName& host_name, Status* status) { if (underlying_devices.empty()) { status->Update( errors::InvalidArgument("underlying_devices should not be empty.")); @@ -62,13 +62,15 @@ std::unique_ptr<CompositeDevice> CompositeDevice::MakeDevice( return nullptr; } } + + DeviceNameUtils::ParsedName parsed_composite_name = host_name; DeviceAttributes device_attributes; - parsed_name.type = kCompositeDeviceType; - device_attributes.set_device_type(parsed_name.type); - parsed_name.id = unique_device_id; + parsed_composite_name.type = kCompositeDeviceType; + parsed_composite_name.id = unique_device_id; const string composite_name = - DeviceNameUtils::ParsedNameToString(parsed_name); + DeviceNameUtils::ParsedNameToString(parsed_composite_name); device_attributes.set_name(composite_name); + device_attributes.set_device_type(kCompositeDeviceType); return absl::WrapUnique( new CompositeDevice(device_attributes, underlying_devices)); diff --git a/tensorflow/core/common_runtime/composite_device.h b/tensorflow/core/common_runtime/composite_device.h index 127e5b8303a..850eae55e8d 100644 --- a/tensorflow/core/common_runtime/composite_device.h +++ b/tensorflow/core/common_runtime/composite_device.h @@ -42,10 +42,11 @@ class CompositeDevice : public Device { return &underlying_devices_; } - // Helper for creating a CompositeDevice. + // Helper for creating a CompositeDevice on the same task as the given host + // CPU. static std::unique_ptr<CompositeDevice> MakeDevice( const std::vector<string>& underlying_devices, const int unique_device_id, - Status* status); + const DeviceNameUtils::ParsedName& host_name, Status* status); private: CompositeDevice(const DeviceAttributes& device_attributes, diff --git a/tensorflow/core/common_runtime/composite_device_test.cc b/tensorflow/core/common_runtime/composite_device_test.cc index ac2f9108ecb..73a6ae44912 100644 --- a/tensorflow/core/common_runtime/composite_device_test.cc +++ b/tensorflow/core/common_runtime/composite_device_test.cc @@ -20,12 +20,15 @@ limitations under the License. namespace tensorflow { TEST(CompositeDeviceTest, Basic) { + const string host_name = "/job:localhost/replica:0/task:0/device:CPU:0"; + DeviceNameUtils::ParsedName parsed_host_name; + EXPECT_TRUE(DeviceNameUtils::ParseFullName(host_name, &parsed_host_name)); std::vector<string> underlying_devices; { Status status; std::unique_ptr<CompositeDevice> composite_device = CompositeDevice::MakeDevice(underlying_devices, /*unique_device_id=*/0, - &status); + parsed_host_name, &status); EXPECT_EQ(composite_device, nullptr); EXPECT_EQ(error::INVALID_ARGUMENT, status.code()); EXPECT_TRUE(absl::StrContains(status.error_message(), @@ -41,7 +44,7 @@ TEST(CompositeDeviceTest, Basic) { "/job:localhost/replica:0/task:0/device:CPU:1"); std::unique_ptr<CompositeDevice> composite_device = CompositeDevice::MakeDevice(underlying_devices, /*unique_device_id=*/0, - &status); + parsed_host_name, &status); TF_ASSERT_OK(status); EXPECT_EQ(composite_device->device_type(), kCompositeDeviceType); EXPECT_EQ(underlying_devices, *composite_device->underlying_devices()); @@ -53,7 +56,7 @@ TEST(CompositeDeviceTest, Basic) { "/job:localhost/replica:0/task:0/device:CPU:0"); std::unique_ptr<CompositeDevice> composite_device = CompositeDevice::MakeDevice(underlying_devices, /*unique_device_id=*/1, - &status); + parsed_host_name, &status); EXPECT_EQ(composite_device, nullptr); EXPECT_EQ(error::INVALID_ARGUMENT, status.code()); EXPECT_TRUE( @@ -68,7 +71,7 @@ TEST(CompositeDeviceTest, Basic) { "/job:localhost/replica:0/task:0/device:GPU:0"); std::unique_ptr<CompositeDevice> composite_device = CompositeDevice::MakeDevice(underlying_devices, /*unique_device_id=*/1, - &status); + parsed_host_name, &status); EXPECT_EQ(composite_device, nullptr); EXPECT_EQ(error::INVALID_ARGUMENT, status.code()); EXPECT_TRUE(absl::StrContains(status.error_message(), diff --git a/tensorflow/core/common_runtime/device_set.cc b/tensorflow/core/common_runtime/device_set.cc index b062529a3ff..902ca2c2ee2 100644 --- a/tensorflow/core/common_runtime/device_set.cc +++ b/tensorflow/core/common_runtime/device_set.cc @@ -116,12 +116,15 @@ void DeviceSet::SortPrioritizedDeviceVector(PrioritizedDeviceVector* vector) { if (a_type_name != b_type_name) { auto a_priority = DeviceFactory::DevicePriority(a_type_name); auto b_priority = DeviceFactory::DevicePriority(b_type_name); - // First sort by prioritized device type (higher is preferred) and - // then by device name (lexicographically). if (a_priority != b_priority) { return a_priority > b_priority; } } + + if (a.first->IsLocal() != b.first->IsLocal()) { + return a.first->IsLocal(); + } + return StringPiece(a.first->name()) < StringPiece(b.first->name()); }; std::sort(vector->begin(), vector->end(), device_sort); diff --git a/tensorflow/core/common_runtime/device_set.h b/tensorflow/core/common_runtime/device_set.h index 608705c32f7..f59f84c2066 100644 --- a/tensorflow/core/common_runtime/device_set.h +++ b/tensorflow/core/common_runtime/device_set.h @@ -90,8 +90,8 @@ class DeviceSet { // // After a call to this function, the argument vector will be sorted by // explicit priority (the second element in the `std::pair<DeviceType, - // int32>`), then by `DeviceTypeOrder` of the device type, and lastly - // by device name. + // int32>`), then by `DeviceTypeOrder` of the device type, then by device + // locality, and lastly by device name. static void SortPrioritizedDeviceVector(PrioritizedDeviceVector* vector); // Sorts a PrioritizedDeviceTypeVector according to types and explicit diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index d104e0a985f..96938bcbafd 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -349,12 +349,12 @@ DirectSession::DirectSession(const SessionOptions& options, int devices_added = 0; if (options.config.log_device_placement()) { const string mapping_str = device_mgr_->DeviceMappingString(); + string msg; if (mapping_str.empty()) { - printf("Device mapping: no known devices.\n"); + msg = "Device mapping: no known devices."; } else { - printf("Device mapping:\n%s", mapping_str.c_str()); + msg = strings::StrCat("Device mapping:\n", mapping_str); } - string msg = strings::StrCat("Device mapping:\n", mapping_str); if (!logging::LogToListeners(msg)) { LOG(INFO) << msg; } diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index b8dfe92aac6..1024f3caabd 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -81,7 +81,8 @@ EagerContext::EagerContext( bool device_mgr_owned, Rendezvous* rendezvous, const CustomKernelCreator* custom_kernel_creator, DistributedFunctionLibraryRuntime* cluster_flr) - : default_device_placement_policy_(default_device_placement_policy), + : opts_(opts), + default_device_placement_policy_(default_device_placement_policy), default_mirroring_policy_(default_mirroring_policy), local_device_manager_(device_mgr, device_mgr_owned), host_cpu_device_(device_mgr->HostCPU()), @@ -935,8 +936,11 @@ Status EagerContext::FindOrCreateCompositeDevice( } Status s; - auto device = CompositeDevice::MakeDevice(underlying_devices, - composite_devices_.size(), &s); + // Create a CompositeDevice on the same task as the host CPU, in order to + // trigger packed TensorHandle copy from a client to a remote worker. + auto device = + CompositeDevice::MakeDevice(underlying_devices, composite_devices_.size(), + HostCPU()->parsed_name(), &s); TF_RETURN_IF_ERROR(s); *composite_device = device.get(); pflr_->AddCompositeDevice(*composite_device); @@ -1048,7 +1052,7 @@ void EagerContext::IncrementContextViewId() { // Set collective ops related state in the context. Passing nullptr to // `new_server` will reuse the existing GRPC server in context. Status EagerContext::StoreCollectiveOpsServer( - std::unique_ptr<ServerInterface> new_server, DeviceMgr* device_mgr, + std::unique_ptr<ServerInterface> new_server, const DeviceMgr* device_mgr, CollectiveExecutorMgrInterface* rpc_collective_executor_mgr) { collective_executor_mgr_.Reset(rpc_collective_executor_mgr); @@ -1173,7 +1177,7 @@ Status EagerContext::InitializeRemoteMaster( std::unique_ptr<eager::EagerClientCache> remote_eager_workers, std::unique_ptr<DynamicDeviceMgr> remote_device_manager, const std::vector<string>& remote_contexts, uint64 context_id, - Rendezvous* r, DeviceMgr* local_device_mgr, int keep_alive_secs, + Rendezvous* r, const DeviceMgr* local_device_mgr, int keep_alive_secs, DistributedFunctionLibraryRuntime* cluster_flr, std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>> remote_mgr) { @@ -1272,7 +1276,7 @@ Status EagerContext::SetMasterContextState( std::shared_ptr<WorkerSession> worker_session, std::unique_ptr<eager::EagerClientCache> remote_eager_workers, std::unique_ptr<DynamicDeviceMgr> remote_device_manager, uint64 context_id, - uint64 context_view_id, Rendezvous* r, DeviceMgr* local_device_mgr, + uint64 context_view_id, Rendezvous* r, const DeviceMgr* local_device_mgr, int keep_alive_secs, DistributedFunctionLibraryRuntime* cluster_flr, std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>> remote_mgr) { @@ -1284,7 +1288,13 @@ Status EagerContext::SetMasterContextState( use_send_tensor_rpc_ = ReadBoolFromEnvVar("TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC", true); - local_device_manager_.Reset(local_device_mgr); + if (local_device_mgr != local_device_manager_.Get()) { + if (local_device_manager_.Owned()) { + old_local_device_managers_.push_back( + std::move(local_device_manager_.owned_object)); + } + local_device_manager_.Reset(local_device_mgr); + } host_cpu_device_ = local_device_manager_.Get()->HostCPU(); if (rendezvous_ != nullptr) rendezvous_->Unref(); diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index d034aaf2f9c..cceb883a965 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -295,6 +295,8 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted { // errors, and the error message will be combined from all executors. Status SyncExecutors(); + Status AsyncWait() override { return SyncExecutors(); } + core::RefCountPtr<KernelAndDevice> GetCachedKernel(Fprint128 cache_key); void AddKernelToCache(Fprint128 cache_key, KernelAndDevice* kernel); @@ -397,7 +399,7 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted { std::unique_ptr<eager::EagerClientCache> remote_eager_workers, std::unique_ptr<DynamicDeviceMgr> remote_device_manager, const std::vector<string>& remote_contexts, uint64 context_id, - Rendezvous* r, DeviceMgr* local_device_mgr, int keep_alive_secs, + Rendezvous* r, const DeviceMgr* local_device_mgr, int keep_alive_secs, DistributedFunctionLibraryRuntime* cluster_flr, std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>> remote_mgr); @@ -434,7 +436,7 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted { const std::vector<string>& remote_contexts, uint64 context_id); Status StoreCollectiveOpsServer( - std::unique_ptr<ServerInterface> new_server, DeviceMgr* device_mgr, + std::unique_ptr<ServerInterface> new_server, const DeviceMgr* device_mgr, CollectiveExecutorMgrInterface* rpc_collective_executor_mgr); // For the specified remote worker, preprocess and set its device filters. @@ -508,6 +510,8 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted { // Gets the CPU device on the task of device. Status CPUDeviceOnTask(const Device* device, Device** cpu_device) const; + const SessionOptions& session_options() const { return opts_; } + private: ~EagerContext() override; @@ -561,6 +565,7 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted { T* unowned_object_ptr = nullptr; }; + SessionOptions opts_; const ContextDevicePlacementPolicy default_device_placement_policy_; const ContextMirroringPolicy default_mirroring_policy_; @@ -573,6 +578,8 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted { TF_GUARDED_BY(policy_map_mu_); OwnedOrUnownedHelper<const DeviceMgr> local_device_manager_; + // Maintain copy of all previously created local device managers. + std::vector<std::unique_ptr<const DeviceMgr>> old_local_device_managers_; // Unowned DynamicDeviceMgr is set on remote worker to allow running // multi-device function on remote worker. @@ -660,7 +667,7 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted { std::unique_ptr<eager::EagerClientCache> remote_eager_workers, std::unique_ptr<DynamicDeviceMgr> remote_device_manager, uint64 context_id, uint64 context_view_id, Rendezvous* r, - DeviceMgr* local_device_mgr, int keep_alive_secs, + const DeviceMgr* local_device_mgr, int keep_alive_secs, DistributedFunctionLibraryRuntime* cluster_flr, std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>> remote_mgr); diff --git a/tensorflow/core/common_runtime/eager/context_test.cc b/tensorflow/core/common_runtime/eager/context_test.cc index f83e3f0b45d..c6ed61c80c4 100644 --- a/tensorflow/core/common_runtime/eager/context_test.cc +++ b/tensorflow/core/common_runtime/eager/context_test.cc @@ -31,7 +31,7 @@ static Device* CreateDevice(const string& type, int n) { Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; } }; DeviceAttributes attr; - attr.set_name("/job:a/replica:0/task:0/device:" + type + ":" + + attr.set_name("/job:localhost/replica:0/task:0/device:" + type + ":" + std::to_string(n)); attr.set_device_type(type); return new FakeDevice(attr); @@ -179,10 +179,10 @@ TEST_F(EagerContextTest, CompositeDevice) { TF_ASSERT_OK(context()->FindOrCreateCompositeDevice(underlying_devices, &composite_device_0)); EXPECT_EQ(composite_device_0->name(), - "/job:worker/replica:0/task:0/device:COMPOSITE:0"); + "/job:localhost/replica:0/task:0/device:COMPOSITE:0"); CompositeDevice* device = nullptr; TF_EXPECT_OK(context()->FindCompositeDeviceFromName( - "/job:worker/replica:0/task:0/device:COMPOSITE:0", &device)); + "/job:localhost/replica:0/task:0/device:COMPOSITE:0", &device)); EXPECT_EQ(device, composite_device_0); CompositeDevice* composite_device_1 = nullptr; TF_ASSERT_OK(context()->FindOrCreateCompositeDevice(underlying_devices, @@ -193,13 +193,13 @@ TEST_F(EagerContextTest, CompositeDevice) { TF_ASSERT_OK(context()->FindOrCreateCompositeDevice(underlying_devices, &composite_device_2)); EXPECT_EQ(composite_device_2->name(), - "/job:worker/replica:0/task:0/device:COMPOSITE:1"); + "/job:localhost/replica:0/task:0/device:COMPOSITE:1"); TF_EXPECT_OK(context()->FindCompositeDeviceFromName( - "/job:worker/replica:0/task:0/device:COMPOSITE:1", &device)); + "/job:localhost/replica:0/task:0/device:COMPOSITE:1", &device)); EXPECT_EQ(device, composite_device_2); EXPECT_TRUE(errors::IsNotFound(context()->FindCompositeDeviceFromName( - "/job:worker/replica:0/task:0/device:COMPOSITE:2", &device))); + "/job:localhost/replica:0/task:0/device:COMPOSITE:2", &device))); } } // namespace diff --git a/tensorflow/core/common_runtime/eager/eager_executor.cc b/tensorflow/core/common_runtime/eager/eager_executor.cc index 7850978410f..ddfdabf9472 100644 --- a/tensorflow/core/common_runtime/eager/eager_executor.cc +++ b/tensorflow/core/common_runtime/eager/eager_executor.cc @@ -50,7 +50,6 @@ EagerExecutor::~EagerExecutor() { Status EagerExecutor::ShutDown() { { - std::vector<core::RefCountPtr<NodeItem>> items_to_destroy; bool has_thread; Status status; { @@ -72,9 +71,6 @@ Status EagerExecutor::ShutDown() { nodes_pending_.notify_all(); } } - for (auto& item : items_to_destroy) { - item->node->Abort(status); - } if (!has_thread) { return status; } diff --git a/tensorflow/core/common_runtime/eager/execute_node_test.cc b/tensorflow/core/common_runtime/eager/execute_node_test.cc index 99f030322df..83fbcf5017e 100644 --- a/tensorflow/core/common_runtime/eager/execute_node_test.cc +++ b/tensorflow/core/common_runtime/eager/execute_node_test.cc @@ -61,7 +61,8 @@ TEST(ExecuteNodeTest, ExecuteNodeArgs) { Status s; std::unique_ptr<CompositeDevice> composite_device = CompositeDevice::MakeDevice({device0->name(), device1->name()}, - /*unique_device_id=*/0, &s); + /*unique_device_id=*/0, + device_mgr.HostCPU()->parsed_name(), &s); TF_ASSERT_OK(s); auto ctx = new EagerContext( diff --git a/tensorflow/core/common_runtime/eager/tensor_handle_test.cc b/tensorflow/core/common_runtime/eager/tensor_handle_test.cc index 779158375de..13b634bbec4 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle_test.cc +++ b/tensorflow/core/common_runtime/eager/tensor_handle_test.cc @@ -100,6 +100,7 @@ class PackedTensorHandleTest : public ::testing::Test { for (const char* name : device_names_) { devices.emplace_back(CreateDevice("GPU", name)); } + devices.emplace_back(CreateDevice("CPU", host_name_)); device_mgr_ = new StaticDeviceMgr(std::move(devices)); context_ = new EagerContext( @@ -132,6 +133,8 @@ class PackedTensorHandleTest : public ::testing::Test { "/job:worker/replica:0/task:1/device:GPU:0", "/job:worker/replica:0/task:1/device:GPU:1"}; + const char* host_name_ = "/job:worker/replica:0/task:0/device:CPU:0"; + StaticDeviceMgr* device_mgr_; EagerContext* context_; }; diff --git a/tensorflow/core/common_runtime/graph_optimizer.cc b/tensorflow/core/common_runtime/graph_optimizer.cc index 746930750ad..ae1a2daa788 100644 --- a/tensorflow/core/common_runtime/graph_optimizer.cc +++ b/tensorflow/core/common_runtime/graph_optimizer.cc @@ -42,7 +42,7 @@ void GraphOptimizer::Optimize( const NodePredicate& cse_consider_fn, const NodePredicate& cf_consider_fn, bool inline_multi_device_functions, bool inline_impl_selection_group_functions, - bool inline_with_single_device_body_placer) { + bool inline_with_single_device_body_placer, bool ignore_noinline) { Graph* g = graph->get(); DumpGraph("Initial", g); @@ -116,6 +116,11 @@ void GraphOptimizer::Optimize( .inline_impl_selection_group_functions = true; } + if (ignore_noinline) { + expand_inline_opts.multi_device_options.ignore_noinline = true; + expand_inline_opts.native_options.ignore_noinline = true; + } + bool was_mutated = ExpandInlineFunctions(runtime, g, expand_inline_opts); if (was_mutated) { DumpGraph("ExpandInlineFunctions", g); @@ -138,11 +143,11 @@ void GraphOptimizer::Optimize(FunctionLibraryRuntime* runtime, Env* env, const Device* device, std::unique_ptr<Graph>* graph, const Options& options) { - Optimize(runtime, env, device, graph, options.shape_map, - options.cse_consider_fn, options.cf_consider_fn, - options.inline_multi_device_functions, - options.inline_impl_selection_group_functions, - options.inline_with_single_device_body_placer); + Optimize( + runtime, env, device, graph, options.shape_map, options.cse_consider_fn, + options.cf_consider_fn, options.inline_multi_device_functions, + options.inline_impl_selection_group_functions, + options.inline_with_single_device_body_placer, options.ignore_noinline); } void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr<Graph>* g, diff --git a/tensorflow/core/common_runtime/graph_optimizer.h b/tensorflow/core/common_runtime/graph_optimizer.h index 099ea8efa12..53bf532bd9c 100644 --- a/tensorflow/core/common_runtime/graph_optimizer.h +++ b/tensorflow/core/common_runtime/graph_optimizer.h @@ -58,6 +58,9 @@ class GraphOptimizer { // If true all functions will be inlined with a single device function // body placer strategy. bool inline_with_single_device_body_placer = false; + + // If true, the _noinline attribute on functions and callers is ignored. + bool ignore_noinline = false; }; explicit GraphOptimizer(const OptimizerOptions& opts); @@ -81,7 +84,8 @@ class GraphOptimizer { const NodePredicate& cf_consider_fn = nullptr, bool inline_multi_device_functions = false, bool inline_impl_selection_group_functions = false, - bool inline_with_single_device_body_placer = false); + bool inline_with_single_device_body_placer = false, + bool ignore_noinline = false); const OptimizerOptions& options() { return opts_; } diff --git a/tensorflow/core/common_runtime/mkl_layout_pass.cc b/tensorflow/core/common_runtime/mkl_layout_pass.cc index 2941845a604..fbec7059743 100644 --- a/tensorflow/core/common_runtime/mkl_layout_pass.cc +++ b/tensorflow/core/common_runtime/mkl_layout_pass.cc @@ -268,6 +268,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { csinfo_.dequantize = "Dequantize"; csinfo_.fused_batch_norm = "FusedBatchNorm"; csinfo_.fused_batch_norm_grad = "FusedBatchNormGrad"; + csinfo_.fused_batch_norm_ex = "_FusedBatchNormEx"; csinfo_.fused_batch_norm_v2 = "FusedBatchNormV2"; csinfo_.fused_batch_norm_grad_v2 = "FusedBatchNormGradV2"; csinfo_.fused_batch_norm_v3 = "FusedBatchNormV3"; @@ -295,6 +296,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { "_MklDepthwiseConv2dNativeBackpropInput"; csinfo_.mkl_depthwise_conv2d_grad_filter = "_MklDepthwiseConv2dNativeBackpropFilter"; + csinfo_.mkl_fused_batch_norm_ex = "_MklFusedBatchNormEx"; csinfo_.mkl_fused_conv2d = "_MklFusedConv2D"; csinfo_.mkl_fused_depthwise_conv2d = "_MklFusedDepthwiseConv2dNative"; csinfo_.mkl_fused_matmul = "_MklFusedMatMul"; @@ -478,6 +480,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass { {csinfo_.fused_batch_norm_grad_v3, mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad_v3), CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation}); +#ifdef ENABLE_MKLDNN_V1 + rinfo_.push_back({csinfo_.fused_batch_norm_ex, + csinfo_.mkl_fused_batch_norm_ex, CopyAttrsAll, + FusedBatchNormExRewrite, kRewriteForLayoutPropagation}); +#endif rinfo_.push_back({csinfo_.fused_conv2d, csinfo_.mkl_fused_conv2d, CopyAttrsFusedConv2D, FusedConv2DRewrite, kRewriteForLayoutPropagation}); @@ -499,7 +506,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { CopyAttrsAll, LrnGradRewrite, kRewriteForLayoutPropagation}); rinfo_.push_back({csinfo_.matmul, mkl_op_registry::GetMklOpName(csinfo_.matmul), - CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange}); + CopyAttrsAll, MatMulRewrite, kRewriteForOpNameChange}); rinfo_.push_back( {csinfo_.leakyrelu, mkl_op_registry::GetMklOpName(csinfo_.leakyrelu), CopyAttrsAll, LeakyReluRewrite, kRewriteForLayoutPropagation}); @@ -926,6 +933,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { string dequantize; string fused_batch_norm; string fused_batch_norm_grad; + string fused_batch_norm_ex; string fused_batch_norm_v2; string fused_batch_norm_grad_v2; string fused_batch_norm_v3; @@ -951,6 +959,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { string mkl_conv2d_with_bias; string mkl_depthwise_conv2d_grad_input; string mkl_depthwise_conv2d_grad_filter; + string mkl_fused_batch_norm_ex; string mkl_fused_conv2d; string mkl_fused_depthwise_conv2d; string mkl_fused_matmul; @@ -1473,6 +1482,16 @@ class MklLayoutRewritePass : public GraphOptimizationPass { return false; } + static bool MatMulRewrite(const Node* n) { + DataType T; + GetNodeAttr(n->def(), "T", &T); + if ((T == DT_FLOAT) || (T == DT_BFLOAT16)) { + VLOG(2) << "Rewriting MatMul to _MklMatMul"; + return true; + } + return false; + } + static bool DequantizeRewrite(const Node* n) { DCHECK(n); Node* input = nullptr; @@ -1670,6 +1689,31 @@ class MklLayoutRewritePass : public GraphOptimizationPass { return do_rewrite; } + static bool FusedBatchNormExRewrite(const Node* n) { + DCHECK(n); + + int num_side_inputs; + TF_CHECK_OK(GetNodeAttr(n->def(), "num_side_inputs", &num_side_inputs)); + string activation_mode; + TF_CHECK_OK(GetNodeAttr(n->def(), "activation_mode", &activation_mode)); + + // if the num_side_inputs is not 0, don't rewrite the node. + if (num_side_inputs != 0) { + VLOG(1) << "FusedBatchNormExRewrite: The model sets num_side_inputs" + << "larger than 0 is not optimized by Intel MKL."; + return false; + } + + // if the activation_mode is not 'Relu', don't rewrite the node. + if (activation_mode != "Relu") { + VLOG(1) << "FusedBatchNormExRewrite: Only Relu activation mode is" + << "supported by Intel MKL."; + return false; + } + + return true; + } + static bool FusedConv2DRewrite(const Node* n) { // MKL DNN currently doesn't support all fusions that grappler fuses // together with Conv2D (ex. batchnorm). We rewrite _FusedConv2D only if @@ -2168,9 +2212,6 @@ int MklLayoutRewritePass::SetUpContiguousInputs( // Number of input slots to original op // Input slots are represented by .Input() calls in REGISTER_OP. int old_node_input_slots = old_node->op_def().input_arg_size(); - // Actual number of inputs can be greater than or equal to number - // of Input slots because inputs of type list could be unfolded. - CHECK_GE(old_node_inputs.size(), old_node_input_slots); int nn_slot_idx = 0; // slot index for inputs of new node // Let's copy all inputs (TF tensors) of original node to new node. @@ -2178,13 +2219,14 @@ int MklLayoutRewritePass::SetUpContiguousInputs( for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) { // An input slot could be a single tensor or a list. We need // to handle this case accordingly. - CHECK_LT(iidx, old_node_inputs.size()); const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx); if (ArgIsList(arg)) { std::vector<NodeBuilder::NodeOut> new_node_inputs; - int N = GetTensorListLength(arg, old_node); - GetNodesProducingTFTensorList(old_node_inputs, &iidx, N, - &new_node_inputs); + int tensor_list_length = GetTensorListLength(arg, old_node); + if (tensor_list_length != 0) { + GetNodesProducingTFTensorList(old_node_inputs, &iidx, + tensor_list_length, &new_node_inputs); + } nb->Input(new_node_inputs); nn_slot_idx++; } else { @@ -2217,13 +2259,14 @@ int MklLayoutRewritePass::SetUpContiguousInputs( for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) { // An input slot could be a single tensor or a list. We need // to handle this case accordingly. - CHECK_LT(iidx, old_node_inputs.size()); const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx); if (ArgIsList(arg)) { std::vector<NodeBuilder::NodeOut> new_node_inputs; - int N = GetTensorListLength(arg, old_node); - GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx, N, - &new_node_inputs); + int tensor_list_length = GetTensorListLength(arg, old_node); + if (tensor_list_length != 0) { + GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx, + tensor_list_length, &new_node_inputs); + } nb->Input(new_node_inputs); nn_slot_idx++; } else { @@ -3739,6 +3782,7 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const { n->type_string() != csinfo_.pad_with_conv2d && n->type_string() != csinfo_.pad_with_fused_conv2d && n->type_string() != csinfo_.conv2d_grad_filter_with_bias && + n->type_string() != csinfo_.fused_batch_norm_ex && n->type_string() != csinfo_.fused_conv2d && n->type_string() != csinfo_.fused_depthwise_conv2d && n->type_string() != csinfo_.fused_matmul && diff --git a/tensorflow/core/common_runtime/mkl_layout_pass_test.cc b/tensorflow/core/common_runtime/mkl_layout_pass_test.cc index c6d5331852e..71ab786f8a5 100644 --- a/tensorflow/core/common_runtime/mkl_layout_pass_test.cc +++ b/tensorflow/core/common_runtime/mkl_layout_pass_test.cc @@ -3216,6 +3216,100 @@ TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNormV3_Negative) { "B->F:1;C->F:2;D->F:3;E->F:4;F->G:1"); } +// clang-format off +#ifdef ENABLE_MKLDNN_V1 +#define REGISTER_TEST(NAME, T, INPUT) \ + TEST_F(MklLayoutPassTest, NAME##_##T) { \ + InitGraph("node { name: 'A' op: '" #INPUT "'}" \ + "node { name: 'B' op: 'Input'}" \ + "node { name: 'C' op: 'Input'}" \ + "node { name: 'D' op: 'Input'}" \ + "node { name: 'E' op: 'Input'}" \ + "node { name: 'F' op: '_FusedBatchNormEx'" \ + " attr { key: 'T' value { type: " #T " } }" \ + " attr { key: 'U' value { type: DT_FLOAT } }" \ + " attr { key: 'data_format' value { s: 'NCHW' } }" \ + " attr { key: 'epsilon' value { f: 0.0001 } }" \ + " attr { key: 'num_side_inputs' value { i: 0 } }" \ + " attr { key: 'is_training' value { b: true } }" \ + " attr { key: 'activation_mode' value { s: 'Relu' } }" \ + " input: ['A', 'B', 'C', 'D', 'E'] }" \ + "node { name: 'G' op: 'Zeta'" \ + " attr { key: 'T' value { type: " #T " } }" \ + " input: ['A', 'F'] }"); \ + EXPECT_EQ(DoMklLayoutOptimizationPass(), \ + "A(" #INPUT ");B(Input);C(Input);D(Input);" \ + "DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);" \ + "DMT/_4(Const);E(Input);" \ + "F(_MklFusedBatchNormEx);G(Zeta)|A->F;A->G;" \ + "A:control->DMT/_0:control;A:control->DMT/_1:control;" \ + "A:control->DMT/_2:control;A:control->DMT/_3:control;" \ + "A:control->DMT/_4:control;B->F:1;C->F:2;D->F:3;" \ + "DMT/_0->F:5;DMT/_1->F:6;DMT/_2->F:7;DMT/_3->F:8;DMT/_4->F:9;" \ + "E->F:4;F->G:1"); \ + } +REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormEx_Positive); +#undef REGISTER_TEST + +// Rewrite test for _FusedBatchNormEx Op with side input +#define REGISTER_TEST(NAME, T, INPUT) \ + TEST_F(MklLayoutPassTest, NAME##_##T) { \ + InitGraph("node { name: 'A' op: '" #INPUT "'}" \ + "node { name: 'B' op: 'Input'}" \ + "node { name: 'C' op: 'Input'}" \ + "node { name: 'D' op: 'Input'}" \ + "node { name: 'E' op: 'Input'}" \ + "node { name: 'F' op: '" #INPUT "'}" \ + "node { name: 'G' op: '_FusedBatchNormEx'" \ + " attr { key: 'T' value { type: " #T " } }" \ + " attr { key: 'U' value { type: DT_FLOAT } }" \ + " attr { key: 'data_format' value { s: 'NCHW' } }" \ + " attr { key: 'epsilon' value { f: 0.0001 } }" \ + " attr { key: 'num_side_inputs' value { i: 1 } }" \ + " attr { key: 'is_training' value { b: true } }" \ + " attr { key: 'activation_mode' value { s: 'Relu' } }" \ + " input: ['A', 'B', 'C', 'D', 'E', 'F'] }" \ + "node { name: 'H' op: 'Zeta'" \ + " attr { key: 'T' value { type: " #T " } }" \ + " input: ['A', 'G'] }"); \ + EXPECT_EQ(DoMklLayoutOptimizationPass(), \ + "A(" #INPUT ");B(Input);C(Input);D(Input);E(Input);" \ + "F(" #INPUT ");G(_FusedBatchNormEx);H(Zeta)|A->G;A->H;" \ + "B->G:1;C->G:2;D->G:3;E->G:4;F->G:5;G->H:1"); \ + } +REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormEx_Negative1); +#undef REGISTER_TEST + +// Rewrite test for _FusedBatchNormEx Op with Identity activation +#define REGISTER_TEST(NAME, T, INPUT) \ + TEST_F(MklLayoutPassTest, NAME##_##T) { \ + InitGraph("node { name: 'A' op: '" #INPUT "'}" \ + "node { name: 'B' op: 'Input'}" \ + "node { name: 'C' op: 'Input'}" \ + "node { name: 'D' op: 'Input'}" \ + "node { name: 'E' op: 'Input'}" \ + "node { name: 'G' op: '_FusedBatchNormEx'" \ + " attr { key: 'T' value { type: " #T " } }" \ + " attr { key: 'U' value { type: DT_FLOAT } }" \ + " attr { key: 'data_format' value { s: 'NCHW' } }" \ + " attr { key: 'epsilon' value { f: 0.0001 } }" \ + " attr { key: 'num_side_inputs' value { i: 1 } }" \ + " attr { key: 'is_training' value { b: true } }" \ + " attr { key: 'activation_mode' value { s: 'Identity' } }" \ + " input: ['A', 'B', 'C', 'D', 'E'] }" \ + "node { name: 'H' op: 'Zeta'" \ + " attr { key: 'T' value { type: " #T " } }" \ + " input: ['A', 'G'] }"); \ + EXPECT_EQ(DoMklLayoutOptimizationPass(), \ + "A(" #INPUT ");B(Input);C(Input);D(Input);E(Input);" \ + "G(_FusedBatchNormEx);H(Zeta)|A->G;A->H;" \ + "B->G:1;C->G:2;D->G:3;E->G:4;G->H:1"); \ + } +REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormEx_Negative2); +#undef REGISTER_TEST +#endif // ENABLE_MKLDNN_V1 +// clang-format on + TEST_F(MklLayoutPassTest, NodeRewrite_QuantizedDepthwiseConv2D_Positive) { InitGraph( "node { name: 'A' op: 'QuantizedUnsignedInt8Input'}" diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc index 271169f2a5e..364750b6679 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/common_runtime/rendezvous_util.h" #include "tensorflow/core/common_runtime/replicate_per_replica_nodes.h" +#include "tensorflow/core/framework/cancellation.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/op_kernel.h" @@ -230,7 +231,7 @@ FunctionLibraryRuntime* ProcessFunctionLibraryRuntime::GetFLR( Device* device = nullptr; if (device_name != kDefaultFLRDevice) { if (!device_mgr_->LookupDevice(device_name, &device).ok()) { - VLOG(1) << "Could not find device: " << device_name; + VLOG(4) << "Could not find device: " << device_name; return nullptr; } } @@ -1046,7 +1047,37 @@ void ProcessFunctionLibraryRuntime::RunMultiDevice( return; } - auto* refcounted_done = new ReffedStatusCallback(std::move(done)); + // A locally created cancellation manager, used only when the caller does not + // provide one in argument. + std::shared_ptr<CancellationManager> local_cm; + CancellationManager* cm = opts.cancellation_manager; + if (cm == nullptr) { + local_cm = std::make_shared<CancellationManager>(); + cm = local_cm.get(); + } + auto token = cm->get_cancellation_token(); + const auto cancelled_error = errors::Cancelled( + "ProcessFunctionLibraryRuntime::RunMultiDevice was cancelled."); + const bool already_cancelled = !cm->RegisterCallback( + token, + [rendez = opts.rendezvous, n_func = data->glue_.size(), cancelled_error] { + // Abort rendezvous only if there are more than one component functions + // to avoid reporting cancellation error directly to PartitionedCallOps + // that launch a single component function. + if (rendez && n_func > 1) { + rendez->StartAbort(cancelled_error); + } + }); + if (already_cancelled) { + done(cancelled_error); + return; + } + + auto* refcounted_done = new ReffedStatusCallback( + [cm, token, local_cm, done = std::move(done)](const Status& s) { + cm->TryDeregisterCallback(token); + done(s); + }); for (int i = 0; i < data->glue_.size(); ++i) { refcounted_done->Ref(); } @@ -1059,7 +1090,7 @@ void ProcessFunctionLibraryRuntime::RunMultiDevice( opts_copy.args_alloc_attrs = comp_data.arg_alloc_attrs; opts_copy.rets_alloc_attrs = comp_data.ret_alloc_attrs; - opts_copy.remote_execution = false; + opts_copy.cancellation_manager = cm; InternalArgs comp_args; Status s = get_component_args(comp_data, &comp_args); @@ -1067,13 +1098,39 @@ void ProcessFunctionLibraryRuntime::RunMultiDevice( VLOG(2) << "Failed to get component function arguments: " << s; refcounted_done->UpdateStatus(s); refcounted_done->Unref(); + cm->StartCancel(); continue; } std::vector<Tensor>* comp_rets = new std::vector<Tensor>; rets->resize(data->num_outputs_); + auto component_fn_callback = [comp_rets, rets, comp_data, refcounted_done, + cm, local_cm, data, + target](const Status& status) { + if (!status.ok()) { + VLOG(2) << "Component function execution on target " << target + << " failed: " << status; + const string function_and_msg = strings::StrCat( + errors::FormatFunctionForError(data->function_name_), " ", + status.error_message()); + refcounted_done->UpdateStatus(Status(status.code(), function_and_msg)); + // Cancel the execution of other component functions. + cm->StartCancel(); + } else { + VLOG(2) << "Component function execution on target " << target + << " succeeded."; + for (int i = 0; i < comp_rets->size(); ++i) { + (*rets)[comp_data.ret_indices[i]] = (*comp_rets)[i]; + } + } + delete comp_rets; + // refcounted_done is thread-safe + refcounted_done->Unref(); + }; + FunctionLibraryRuntime* flr = GetFLR(target); if (flr != nullptr) { + opts_copy.remote_execution = false; // When target device has private thread pool, use the target device // runner thread::ThreadPool* pool = flr->device()->tensorflow_device_thread_pool(); @@ -1084,24 +1141,7 @@ void ProcessFunctionLibraryRuntime::RunMultiDevice( VLOG(4) << " with " << opts_copy.DebugString(); flr->Run(opts_copy, handle, GetLocalArgs(comp_args.args), comp_rets, - [comp_rets, rets, comp_data, refcounted_done, - data](const Status& status) { - if (!status.ok()) { - VLOG(2) << "Component function execution failed: " << status; - const string function_and_msg = strings::StrCat( - errors::FormatFunctionForError(data->function_name_), - " ", status.error_message()); - refcounted_done->UpdateStatus( - Status(status.code(), function_and_msg)); - } else { - for (int i = 0; i < comp_rets->size(); ++i) { - (*rets)[comp_data.ret_indices[i]] = (*comp_rets)[i]; - } - } - delete comp_rets; - // refcounted_done is thread-safe - refcounted_done->Unref(); - }); + std::move(component_fn_callback)); } else { opts_copy.remote_execution = true; @@ -1109,21 +1149,8 @@ void ProcessFunctionLibraryRuntime::RunMultiDevice( << " with handle " << handle; VLOG(4) << " with " << opts_copy.DebugString(); - RunInternal( - opts_copy, handle, comp_args.args, comp_rets, cleanup_items, - [comp_rets, rets, comp_data, refcounted_done](const Status& status) { - if (!status.ok()) { - VLOG(2) << "Component function execution failed: " << status; - refcounted_done->UpdateStatus(status); - } else { - for (int i = 0; i < comp_rets->size(); ++i) { - (*rets)[comp_data.ret_indices[i]] = (*comp_rets)[i]; - } - } - delete comp_rets; - // refcounted_done is thread-safe - refcounted_done->Unref(); - }); + RunInternal(opts_copy, handle, comp_args.args, comp_rets, cleanup_items, + std::move(component_fn_callback)); } } refcounted_done->Unref(); diff --git a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc index 247b94dc58c..5bdb4601d37 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc @@ -820,7 +820,8 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_CompositeDevice) { Status s; std::unique_ptr<CompositeDevice> composite_device = CompositeDevice::MakeDevice({device0_->name(), device1_->name()}, - /*unique_device_id=*/0, &s); + /*unique_device_id=*/0, + device_mgr_->HostCPU()->parsed_name(), &s); TF_ASSERT_OK(s); AddCompositeDevice(composite_device.get()); diff --git a/tensorflow/core/data/BUILD b/tensorflow/core/data/BUILD index 9c58be108fc..1b6e6790559 100644 --- a/tensorflow/core/data/BUILD +++ b/tensorflow/core/data/BUILD @@ -1,5 +1,10 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test") -load("//tensorflow/core/platform:build_config.bzl", "tf_protos_all") +load( + "//tensorflow/core/platform:build_config.bzl", + "tf_additional_all_protos", + "tf_proto_library", + "tf_protos_all", +) package( default_visibility = [ @@ -10,6 +15,45 @@ package( exports_files(["LICENSE"]) +cc_library( + name = "compression_utils", + srcs = ["compression_utils.cc"], + hdrs = [ + "compression_utils.h", + ], + deps = [ + ":dataset_proto_cc", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/memory", + ], +) + +tf_cc_test( + name = "compression_utils_test", + srcs = ["compression_utils_test.cc"], + deps = [ + ":compression_utils", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels/data:dataset_test_base", + ], +) + +tf_proto_library( + name = "dataset_proto", + srcs = ["dataset.proto"], + cc_api_version = 2, + protodeps = tf_additional_all_protos(), +) + cc_library( name = "standalone", srcs = ["standalone.cc"], diff --git a/tensorflow/core/data/service/compression_utils.cc b/tensorflow/core/data/compression_utils.cc similarity index 88% rename from tensorflow/core/data/service/compression_utils.cc rename to tensorflow/core/data/compression_utils.cc index c4a47e1b00e..d132bdca8da 100644 --- a/tensorflow/core/data/service/compression_utils.cc +++ b/tensorflow/core/data/compression_utils.cc @@ -12,21 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/data/service/compression_utils.h" +#include "tensorflow/core/data/compression_utils.h" #include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/platform/snappy.h" -#include "tensorflow/core/profiler/lib/traceme.h" namespace tensorflow { namespace data { -namespace service_util { - -Status Compress(const std::vector<Tensor>& element, CompressedElement* out) { - tensorflow::profiler::TraceMe activity( - "Compress", tensorflow::profiler::TraceMeLevel::kInfo); +Status CompressElement(const std::vector<Tensor>& element, + CompressedElement* out) { // Step 1: Determine the total uncompressed size. This requires serializing // non-memcopyable tensors, which we save to use again later. std::vector<TensorProto> non_memcpy_components; @@ -51,7 +47,8 @@ Status Compress(const std::vector<Tensor>& element, CompressedElement* out) { char* position = uncompressed.mdata(); int non_memcpy_component_index = 0; for (auto& component : element) { - ComponentMetadata* metadata = out->mutable_component_metadata()->Add(); + CompressedComponentMetadata* metadata = + out->mutable_component_metadata()->Add(); metadata->set_dtype(component.dtype()); component.shape().AsProto(metadata->mutable_tensor_shape()); if (DataTypeCanUseMemcpy(component.dtype())) { @@ -71,13 +68,13 @@ Status Compress(const std::vector<Tensor>& element, CompressedElement* out) { out->mutable_data())) { return errors::Internal("Failed to compress using snappy."); } + VLOG(3) << "Compressed element from " << total_size << " bytes to " + << out->data().size() << " bytes"; return Status::OK(); } -Status Uncompress(const CompressedElement& compressed, - std::vector<Tensor>* out) { - tensorflow::profiler::TraceMe activity( - "Uncompress", tensorflow::profiler::TraceMeLevel::kInfo); +Status UncompressElement(const CompressedElement& compressed, + std::vector<Tensor>* out) { int num_components = compressed.component_metadata_size(); out->clear(); out->reserve(num_components); @@ -92,7 +89,8 @@ Status Uncompress(const CompressedElement& compressed, tensor_proto_strs.reserve(num_components); int64 total_size = 0; for (int i = 0; i < num_components; ++i) { - const ComponentMetadata& metadata = compressed.component_metadata(i); + const CompressedComponentMetadata& metadata = + compressed.component_metadata(i); if (DataTypeCanUseMemcpy(metadata.dtype())) { out->emplace_back(metadata.dtype(), metadata.tensor_shape()); TensorBuffer* buffer = DMAHelper::buffer(&out->back()); @@ -146,6 +144,5 @@ Status Uncompress(const CompressedElement& compressed, return Status::OK(); } -} // namespace service_util } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/data/service/compression_utils.h b/tensorflow/core/data/compression_utils.h similarity index 82% rename from tensorflow/core/data/service/compression_utils.h rename to tensorflow/core/data/compression_utils.h index 96698aaaf09..5e033771272 100644 --- a/tensorflow/core/data/service/compression_utils.h +++ b/tensorflow/core/data/compression_utils.h @@ -16,24 +16,23 @@ limitations under the License. #define TENSORFLOW_CORE_DATA_SERVICE_COMPRESSION_UTILS_H_ #include "tensorflow/core/common_runtime/dma_helper.h" -#include "tensorflow/core/data/service/common.pb.h" +#include "tensorflow/core/data/dataset.pb.h" #include "tensorflow/core/platform/status.h" namespace tensorflow { namespace data { -namespace service_util { // Compresses the components of `element` into the `CompressedElement` proto. // // In addition to writing the actual compressed bytes, `Compress` fills // out the per-component metadata for the `CompressedElement`. -Status Compress(const std::vector<Tensor>& element, CompressedElement* out); +Status CompressElement(const std::vector<Tensor>& element, + CompressedElement* out); // Uncompresses a `CompressedElement` into a vector of tensor components. -Status Uncompress(const CompressedElement& compressed, - std::vector<Tensor>* out); +Status UncompressElement(const CompressedElement& compressed, + std::vector<Tensor>* out); -} // namespace service_util } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/data/service/compression_utils_test.cc b/tensorflow/core/data/compression_utils_test.cc similarity index 89% rename from tensorflow/core/data/service/compression_utils_test.cc rename to tensorflow/core/data/compression_utils_test.cc index b5da13efeed..eb220092f88 100644 --- a/tensorflow/core/data/service/compression_utils_test.cc +++ b/tensorflow/core/data/compression_utils_test.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/data/service/compression_utils.h" +#include "tensorflow/core/data/compression_utils.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/kernels/data/dataset_test_base.h" @@ -20,7 +20,6 @@ limitations under the License. namespace tensorflow { namespace data { -namespace service_util { class ParameterizedCompressionUtilsTest : public DatasetOpsTestBase, @@ -29,9 +28,9 @@ class ParameterizedCompressionUtilsTest TEST_P(ParameterizedCompressionUtilsTest, RoundTrip) { std::vector<Tensor> element = GetParam(); CompressedElement compressed; - TF_ASSERT_OK(Compress(element, &compressed)); + TF_ASSERT_OK(CompressElement(element, &compressed)); std::vector<Tensor> round_trip_element; - TF_ASSERT_OK(Uncompress(compressed, &round_trip_element)); + TF_ASSERT_OK(UncompressElement(compressed, &round_trip_element)); TF_EXPECT_OK( ExpectEqual(element, round_trip_element, /*compare_order=*/true)); } @@ -50,6 +49,5 @@ std::vector<std::vector<Tensor>> TestCases() { INSTANTIATE_TEST_SUITE_P(Instantiation, ParameterizedCompressionUtilsTest, ::testing::ValuesIn(TestCases())); -} // namespace service_util } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/data/dataset.proto b/tensorflow/core/data/dataset.proto new file mode 100644 index 00000000000..27a36364e76 --- /dev/null +++ b/tensorflow/core/data/dataset.proto @@ -0,0 +1,27 @@ +syntax = "proto3"; + +package tensorflow.data; + +import "tensorflow/core/framework/tensor_shape.proto"; +import "tensorflow/core/framework/types.proto"; + +// This file contains protocol buffers for working with tf.data Datasets. + +// Metadata describing a compressed component of a dataset element. +message CompressedComponentMetadata { + // The dtype of the component tensor. + .tensorflow.DataType dtype = 1; + // The shape of the component tensor. + .tensorflow.TensorShapeProto tensor_shape = 2; + // Size of the uncompressed tensor bytes. For tensors serialized as + // TensorProtos, this is TensorProto::BytesAllocatedLong(). For raw Tensors, + // this is the size of the buffer underlying the Tensor. + int64 tensor_size_bytes = 3; +} + +message CompressedElement { + // Compressed tensor bytes for all components of the element. + bytes data = 1; + // Metadata for the components of the element. + repeated CompressedComponentMetadata component_metadata = 2; +} diff --git a/tensorflow/core/data/service/BUILD b/tensorflow/core/data/service/BUILD index 5413493cb78..b76f93c454e 100644 --- a/tensorflow/core/data/service/BUILD +++ b/tensorflow/core/data/service/BUILD @@ -44,6 +44,7 @@ tf_proto_library( cc_api_version = 2, protodeps = tf_additional_all_protos() + [ ":common_proto", + "//tensorflow/core/data:dataset_proto", ], ) @@ -84,7 +85,6 @@ cc_library( ], deps = [ ":common_proto_cc", - ":compression_utils", ":credentials_factory", ":grpc_util", ":master_cc_grpc_proto", @@ -98,6 +98,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/data:dataset_proto_cc", "//tensorflow/core/data:standalone", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", @@ -129,39 +130,6 @@ tf_cc_test( ], ) -cc_library( - name = "compression_utils", - srcs = ["compression_utils.cc"], - hdrs = [ - "compression_utils.h", - ], - deps = [ - ":common_proto_cc", - "//tensorflow/core:core_cpu", - "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/profiler/lib:traceme", - "@com_google_absl//absl/memory", - ], -) - -tf_cc_test( - name = "compression_utils_test", - srcs = ["compression_utils_test.cc"], - deps = [ - ":compression_utils", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core:testlib", - "//tensorflow/core/kernels/data:dataset_test_base", - ], -) - cc_library( name = "credentials_factory", srcs = ["credentials_factory.cc"], @@ -317,7 +285,6 @@ tf_cc_test( srcs = ["data_service_test.cc"], tags = ["no_windows"], deps = [ - ":compression_utils", ":data_service", ":grpc_master_impl", ":grpc_util", @@ -333,6 +300,7 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core/data:compression_utils", "//tensorflow/core/kernels/data:dataset_test_base", "@com_google_absl//absl/strings", tf_grpc_cc_dependency(), diff --git a/tensorflow/core/data/service/common.proto b/tensorflow/core/data/service/common.proto index 6dfa698764b..4bde56fe1ca 100644 --- a/tensorflow/core/data/service/common.proto +++ b/tensorflow/core/data/service/common.proto @@ -3,7 +3,6 @@ syntax = "proto3"; package tensorflow.data; import "tensorflow/core/framework/graph.proto"; -import "tensorflow/core/framework/tensor_shape.proto"; import "tensorflow/core/framework/types.proto"; message DatasetDef { @@ -12,24 +11,6 @@ message DatasetDef { GraphDef graph = 1; } -message ComponentMetadata { - // The dtype of the component tensor. - .tensorflow.DataType dtype = 1; - // The shape of the component tensor. - .tensorflow.TensorShapeProto tensor_shape = 2; - // Size of the uncompressed tensor bytes. For tensors serialized as - // TensorProtos, this is TensorProto::BytesAllocatedLong(). For raw Tensors, - // this is the size of the buffer underlying the Tensor. - int64 tensor_size_bytes = 3; -} - -message CompressedElement { - // Compressed tensor bytes for all components of the element. - bytes data = 1; - // Metadata for the components of the element. - repeated ComponentMetadata component_metadata = 2; -} - message TaskDef { // The dataset to iterate over. // TODO(aaudibert): load the dataset from disk instead of passing it here. diff --git a/tensorflow/core/data/service/data_service.cc b/tensorflow/core/data/service/data_service.cc index 915435d8fcb..d4e08c77f35 100644 --- a/tensorflow/core/data/service/data_service.cc +++ b/tensorflow/core/data/service/data_service.cc @@ -132,6 +132,22 @@ Status DataServiceMasterClient::GetTasks(int64 job_id, return Status::OK(); } +Status DataServiceMasterClient::GetWorkers(std::vector<WorkerInfo>* workers) { + TF_RETURN_IF_ERROR(EnsureInitialized()); + GetWorkersRequest req; + GetWorkersResponse resp; + grpc_impl::ClientContext ctx; + grpc::Status s = stub_->GetWorkers(&ctx, req, &resp); + if (!s.ok()) { + return grpc_util::WrapError("Failed to get workers", s); + } + workers->clear(); + for (auto& worker : resp.workers()) { + workers->push_back(worker); + } + return Status::OK(); +} + Status DataServiceMasterClient::EnsureInitialized() { std::shared_ptr<grpc::ChannelCredentials> credentials; TF_RETURN_IF_ERROR( diff --git a/tensorflow/core/data/service/data_service.h b/tensorflow/core/data/service/data_service.h index d205b4d9ebf..bb5a8a470f0 100644 --- a/tensorflow/core/data/service/data_service.h +++ b/tensorflow/core/data/service/data_service.h @@ -96,6 +96,10 @@ class DataServiceMasterClient : public DataServiceClientBase { Status GetTasks(int64 job_id, std::vector<TaskInfo>* tasks, bool* job_finished); + // Queries the master for its registered workers. The worker info will be + // stored in `*workers`. + Status GetWorkers(std::vector<WorkerInfo>* workers); + protected: Status EnsureInitialized() override; diff --git a/tensorflow/core/data/service/data_service_test.cc b/tensorflow/core/data/service/data_service_test.cc index 73a46bad3d0..19392393eeb 100644 --- a/tensorflow/core/data/service/data_service_test.cc +++ b/tensorflow/core/data/service/data_service_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include "grpcpp/create_channel.h" #include "grpcpp/security/credentials.h" #include "absl/strings/str_split.h" -#include "tensorflow/core/data/service/compression_utils.h" +#include "tensorflow/core/data/compression_utils.h" #include "tensorflow/core/data/service/grpc_util.h" #include "tensorflow/core/data/service/master.grpc.pb.h" #include "tensorflow/core/data/service/master.pb.h" @@ -37,6 +37,7 @@ namespace data { namespace { constexpr const char kProtocol[] = "grpc+local"; +} TEST(DataService, ParseParallelEpochsProcessingMode) { ProcessingMode mode; @@ -62,114 +63,13 @@ TEST(DataService, ProcessingModeToString) { EXPECT_EQ("one_epoch", ProcessingModeToString(ProcessingMode::ONE_EPOCH)); } -Status CheckWorkerOutput(const std::string& worker_address, int64 task_id, - std::vector<std::vector<Tensor>> expected_output) { - DataServiceWorkerClient worker(worker_address, kProtocol); - for (std::vector<Tensor>& expected : expected_output) { - bool end_of_sequence; - CompressedElement compressed; - TF_RETURN_IF_ERROR( - worker.GetElement(task_id, &compressed, &end_of_sequence)); - if (end_of_sequence) { - return errors::Internal("Reached end of sequence too early."); - } - std::vector<Tensor> element; - TF_RETURN_IF_ERROR(service_util::Uncompress(compressed, &element)); - TF_RETURN_IF_ERROR(DatasetOpsTestBase::ExpectEqual(element, expected, - /*compare_order=*/true)); - } - // Call GetElement a couple more times to verify tha end_of_sequence keeps - // returning true. - bool end_of_sequence; - CompressedElement compressed; - TF_RETURN_IF_ERROR(worker.GetElement(task_id, &compressed, &end_of_sequence)); - if (!end_of_sequence) { - return errors::Internal("Expected end_of_sequence to be true"); - } - TF_RETURN_IF_ERROR(worker.GetElement(task_id, &compressed, &end_of_sequence)); - if (!end_of_sequence) { - return errors::Internal("Expected end_of_sequence to be true"); - } - return Status::OK(); -} - -} // namespace - -TEST(DataService, IterateDatasetOneWorker) { +TEST(DataService, GetWorkers) { TestCluster cluster(1); TF_ASSERT_OK(cluster.Initialize()); - test_util::GraphDefTestCase test_case; - TF_ASSERT_OK(test_util::map_test_case(&test_case)); DataServiceMasterClient master(cluster.MasterAddress(), kProtocol); - - int64 dataset_id; - TF_ASSERT_OK(master.RegisterDataset(test_case.graph_def, &dataset_id)); - int64 job_id; - TF_ASSERT_OK( - master.CreateJob(dataset_id, ProcessingMode::PARALLEL_EPOCHS, &job_id)); - std::vector<TaskInfo> tasks; - bool job_finished; - TF_ASSERT_OK(master.GetTasks(job_id, &tasks, &job_finished)); - ASSERT_EQ(tasks.size(), 1); - EXPECT_EQ(tasks[0].worker_address(), cluster.WorkerAddress(0)); - EXPECT_FALSE(job_finished); - - TF_EXPECT_OK(CheckWorkerOutput(tasks[0].worker_address(), tasks[0].id(), - test_case.output)); -} - -TEST(DataService, IterateDatasetTwoWorkers) { - TestCluster cluster(2); - TF_ASSERT_OK(cluster.Initialize()); - test_util::GraphDefTestCase test_case; - TF_ASSERT_OK(test_util::map_test_case(&test_case)); - DataServiceMasterClient master(cluster.MasterAddress(), kProtocol); - - int64 dataset_id; - TF_ASSERT_OK(master.RegisterDataset(test_case.graph_def, &dataset_id)); - int64 job_id; - TF_ASSERT_OK( - master.CreateJob(dataset_id, ProcessingMode::PARALLEL_EPOCHS, &job_id)); - std::vector<TaskInfo> tasks; - bool job_finished; - TF_EXPECT_OK(master.GetTasks(job_id, &tasks, &job_finished)); - EXPECT_EQ(tasks.size(), 2); - EXPECT_FALSE(job_finished); - - // Each worker produces the full dataset. - for (TaskInfo task : tasks) { - TF_EXPECT_OK( - CheckWorkerOutput(task.worker_address(), task.id(), test_case.output)); - } -} - -TEST(DataService, AddWorkerMidEpoch) { - TestCluster cluster(1); - TF_ASSERT_OK(cluster.Initialize()); - test_util::GraphDefTestCase test_case; - TF_ASSERT_OK(test_util::map_test_case(&test_case)); - DataServiceMasterClient master(cluster.MasterAddress(), kProtocol); - - int64 dataset_id; - TF_ASSERT_OK(master.RegisterDataset(test_case.graph_def, &dataset_id)); - int64 job_id; - TF_ASSERT_OK( - master.CreateJob(dataset_id, ProcessingMode::PARALLEL_EPOCHS, &job_id)); - std::vector<TaskInfo> tasks; - bool job_finished; - TF_ASSERT_OK(master.GetTasks(job_id, &tasks, &job_finished)); - EXPECT_EQ(tasks.size(), 1); - EXPECT_FALSE(job_finished); - TF_ASSERT_OK(cluster.AddWorker()); - TF_EXPECT_OK(master.GetTasks(job_id, &tasks, &job_finished)); - EXPECT_EQ(tasks.size(), 2); - EXPECT_FALSE(job_finished); - - // Each worker produces the full dataset. - for (TaskInfo task : tasks) { - TF_EXPECT_OK( - CheckWorkerOutput(task.worker_address(), task.id(), test_case.output)); - } + std::vector<WorkerInfo> workers; + TF_EXPECT_OK(master.GetWorkers(&workers)); + EXPECT_EQ(1, workers.size()); } } // namespace data diff --git a/tensorflow/core/data/service/grpc_master_impl.cc b/tensorflow/core/data/service/grpc_master_impl.cc index ba27959fee7..20ad58a0115 100644 --- a/tensorflow/core/data/service/grpc_master_impl.cc +++ b/tensorflow/core/data/service/grpc_master_impl.cc @@ -44,6 +44,7 @@ HANDLER(GetOrRegisterDataset); HANDLER(CreateJob); HANDLER(GetOrCreateJob); HANDLER(GetTasks); +HANDLER(GetWorkers); #undef HANDLER } // namespace data diff --git a/tensorflow/core/data/service/grpc_master_impl.h b/tensorflow/core/data/service/grpc_master_impl.h index 32eb0f3fc6a..d29bb6759f0 100644 --- a/tensorflow/core/data/service/grpc_master_impl.h +++ b/tensorflow/core/data/service/grpc_master_impl.h @@ -48,6 +48,7 @@ class GrpcMasterImpl : public MasterService::Service { HANDLER(CreateJob); HANDLER(GetOrCreateJob); HANDLER(GetTasks); + HANDLER(GetWorkers); #undef HANDLER private: diff --git a/tensorflow/core/data/service/grpc_worker_impl.cc b/tensorflow/core/data/service/grpc_worker_impl.cc index a5d005d6c6e..7884fa063ba 100644 --- a/tensorflow/core/data/service/grpc_worker_impl.cc +++ b/tensorflow/core/data/service/grpc_worker_impl.cc @@ -30,7 +30,6 @@ GrpcWorkerImpl::GrpcWorkerImpl(ServerBuilder* server_builder, const std::string& protocol) : impl_(master_address, protocol) { server_builder->RegisterService(this); - LOG(INFO) << "GrpcWorkerImpl: master address is " << master_address; VLOG(1) << "Registered data service worker"; } diff --git a/tensorflow/core/data/service/master.proto b/tensorflow/core/data/service/master.proto index 005e5affb7d..661264cc41b 100644 --- a/tensorflow/core/data/service/master.proto +++ b/tensorflow/core/data/service/master.proto @@ -98,6 +98,18 @@ message GetTasksResponse { bool job_finished = 2; } +message WorkerInfo { + string address = 1; + int64 id = 2; +} + +message GetWorkersRequest {} + +message GetWorkersResponse { + // A list of all workers. + repeated WorkerInfo workers = 1; +} + service MasterService { // Registers a worker with the master. rpc RegisterWorker(RegisterWorkerRequest) returns (RegisterWorkerResponse); @@ -121,4 +133,7 @@ service MasterService { // Reports a list of all tasks for a job. rpc GetTasks(GetTasksRequest) returns (GetTasksResponse); + + // Reports a list of all workers registered with the master. + rpc GetWorkers(GetWorkersRequest) returns (GetWorkersResponse); } diff --git a/tensorflow/core/data/service/master_impl.cc b/tensorflow/core/data/service/master_impl.cc index 336ab068c40..37a884d540e 100644 --- a/tensorflow/core/data/service/master_impl.cc +++ b/tensorflow/core/data/service/master_impl.cc @@ -315,5 +315,19 @@ Status DataServiceMasterImpl::GetTasks(const GetTasksRequest* request, return Status::OK(); } +Status DataServiceMasterImpl::GetWorkers(const GetWorkersRequest* request, + GetWorkersResponse* response) { + mutex_lock l(mu_); + VLOG(3) << "Enter GetWorkers"; + for (auto& worker : workers_) { + WorkerInfo* info = response->add_workers(); + info->set_address(worker.address()); + info->set_id(worker.worker_id()); + } + VLOG(3) << "Returning list of " << workers_.size() + << " workers from GetWorkers"; + return Status::OK(); +} + } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/data/service/master_impl.h b/tensorflow/core/data/service/master_impl.h index e8b70e84d0f..0dc049a389c 100644 --- a/tensorflow/core/data/service/master_impl.h +++ b/tensorflow/core/data/service/master_impl.h @@ -60,6 +60,8 @@ class DataServiceMasterImpl { Status GetOrCreateJob(const GetOrCreateJobRequest* request, GetOrCreateJobResponse* response); Status GetTasks(const GetTasksRequest* request, GetTasksResponse* response); + Status GetWorkers(const GetWorkersRequest* request, + GetWorkersResponse* response); private: class Worker { diff --git a/tensorflow/core/data/service/server_lib.cc b/tensorflow/core/data/service/server_lib.cc index 66fc1e20603..33c2232f4dc 100644 --- a/tensorflow/core/data/service/server_lib.cc +++ b/tensorflow/core/data/service/server_lib.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/data/service/grpc_master_impl.h" #include "tensorflow/core/data/service/grpc_util.h" #include "tensorflow/core/data/service/grpc_worker_impl.h" +#include "tensorflow/core/platform/errors.h" namespace tensorflow { namespace data { @@ -31,6 +32,13 @@ GrpcDataServerBase::GrpcDataServerBase(int port, const std::string& protocol) : requested_port_(port), protocol_(protocol) {} Status GrpcDataServerBase::Start() { + if (stopped_) { + return errors::FailedPrecondition( + "Server cannot be started after it has been stopped."); + } + if (started_) { + return Status::OK(); + } ::grpc::ServerBuilder builder; std::shared_ptr<::grpc::ServerCredentials> credentials; TF_RETURN_IF_ERROR( @@ -47,11 +55,18 @@ Status GrpcDataServerBase::Start() { TF_RETURN_IF_ERROR(StartServiceInternal()); + started_ = true; VLOG(1) << "Started tf.data service running at 0.0.0.0:" << BoundPort(); return Status::OK(); } -void GrpcDataServerBase::Stop() { server_->Shutdown(); } +void GrpcDataServerBase::Stop() { + if (stopped_) { + return; + } + server_->Shutdown(); + stopped_ = true; +} void GrpcDataServerBase::Join() { server_->Wait(); } @@ -68,15 +83,15 @@ void MasterGrpcDataServer::AddServiceToBuilder(grpc::ServerBuilder* builder) { service_ = service.release(); } -Status MasterGrpcDataServer::NumTasks(int* num_tasks) { - GetTasksRequest req; - GetTasksResponse resp; +Status MasterGrpcDataServer::NumWorkers(int* num_workers) { + GetWorkersRequest req; + GetWorkersResponse resp; grpc::ServerContext ctx; - grpc::Status s = service_->GetTasks(&ctx, &req, &resp); + grpc::Status s = service_->GetWorkers(&ctx, &req, &resp); if (!s.ok()) { - return grpc_util::WrapError("Failed to get num tasks", s); + return grpc_util::WrapError("Failed to get workers", s); } - *num_tasks = resp.task_info_size(); + *num_workers = resp.workers_size(); return Status::OK(); } diff --git a/tensorflow/core/data/service/server_lib.h b/tensorflow/core/data/service/server_lib.h index 0ef305db89a..72bec665c8e 100644 --- a/tensorflow/core/data/service/server_lib.h +++ b/tensorflow/core/data/service/server_lib.h @@ -64,6 +64,8 @@ class GrpcDataServerBase { private: int bound_port_; + bool started_ = false; + bool stopped_ = false; std::unique_ptr<grpc::Server> server_; }; @@ -73,8 +75,8 @@ class MasterGrpcDataServer : public GrpcDataServerBase { MasterGrpcDataServer(int requested_port, const std::string& protocol); ~MasterGrpcDataServer() override; - // Returns the number of tasks created by the master. - Status NumTasks(int* num_tasks); + // Returns the number of workers registerd with the master. + Status NumWorkers(int* num_workers); protected: void AddServiceToBuilder(grpc::ServerBuilder* builder) override; diff --git a/tensorflow/core/data/service/worker.proto b/tensorflow/core/data/service/worker.proto index 04b8f03474c..51c6899f540 100644 --- a/tensorflow/core/data/service/worker.proto +++ b/tensorflow/core/data/service/worker.proto @@ -2,6 +2,7 @@ syntax = "proto3"; package tensorflow.data; +import "tensorflow/core/data/dataset.proto"; import "tensorflow/core/data/service/common.proto"; message ProcessTaskRequest { diff --git a/tensorflow/core/data/service/worker_impl.cc b/tensorflow/core/data/service/worker_impl.cc index 8d00825227b..151410bb219 100644 --- a/tensorflow/core/data/service/worker_impl.cc +++ b/tensorflow/core/data/service/worker_impl.cc @@ -19,7 +19,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/tf_status_helper.h" -#include "tensorflow/core/data/service/compression_utils.h" +#include "tensorflow/core/data/dataset.pb.h" #include "tensorflow/core/data/service/credentials_factory.h" #include "tensorflow/core/data/service/grpc_util.h" #include "tensorflow/core/data/service/master.grpc.pb.h" @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/io/zlib_outputbuffer.h" #include "tensorflow/core/lib/monitoring/gauge.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/snappy.h" #include "tensorflow/core/public/session_options.h" @@ -135,8 +136,33 @@ Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request, if (!end_of_sequence) { VLOG(3) << "Producing an element for task " << request->task_id(); - TF_RETURN_IF_ERROR(service_util::Compress( - outputs, response->mutable_compressed_element())); + if (outputs.size() != 1) { + return errors::FailedPrecondition( + "Expected dataset to produce a single scalar variant tensor, but the " + "dataset produced ", + outputs.size(), " outputs"); + } + if (outputs[0].dtype() != DT_VARIANT) { + return errors::FailedPrecondition( + "Expected dataset to produce a single scalar variant tensor, but " + "the dataset produced a tensor with type ", + DataTypeString(outputs[0].dtype())); + } + if (!TensorShapeUtils::IsScalar(outputs[0].shape())) { + return errors::FailedPrecondition( + "Expected dataset to produce a single scalar variant tensor, but " + "the dataset produced a tensor with shape ", + outputs[0].shape()); + } + Variant& variant = outputs[0].scalar<Variant>()(); + CompressedElement* compressed = variant.get<CompressedElement>(); + if (compressed == nullptr) { + return errors::FailedPrecondition( + "Expected dataset to produce a CompressedElement variant tensor, but " + "it produced ", + variant.TypeName()); + } + compressed->Swap(response->mutable_compressed_element()); } response->set_end_of_sequence(end_of_sequence); diff --git a/tensorflow/core/distributed_runtime/eager/BUILD b/tensorflow/core/distributed_runtime/eager/BUILD index c7fdfa176b1..c27758cbb44 100644 --- a/tensorflow/core/distributed_runtime/eager/BUILD +++ b/tensorflow/core/distributed_runtime/eager/BUILD @@ -42,6 +42,7 @@ cc_library( "//tensorflow/core/common_runtime/eager:context", "//tensorflow/core/common_runtime/eager:eager_operation", "//tensorflow/core/common_runtime/eager:tensor_handle", + "//tensorflow/core/distributed_runtime:call_options", "//tensorflow/core/distributed_runtime:worker_session", "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", @@ -68,6 +69,7 @@ cc_library( "//tensorflow/core:eager_service_proto_cc", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core/distributed_runtime:call_options", ], ) diff --git a/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.cc b/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.cc index ec129173833..808188aa36d 100644 --- a/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.cc +++ b/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.cc @@ -20,9 +20,11 @@ limitations under the License. #include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/eager/eager_operation.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/distributed_runtime/call_options.h" #include "tensorflow/core/distributed_runtime/eager/eager_client.h" #include "tensorflow/core/distributed_runtime/eager/remote_execute_node.h" #include "tensorflow/core/distributed_runtime/eager/remote_mgr.h" +#include "tensorflow/core/framework/cancellation.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/lib/core/errors.h" @@ -64,13 +66,7 @@ void EagerClusterFunctionLibraryRuntime::Instantiate( VLOG(1) << "CFLR::Instantiate: " << function_name << " on " << target << " (this: " << this << ")"; core::RefCountPtr<eager::EagerClient> eager_client; - Device* device; - s = ctx_->FindDeviceFromName(target.c_str(), &device); - if (!s.ok()) { - done(s); - return; - } - s = ctx_->GetClient(device, &eager_client); + s = ctx_->GetClient(target, &eager_client); if (!s.ok()) { done(s); return; @@ -189,13 +185,31 @@ void EagerClusterFunctionLibraryRuntime::Run( op->Attrs().FillAttrValueMap(remote_op->mutable_attrs()); remote_op->set_device(function_data->target); + CancellationManager* cm = opts.cancellation_manager; + CancellationToken token = 0; + auto call_opts = std::make_shared<CallOptions>(); + if (cm != nullptr) { + token = cm->get_cancellation_token(); + const bool already_cancelled = !cm->RegisterCallback( + token, + [call_opts, request, response, done]() { call_opts->StartCancel(); }); + if (already_cancelled) { + done(errors::Cancelled("EagerClusterFunctionLibraryRuntime::Run")); + return; + } + } + // Execute component function on remote worker using RunComponentFunction RPC. // Different from executing remote functions with Enqueue, this method runs // a function on remote worker without tying up a thread (i.e., pure // asynchronously). eager_client->RunComponentFunctionAsync( - request.get(), response.get(), - [request, response, rets, done = std::move(done)](const Status& s) { + call_opts.get(), request.get(), response.get(), + [request, response, rets, call_opts, cm, token, + done = std::move(done)](const Status& s) { + if (cm != nullptr) { + cm->TryDeregisterCallback(token); + } if (!s.ok()) { done(s); return; diff --git a/tensorflow/core/distributed_runtime/eager/eager_client.h b/tensorflow/core/distributed_runtime/eager/eager_client.h index 9ca802d8a72..d6cf0943176 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_client.h +++ b/tensorflow/core/distributed_runtime/eager/eager_client.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_EAGER_CLIENT_H_ #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_EAGER_CLIENT_H_ +#include "tensorflow/core/distributed_runtime/call_options.h" #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/env.h" @@ -38,12 +39,15 @@ class EagerClient : public core::RefCounted { CLIENT_METHOD(UpdateContext); CLIENT_METHOD(Enqueue); CLIENT_METHOD(WaitQueueDone); - CLIENT_METHOD(RunComponentFunction); CLIENT_METHOD(KeepAlive); CLIENT_METHOD(CloseContext); #undef CLIENT_METHOD + virtual void RunComponentFunctionAsync( + CallOptions* call_opts, const RunComponentFunctionRequest* request, + RunComponentFunctionResponse* response, StatusCallback done) = 0; + // Feeds `request` into the request stream of EagerService::StreamingEnqueue. // `response` will be filled with the response for this `request`. The // 1-to-1 correspondence between requests and responses is a property diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc index 6dc03cbc527..5327cbb6480 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc @@ -238,7 +238,7 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request, TF_RETURN_IF_ERROR(env_->session_mgr->WorkerSessionForSession( session_name, &worker_session)); - tensorflow::DeviceMgr* device_mgr = worker_session->device_mgr(); + const tensorflow::DeviceMgr* device_mgr = worker_session->device_mgr(); // Initialize remote tensor communication based on worker session. TF_RETURN_IF_ERROR(r->Initialize(worker_session.get())); @@ -355,7 +355,7 @@ Status EagerServiceImpl::UpdateContext(const UpdateContextRequest* request, TF_RETURN_IF_ERROR(env_->session_mgr->WorkerSessionForSession( session_name, &worker_session)); - tensorflow::DeviceMgr* device_mgr = worker_session->device_mgr(); + const tensorflow::DeviceMgr* device_mgr = worker_session->device_mgr(); std::vector<string> remote_workers; worker_session->worker_cache()->ListWorkers(&remote_workers); diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc index 46a6181cfa9..3c537d99a3a 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc @@ -90,7 +90,8 @@ class FakeEagerClient : public EagerClient { CLIENT_METHOD(CloseContext); #undef CLIENT_METHOD - void RunComponentFunctionAsync(const RunComponentFunctionRequest* request, + void RunComponentFunctionAsync(CallOptions* call_opts, + const RunComponentFunctionRequest* request, RunComponentFunctionResponse* response, StatusCallback done) override { impl_->RunComponentFunction(request, response, std::move(done)); diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc index 8b363e66d87..fe353d7d76c 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.cc +++ b/tensorflow/core/distributed_runtime/graph_mgr.cc @@ -55,7 +55,7 @@ limitations under the License. namespace tensorflow { -GraphMgr::GraphMgr(const WorkerEnv* worker_env, DeviceMgr* device_mgr) +GraphMgr::GraphMgr(const WorkerEnv* worker_env, const DeviceMgr* device_mgr) : worker_env_(worker_env), device_mgr_(device_mgr), table_(5) { // The default value of sync_on_finish will be flipped soon and this // environment variable will be removed as well. diff --git a/tensorflow/core/distributed_runtime/graph_mgr.h b/tensorflow/core/distributed_runtime/graph_mgr.h index 50190ab337e..e768c0907b6 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.h +++ b/tensorflow/core/distributed_runtime/graph_mgr.h @@ -69,7 +69,7 @@ class WorkerSession; // EXPECT_EQ(out["c"], Tensor({4, 6})); class GraphMgr { public: - explicit GraphMgr(const WorkerEnv* worker_env, DeviceMgr* device_mgr); + explicit GraphMgr(const WorkerEnv* worker_env, const DeviceMgr* device_mgr); ~GraphMgr(); // Registers a graph. Fills in "handle". The registered graph retains a @@ -145,7 +145,7 @@ class GraphMgr { }; const WorkerEnv* worker_env_; // Not owned. - DeviceMgr* device_mgr_; + const DeviceMgr* device_mgr_; CostModelManager cost_model_manager_; diff --git a/tensorflow/core/distributed_runtime/rpc/BUILD b/tensorflow/core/distributed_runtime/rpc/BUILD index 96e1a63e5a6..60d7172c2fc 100644 --- a/tensorflow/core/distributed_runtime/rpc/BUILD +++ b/tensorflow/core/distributed_runtime/rpc/BUILD @@ -462,6 +462,8 @@ tf_cuda_cc_tests( "//tensorflow/core:test_main", "//tensorflow/core:testlib", "//tensorflow/core/distributed_runtime:server_lib", + "//tensorflow/core/distributed_runtime:test_utils", + "//tensorflow/core/platform:blocking_counter", ], ) diff --git a/tensorflow/core/distributed_runtime/rpc/eager/BUILD b/tensorflow/core/distributed_runtime/rpc/eager/BUILD index d7251029d10..ff362c3411f 100644 --- a/tensorflow/core/distributed_runtime/rpc/eager/BUILD +++ b/tensorflow/core/distributed_runtime/rpc/eager/BUILD @@ -1,4 +1,5 @@ load("//tensorflow:tensorflow.bzl", "tf_grpc_cc_dependency") +load("//tensorflow:tensorflow.bzl", "tf_cc_test") package( default_visibility = [ @@ -29,6 +30,7 @@ cc_library( "//tensorflow/core:eager_service_proto_cc", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core/distributed_runtime:call_options", "//tensorflow/core/distributed_runtime/eager:eager_client", "//tensorflow/core/distributed_runtime/rpc:grpc_channel", "//tensorflow/core/distributed_runtime/rpc:grpc_client_cq_tag", @@ -56,3 +58,21 @@ cc_library( tf_grpc_cc_dependency(), ], ) + +tf_cc_test( + name = "grpc_eager_client_test", + size = "small", + srcs = [ + "grpc_eager_client_test.cc", + ], + deps = [ + ":grpc_eager_client", + "//tensorflow/c:tf_status_headers", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/distributed_runtime/rpc:grpc_channel", + "//tensorflow/core/platform:blocking_counter", + "//tensorflow/core/platform:status", + "//tensorflow/core/platform:strcat", + ], +) diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc index 752bfdf71a1..4e3da8b00e0 100644 --- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc +++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h" #include "grpcpp/generic/generic_stub.h" +#include "tensorflow/core/distributed_runtime/call_options.h" #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_state.h" @@ -135,7 +136,6 @@ class GrpcEagerClient : public EagerClient { CLIENT_METHOD(UpdateContext); CLIENT_METHOD(Enqueue); CLIENT_METHOD(WaitQueueDone); - CLIENT_METHOD(RunComponentFunction); CLIENT_METHOD(KeepAlive); #undef CLIENT_METHOD @@ -164,6 +164,18 @@ class GrpcEagerClient : public EagerClient { } } + void RunComponentFunctionAsync(CallOptions* call_opts, + const RunComponentFunctionRequest* request, + RunComponentFunctionResponse* response, + StatusCallback done) override { + StatusCallback done_wrapped = callback_wrapper(std::move(done)); + new RPCState<protobuf::Message>( + &stub_, cq_, "/tensorflow.eager.EagerService/RunComponentFunction", + *request, response, std::move(done_wrapped), call_opts, + /*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true, + &target_); + } + void StreamingEnqueueAsync(const EnqueueRequest* request, EnqueueResponse* response, StatusCallback done) override { @@ -228,6 +240,7 @@ class GrpcEagerClientCache : public EagerClientCache { Status GetClient(const string& target, core::RefCountPtr<EagerClient>* client) override { + mutex_lock l(clients_mu_); auto it = clients_.find(target); if (it == clients_.end()) { tensorflow::SharedGrpcChannelPtr shared = @@ -269,7 +282,9 @@ class GrpcEagerClientCache : public EagerClientCache { } std::shared_ptr<tensorflow::GrpcChannelCache> cache_; - std::unordered_map<string, core::RefCountPtr<EagerClient>> clients_; + mutable mutex clients_mu_; + std::unordered_map<string, core::RefCountPtr<EagerClient>> clients_ + TF_GUARDED_BY(clients_mu_); std::vector<core::RefCountPtr<GrpcEagerClientThread>> threads_; }; diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client_test.cc b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client_test.cc new file mode 100644 index 00000000000..a6da56eca13 --- /dev/null +++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client_test.cc @@ -0,0 +1,58 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h" + +#include "tensorflow/c/tf_status.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/blocking_counter.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/strcat.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace eager { + +TEST(GrpcEagerClientCache, TestGetClientThreadSafety) { + GrpcChannelSpec spec; + TF_ASSERT_OK(spec.AddHostPortsJob( + "worker", {"a:1", "b:2", "c:3", "d:4", "e:5", "f:6"})); + ChannelCreationFunction channel_func = + ConvertToChannelCreationFunction(NewHostPortGrpcChannel); + auto channel_cache = std::shared_ptr<GrpcChannelCache>( + NewGrpcChannelCache(spec, channel_func)); + std::unique_ptr<EagerClientCache> client_cache( + NewGrpcEagerClientCache(channel_cache)); + const int num_calls = 10; + BlockingCounter counter(num_calls); + + for (int i = 0; i < num_calls; i++) { + Env::Default()->SchedClosure([&client_cache, i, &counter]() { + string target = strings::StrCat("/job:worker/replica:0/task:", i); + core::RefCountPtr<EagerClient> eager_client; + Status s = client_cache->GetClient(target, &eager_client); + // With 6 tasks added to the job, querying client for 0--5 should be OK, + // and querying client for 6+ should give invalid argument error. + error::Code expected_code = i <= 5 ? error::OK : error::INVALID_ARGUMENT; + EXPECT_EQ(expected_code, s.code()); + counter.DecrementCount(); + }); + } + counter.Wait(); +} + +} // namespace eager +} // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc index 85431acdf0c..6e706179863 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc @@ -45,7 +45,7 @@ class GrpcRemoteWorker : public WorkerInterface { explicit GrpcRemoteWorker(SharedGrpcChannelPtr channel, ::grpc::CompletionQueue* completion_queue, thread::ThreadPool* callback_threadpool, - WorkerCacheLogger* logger) + WorkerCacheLogger* logger, const string& target) : channel_(std::move(channel)), stub_(channel_), cq_(completion_queue), @@ -66,7 +66,8 @@ class GrpcRemoteWorker : public WorkerInterface { instancesource_(Method(GrpcWorkerMethod::kCompleteInstance)), getstepsequence_(Method(GrpcWorkerMethod::kGetStepSequence)), markrecvfinished_(Method(GrpcWorkerMethod::kMarkRecvFinished)), - logger_(logger) {} + logger_(logger), + target_(target) {} ~GrpcRemoteWorker() override {} @@ -273,7 +274,7 @@ class GrpcRemoteWorker : public WorkerInterface { bool fail_fast = true) { new RPCState<protobuf::Message>( &stub_, cq_, method, *request, response, std::move(done), call_opts, - callback_threadpool_, /*max_retries=*/0, fail_fast); + callback_threadpool_, /*max_retries=*/0, fail_fast, &target_); } void IssueRequest(const protobuf::Message* request, TensorResponse* response, @@ -281,7 +282,8 @@ class GrpcRemoteWorker : public WorkerInterface { CallOptions* call_opts = nullptr) { new RPCState<TensorResponse>(&stub_, cq_, method, *request, response, std::move(done), call_opts, - callback_threadpool_); + callback_threadpool_, /*max_retries=*/0, + /*fail_fast=*/true, &target_); } void IssueMarkRecvFinishedRequest(int64 request_id) { @@ -321,6 +323,7 @@ class GrpcRemoteWorker : public WorkerInterface { // Support for logging. WorkerCacheLogger* logger_; + const string target_; TF_DISALLOW_COPY_AND_ASSIGN(GrpcRemoteWorker); }; @@ -328,9 +331,10 @@ class GrpcRemoteWorker : public WorkerInterface { WorkerInterface* NewGrpcRemoteWorker(SharedGrpcChannelPtr channel, ::grpc::CompletionQueue* completion_queue, thread::ThreadPool* callback_threadpool, - WorkerCacheLogger* logger) { + WorkerCacheLogger* logger, + const string& target) { return new GrpcRemoteWorker(std::move(channel), completion_queue, - callback_threadpool, logger); + callback_threadpool, logger, target); } } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h index c0a49ecfc38..97e590e0ad1 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h @@ -29,7 +29,8 @@ class WorkerInterface; WorkerInterface* NewGrpcRemoteWorker(SharedGrpcChannelPtr channel, ::grpc::CompletionQueue* completion_queue, thread::ThreadPool* callback_threadpool, - WorkerCacheLogger* logger); + WorkerCacheLogger* logger, + const string& target); } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc index 25aa5f3480c..6523d2fb4dd 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc @@ -130,9 +130,6 @@ GrpcServer::~GrpcServer() { // OpSegments.) if (worker_env_.session_mgr != nullptr) { delete worker_env_.session_mgr; // Deletes graph_mgr's. - } else { - // Note: session_mgr's legacy_session_ deletes device_mgr now. - delete worker_env_.device_mgr; } // Do not delete (as these are not owned by the server): @@ -143,9 +140,11 @@ GrpcServer::~GrpcServer() { void GrpcServer::MaybeMutateBuilder(::grpc::ServerBuilder* builder) {} -// Look up the port that has been requested for this task in `server_def`. -Status GrpcServer::GetPort(const ServerDef& server_def, int* port) const { +// Look up the requested host name and port for this task in `server_def`. +Status GrpcServer::GetHostAndPort(const ServerDef& server_def, + string* host_name, int* port) const { *port = -1; + *host_name = "localhost"; for (const auto& job : server_def.cluster().job()) { if (job.name() == server_def.job_name()) { auto iter = job.tasks().find(server_def.task_index()); @@ -165,6 +164,11 @@ Status GrpcServer::GetPort(const ServerDef& server_def, int* port) const { "Could not parse port for local server from \"", iter->second, "\"."); } + + if (colon_index != string::npos && + !iter->second.substr(0, colon_index).empty()) { + *host_name = iter->second.substr(0, colon_index); + } } break; } @@ -187,7 +191,7 @@ Status GrpcServer::Init(const GrpcServerOptions& opts) { // otherwise if 'task_index=-1' the program will abort. int requested_port; - TF_RETURN_IF_ERROR(GetPort(server_def_, &requested_port)); + TF_RETURN_IF_ERROR(GetHostAndPort(server_def_, &host_name_, &requested_port)); SessionOptions sess_opts; ConfigProto config = server_def_.default_session_config(); @@ -197,12 +201,18 @@ Status GrpcServer::Init(const GrpcServerOptions& opts) { string name_prefix = strings::StrCat("/job:", server_def_.job_name(), "/replica:0", "/task:", server_def_.task_index()); - std::vector<std::unique_ptr<Device>> devices; - TF_RETURN_IF_ERROR( - DeviceFactory::AddDevices(sess_opts, name_prefix, &devices)); - worker_env_.device_mgr = new StaticDeviceMgr(std::move(devices)); - master_env_.local_devices = worker_env_.device_mgr->ListDevices(); + if (opts.local_device_mgr == nullptr) { + std::vector<std::unique_ptr<Device>> devices; + TF_RETURN_IF_ERROR( + DeviceFactory::AddDevices(sess_opts, name_prefix, &devices)); + worker_env_.device_mgr = new StaticDeviceMgr(std::move(devices)); + owned_device_manager_.reset(worker_env_.device_mgr); + } else { + worker_env_.device_mgr = opts.local_device_mgr; + owned_device_manager_.reset(nullptr); + } worker_env_.local_devices = worker_env_.device_mgr->ListDevices(); + master_env_.local_devices = worker_env_.device_mgr->ListDevices(); worker_env_.rendezvous_mgr = opts.rendezvous_mgr_func == nullptr ? new RpcRendezvousMgr(&worker_env_) : opts.rendezvous_mgr_func(&worker_env_); @@ -347,7 +357,7 @@ Status GrpcServer::ParseChannelSpec(const WorkerCacheFactoryOptions& options, task.second); } if (job.name() == *options.job_name && task.first == options.task_index) { - host_port = strings::StrCat("localhost:", bound_port_); + host_port = strings::StrCat(host_name_, ":", bound_port_); } else { host_port = task.second; } @@ -500,7 +510,7 @@ Status GrpcServer::Join() { } const string GrpcServer::target() const { - return strings::StrCat("grpc://localhost:", bound_port_); + return strings::StrCat("grpc://", host_name_, ":", bound_port_); } std::shared_ptr<::grpc::ServerCredentials> GrpcServer::GetServerCredentials( @@ -520,12 +530,13 @@ std::unique_ptr<Master> GrpcServer::CreateMaster(MasterEnv* master_env) { /* static */ Status GrpcServer::Create(const ServerDef& server_def, Env* env, + const DeviceMgr* local_device_mgr, std::unique_ptr<ServerInterface>* out_server) { std::unique_ptr<GrpcServer> ret( new GrpcServer(server_def, env == nullptr ? Env::Default() : env)); - ServiceInitFunction service_func = nullptr; GrpcServerOptions options; options.rendezvous_mgr_func = NewRpcRendezvousMgr; + options.local_device_mgr = local_device_mgr; Status s = ret->Init(options); if (!s.ok()) { LOG(ERROR) << s; @@ -535,19 +546,21 @@ Status GrpcServer::Create(const ServerDef& server_def, Env* env, return Status::OK(); } +/* static */ +Status GrpcServer::Create(const ServerDef& server_def, Env* env, + std::unique_ptr<ServerInterface>* out_server) { + return Create(server_def, env, nullptr, out_server); +} + /* static */ Status GrpcServer::Create(const ServerDef& server_def, Env* env, std::unique_ptr<GrpcServer>* out_server) { - std::unique_ptr<GrpcServer> ret( - new GrpcServer(server_def, env == nullptr ? Env::Default() : env)); - GrpcServerOptions options; - options.rendezvous_mgr_func = NewRpcRendezvousMgr; - Status s = ret->Init(options); + std::unique_ptr<ServerInterface> server; + Status s = Create(server_def, env, nullptr, &server); if (!s.ok()) { - LOG(ERROR) << s; return s; } - *out_server = std::move(ret); + out_server->reset(dynamic_cast<GrpcServer*>(server.release())); return Status::OK(); } @@ -559,9 +572,10 @@ class GrpcServerFactory : public ServerFactory { return server_def.protocol() == "grpc"; } - Status NewServer(const ServerDef& server_def, + Status NewServer(const ServerDef& server_def, const Options& options, std::unique_ptr<ServerInterface>* out_server) override { - return GrpcServer::Create(server_def, Env::Default(), out_server); + return GrpcServer::Create(server_def, Env::Default(), + options.local_device_mgr, out_server); } }; diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h index 8e25b8835eb..0474c5a517f 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h @@ -68,11 +68,14 @@ struct GrpcServerOptions { WorkerCreationFunction worker_func = nullptr; StatsPublisherFactory stats_factory = CreateNoOpStatsPublisher; GrpcWorkerServiceOptions worker_service_options; + const DeviceMgr* local_device_mgr = nullptr; }; class GrpcServer : public ServerInterface { protected: GrpcServer(const ServerDef& server_def, Env* env); + GrpcServer(const ServerDef& server_def, DeviceMgr* local_device_mgr, + Env* env); // Allow children classes to override this and provide custom args to the // server before it is constructed. Default behavior is to do nothing. virtual void MaybeMutateBuilder(::grpc::ServerBuilder* builder); @@ -82,6 +85,10 @@ class GrpcServer : public ServerInterface { std::unique_ptr<ServerInterface>* out_server); static Status Create(const ServerDef& server_def, Env* env, std::unique_ptr<GrpcServer>* out_server); + // Reuse the local_device_mgr. + static Status Create(const ServerDef& server_def, Env* env, + const DeviceMgr* local_device_mgr, + std::unique_ptr<ServerInterface>* out_server); // Destruction is only supported in the factory method. Clean // shutdown is not currently implemented for this server type. @@ -104,7 +111,8 @@ class GrpcServer : public ServerInterface { Status UpdateServerDef(const ServerDef& server_def); protected: - virtual Status GetPort(const ServerDef& server_def, int* port) const; + virtual Status GetHostAndPort(const ServerDef& server_def, string* host_name, + int* port) const; Status Init(const GrpcServerOptions& opts = GrpcServerOptions()); // A subclass can override this method to support secure credentials. @@ -136,6 +144,9 @@ class GrpcServer : public ServerInterface { // The port to which this server is bound. int bound_port_ = 0; + // The host name of this server + string host_name_; + // Guards server configuration, server, and state. mutex mu_; @@ -159,6 +170,7 @@ class GrpcServer : public ServerInterface { // Implementation of a TensorFlow worker, and RPC polling thread. WorkerEnv worker_env_; + std::unique_ptr<const DeviceMgr> owned_device_manager_; std::unique_ptr<GrpcWorker> worker_impl_; AsyncServiceInterface* worker_service_ = nullptr; std::unique_ptr<Thread> worker_thread_ TF_GUARDED_BY(mu_); diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc index f6b6e15a2ba..1d75728ddd2 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc @@ -69,9 +69,9 @@ class GrpcWorkerCache : public WorkerCachePartial { return nullptr; } size_t index = AssignWorkerToThread(target); - return NewGrpcRemoteWorker(channel, - worker_env_->GetCompletionQueue(index), - worker_env_->GetThreadPool(), &logger_); + return NewGrpcRemoteWorker( + channel, worker_env_->GetCompletionQueue(index), + worker_env_->GetThreadPool(), &logger_, target); } } diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc index 5bb61eb8cc1..512c17fcfcf 100644 --- a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc +++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/notification.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { @@ -136,7 +137,12 @@ class RpcRecvTensorCall : public BaseRecvTensorCall { // Start the main RecvTensor call, checking for an async abort. void StartRTCall(std::function<void()> recv_done) { resp_.InitAlloc(dst_device_, alloc_attrs_); - auto cb = [this, recv_done = std::move(recv_done)](const Status& s) { + auto abort_checked = std::make_shared<Notification>(); + auto cb = [this, abort_checked, + recv_done = std::move(recv_done)](const Status& s) { + // Make sure the Rendezvous abort checking is finished before running the + // callback, which might destroy the current call object. + abort_checked->WaitForNotification(); if (!s.ok()) { mutex_lock l(mu_); status_.Update(s); @@ -144,6 +150,22 @@ class RpcRecvTensorCall : public BaseRecvTensorCall { recv_done(); }; wi_->RecvTensorAsync(&opts_, &req_, &resp_, std::move(cb)); + + // NOTE: Check if the rendezvous was aborted after sending out the RPC. The + // ordering is important because `StartAbort` could be called right before + // the `RecvTensorAsync` request registers its RPC cancellation to `opts_`. + // In that case, the previous `StartAbort` would not trigger the + // cancellation of this call. + Status s; + { + mutex_lock l(mu_); + s = status_; + } + if (!s.ok()) { + opts_.StartCancel(); + } + // Notify that the abort check has finished. + abort_checked->Notify(); } string src_worker_; diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc index 85923542f73..7c5779246bd 100644 --- a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc +++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc @@ -16,13 +16,16 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h" #include "tensorflow/core/common_runtime/process_util.h" +#include "tensorflow/core/distributed_runtime/test_utils.h" #include "tensorflow/core/framework/cancellation.h" #include "tensorflow/core/framework/control_flow.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/notification.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/blocking_counter.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/random.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -48,13 +51,34 @@ Rendezvous::ParsedKey MakeKey(const string& s) { } namespace { +// A dummy worker interface implementation that simply triggers the callback +// with OK status for RecvTensor request. +class DummyWorker : public TestWorkerInterface { + public: + void RecvTensorAsync(CallOptions* opts, const RecvTensorRequest* request, + TensorResponse* response, StatusCallback done) override { + SchedClosure([done = std::move(done)]() { + // Simulate a random delay for RPC. This is needed to fill the entire + // object buffer in `RpcRecvTensorFreeList` and trigger the destruction of + // RPC call objects. + const int64 t_us = random::New64() % 100 * 1000; + Env::Default()->SleepForMicroseconds(t_us); + done(Status::OK()); + }); + } +}; + // Fake cache implementation for WorkerEnv. class DummyWorkerCache : public WorkerCacheInterface { void ListWorkers(std::vector<string>* workers) const override {} void ListWorkersInJob(const string& job_name, std::vector<string>* workers) const override {} WorkerInterface* GetOrCreateWorker(const string& target) override { - return nullptr; + if (dummy_remote_worker_ == nullptr) { + // Ownership transferred to WorkerFreeList + dummy_remote_worker_ = new DummyWorker; + } + return dummy_remote_worker_; } Status GetEagerClientCache( std::unique_ptr<eager::EagerClientCache>* eager_client_cache) override { @@ -66,7 +90,31 @@ class DummyWorkerCache : public WorkerCacheInterface { } void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality, StatusCallback done) override {} + + private: + DummyWorker* dummy_remote_worker_ = nullptr; }; + +static Device* CreateDevice(const char* type, const char* name) { + class FakeDevice : public Device { + public: + explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {} + Status Sync() override { return Status::OK(); } + Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; } + }; + DeviceAttributes attr; + attr.set_name(name); + attr.set_device_type(type); + return new FakeDevice(attr); +} + +static DeviceMgr* CreateDeviceMgr() { + std::unique_ptr<Device> d0( + CreateDevice("CPU", "/job:mnist/replica:1/task:2/cpu:1")); + std::vector<std::unique_ptr<Device>> devices; + devices.emplace_back(std::move(d0)); + return new StaticDeviceMgr(std::move(devices)); +} } // namespace class RpcRendezvousMgrTest : public ::testing::Test { @@ -75,7 +123,7 @@ class RpcRendezvousMgrTest : public ::testing::Test { : cache_(new DummyWorkerCache), worker_session_("rpc_session", "/job:mnist/replica:1/task:2", std::unique_ptr<WorkerCacheInterface>(cache_), - std::unique_ptr<DeviceMgr>(), + std::unique_ptr<DeviceMgr>(CreateDeviceMgr()), std::unique_ptr<GraphMgr>(), nullptr), rmgr_(&env) { env.env = Env::Default(); @@ -193,6 +241,7 @@ TEST_F(RpcRendezvousMgrTest, CancelAfterReceived) { delete cm; } +namespace { class DummyDeviceContext : public DeviceContext { public: explicit DummyDeviceContext(int stream_id) : stream_id_(stream_id) {} @@ -202,6 +251,7 @@ class DummyDeviceContext : public DeviceContext { private: const int stream_id_; }; +} // namespace TEST_F(RpcRendezvousMgrTest, TransferDummyDeviceContext) { DummyDeviceContext* dc = new DummyDeviceContext(123); @@ -237,6 +287,59 @@ TEST_F(RpcRendezvousMgrTest, TransferDummyDeviceContext) { dc->Unref(); } -// NOTE: Remote Send/Recv is better tested in worker_test.cc +TEST_F(RpcRendezvousMgrTest, RemoteRecvOne) { + const int64 step_id = 123; + const Rendezvous::ParsedKey key = MakeKey(Rendezvous::CreateKey( + "/job:worker/replica:1/task:2/cpu:0", 7890, + "/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0))); + { + RemoteRendezvous* rendez = rmgr_.Find(step_id); + TF_ASSERT_OK(rendez->Initialize(&worker_session_)); + core::ScopedUnref unref(rendez); + Rendezvous::Args args; + + Tensor val(DT_STRING); + bool val_dead = false; + + TF_ASSERT_OK(rendez->Recv(key, args, &val, &val_dead)); + } + rmgr_.Cleanup(step_id); +} + +TEST_F(RpcRendezvousMgrTest, RemoteRecvAsyncMany) { + const int64 step_id = 123; + const Rendezvous::ParsedKey key = MakeKey(Rendezvous::CreateKey( + "/job:worker/replica:1/task:2/cpu:0", 7890, + "/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0))); + { + RemoteRendezvous* rendez = rmgr_.Find(step_id); + TF_ASSERT_OK(rendez->Initialize(&worker_session_)); + core::ScopedUnref unref(rendez); + Rendezvous::Args args; + + // Send a large number of async RPC requests to fill up the buffer in + // `RpcRecvTensorFreeList`, in order to test deleting RPC call objects. + int num_requests = 10000; + Tensor val(DT_STRING); + mutex mu_; + Status status = Status::OK(); + BlockingCounter counter(num_requests); + + for (int i = 0; i < num_requests; i++) { + rendez->RecvAsync( + key, args, + [&mu_, &status, &counter](const Status& s, const Rendezvous::Args&, + const Rendezvous::Args&, const Tensor&, + const bool) { + mutex_lock l(mu_); + status.Update(s); + counter.DecrementCount(); + }); + } + counter.Wait(); + TF_ASSERT_OK(status); + } + rmgr_.Cleanup(step_id); +} } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/server_lib.cc b/tensorflow/core/distributed_runtime/server_lib.cc index 62a2011db39..12baa75976a 100644 --- a/tensorflow/core/distributed_runtime/server_lib.cc +++ b/tensorflow/core/distributed_runtime/server_lib.cc @@ -73,7 +73,17 @@ Status NewServer(const ServerDef& server_def, std::unique_ptr<ServerInterface>* out_server) { ServerFactory* factory; TF_RETURN_IF_ERROR(ServerFactory::GetFactory(server_def, &factory)); - return factory->NewServer(server_def, out_server); + return factory->NewServer(server_def, ServerFactory::Options(), out_server); +} + +// Creates a server based on the given `server_def`, and stores it in +// `*out_server`. Returns OK on success, otherwise returns an error. +Status NewServerWithOptions(const ServerDef& server_def, + const ServerFactory::Options& options, + std::unique_ptr<ServerInterface>* out_server) { + ServerFactory* factory; + TF_RETURN_IF_ERROR(ServerFactory::GetFactory(server_def, &factory)); + return factory->NewServer(server_def, options, out_server); } } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/server_lib.h b/tensorflow/core/distributed_runtime/server_lib.h index 275f526d311..7b4b4892848 100644 --- a/tensorflow/core/distributed_runtime/server_lib.h +++ b/tensorflow/core/distributed_runtime/server_lib.h @@ -24,6 +24,8 @@ limitations under the License. namespace tensorflow { +class DeviceMgr; + // This library supports a registration/factory-based mechanism for // creating TensorFlow server objects. Each server implementation must // have an accompanying implementation of ServerFactory, and create a @@ -63,10 +65,14 @@ class ServerInterface { class ServerFactory { public: + struct Options { + // Local DeviceMgr to use. + const tensorflow::DeviceMgr* local_device_mgr; + }; // Creates a new server based on the given `server_def`, and stores // it in `*out_server`. Returns OK on success, otherwise returns an // error. - virtual Status NewServer(const ServerDef& server_def, + virtual Status NewServer(const ServerDef& server_def, const Options& options, std::unique_ptr<ServerInterface>* out_server) = 0; // Returns true if and only if this factory can create a server @@ -92,6 +98,9 @@ class ServerFactory { // `*out_server`. Returns OK on success, otherwise returns an error. Status NewServer(const ServerDef& server_def, std::unique_ptr<ServerInterface>* out_server); +Status NewServerWithOptions(const ServerDef& server_def, + const ServerFactory::Options& options, + std::unique_ptr<ServerInterface>* out_server); } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/server_lib_test.cc b/tensorflow/core/distributed_runtime/server_lib_test.cc index 77048c24b47..2152ff986d6 100644 --- a/tensorflow/core/distributed_runtime/server_lib_test.cc +++ b/tensorflow/core/distributed_runtime/server_lib_test.cc @@ -26,7 +26,7 @@ class TestServerFactory : public ServerFactory { return server_def.protocol() == "test_protocol"; } - Status NewServer(const ServerDef& server_def, + Status NewServer(const ServerDef& server_def, const Options& options, std::unique_ptr<ServerInterface>* out_server) override { return Status::OK(); } diff --git a/tensorflow/core/distributed_runtime/session_mgr.cc b/tensorflow/core/distributed_runtime/session_mgr.cc index e2151e068f6..1d9a22a5817 100644 --- a/tensorflow/core/distributed_runtime/session_mgr.cc +++ b/tensorflow/core/distributed_runtime/session_mgr.cc @@ -171,7 +171,7 @@ Status SessionMgr::UpdateSession( std::vector<std::unique_ptr<Device>> cluster_devices; - DeviceMgr* local_device_mgr = worker_session->device_mgr(); + const DeviceMgr* local_device_mgr = worker_session->device_mgr(); DeviceMgr* remote_device_mgr = worker_session->remote_device_mgr(); std::vector<Device*> curr_remote_devices = remote_device_mgr->ListDevices(); std::vector<std::unique_ptr<Device>> added_remote_devices; diff --git a/tensorflow/core/distributed_runtime/test_utils.h b/tensorflow/core/distributed_runtime/test_utils.h index a93c78e62fd..cec09775469 100644 --- a/tensorflow/core/distributed_runtime/test_utils.h +++ b/tensorflow/core/distributed_runtime/test_utils.h @@ -70,28 +70,28 @@ class TestWorkerInterface : public WorkerInterface { void CleanupGraphAsync(const CleanupGraphRequest* request, CleanupGraphResponse* response, StatusCallback done) override { - done(errors::Unimplemented("RunGraphAsync")); + done(errors::Unimplemented("CleanupGraphAsync")); } void CleanupAllAsync(const CleanupAllRequest* request, CleanupAllResponse* response, StatusCallback done) override { - done(errors::Unimplemented("RunGraphAsync")); + done(errors::Unimplemented("CleanupAllAsync")); } void RecvTensorAsync(CallOptions* opts, const RecvTensorRequest* request, TensorResponse* response, StatusCallback done) override { - done(errors::Unimplemented("RunGraphAsync")); + done(errors::Unimplemented("RecvTensorAsync")); } void LoggingAsync(const LoggingRequest* request, LoggingResponse* response, StatusCallback done) override { - done(errors::Unimplemented("RunGraphAsync")); + done(errors::Unimplemented("LoggingAsync")); } void TracingAsync(const TracingRequest* request, TracingResponse* response, StatusCallback done) override { - done(errors::Unimplemented("RunGraphAsync")); + done(errors::Unimplemented("TracingAsync")); } void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request, @@ -103,20 +103,20 @@ class TestWorkerInterface : public WorkerInterface { const CompleteGroupRequest* request, CompleteGroupResponse* response, StatusCallback done) override { - done(errors::Unimplemented("RunGraphAsync")); + done(errors::Unimplemented("CompleteGroupAsync")); } void CompleteInstanceAsync(CallOptions* ops, const CompleteInstanceRequest* request, CompleteInstanceResponse* response, StatusCallback done) override { - done(errors::Unimplemented("RunGraphAsync")); + done(errors::Unimplemented("CompleteInstanceAsync")); } void GetStepSequenceAsync(const GetStepSequenceRequest* request, GetStepSequenceResponse* response, StatusCallback done) override { - done(errors::Unimplemented("RunGraphAsync")); + done(errors::Unimplemented("GetStepSequenceAsync")); } }; diff --git a/tensorflow/core/distributed_runtime/worker.cc b/tensorflow/core/distributed_runtime/worker.cc index 7850ecc46b2..f857a63e64d 100644 --- a/tensorflow/core/distributed_runtime/worker.cc +++ b/tensorflow/core/distributed_runtime/worker.cc @@ -38,7 +38,7 @@ Worker::Worker(WorkerEnv* env) : env_(env), recent_request_ids_(100000) { void Worker::GetStatusAsync(const GetStatusRequest* request, GetStatusResponse* response, bool fail_fast, StatusCallback done) { - DeviceMgr* dm = env_->device_mgr; + const DeviceMgr* dm = env_->device_mgr; std::vector<DeviceAttributes> devices; dm->ListDeviceAttributes(&devices); response->mutable_device_attributes()->Reserve(devices.size()); diff --git a/tensorflow/core/distributed_runtime/worker_env.h b/tensorflow/core/distributed_runtime/worker_env.h index 93d933bfa60..ecc3313d0ce 100644 --- a/tensorflow/core/distributed_runtime/worker_env.h +++ b/tensorflow/core/distributed_runtime/worker_env.h @@ -53,7 +53,7 @@ struct WorkerEnv { // Note: Please use the device_mgr associated with your session if appropriate // instead of this one. Using this device_mgr does not support ClusterSpec // propagated sessions. - DeviceMgr* device_mgr = nullptr; + const DeviceMgr* device_mgr = nullptr; // A set of rendezvous keyed by step ids. RendezvousMgrInterface* rendezvous_mgr = nullptr; diff --git a/tensorflow/core/distributed_runtime/worker_session.cc b/tensorflow/core/distributed_runtime/worker_session.cc index ca4f25f08f5..3aed73fa358 100644 --- a/tensorflow/core/distributed_runtime/worker_session.cc +++ b/tensorflow/core/distributed_runtime/worker_session.cc @@ -144,7 +144,7 @@ Status WorkerSession::UpdateWorkerCacheAndDevices( std::shared_ptr<WorkerSession> WorkerSession::CreateWithBorrowedDeviceMgr( const string& session_name, const string& worker_name, std::unique_ptr<WorkerCacheInterface> worker_cache, - DeviceMgr* borrowed_device_mgr, std::unique_ptr<GraphMgr> graph_mgr, + const DeviceMgr* borrowed_device_mgr, std::unique_ptr<GraphMgr> graph_mgr, std::unique_ptr<DynamicDeviceMgr> remote_device_mgr) { return std::shared_ptr<WorkerSession>(new WorkerSession( session_name, worker_name, std::move(worker_cache), borrowed_device_mgr, @@ -154,7 +154,7 @@ std::shared_ptr<WorkerSession> WorkerSession::CreateWithBorrowedDeviceMgr( WorkerSession::WorkerSession( const string& session_name, const string& worker_name, std::unique_ptr<WorkerCacheInterface> worker_cache, - DeviceMgr* borrowed_device_mgr, std::unique_ptr<GraphMgr> graph_mgr, + const DeviceMgr* borrowed_device_mgr, std::unique_ptr<GraphMgr> graph_mgr, std::unique_ptr<DynamicDeviceMgr> remote_device_mgr) : session_name_(session_name), worker_name_(worker_name), diff --git a/tensorflow/core/distributed_runtime/worker_session.h b/tensorflow/core/distributed_runtime/worker_session.h index 3b2d1122558..f870a8c064b 100644 --- a/tensorflow/core/distributed_runtime/worker_session.h +++ b/tensorflow/core/distributed_runtime/worker_session.h @@ -37,7 +37,7 @@ class WorkerSession { // sessions created with `isolate_session_state == false`. In the // those cases, this method returns a pointer to a borrowed // DeviceMgr (typically the `worker_env.device_mgr`). - DeviceMgr* device_mgr() { + const DeviceMgr* device_mgr() { return device_mgr_ ? device_mgr_.get() : borrowed_device_mgr_; } @@ -65,7 +65,7 @@ class WorkerSession { static std::shared_ptr<WorkerSession> CreateWithBorrowedDeviceMgr( const string& session_name, const string& worker_name, std::unique_ptr<WorkerCacheInterface> worker_cache, - DeviceMgr* borrowed_device_mgr, std::unique_ptr<GraphMgr> graph_mgr, + const DeviceMgr* borrowed_device_mgr, std::unique_ptr<GraphMgr> graph_mgr, std::unique_ptr<DynamicDeviceMgr> remote_device_mgr); // In the eager runtime we allow WorkerSession to be updated, where the @@ -90,7 +90,7 @@ class WorkerSession { private: WorkerSession(const string& session_name, const string& worker_name, std::unique_ptr<WorkerCacheInterface> worker_cache, - DeviceMgr* borrowed_device_mgr, + const DeviceMgr* borrowed_device_mgr, std::unique_ptr<GraphMgr> graph_mgr, std::unique_ptr<DynamicDeviceMgr> remote_device_mgr); @@ -113,8 +113,8 @@ class WorkerSession { std::unique_ptr<ClusterFunctionLibraryRuntime> cluster_flr_; - const std::unique_ptr<DeviceMgr> device_mgr_; - DeviceMgr* const borrowed_device_mgr_; // Not owned. + const std::unique_ptr<const DeviceMgr> device_mgr_; + const DeviceMgr* const borrowed_device_mgr_; // Not owned. std::unique_ptr<DynamicDeviceMgr> remote_device_mgr_; }; diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index 113adbdd432..216002ad8e7 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -468,6 +468,25 @@ Status CheckFormatConstraintsOnShape(const TensorFormat tensor_format, return Status::OK(); } +Status DatasetIteratorShape(shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); + std::vector<PartialTensorShape> output_shapes; + TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes)); + if (output_shapes.size() != c->num_outputs()) { + return errors::InvalidArgument( + "`output_shapes` must be the same length as `output_types` (", + output_shapes.size(), " vs. ", c->num_outputs()); + } + for (size_t i = 0; i < output_shapes.size(); ++i) { + shape_inference::ShapeHandle output_shape_handle; + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape( + output_shapes[i], &output_shape_handle)); + c->set_output(static_cast<int>(i), output_shape_handle); + } + return Status::OK(); +} + Status MakeShapeFromFormat(TensorFormat format, DimensionOrConstant N, const std::vector<DimensionOrConstant>& spatial, DimensionOrConstant C, ShapeHandle* out, diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h index e1984abab7e..218400c2435 100644 --- a/tensorflow/core/framework/common_shape_fns.h +++ b/tensorflow/core/framework/common_shape_fns.h @@ -92,6 +92,9 @@ inline Status MergeBothInputsShapeFn(InferenceContext* c) { return Status::OK(); } +// Shape function for dataset iterators. +Status DatasetIteratorShape(shape_inference::InferenceContext* c); + // Returns a new shape with the specified dims arranged in the specified // format. The returned value is owned by this context. // Note: if format = "FORMAT_NCHW_VECT_C" then C represents the outer_depth. diff --git a/tensorflow/core/framework/load_library.cc b/tensorflow/core/framework/load_library.cc index b9e33b148f7..c223eac4722 100644 --- a/tensorflow/core/framework/load_library.cc +++ b/tensorflow/core/framework/load_library.cc @@ -21,6 +21,9 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/mem.h" +#if !defined(IS_MOBILE_PLATFORM) +#include "tensorflow/core/tpu/tpu_library_loader.h" +#endif // IS_MOBILE_PLATFORM namespace tensorflow { @@ -97,6 +100,17 @@ Status LoadLibrary(const char* library_filename, void** result, *buf = str_buf; *len = str.length(); +#if !defined(IS_MOBILE_PLATFORM) + // Determine if this library is a TPU library, and if so, calls the TPU + // initialization functions to populate function tables, etc... + void* unused_symbol; + if (env->GetSymbolFromLibrary(library.handle, "TfTpu_Initialize", + &unused_symbol) + .ok()) { + TF_RETURN_IF_ERROR(tensorflow::tpu::InitializeTPULibrary(library.handle)); + } +#endif // IS_MOBILE_PLATFORM + *result = library.handle; return Status::OK(); } diff --git a/tensorflow/core/framework/model.cc b/tensorflow/core/framework/model.cc index b4a54029a4f..658be94b9bb 100644 --- a/tensorflow/core/framework/model.cc +++ b/tensorflow/core/framework/model.cc @@ -25,10 +25,6 @@ namespace data { namespace model { namespace { -// Key of the derivative w.r.t. the last input time in the gradient of -// `OutputTime`. -constexpr char kInputTimeDerivativeKey[] = "last_input_time"; - // Wrapper for the square function to reduce verbosity. inline double Square(double x) { return x * x; } @@ -50,34 +46,60 @@ class InterleaveMany : public Node { Args{id_, name_, std::move(output)}); } + void InputTimeLocked(absl::flat_hash_map<string, double>* input_times) + const override TF_SHARED_LOCKS_REQUIRED(mu_) { + double old_input_time; + if (output_) { + old_input_time = (*input_times)[output_->long_name()]; + } else { + old_input_time = gtl::FindWithDefault(*input_times, kInputTimeKey, 0.0L); + } + + if (num_inputs() <= 1) { + (*input_times)[long_name()] = old_input_time; + return; + } + double new_input_time = + old_input_time + + SelfProcessingTimeLocked() * static_cast<double>(num_inputs() - 1); + (*input_times)[long_name()] = new_input_time; + } + // The output time is the sum of the self processing time and the average // output time of inputs comprising the interleave "cycle". - double OutputTimeLocked(std::vector<double>* input_times, - absl::flat_hash_map<string, double>* gradient) - const override TF_SHARED_LOCKS_REQUIRED(mu_) { + void OutputTimeLocked( + const absl::flat_hash_map<string, double>& input_times, + absl::flat_hash_map<string, double>* gradients, + absl::flat_hash_map<string, double>* output_times, + absl::flat_hash_map<string, double>* output_time_gradients) const override + TF_SHARED_LOCKS_REQUIRED(mu_) { + double self_processing_time = SelfProcessingTimeLocked(); if (num_inputs() <= 1) { - return SelfProcessingTimeLocked(); - } - double delta = SelfProcessingTimeLocked() * (num_inputs() - 1); - input_times->back() += delta; - auto cleanup = gtl::MakeCleanup( - [input_times, delta]() { input_times->back() -= delta; }); - double output_time; - if (gradient) { - absl::flat_hash_map<string, double> inputs_gradient; - output_time = - (OutputTimeForInputs(input_times, &inputs_gradient) - - inputs_.front()->OutputTime(input_times, /*gradient=*/nullptr)) / - static_cast<double>(num_inputs() - 1); - for (auto& pair : inputs_gradient) { - (*gradient)[pair.first] = - pair.second / static_cast<double>(num_inputs() - 1); + (*output_times)[long_name()] = self_processing_time; + if (gradients) { + for (const auto& node : CollectNodes(TraversalOrder::REVERSE_BFS)) { + gradients->erase(node->long_name()); + } } - auto last_input_time_der = - gtl::FindWithDefault(*gradient, kInputTimeDerivativeKey, 0.0L); - (*gradient)[kInputTimeDerivativeKey] = - last_input_time_der + inputs_gradient[kInputTimeDerivativeKey] / - static_cast<double>(num_inputs() - 1); + return; + } + + double output_time = (OutputTimeForInputs(*output_times) - + (*output_times)[inputs_.front()->long_name()]) / + static_cast<double>(num_inputs() - 1); + if (gradients) { + for (const auto& node : CollectNodes(TraversalOrder::REVERSE_BFS)) { + auto* gradient = gtl::FindOrNull(*gradients, node->long_name()); + if (gradient) { + *gradient /= static_cast<double>(num_inputs() - 1); + } + } + + (*output_time_gradients)[long_name()] = + (OutputTimeGradientsForInputs(*output_time_gradients) - + (*output_time_gradients)[inputs_.front()->long_name()]) / + static_cast<double>(num_inputs() - 1); + // Set derivatives w.r.t. tunable parameters of the subtree rooted in the // first input equal to 0 since its output time is excluded from // computations. @@ -85,15 +107,10 @@ class InterleaveMany : public Node { first_input_parameters; inputs_.front()->CollectTunableParameters(&first_input_parameters); for (auto& pair : first_input_parameters) { - (*gradient)[pair.first] = 0.0L; + (*gradients)[pair.first] = 0.0L; } - } else { - output_time = - (OutputTimeForInputs(input_times, /*gradient=*/nullptr) - - inputs_.front()->OutputTime(input_times, /*gradient=*/nullptr)) / - static_cast<double>(num_inputs() - 1); } - return SelfProcessingTimeLocked() + output_time; + (*output_times)[long_name()] = self_processing_time + output_time; } // The processing time is the sum of the self processing time and the average @@ -107,16 +124,15 @@ class InterleaveMany : public Node { (*processing_times)[long_name()] = self_processing_time; } if (num_inputs() <= 1) { - total_processing_times->insert( - std::make_pair(long_name(), self_processing_time)); + (*total_processing_times)[long_name()] = self_processing_time; return; } double processing_time = (TotalProcessingTimeForInputs(*total_processing_times) - (*total_processing_times)[inputs_.front()->long_name()]) / static_cast<double>(num_inputs() - 1); - total_processing_times->insert( - std::make_pair(long_name(), self_processing_time + processing_time)); + (*total_processing_times)[long_name()] = + self_processing_time + processing_time; } }; @@ -148,55 +164,85 @@ class AsyncInterleaveMany : public Node { Args{id_, name_, std::move(output)}, parameters); } + void InputTimeLocked(absl::flat_hash_map<string, double>* input_times) + const override TF_SHARED_LOCKS_REQUIRED(mu_) { + double input_time; + + if (num_inputs() <= 1) { + if (output_) { + input_time = (*input_times)[output_->long_name()]; + } else { + input_time = gtl::FindWithDefault(*input_times, kInputTimeKey, 0.0L); + } + } else { + input_time = + SelfProcessingTimeLocked() * static_cast<double>(num_inputs() - 1); + } + (*input_times)[long_name()] = input_time; + } + // The output time is estimated using `ComputeWaitTime(output_time, // input_time, parallelism, ...)`, where `output_time` is the sum of the // self-processing time and the average output time of inputs comprising the // interleave "cycle", `input_time` is specified through `input_times` and // `buffer_size` is derived from parallelism. - double OutputTimeLocked(std::vector<double>* input_times, - absl::flat_hash_map<string, double>* gradient) - const override TF_SHARED_LOCKS_REQUIRED(mu_) { + void OutputTimeLocked( + const absl::flat_hash_map<string, double>& input_times, + absl::flat_hash_map<string, double>* gradients, + absl::flat_hash_map<string, double>* output_times, + absl::flat_hash_map<string, double>* output_time_gradients) const override + TF_SHARED_LOCKS_REQUIRED(mu_) { + double self_processing_time = SelfProcessingTimeLocked(); if (num_inputs() <= 1) { - return SelfProcessingTimeLocked(); + (*output_times)[long_name()] = self_processing_time; + if (gradients) { + for (const auto& node : CollectNodes(TraversalOrder::REVERSE_BFS)) { + gradients->erase(node->long_name()); + } + } + return; } - double old_input_time = input_times->back(); - double new_input_time = - SelfProcessingTimeLocked() * static_cast<double>(num_inputs() - 1); - input_times->push_back(new_input_time); - auto cleanup = - gtl::MakeCleanup([input_times]() { input_times->pop_back(); }); + + double input_time; + if (output_) { + input_time = input_times.at(output_->long_name()); + } else { + input_time = gtl::FindWithDefault(input_times, kInputTimeKey, 0.0L); + } + double parallelism = num_inputs() - 1; // default to cycle length auto* parameter = gtl::FindOrNull(parameters_, kParallelism); if (parameter) { parallelism = std::min(parallelism, (*parameter)->value); } - if (gradient) { - absl::flat_hash_map<string, double> inputs_gradient; - double output_time_for_inputs = - OutputTimeForInputs(input_times, &inputs_gradient) - - inputs_.front()->OutputTime(input_times, /*gradient=*/nullptr); - double output_time = output_time_for_inputs / - static_cast<double>(num_inputs() - 1) / parallelism; + + double output_time_for_inputs = + OutputTimeForInputs(*output_times) - + (*output_times)[inputs_.front()->long_name()]; + double output_time = output_time_for_inputs / + static_cast<double>(num_inputs() - 1) / parallelism; + double result; + + if (gradients) { double output_time_der = 0.0L; double input_time_der = 0.0L; double buffer_size_der = 0.0L; - double result = ComputeWaitTime( - SelfProcessingTimeLocked() + output_time, old_input_time, parallelism, - &output_time_der, &input_time_der, &buffer_size_der); - auto last_input_time_der = - gtl::FindWithDefault(*gradient, kInputTimeDerivativeKey, 0.0L); - (*gradient)[kInputTimeDerivativeKey] = - last_input_time_der + input_time_der; + result = ComputeWaitTime(self_processing_time + output_time, input_time, + parallelism, &output_time_der, &input_time_der, + &buffer_size_der); + (*output_time_gradients)[long_name()] = input_time_der; double parallelism_der = -output_time_for_inputs / static_cast<double>(num_inputs() - 1) / Square(parallelism); - for (auto& pair : inputs_gradient) { - if (pair.first != kInputTimeDerivativeKey) { - (*gradient)[pair.first] = output_time_der * pair.second / - static_cast<double>(num_inputs() - 1) / - parallelism; + + for (const auto& node : CollectNodes(TraversalOrder::REVERSE_BFS)) { + auto* gradient = gtl::FindOrNull(*gradients, node->long_name()); + if (gradient) { + *gradient *= (output_time_der / + static_cast<double>(num_inputs() - 1) / parallelism); } } + // Set derivatives w.r.t. tunable parameters of the subtree rooted in the // first input equal to 0 since its output time is excluded from // computations. @@ -204,23 +250,21 @@ class AsyncInterleaveMany : public Node { first_input_parameters; inputs_.front()->CollectTunableParameters(&first_input_parameters); for (auto& pair : first_input_parameters) { - (*gradient)[pair.first] = 0.0L; + (*gradients)[pair.first] = 0.0L; } // Add derivative w.r.t. own parallelism parameter. if (parameter && (*parameter)->state->tunable) { - (*gradient)[long_name()] = + (*gradients)[long_name()] = output_time_der * parallelism_der + buffer_size_der; } - return result; + } else { + result = ComputeWaitTime(self_processing_time + output_time, input_time, + parallelism, + /*output_time_derivative=*/nullptr, + /*input_time_derivative=*/nullptr, + /*buffer_size_derivative=*/nullptr); } - double output_time = - (OutputTimeForInputs(input_times, /*gradient=*/nullptr) - - inputs_.front()->OutputTime(input_times, /*gradient=*/nullptr)) / - static_cast<double>(num_inputs() - 1) / parallelism; - return ComputeWaitTime( - SelfProcessingTimeLocked() + output_time, old_input_time, parallelism, - /*output_time_derivative=*/nullptr, - /*input_time_derivative=*/nullptr, /*buffer_size_derivative=*/nullptr); + (*output_times)[long_name()] = result; } // The processing time is the sum of the self processing time and the average @@ -234,16 +278,15 @@ class AsyncInterleaveMany : public Node { (*processing_times)[long_name()] = self_processing_time; } if (num_inputs() <= 1) { - total_processing_times->insert( - std::make_pair(long_name(), self_processing_time)); + (*total_processing_times)[long_name()] = self_processing_time; return; } double processing_time = (TotalProcessingTimeForInputs(*total_processing_times) - (*total_processing_times)[inputs_.front()->long_name()]) / static_cast<double>(num_inputs() - 1); - total_processing_times->insert( - std::make_pair(long_name(), self_processing_time + processing_time)); + (*total_processing_times)[long_name()] = + self_processing_time + processing_time; } }; @@ -260,41 +303,55 @@ class KnownRatio : public Node { ratio_); } + void InputTimeLocked(absl::flat_hash_map<string, double>* input_times) + const override TF_SHARED_LOCKS_REQUIRED(mu_) { + double old_input_time; + if (output_) { + old_input_time = (*input_times)[output_->long_name()]; + } else { + old_input_time = gtl::FindWithDefault(*input_times, kInputTimeKey, 0.0L); + } + + if (ratio_ == 0) { + (*input_times)[long_name()] = old_input_time; + return; + } + double new_input_time = + (old_input_time + SelfProcessingTimeLocked()) / ratio_; + (*input_times)[long_name()] = new_input_time; + } + // The output time is the sum of the self processing time and the product of // `ratio_` and the sum of output times of inputs. - double OutputTimeLocked(std::vector<double>* input_times, - absl::flat_hash_map<string, double>* gradient) - const override TF_SHARED_LOCKS_REQUIRED(mu_) { + void OutputTimeLocked( + const absl::flat_hash_map<string, double>& input_times, + absl::flat_hash_map<string, double>* gradients, + absl::flat_hash_map<string, double>* output_times, + absl::flat_hash_map<string, double>* output_time_gradients) const override + TF_SHARED_LOCKS_REQUIRED(mu_) { + double self_processing_time = SelfProcessingTimeLocked(); if (ratio_ == 0) { - return SelfProcessingTimeLocked(); - } - double old_input_time = input_times->back(); - input_times->back() = - (old_input_time + SelfProcessingTimeLocked()) / ratio_; - auto cleanup = gtl::MakeCleanup([input_times, old_input_time]() { - input_times->back() = old_input_time; - }); - double result; - if (gradient) { - absl::flat_hash_map<string, double> inputs_gradient; - result = SelfProcessingTimeLocked() + - ratio_ * OutputTimeForInputs(input_times, &inputs_gradient); - auto last_input_time_der = - gtl::FindWithDefault(*gradient, kInputTimeDerivativeKey, 0.0L); - (*gradient)[kInputTimeDerivativeKey] = - last_input_time_der + ratio_ * - inputs_gradient[kInputTimeDerivativeKey] * - (1.0L + 1.0L / ratio_); - for (auto& pair : inputs_gradient) { - if (pair.first != kInputTimeDerivativeKey) { - (*gradient)[pair.first] = pair.second * ratio_; + (*output_times)[long_name()] = self_processing_time; + if (gradients) { + for (const auto& node : CollectNodes(TraversalOrder::REVERSE_BFS)) { + gradients->erase(node->long_name()); } } - } else { - result = SelfProcessingTimeLocked() + - ratio_ * OutputTimeForInputs(input_times, /*gradient=*/nullptr); + return; } - return result; + double result = + self_processing_time + ratio_ * OutputTimeForInputs(*output_times); + if (gradients) { + for (const auto& node : CollectNodes(TraversalOrder::REVERSE_BFS)) { + auto* gradient = gtl::FindOrNull(*gradients, node->long_name()); + if (gradient) { + *gradient *= ratio_; + } + } + (*output_time_gradients)[long_name()] = + OutputTimeGradientsForInputs(*output_time_gradients); + } + (*output_times)[long_name()] = result; } // The processing time is the sum of the self processing time and the product @@ -309,8 +366,8 @@ class KnownRatio : public Node { } double processing_time = ratio_ * TotalProcessingTimeForInputs(*total_processing_times); - total_processing_times->insert( - std::make_pair(long_name(), self_processing_time + processing_time)); + (*total_processing_times)[long_name()] = + self_processing_time + processing_time; } private: @@ -340,6 +397,29 @@ class AsyncKnownRatio : public Node { Args{id_, name_, std::move(output)}, ratio_, parameters); } + void InputTimeLocked(absl::flat_hash_map<string, double>* input_times) + const override TF_SHARED_LOCKS_REQUIRED(mu_) { + double input_time; + + if (ratio_ == 0.0) { + if (output_) { + input_time = (*input_times)[output_->long_name()]; + } else { + input_time = gtl::FindWithDefault(*input_times, kInputTimeKey, 0.0L); + } + (*input_times)[long_name()] = input_time; + return; + } + + double parallelism = 1.0; + auto* parallelism_parameter = gtl::FindOrNull(parameters_, kParallelism); + if (parallelism_parameter) { + parallelism = (*parallelism_parameter)->value; + } + input_time = SelfProcessingTimeLocked() / ratio_ / parallelism; + (*input_times)[long_name()] = input_time; + } + // The output time is estimated using `ComputeWaitTime(output_time, // input_time, parallelism, ...)`, where `output_time` is the sum of the self // processing time and the product of `ratio_` and the sum of output times of @@ -347,9 +427,12 @@ class AsyncKnownRatio : public Node { // has parallelism parameter, then `buffer_size` is derived from parallelism. // // Current implementation assumes that there is at most 1 parameter per node. - double OutputTimeLocked(std::vector<double>* input_times, - absl::flat_hash_map<string, double>* gradient) - const override TF_SHARED_LOCKS_REQUIRED(mu_) { + void OutputTimeLocked( + const absl::flat_hash_map<string, double>& input_times, + absl::flat_hash_map<string, double>* gradients, + absl::flat_hash_map<string, double>* output_times, + absl::flat_hash_map<string, double>* output_time_gradients) const override + TF_SHARED_LOCKS_REQUIRED(mu_) { double parallelism = 1.0; double buffer_size = 0.0; auto* parallelism_parameter = gtl::FindOrNull(parameters_, kParallelism); @@ -361,80 +444,85 @@ class AsyncKnownRatio : public Node { buffer_size = (*buffer_size_parameter)->value; } double self_processing_time = SelfProcessingTimeLocked(); + double result; + double input_time; + if (output_) { + input_time = input_times.at(output_->long_name()); + } else { + input_time = gtl::FindWithDefault(input_times, kInputTimeKey, 0.0L); + } + if (ratio_ == 0.0) { double output_time = self_processing_time / parallelism; - if (gradient) { + if (gradients) { + for (const auto& node : CollectNodes(TraversalOrder::REVERSE_BFS)) { + gradients->erase(node->long_name()); + } + double output_time_der = 0.0L; double input_time_der = 0.0L; double buffer_size_der = 0.0L; - double result = ComputeWaitTime(output_time, input_times->back(), - buffer_size, &output_time_der, - &input_time_der, &buffer_size_der); - auto last_input_time_der = - gtl::FindWithDefault(*gradient, kInputTimeDerivativeKey, 0.0L); - (*gradient)[kInputTimeDerivativeKey] = - last_input_time_der + input_time_der; + result = ComputeWaitTime(output_time, input_time, buffer_size, + &output_time_der, &input_time_der, + &buffer_size_der); + (*output_time_gradients)[long_name()] = input_time_der; // Add derivative w.r.t. own parameter if it's tunable. if (parallelism_parameter && (*parallelism_parameter)->state->tunable) { - (*gradient)[long_name()] = + (*gradients)[long_name()] = -output_time_der * self_processing_time / Square(parallelism) + buffer_size_der; } else if (buffer_size_parameter && (*buffer_size_parameter)->state->tunable) { - (*gradient)[long_name()] = buffer_size_der; + (*gradients)[long_name()] = buffer_size_der; } - return result; + } else { + result = ComputeWaitTime(output_time, input_time, buffer_size, + /*output_time_derivative=*/nullptr, + /*input_time_derivative=*/nullptr, + /*buffer_size_derivative=*/nullptr); } - return ComputeWaitTime(output_time, input_times->back(), buffer_size, - /*output_time_derivative=*/nullptr, - /*input_time_derivative=*/nullptr, - /*buffer_size_derivative=*/nullptr); + (*output_times)[long_name()] = result; + return; } - double old_input_time = input_times->back(); - double new_input_time = self_processing_time / ratio_ / parallelism; - input_times->push_back(new_input_time); - auto cleanup = - gtl::MakeCleanup([input_times]() { input_times->pop_back(); }); - if (gradient) { - absl::flat_hash_map<string, double> inputs_gradient; + + double output_time = self_processing_time / parallelism + + ratio_ * OutputTimeForInputs(*output_times); + if (gradients) { double output_time_der = 0.0L; double input_time_der = 0.0L; double buffer_size_der = 0.0L; - double output_time = - self_processing_time / parallelism + - ratio_ * OutputTimeForInputs(input_times, &inputs_gradient); - double result = - ComputeWaitTime(output_time, old_input_time, buffer_size, + result = + ComputeWaitTime(output_time, input_time, buffer_size, &output_time_der, &input_time_der, &buffer_size_der); - auto last_input_time_der = - gtl::FindWithDefault(*gradient, kInputTimeDerivativeKey, 0.0L); - (*gradient)[kInputTimeDerivativeKey] = - last_input_time_der + input_time_der; - for (auto& pair : inputs_gradient) { - if (pair.first != kInputTimeDerivativeKey) { - (*gradient)[pair.first] = pair.second * ratio_ * output_time_der; + (*output_time_gradients)[long_name()] = input_time_der; + + for (const auto& node : CollectNodes(TraversalOrder::REVERSE_BFS)) { + auto* gradient = gtl::FindOrNull(*gradients, node->long_name()); + if (gradient) { + *gradient *= (ratio_ * output_time_der); } } + // Add derivative w.r.t. own parameter if it's tunable. if (parallelism_parameter && (*parallelism_parameter)->state->tunable) { - (*gradient)[long_name()] = + double inputs_time_der_sum = + OutputTimeGradientsForInputs(*output_time_gradients); + (*gradients)[long_name()] = -output_time_der * self_processing_time / Square(parallelism) + buffer_size_der - - output_time_der * inputs_gradient[kInputTimeDerivativeKey] * - self_processing_time / Square(parallelism); + output_time_der * inputs_time_der_sum * self_processing_time / + Square(parallelism); } else if (buffer_size_parameter && (*buffer_size_parameter)->state->tunable) { - (*gradient)[long_name()] = buffer_size_der; + (*gradients)[long_name()] = buffer_size_der; } - return result; + } else { + result = ComputeWaitTime(output_time, input_time, buffer_size, + /*output_time_derivative=*/nullptr, + /*input_time_derivative=*/nullptr, + /*buffer_size_derivative=*/nullptr); } - double output_time = - self_processing_time / parallelism + - ratio_ * OutputTimeForInputs(input_times, /*gradient=*/nullptr); - return ComputeWaitTime(output_time, old_input_time, buffer_size, - /*output_time_derivative=*/nullptr, - /*input_time_derivative=*/nullptr, - /*buffer_size_derivative=*/nullptr); + (*output_times)[long_name()] = result; } // The processing time is the sum of the self processing time and the product @@ -449,8 +537,8 @@ class AsyncKnownRatio : public Node { } double processing_time = ratio_ * TotalProcessingTimeForInputs(*total_processing_times); - total_processing_times->insert( - std::make_pair(long_name(), self_processing_time + processing_time)); + (*total_processing_times)[long_name()] = + self_processing_time + processing_time; } private: @@ -469,44 +557,64 @@ class UnknownRatio : public Node { return std::make_shared<UnknownRatio>(Args{id_, name_, std::move(output)}); } - // The output time is the sum of the self processing time and the product of - // the ratio estimate and the sum of output times of inputs. - double OutputTimeLocked(std::vector<double>* input_times, - absl::flat_hash_map<string, double>* gradient) + void InputTimeLocked(absl::flat_hash_map<string, double>* input_times) const override TF_SHARED_LOCKS_REQUIRED(mu_) { + double old_input_time; + if (output_) { + old_input_time = (*input_times)[output_->long_name()]; + } else { + old_input_time = gtl::FindWithDefault(*input_times, kInputTimeKey, 0.0L); + } + if (num_elements_ == 0 || inputs_.empty() || inputs_.front()->num_elements() == 0) { - return SelfProcessingTimeLocked(); + (*input_times)[long_name()] = old_input_time; + return; } - // TODO(jsimsa): The current implementation assumes that the number of input - // elements consumed per output is the same across all inputs. std::shared_ptr<Node> input = inputs_.front(); double ratio = static_cast<double>(input->num_elements()) / static_cast<double>(num_elements_); - double old_input_time = input_times->back(); - input_times->back() = (old_input_time + SelfProcessingTimeLocked()) / ratio; - auto cleanup = gtl::MakeCleanup([input_times, old_input_time]() { - input_times->back() = old_input_time; - }); - if (gradient) { - absl::flat_hash_map<string, double> inputs_gradient; - double result = - SelfProcessingTimeLocked() + - ratio * OutputTimeForInputs(input_times, &inputs_gradient); - auto last_input_time_der = - gtl::FindWithDefault(*gradient, kInputTimeDerivativeKey, 0.0L); - (*gradient)[kInputTimeDerivativeKey] = - last_input_time_der + - inputs_gradient[kInputTimeDerivativeKey] / ratio; - for (auto& pair : inputs_gradient) { - if (pair.first != kInputTimeDerivativeKey) { - (*gradient)[pair.first] = pair.second * ratio; + double new_input_time = + (old_input_time + SelfProcessingTimeLocked()) / ratio; + (*input_times)[long_name()] = new_input_time; + } + + // The output time is the sum of the self processing time and the product of + // the ratio estimate and the sum of output times of inputs. + void OutputTimeLocked( + const absl::flat_hash_map<string, double>& input_times, + absl::flat_hash_map<string, double>* gradients, + absl::flat_hash_map<string, double>* output_times, + absl::flat_hash_map<string, double>* output_time_gradients) const override + TF_SHARED_LOCKS_REQUIRED(mu_) { + double self_processing_time = SelfProcessingTimeLocked(); + if (num_elements_ == 0 || inputs_.empty() || + inputs_.front()->num_elements() == 0) { + (*output_times)[long_name()] = self_processing_time; + if (gradients) { + for (const auto& node : CollectNodes(TraversalOrder::REVERSE_BFS)) { + gradients->erase(node->long_name()); } } - return result; + return; } - return SelfProcessingTimeLocked() + - ratio * OutputTimeForInputs(input_times, /*gradient=*/nullptr); + // TODO(jsimsa): The current implementation assumes that the number of input + // elements consumed per output is the same across all inputs. + double ratio = static_cast<double>(inputs_.front()->num_elements()) / + static_cast<double>(num_elements_); + double result = + self_processing_time + ratio * OutputTimeForInputs(*output_times); + if (gradients) { + for (const auto& node : CollectNodes(TraversalOrder::REVERSE_BFS)) { + auto* gradient = gtl::FindOrNull(*gradients, node->long_name()); + if (gradient) { + *gradient *= ratio; + } + } + (*output_time_gradients)[long_name()] = + OutputTimeGradientsForInputs(*output_time_gradients); + } + (*output_times)[long_name()] = result; } // The processing time is the sum of the self processing time and the product @@ -520,8 +628,7 @@ class UnknownRatio : public Node { (*processing_times)[long_name()] = self_processing_time; } if (inputs_.empty() || num_elements_ == 0) { - total_processing_times->insert( - std::make_pair(long_name(), self_processing_time)); + (*total_processing_times)[long_name()] = self_processing_time; return; } // TODO(jsimsa): The current implementation assumes that the number of input @@ -531,8 +638,8 @@ class UnknownRatio : public Node { static_cast<double>(num_elements_); double processing_time = ratio * TotalProcessingTimeForInputs(*total_processing_times); - total_processing_times->insert( - std::make_pair(long_name(), self_processing_time + processing_time)); + (*total_processing_times)[long_name()] = + self_processing_time + processing_time; } }; @@ -548,11 +655,30 @@ class Unknown : public Node { return std::make_shared<Unknown>(Args{id_, name_, std::move(output)}); } - // The output time is the sum of output times of inputs. - double OutputTimeLocked(std::vector<double>* input_times, - absl::flat_hash_map<string, double>* gradient) + void InputTimeLocked(absl::flat_hash_map<string, double>* input_times) const override TF_SHARED_LOCKS_REQUIRED(mu_) { - return OutputTimeForInputs(input_times, gradient); + double input_time; + if (output_) { + input_time = (*input_times)[output_->long_name()]; + } else { + input_time = gtl::FindWithDefault(*input_times, kInputTimeKey, 0.0L); + } + (*input_times)[long_name()] = input_time; + } + + // The output time is the sum of output times of inputs. + void OutputTimeLocked( + const absl::flat_hash_map<string, double>& input_times, + absl::flat_hash_map<string, double>* gradients, + absl::flat_hash_map<string, double>* output_times, + absl::flat_hash_map<string, double>* output_time_gradients) const override + TF_SHARED_LOCKS_REQUIRED(mu_) { + double result = OutputTimeForInputs(*output_times); + (*output_times)[long_name()] = result; + if (gradients) { + (*output_time_gradients)[long_name()] = + OutputTimeGradientsForInputs(*output_time_gradients); + } } // The processing time is the sum of processing times of inputs. @@ -562,8 +688,7 @@ class Unknown : public Node { TF_SHARED_LOCKS_REQUIRED(mu_) { double processing_time = TotalProcessingTimeForInputs(*total_processing_times); - total_processing_times->insert( - std::make_pair(long_name(), processing_time)); + (*total_processing_times)[long_name()] = processing_time; } }; @@ -751,19 +876,21 @@ double Node::ComputeWaitTime(const double& output_time, void Node::CollectTunableParameters( absl::flat_hash_map<string, std::shared_ptr<Parameter>>* parameters) const { - CollectTunableParametersHelper(parameters); - + tf_shared_lock l(mu_); // Collect tunable parameters from the leaves of the nodes tree to the root. - for (const auto& node : CollectNodes()) { + for (const auto& node : CollectNodes(TraversalOrder::REVERSE_BFS)) { + tf_shared_lock l(node->mu_); node->CollectTunableParametersHelper(parameters); } + CollectTunableParametersHelper(parameters); } string Node::DebugString() const { absl::flat_hash_map<string, string> debug_strings; - + tf_shared_lock l(mu_); // Build up the debug string from the leaves of the nodes tree to the root. - for (const auto& node : CollectNodes()) { + for (const auto& node : CollectNodes(TraversalOrder::REVERSE_BFS)) { + tf_shared_lock l(node->mu_); node->DebugStringHelper(&debug_strings); } DebugStringHelper(&debug_strings); @@ -780,10 +907,35 @@ void Node::FlushMetrics() { metrics_.record_num_elements(num_elements_); } -double Node::OutputTime(std::vector<double>* input_times, - absl::flat_hash_map<string, double>* gradient) const { +double Node::OutputTime(absl::flat_hash_map<string, double>* input_times, + absl::flat_hash_map<string, double>* gradients) const { + // To store the output time gradient w.r.t. input time (if `gradients` is not + // `nullptr`) and the output time for each node. + absl::flat_hash_map<string, double> output_time_gradients, output_times; tf_shared_lock l(mu_); - return OutputTimeLocked(input_times, gradient); + auto nodes = CollectNodes(TraversalOrder::BFS); + + // Computes and stores input time for each node from the root to leaves of the + // nodes tree. + InputTimeLocked(input_times); + for (const auto& node : nodes) { + tf_shared_lock l(node->mu_); + node->InputTimeLocked(input_times); + } + + std::reverse(nodes.begin(), nodes.end()); + // Computes and stores the output time and output time gradient w.r.t. input + // time (if `gradients` is not `nullptr`) for each node from leaves of the + // nodes tree to the root. + for (const auto& node : nodes) { + tf_shared_lock l(node->mu_); + node->OutputTimeLocked(*input_times, gradients, &output_times, + &output_time_gradients); + } + OutputTimeLocked(*input_times, gradients, &output_times, + &output_time_gradients); + + return output_times[long_name()]; } std::shared_ptr<Node> Node::Snapshot(std::shared_ptr<Node> output) const { @@ -808,9 +960,10 @@ double Node::SelfProcessingTime() const { double Node::TotalBufferedBytes() const { absl::flat_hash_map<string, double> total_bytes; - + tf_shared_lock l(mu_); // Compute total buffered bytes from the leaves of the nodes tree to the root. - for (const auto& node : CollectNodes()) { + for (const auto& node : CollectNodes(TraversalOrder::REVERSE_BFS)) { + tf_shared_lock l(node->mu_); node->TotalBufferedBytesHelper(&total_bytes); } TotalBufferedBytesHelper(&total_bytes); @@ -820,10 +973,11 @@ double Node::TotalBufferedBytes() const { double Node::TotalMaximumBufferedBytes() const { absl::flat_hash_map<string, double> total_bytes; - + tf_shared_lock l(mu_); // Compute total maximum buffered bytes from the leaves of the nodes tree // to the root. - for (const auto& node : CollectNodes()) { + for (const auto& node : CollectNodes(TraversalOrder::REVERSE_BFS)) { + tf_shared_lock l(node->mu_); node->TotalMaximumBufferedBytesHelper(&total_bytes); } TotalMaximumBufferedBytesHelper(&total_bytes); @@ -836,17 +990,16 @@ double Node::TotalProcessingTime( // Create a hash map to store the per-element CPU time spent in the subtree // rooted in each node. absl::flat_hash_map<string, double> total_processing_times; + tf_shared_lock l(mu_); // Computes per-element CPU time spent in the subtree rooted in the node from // the leaves of the nodes tree to the root. - for (const auto& node : CollectNodes()) { + for (const auto& node : CollectNodes(TraversalOrder::REVERSE_BFS)) { tf_shared_lock l(node->mu_); node->TotalProcessingTimeLocked(processing_times, &total_processing_times); } - { - tf_shared_lock l(mu_); - TotalProcessingTimeLocked(processing_times, &total_processing_times); - } + TotalProcessingTimeLocked(processing_times, &total_processing_times); + return total_processing_times[long_name()]; } @@ -859,13 +1012,25 @@ double Node::AverageBufferedElementSize() const { } double Node::OutputTimeForInputs( - std::vector<double>* input_times, - absl::flat_hash_map<string, double>* gradient) const { + const absl::flat_hash_map<string, double>& output_times) const { double sum = 0; for (auto& input : inputs_) { // Inputs for which autotuning is disabled are excluded. if (input->autotune()) { - sum += input->OutputTime(input_times, gradient); + sum += output_times.at(input->long_name()); + } + } + return sum; +} + +double Node::OutputTimeGradientsForInputs( + const absl::flat_hash_map<string, double>& output_time_gradients) const { + double sum = 0; + for (auto& input : inputs_) { + // Inputs for which autotuning is disabled are excluded. + if (input->autotune()) { + sum += + gtl::FindWithDefault(output_time_gradients, input->long_name(), 0.0L); } } return sum; @@ -919,12 +1084,12 @@ double Node::SelfProcessingTimeLocked() const { static_cast<double>(num_elements_); } -Node::NodeVector Node::CollectNodes() const { +Node::NodeVector Node::CollectNodes(TraversalOrder order) const + TF_SHARED_LOCKS_REQUIRED(mu_) { NodeVector node_vector; std::list<std::shared_ptr<Node>> temp_list; { - tf_shared_lock l(mu_); for (auto& input : inputs_) { node_vector.push_back(input); temp_list.push_back(input); @@ -942,16 +1107,19 @@ Node::NodeVector Node::CollectNodes() const { } } } - std::reverse(node_vector.begin(), node_vector.end()); + + if (order == TraversalOrder::REVERSE_BFS) { + std::reverse(node_vector.begin(), node_vector.end()); + } return node_vector; } void Node::CollectTunableParametersHelper( - absl::flat_hash_map<string, std::shared_ptr<Parameter>>* parameters) const { + absl::flat_hash_map<string, std::shared_ptr<Parameter>>* parameters) const + TF_SHARED_LOCKS_REQUIRED(mu_) { if (!autotune_) { return; } - tf_shared_lock l(mu_); for (auto& pair : parameters_) { if (pair.second->state->tunable) { parameters->insert(std::make_pair(long_name(), pair.second)); @@ -959,9 +1127,8 @@ void Node::CollectTunableParametersHelper( } } -void Node::DebugStringHelper( - absl::flat_hash_map<string, string>* debug_strings) const { - tf_shared_lock l(mu_); +void Node::DebugStringHelper(absl::flat_hash_map<string, string>* debug_strings) + const TF_SHARED_LOCKS_REQUIRED(mu_) { string result; strings::StrAppend(&result, long_name(), ":\n"); strings::StrAppend(&result, " autotune=", autotune_.load(), "\n"); @@ -1011,13 +1178,13 @@ std::shared_ptr<Node> Node::SnapshotHelper( } void Node::TotalBufferedBytesHelper( - absl::flat_hash_map<string, double>* total_bytes) const { + absl::flat_hash_map<string, double>* total_bytes) const + TF_SHARED_LOCKS_REQUIRED(mu_) { if (!autotune_) { total_bytes->insert(std::make_pair(long_name(), 0)); return; } - tf_shared_lock l(mu_); double result = 0; auto* parameter = gtl::FindOrNull(parameters_, kBufferSize); if (!parameter) { @@ -1033,13 +1200,13 @@ void Node::TotalBufferedBytesHelper( } void Node::TotalMaximumBufferedBytesHelper( - absl::flat_hash_map<string, double>* total_bytes) const { + absl::flat_hash_map<string, double>* total_bytes) const + TF_SHARED_LOCKS_REQUIRED(mu_) { if (!autotune_) { total_bytes->insert(std::make_pair(long_name(), 0)); return; } - tf_shared_lock l(mu_); double result = 0; auto* parameter = gtl::FindOrNull(parameters_, kBufferSize); if (!parameter) { @@ -1181,8 +1348,8 @@ void Model::OptimizeGradientDescent(int64 cpu_budget, int64 ram_budget) { double new_output_time; double new_value; for (int i = 0; i < kMaxIterations; ++i) { - absl::flat_hash_map<string, double> gradient; - new_output_time = OutputTime(snapshot, &gradient); + absl::flat_hash_map<string, double> gradients; + new_output_time = OutputTime(snapshot, &gradients); int64 model_parallelism = 0; for (auto& pair : essential_parameters) { model_parallelism += std::round(pair.second->value); @@ -1199,12 +1366,12 @@ void Model::OptimizeGradientDescent(int64 cpu_budget, int64 ram_budget) { for (auto& pair : parameters) { if (pair.second->value != pair.second->max) { max_abs_derivative = - std::max(max_abs_derivative, std::abs(gradient[pair.first])); + std::max(max_abs_derivative, std::abs(gradients[pair.first])); } } for (auto& pair : parameters) { new_value = pair.second->value - - kDescentStep * gradient[pair.first] / max_abs_derivative; + kDescentStep * gradients[pair.first] / max_abs_derivative; // Projection on a feasible interval. if (new_value > pair.second->max) { pair.second->value = pair.second->max; @@ -1248,7 +1415,7 @@ void Model::OptimizeHillClimb(int64 cpu_budget, int64 ram_budget) { pair.second->value = pair.second->min; } while (true) { - const double output_time = OutputTime(snapshot, /*gradient=*/nullptr); + const double output_time = OutputTime(snapshot, /*gradients=*/nullptr); bool all_max = true; for (auto& pair : parameters) { if (pair.second->value < pair.second->max) { @@ -1267,7 +1434,7 @@ void Model::OptimizeHillClimb(int64 cpu_budget, int64 ram_budget) { continue; } pair.second->value++; - double new_output_time = OutputTime(snapshot, /*gradient=*/nullptr); + double new_output_time = OutputTime(snapshot, /*gradients=*/nullptr); double delta = output_time - new_output_time; if (delta > best_delta && (delta > kBufferSizeMinDelta || pair.second->name != kBufferSize)) { @@ -1297,15 +1464,18 @@ void Model::OptimizeHillClimb(int64 cpu_budget, int64 ram_budget) { } double Model::OutputTime(std::shared_ptr<Node> node, - absl::flat_hash_map<string, double>* gradient) { - std::vector<double> input_times(1, 0); + absl::flat_hash_map<string, double>* gradients) { + // To store the input time for each node. + absl::flat_hash_map<string, double> input_times; + // TODO(jsimsa): Now that we are accounting for buffer size in wait time // computation, assuming that the input is infinitely fast will result in // inaccurate estimates of the output latency. // // We should compute the output latency as a fix-point of the following // equation: `output_time = node(OutputTime(input_times(1, output_time))`. - return node->OutputTime(&input_times, gradient); + + return node->OutputTime(&input_times, gradients); } double Model::TotalBufferedBytes(std::shared_ptr<Node> node) { diff --git a/tensorflow/core/framework/model.h b/tensorflow/core/framework/model.h index a4af549fad2..e325056f0c4 100644 --- a/tensorflow/core/framework/model.h +++ b/tensorflow/core/framework/model.h @@ -42,11 +42,19 @@ constexpr int64 kAutotune = -1; constexpr char kParallelism[] = "parallelism"; constexpr char kBufferSize[] = "buffer_size"; +// A key used to identify input time gradient. +constexpr char kInputTimeKey[] = "input_time"; + enum class AutotuneAlgorithm { HILL_CLIMB = 0, GRADIENT_DESCENT = 1, }; +enum class TraversalOrder { + BFS = 0, + REVERSE_BFS = 1, +}; + // Represents thread-safe state that can be shared between an input pipeline and // the performance model. struct SharedState { @@ -316,11 +324,11 @@ class Node { // Flushes the metrics recorded by this node. void FlushMetrics() TF_LOCKS_EXCLUDED(mu_); - // Returns the per-element output time for this node and if `gradient` is not - // `nullptr`, collects the gradient of the output time w.r.t. tunable - // parameters of the subtree rooted in this node and the last input time. - double OutputTime(std::vector<double>* input_times, - absl::flat_hash_map<string, double>* gradient) const + // Returns the per-element output time for this node and if `gradients` is not + // `nullptr`, collects the output time gradient w.r.t. tunable parameters of + // the subtree rooted in this node. + double OutputTime(absl::flat_hash_map<string, double>* input_times, + absl::flat_hash_map<string, double>* gradients) const TF_LOCKS_EXCLUDED(mu_); // Returns a copy of this node, making a deep copy of its inputs and a @@ -414,20 +422,34 @@ class Node { // Returns the average size of an element buffered in this node. double AverageBufferedElementSize() const TF_SHARED_LOCKS_REQUIRED(mu_); - // Returns the sum of per-element output time for the inputs of this node and - // if `gradient` is not `nullptr`, collects gradients of output times w.r.t. - // tunable parameters and the last input time. - double OutputTimeForInputs(std::vector<double>* input_times, - absl::flat_hash_map<string, double>* gradient) - const TF_SHARED_LOCKS_REQUIRED(mu_); + // Returns the sum of per-element output time for the tunable inputs of this + // node. + double OutputTimeForInputs( + const absl::flat_hash_map<string, double>& output_times) const + TF_SHARED_LOCKS_REQUIRED(mu_); - // Returns the per-element output time for this node and if `gradient` is not - // `nullptr`, collects the gradient of the output time w.r.t. tunable - // parameters of the subtree rooted in this node and the last input time. - virtual double OutputTimeLocked(std::vector<double>* input_times, - absl::flat_hash_map<string, double>* gradient) + // Returns the sum of output time gradient w.r.t. input time for the tunable + // inputs of this node. + double OutputTimeGradientsForInputs( + const absl::flat_hash_map<string, double>& output_time_gradients) const + TF_SHARED_LOCKS_REQUIRED(mu_); + + // Computes the input time for this node and stores it in `input_times`. + virtual void InputTimeLocked(absl::flat_hash_map<string, double>* input_times) const TF_SHARED_LOCKS_REQUIRED(mu_) = 0; + // Computes the per-element output time for this node and stores it in + // `output_times`. If `gradients` is not `nullptr`, computes the output time + // gradient w.r.t. tunable parameters of the subtree rooted in this node and + // stores it in `gradients`, also computes the output time gradient w.r.t. + // input time and stores it in `output_time_gradients`. + virtual void OutputTimeLocked( + const absl::flat_hash_map<string, double>& input_times, + absl::flat_hash_map<string, double>* gradients, + absl::flat_hash_map<string, double>* output_times, + absl::flat_hash_map<string, double>* output_time_gradients) const + TF_SHARED_LOCKS_REQUIRED(mu_) = 0; + // Returns the sum of per-element processing time for the inputs of this node // by adding values for input nodes in `total_processing_times`. Processing // time for a given input is a weighted combination of a statistic based on @@ -452,18 +474,20 @@ class Node { absl::flat_hash_map<string, double>* total_processing_times) TF_SHARED_LOCKS_REQUIRED(mu_) = 0; - // Returns a vector of nodes of the subtree rooted in this node. - // The nodes are in the reverse breadth-first search order. - NodeVector CollectNodes() const; + // Returns a vector of nodes of the subtree rooted in this node. The nodes are + // either in breadth-first search or reverse breadth-first search order + // depending on the `order` argument. The root node itself is not collected. + NodeVector CollectNodes(TraversalOrder order) const + TF_SHARED_LOCKS_REQUIRED(mu_); // Collect tunable parameters for the node. void CollectTunableParametersHelper( - absl::flat_hash_map<string, std::shared_ptr<Parameter>>* parameters) - const; + absl::flat_hash_map<string, std::shared_ptr<Parameter>>* parameters) const + TF_SHARED_LOCKS_REQUIRED(mu_); // Build up debug string for the node and store in the debug strings map. - void DebugStringHelper( - absl::flat_hash_map<string, string>* debug_strings) const; + void DebugStringHelper(absl::flat_hash_map<string, string>* debug_strings) + const TF_SHARED_LOCKS_REQUIRED(mu_); // Copy the node and add the (input, copy) pairs to the NodePairList. std::shared_ptr<Node> SnapshotHelper(std::shared_ptr<Node> clone_base, @@ -471,12 +495,14 @@ class Node { // Compute total buffered bytes for the node and store in the total bytes map. void TotalBufferedBytesHelper( - absl::flat_hash_map<string, double>* total_bytes) const; + absl::flat_hash_map<string, double>* total_bytes) const + TF_SHARED_LOCKS_REQUIRED(mu_); // Compute total maximum buffered bytes for the node and store in the total // bytes map. void TotalMaximumBufferedBytesHelper( - absl::flat_hash_map<string, double>* total_bytes) const; + absl::flat_hash_map<string, double>* total_bytes) const + TF_SHARED_LOCKS_REQUIRED(mu_); // Stores the time passed to the last call to `Node::record_start()` on the // current thread. @@ -619,11 +645,11 @@ class Model { // an element divided by CPU budget. void OptimizeGradientDescent(int64 cpu_budget, int64 ram_budget); - // Collects the output time and if `gradient` is not `nullptr`, the output + // Collects the output time and if `gradients` is not `nullptr`, the output // time gradient w.r.t. tunable parameters of the subtree rooted in the given - // node and the last input time. + // node. double OutputTime(std::shared_ptr<Node> node, - absl::flat_hash_map<string, double>* gradient); + absl::flat_hash_map<string, double>* gradients); // Collects the processing time for the given node. double TotalProcessingTime(std::shared_ptr<Node> node); diff --git a/tensorflow/core/framework/model_test.cc b/tensorflow/core/framework/model_test.cc index 898594b7c81..688dd0083e9 100644 --- a/tensorflow/core/framework/model_test.cc +++ b/tensorflow/core/framework/model_test.cc @@ -44,18 +44,19 @@ TEST_P(AsyncInterleaveManyTest, Model) { async_interleave_many->remove_input(meta_source); }); std::shared_ptr<Node> source1 = - model::MakeSourceNode({1, "source1", async_interleave_many}); + model::MakeSourceNode({2, "source1", async_interleave_many}); async_interleave_many->add_input(source1); auto cleanup1 = gtl::MakeCleanup([async_interleave_many, source1]() { async_interleave_many->remove_input(source1); }); std::shared_ptr<Node> source2 = - model::MakeSourceNode({2, "source2", async_interleave_many}); + model::MakeSourceNode({3, "source2", async_interleave_many}); async_interleave_many->add_input(source2); auto cleanup2 = gtl::MakeCleanup([async_interleave_many, source2]() { async_interleave_many->remove_input(source2); }); - std::vector<double> input_times(1, input_time); + absl::flat_hash_map<string, double> input_times; + input_times[kInputTimeKey] = input_time; EXPECT_EQ(async_interleave_many->TotalBufferedBytes(), 0); EXPECT_EQ(async_interleave_many->TotalMaximumBufferedBytes(), 0); async_interleave_many->record_buffer_event(110, 10); @@ -123,7 +124,8 @@ TEST_P(AsyncKnownRatioTest, Model) { std::shared_ptr<Node> source2 = model::MakeSourceNode({2, "source2", async_known_many}); async_known_many->add_input(source2); - std::vector<double> input_times(1, input_time); + absl::flat_hash_map<string, double> input_times; + input_times[kInputTimeKey] = input_time; EXPECT_EQ(async_known_many->TotalBufferedBytes(), 0); EXPECT_EQ(async_known_many->TotalMaximumBufferedBytes(), 0); async_known_many->record_buffer_event(110, 10); @@ -194,12 +196,12 @@ TEST(InterleaveManyTest, Model) { model::MakeSourceNode({1, "meta_source", interleave_many}); interleave_many->add_input(meta_source); std::shared_ptr<Node> source1 = - model::MakeSourceNode({1, "source1", interleave_many}); + model::MakeSourceNode({2, "source1", interleave_many}); interleave_many->add_input(source1); std::shared_ptr<Node> source2 = - model::MakeSourceNode({2, "source2", interleave_many}); + model::MakeSourceNode({3, "source2", interleave_many}); interleave_many->add_input(source2); - std::vector<double> input_times(1, 0); + absl::flat_hash_map<string, double> input_times; interleave_many->add_processing_time(100); EXPECT_EQ(interleave_many->processing_time(), 100); EXPECT_EQ(interleave_many->TotalProcessingTime(/*processing_times=*/nullptr), @@ -238,7 +240,7 @@ TEST_P(KnownRatioTest, Model) { std::shared_ptr<Node> source2 = model::MakeSourceNode({2, "source2", known_many}); known_many->add_input(source2); - std::vector<double> input_times(1, 0); + absl::flat_hash_map<string, double> input_times; source1->add_processing_time(100); EXPECT_EQ(known_many->TotalProcessingTime(/*processing_times=*/nullptr), 0); EXPECT_EQ(known_many->OutputTime(&input_times, nullptr), 0); @@ -286,7 +288,7 @@ INSTANTIATE_TEST_SUITE_P(Test, KnownRatioTest, ::testing::Values(0, 1, 2, 4)); TEST(SourceTest, Model) { std::shared_ptr<Node> source = model::MakeSourceNode({0, "source", nullptr}); - std::vector<double> input_times(1, 0); + absl::flat_hash_map<string, double> input_times; source->add_processing_time(100); EXPECT_EQ(source->processing_time(), 100); EXPECT_EQ(source->TotalProcessingTime(/*processing_times=*/nullptr), 0); @@ -310,7 +312,7 @@ TEST(UnknownRatioTest, Model) { std::shared_ptr<Node> source2 = model::MakeSourceNode({2, "source2", unknown_many}); unknown_many->add_input(source2); - std::vector<double> input_times(1, 0); + absl::flat_hash_map<string, double> input_times; unknown_many->add_processing_time(100); EXPECT_EQ(unknown_many->processing_time(), 100); EXPECT_EQ(unknown_many->TotalProcessingTime(/*processing_times=*/nullptr), 0); @@ -345,7 +347,7 @@ TEST(UnknownTest, Model) { std::shared_ptr<Node> source2 = model::MakeSourceNode({2, "source2", unknown}); unknown->add_input(source2); - std::vector<double> input_times(1, 0); + absl::flat_hash_map<string, double> input_times; source1->add_processing_time(100); EXPECT_EQ(unknown->TotalProcessingTime(/*processing_times=*/nullptr), 0); EXPECT_EQ(unknown->OutputTime(&input_times, nullptr), 0); @@ -390,17 +392,23 @@ class TestNode : public model::Node { return nullptr; } - double OutputTimeLocked(std::vector<double>* input_times, - absl::flat_hash_map<string, double>* gradient) - const override TF_SHARED_LOCKS_REQUIRED(mu_) { - return 0; + void InputTimeLocked(absl::flat_hash_map<string, double>* input_times) + const override TF_SHARED_LOCKS_REQUIRED(mu_) {} + + void OutputTimeLocked( + const absl::flat_hash_map<string, double>& input_times, + absl::flat_hash_map<string, double>* gradients, + absl::flat_hash_map<string, double>* output_times, + absl::flat_hash_map<string, double>* output_time_gradients) const override + TF_SHARED_LOCKS_REQUIRED(mu_) { + (*output_times)[long_name()] = 0; } void TotalProcessingTimeLocked( absl::flat_hash_map<string, double>* processing_times, absl::flat_hash_map<string, double>* total_processing_times) override TF_SHARED_LOCKS_REQUIRED(mu_) { - total_processing_times->insert(std::make_pair(long_name(), 0)); + (*total_processing_times)[long_name()] = 0; } }; @@ -504,7 +512,7 @@ TEST(AsyncInterleaveManyGradientTest, Model) { async_interleave_many->remove_input(meta_source); }); std::shared_ptr<Node> source1 = model::MakeAsyncInterleaveManyNode( - {0, "async_interleave_many", nullptr}, + {2, "async_interleave_many", async_interleave_many}, {model::MakeParameter( "parallelism", std::make_shared<SharedState>(parallelism, nullptr, nullptr), 1, @@ -514,12 +522,13 @@ TEST(AsyncInterleaveManyGradientTest, Model) { async_interleave_many->remove_input(source1); }); std::shared_ptr<Node> source2 = - model::MakeSourceNode({2, "source2", async_interleave_many}); + model::MakeSourceNode({3, "source2", async_interleave_many}); async_interleave_many->add_input(source2); auto cleanup2 = gtl::MakeCleanup([async_interleave_many, source2]() { async_interleave_many->remove_input(source2); }); - std::vector<double> input_times(1, input_time); + absl::flat_hash_map<string, double> input_times; + input_times[kInputTimeKey] = input_time; absl::flat_hash_map<string, std::shared_ptr<Parameter>> parameters; async_interleave_many->CollectTunableParameters(¶meters); async_interleave_many->record_element(); @@ -532,13 +541,13 @@ TEST(AsyncInterleaveManyGradientTest, Model) { parameters[source1->long_name()]->value = 1; // Test gradient of own parameters. - absl::flat_hash_map<string, double> gradient; + absl::flat_hash_map<string, double> gradients; double output_time = - async_interleave_many->OutputTime(&input_times, &gradient); + async_interleave_many->OutputTime(&input_times, &gradients); parameters[async_interleave_many->long_name()]->value += kParameterStep; double new_output_time = async_interleave_many->OutputTime(&input_times, nullptr); - EXPECT_NEAR(gradient[async_interleave_many->long_name()], + EXPECT_NEAR(gradients[async_interleave_many->long_name()], (new_output_time - output_time) / kParameterStep, kComparisonPrecision); @@ -546,7 +555,7 @@ TEST(AsyncInterleaveManyGradientTest, Model) { parameters[async_interleave_many->long_name()]->value -= kParameterStep; parameters[source1->long_name()]->value += kParameterStep; new_output_time = async_interleave_many->OutputTime(&input_times, nullptr); - EXPECT_NEAR(gradient[source1->long_name()], + EXPECT_NEAR(gradients[source1->long_name()], (new_output_time - output_time) / kParameterStep, kComparisonPrecision); } @@ -565,7 +574,7 @@ TEST_P(AsyncKnownRatioGradientTest, Model) { std::make_shared<SharedState>(parameter_value, nullptr, nullptr), 1, parameter_value)}); std::shared_ptr<Node> source1 = model::MakeAsyncKnownRatioNode( - {0, "source1", nullptr}, num_inputs_per_output, + {1, "source1", async_known_many}, num_inputs_per_output, {model::MakeParameter( parameter_name, std::make_shared<SharedState>(parameter_value, nullptr, nullptr), 1, @@ -573,7 +582,8 @@ TEST_P(AsyncKnownRatioGradientTest, Model) { async_known_many->add_input(source1); std::shared_ptr<Node> source2 = model::MakeSourceNode({2, "source2", async_known_many}); - std::vector<double> input_times(1, input_time); + absl::flat_hash_map<string, double> input_times; + input_times[kInputTimeKey] = input_time; async_known_many->add_input(source2); source1->record_element(); source1->add_processing_time(100); @@ -584,14 +594,14 @@ TEST_P(AsyncKnownRatioGradientTest, Model) { // Test gradient of own parameters. absl::flat_hash_map<string, std::shared_ptr<Parameter>> parameters; - absl::flat_hash_map<string, double> gradient; + absl::flat_hash_map<string, double> gradients; async_known_many->CollectTunableParameters(¶meters); parameters[async_known_many->long_name()]->value = 1; parameters[source1->long_name()]->value = 1; - double output_time = async_known_many->OutputTime(&input_times, &gradient); + double output_time = async_known_many->OutputTime(&input_times, &gradients); parameters[async_known_many->long_name()]->value += kParameterStep; double new_output_time = async_known_many->OutputTime(&input_times, nullptr); - EXPECT_NEAR(gradient[async_known_many->long_name()], + EXPECT_NEAR(gradients[async_known_many->long_name()], (new_output_time - output_time) / kParameterStep, kComparisonPrecision); @@ -599,7 +609,7 @@ TEST_P(AsyncKnownRatioGradientTest, Model) { parameters[async_known_many->long_name()]->value -= kParameterStep; parameters[source1->long_name()]->value += kParameterStep; new_output_time = async_known_many->OutputTime(&input_times, nullptr); - EXPECT_NEAR(gradient[source1->long_name()], + EXPECT_NEAR(gradients[source1->long_name()], (new_output_time - output_time) / kParameterStep, kComparisonPrecision); } @@ -614,28 +624,29 @@ TEST(InterleaveManyGradientTest, Model) { std::shared_ptr<Node> interleave_many = model::MakeInterleaveManyNode({0, "interleave_many", nullptr}); std::shared_ptr<Node> async_known_many = model::MakeAsyncKnownRatioNode( - {0, "async_known_many", nullptr}, num_inputs_per_output, + {1, "async_known_many", interleave_many}, num_inputs_per_output, {model::MakeParameter( "parallelism", std::make_shared<SharedState>(parallelism, nullptr, nullptr), 1, parallelism)}); std::shared_ptr<Node> source1 = - model::MakeSourceNode({2, "source1", async_known_many}); + model::MakeSourceNode({2, "source1", interleave_many}); interleave_many->record_element(); interleave_many->add_processing_time(100); interleave_many->add_input(source1); interleave_many->add_input(async_known_many); async_known_many->record_element(); async_known_many->add_processing_time(300); - std::vector<double> input_times(1, input_time); + absl::flat_hash_map<string, double> input_times; + input_times[kInputTimeKey] = input_time; absl::flat_hash_map<string, std::shared_ptr<Parameter>> parameters; - absl::flat_hash_map<string, double> gradient; + absl::flat_hash_map<string, double> gradients; interleave_many->CollectTunableParameters(¶meters); parameters[async_known_many->long_name()]->value = 1; - double output_time = interleave_many->OutputTime(&input_times, &gradient); + double output_time = interleave_many->OutputTime(&input_times, &gradients); parameters[async_known_many->long_name()]->value += kParameterStep; double new_output_time = interleave_many->OutputTime(&input_times, nullptr); - EXPECT_NEAR(gradient[async_known_many->long_name()], + EXPECT_NEAR(gradients[async_known_many->long_name()], (new_output_time - output_time) / kParameterStep, kComparisonPrecision); } @@ -647,7 +658,7 @@ TEST(KnownRatioGradientTest, Model) { std::shared_ptr<Node> known_many = model::MakeKnownRatioNode( {0, "known_many", nullptr}, num_inputs_per_output); std::shared_ptr<Node> async_known_many = model::MakeAsyncKnownRatioNode( - {0, "async_known_many", nullptr}, num_inputs_per_output, + {1, "async_known_many", known_many}, num_inputs_per_output, {model::MakeParameter( "parallelism", std::make_shared<SharedState>(parallelism, nullptr, nullptr), 1, @@ -657,15 +668,16 @@ TEST(KnownRatioGradientTest, Model) { known_many->add_input(async_known_many); async_known_many->record_element(); async_known_many->add_processing_time(300); - std::vector<double> input_times(1, input_time); + absl::flat_hash_map<string, double> input_times; + input_times[kInputTimeKey] = input_time; absl::flat_hash_map<string, std::shared_ptr<Parameter>> parameters; - absl::flat_hash_map<string, double> gradient; + absl::flat_hash_map<string, double> gradients; known_many->CollectTunableParameters(¶meters); parameters[async_known_many->long_name()]->value = 1; - double output_time = known_many->OutputTime(&input_times, &gradient); + double output_time = known_many->OutputTime(&input_times, &gradients); parameters[async_known_many->long_name()]->value += kParameterStep; double new_output_time = known_many->OutputTime(&input_times, nullptr); - EXPECT_NEAR(gradient[async_known_many->long_name()], + EXPECT_NEAR(gradients[async_known_many->long_name()], (new_output_time - output_time) / kParameterStep, kComparisonPrecision); } @@ -677,7 +689,7 @@ TEST(UnknownRatioGradientTest, Model) { std::shared_ptr<Node> unknown_many = model::MakeUnknownRatioNode({0, "unknown_many", nullptr}); std::shared_ptr<Node> async_known_many = model::MakeAsyncKnownRatioNode( - {0, "async_known_many", nullptr}, num_inputs_per_output, + {1, "async_known_many", unknown_many}, num_inputs_per_output, {model::MakeParameter( "parallelism", std::make_shared<SharedState>(parallelism, nullptr, nullptr), 1, @@ -687,15 +699,16 @@ TEST(UnknownRatioGradientTest, Model) { unknown_many->add_input(async_known_many); async_known_many->record_element(); async_known_many->add_processing_time(300); - std::vector<double> input_times(1, input_time); + absl::flat_hash_map<string, double> input_times; + input_times[kInputTimeKey] = input_time; absl::flat_hash_map<string, std::shared_ptr<Parameter>> parameters; - absl::flat_hash_map<string, double> gradient; + absl::flat_hash_map<string, double> gradients; unknown_many->CollectTunableParameters(¶meters); parameters[async_known_many->long_name()]->value = 1; - double output_time = unknown_many->OutputTime(&input_times, &gradient); + double output_time = unknown_many->OutputTime(&input_times, &gradients); parameters[async_known_many->long_name()]->value += kParameterStep; double new_output_time = unknown_many->OutputTime(&input_times, nullptr); - EXPECT_NEAR(gradient[async_known_many->long_name()], + EXPECT_NEAR(gradients[async_known_many->long_name()], (new_output_time - output_time) / kParameterStep, kComparisonPrecision); } @@ -707,7 +720,7 @@ TEST(UnknownGradientTest, Model) { std::shared_ptr<Node> unknown = model::MakeUnknownNode({0, "unknown", nullptr}); std::shared_ptr<Node> async_known_many = model::MakeAsyncKnownRatioNode( - {0, "async_known_many", nullptr}, num_inputs_per_output, + {1, "async_known_many", unknown}, num_inputs_per_output, {model::MakeParameter( "parallelism", std::make_shared<SharedState>(parallelism, nullptr, nullptr), 1, @@ -717,15 +730,16 @@ TEST(UnknownGradientTest, Model) { unknown->add_input(async_known_many); async_known_many->record_element(); async_known_many->add_processing_time(300); - std::vector<double> input_times(1, input_time); + absl::flat_hash_map<string, double> input_times; + input_times[kInputTimeKey] = input_time; absl::flat_hash_map<string, std::shared_ptr<Parameter>> parameters; - absl::flat_hash_map<string, double> gradient; + absl::flat_hash_map<string, double> gradients; unknown->CollectTunableParameters(¶meters); parameters[async_known_many->long_name()]->value = 1; - double output_time = unknown->OutputTime(&input_times, &gradient); + double output_time = unknown->OutputTime(&input_times, &gradients); parameters[async_known_many->long_name()]->value += kParameterStep; double new_output_time = unknown->OutputTime(&input_times, nullptr); - EXPECT_NEAR(gradient[async_known_many->long_name()], + EXPECT_NEAR(gradients[async_known_many->long_name()], (new_output_time - output_time) / kParameterStep, kComparisonPrecision); } diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h index 54541be0b4f..744a14e007e 100644 --- a/tensorflow/core/framework/tensor.h +++ b/tensorflow/core/framework/tensor.h @@ -18,6 +18,7 @@ limitations under the License. #include <cstdint> #include <type_traits> + #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -239,6 +240,12 @@ class Tensor { /// are not valid. Tensor(Tensor&& other); + // Explicitly delete constructor that take a pointer (except char*) + // so that the pointer doesn't get implicitly cast to bool. + template <typename T, typename std::enable_if<!std::is_same<T, char>::value, + T>::type* = nullptr> + explicit Tensor(T* t) = delete; + ~Tensor(); /// Returns the data type. diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index b880055b47d..030064e49fb 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -1064,6 +1064,7 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:op_types", + "//tensorflow/core/grappler:utils", "//tensorflow/core/grappler/costs:graph_properties", "//tensorflow/core/grappler/utils:graph_view", "@com_google_absl//absl/strings", diff --git a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_test.cc b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_test.cc index c85d85e69ff..79bedf5f2e6 100644 --- a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_test.cc @@ -356,57 +356,35 @@ TEST_F(GenericLayoutOptimizerTest, Conv2DBackpropInputNonConstInputSizes) { #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM) GTEST_SKIP() << "Neither CUDA nor ROCm is enabled"; #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM) - Scope s = Scope::NewRootScope(); - auto conv = SimpleConv2DBackpropInput(&s, 7, 2, "SAME", /*dilated=*/false, - /*input_sizes_length=*/4); - Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv}); - GrapplerItem item; - TF_ASSERT_OK(s.ToGraphDef(&item.graph)); + for (const int input_sizes_length : {2, 4}) { + Scope s = Scope::NewRootScope(); + auto conv = SimpleConv2DBackpropInput(&s, 7, 2, "SAME", /*dilated=*/false, + input_sizes_length); + Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv}); + GrapplerItem item; + TF_ASSERT_OK(s.ToGraphDef(&item.graph)); - GenericLayoutOptimizer optimizer; - GraphDef output; - TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); + GenericLayoutOptimizer optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); - Status status; - utils::GraphView graph_view(&output, &status); - TF_ASSERT_OK(status); - auto* conv2d_backprop_node = graph_view.GetNode("Conv2DBackpropInput"); - ASSERT_NE(conv2d_backprop_node, nullptr); - ASSERT_EQ(conv2d_backprop_node->NumRegularFanins(), 3); - VerifyRegularFaninMatch( - conv2d_backprop_node, 0, - "Conv2DBackpropInput-0-DataFormatVecPermuteNHWCToNCHW-LayoutOptimizer", - 0); - auto* input_sizes_node = graph_view.GetNode( - "Conv2DBackpropInput-0-DataFormatVecPermuteNHWCToNCHW-LayoutOptimizer"); - ASSERT_NE(input_sizes_node, nullptr); - EXPECT_EQ(input_sizes_node->GetOp(), "DataFormatVecPermute"); - ASSERT_EQ(input_sizes_node->NumRegularFanins(), 1); - VerifyRegularFaninMatch(input_sizes_node, 0, "InputSizesIdentity", 0); -} - -TEST_F(GenericLayoutOptimizerTest, Conv2DBackpropInput2DInputSizes) { -#if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM) - GTEST_SKIP() << "Neither CUDA nor ROCm is enabled"; -#endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM) - Scope s = Scope::NewRootScope(); - auto conv = SimpleConv2DBackpropInput(&s, 7, 2, "SAME", /*dilated=*/false, - /*input_sizes_length=*/2); - Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv}); - GrapplerItem item; - TF_ASSERT_OK(s.ToGraphDef(&item.graph)); - - GenericLayoutOptimizer optimizer; - GraphDef output; - TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); - - Status status; - utils::GraphView graph_view(&output, &status); - TF_ASSERT_OK(status); - auto* conv2d_backprop_node = graph_view.GetNode("Conv2DBackpropInput"); - ASSERT_NE(conv2d_backprop_node, nullptr); - ASSERT_EQ(conv2d_backprop_node->NumRegularFanins(), 3); - VerifyRegularFaninMatch(conv2d_backprop_node, 0, "InputSizesIdentity", 0); + Status status; + utils::GraphView graph_view(&output, &status); + TF_ASSERT_OK(status); + auto* conv2d_backprop_node = graph_view.GetNode("Conv2DBackpropInput"); + ASSERT_NE(conv2d_backprop_node, nullptr); + ASSERT_EQ(conv2d_backprop_node->NumRegularFanins(), 3); + VerifyRegularFaninMatch( + conv2d_backprop_node, 0, + "Conv2DBackpropInput-0-DataFormatVecPermuteNHWCToNCHW-LayoutOptimizer", + 0); + auto* input_sizes_node = graph_view.GetNode( + "Conv2DBackpropInput-0-DataFormatVecPermuteNHWCToNCHW-LayoutOptimizer"); + ASSERT_NE(input_sizes_node, nullptr); + EXPECT_EQ(input_sizes_node->GetOp(), "DataFormatVecPermute"); + ASSERT_EQ(input_sizes_node->NumRegularFanins(), 1); + VerifyRegularFaninMatch(input_sizes_node, 0, "InputSizesIdentity", 0); + } } TEST_F(GenericLayoutOptimizerTest, Conv2DDataFormatVecPermuteCollapse) { diff --git a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc index a5a5f7ae64a..ab7d8fcd6cf 100644 --- a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc +++ b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc @@ -739,28 +739,13 @@ Status Conv2DBackpropInputTransposer::TransposeNode( VLOG(3) << fanin_node->GetName() << " is not a vector."; return Status::OK(); } - int vector_size = fanin_shape.dim(0).size(); - if (vector_size == -1) { - VLOG(3) << "The number of elements in " << fanin_node->GetName() - << " is unknown."; - return Status::OK(); - } - if (vector_size != 2 && vector_size != 4) { - return errors::InvalidArgument( - fanin_node->GetName(), " must be a vector of size 2 or 4, but found ", - vector_size); - } VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName() << "' with op '" << node->GetOp() << "' from data format '" << context->src_format << "' to '" << context->dst_format << "'"; TF_RETURN_IF_ERROR(UpdateNode(context, node)); - // Do not permute a input_sizes of size 2 because it represents HW regardless - // of whether NCHW or NHWC. - if (vector_size != 2) { - TF_RETURN_IF_ERROR( - UpdateFaninEdgesWithOp(context, {0}, node, kOpDataFormatVecPermute)); - } + TF_RETURN_IF_ERROR( + UpdateFaninEdgesWithOp(context, {0}, node, kOpDataFormatVecPermute)); TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {2}, node, kOpTranspose)); TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose)); return context->graph_view->GetMutationBuilder()->Apply(); diff --git a/tensorflow/core/grappler/optimizers/implementation_selector.cc b/tensorflow/core/grappler/optimizers/implementation_selector.cc index 37dda6ab6a3..9c4f74d7268 100644 --- a/tensorflow/core/grappler/optimizers/implementation_selector.cc +++ b/tensorflow/core/grappler/optimizers/implementation_selector.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h" #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" #include "tensorflow/core/grappler/optimizers/function_api_info.h" +#include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/grappler/utils/graph_view.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" @@ -159,6 +160,15 @@ Status UpdateNodeDef(utils::MutableNodeView* node_view, const string& funcName, } if (apiInfo.function_type() == FunctionApiInfo::BACKWARD) { + // Strip node control dependencies. We'll add them back after updating + // all the data inputs. + std::vector<std::string> control_deps; + for (int i = node_def->input_size() - 1; i >= 0; --i) { + if (!IsControlInput(node_def->input(i))) break; + control_deps.push_back(node_def->input(i)); + node_def->mutable_input()->RemoveLast(); + } + // For step 4 above. const int prev_input_size = node_def->input_size(); const int diff = prev_input_size - apiInfo.input_arg_dtypes().size(); @@ -194,6 +204,11 @@ Status UpdateNodeDef(utils::MutableNodeView* node_view, const string& funcName, for (int i = 1; i <= -diff; ++i) node_def->add_input(strings::StrCat(node_name, ":", i + last_index)); } + + // Add control dependencies back. + for (std::string& control : control_deps) + node_def->add_input(std::move(control)); + } else if (apiInfo.function_type() == FunctionApiInfo::FORWARD) { // For forward function, since the DTYPE of the intermediate state might // have been changed, we want to update the down stream Identity node if diff --git a/tensorflow/core/grappler/optimizers/mkl_remapper_test.cc b/tensorflow/core/grappler/optimizers/mkl_remapper_test.cc index 7a6b4907bf4..cf1953fcdb2 100644 --- a/tensorflow/core/grappler/optimizers/mkl_remapper_test.cc +++ b/tensorflow/core/grappler/optimizers/mkl_remapper_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/optimizers/remapper.h" #include "tensorflow/core/grappler/utils/grappler_test.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -294,6 +295,145 @@ TEST_F(MklRemapperTest, FuseDepthwiseConv2DWithBiasAndActivation) { } } +#ifdef ENABLE_MKLDNN_V1 +TEST_F(MklRemapperTest, FuseBatchNormWithRelu) { + using ::tensorflow::ops::Placeholder; + + for (bool is_training : {true, false}) { + for (bool has_side_input : {true, false}) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + + const int num_channels = 24; + + TensorShape channel_shape({num_channels}); + TensorShape empty_shape({0}); + + auto input = + Placeholder(s.WithOpName("input"), DT_FLOAT, + ops::Placeholder::Shape({2, 8, 8, num_channels})); + auto input_cast = ops::Cast(s.WithOpName("input_cast"), input, DT_FLOAT); + auto scale = Placeholder(s.WithOpName("scale"), DT_FLOAT); + auto offset = Placeholder(s.WithOpName("offset"), DT_FLOAT); + auto mean = Placeholder(s.WithOpName("mean"), DT_FLOAT); + auto var = Placeholder(s.WithOpName("var"), DT_FLOAT); + + float epsilon = 0.1f; + auto fbn = + ops::FusedBatchNormV3(s.WithOpName("fused_batch_norm"), input_cast, + scale, offset, mean, var, + ops::FusedBatchNormV3::IsTraining(is_training) + .Epsilon(epsilon) + .DataFormat("NHWC")); + + if (has_side_input) { + auto side_input = + Placeholder(s.WithOpName("side_input"), DT_FLOAT, + ops::Placeholder::Shape({2, 8, 8, num_channels})); + auto side_input_cast = + ops::Cast(s.WithOpName("side_input_cast"), side_input, DT_FLOAT); + auto add = ops::Add(s.WithOpName("add"), fbn.y, side_input_cast); + auto relu = ops::Relu(s.WithOpName("relu"), add); + } else { + auto relu = ops::Relu(s.WithOpName("relu"), fbn.y); + } + + auto input_t = GenerateRandomTensor<DT_FLOAT>({2, 8, 8, num_channels}); + auto scale_t = GenerateRandomTensor<DT_FLOAT>(channel_shape); + auto offset_t = GenerateRandomTensor<DT_FLOAT>(channel_shape); + auto mean_t = GenerateRandomTensor<DT_FLOAT>(is_training ? empty_shape + : channel_shape); + auto var_t = GenerateRandomTensor<DT_FLOAT>(is_training ? empty_shape + : channel_shape); + auto side_input_t = + GenerateRandomTensor<DT_FLOAT>({2, 8, 8, num_channels}); + + GrapplerItem item; + item.fetch = {"relu"}; + if (has_side_input) + item.feed = {{"input", input_t}, {"scale", scale_t}, + {"offset", offset_t}, {"mean", mean_t}, + {"var", var_t}, {"side_input", side_input_t}}; + else + item.feed = {{"input", input_t}, + {"scale", scale_t}, + {"offset", offset_t}, + {"mean", mean_t}, + {"var", var_t}}; + TF_ASSERT_OK(s.ToGraphDef(&item.graph)); + + // Place all nodes on CPU. + for (int i = 0; i < item.graph.node_size(); ++i) { + item.graph.mutable_node(i)->set_device("/device:CPU:0"); + } + + Remapper optimizer(RewriterConfig::AGGRESSIVE); + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + + int found = 0; + if (has_side_input) { + for (const NodeDef& node : output.node()) { + if (node.name() == "add") { + EXPECT_EQ(node.op(), "Add"); + ASSERT_EQ(node.input_size(), 2); + EXPECT_EQ(node.input(0), "fused_batch_norm"); + EXPECT_EQ(node.input(1), "side_input_cast"); + found++; + } + if (node.name() == "relu") { + EXPECT_EQ(node.op(), "Relu"); + ASSERT_EQ(node.input_size(), 1); + EXPECT_EQ(node.input(0), "add"); + found++; + } + if (node.name() == "fused_batch_norm") { + EXPECT_EQ(node.op(), "FusedBatchNormV3"); + ASSERT_EQ(node.input_size(), 5); + EXPECT_EQ(node.input(0), "input_cast"); + EXPECT_EQ(node.input(1), "scale"); + EXPECT_EQ(node.input(2), "offset"); + EXPECT_EQ(node.input(3), "mean"); + EXPECT_EQ(node.input(4), "var"); + found++; + } + } + EXPECT_EQ(found, 3); + } else { + for (const NodeDef& node : output.node()) { + if (node.name() == "relu") { + EXPECT_EQ(node.op(), "Identity"); + ASSERT_EQ(node.input_size(), 1); + EXPECT_EQ(node.input(0), "fused_batch_norm"); + found++; + } + if (node.name() == "fused_batch_norm") { + EXPECT_EQ(node.op(), "_FusedBatchNormEx"); + ASSERT_EQ(node.input_size(), 5); + EXPECT_EQ(node.input(0), "input_cast"); + EXPECT_EQ(node.input(1), "scale"); + EXPECT_EQ(node.input(2), "offset"); + EXPECT_EQ(node.input(3), "mean"); + EXPECT_EQ(node.input(4), "var"); + + auto attr = node.attr(); + EXPECT_EQ(attr["num_side_inputs"].i(), 0); + EXPECT_EQ(attr["activation_mode"].s(), "Relu"); + found++; + } + } + EXPECT_EQ(found, 2); + } + + auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed); + ASSERT_EQ(tensors_expected.size(), 1); + auto tensors = EvaluateNodes(output, item.fetch, item.feed); + ASSERT_EQ(tensors.size(), 1); + test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6); + } + } +} +#endif // ENABLE_MKLDNN_V1 + } // namespace grappler } // namespace tensorflow #endif // INTEL_MKL diff --git a/tensorflow/core/grappler/optimizers/model_pruner.cc b/tensorflow/core/grappler/optimizers/model_pruner.cc index 243ab7bd965..20db4360f73 100644 --- a/tensorflow/core/grappler/optimizers/model_pruner.cc +++ b/tensorflow/core/grappler/optimizers/model_pruner.cc @@ -33,6 +33,7 @@ limitations under the License. namespace tensorflow { namespace grappler { +namespace { bool IsTrivialIdentity(const NodeDef& node, const GraphView& graph_view) { for (const auto input : @@ -103,7 +104,9 @@ bool IsOutputPortRefValue(const NodeDef& node, int port_id, bool CanRemoveNode(const NodeDef& node, const GraphView& graph_view, const absl::flat_hash_set<string>& function_names, const OpRegistryInterface& op_registry) { - if (IsNoOp(node) && node.input().empty()) { + if (IsNoOp(node) && + (node.input().empty() || + graph_view.NumFanouts(node, /*include_controlled_nodes=*/true) == 0)) { return true; } if (IsConstant(node) && node.input().empty() && @@ -412,6 +415,8 @@ Status SplitIdentityNInputs(GraphDef* graph, return Status::OK(); } +} // namespace + Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item, GraphDef* optimized_graph) { const std::unordered_set<string> nodes_to_preserve = item.NodesToPreserve(); @@ -453,13 +458,18 @@ Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item, // Check if we can further prune the graph, by removing the trivial ops. absl::flat_hash_set<const NodeDef*> nodes_to_delete; - for (const auto& node : pruned_graph->node()) { - if (!IsTrivialOp(node, graph_view)) { + for (int i = 0; i < pruned_graph->node_size(); ++i) { + NodeDef* node = pruned_graph->mutable_node(i); + // Remove redundant control inputs, since they may prevent pruning below. + DedupControlInputs(node); + + if (!IsTrivialOp(*node, graph_view)) { + VLOG(3) << node->name() << " is not trivial."; continue; } // Don't remove nodes that must be preserved. - if (nodes_to_preserve.find(node.name()) != nodes_to_preserve.end()) { + if (nodes_to_preserve.find(node->name()) != nodes_to_preserve.end()) { continue; } @@ -477,8 +487,10 @@ Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item, // converting references to non-references. It is important to preserve // these non-references since the partitioner will avoid sending // non-references across partitions more than once. - if (CanRemoveNode(node, graph_view, function_names, *op_registry)) { - nodes_to_delete.insert(&node); + if (CanRemoveNode(*node, graph_view, function_names, *op_registry)) { + nodes_to_delete.insert(node); + } else { + VLOG(3) << node->name() << " cannot be removed"; } } diff --git a/tensorflow/core/grappler/optimizers/model_pruner_test.cc b/tensorflow/core/grappler/optimizers/model_pruner_test.cc index d2624e3d842..9beadbb7c70 100644 --- a/tensorflow/core/grappler/optimizers/model_pruner_test.cc +++ b/tensorflow/core/grappler/optimizers/model_pruner_test.cc @@ -100,12 +100,13 @@ TEST_F(ModelPrunerTest, IdentityPruning) { Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10}); Output b = ops::Sqrt(s.WithOpName("b"), {a}); - Output c = ops::Identity(s.WithOpName("c"), b); + Output c = ops::Identity(s.WithOpName("c").WithControlDependencies(b), b); Output d = ops::Identity(s.WithOpName("d"), c); Output e = ops::Sqrt(s.WithOpName("e"), {d}); TF_ASSERT_OK(s.ToGraphDef(&item.graph)); } + item.fetch.push_back("e"); ModelPruner pruner; GraphDef output; @@ -117,8 +118,6 @@ TEST_F(ModelPrunerTest, IdentityPruning) { Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10}); Output b = ops::Sqrt(s.WithOpName("b"), {a}); - Output c = ops::Identity(s.WithOpName("c"), b); - Output d = ops::Identity(s.WithOpName("d"), b); Output e = ops::Sqrt(s.WithOpName("e"), {b}); TF_ASSERT_OK(s.ToGraphDef(&expected)); @@ -126,10 +125,9 @@ TEST_F(ModelPrunerTest, IdentityPruning) { CompareGraphs(expected, output); - std::vector<string> fetch = {"e"}; - auto actual_tensors = EvaluateNodes(output, fetch); + auto actual_tensors = EvaluateNodes(output, item.fetch); ASSERT_EQ(actual_tensors.size(), 1); - auto expected_tensors = EvaluateNodes(item.graph, fetch); + auto expected_tensors = EvaluateNodes(item.graph, item.fetch); ASSERT_EQ(expected_tensors.size(), 1); test::ExpectTensorEqual<float>(actual_tensors[0], expected_tensors[0]); } diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc index 9602ea44f0d..9a7d1953105 100644 --- a/tensorflow/core/grappler/optimizers/remapper.cc +++ b/tensorflow/core/grappler/optimizers/remapper.cc @@ -797,23 +797,27 @@ bool FindFusedBatchNormEx(const RemapperContext& ctx, int node_index, const auto* fused_batch_norm_node_def = fused_batch_norm.node(); if (!IsFusedBatchNorm(*fused_batch_norm_node_def)) return false; - // We fuse FusedBatchNorm only on GPU, because on CPU we fuse it with - // contraction (MatMul or Conv2D node). +#ifndef ENABLE_MKLDNN_V1 + // We fuse FusedBatchNorm on GPU or MKL CPU. if (!NodeIsOnGpu(fused_batch_norm_node_def)) return false; +#endif DataType t_dtype = GetDataTypeFromAttr(*fused_batch_norm_node_def, "T"); +#ifndef ENABLE_MKLDNN_V1 if (t_dtype != DT_FLOAT && t_dtype != DT_HALF) return false; +#else + if (t_dtype != DT_FLOAT && t_dtype != DT_BFLOAT16) return false; +#endif // Get the FusedBatchNorm training mode. bool is_training; if (!GetNodeAttr(*fused_batch_norm_node_def, kIsTraining, &is_training) .ok()) return false; - // In training mode we rely on cuDNN for computing FusedBatchNorm with side // inputs and activation, and it has its own limitations. In inference mode // we have a custom CUDA kernel that doesn't not have these constraints. - if (is_training) { + if (is_training && NodeIsOnGpu(fused_batch_norm_node_def)) { // cuDNN only supports NHWC data layout. string data_format; if (!GetNodeAttr(*fused_batch_norm_node_def, kDataFormat, &data_format) @@ -865,6 +869,12 @@ bool FindFusedBatchNormEx(const RemapperContext& ctx, int node_index, // Input to a Relu can be an Add node with FusedBatchNorm as one of the inputs if (IsAdd(*relu_fanin_0_node_def)) { + // Currently no CPU implementation for "FusedBatchNorm + SideInput + + // <Activation>"" +#ifdef ENABLE_MKLDNN_V1 + return false; +#endif + // Check that only Relu node consumes the output of an Add node. if (HasControlFaninOrFanout(*relu_fanin_0_node_view) || !HasAtMostOneFanoutAtPort0(*relu_fanin_0_node_view) || @@ -946,12 +956,17 @@ void CopyFusedBatchNormAttributes(const NodeDef& fused_batch_norm, (*attr)["is_training"] = src_attr.at("is_training"); (*attr)["data_format"] = src_attr.at("data_format"); (*attr)["epsilon"] = src_attr.at("epsilon"); + (*attr)["exponential_avg_factor"] = src_attr.at("exponential_avg_factor"); // FusedBatchNormV2 and V3 have an extra type parameter. if (fused_batch_norm.op() != "FusedBatchNorm") { - (*attr)["U"] = src_attr.at("U"); + SetAttrValue(src_attr.at("U"), &(*attr)["U"]); } else { - (*attr)["U"] = src_attr.at("T"); +#ifndef ENABLE_MKLDNN_V1 + SetAttrValue(src_attr.at("T"), &(*attr)["U"]); +#else + SetAttrValue(DT_FLOAT, &(*attr)["U"]); +#endif } } diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 788924e8b37..c1fc17079c8 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -1338,7 +1338,10 @@ tf_kernel_library( "tile_functor_cpu_int64.cc", "tile_functor_cpu_int8.cc", "tile_functor_cpu_tstring.cc", + "tile_functor_cpu_uint32.cc", + "tile_functor_cpu_uint64.cc", "tile_functor_cpu_uint8.cc", + "tile_functor_cpu_variant.cc", "tile_functor_sycl.cc", ], hdrs = ["tile_functor.h"], @@ -6613,6 +6616,7 @@ filegroup( "avgpooling_op.h", "batch_matmul_op_impl.h", "batch_norm_op.h", + "broadcast_to_op.h", "control_flow_ops.h", "conv_2d.h", "conv_3d.h", @@ -6702,6 +6706,7 @@ filegroup( "conv_ops_fused_float.cc", "conv_ops_fused_half.cc", "conv_ops_fused_impl.h", + "conv_ops_fused_image_transform.cc", "conv_ops_using_gemm.cc", "crop_and_resize_op.cc", "crop_and_resize_op.h", @@ -6711,6 +6716,7 @@ filegroup( "cwise_op_bitwise_and.cc", "cwise_op_bitwise_or.cc", "cwise_op_bitwise_xor.cc", + "cwise_op_ceil.cc", "cwise_op_conj.cc", "cwise_op_cos.cc", "cwise_op_cosh.cc", @@ -6803,6 +6809,7 @@ filegroup( name = "android_extended_ops_group2", srcs = [ "batchtospace_op.cc", + "broadcast_to_op.cc", "ctc_decoder_ops.cc", "decode_bmp_op.cc", "depthtospace_op.cc", @@ -6906,7 +6913,10 @@ filegroup( "tile_functor_cpu_int64.cc", "tile_functor_cpu_int8.cc", "tile_functor_cpu_tstring.cc", + "tile_functor_cpu_uint32.cc", + "tile_functor_cpu_uint64.cc", "tile_functor_cpu_uint8.cc", + "tile_functor_cpu_variant.cc", "tile_ops.cc", "tile_ops_cpu_impl_1.cc", "tile_ops_cpu_impl_2.cc", @@ -8227,7 +8237,10 @@ tf_mkl_kernel_library( tf_mkl_kernel_library( name = "mkl_fused_batch_norm_op", srcs = ["mkl_fused_batch_norm_op.cc"], - deps = NN_DEPS + mkl_deps(), + deps = NN_DEPS + [ + ":fused_batch_norm_op", + ":no_op", + ] + mkl_deps(), ) tf_cc_test_mkl( diff --git a/tensorflow/core/kernels/attention_ops.cc b/tensorflow/core/kernels/attention_ops.cc index f555c0fd679..6e5e07a9fb1 100644 --- a/tensorflow/core/kernels/attention_ops.cc +++ b/tensorflow/core/kernels/attention_ops.cc @@ -32,6 +32,8 @@ namespace tensorflow { class ExtractGlimpseOp : public OpKernel { public: explicit ExtractGlimpseOp(OpKernelConstruction* context) : OpKernel(context) { + const string& op = context->def().op(); + version_ = (op == "ExtractGlimpse") ? 1 : 2; OP_REQUIRES_OK(context, context->GetAttr("normalized", &normalized_)); OP_REQUIRES_OK(context, context->GetAttr("centered", ¢ered_)); bool uniform_noise = false; @@ -117,21 +119,23 @@ class ExtractGlimpseOp : public OpKernel { // calling TensorFlow operates with (y,x) as indices. offset_vec.push_back(Eigen::IndexPair<float>(offset_x, offset_y)); } - output->tensor<float, 4>().swap_layout().device( context->eigen_cpu_device()) = Eigen::ExtractGlimpses(input.tensor<float, 4>().swap_layout(), output_width, output_height, offset_vec, - normalized_, centered_, noise_); + normalized_, centered_, noise_, version_); } private: bool normalized_; bool centered_; Eigen::ExtractGlimpsesNoiseMode noise_; + int32 version_; }; REGISTER_KERNEL_BUILDER(Name("ExtractGlimpse").Device(DEVICE_CPU), ExtractGlimpseOp); +REGISTER_KERNEL_BUILDER(Name("ExtractGlimpseV2").Device(DEVICE_CPU), + ExtractGlimpseOp); } // end namespace tensorflow diff --git a/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary.h b/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary.h index 5690c3a6014..a22af7ab71e 100644 --- a/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary.h +++ b/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary.h @@ -16,6 +16,7 @@ #define TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_ #include <cstring> +#include <list> #include <vector> #include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_buffer.h" @@ -250,10 +251,37 @@ class WeightedQuantilesSummary { float compression_eps = ApproximationError() + (1.0 / num_boundaries); compressed_summary.Compress(num_boundaries, compression_eps); + // Remove the least important boundaries by the gap removing them would + // create. + std::list<int64> boundaries_to_keep; + for (int64 i = 0; i != compressed_summary.entries_.size(); ++i) { + boundaries_to_keep.push_back(i); + } + while (boundaries_to_keep.size() > num_boundaries) { + std::list<int64>::iterator min_element = boundaries_to_keep.end(); + auto prev = boundaries_to_keep.begin(); + auto curr = prev; + ++curr; + auto next = curr; + ++next; + WeightType min_weight = TotalWeight(); + for (; next != boundaries_to_keep.end(); ++prev, ++curr, ++next) { + WeightType new_weight = + compressed_summary.entries_[*next].PrevMaxRank() - + compressed_summary.entries_[*prev].NextMinRank(); + if (new_weight < min_weight) { + min_element = curr; + min_weight = new_weight; + } + } + boundaries_to_keep.erase(min_element); + } + // Return boundaries. - output.reserve(compressed_summary.entries_.size()); - for (const auto& entry : compressed_summary.entries_) { - output.push_back(entry.value); + output.reserve(boundaries_to_keep.size()); + for (auto itr = boundaries_to_keep.begin(); itr != boundaries_to_keep.end(); + ++itr) { + output.push_back(compressed_summary.entries_[*itr].value); } return output; } diff --git a/tensorflow/core/kernels/conv_2d_gpu.h b/tensorflow/core/kernels/conv_2d_gpu.h index 31abe9dfead..85ca2b5722a 100644 --- a/tensorflow/core/kernels/conv_2d_gpu.h +++ b/tensorflow/core/kernels/conv_2d_gpu.h @@ -210,6 +210,57 @@ __global__ void ShuffleInTensor3Simple(int nthreads, } } +static constexpr int kUnroll = 4; + +template <typename T, int sp0, int sp1, int sp2, bool conjugate = false> +__global__ void ShuffleInTensor3SimpleVector(int nthreads, + const T* __restrict__ input, + Dimension<3> input_dims, + T* __restrict__ output) { + Dimension<3> output_dims; + output_dims[sp0] = input_dims[0]; + output_dims[sp1] = input_dims[1]; + output_dims[sp2] = input_dims[2]; + + const int stride = blockDim.x * gridDim.x * kUnroll; + const int tid = blockIdx.x * blockDim.x + threadIdx.x; + T buf[kUnroll]; + + int output_index; + for (output_index = tid * kUnroll; output_index + kUnroll - 1 < nthreads; + output_index += stride) { +#pragma unroll + for (int i = 0; i < kUnroll; i++) { + int output_index_i = output_index + i; + Index<3> output_tensor_index = + FlatToTensorIndex(output_index_i, output_dims); + Index<3> input_tensor_index; + input_tensor_index[0] = output_tensor_index[sp0]; + input_tensor_index[1] = output_tensor_index[sp1]; + input_tensor_index[2] = output_tensor_index[sp2]; + + int input_index_i = TensorIndexToFlat(input_tensor_index, input_dims); + buf[i] = maybe_conj<T, conjugate>::run(ldg(input + input_index_i)); + } + float2* out = reinterpret_cast<float2*>(output + output_index); + *out = *reinterpret_cast<float2*>(buf); + } + + for (; output_index < nthreads; ++output_index) { + Index<3> output_tensor_index = FlatToTensorIndex(output_index, output_dims); + + Index<3> input_tensor_index; + input_tensor_index[0] = output_tensor_index[sp0]; + input_tensor_index[1] = output_tensor_index[sp1]; + input_tensor_index[2] = output_tensor_index[sp2]; + + int input_index = TensorIndexToFlat(input_tensor_index, input_dims); + + output[output_index] = + maybe_conj<T, conjugate>::run(ldg(input + input_index)); + } +} + // Use shared memory tiles to swap dimension-1 and dimension-2 of a 3D tensor, // where dimensions are zero-based: output[i][j][k] = input[i][k][j]. // @@ -1008,10 +1059,40 @@ struct SwapDimension0And2InTensor3<GPUDevice, T, conjugate> { static_cast<int>(combined_dims[2])}; size_t total_size = combined_dims[0] * combined_dims[1] * combined_dims[2]; GpuLaunchConfig config = GetGpuLaunchConfig(total_size, d); - TF_CHECK_OK(GpuLaunchKernel(ShuffleInTensor3Simple<T, 2, 1, 0, conjugate>, - config.block_count, config.thread_per_block, 0, - d.stream(), config.virtual_thread_count, in, - input_dims, out)); + + auto out_ptr = reinterpret_cast<uintptr_t>(out); + bool aligned = out_ptr % 16 == 0; + + bool use_vector = false; + bool use_custom_config = false; + if ((input_dims[0] <= 128 && input_dims[2] <= 128) || + input_dims[0] * input_dims[1] <= 128 || + input_dims[1] * input_dims[2] <= 8) { + use_vector = true; + use_custom_config = true; + } else if (input_dims[1] * input_dims[2] <= 16384) { + use_vector = true; + } + + if (sizeof(T) == 2 && aligned && use_vector) { + int block_count; + if (use_custom_config) { + block_count = (total_size + config.thread_per_block - 1) / + config.thread_per_block; + } else { + block_count = config.block_count; + } + + TF_CHECK_OK( + GpuLaunchKernel(ShuffleInTensor3SimpleVector<T, 2, 1, 0, conjugate>, + block_count, config.thread_per_block / kUnroll, 0, + d.stream(), total_size, in, input_dims, out)); + } else { + TF_CHECK_OK(GpuLaunchKernel(ShuffleInTensor3Simple<T, 2, 1, 0, conjugate>, + config.block_count, config.thread_per_block, + 0, d.stream(), config.virtual_thread_count, + in, input_dims, out)); + } } }; diff --git a/tensorflow/core/kernels/cubin_headers/BUILD b/tensorflow/core/kernels/cubin_headers/BUILD index bb7995dd221..509ac008355 100644 --- a/tensorflow/core/kernels/cubin_headers/BUILD +++ b/tensorflow/core/kernels/cubin_headers/BUILD @@ -45,3 +45,23 @@ func @relu(%arg0: tensor<?xf99>) -> tensor<?xf99> { ("f64", "DT_DOUBLE"), ] ] + +tanh_kernel = """ +func @tanh(%arg0: tensor<?xf99>) -> tensor<?xf99> { + %0 = "tf.Tanh"(%arg0) { T = "tfdtype$DT_TYPE" } + : (tensor<?xf99>) -> tensor<?xf99> + return %0 : tensor<?xf99> +} +""" + +[ + gen_kernel_image_hdr( + name = "tanh_{type}_kernel".format(type = type), + op = tanh_kernel.replace("f99", type).replace("DT_TYPE", dtype), + tile_size = "256", + ) + for (type, dtype) in [ + ("f32", "DT_FLOAT"), + ("f64", "DT_DOUBLE"), + ] +] diff --git a/tensorflow/core/kernels/cubin_headers/build_defs.bzl b/tensorflow/core/kernels/cubin_headers/build_defs.bzl index 14f47601f06..f9dac50591a 100644 --- a/tensorflow/core/kernels/cubin_headers/build_defs.bzl +++ b/tensorflow/core/kernels/cubin_headers/build_defs.bzl @@ -22,6 +22,8 @@ def _gen_kernel_image_hdr_impl(ctx): cubins = [] images = [] for arch in ctx.attr.gpu_archs: + # TODO(b/152737872): 'compute_' should generate both SASS and PTX. + arch = arch.replace("compute_", "sm_") filename = "%s.%s.cubin" % (name, arch) cubin = ctx.actions.declare_file(filename) ctx.actions.run( diff --git a/tensorflow/core/kernels/data/captured_function.cc b/tensorflow/core/kernels/data/captured_function.cc index 28738e3e2fe..dd64475d7d6 100644 --- a/tensorflow/core/kernels/data/captured_function.cc +++ b/tensorflow/core/kernels/data/captured_function.cc @@ -466,17 +466,15 @@ Status FunctionMetadata::Create( auto attr = fdef->attr().find(FunctionLibraryDefinition::kIntsOnDeviceAttr); if (attr != fdef->attr().end() && attr->second.b()) { - LOG(WARNING) - << "Disabling multi-device execution for a function that uses the " - << FunctionLibraryDefinition::kIntsOnDeviceAttr << " attribute."; + VLOG(1) << "Disabling multi-device execution for a function that uses the " + << FunctionLibraryDefinition::kIntsOnDeviceAttr << " attribute."; (*out_metadata)->use_multi_device_function_ = false; return Status::OK(); } auto validate_arg = [](const OpDef::ArgDef& arg) { if (!arg.number_attr().empty() || !arg.type_list_attr().empty()) { - LOG(WARNING) << "Disabling multi-device execution for a function with " - "a vector argument " - << arg.name() << "."; + VLOG(1) << "Disabling multi-device execution for a function with " + << "a vector argument " << arg.name() << "."; return false; } return true; @@ -562,8 +560,7 @@ Status CapturedFunction::Instantiate( if (!metadata_->use_inter_op_parallelism()) { inst_opts.executor_type = "SINGLE_THREADED_EXECUTOR"; } - bool is_multi_device = false; - TF_RETURN_IF_ERROR(IsMultiDevice(ctx, &is_multi_device)); + bool is_multi_device = metadata_->use_multi_device_function(); inst_opts.is_multi_device_function = is_multi_device; // We infer the target device from the function library runtime. @@ -866,77 +863,5 @@ CapturedFunction::CapturedFunction( : metadata_(std::move(metadata)), captured_inputs_(std::move(captured_inputs)) {} -Status CapturedFunction::IsMultiDevice(IteratorContext* ctx, - bool* is_multi_device) { - if (!metadata_->use_multi_device_function()) { - *is_multi_device = false; - return Status::OK(); - } - - const FunctionDef* fdef; - TF_RETURN_IF_ERROR( - LookupFunction(*metadata_->lib_def(), metadata_->func().name(), &fdef)); - - Device* current_device = ctx->flr()->device(); - DeviceType current_device_type(current_device->device_type()); - DeviceNameUtils::ParsedName current_device_name; - if (!DeviceNameUtils::ParseFullName(current_device->name(), - ¤t_device_name)) { - return errors::InvalidArgument("Failed to parse device name: ", - current_device->name()); - } - - // Check if any of the captured inputs are placed on a device not compatible - // with the current device. For non-captured inputs, we assume they are placed - // on the current device. - for (const auto& input : captured_inputs_) { - DataType dtype = input.dtype(); - if (dtype == DT_RESOURCE) { - const ResourceHandle& handle = input.flat<ResourceHandle>()(0); - DeviceNameUtils::ParsedName resource_device_name; - if (!DeviceNameUtils::ParseFullName(handle.device(), - &resource_device_name)) { - return errors::InvalidArgument("Failed to parse device name: ", - handle.device()); - } - if (!DeviceNameUtils::AreCompatibleDevNames(current_device_name, - resource_device_name)) { - *is_multi_device = true; - return Status::OK(); - } - } - } - - // Check if all ops could be placed on the current device. - for (const auto& name : metadata_->lib_def()->ListFunctionNames()) { - const FunctionDef* fdef; - TF_RETURN_IF_ERROR(LookupFunction(*metadata_->lib_def(), name, &fdef)); - for (const auto& node : fdef->node_def()) { - // Check if the op has a kernel available for the current device. - if (!KernelDefAvailable(current_device_type, node)) { - *is_multi_device = true; - return Status::OK(); - } - // If the op has a requested device, check if the requested device is - // compatible with the current device. - if (!node.device().empty()) { - DeviceNameUtils::ParsedName node_device_name; - if (!DeviceNameUtils::ParseFullName(node.device(), &node_device_name)) { - return errors::InvalidArgument("Failed to parse device name: ", - node.device()); - } - if (!DeviceNameUtils::AreCompatibleDevNames(current_device_name, - node_device_name)) { - *is_multi_device = true; - return Status::OK(); - } - } - } - } - - *is_multi_device = false; - return Status::OK(); -} - } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/captured_function.h b/tensorflow/core/kernels/data/captured_function.h index 284a02091dd..de424fc547c 100644 --- a/tensorflow/core/kernels/data/captured_function.h +++ b/tensorflow/core/kernels/data/captured_function.h @@ -256,10 +256,6 @@ class CapturedFunction { CapturedFunction(std::shared_ptr<const FunctionMetadata> metadata, std::vector<Tensor> captured_inputs); - // Determines whether the captured function requires the use of the - // multi-device function backend. - Status IsMultiDevice(IteratorContext* ctx, bool* is_multi_device); - const std::shared_ptr<const FunctionMetadata> metadata_; const std::vector<Tensor> captured_inputs_; diff --git a/tensorflow/core/kernels/data/experimental/BUILD b/tensorflow/core/kernels/data/experimental/BUILD index 4ddfd99951c..a9790fd99a4 100644 --- a/tensorflow/core/kernels/data/experimental/BUILD +++ b/tensorflow/core/kernels/data/experimental/BUILD @@ -109,6 +109,20 @@ tf_kernel_library( ], ) +tf_kernel_library( + name = "compression_ops", + srcs = ["compression_ops.cc"], + hdrs = ["compression_ops.h"], + deps = [ + "//tensorflow/core:experimental_dataset_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core/data:compression_utils", + "//tensorflow/core/data:dataset_proto_cc", + ], +) + tf_kernel_library( name = "csv_dataset_op", srcs = ["csv_dataset_op.cc"], @@ -131,8 +145,8 @@ tf_kernel_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", - "//tensorflow/core/data/service:common_proto_cc", - "//tensorflow/core/data/service:compression_utils", + "//tensorflow/core/data:compression_utils", + "//tensorflow/core/data:dataset_proto_cc", "//tensorflow/core/data/service:data_service", "//tensorflow/core/distributed_runtime/rpc:grpc_util", "//tensorflow/core/kernels/data:dataset_utils", @@ -523,6 +537,7 @@ cc_library( "//tensorflow/core/platform:random", "//tensorflow/core/profiler/lib:traceme", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:str_format", ], ) @@ -681,6 +696,7 @@ tf_kernel_library( ":auto_shard_dataset_op", ":choose_fastest_branch_dataset_op", ":choose_fastest_dataset_op", + ":compression_ops", ":csv_dataset_op", ":dense_to_sparse_batch_dataset_op", ":directed_interleave_dataset_op", diff --git a/tensorflow/core/kernels/data/experimental/compression_ops.cc b/tensorflow/core/kernels/data/experimental/compression_ops.cc new file mode 100644 index 00000000000..efa7018acb6 --- /dev/null +++ b/tensorflow/core/kernels/data/experimental/compression_ops.cc @@ -0,0 +1,76 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/data/experimental/compression_ops.h" + +#include "tensorflow/core/data/compression_utils.h" +#include "tensorflow/core/platform/errors.h" + +namespace tensorflow { +namespace data { +namespace experimental { + +CompressElementOp::CompressElementOp(OpKernelConstruction* ctx) + : OpKernel(ctx) {} + +void CompressElementOp::Compute(OpKernelContext* ctx) { + std::vector<Tensor> components; + for (size_t i = 0; i < ctx->num_inputs(); ++i) { + components.push_back(ctx->input(i)); + } + CompressedElement compressed; + OP_REQUIRES_OK(ctx, CompressElement(components, &compressed)); + + Tensor* output; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output)); + output->scalar<Variant>()() = std::move(compressed); +} + +UncompressElementOp::UncompressElementOp(OpKernelConstruction* ctx) + : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_)); +} + +void UncompressElementOp::Compute(OpKernelContext* ctx) { + Tensor tensor = ctx->input(0); + const Variant& variant = tensor.scalar<Variant>()(); + const CompressedElement* compressed = variant.get<CompressedElement>(); + + std::vector<Tensor> components; + OP_REQUIRES_OK(ctx, UncompressElement(*compressed, &components)); + OP_REQUIRES(ctx, components.size() == output_types_.size(), + errors::FailedPrecondition("Expected ", output_types_.size(), + " outputs from uncompress, but got ", + components.size())); + for (int i = 0; i < components.size(); ++i) { + OP_REQUIRES( + ctx, components[i].dtype() == output_types_[i], + errors::FailedPrecondition("Expected a tensor of type ", + DataTypeString(output_types_[i]), + " but got a tensor of type ", + DataTypeString(components[i].dtype()))); + ctx->set_output(i, components[i]); + } +} + +REGISTER_KERNEL_BUILDER(Name("CompressElement").Device(DEVICE_CPU), + CompressElementOp); +REGISTER_KERNEL_BUILDER(Name("UncompressElement").Device(DEVICE_CPU), + UncompressElementOp); + +} // namespace experimental +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/experimental/compression_ops.h b/tensorflow/core/kernels/data/experimental/compression_ops.h new file mode 100644 index 00000000000..6dd89ea4e5d --- /dev/null +++ b/tensorflow/core/kernels/data/experimental/compression_ops.h @@ -0,0 +1,49 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_COMPRESSION_OPS_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_COMPRESSION_OPS_H_ + +#include "tensorflow/core/framework/dataset.h" + +namespace tensorflow { +namespace data { +namespace experimental { + +class CompressElementOp : public OpKernel { + public: + explicit CompressElementOp(OpKernelConstruction* ctx); + + void Compute(OpKernelContext* ctx) override; +}; + +class UncompressElementOp : public OpKernel { + public: + static constexpr const char* const kOutputTypes = "output_types"; + static constexpr const char* const kOutputShapes = "output_shapes"; + + explicit UncompressElementOp(OpKernelConstruction* ctx); + + void Compute(OpKernelContext* ctx) override; + + private: + DataTypeVector output_types_; + std::vector<PartialTensorShape> output_shapes_; +}; + +} // namespace experimental +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_COMPRESSION_OPS_H_ diff --git a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc index 56077a671fb..a106bcb0a7c 100644 --- a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc @@ -21,14 +21,14 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" -#include "tensorflow/core/data/service/common.pb.h" -#include "tensorflow/core/data/service/compression_utils.h" +#include "tensorflow/core/data/dataset.pb.h" #include "tensorflow/core/data/service/data_service.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/model.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/kernels/data/name_utils.h" #include "tensorflow/core/kernels/data/serialization_utils.h" @@ -496,7 +496,9 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { std::vector<Tensor> element; if (!end_of_sequence) { - TF_RETURN_IF_ERROR(service_util::Uncompress(compressed, &element)); + Tensor tensor(DT_VARIANT, TensorShape{}); + tensor.scalar<Variant>()() = std::move(compressed); + element.push_back(tensor); } mutex_lock l(mu_); if (end_of_sequence) { diff --git a/tensorflow/core/kernels/data/experimental/snapshot_util.cc b/tensorflow/core/kernels/data/experimental/snapshot_util.cc index 6c4d6424146..31d1a87087e 100644 --- a/tensorflow/core/kernels/data/experimental/snapshot_util.cc +++ b/tensorflow/core/kernels/data/experimental/snapshot_util.cc @@ -18,6 +18,7 @@ limitations under the License. #include <queue> #include "absl/memory/memory.h" +#include "absl/strings/str_format.h" #include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/graph.pb.h" @@ -31,6 +32,7 @@ limitations under the License. #include "tensorflow/core/lib/io/zlib_inputstream.h" #include "tensorflow/core/lib/io/zlib_outputbuffer.h" #include "tensorflow/core/platform/coding.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/file_system.h" #include "tensorflow/core/platform/path.h" #include "tensorflow/core/platform/random.h" @@ -44,6 +46,12 @@ namespace snapshot_util { /* static */ constexpr const int64 Reader::kSnappyReaderInputBufferSizeBytes; /* static */ constexpr const int64 Reader::kSnappyReaderOutputBufferSizeBytes; +std::string GetCurrentCheckpointFile(const std::string& shard_directory, + const uint64 current_checkpoint_id) { + return io::JoinPath(shard_directory, + absl::StrFormat("%08d.snapshot", current_checkpoint_id)); +} + Writer::Writer(const std::string& filename, const std::string& compression_type, int version, const DataTypeVector& dtypes) : filename_(filename), @@ -62,7 +70,7 @@ Status Writer::Create(Env* env, const std::string& filename, } Status Writer::Initialize(tensorflow::Env* env) { - TF_RETURN_IF_ERROR(env->NewWritableFile(filename_, &dest_)); + TF_RETURN_IF_ERROR(env->NewAppendableFile(filename_, &dest_)); #if defined(IS_SLIM_BUILD) if (compression_type_ != io::compression::kNone) { LOG(ERROR) << "Compression is unsupported on mobile platforms. Turning " @@ -225,16 +233,17 @@ Status Reader::Create(Env* env, const std::string& filename, class Reader::Dataset : public DatasetBase { public: - explicit Dataset(const std::string& filename, const std::string& compression, + explicit Dataset(const std::string& shard_dir, const std::string& compression, const int64 version, const DataTypeVector& dtypes, const std::vector<PartialTensorShape>& shapes, - DatasetContext::Params params) + const int64 start_index, DatasetContext::Params params) : DatasetBase(DatasetContext(std::move(params))), - filename_(filename), + shard_dir_(shard_dir), compression_(compression), version_(version), dtypes_(dtypes), - shapes_(shapes) {} + shapes_(shapes), + start_index_(start_index) {} const DataTypeVector& output_dtypes() const override { return dtypes_; } @@ -252,7 +261,8 @@ class Reader::Dataset : public DatasetBase { Status AsGraphDefInternal(SerializationContext* ctx, DatasetGraphDefBuilder* b, Node** node) const override { - // TODO(frankchn): Implement for serialization and checkpointing. + // Not necessary perform any serialization as this dataset is only + // constructed at runtime in C++ and will be reconstructed every time. return Status::OK(); } @@ -263,21 +273,29 @@ class Reader::Dataset : public DatasetBase { } private: - std::string filename_; - std::string compression_; - int64 version_; - DataTypeVector dtypes_; - std::vector<PartialTensorShape> shapes_; + const std::string shard_dir_; + const std::string compression_; + const int64 version_; + const DataTypeVector dtypes_; + const std::vector<PartialTensorShape> shapes_; + const int64 start_index_; class Iterator : public DatasetIterator<Dataset> { public: explicit Iterator(const Params& params) - : DatasetIterator<Dataset>(params) {} + : DatasetIterator<Dataset>(params), current_checkpoint_id_(0) {} Status Initialize(IteratorContext* ctx) override { - return Reader::Create(ctx->env(), dataset()->filename_, - dataset()->compression_, dataset()->version_, - dataset()->dtypes_, &reader_); + TF_RETURN_IF_ERROR(Reader::Create( + ctx->env(), GetCurrentFilename(), dataset()->compression_, + dataset()->version_, dataset()->dtypes_, &reader_)); + bool end_of_sequence; + for (int64 i = 0; i < dataset()->start_index_; ++i) { + // TODO(frankchn): Optimize this to not parse every single element. + std::vector<Tensor> unused; + TF_RETURN_IF_ERROR(GetNextInternal(ctx, &unused, &end_of_sequence)); + } + return Status::OK(); } protected: @@ -286,27 +304,53 @@ class Reader::Dataset : public DatasetBase { bool* end_of_sequence) override { *end_of_sequence = false; Status s = reader_->ReadTensors(out_tensors); - if (errors::IsOutOfRange(s)) { + if (!errors::IsOutOfRange(s)) { + return s; + } + Status status = AdvanceToNextFile(ctx->env()); + if (errors::IsNotFound(status)) { *end_of_sequence = true; return Status::OK(); + } else { + return status; } - return s; } Status SaveInternal(SerializationContext* ctx, IteratorStateWriter* writer) override { - // TODO(frankchn): Implement for serialization and checkpointing. + // Not necessary to save any state as this iterator will be reconstructed + // from scratch when the parent snapshot dataset is restored from + // checkpoint. return Status::OK(); } Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { - // TODO(frankchn): Implement for serialization and checkpointing. + // Not necessary to restore any state as this iterator will be + // reconstructed from scratch when the parent snapshot dataset is restored + // from checkpoint. return Status::OK(); } private: std::unique_ptr<Reader> reader_; + + // Stores the id current checkpoint file that we are in the process of + // reading (e.g. if the file is currently 00000001.snapshot, then this will + // be 1). + uint64 current_checkpoint_id_; + + std::string GetCurrentFilename() { + return GetCurrentCheckpointFile(dataset()->shard_dir_, + current_checkpoint_id_); + } + + Status AdvanceToNextFile(Env* env) { + current_checkpoint_id_++; + TF_RETURN_IF_ERROR(env->FileExists(GetCurrentFilename())); + return Reader::Create(env, GetCurrentFilename(), dataset()->compression_, + dataset()->version_, dataset()->dtypes_, &reader_); + } }; }; @@ -337,7 +381,8 @@ class Reader::NestedDataset : public DatasetBase { Status AsGraphDefInternal(SerializationContext* ctx, DatasetGraphDefBuilder* b, Node** node) const override { - // TODO(frankchn): Implement for serialization and checkpointing. + // Not necessary perform any serialization as this dataset is only + // constructed at runtime in C++ and will be reconstructed every time. return Status::OK(); } @@ -377,13 +422,17 @@ class Reader::NestedDataset : public DatasetBase { Status SaveInternal(SerializationContext* ctx, IteratorStateWriter* writer) override { - // TODO(frankchn): Implement for serialization and checkpointing. + // Not necessary to save any state as this iterator will be reconstructed + // from scratch when the parent snapshot dataset is restored from + // checkpoint. return Status::OK(); } Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { - // TODO(frankchn): Implement for serialization and checkpointing. + // Not necessary to restore any state as this iterator will be + // reconstructed from scratch when the parent snapshot dataset is restored + // from checkpoint. return Status::OK(); } @@ -393,21 +442,36 @@ class Reader::NestedDataset : public DatasetBase { }; Status Reader::MakeNestedDataset(Env* env, - const std::vector<std::string>& filenames, + const std::vector<std::string>& shard_dirs, const string& compression_type, int version, const DataTypeVector& dtypes, const std::vector<PartialTensorShape>& shapes, + const int64 start_index, DatasetBase** output) { std::vector<DatasetBase*> datasets; - datasets.reserve(filenames.size()); - for (const auto& filename : filenames) { + datasets.reserve(shard_dirs.size()); + for (const auto& shard_dir : shard_dirs) { + // TODO(frankchn): The reading pattern could be controlled in a non-round + // robin fashion, so we cannot assume a round-robin manner when restoring. + int64 dataset_start_index = start_index / shard_dirs.size(); + if (start_index % shard_dirs.size() > datasets.size()) { + dataset_start_index++; + } + datasets.push_back( - new Dataset(filename, compression_type, version, dtypes, shapes, + new Dataset(shard_dir, compression_type, version, dtypes, shapes, + dataset_start_index, DatasetContext::Params({"snapshot_util::Reader::Dataset", "snapshot_util_reader_Dataset"}))); } + // Rotate the vector such that the first dataset contains the next element + // to be produced. + std::rotate(datasets.begin(), + datasets.begin() + (start_index % shard_dirs.size()), + datasets.end()); + *output = new NestedDataset( datasets, DatasetContext::Params({"snapshot_util::Reader::NestedDataset", "snapshot_util_reader_NestedDataset"})); @@ -463,6 +527,15 @@ Status Reader::Initialize(Env* env) { return Status::OK(); } +Status Reader::SkipRecords(int64 num_records) { + // TODO(frankchn): Optimize to not parse the entire Tensor and actually skip. + for (int i = 0; i < num_records; ++i) { + std::vector<Tensor> unused_tensors; + TF_RETURN_IF_ERROR(ReadTensors(&unused_tensors)); + } + return Status::OK(); +} + Status Reader::ReadTensors(std::vector<Tensor>* read_tensors) { profiler::TraceMe activity( [&]() { return absl::StrCat(kClassName, kSeparator, "ReadTensors"); }, diff --git a/tensorflow/core/kernels/data/experimental/snapshot_util.h b/tensorflow/core/kernels/data/experimental/snapshot_util.h index dd15c591a22..a6455a85393 100644 --- a/tensorflow/core/kernels/data/experimental/snapshot_util.h +++ b/tensorflow/core/kernels/data/experimental/snapshot_util.h @@ -49,6 +49,9 @@ constexpr char kModePassthrough[] = "passthrough"; enum Mode { READER = 0, WRITER = 1, PASSTHROUGH = 2 }; +std::string GetCurrentCheckpointFile(const std::string& shard_directory, + const uint64 current_checkpoint_id); + class Writer { public: static constexpr const size_t kHeaderSize = sizeof(uint64); @@ -126,14 +129,17 @@ class Reader { // dataset. Each element within the nested dataset is itself a dataset, and // contains all the elements written out to each individual snapshot file. static Status MakeNestedDataset(Env* env, - const std::vector<std::string>& filenames, + const std::vector<std::string>& shard_dirs, const string& compression_type, int version, const DataTypeVector& dtypes, const std::vector<PartialTensorShape>& shapes, + const int64 start_index, DatasetBase** output); Status ReadTensors(std::vector<Tensor>* read_tensors); + Status SkipRecords(int64 num_records); + private: explicit Reader(const std::string& filename, const string& compression_type, int version, const DataTypeVector& dtypes); diff --git a/tensorflow/core/kernels/data_format_ops.cc b/tensorflow/core/kernels/data_format_ops.cc index 0b4241dbb93..181aa1b8a2c 100644 --- a/tensorflow/core/kernels/data_format_ops.cc +++ b/tensorflow/core/kernels/data_format_ops.cc @@ -90,16 +90,15 @@ class DataFormatVecPermuteOp : public OpKernel { "input must be a vector or 2D tensor, but got shape ", input.shape().DebugString())); if (input.dims() == 1) { - OP_REQUIRES( - context, input.NumElements() == 4, - errors::InvalidArgument("1D input must be of size 4, but got shape ", - input.shape().DebugString())); + OP_REQUIRES(context, input.NumElements() == 2 || input.NumElements() == 4, + errors::InvalidArgument( + "1D input must be of size 2 or 4, but got shape ", + input.shape().DebugString())); } else if (input.dims() == 2) { - OP_REQUIRES( - context, input.dim_size(0) == 4, - errors::InvalidArgument( - "First dimension of 2D input must be of size 4, but got shape ", - input.shape().DebugString())); + OP_REQUIRES(context, input.dim_size(0) == 2 || input.dim_size(0) == 4, + errors::InvalidArgument("First dimension of 2D input must be " + "of size 2 or 4, but got shape ", + input.shape().DebugString())); OP_REQUIRES( context, input.dim_size(1) == 2, errors::InvalidArgument( @@ -112,7 +111,21 @@ class DataFormatVecPermuteOp : public OpKernel { context->allocate_output(0, input.shape(), &output)); // Support 1D and 2D cases. Eigen::DSizes<Eigen::DenseIndex, 8> dst_idx; - ComputeDstIndex(input.dims(), &dst_idx); + string src_format_str = src_format_; + string dst_format_str = dst_format_; + if (input.dim_size(0) == 2) { + // If the input is a vector of size 2, treat the two elements as spatial + // dimensions. + auto keep_only_spatial_dimensions = [](string* format_str) -> void { + auto new_end = std::remove_if( + format_str->begin(), format_str->end(), + [](const char dim) { return dim != 'H' && dim != 'W'; }); + format_str->erase(new_end, format_str->end()); + }; + keep_only_spatial_dimensions(&src_format_str); + keep_only_spatial_dimensions(&dst_format_str); + } + ComputeDstIndex(src_format_str, dst_format_str, input.dims(), &dst_idx); functor::DataFormatVecPermute<Device, T>()(context->eigen_device<Device>(), input.flat<T>(), @@ -124,10 +137,12 @@ class DataFormatVecPermuteOp : public OpKernel { // Example: HWNC --> NHWC // 1D: dst = [1, 2, 0, 3], // 2D: dst = [2, 3, 4, 5, 0, 1, 6, 7] - void ComputeDstIndex(int num_dim, Eigen::DSizes<Eigen::DenseIndex, 8>* dst) { - for (int i = 0; i < src_format_.size(); ++i) { - for (int j = 0; j < dst_format_.size(); ++j) { - if (dst_format_[j] != src_format_[i]) continue; + static void ComputeDstIndex(const string& src_format_str, + const string& dst_format_str, int num_dim, + Eigen::DSizes<Eigen::DenseIndex, 8>* dst) { + for (int i = 0; i < src_format_str.size(); ++i) { + for (int j = 0; j < dst_format_str.size(); ++j) { + if (dst_format_str[j] != src_format_str[i]) continue; // Found the dst index. Set output based on the number of dims. for (int k = 0; k < num_dim; ++k) { (*dst)[i * num_dim + k] = j * num_dim + k; diff --git a/tensorflow/core/kernels/eigen_attention.h b/tensorflow/core/kernels/eigen_attention.h index c5158e65d8a..4e03a787c2b 100644 --- a/tensorflow/core/kernels/eigen_attention.h +++ b/tensorflow/core/kernels/eigen_attention.h @@ -56,13 +56,14 @@ struct GlimpseExtractionOp { GlimpseExtractionOp(const Index width, const Index height, const std::vector<IndexPair<float> >& offsets, const bool normalized, const bool centered, - const ExtractGlimpsesNoiseMode noise) + const ExtractGlimpsesNoiseMode noise, const int version) : width_(width), height_(height), offsets_(offsets), normalized_(normalized), centered_(centered), - noise_(noise) {} + noise_(noise), + version_(version) {} template <typename Input> DSizes<Index, 4> dimensions(const Input& input) const { @@ -101,21 +102,44 @@ struct GlimpseExtractionOp { for (Index i = 0; i < batch_size; ++i) { float x = offsets_[i].first, y = offsets_[i].second; - // Un-normalize coordinates back to pixel space if normalized. - if (normalized_) { - x *= input_width; - y *= input_height; + if (version_ == 1) { + // Un-normalize coordinates back to pixel space if normalized. + if (normalized_) { + x *= input_width; + y *= input_height; + } + // Un-center if coordinates are centered on the image center. + if (centered_) { + x /= 2.0f; + y /= 2.0f; + x += input_width / 2.0f; + y += input_height / 2.0f; + } + // Remove half of the glimpse window. + x -= width_ / 2.0f; + y -= height_ / 2.0f; + } else { + if (normalized_) { + // Un-normalize coordinates back to pixel space if normalized. + x *= input_width; + y *= input_height; + if (centered_) { + // Un-center if coordinates are centered on the image center. + x /= 2.0f; + y /= 2.0f; + x += input_width / 2.0f; + y += input_height / 2.0f; + // Remove half of the glimpse window. + x -= width_ / 2.0f; + y -= height_ / 2.0f; + } + } else { + if (centered_) { + x += input_width / 2.0f; + y += input_height / 2.0f; + } + } } - // Un-center if coordinates are centered on the image center. - if (centered_) { - x /= 2.0f; - y /= 2.0f; - x += input_width / 2.0f; - y += input_height / 2.0f; - } - // Remove half of the glimpse window. - x -= width_ / 2.0f; - y -= height_ / 2.0f; const Index offset_x = (Index)x; const Index offset_y = (Index)y; @@ -243,6 +267,7 @@ struct GlimpseExtractionOp { const bool normalized_; const bool centered_; const ExtractGlimpsesNoiseMode noise_; + const int version_; }; } // namespace @@ -255,7 +280,8 @@ ExtractGlimpses( const typename internal::traits<Input>::Index height, const std::vector<IndexPair<float> >& offsets, const bool normalized = true, const bool centered = true, - const ExtractGlimpsesNoiseMode noise = ExtractGlimpsesNoiseMode::UNIFORM) { + const ExtractGlimpsesNoiseMode noise = ExtractGlimpsesNoiseMode::UNIFORM, + const int version = 2) { EIGEN_STATIC_ASSERT(internal::traits<Input>::Layout == ColMajor, YOU_MADE_A_PROGRAMMING_MISTAKE); EIGEN_STATIC_ASSERT(internal::traits<Input>::NumDimensions == 4, @@ -263,7 +289,7 @@ ExtractGlimpses( typedef typename internal::traits<Input>::Index Index; const GlimpseExtractionOp<Index> op(width, height, offsets, normalized, - centered, noise); + centered, noise, version); return input.customOp(op); } diff --git a/tensorflow/core/kernels/gather_functor_gpu.cu.h b/tensorflow/core/kernels/gather_functor_gpu.cu.h index 1cadee41a88..b2dd43885d0 100644 --- a/tensorflow/core/kernels/gather_functor_gpu.cu.h +++ b/tensorflow/core/kernels/gather_functor_gpu.cu.h @@ -92,13 +92,18 @@ struct GatherFunctor<GPUDevice, T, Index> { const int64 indices_size = indices.size(); const int64 slice_size = params.dimension(2); - GpuLaunchConfig config = GetGpuLaunchConfig(out_size, d); if (is_axis_zero) { + GpuLaunchConfig config = GetGpuLaunchConfig( + out_size, d, &GatherOpKernel<T, Index, true>, + /*dynamic_shared_memory_size=*/0, /*block_size_limit=*/0); TF_CHECK_OK(GpuLaunchKernel( GatherOpKernel<T, Index, true>, config.block_count, config.thread_per_block, 0, d.stream(), params.data(), indices.data(), out.data(), gather_dim_size, indices_size, slice_size, out_size)); } else { + GpuLaunchConfig config = GetGpuLaunchConfig( + out_size, d, &GatherOpKernel<T, Index, false>, + /*dynamic_shared_memory_size=*/0, /*block_size_limit=*/0); TF_CHECK_OK(GpuLaunchKernel( GatherOpKernel<T, Index, false>, config.block_count, config.thread_per_block, 0, d.stream(), params.data(), indices.data(), diff --git a/tensorflow/core/kernels/gather_op.cc b/tensorflow/core/kernels/gather_op.cc index 849a2b4389f..5e6bd1de9d6 100644 --- a/tensorflow/core/kernels/gather_op.cc +++ b/tensorflow/core/kernels/gather_op.cc @@ -88,18 +88,18 @@ class GatherOp : public OpKernel { } if (batch_dims_ != 0) { - if (batch_dims_ < 0) { - batch_dims_ = indices.dims() + batch_dims_; - } - - if (!axis_is_set) axis = batch_dims_; - OP_REQUIRES( c, batch_dims_ >= -indices.dims() && batch_dims_ <= indices.dims(), errors::InvalidArgument("Expected batch_dims in the range [", -indices.dims(), ", ", indices.dims(), "], but got ", batch_dims_)); + if (batch_dims_ < 0) { + batch_dims_ = indices.dims() + batch_dims_; + } + + if (!axis_is_set) axis = batch_dims_; + OP_REQUIRES(c, batch_dims_ < params.dims(), errors::InvalidArgument("batch_dims (", batch_dims_, ") must be less than rank(params) (", @@ -154,6 +154,7 @@ class GatherOp : public OpKernel { Tensor* out = nullptr; OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, &out)); if (N == 0) return; + if (inner_size == 0) return; int64 bad_i = -1; auto indices_flat = indices.flat<Index>(); diff --git a/tensorflow/core/kernels/gather_op_test.cc b/tensorflow/core/kernels/gather_op_test.cc index ecac2274ae8..e4c77881ea8 100644 --- a/tensorflow/core/kernels/gather_op_test.cc +++ b/tensorflow/core/kernels/gather_op_test.cc @@ -40,11 +40,12 @@ namespace { class GatherOpTest : public OpsTestBase { protected: - void MakeOp(DataType data_type, DataType index_type) { + void MakeOp(DataType data_type, DataType index_type, int batch_dims = 0) { TF_ASSERT_OK(NodeDefBuilder("myop", "GatherV2") .Input(FakeInput(data_type)) .Input(FakeInput(index_type)) .Input(FakeInput(index_type)) + .Attr("batch_dims", batch_dims) .Finalize(node_def())); TF_ASSERT_OK(InitOp()); } @@ -176,6 +177,20 @@ TEST_F(GatherOpTest, Error_IndexOutOfRange) { << s; } +TEST_F(GatherOpTest, Error_BatchDimsOutOfRange) { + MakeOp(DT_FLOAT, DT_INT32, 10); + + // Feed and run + AddInputFromArray<float>(TensorShape({5, 3}), + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14}); + AddInputFromArray<int32>(TensorShape({4}), {0, 4, 99, 2}); + AddInputFromArray<int32>(TensorShape({}), {0}); + Status s = RunOpKernel(); + EXPECT_TRUE(absl::StrContains( + s.ToString(), "Expected batch_dims in the range [-1, 1], but got 10")) + << s; +} + constexpr int kLookups = 2000; template <typename Index> diff --git a/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc b/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc index 115b3597964..6ef806d94c7 100644 --- a/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc +++ b/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc @@ -19,6 +19,8 @@ limitations under the License. #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/fused_batch_norm_op.h" +#include "tensorflow/core/kernels/no_op.h" #include "tensorflow/core/util/mkl_types.h" #include "tensorflow/core/util/mkl_util.h" #include "tensorflow/core/util/tensor_format.h" @@ -37,11 +39,14 @@ using BatchNormBwdPd = mkldnn::batch_normalization_backward::primitive_desc; namespace tensorflow { using CPUDevice = Eigen::ThreadPoolDevice; +using FusedBNActivationMode = functor::FusedBatchNormActivationMode; + struct MklBatchNormFwdParams { memory::dims src_dims; int depth; float eps; bool training; + FusedBNActivationMode activation_mode; #ifndef ENABLE_MKLDNN_V1 MEMORY_FORMAT src_format; #else @@ -50,14 +55,17 @@ struct MklBatchNormFwdParams { MklBatchNormFwdParams(const memory::dims& src_dims, int depth, float eps, #ifndef ENABLE_MKLDNN_V1 - bool training, MEMORY_FORMAT src_format) + bool training, MEMORY_FORMAT src_format, + FusedBNActivationMode activation_mode) #else - bool training, memory::desc src_md) + bool training, memory::desc src_md, + FusedBNActivationMode activation_mode) #endif // !ENABLE_MKLDNN_V1 : src_dims(src_dims), depth(depth), eps(eps), training(training), + activation_mode(activation_mode), #ifndef ENABLE_MKLDNN_V1 src_format(src_format) { } @@ -90,7 +98,7 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive { // mean_data: output data buffer of means // variance_data: output data buffer of variances void Execute(const T* src_data, const U* weights_data, T* dst_data, - U* mean_data, U* variance_data) { + U* mean_data, U* variance_data, U* workspace_data) { context_.src_mem->set_data_handle( static_cast<void*>(const_cast<T*>(src_data))); context_.dst_mem->set_data_handle(static_cast<void*>(dst_data)); @@ -104,6 +112,9 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive { context_.mean_mem->set_data_handle(static_cast<void*>(mean_data)); context_.variance_mem->set_data_handle(static_cast<void*>(variance_data)); } + if (workspace_data != nullptr) { + context_.ws_mem->set_data_handle(workspace_data); + } #ifdef ENABLE_MKLDNN_V1 // Execute batch-normalization forward primitives. execute_primitives(context_.fwd_primitives, context_.fwd_stream, @@ -123,6 +134,10 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive { context_.mean_mem->set_data_handle(DummyData); context_.variance_mem->set_data_handle(DummyData); } + + if (workspace_data != nullptr) { + context_.ws_mem->set_data_handle(DummyData); + } } MEMORY_PRIMITIVE_DESC GetDstPd() const { return context_.dst_mem->GET_DESC; } @@ -158,6 +173,7 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive { std::shared_ptr<mkldnn::memory> dst_mem; std::shared_ptr<mkldnn::memory> mean_mem; std::shared_ptr<mkldnn::memory> variance_mem; + std::shared_ptr<mkldnn::memory> ws_mem; // Forward BatchNorm primitive descriptor. std::shared_ptr<BatchNormFwdPd> fwd_pd; @@ -179,6 +195,7 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive { dst_mem(nullptr), mean_mem(nullptr), variance_mem(nullptr), + ws_mem(nullptr), bn_fwd(nullptr), fwd_stream(nullptr) {} }; @@ -192,6 +209,9 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive { : prop_kind::forward_scoring; #ifdef ENABLE_MKLDNN_V1 + if (fwdParams.activation_mode == FusedBNActivationMode::kRelu) { + context_.flags |= GET_FLAG(fuse_norm_relu); + } // Memory descriptor auto src_md = fwdParams.src_md; // Create forward BatchNorm descriptor and primitive descriptor. @@ -229,6 +249,13 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive { m_dims, U, MEMORY_FORMAT::nc, cpu_engine_, DummyData)); } +#ifdef ENABLE_MKLDNN_V1 + if (IS_SET(fuse_norm_relu)) { + context_.ws_mem.reset(new MEMORY_CONSTRUCTOR( + context_.fwd_pd->workspace_desc(), cpu_engine_, DummyData)); + } +#endif // ENABLE_MKLDNN_V1 + // BatchNorm forward primitive. // TODO(intel-tf): Merge all the #ifdefs and simplify code if (!fwdParams.training && !(IS_SET(use_global_stats))) { @@ -258,20 +285,41 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive { } else if (IS_SET(use_global_stats)) { #ifdef ENABLE_MKLDNN_V1 if ((IS_SET(use_scale_shift)) && GET_FLAG(use_scale_shift)) { - context_.net_args.push_back( - {{MKLDNN_ARG_SRC, *context_.src_mem}, - {MKLDNN_ARG_MEAN, *context_.mean_mem}, - {MKLDNN_ARG_VARIANCE, *context_.variance_mem}, - {MKLDNN_ARG_WEIGHTS, *context_.weights_mem}, - { MKLDNN_ARG_DST, - *context_.dst_mem }}); + if (IS_SET(fuse_norm_relu)) { + context_.net_args.push_back( + {{MKLDNN_ARG_SRC, *context_.src_mem}, + {MKLDNN_ARG_MEAN, *context_.mean_mem}, + {MKLDNN_ARG_VARIANCE, *context_.variance_mem}, + {MKLDNN_ARG_WEIGHTS, *context_.weights_mem}, + {MKLDNN_ARG_DST, *context_.dst_mem}, + { MKLDNN_ARG_WORKSPACE, + *context_.ws_mem }}); + } else { + context_.net_args.push_back( + {{MKLDNN_ARG_SRC, *context_.src_mem}, + {MKLDNN_ARG_MEAN, *context_.mean_mem}, + {MKLDNN_ARG_VARIANCE, *context_.variance_mem}, + {MKLDNN_ARG_WEIGHTS, *context_.weights_mem}, + { MKLDNN_ARG_DST, + *context_.dst_mem }}); + } } else { - context_.net_args.push_back( - {{MKLDNN_ARG_SRC, *context_.src_mem}, - {MKLDNN_ARG_MEAN, *context_.mean_mem}, - {MKLDNN_ARG_VARIANCE, *context_.variance_mem}, - { MKLDNN_ARG_DST, - *context_.dst_mem }}); + if (IS_SET(fuse_norm_relu)) { + context_.net_args.push_back( + {{MKLDNN_ARG_SRC, *context_.src_mem}, + {MKLDNN_ARG_MEAN, *context_.mean_mem}, + {MKLDNN_ARG_VARIANCE, *context_.variance_mem}, + {MKLDNN_ARG_DST, *context_.dst_mem}, + { MKLDNN_ARG_WORKSPACE, + *context_.ws_mem }}); + } else { + context_.net_args.push_back( + {{MKLDNN_ARG_SRC, *context_.src_mem}, + {MKLDNN_ARG_MEAN, *context_.mean_mem}, + {MKLDNN_ARG_VARIANCE, *context_.variance_mem}, + { MKLDNN_ARG_DST, + *context_.dst_mem }}); + } } context_.bn_fwd.reset(new batch_normalization_forward(*context_.fwd_pd)); #else @@ -291,19 +339,40 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive { } else { #ifdef ENABLE_MKLDNN_V1 if ((IS_SET(use_scale_shift)) && GET_FLAG(use_scale_shift)) { - context_.net_args.push_back( - {{MKLDNN_ARG_SRC, *context_.src_mem}, - {MKLDNN_ARG_WEIGHTS, *context_.weights_mem}, - {MKLDNN_ARG_DST, *context_.dst_mem}, - {MKLDNN_ARG_MEAN, *context_.mean_mem}, - { MKLDNN_ARG_VARIANCE, - *context_.variance_mem }}); + if (IS_SET(fuse_norm_relu)) { + context_.net_args.push_back( + {{MKLDNN_ARG_SRC, *context_.src_mem}, + {MKLDNN_ARG_WEIGHTS, *context_.weights_mem}, + {MKLDNN_ARG_DST, *context_.dst_mem}, + {MKLDNN_ARG_MEAN, *context_.mean_mem}, + {MKLDNN_ARG_VARIANCE, *context_.variance_mem}, + { MKLDNN_ARG_WORKSPACE, + *context_.ws_mem }}); + } else { + context_.net_args.push_back( + {{MKLDNN_ARG_SRC, *context_.src_mem}, + {MKLDNN_ARG_WEIGHTS, *context_.weights_mem}, + {MKLDNN_ARG_DST, *context_.dst_mem}, + {MKLDNN_ARG_MEAN, *context_.mean_mem}, + { MKLDNN_ARG_VARIANCE, + *context_.variance_mem }}); + } } else { - context_.net_args.push_back({{MKLDNN_ARG_SRC, *context_.src_mem}, - {MKLDNN_ARG_DST, *context_.dst_mem}, - {MKLDNN_ARG_MEAN, *context_.mean_mem}, - { MKLDNN_ARG_VARIANCE, - *context_.variance_mem }}); + if (IS_SET(fuse_norm_relu)) { + context_.net_args.push_back( + {{MKLDNN_ARG_SRC, *context_.src_mem}, + {MKLDNN_ARG_DST, *context_.dst_mem}, + {MKLDNN_ARG_MEAN, *context_.mean_mem}, + {MKLDNN_ARG_VARIANCE, *context_.variance_mem}, + { MKLDNN_ARG_WORKSPACE, + *context_.ws_mem }}); + } else { + context_.net_args.push_back({{MKLDNN_ARG_SRC, *context_.src_mem}, + {MKLDNN_ARG_DST, *context_.dst_mem}, + {MKLDNN_ARG_MEAN, *context_.mean_mem}, + { MKLDNN_ARG_VARIANCE, + *context_.variance_mem }}); + } } context_.bn_fwd.reset(new batch_normalization_forward(*context_.fwd_pd)); #else @@ -360,6 +429,7 @@ class MklFusedBatchNormFwdPrimitiveFactory : public MklPrimitiveFactory<T> { key_creator.AddAsKey<int>(fwdParams.depth); key_creator.AddAsKey<float>(fwdParams.eps); key_creator.AddAsKey<bool>(fwdParams.training); + key_creator.AddAsKey<FusedBNActivationMode>(fwdParams.activation_mode); key_creator.AddAsKey(typeid(T).name()); key_creator.AddAsKey(typeid(U).name()); return key_creator.GetKey(); @@ -676,7 +746,8 @@ class MklFusedBatchNormBwdPrimitiveFactory : public MklPrimitiveFactory<T> { // Adding a third parameter to the template to support FusedBatchNormV3 // with MKL. This is different from default where the classes are // derived. Moves enabling to compile-time rather than runtime. -template <typename Device, typename T, typename U, bool reserved_space> +template <typename Device, typename T, typename U, bool reserved_space, + bool is_batch_norm_ex = false> class MklFusedBatchNormOp : public OpKernel { public: explicit MklFusedBatchNormOp(OpKernelConstruction* context) @@ -696,6 +767,28 @@ class MklFusedBatchNormOp : public OpKernel { depth_ = 0; mean_values_ = nullptr; variance_values_ = nullptr; + +#ifndef ENABLE_MKLDNN_V1 + OP_REQUIRES(context, !is_batch_norm_ex, + errors::InvalidArgument( + "_MklFusedBatchNormEx is not supported in DNNL 0.x .")); +#endif + if (!is_batch_norm_ex) { + activation_mode_ = FusedBNActivationMode::kIdentity; + } else { + int num_side_inputs; + OP_REQUIRES_OK(context, + context->GetAttr("num_side_inputs", &num_side_inputs)); + // Currently _MKLFusedBatchNormEx do not support "SideInput" + OP_REQUIRES(context, num_side_inputs == 0, + errors::InvalidArgument( + "_MKLFusedBatchNorm do not support side input now.")); + + OP_REQUIRES_OK(context, ParseActivationMode(context, &activation_mode_)); + OP_REQUIRES(context, activation_mode_ == FusedBNActivationMode::kRelu, + errors::InvalidArgument( + "_MKLFusedBatchNorm only support Relu activation")); + } } void Compute(OpKernelContext* context) override { @@ -744,9 +837,12 @@ class MklFusedBatchNormOp : public OpKernel { // Handle the special case: input with 0 element and 0 batch size. Tensor* dst_tensor = nullptr; + TensorShape workspace_tf_shape; if (tf_shape_src.num_elements() == 0) { - HandleEmptyInput(context, tf_shape_src, scale_tensor.shape(), - &dst_tensor); + size_t workspace_bytes = 0; + workspace_tf_shape.AddDim(workspace_bytes); + HandleEmptyInput(context, tf_shape_src, workspace_tf_shape, + scale_tensor.shape(), &dst_tensor); return; } @@ -758,23 +854,16 @@ class MklFusedBatchNormOp : public OpKernel { // Index of output tensor(diff_src). const size_t kDstIndex = 0; - // Allocate 4 output TF tensors. + // Allocate 5 output TF tensors. Tensor* batch_mean_tensor = nullptr; Tensor* batch_variance_tensor = nullptr; Tensor* saved_mean_tensor = nullptr; Tensor* saved_variance_tensor = nullptr; Tensor* reserved_space_tensor = nullptr; - AllocateTFOutputs(context, scale_tensor.shape(), &batch_mean_tensor, - &batch_variance_tensor, &saved_mean_tensor, - &saved_variance_tensor, &reserved_space_tensor); - - if (is_training_) - SetMeanVariance(*batch_mean_tensor, *batch_variance_tensor); - else - SetMeanVariance(est_mean_tensor, est_variance_tensor); MklDnnData<T> src(&cpu_engine_); MklDnnData<U> weights(&cpu_engine_); + MklDnnData<U> wksp(&cpu_engine_); MEMORY_FORMAT dnn_fmt; MKL_TENSOR_FORMAT mkl_tensor_fmt; @@ -801,6 +890,51 @@ class MklFusedBatchNormOp : public OpKernel { ? dnn_shape_src.GetMklLayout() : memory::desc(src_dims, MklDnnType<T>(), dnn_fmt); +#ifdef ENABLE_MKLDNN_V1 + MklBatchNormFwdParams fwdParams(src_dims, depth_, epsilon_, is_training_, + src_md, activation_mode_); +#else + MklBatchNormFwdParams fwdParams( + src_dims, depth_, epsilon_, is_training_, + static_cast<MEMORY_FORMAT>(src_md.data.format), activation_mode_); +#endif // ENABLE_MKLDNN_V1 + // Get forward batch-normalization op from the primitive caching pool. + MklFusedBatchNormFwdPrimitive<T, U>* bn_fwd = + MklFusedBatchNormFwdPrimitiveFactory<T, U>::Get(fwdParams); + + // Allocate workspace tensor + U* ws_data = nullptr; + if (fwdParams.activation_mode == FusedBNActivationMode::kRelu) { +#ifdef ENABLE_MKLDNN_V1 + MEMORY_PRIMITIVE_DESC workspace_pd = + bn_fwd->GetBatchNormFwdPd()->workspace_desc(); + size_t workspace_bytes = workspace_pd.get_size(); + workspace_tf_shape.AddDim(workspace_bytes); + + AllocateTFOutputs(context, scale_tensor.shape(), workspace_tf_shape, + &batch_mean_tensor, &batch_variance_tensor, + &saved_mean_tensor, &saved_variance_tensor, + &reserved_space_tensor); + if (reserved_space) { + wksp.SetUsrMem(workspace_pd, reserved_space_tensor); + ws_data = static_cast<U*>(wksp.GetOpMem().get_data_handle()); + } +#endif // ENABLE_MKLDNN_V1 + } else { + // There is actually no workspace tensor out, so we make a dummy one. + size_t workspace_bytes = 0; + workspace_tf_shape.AddDim(workspace_bytes); + AllocateTFOutputs(context, scale_tensor.shape(), workspace_tf_shape, + &batch_mean_tensor, &batch_variance_tensor, + &saved_mean_tensor, &saved_variance_tensor, + &reserved_space_tensor); + } + + if (is_training_) + SetMeanVariance(*batch_mean_tensor, *batch_variance_tensor); + else + SetMeanVariance(est_mean_tensor, est_variance_tensor); + // MKL-DNN packs scale & shift as "weights": // <scale>...<scale><shift>...<shift> weights.AllocateBuffer(2 * depth_ * sizeof(U)); @@ -821,18 +955,6 @@ class MklFusedBatchNormOp : public OpKernel { reinterpret_cast<char*>(variance_values_), depth_ * sizeof(U)); -#ifdef ENABLE_MKLDNN_V1 - MklBatchNormFwdParams fwdParams(src_dims, depth_, epsilon_, is_training_, - src_md); -#else - MklBatchNormFwdParams fwdParams( - src_dims, depth_, epsilon_, is_training_, - static_cast<MEMORY_FORMAT>(src_md.data.format)); -#endif // ENABLE_MKLDNN_V1 - // Get forward batch-normalization op from the primitive caching pool. - MklFusedBatchNormFwdPrimitive<T, U>* bn_fwd = - MklFusedBatchNormFwdPrimitiveFactory<T, U>::Get(fwdParams); - // Check if reorder is needed for src. const T* src_data = nullptr; std::shared_ptr<BatchNormFwdPd> bn_fwd_pd = bn_fwd->GetBatchNormFwdPd(); @@ -866,7 +988,7 @@ class MklFusedBatchNormOp : public OpKernel { // Execute bn_fwd->Execute(src_data, weights_op_data, dst_data, mean_op_data, - variance_op_data); + variance_op_data, ws_data); float adjust_factor = 1.0; if (is_training_) { @@ -924,6 +1046,7 @@ class MklFusedBatchNormOp : public OpKernel { U* mean_values_; U* variance_values_; size_t depth_; // Batch normalization is performed for per channel. + FusedBNActivationMode activation_mode_; engine cpu_engine_ = engine(ENGINE_CPU, 0); void ExtractParams(OpKernelContext* context) { @@ -938,6 +1061,7 @@ class MklFusedBatchNormOp : public OpKernel { } void HandleEmptyInput(OpKernelContext* context, TensorShape tf_shape_src, + TensorShape workspace_tf_shape, TensorShape tf_shape_scale, Tensor** dst_tensor) { DCHECK(dst_tensor); @@ -955,12 +1079,14 @@ class MklFusedBatchNormOp : public OpKernel { Tensor* saved_mean_tensor = nullptr; Tensor* saved_variance_tensor = nullptr; Tensor* reserved_space_tensor = nullptr; - AllocateTFOutputs(context, tf_shape_scale, &batch_mean_tensor, - &batch_variance_tensor, &saved_mean_tensor, - &saved_variance_tensor, &reserved_space_tensor); + AllocateTFOutputs(context, tf_shape_scale, workspace_tf_shape, + &batch_mean_tensor, &batch_variance_tensor, + &saved_mean_tensor, &saved_variance_tensor, + &reserved_space_tensor); } void AllocateTFOutputs(OpKernelContext* context, TensorShape tf_shape_scale, + TensorShape workspace_tf_shape, Tensor** batch_mean_tensor, Tensor** batch_variance_tensor, Tensor** saved_mean_tensor, @@ -1024,21 +1150,15 @@ class MklFusedBatchNormOp : public OpKernel { std::fill_n(saved_variance_data, num_elements, static_cast<U>(0)); // Changes to support reserved_space_3 parameter in FusedBatchNormV3. - // TODO: This parameter functionality is not implemented on CPU. - // It is used to hold intermediate results. So the allocated - // memory is filled with 0s. if (reserved_space) { DCHECK(reserved_space_tensor != nullptr); MklDnnShape mkl_shape_reserved_space; mkl_shape_reserved_space.SetMklTensor(false); AllocateOutputSetMklShape(context, kReservedSpaceIndex, - reserved_space_tensor, tf_shape_scale, + reserved_space_tensor, workspace_tf_shape, mkl_shape_reserved_space); DCHECK((*reserved_space_tensor) != nullptr); - auto saved_reserved_space_data = - (*reserved_space_tensor)->flat<U>().data(); - std::fill_n(saved_reserved_space_data, num_elements, static_cast<U>(0)); } } }; @@ -1363,7 +1483,7 @@ class MklFusedBatchNormGradOp : public OpKernel { .Device(DEVICE_CPU) \ .TypeConstraint<T>("T") \ .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ - MklFusedBatchNormOp<CPUDevice, T, T, false>); + MklFusedBatchNormOp<CPUDevice, T, T, false, false>); TF_CALL_float(REGISTER_MKL_FUSED_BATCHNORM_CPU); TF_CALL_bfloat16(REGISTER_MKL_FUSED_BATCHNORM_CPU); @@ -1376,7 +1496,7 @@ TF_CALL_bfloat16(REGISTER_MKL_FUSED_BATCHNORM_CPU); .TypeConstraint<T>("T") \ .TypeConstraint<U>("U") \ .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ - MklFusedBatchNormOp<CPUDevice, T, U, false>); + MklFusedBatchNormOp<CPUDevice, T, U, false, false>); REGISTER_MKL_FUSED_BATCHNORM_V2_CPU(float, float); REGISTER_MKL_FUSED_BATCHNORM_V2_CPU(bfloat16, float); @@ -1417,12 +1537,30 @@ REGISTER_MKL_FUSED_BATCHNORM_GRAD_V2_CPU(bfloat16, float); .TypeConstraint<T>("T") \ .TypeConstraint<U>("U") \ .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ - MklFusedBatchNormOp<CPUDevice, T, U, true>); + MklFusedBatchNormOp<CPUDevice, T, U, true, false>); \ + REGISTER_KERNEL_BUILDER( \ + Name("_MklFusedBatchNormEx") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<T>("T") \ + .TypeConstraint<U>("U") \ + .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ + MklFusedBatchNormOp<CPUDevice, T, U, true, true>); REGISTER_MKL_FUSED_BATCHNORM_V3_CPU(float, float); REGISTER_MKL_FUSED_BATCHNORM_V3_CPU(bfloat16, float); #undef REGISTER_MKL_FUSED_BATCHNORM_V3_CPU +REGISTER_KERNEL_BUILDER(Name("_FusedBatchNormEx") + .Device(DEVICE_CPU) + .TypeConstraint<float>("T") + .TypeConstraint<float>("U"), + NoOp); +REGISTER_KERNEL_BUILDER(Name("_FusedBatchNormEx") + .Device(DEVICE_CPU) + .TypeConstraint<bfloat16>("T") + .TypeConstraint<float>("U"), + NoOp); + #define REGISTER_MKL_FUSED_BATCHNORM_GRAD_V3_CPU(T, U) \ REGISTER_KERNEL_BUILDER( \ Name("_MklFusedBatchNormGradV3") \ diff --git a/tensorflow/core/kernels/mkl_matmul_op.cc b/tensorflow/core/kernels/mkl_matmul_op.cc index 3a7c864d10e..3eccf97f53c 100644 --- a/tensorflow/core/kernels/mkl_matmul_op.cc +++ b/tensorflow/core/kernels/mkl_matmul_op.cc @@ -25,6 +25,7 @@ limitations under the License. #if defined(INTEL_MKL) +#include "mkldnn.hpp" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -32,13 +33,6 @@ limitations under the License. #include "tensorflow/core/kernels/mkl_matmul_ops_common.h" #include "tensorflow/core/util/mkl_util.h" -// This header file is part of MKL ML, need equivalent file in MKL DNN -#ifndef INTEL_MKL_DNN_ONLY -#include "mkl_cblas.h" -#endif - -#include "mkldnn.h" - namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; @@ -157,21 +151,11 @@ class MklMatMulOp : public OpKernel { // 1.0 and 0.0 respectively. const float alpha = 1.0f; const float beta = 0.0f; -#if defined(INTEL_MKL_DNN_ONLY) - const char* const ftrans[] = {"N", "T", "C"}; - int index_transa = transa ? 1 : 0; - int index_transb = transb ? 1 : 0; - VLOG(2) << "MKL DNN SGEMM called"; - // MKL DNN only supports the Fortran api and requires column major while - // Tensorflow uses row major so we reverse the order A and B - mkldnn_sgemm(ftrans[index_transb], ftrans[index_transa], &n, &m, &k, &alpha, - b, &ldb, a, &lda, &beta, c, &ldc); -#else - // MKL ML binary uses CBLAS API - cblas_sgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans, - transb ? CblasTrans : CblasNoTrans, m, n, k, alpha, a, lda, b, - ldb, beta, c, ldc); -#endif + char char_transa = transa ? 'T' : 'N'; + char char_transb = transb ? 'T' : 'N'; + VLOG(2) << "MKL DNN SGEMM CALLED"; + dnnl_sgemm(char_transa, char_transb, m, n, k, alpha, a, lda, b, ldb, beta, + c, ldc); } #ifdef ENABLE_INTEL_MKL_BFLOAT16 @@ -205,53 +189,6 @@ class MklMatMulOp : public OpKernel { FloatToBFloat16(c_float.flat<float>().data(), c, c_float.NumElements()); } #endif // ENABLE_INTEL_MKL_BFLOAT16 - -// MKL-DNN only supports SGEMM and bfloat16-GEMM. -#ifndef INTEL_MKL_DNN_ONLY - - // Matrix-Matrix Multiplication with FP64 tensors. For detailed info about - // parameters, look at FP32 function description. - void MklBlasGemm(OpKernelContext* ctx, bool transa, bool transb, const int m, - const int n, const int k, const double* a, const int lda, - const double* b, const int ldb, double* c, const int ldc) { - const double alpha = 1.0; - const double beta = 0.0; - cblas_dgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans, - transb ? CblasTrans : CblasNoTrans, m, n, k, alpha, a, lda, b, - ldb, beta, c, ldc); - } - - // Matrix-Matrix Multiplication with Complex64 (std::complex<float>) tensors. - // For detailed info about parameters, look at FP32 function description. - void MklBlasGemm(OpKernelContext* ctx, bool transa, bool transb, const int m, - const int n, const int k, const complex64* a, const int lda, - const complex64* b, const int ldb, complex64* c, - int const ldc) { - const MKL_Complex8 alpha = {1.0f, 0.0f}; - const MKL_Complex8 beta = {0.0f, 0.0f}; - cblas_cgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans, - transb ? CblasTrans : CblasNoTrans, m, n, k, &alpha, - reinterpret_cast<const MKL_Complex8*>(a), lda, - reinterpret_cast<const MKL_Complex8*>(b), ldb, &beta, - reinterpret_cast<MKL_Complex8*>(c), ldc); - } - - // Matrix-Matrix Multiplication with Complex128 (std::complex<double>) - // tensors. For detailed info about parameters, look at FP32 function - // description. - void MklBlasGemm(OpKernelContext* ctx, bool transa, bool transb, const int m, - const int n, const int k, const complex128* a, const int lda, - const complex128* b, const int ldb, complex128* c, - const int ldc) { - const MKL_Complex16 alpha = {1.0, 0.0}; - const MKL_Complex16 beta = {0.0, 0.0}; - cblas_zgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans, - transb ? CblasTrans : CblasNoTrans, m, n, k, &alpha, - reinterpret_cast<const MKL_Complex16*>(a), lda, - reinterpret_cast<const MKL_Complex16*>(b), ldb, &beta, - reinterpret_cast<MKL_Complex16*>(c), ldc); - } -#endif // !INTEL_MKL_DNN_ONLY }; #define REGISTER_CPU(T) \ @@ -269,13 +206,6 @@ TF_CALL_float(REGISTER_CPU); #ifdef ENABLE_INTEL_MKL_BFLOAT16 TF_CALL_bfloat16(REGISTER_CPU); #endif // ENABLE_INTEL_MKL_BFLOAT16 - -#ifndef INTEL_MKL_DNN_ONLY -TF_CALL_double(REGISTER_CPU); -TF_CALL_complex64(REGISTER_CPU); -TF_CALL_complex128(REGISTER_CPU); -#endif // !INTEL_MKL_DNN_ONLY #endif // ENABLE_MKL - } // namespace tensorflow #endif // INTEL_MKL diff --git a/tensorflow/core/kernels/partitioned_function_ops.cc b/tensorflow/core/kernels/partitioned_function_ops.cc index 3045fd050d5..a85f3f449fd 100644 --- a/tensorflow/core/kernels/partitioned_function_ops.cc +++ b/tensorflow/core/kernels/partitioned_function_ops.cc @@ -245,7 +245,6 @@ void PartitionedCallOp::RunFunction(FunctionLibraryRuntime::Handle handle, run_opts.source_device = lib->device() == nullptr ? "" : lib->device()->name(); run_opts.allow_dead_tensors = true; - run_opts.rendezvous = ctx->rendezvous(); std::vector<Tensor>* rets = new std::vector<Tensor>; const string& func_name = func_->name(); diff --git a/tensorflow/core/kernels/pooling_ops_3d.cc b/tensorflow/core/kernels/pooling_ops_3d.cc index 31ead11dd34..532d861e615 100644 --- a/tensorflow/core/kernels/pooling_ops_3d.cc +++ b/tensorflow/core/kernels/pooling_ops_3d.cc @@ -192,6 +192,7 @@ class Pooling3DOp : public UnaryOp<T> { {{out[2], out[1], out[0]}}, depth); Tensor* output; OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output)); + if (out_shape.num_elements() == 0) return; LaunchPoolingOp<Device, T, Type>::launch(context, tensor_in, window, stride, padding, data_format_, padding_, output); diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc index ccd1e3c835d..b606d411a3d 100644 --- a/tensorflow/core/kernels/resource_variable_ops.cc +++ b/tensorflow/core/kernels/resource_variable_ops.cc @@ -282,6 +282,7 @@ REGISTER_KERNEL_BUILDER( TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS); TF_CALL_int64(REGISTER_GPU_KERNELS); TF_CALL_variant(REGISTER_GPU_KERNELS); +TF_CALL_uint32(REGISTER_GPU_KERNELS); #undef REGISTER_GPU_KERNELS REGISTER_KERNEL_BUILDER(Name("_VarHandlesOp") @@ -511,6 +512,7 @@ class AssignVariableOp<Device, Variant> : public OpKernel { TF_CALL_ALL_TYPES(REGISTER_KERNELS); TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS); +TF_CALL_uint32(REGISTER_KERNELS); #undef REGISTER_KERNELS #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM @@ -524,6 +526,7 @@ TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS); TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS); TF_CALL_int64(REGISTER_GPU_KERNELS); TF_CALL_variant(REGISTER_GPU_KERNELS); +TF_CALL_uint32(REGISTER_GPU_KERNELS); #undef REGISTER_GPU_KERNELS #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/reverse_sequence_op.cc b/tensorflow/core/kernels/reverse_sequence_op.cc index 0e112133915..b5b62bc76ca 100644 --- a/tensorflow/core/kernels/reverse_sequence_op.cc +++ b/tensorflow/core/kernels/reverse_sequence_op.cc @@ -43,9 +43,9 @@ typedef Eigen::GpuDevice GPUDevice; template <typename Device, typename Tlen> void CheckErrors(OpKernelContext* context, int batch_dim, int seq_dim) { const Tensor& input = context->input(0); - const Tensor& seq_lens = context->input(1); + const Tensor& seq_lengths = context->input(1); - auto seq_lens_t = seq_lens.vec<Tlen>(); + auto seq_lens_t = seq_lengths.vec<Tlen>(); std::vector<Tlen> seq_lens_vec(seq_lens_t.size()); @@ -56,15 +56,16 @@ void CheckErrors(OpKernelContext* context, int batch_dim, int seq_dim) { OP_REQUIRES(context, batch_dim != seq_dim, errors::InvalidArgument("batch_dim == seq_dim == ", seq_dim)); OP_REQUIRES(context, seq_dim < input.dims(), - errors::InvalidArgument("seq_dim must be < input.dims()", "( ", + errors::InvalidArgument("seq_dim must be < input rank", " ( ", seq_dim, " vs. ", input.dims(), ")")); OP_REQUIRES(context, batch_dim < input.dims(), - errors::InvalidArgument("batch_dim must be < input.dims()", "( ", + errors::InvalidArgument("batch_dim must be < input rank", " ( ", batch_dim, " vs. ", input.dims(), ")")); - OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(batch_dim), - errors::InvalidArgument("len(seq_lens) != input.dims(", batch_dim, - "), ", "(", seq_lens.NumElements(), - " vs. ", input.dim_size(batch_dim), ")")); + OP_REQUIRES( + context, seq_lengths.NumElements() == input.dim_size(batch_dim), + errors::InvalidArgument("Length of seq_lengths != input.dims(", batch_dim, + "), ", "(", seq_lengths.NumElements(), " vs. ", + input.dim_size(batch_dim), ")")); for (size_t d = 0; d < seq_lens_vec.size(); ++d) { OP_REQUIRES(context, seq_lens_vec[d] >= 0, @@ -77,21 +78,22 @@ void CheckErrors(OpKernelContext* context, int batch_dim, int seq_dim) { void CheckErrorsGPU(OpKernelContext* context, int batch_dim, int seq_dim) { const Tensor& input = context->input(0); - const Tensor& seq_lens = context->input(1); + const Tensor& seq_lengths = context->input(1); OP_REQUIRES(context, batch_dim != seq_dim, errors::InvalidArgument("batch_dim == seq_dim == ", seq_dim)); OP_REQUIRES(context, seq_dim < input.dims(), - errors::InvalidArgument("seq_dim must be < input.dims()", "( ", + errors::InvalidArgument("seq_dim must be < input rank", " ( ", seq_dim, " vs. ", input.dims(), ")")); OP_REQUIRES(context, batch_dim < input.dims(), - errors::InvalidArgument("batch_dim must be < input.dims()", "( ", + errors::InvalidArgument("batch_dim must be < input rank", " ( ", batch_dim, " vs. ", input.dims(), ")")); - OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(batch_dim), - errors::InvalidArgument("len(seq_lens) != input.dims(", batch_dim, - "), ", "(", seq_lens.NumElements(), - " vs. ", input.dim_size(batch_dim), ")")); + OP_REQUIRES( + context, seq_lengths.NumElements() == input.dim_size(batch_dim), + errors::InvalidArgument("Length of seq_lengths != input.dims(", batch_dim, + "), ", "(", seq_lengths.NumElements(), " vs. ", + input.dim_size(batch_dim), ")")); } template <> @@ -117,14 +119,14 @@ class ReverseSequenceOp : public OpKernel { void Compute(OpKernelContext* context) override { const Tensor& input = context->input(0); - const Tensor& seq_lens = context->input(1); + const Tensor& seq_lengths = context->input(1); // Preliminary validation of sizes. - OP_REQUIRES(context, TensorShapeUtils::IsVector(seq_lens.shape()), - errors::InvalidArgument("seq_lens input must be 1-dim, not ", - seq_lens.dims())); + OP_REQUIRES(context, TensorShapeUtils::IsVector(seq_lengths.shape()), + errors::InvalidArgument("seq_lengths must be 1-dim, not ", + seq_lengths.dims())); - auto seq_lens_t = seq_lens.vec<Tlen>(); + auto seq_lens_t = seq_lengths.vec<Tlen>(); CheckErrors<Device, Tlen>(context, batch_dim_, seq_dim_); if (!context->status().ok()) return; @@ -186,7 +188,7 @@ namespace functor { void ReverseSequence<GPUDevice, T, Tlen, Dims>::Compute( \ const GPUDevice& d, typename TTypes<T, Dims>::ConstTensor input, \ int32 batch_dim, int32 seq_dim, \ - typename TTypes<Tlen>::ConstVec seq_lens, \ + typename TTypes<Tlen>::ConstVec seq_lengths, \ typename TTypes<T, Dims>::Tensor output); \ extern template struct ReverseSequence<GPUDevice, T, Tlen, Dims>; diff --git a/tensorflow/core/kernels/sparse_cross_op.cc b/tensorflow/core/kernels/sparse_cross_op.cc index c7c538a945f..9a80aad5d04 100644 --- a/tensorflow/core/kernels/sparse_cross_op.cc +++ b/tensorflow/core/kernels/sparse_cross_op.cc @@ -15,6 +15,7 @@ limitations under the License. // Contains OP to generate sparse crosses. #include <assert.h> + #include <limits> #include <string> #include <vector> @@ -29,6 +30,7 @@ limitations under the License. #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/fingerprint.h" +#include "tensorflow/core/platform/strong_hash.h" #include "tensorflow/core/util/work_sharder.h" namespace tensorflow { @@ -42,7 +44,8 @@ class ColumnInterface { virtual int64 FeatureCount(int64 batch) const = 0; // Returns the fingerprint of nth feature from the specified batch. - virtual InternalType Feature(int64 batch, int64 n) const = 0; + virtual InternalType Feature(int64 batch, int64 n, + bool strong_hash) const = 0; virtual ~ColumnInterface() {} }; @@ -63,7 +66,7 @@ class SparseTensorColumn : public ColumnInterface<InternalType> { return feature_counts_[batch]; } - InternalType Feature(int64 batch, int64 n) const override; + InternalType Feature(int64 batch, int64 n, bool strong_hash) const override; ~SparseTensorColumn() override {} @@ -73,18 +76,69 @@ class SparseTensorColumn : public ColumnInterface<InternalType> { std::vector<int64> feature_start_indices_; }; +// A column that is backed by a sparse tensor. +template <typename InternalType> +class KeyedSparseTensorColumn : public ColumnInterface<InternalType> { + public: + KeyedSparseTensorColumn(const Tensor& values, + std::vector<int64> feature_counts, + std::vector<int64> feature_start_indices, + std::vector<int64> key) + : values_(values), + feature_counts_(std::move(feature_counts)), + feature_start_indices_(std::move(feature_start_indices)) { + DCHECK_EQ(feature_counts_.size(), feature_start_indices_.size()); + std::memcpy(key_, key.data(), sizeof(key_)); + } + + int64 FeatureCount(int64 batch) const override { + return feature_counts_[batch]; + } + + InternalType Feature(int64 batch, int64 n, bool strong_hash) const override; + + ~KeyedSparseTensorColumn() override {} + + private: + const Tensor& values_; + uint64 key_[2]; + std::vector<int64> feature_counts_; + std::vector<int64> feature_start_indices_; +}; + // InternalType is int64 only when using HashCrosser. template <> -int64 SparseTensorColumn<int64>::Feature(int64 batch, int64 n) const { +int64 SparseTensorColumn<int64>::Feature(int64 batch, int64 n, + bool strong_hash) const { const int64 start = feature_start_indices_[batch]; if (DT_STRING == values_.dtype()) return Fingerprint64(values_.vec<tstring>().data()[start + n]); return values_.vec<int64>().data()[start + n]; } +template <> +int64 KeyedSparseTensorColumn<int64>::Feature(int64 batch, int64 n, + bool strong_hash) const { + const int64 start = feature_start_indices_[batch]; + if (strong_hash) { + if (DT_STRING == values_.dtype()) { + return StrongKeyedHash(key_, values_.vec<tstring>()(start + n)); + } + return StrongKeyedHash( + key_, {reinterpret_cast<const char*>(&values_.vec<int64>()(start + n)), + sizeof(values_.dtype())}); + } + if (DT_STRING == values_.dtype()) + return Fingerprint64(values_.vec<tstring>()(start + n)); + return Fingerprint64( + {reinterpret_cast<const char*>(&values_.vec<int64>()(start + n)), + sizeof(values_.dtype())}); +} + // InternalType is string or StringPiece when using StringCrosser. template <> -tstring SparseTensorColumn<tstring>::Feature(int64 batch, int64 n) const { +tstring SparseTensorColumn<tstring>::Feature(int64 batch, int64 n, + bool strong_hash) const { const int64 start = feature_start_indices_[batch]; if (DT_STRING == values_.dtype()) return values_.vec<tstring>().data()[start + n]; @@ -92,8 +146,24 @@ tstring SparseTensorColumn<tstring>::Feature(int64 batch, int64 n) const { } template <> -StringPiece SparseTensorColumn<StringPiece>::Feature(int64 batch, - int64 n) const { +tstring KeyedSparseTensorColumn<tstring>::Feature(int64 batch, int64 n, + bool strong_hash) const { + const int64 start = feature_start_indices_[batch]; + if (DT_STRING == values_.dtype()) + return values_.vec<tstring>().data()[start + n]; + return std::to_string(values_.vec<int64>().data()[start + n]); +} + +template <> +StringPiece SparseTensorColumn<StringPiece>::Feature(int64 batch, int64 n, + bool strong_hash) const { + const int64 start = feature_start_indices_[batch]; + return values_.vec<tstring>().data()[start + n]; +} + +template <> +StringPiece KeyedSparseTensorColumn<StringPiece>::Feature( + int64 batch, int64 n, bool strong_hash) const { const int64 start = feature_start_indices_[batch]; return values_.vec<tstring>().data()[start + n]; } @@ -106,7 +176,7 @@ class DenseTensorColumn : public ColumnInterface<InternalType> { int64 FeatureCount(int64 batch) const override { return tensor_.dim_size(1); } - InternalType Feature(int64 batch, int64 n) const override; + InternalType Feature(int64 batch, int64 n, bool strong_hash) const override; ~DenseTensorColumn() override {} @@ -114,9 +184,46 @@ class DenseTensorColumn : public ColumnInterface<InternalType> { const Tensor& tensor_; }; +// A column that is backed by a dense tensor. +template <typename InternalType> +class KeyedDenseTensorColumn : public ColumnInterface<InternalType> { + public: + explicit KeyedDenseTensorColumn(const Tensor& tensor, std::vector<int64> key) + : tensor_(tensor) { + std::memcpy(key_, key.data(), sizeof(key_)); + } + + int64 FeatureCount(int64 batch) const override { return tensor_.dim_size(1); } + + InternalType Feature(int64 batch, int64 n, bool strong_hash) const override; + + ~KeyedDenseTensorColumn() override {} + + private: + const Tensor& tensor_; + uint64 key_[2]; +}; + // InternalType is int64 only when using HashCrosser. template <> -int64 DenseTensorColumn<int64>::Feature(int64 batch, int64 n) const { +int64 DenseTensorColumn<int64>::Feature(int64 batch, int64 n, + bool strong_hash) const { + if (DT_STRING == tensor_.dtype()) + return Fingerprint64(tensor_.matrix<tstring>()(batch, n)); + return tensor_.matrix<int64>()(batch, n); +} + +template <> +int64 KeyedDenseTensorColumn<int64>::Feature(int64 batch, int64 n, + bool strong_hash) const { + if (strong_hash) { + if (DT_STRING == tensor_.dtype()) { + return StrongKeyedHash(key_, tensor_.matrix<tstring>()(batch, n)); + } + return StrongKeyedHash( + key_, {reinterpret_cast<const char*>(tensor_.matrix<int64>()(batch, n)), + sizeof(tensor_.dtype())}); + } if (DT_STRING == tensor_.dtype()) return Fingerprint64(tensor_.matrix<tstring>()(batch, n)); return tensor_.matrix<int64>()(batch, n); @@ -124,14 +231,28 @@ int64 DenseTensorColumn<int64>::Feature(int64 batch, int64 n) const { // Internal type is string or StringPiece when using StringCrosser. template <> -tstring DenseTensorColumn<tstring>::Feature(int64 batch, int64 n) const { +tstring DenseTensorColumn<tstring>::Feature(int64 batch, int64 n, + bool strong_hash) const { if (DT_STRING == tensor_.dtype()) return tensor_.matrix<tstring>()(batch, n); return std::to_string(tensor_.matrix<int64>()(batch, n)); } template <> -StringPiece DenseTensorColumn<StringPiece>::Feature(int64 batch, - int64 n) const { +tstring KeyedDenseTensorColumn<tstring>::Feature(int64 batch, int64 n, + bool strong_hash) const { + if (DT_STRING == tensor_.dtype()) return tensor_.matrix<tstring>()(batch, n); + return std::to_string(tensor_.matrix<int64>()(batch, n)); +} + +template <> +StringPiece DenseTensorColumn<StringPiece>::Feature(int64 batch, int64 n, + bool strong_hash) const { + return tensor_.matrix<tstring>()(batch, n); +} + +template <> +StringPiece KeyedDenseTensorColumn<StringPiece>::Feature( + int64 batch, int64 n, bool strong_hash) const { return tensor_.matrix<tstring>()(batch, n); } @@ -169,24 +290,24 @@ class StringCrosser { public: StringCrosser(const std::vector< std::unique_ptr<ColumnInterface<InternalType>>>& columns, - const int64 num_buckets_unused, const uint64 hash_key_unused) - : columns_(columns) {} - - string Generate(const int64 batch_index, - const std::vector<int>& permutation) const { - static const auto k_feature_separator = "_X_"; + const int64 num_buckets_unused, const uint64 hash_key_unused, + const tstring k_feature_separator) + : columns_(columns), k_feature_separator_(k_feature_separator) {} + string Generate(const int64 batch_index, const std::vector<int>& permutation, + bool unused_strong_hash) const { gtl::InlinedVector<InternalType, 6> cross_vec(columns_.size()); for (int i = 0; i < permutation.size(); i++) { - cross_vec[i] = columns_[i]->Feature(batch_index, permutation[i]); + cross_vec[i] = columns_[i]->Feature(batch_index, permutation[i], false); } // TODO(zakaria): this will copy the string twice, might effect // performance. - return absl::StrJoin(cross_vec, k_feature_separator); + return absl::StrJoin(cross_vec, k_feature_separator_); } private: const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>& columns_; + const tstring k_feature_separator_; }; // Generates the sparse crosses as nested hash to avoid string manipulations. @@ -194,15 +315,16 @@ class HashCrosser { public: HashCrosser( const std::vector<std::unique_ptr<ColumnInterface<int64>>>& columns, - const int64 num_buckets, const uint64 hash_key) + const int64 num_buckets, const uint64 hash_key, + const tstring k_feature_separator_unused) : columns_(columns), num_buckets_(num_buckets), hash_key_(hash_key) {} - int64 Generate(const int64 batch_index, - const std::vector<int>& permutation) const { + int64 Generate(const int64 batch_index, const std::vector<int>& permutation, + bool unused_strong_hash) const { // Do the fingerprint concatenation on uint64. uint64 hashed_output = hash_key_; for (size_t i = 0; i < permutation.size(); ++i) { - uint64 hash_i = columns_[i]->Feature(batch_index, permutation[i]); + uint64 hash_i = columns_[i]->Feature(batch_index, permutation[i], false); hashed_output = FingerprintCat64(hashed_output, hash_i); } // The return value is int64 based on the number of buckets. @@ -220,6 +342,39 @@ class HashCrosser { const uint64 hash_key_; }; +// Generates the sparse crosses as nested hash to avoid string manipulations. +class HashCrosserV2 { + public: + HashCrosserV2( + const std::vector<std::unique_ptr<ColumnInterface<int64>>>& columns, + const int64 num_buckets, const uint64 hash_key_unused, + const tstring k_feature_separator_unused) + : columns_(columns), num_buckets_(num_buckets) {} + + int64 Generate(const int64 batch_index, const std::vector<int>& permutation, + bool strong_hash) const { + // Do the fingerprint concatenation on uint64. + uint64 hashed_output = + columns_[0]->Feature(batch_index, permutation[0], strong_hash); + for (size_t i = 1; i < permutation.size(); ++i) { + uint64 hash_i = + columns_[i]->Feature(batch_index, permutation[i], strong_hash); + hashed_output = FingerprintCat64(hashed_output, hash_i); + } + // The return value is int64 based on the number of buckets. + if (num_buckets_ > 0) { + return hashed_output % num_buckets_; + } else { + // To prevent negative output we take modulo to max int64. + return hashed_output % std::numeric_limits<int64>::max(); + } + } + + private: + const std::vector<std::unique_ptr<ColumnInterface<int64>>>& columns_; + const int64 num_buckets_; +}; + // ProductIterator generates cartesian products based on indices. template <typename InternalType> class ProductIterator { @@ -275,16 +430,264 @@ struct CrossTraits; template <typename InternalType> struct CrossTraits<false, InternalType> { typedef StringCrosser<InternalType> Crosser; + typedef StringCrosser<InternalType> CrosserV2; typedef OutputUpdater<tstring> Updater; }; template <> struct CrossTraits<true, int64> { typedef HashCrosser Crosser; + typedef HashCrosserV2 CrosserV2; typedef OutputUpdater<int64> Updater; }; } // namespace +// Calculate the batch size from either the shapes input or the dense input. +int64 CalculateBatchSize(const OpInputList& shapes_list_in, + const OpInputList& dense_list_in) { + if (shapes_list_in.size() > 0) { + return shapes_list_in[0].vec<int64>()(0); + } + + if (dense_list_in.size() > 0) { + return dense_list_in[0].dim_size(0); + } + + return 0; +} + +// Validates input tensors. +Status ValidateInput(const OpInputList& indices_list_in, + const OpInputList& values_list_in, + const OpInputList& shapes_list_in, + const OpInputList& dense_list_in) { + const auto size = indices_list_in.size(); + // Validates indices_list_in OpInputList. + for (int i = 0; i < size; i++) { + if (!TensorShapeUtils::IsMatrix(indices_list_in[i].shape())) { + return errors::InvalidArgument( + "Input indices should be a matrix but received shape ", + indices_list_in[i].shape().DebugString(), " at position ", i); + } + if (indices_list_in[i].shape().dim_size(1) != 2) { + return errors::InvalidArgument("Expected D2 of index to be 2 got ", + indices_list_in[i].shape().dim_size(1), + " at position ", i); + } + } + + // Validates values_list_in OpInputList. + if (values_list_in.size() != size) { + return errors::InvalidArgument("Expected ", size, " input values, got ", + values_list_in.size()); + } + for (int i = 0; i < size; i++) { + if (!TensorShapeUtils::IsVector(values_list_in[i].shape())) { + return errors::InvalidArgument( + "Input values should be a vector but received shape ", + values_list_in[i].shape().DebugString(), " at position ", i); + } + if (indices_list_in[i].shape().dim_size(0) != + values_list_in[i].shape().dim_size(0)) { + return errors::InvalidArgument( + "Expected size of values to be ", + indices_list_in[i].shape().dim_size(0), " got ", + values_list_in[i].shape().dim_size(0), " at position ", i); + } + } + + // Validates shapes_list_in OpInputList + if (shapes_list_in.size() != size) { + return errors::InvalidArgument("Expected ", size, " input shapes, got ", + shapes_list_in.size()); + } + for (int i = 0; i < size; i++) { + if (!TensorShapeUtils::IsVector(shapes_list_in[i].shape())) { + return errors::InvalidArgument( + "Input shapes should be a vector but received shape ", + shapes_list_in[i].shape().DebugString(), " at position ", i); + } + + if (shapes_list_in[i].vec<int64>().size() != 2) { + return errors::InvalidArgument("shape should imply a 2D tensor, but got ", + shapes_list_in[i].shape().DebugString(), + " at position ", i); + } + } + + // Validates dense_list_in OpInputList + for (int i = 0; i < dense_list_in.size(); ++i) { + if (!TensorShapeUtils::IsMatrix(dense_list_in[i].shape())) { + return errors::InvalidArgument( + "Dense inputs should be a matrix but received shape ", + dense_list_in[i].shape().DebugString(), " at position ", i); + } + } + + // Validates batch sizes. (Note: we do this after validating the input + // shapes, because CalculateBatchSize() depends on inputs having valid + // shapes). + const auto batch_size = CalculateBatchSize(shapes_list_in, dense_list_in); + for (int i = 0; i < size; i++) { + if (shapes_list_in[i].vec<int64>()(0) != batch_size) { + return errors::InvalidArgument("Expected batch size ", batch_size, + " got ", shapes_list_in[i].vec<int64>()(0), + " at position ", i); + } + } + for (int i = 0; i < dense_list_in.size(); ++i) { + if (dense_list_in[i].dim_size(0) != batch_size) { + return errors::InvalidArgument("Expected batch size ", batch_size, + " got ", dense_list_in[i].dim_size(0), + " at dense tensor ", i); + } + } + + return Status::OK(); +} + +// Extracts data about the features and populates feature data. +void ExtractFeatureData( + const OpInputList& indices_list_in, int64 batch_size, + std::vector<std::vector<int64>>* feature_counts, + std::vector<std::vector<int64>>* feature_start_indices) { + gtl::InlinedVector<int64, 8> current_row(indices_list_in.size(), 0); + for (int b = 0; b < batch_size; b++) { + for (int i = 0; i < indices_list_in.size(); i++) { + const auto indices = indices_list_in[i].matrix<int64>(); + int64 feature_count = 0; + int64 start_index = current_row[i]; + // Loops until we reach next batch index for current feature column. + while (current_row[i] < indices_list_in[i].dim_size(0) && + indices(current_row[i], 0) == b) { + feature_count++; + current_row[i]++; + } + (*feature_counts)[i].push_back(feature_count); + (*feature_start_indices)[i].push_back(start_index); + } + } +} + +// Returns number of crosses for a given batch_index +template <typename InternalType> +int64 CrossCountByBatchIndex( + const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>& columns, + int batch_index) { + int64 cross_count = 1; + for (int i = 0; i < columns.size(); i++) { + const auto feature_count = columns[i]->FeatureCount(batch_index); + // If one column is missing any feature, there won't be any cross. + if (feature_count == 0) { + return 0; + } + cross_count *= feature_count; + } + return cross_count; +} + +// Generate the columns given the sparse and dense inputs. +template <typename InternalType> +std::vector<std::unique_ptr<ColumnInterface<InternalType>>> +GenerateColumnsFromInput(const OpInputList& indices_list_in, + const OpInputList& values_list_in, + const OpInputList& shapes_list_in, + const OpInputList& dense_list_in) { + std::vector<std::unique_ptr<ColumnInterface<InternalType>>> columns; + const int64 batch_size = CalculateBatchSize(shapes_list_in, dense_list_in); + const int64 number_of_columns = shapes_list_in.size(); + + std::vector<std::vector<int64>> feature_counts(number_of_columns, + std::vector<int64>()); + std::vector<std::vector<int64>> feature_start_indices(number_of_columns, + std::vector<int64>()); + + ExtractFeatureData(indices_list_in, batch_size, &feature_counts, + &feature_start_indices); + + columns.reserve(values_list_in.size()); + for (int i = 0; i < values_list_in.size(); ++i) { + columns.emplace_back(new SparseTensorColumn<InternalType>( + values_list_in[i], std::move(feature_counts[i]), + std::move(feature_start_indices[i]))); + } + for (int i = 0; i < dense_list_in.size(); ++i) { + columns.emplace_back(new DenseTensorColumn<InternalType>(dense_list_in[i])); + } + + return columns; +} + +// Generate the columns given the sparse and dense inputs. +template <typename InternalType> +std::vector<std::unique_ptr<ColumnInterface<InternalType>>> +GenerateKeyedColumnsFromInput(const OpInputList& indices_list_in, + const OpInputList& values_list_in, + const OpInputList& shapes_list_in, + const OpInputList& dense_list_in, + std::vector<int64> keys) { + std::vector<std::unique_ptr<ColumnInterface<InternalType>>> columns; + const int64 batch_size = CalculateBatchSize(shapes_list_in, dense_list_in); + const int64 number_of_columns = shapes_list_in.size(); + + std::vector<std::vector<int64>> feature_counts(number_of_columns, + std::vector<int64>()); + std::vector<std::vector<int64>> feature_start_indices(number_of_columns, + std::vector<int64>()); + + ExtractFeatureData(indices_list_in, batch_size, &feature_counts, + &feature_start_indices); + + columns.reserve(values_list_in.size()); + for (int i = 0; i < values_list_in.size(); ++i) { + columns.emplace_back(new KeyedSparseTensorColumn<InternalType>( + values_list_in[i], std::move(feature_counts[i]), + std::move(feature_start_indices[i]), keys)); + } + for (int i = 0; i < dense_list_in.size(); ++i) { + columns.emplace_back( + new KeyedDenseTensorColumn<InternalType>(dense_list_in[i], keys)); + } + + return columns; +} + +// Allocates output tensors with proper size and sets the shape tensor of +// the output SparseTensor. +// It also output_start_indices which contains the start indices for each +// input in the output SparseTensor. +template <typename InternalType> +Status CreateOutputTensors( + const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>& columns, + int64 batch_size, OpKernelContext* context, Tensor** indices_out, + Tensor** values_out, Tensor** shape_out, + std::vector<int64>* output_start_indices) { + // Calculates dimensions for output tensors. + int64 cross_count_total = 0; + int64 max_cross_count = 0; + for (int64 b = 0; b < batch_size; b++) { + // For each input, sets starting indices in output SparseTensor + (*output_start_indices)[b] = cross_count_total; + const auto cross_count = CrossCountByBatchIndex(columns, b); + max_cross_count = std::max(max_cross_count, cross_count); + cross_count_total += cross_count; + } + + // Allocates tensors. + TF_RETURN_IF_ERROR(context->allocate_output( + 0, TensorShape({cross_count_total, 2}), indices_out)); + TF_RETURN_IF_ERROR(context->allocate_output( + 1, TensorShape({cross_count_total}), values_out)); + TF_RETURN_IF_ERROR(context->allocate_output(2, TensorShape({2}), shape_out)); + + // Sets shape. + auto shape_vec = (*shape_out)->vec<int64>(); + shape_vec(0) = batch_size; + shape_vec(1) = max_cross_count; + + return Status::OK(); +} + template <bool HASHED_OUTPUT, typename InternalType> class SparseCrossOp : public OpKernel { public: @@ -312,11 +715,12 @@ class SparseCrossOp : public OpKernel { shapes_list_in, dense_list_in)); std::vector<std::unique_ptr<ColumnInterface<InternalType>>> columns = - GenerateColumnsFromInput(indices_list_in, values_list_in, - shapes_list_in, dense_list_in); + GenerateColumnsFromInput<InternalType>(indices_list_in, values_list_in, + shapes_list_in, dense_list_in); + const tstring k_feature_separator = "_X_"; typename CrossTraits<HASHED_OUTPUT, InternalType>::Crosser crosser( - columns, num_buckets_, hash_key_); + columns, num_buckets_, hash_key_, k_feature_separator); Tensor* indices_out; Tensor* values_out; Tensor* shape_out; @@ -335,7 +739,8 @@ class SparseCrossOp : public OpKernel { int64 cross_count = 0; while (product_iterator.HasNext()) { const auto permutation = product_iterator.Next(); - updater.Update(b, cross_count, crosser.Generate(b, permutation)); + updater.Update(b, cross_count, + crosser.Generate(b, permutation, false)); cross_count++; } } @@ -349,222 +754,138 @@ class SparseCrossOp : public OpKernel { } private: - // Validates input tensors. - Status ValidateInput(const OpInputList& indices_list_in, - const OpInputList& values_list_in, - const OpInputList& shapes_list_in, - const OpInputList& dense_list_in) { - const auto size = indices_list_in.size(); - // Validates indices_list_in OpInputList. - for (int i = 0; i < size; i++) { - if (!TensorShapeUtils::IsMatrix(indices_list_in[i].shape())) { - return errors::InvalidArgument( - "Input indices should be a matrix but received shape ", - indices_list_in[i].shape().DebugString(), " at position ", i); - } - if (indices_list_in[i].shape().dim_size(1) != 2) { - return errors::InvalidArgument("Expected D2 of index to be 2 got ", - indices_list_in[i].shape().dim_size(1), - " at position ", i); - } - } - - // Validates values_list_in OpInputList. - if (values_list_in.size() != size) { - return errors::InvalidArgument("Expected ", size, " input values, got ", - values_list_in.size()); - } - for (int i = 0; i < size; i++) { - if (!TensorShapeUtils::IsVector(values_list_in[i].shape())) { - return errors::InvalidArgument( - "Input values should be a vector but received shape ", - values_list_in[i].shape().DebugString(), " at position ", i); - } - if (indices_list_in[i].shape().dim_size(0) != - values_list_in[i].shape().dim_size(0)) { - return errors::InvalidArgument( - "Expected size of values to be ", - indices_list_in[i].shape().dim_size(0), " got ", - values_list_in[i].shape().dim_size(0), " at position ", i); - } - } - - // Validates shapes_list_in OpInputList - if (shapes_list_in.size() != size) { - return errors::InvalidArgument("Expected ", size, " input shapes, got ", - shapes_list_in.size()); - } - for (int i = 0; i < size; i++) { - if (!TensorShapeUtils::IsVector(shapes_list_in[i].shape())) { - return errors::InvalidArgument( - "Input shapes should be a vector but received shape ", - shapes_list_in[i].shape().DebugString(), " at position ", i); - } - - if (shapes_list_in[i].vec<int64>().size() != 2) { - return errors::InvalidArgument( - "shape should imply a 2D tensor, but got ", - shapes_list_in[i].shape().DebugString(), " at position ", i); - } - } - - // Validates dense_list_in OpInputList - for (int i = 0; i < dense_list_in.size(); ++i) { - if (!TensorShapeUtils::IsMatrix(dense_list_in[i].shape())) { - return errors::InvalidArgument( - "Dense inputs should be a matrix but received shape ", - dense_list_in[i].shape().DebugString(), " at position ", i); - } - } - - // Validates batch sizes. (Note: we do this after validating the input - // shapes, because CalculateBatchSize() depends on inputs having valid - // shapes). - const auto batch_size = CalculateBatchSize(shapes_list_in, dense_list_in); - for (int i = 0; i < size; i++) { - if (shapes_list_in[i].vec<int64>()(0) != batch_size) { - return errors::InvalidArgument( - "Expected batch size ", batch_size, " got ", - shapes_list_in[i].vec<int64>()(0), " at position ", i); - } - } - for (int i = 0; i < dense_list_in.size(); ++i) { - if (dense_list_in[i].dim_size(0) != batch_size) { - return errors::InvalidArgument("Expected batch size ", batch_size, - " got ", dense_list_in[i].dim_size(0), - " at dense tensor ", i); - } - } - - return Status::OK(); - } - - // Calculate the batch size from either the shapes input or the dense input. - int64 CalculateBatchSize(const OpInputList& shapes_list_in, - const OpInputList& dense_list_in) { - if (shapes_list_in.size() > 0) { - return shapes_list_in[0].vec<int64>()(0); - } - - if (dense_list_in.size() > 0) { - return dense_list_in[0].dim_size(0); - } - - return 0; - } - - // Generate the columns given the sparse and dense inputs. - std::vector<std::unique_ptr<ColumnInterface<InternalType>>> - GenerateColumnsFromInput(const OpInputList& indices_list_in, - const OpInputList& values_list_in, - const OpInputList& shapes_list_in, - const OpInputList& dense_list_in) { - std::vector<std::unique_ptr<ColumnInterface<InternalType>>> columns; - const int64 batch_size = CalculateBatchSize(shapes_list_in, dense_list_in); - const int64 number_of_columns = shapes_list_in.size(); - - std::vector<std::vector<int64>> feature_counts(number_of_columns, - std::vector<int64>()); - std::vector<std::vector<int64>> feature_start_indices(number_of_columns, - std::vector<int64>()); - - ExtractFeatureData(indices_list_in, batch_size, &feature_counts, - &feature_start_indices); - - columns.reserve(values_list_in.size()); - for (int i = 0; i < values_list_in.size(); ++i) { - columns.emplace_back(new SparseTensorColumn<InternalType>( - values_list_in[i], std::move(feature_counts[i]), - std::move(feature_start_indices[i]))); - } - for (int i = 0; i < dense_list_in.size(); ++i) { - columns.emplace_back( - new DenseTensorColumn<InternalType>(dense_list_in[i])); - } - - return columns; - } - - // Extracts data about the features and populates feature data. - void ExtractFeatureData( - const OpInputList& indices_list_in, int64 batch_size, - std::vector<std::vector<int64>>* feature_counts, - std::vector<std::vector<int64>>* feature_start_indices) { - gtl::InlinedVector<int64, 8> current_row(indices_list_in.size(), 0); - for (int b = 0; b < batch_size; b++) { - for (int i = 0; i < indices_list_in.size(); i++) { - const auto indices = indices_list_in[i].matrix<int64>(); - int64 feature_count = 0; - int64 start_index = current_row[i]; - // Loops until we reach next batch index for current feature column. - while (current_row[i] < indices_list_in[i].dim_size(0) && - indices(current_row[i], 0) == b) { - feature_count++; - current_row[i]++; - } - (*feature_counts)[i].push_back(feature_count); - (*feature_start_indices)[i].push_back(start_index); - } - } - } - - // Allocates output tensors with proper size and sets the shape tensor of - // the output SparseTensor. - // It also output_start_indices which contains the start indices for each - // input in the output SparseTensor. - Status CreateOutputTensors( - const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>& - columns, - int64 batch_size, OpKernelContext* context, Tensor** indices_out, - Tensor** values_out, Tensor** shape_out, - std::vector<int64>* output_start_indices) { - // Calculates dimensions for output tensors. - int64 cross_count_total = 0; - int64 max_cross_count = 0; - for (int64 b = 0; b < batch_size; b++) { - // For each input, sets starting indices in output SparseTensor - (*output_start_indices)[b] = cross_count_total; - const auto cross_count = CrossCountByBatchIndex(columns, b); - max_cross_count = std::max(max_cross_count, cross_count); - cross_count_total += cross_count; - } - - // Allocates tensors. - TF_RETURN_IF_ERROR(context->allocate_output( - 0, TensorShape({cross_count_total, 2}), indices_out)); - TF_RETURN_IF_ERROR(context->allocate_output( - 1, TensorShape({cross_count_total}), values_out)); - TF_RETURN_IF_ERROR( - context->allocate_output(2, TensorShape({2}), shape_out)); - - // Sets shape. - auto shape_vec = (*shape_out)->vec<int64>(); - shape_vec(0) = batch_size; - shape_vec(1) = max_cross_count; - - return Status::OK(); - } - - // Returns number of crosses for a given batch_index - int64 CrossCountByBatchIndex( - const std::vector<std::unique_ptr<ColumnInterface<InternalType>>>& - columns, - int batch_index) { - int64 cross_count = 1; - for (int i = 0; i < columns.size(); i++) { - const auto feature_count = columns[i]->FeatureCount(batch_index); - // If one column is missing any feature, there won't be any cross. - if (feature_count == 0) { - return 0; - } - cross_count *= feature_count; - } - return cross_count; - } int64 num_buckets_; uint64 hash_key_; }; +class SparseCrossV2Op : public OpKernel { + public: + explicit SparseCrossV2Op(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + OpInputList indices_list_in; + OP_REQUIRES_OK(context, context->input_list("indices", &indices_list_in)); + OpInputList values_list_in; + OP_REQUIRES_OK(context, context->input_list("values", &values_list_in)); + OpInputList shapes_list_in; + OP_REQUIRES_OK(context, context->input_list("shapes", &shapes_list_in)); + OpInputList dense_list_in; + OP_REQUIRES_OK(context, + context->input_list("dense_inputs", &dense_list_in)); + + OP_REQUIRES_OK(context, ValidateInput(indices_list_in, values_list_in, + shapes_list_in, dense_list_in)); + + const Tensor* sep_t; + OP_REQUIRES_OK(context, context->input("sep", &sep_t)); + const tstring separator = sep_t->scalar<tstring>()(); + + std::vector<std::unique_ptr<ColumnInterface<tstring>>> columns = + GenerateColumnsFromInput<tstring>(indices_list_in, values_list_in, + shapes_list_in, dense_list_in); + Tensor* indices_out; + Tensor* values_out; + Tensor* shape_out; + const int64 batch_size = CalculateBatchSize(shapes_list_in, dense_list_in); + std::vector<int64> output_start_indices(batch_size); + OP_REQUIRES_OK( + context, + CreateOutputTensors(columns, batch_size, context, &indices_out, + &values_out, &shape_out, &output_start_indices)); + StringCrosser<tstring> crosser(columns, 0, 0, separator); + OutputUpdater<tstring> updater(output_start_indices, indices_out, + values_out); + auto do_work = [&columns, crosser, updater](int64 begin, int64 end) { + for (int b = begin; b < end; b++) { + ProductIterator<tstring> product_iterator(columns, b); + int64 cross_count = 0; + while (product_iterator.HasNext()) { + const auto permutation = product_iterator.Next(); + updater.Update(b, cross_count, + crosser.Generate(b, permutation, false)); + cross_count++; + } + } + }; + + auto* worker_threads = context->device()->tensorflow_cpu_worker_threads(); + // TODO(zakaria): optimize kCostPerUnit + const int kCostPerUnit = 5000 * indices_list_in.size(); + Shard(worker_threads->num_threads, worker_threads->workers, batch_size, + kCostPerUnit, do_work); + } +}; + +class SparseCrossHashedOp : public OpKernel { + public: + explicit SparseCrossHashedOp(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + OpInputList indices_list_in; + OP_REQUIRES_OK(context, context->input_list("indices", &indices_list_in)); + OpInputList values_list_in; + OP_REQUIRES_OK(context, context->input_list("values", &values_list_in)); + OpInputList shapes_list_in; + OP_REQUIRES_OK(context, context->input_list("shapes", &shapes_list_in)); + OpInputList dense_list_in; + OP_REQUIRES_OK(context, + context->input_list("dense_inputs", &dense_list_in)); + + OP_REQUIRES_OK(context, ValidateInput(indices_list_in, values_list_in, + shapes_list_in, dense_list_in)); + + const Tensor* num_buckets_t; + OP_REQUIRES_OK(context, context->input("num_buckets", &num_buckets_t)); + const int64 num_buckets = num_buckets_t->scalar<int64>()(); + + const Tensor* strong_hash_t; + OP_REQUIRES_OK(context, context->input("strong_hash", &strong_hash_t)); + const bool strong_hash = strong_hash_t->scalar<bool>()(); + + const Tensor* salt_t; + OP_REQUIRES_OK(context, context->input("salt", &salt_t)); + const auto salt = salt_t->flat<int64>(); + std::vector<int64> key_{salt(0), salt(1)}; + + std::vector<std::unique_ptr<ColumnInterface<int64>>> columns = + GenerateKeyedColumnsFromInput<int64>(indices_list_in, values_list_in, + shapes_list_in, dense_list_in, + key_); + Tensor* indices_out; + Tensor* values_out; + Tensor* shape_out; + const int64 batch_size = CalculateBatchSize(shapes_list_in, dense_list_in); + std::vector<int64> output_start_indices(batch_size); + OP_REQUIRES_OK( + context, + CreateOutputTensors(columns, batch_size, context, &indices_out, + &values_out, &shape_out, &output_start_indices)); + const tstring unused_sep; + HashCrosserV2 crosser(columns, num_buckets, 0, unused_sep); + OutputUpdater<int64> updater(output_start_indices, indices_out, values_out); + auto do_work = [&columns, crosser, updater, strong_hash](int64 begin, + int64 end) { + for (int b = begin; b < end; b++) { + ProductIterator<int64> product_iterator(columns, b); + int64 cross_count = 0; + while (product_iterator.HasNext()) { + const auto permutation = product_iterator.Next(); + updater.Update(b, cross_count, + crosser.Generate(b, permutation, strong_hash)); + cross_count++; + } + } + }; + + auto* worker_threads = context->device()->tensorflow_cpu_worker_threads(); + // TODO(zakaria): optimize kCostPerUnit + const int kCostPerUnit = 5000 * indices_list_in.size(); + Shard(worker_threads->num_threads, worker_threads->workers, batch_size, + kCostPerUnit, do_work); + } +}; + REGISTER_KERNEL_BUILDER(Name("SparseCross") .Device(DEVICE_CPU) .TypeConstraint<tstring>("out_type") @@ -589,4 +910,10 @@ REGISTER_KERNEL_BUILDER(Name("SparseCross") .TypeConstraint<int64>("internal_type"), SparseCrossOp<true, int64>); +REGISTER_KERNEL_BUILDER(Name("SparseCrossV2").Device(DEVICE_CPU), + SparseCrossV2Op); + +REGISTER_KERNEL_BUILDER(Name("SparseCrossHashed").Device(DEVICE_CPU), + SparseCrossHashedOp); + } // namespace tensorflow diff --git a/tensorflow/core/kernels/tile_functor_cpu_uint32.cc b/tensorflow/core/kernels/tile_functor_cpu_uint32.cc new file mode 100644 index 00000000000..4dd44eeea0f --- /dev/null +++ b/tensorflow/core/kernels/tile_functor_cpu_uint32.cc @@ -0,0 +1,29 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/kernels/tile_functor_cpu.h" + +namespace tensorflow { +namespace functor { + +typedef Eigen::ThreadPoolDevice CPUDevice; + +template struct Tile<CPUDevice, uint32, int32>; +template struct Tile<CPUDevice, uint32, int64>; + +} // end namespace functor +} // end namespace tensorflow diff --git a/tensorflow/core/kernels/tile_functor_cpu_uint64.cc b/tensorflow/core/kernels/tile_functor_cpu_uint64.cc new file mode 100644 index 00000000000..ec1eb7b0946 --- /dev/null +++ b/tensorflow/core/kernels/tile_functor_cpu_uint64.cc @@ -0,0 +1,29 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/kernels/tile_functor_cpu.h" + +namespace tensorflow { +namespace functor { + +typedef Eigen::ThreadPoolDevice CPUDevice; + +template struct Tile<CPUDevice, uint64, int32>; +template struct Tile<CPUDevice, uint64, int64>; + +} // end namespace functor +} // end namespace tensorflow diff --git a/tensorflow/core/kernels/tile_functor_cpu_variant.cc b/tensorflow/core/kernels/tile_functor_cpu_variant.cc new file mode 100644 index 00000000000..9ecfb4e9fe1 --- /dev/null +++ b/tensorflow/core/kernels/tile_functor_cpu_variant.cc @@ -0,0 +1,30 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/framework/variant.h" +#include "tensorflow/core/kernels/tile_functor_cpu.h" + +namespace tensorflow { +namespace functor { + +typedef Eigen::ThreadPoolDevice CPUDevice; + +template struct Tile<CPUDevice, Variant, int32>; +template struct Tile<CPUDevice, Variant, int64>; + +} // end namespace functor +} // end namespace tensorflow diff --git a/tensorflow/core/kernels/tile_ops.cc b/tensorflow/core/kernels/tile_ops.cc index cd047ed9d4a..e626d430864 100644 --- a/tensorflow/core/kernels/tile_ops.cc +++ b/tensorflow/core/kernels/tile_ops.cc @@ -139,10 +139,13 @@ TF_CALL_uint8(DECLARE_TYPE); TF_CALL_int32(DECLARE_TYPE); TF_CALL_int16(DECLARE_TYPE); TF_CALL_int64(DECLARE_TYPE); +TF_CALL_uint32(DECLARE_TYPE); +TF_CALL_uint64(DECLARE_TYPE); TF_CALL_half(DECLARE_TYPE); TF_CALL_complex64(DECLARE_TYPE); TF_CALL_complex128(DECLARE_TYPE); TF_CALL_tstring(DECLARE_TYPE); +TF_CALL_variant(DECLARE_TYPE); #undef DECLARE_TYPE #define DECLARE_DIM(T, NDIM) \ @@ -240,10 +243,13 @@ class TileOp : public OpKernel { TF_CALL_int32(HANDLE_TYPE_NAME); TF_CALL_int16(HANDLE_TYPE_NAME); TF_CALL_int64(HANDLE_TYPE_NAME); + TF_CALL_uint32(HANDLE_TYPE_NAME); + TF_CALL_uint64(HANDLE_TYPE_NAME); TF_CALL_half(HANDLE_TYPE_NAME); TF_CALL_tstring(HANDLE_TYPE_NAME); // when DEVICE=CPUDevice. TF_CALL_complex64(HANDLE_TYPE_NAME); TF_CALL_complex128(HANDLE_TYPE_NAME); + TF_CALL_variant(HANDLE_TYPE_NAME); // when DEVICE=CPUDevice #undef HANDLE_TYPE_NAME #undef HANDLE_TYPE @@ -319,10 +325,13 @@ TF_CALL_int8(HANDLE_TYPE_NAME_CPU); TF_CALL_int32(HANDLE_TYPE_NAME_CPU); TF_CALL_int16(HANDLE_TYPE_NAME_CPU); TF_CALL_int64(HANDLE_TYPE_NAME_CPU); +TF_CALL_uint32(HANDLE_TYPE_NAME_CPU); +TF_CALL_uint64(HANDLE_TYPE_NAME_CPU); TF_CALL_half(HANDLE_TYPE_NAME_CPU); TF_CALL_complex64(HANDLE_TYPE_NAME_CPU); TF_CALL_complex128(HANDLE_TYPE_NAME_CPU); TF_CALL_tstring(HANDLE_TYPE_NAME_CPU); +TF_CALL_variant(HANDLE_TYPE_NAME_CPU); #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM TF_CALL_bool(HANDLE_TYPE_NAME_GPU); diff --git a/tensorflow/core/lib/bfloat16/bfloat16.h b/tensorflow/core/lib/bfloat16/bfloat16.h index 4c38738593f..54d78480066 100644 --- a/tensorflow/core/lib/bfloat16/bfloat16.h +++ b/tensorflow/core/lib/bfloat16/bfloat16.h @@ -194,171 +194,170 @@ struct bfloat16 { input = f.u; bfloat16 output; + // Fast rounding algorithm that rounds a half value to nearest even. This + // reduces expected error when we convert a large number of floats. Here + // is how it works: + // + // Definitions: + // To convert a float 32 to bfloat16, a float 32 can be viewed as 32 bits + // with the following tags: + // + // Sign | Exp (8 bits) | Frac (23 bits) + // S EEEEEEEE FFFFFFLRTTTTTTTTTTTTTTT + // + // S: Sign bit. + // E: Exponent bits. + // F: First 6 bits of fraction. + // L: Least significant bit of resulting bfloat16 if we truncate away the + // rest of the float32. This is also the 7th bit of fraction + // R: Rounding bit, 8th bit of fraction. + // T: Sticky bits, rest of fraction, 15 bits. + // + // To round half to nearest even, there are 3 cases where we want to round + // down (simply truncate the result of the bits away, which consists of + // rounding bit and sticky bits) and two cases where we want to round up + // (truncate then add one to the result). + // + // The fast converting algorithm simply adds lsb (L) to 0x7fff (15 bits of + // 1s) as the rounding bias, adds the rounding bias to the input, then + // truncates the last 16 bits away. + // + // To understand how it works, we can analyze this algorithm case by case: + // + // 1. L = 0, R = 0: + // Expect: round down, this is less than half value. + // + // Algorithm: + // - Rounding bias: 0x7fff + 0 = 0x7fff + // - Adding rounding bias to input may create any carry, depending on + // whether there is any value set to 1 in T bits. + // - R may be set to 1 if there is a carry. + // - L remains 0. + // - Note that this case also handles Inf and -Inf, where all fraction + // bits, including L, R and Ts are all 0. The output remains Inf after + // this algorithm. + // + // 2. L = 1, R = 0: + // Expect: round down, this is less than half value. + // + // Algorithm: + // - Rounding bias: 0x7fff + 1 = 0x8000 + // - Adding rounding bias to input doesn't change sticky bits but + // adds 1 to rounding bit. + // - L remains 1. + // + // 3. L = 0, R = 1, all of T are 0: + // Expect: round down, this is exactly at half, the result is already + // even (L=0). + // + // Algorithm: + // - Rounding bias: 0x7fff + 0 = 0x7fff + // - Adding rounding bias to input sets all sticky bits to 1, but + // doesn't create a carry. + // - R remains 1. + // - L remains 0. + // + // 4. L = 1, R = 1: + // Expect: round up, this is exactly at half, the result needs to be + // round to the next even number. + // + // Algorithm: + // - Rounding bias: 0x7fff + 1 = 0x8000 + // - Adding rounding bias to input doesn't change sticky bits, but + // creates a carry from rounding bit. + // - The carry sets L to 0, creates another carry bit and propagate + // forward to F bits. + // - If all the F bits are 1, a carry then propagates to the exponent + // bits, which then creates the minimum value with the next exponent + // value. Note that we won't have the case where exponents are all 1, + // since that's either a NaN (handled in the other if condition) or inf + // (handled in case 1). + // + // 5. L = 0, R = 1, any of T is 1: + // Expect: round up, this is greater than half. + // + // Algorithm: + // - Rounding bias: 0x7fff + 0 = 0x7fff + // - Adding rounding bias to input creates a carry from sticky bits, + // sets rounding bit to 0, then create another carry. + // - The second carry sets L to 1. + // + // Examples: + // + // Exact half value that is already even: + // Input: + // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit) + // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT + // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1000000000000000 + // + // This falls into case 3. We truncate the rest of 16 bits and no + // carry is created into F and L: + // + // Output: + // Sign | Exp (8 bit) | Frac (first 7 bit) + // S E E E E E E E E F F F F F F L + // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 + // + // Exact half value, round to next even number: + // Input: + // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit) + // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT + // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1000000000000000 + // + // This falls into case 4. We create a carry from R and T, + // which then propagates into L and F: + // + // Output: + // Sign | Exp (8 bit) | Frac (first 7 bit) + // S E E E E E E E E F F F F F F L + // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 + // + // + // Max denormal value round to min normal value: + // Input: + // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit) + // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT + // 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1111111111111111 + // + // This falls into case 4. We create a carry from R and T, + // propagate into L and F, which then propagates into exponent + // bits: + // + // Output: + // Sign | Exp (8 bit) | Frac (first 7 bit) + // S E E E E E E E E F F F F F F L + // 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 + // + // Max normal value round to Inf: + // Input: + // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit) + // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT + // 0 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1111111111111111 + // + // This falls into case 4. We create a carry from R and T, + // propagate into L and F, which then propagates into exponent + // bits: + // + // Sign | Exp (8 bit) | Frac (first 7 bit) + // S E E E E E E E E F F F F F F L + // 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 + // + // + // Least significant bit of resulting bfloat. + uint32_t lsb = (input >> 16) & 1; + uint32_t rounding_bias = 0x7fff + lsb; + input += rounding_bias; + output.value = static_cast<uint16_t>(input >> 16); + if ((f.u & 0xff800000u) == 0) { + // Flush positive denormal to 0 + output.value = 0x0; + } + if ((f.u & 0xff800000u) == 0x80000000u) { + // Flush negative denormal to -0 + output.value = 0x8000; + } if (float_isnan(v)) { - // If the value is a NaN, squash it to a qNaN with msb of fraction set, - // this makes sure after truncation we don't end up with an inf. - // - // qNaN magic: All exponent bits set + most significant bit of fraction - // set. - output.value = 0x7fc0; - } else if (std::fabs(v) < std::numeric_limits<float>::min()) { - // Flush denormal to +/- 0.0 - output.value = std::signbit(v) ? 0x8000 : 0; - } else { - // Fast rounding algorithm that rounds a half value to nearest even. This - // reduces expected error when we convert a large number of floats. Here - // is how it works: - // - // Definitions: - // To convert a float 32 to bfloat16, a float 32 can be viewed as 32 bits - // with the following tags: - // - // Sign | Exp (8 bits) | Frac (23 bits) - // S EEEEEEEE FFFFFFLRTTTTTTTTTTTTTTT - // - // S: Sign bit. - // E: Exponent bits. - // F: First 6 bits of fraction. - // L: Least significant bit of resulting bfloat16 if we truncate away the - // rest of the float32. This is also the 7th bit of fraction - // R: Rounding bit, 8th bit of fraction. - // T: Sticky bits, rest of fraction, 15 bits. - // - // To round half to nearest even, there are 3 cases where we want to round - // down (simply truncate the result of the bits away, which consists of - // rounding bit and sticky bits) and two cases where we want to round up - // (truncate then add one to the result). - // - // The fast converting algorithm simply adds lsb (L) to 0x7fff (15 bits of - // 1s) as the rounding bias, adds the rounding bias to the input, then - // truncates the last 16 bits away. - // - // To understand how it works, we can analyze this algorithm case by case: - // - // 1. L = 0, R = 0: - // Expect: round down, this is less than half value. - // - // Algorithm: - // - Rounding bias: 0x7fff + 0 = 0x7fff - // - Adding rounding bias to input may create any carry, depending on - // whether there is any value set to 1 in T bits. - // - R may be set to 1 if there is a carry. - // - L remains 0. - // - Note that this case also handles Inf and -Inf, where all fraction - // bits, including L, R and Ts are all 0. The output remains Inf after - // this algorithm. - // - // 2. L = 1, R = 0: - // Expect: round down, this is less than half value. - // - // Algorithm: - // - Rounding bias: 0x7fff + 1 = 0x8000 - // - Adding rounding bias to input doesn't change sticky bits but - // adds 1 to rounding bit. - // - L remains 1. - // - // 3. L = 0, R = 1, all of T are 0: - // Expect: round down, this is exactly at half, the result is already - // even (L=0). - // - // Algorithm: - // - Rounding bias: 0x7fff + 0 = 0x7fff - // - Adding rounding bias to input sets all sticky bits to 1, but - // doesn't create a carry. - // - R remains 1. - // - L remains 0. - // - // 4. L = 1, R = 1: - // Expect: round up, this is exactly at half, the result needs to be - // round to the next even number. - // - // Algorithm: - // - Rounding bias: 0x7fff + 1 = 0x8000 - // - Adding rounding bias to input doesn't change sticky bits, but - // creates a carry from rounding bit. - // - The carry sets L to 0, creates another carry bit and propagate - // forward to F bits. - // - If all the F bits are 1, a carry then propagates to the exponent - // bits, which then creates the minimum value with the next exponent - // value. Note that we won't have the case where exponents are all 1, - // since that's either a NaN (handled in the other if condition) or inf - // (handled in case 1). - // - // 5. L = 0, R = 1, any of T is 1: - // Expect: round up, this is greater than half. - // - // Algorithm: - // - Rounding bias: 0x7fff + 0 = 0x7fff - // - Adding rounding bias to input creates a carry from sticky bits, - // sets rounding bit to 0, then create another carry. - // - The second carry sets L to 1. - // - // Examples: - // - // Exact half value that is already even: - // Input: - // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit) - // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT - // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1000000000000000 - // - // This falls into case 3. We truncate the rest of 16 bits and no - // carry is created into F and L: - // - // Output: - // Sign | Exp (8 bit) | Frac (first 7 bit) - // S E E E E E E E E F F F F F F L - // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 - // - // Exact half value, round to next even number: - // Input: - // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit) - // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT - // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1000000000000000 - // - // This falls into case 4. We create a carry from R and T, - // which then propagates into L and F: - // - // Output: - // Sign | Exp (8 bit) | Frac (first 7 bit) - // S E E E E E E E E F F F F F F L - // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 - // - // - // Max denormal value round to min normal value: - // Input: - // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit) - // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT - // 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1111111111111111 - // - // This falls into case 4. We create a carry from R and T, - // propagate into L and F, which then propagates into exponent - // bits: - // - // Output: - // Sign | Exp (8 bit) | Frac (first 7 bit) - // S E E E E E E E E F F F F F F L - // 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 - // - // Max normal value round to Inf: - // Input: - // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit) - // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT - // 0 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1111111111111111 - // - // This falls into case 4. We create a carry from R and T, - // propagate into L and F, which then propagates into exponent - // bits: - // - // Sign | Exp (8 bit) | Frac (first 7 bit) - // S E E E E E E E E F F F F F F L - // 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 - // - // - // Least significant bit of resulting bfloat. - uint32_t lsb = (input >> 16) & 1; - uint32_t rounding_bias = 0x7fff + lsb; - input += rounding_bias; - output.value = static_cast<uint16_t>(input >> 16); + output.value = NAN_VALUE; } return output; } diff --git a/tensorflow/core/ops/compat/ops_history_v2/CompressElement.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/CompressElement.pbtxt new file mode 100644 index 00000000000..07d8cb461af --- /dev/null +++ b/tensorflow/core/ops/compat/ops_history_v2/CompressElement.pbtxt @@ -0,0 +1,17 @@ +op { + name: "CompressElement" + input_arg { + name: "components" + type_list_attr: "input_types" + } + output_arg { + name: "compressed" + type: DT_VARIANT + } + attr { + name: "input_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } +} diff --git a/tensorflow/core/ops/compat/ops_history_v2/DenseBincount.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/DenseBincount.pbtxt index e26e1639e82..9bab6854e40 100644 --- a/tensorflow/core/ops/compat/ops_history_v2/DenseBincount.pbtxt +++ b/tensorflow/core/ops/compat/ops_history_v2/DenseBincount.pbtxt @@ -39,7 +39,7 @@ op { } } attr { - name: "binary_count" + name: "binary_output" type: "bool" default_value { b: false diff --git a/tensorflow/core/ops/compat/ops_history_v2/DenseCountSparseOutput.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/DenseCountSparseOutput.pbtxt index c5b845fd0fb..be566eab9f4 100644 --- a/tensorflow/core/ops/compat/ops_history_v2/DenseCountSparseOutput.pbtxt +++ b/tensorflow/core/ops/compat/ops_history_v2/DenseCountSparseOutput.pbtxt @@ -6,7 +6,7 @@ op { } input_arg { name: "weights" - type: DT_FLOAT + type_attr: "output_type" } output_arg { name: "output_indices" @@ -49,7 +49,7 @@ op { minimum: -1 } attr { - name: "binary_count" + name: "binary_output" type: "bool" } attr { @@ -57,8 +57,10 @@ op { type: "type" allowed_values { list { + type: DT_INT32 type: DT_INT64 type: DT_FLOAT + type: DT_DOUBLE } } } diff --git a/tensorflow/core/ops/compat/ops_history_v2/ExtractGlimpseV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/ExtractGlimpseV2.pbtxt new file mode 100644 index 00000000000..08725f4504c --- /dev/null +++ b/tensorflow/core/ops/compat/ops_history_v2/ExtractGlimpseV2.pbtxt @@ -0,0 +1,47 @@ +op { + name: "ExtractGlimpseV2" + input_arg { + name: "input" + type: DT_FLOAT + } + input_arg { + name: "size" + type: DT_INT32 + } + input_arg { + name: "offsets" + type: DT_FLOAT + } + output_arg { + name: "glimpse" + type: DT_FLOAT + } + attr { + name: "centered" + type: "bool" + default_value { + b: true + } + } + attr { + name: "normalized" + type: "bool" + default_value { + b: true + } + } + attr { + name: "uniform_noise" + type: "bool" + default_value { + b: true + } + } + attr { + name: "noise" + type: "string" + default_value { + s: "uniform" + } + } +} diff --git a/tensorflow/core/ops/compat/ops_history_v2/RaggedBincount.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/RaggedBincount.pbtxt index 9d94149cc09..4f5fb24109c 100644 --- a/tensorflow/core/ops/compat/ops_history_v2/RaggedBincount.pbtxt +++ b/tensorflow/core/ops/compat/ops_history_v2/RaggedBincount.pbtxt @@ -43,7 +43,7 @@ op { } } attr { - name: "binary_count" + name: "binary_output" type: "bool" default_value { b: false diff --git a/tensorflow/core/ops/compat/ops_history_v2/RaggedCountSparseOutput.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/RaggedCountSparseOutput.pbtxt index 7f492418b48..aa1a4e07aaf 100644 --- a/tensorflow/core/ops/compat/ops_history_v2/RaggedCountSparseOutput.pbtxt +++ b/tensorflow/core/ops/compat/ops_history_v2/RaggedCountSparseOutput.pbtxt @@ -10,7 +10,7 @@ op { } input_arg { name: "weights" - type: DT_FLOAT + type_attr: "output_type" } output_arg { name: "output_indices" @@ -53,7 +53,7 @@ op { minimum: -1 } attr { - name: "binary_count" + name: "binary_output" type: "bool" } attr { @@ -61,8 +61,10 @@ op { type: "type" allowed_values { list { + type: DT_INT32 type: DT_INT64 type: DT_FLOAT + type: DT_DOUBLE } } } diff --git a/tensorflow/core/ops/compat/ops_history_v2/SparseBincount.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/SparseBincount.pbtxt index 333b71a5e1c..9bbc5132845 100644 --- a/tensorflow/core/ops/compat/ops_history_v2/SparseBincount.pbtxt +++ b/tensorflow/core/ops/compat/ops_history_v2/SparseBincount.pbtxt @@ -47,7 +47,7 @@ op { } } attr { - name: "binary_count" + name: "binary_output" type: "bool" default_value { b: false diff --git a/tensorflow/core/ops/compat/ops_history_v2/SparseCountSparseOutput.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/SparseCountSparseOutput.pbtxt index b701e5fc0db..ed79733f97f 100644 --- a/tensorflow/core/ops/compat/ops_history_v2/SparseCountSparseOutput.pbtxt +++ b/tensorflow/core/ops/compat/ops_history_v2/SparseCountSparseOutput.pbtxt @@ -14,7 +14,7 @@ op { } input_arg { name: "weights" - type: DT_FLOAT + type_attr: "output_type" } output_arg { name: "output_indices" @@ -57,7 +57,7 @@ op { minimum: -1 } attr { - name: "binary_count" + name: "binary_output" type: "bool" } attr { @@ -65,8 +65,10 @@ op { type: "type" allowed_values { list { + type: DT_INT32 type: DT_INT64 type: DT_FLOAT + type: DT_DOUBLE } } } diff --git a/tensorflow/core/ops/compat/ops_history_v2/SparseCrossHashed.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/SparseCrossHashed.pbtxt new file mode 100644 index 00000000000..73002a92f24 --- /dev/null +++ b/tensorflow/core/ops/compat/ops_history_v2/SparseCrossHashed.pbtxt @@ -0,0 +1,72 @@ +op { + name: "SparseCrossHashed" + input_arg { + name: "indices" + type: DT_INT64 + number_attr: "N" + } + input_arg { + name: "values" + type_list_attr: "sparse_types" + } + input_arg { + name: "shapes" + type: DT_INT64 + number_attr: "N" + } + input_arg { + name: "dense_inputs" + type_list_attr: "dense_types" + } + input_arg { + name: "num_buckets" + type: DT_INT64 + } + input_arg { + name: "strong_hash" + type: DT_BOOL + } + input_arg { + name: "salt" + type: DT_INT64 + } + output_arg { + name: "output_indices" + type: DT_INT64 + } + output_arg { + name: "output_values" + type: DT_INT64 + } + output_arg { + name: "output_shape" + type: DT_INT64 + } + attr { + name: "N" + type: "int" + has_minimum: true + } + attr { + name: "sparse_types" + type: "list(type)" + has_minimum: true + allowed_values { + list { + type: DT_INT64 + type: DT_STRING + } + } + } + attr { + name: "dense_types" + type: "list(type)" + has_minimum: true + allowed_values { + list { + type: DT_INT64 + type: DT_STRING + } + } + } +} diff --git a/tensorflow/core/ops/compat/ops_history_v2/SparseCrossV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/SparseCrossV2.pbtxt new file mode 100644 index 00000000000..206542e4713 --- /dev/null +++ b/tensorflow/core/ops/compat/ops_history_v2/SparseCrossV2.pbtxt @@ -0,0 +1,64 @@ +op { + name: "SparseCrossV2" + input_arg { + name: "indices" + type: DT_INT64 + number_attr: "N" + } + input_arg { + name: "values" + type_list_attr: "sparse_types" + } + input_arg { + name: "shapes" + type: DT_INT64 + number_attr: "N" + } + input_arg { + name: "dense_inputs" + type_list_attr: "dense_types" + } + input_arg { + name: "sep" + type: DT_STRING + } + output_arg { + name: "output_indices" + type: DT_INT64 + } + output_arg { + name: "output_values" + type: DT_STRING + } + output_arg { + name: "output_shape" + type: DT_INT64 + } + attr { + name: "N" + type: "int" + has_minimum: true + } + attr { + name: "sparse_types" + type: "list(type)" + has_minimum: true + allowed_values { + list { + type: DT_INT64 + type: DT_STRING + } + } + } + attr { + name: "dense_types" + type: "list(type)" + has_minimum: true + allowed_values { + list { + type: DT_INT64 + type: DT_STRING + } + } + } +} diff --git a/tensorflow/core/ops/compat/ops_history_v2/SparseSegmentMean.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/SparseSegmentMean.pbtxt index a3fde8699b1..526c2c25c04 100644 --- a/tensorflow/core/ops/compat/ops_history_v2/SparseSegmentMean.pbtxt +++ b/tensorflow/core/ops/compat/ops_history_v2/SparseSegmentMean.pbtxt @@ -1,45 +1,3 @@ -op { - name: "SparseSegmentMean" - input_arg { - name: "data" - type_attr: "T" - } - input_arg { - name: "indices" - type_attr: "Tidx" - } - input_arg { - name: "segment_ids" - type: DT_INT32 - } - output_arg { - name: "output" - type_attr: "T" - } - attr { - name: "T" - type: "type" - allowed_values { - list { - type: DT_FLOAT - type: DT_DOUBLE - } - } - } - attr { - name: "Tidx" - type: "type" - default_value { - type: DT_INT32 - } - allowed_values { - list { - type: DT_INT32 - type: DT_INT64 - } - } - } -} op { name: "SparseSegmentMean" input_arg { diff --git a/tensorflow/core/ops/compat/ops_history_v2/SparseSegmentMeanWithNumSegments.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/SparseSegmentMeanWithNumSegments.pbtxt index 2d1d816200a..b9984f8df25 100644 --- a/tensorflow/core/ops/compat/ops_history_v2/SparseSegmentMeanWithNumSegments.pbtxt +++ b/tensorflow/core/ops/compat/ops_history_v2/SparseSegmentMeanWithNumSegments.pbtxt @@ -1,62 +1,3 @@ -op { - name: "SparseSegmentMeanWithNumSegments" - input_arg { - name: "data" - type_attr: "T" - } - input_arg { - name: "indices" - type_attr: "Tidx" - } - input_arg { - name: "segment_ids" - type: DT_INT32 - } - input_arg { - name: "num_segments" - type_attr: "Tnumsegments" - } - output_arg { - name: "output" - type_attr: "T" - } - attr { - name: "T" - type: "type" - allowed_values { - list { - type: DT_FLOAT - type: DT_DOUBLE - } - } - } - attr { - name: "Tidx" - type: "type" - default_value { - type: DT_INT32 - } - allowed_values { - list { - type: DT_INT32 - type: DT_INT64 - } - } - } - attr { - name: "Tnumsegments" - type: "type" - default_value { - type: DT_INT32 - } - allowed_values { - list { - type: DT_INT32 - type: DT_INT64 - } - } - } -} op { name: "SparseSegmentMeanWithNumSegments" input_arg { diff --git a/tensorflow/core/ops/compat/ops_history_v2/SparseSegmentSqrtN.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/SparseSegmentSqrtN.pbtxt index 6ab44de93ec..17562d4f333 100644 --- a/tensorflow/core/ops/compat/ops_history_v2/SparseSegmentSqrtN.pbtxt +++ b/tensorflow/core/ops/compat/ops_history_v2/SparseSegmentSqrtN.pbtxt @@ -1,45 +1,3 @@ -op { - name: "SparseSegmentSqrtN" - input_arg { - name: "data" - type_attr: "T" - } - input_arg { - name: "indices" - type_attr: "Tidx" - } - input_arg { - name: "segment_ids" - type: DT_INT32 - } - output_arg { - name: "output" - type_attr: "T" - } - attr { - name: "T" - type: "type" - allowed_values { - list { - type: DT_FLOAT - type: DT_DOUBLE - } - } - } - attr { - name: "Tidx" - type: "type" - default_value { - type: DT_INT32 - } - allowed_values { - list { - type: DT_INT32 - type: DT_INT64 - } - } - } -} op { name: "SparseSegmentSqrtN" input_arg { diff --git a/tensorflow/core/ops/compat/ops_history_v2/SparseSegmentSqrtNWithNumSegments.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/SparseSegmentSqrtNWithNumSegments.pbtxt index 038a5a2bd28..1f24446a587 100644 --- a/tensorflow/core/ops/compat/ops_history_v2/SparseSegmentSqrtNWithNumSegments.pbtxt +++ b/tensorflow/core/ops/compat/ops_history_v2/SparseSegmentSqrtNWithNumSegments.pbtxt @@ -1,62 +1,3 @@ -op { - name: "SparseSegmentSqrtNWithNumSegments" - input_arg { - name: "data" - type_attr: "T" - } - input_arg { - name: "indices" - type_attr: "Tidx" - } - input_arg { - name: "segment_ids" - type: DT_INT32 - } - input_arg { - name: "num_segments" - type_attr: "Tnumsegments" - } - output_arg { - name: "output" - type_attr: "T" - } - attr { - name: "T" - type: "type" - allowed_values { - list { - type: DT_FLOAT - type: DT_DOUBLE - } - } - } - attr { - name: "Tidx" - type: "type" - default_value { - type: DT_INT32 - } - allowed_values { - list { - type: DT_INT32 - type: DT_INT64 - } - } - } - attr { - name: "Tnumsegments" - type: "type" - default_value { - type: DT_INT32 - } - allowed_values { - list { - type: DT_INT32 - type: DT_INT64 - } - } - } -} op { name: "SparseSegmentSqrtNWithNumSegments" input_arg { diff --git a/tensorflow/core/ops/compat/ops_history_v2/TPUReplicatedInput.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/TPUReplicatedInput.pbtxt index a293537e36d..b549b570c13 100644 --- a/tensorflow/core/ops/compat/ops_history_v2/TPUReplicatedInput.pbtxt +++ b/tensorflow/core/ops/compat/ops_history_v2/TPUReplicatedInput.pbtxt @@ -56,3 +56,46 @@ op { } } } +op { + name: "TPUReplicatedInput" + input_arg { + name: "inputs" + type_attr: "T" + number_attr: "N" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "N" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "T" + type: "type" + } + attr { + name: "is_mirrored_variable" + type: "bool" + default_value { + b: false + } + } + attr { + name: "index" + type: "int" + default_value { + i: -1 + } + } + attr { + name: "is_packed" + type: "bool" + default_value { + b: false + } + } +} diff --git a/tensorflow/core/ops/compat/ops_history_v2/UncompressElement.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/UncompressElement.pbtxt new file mode 100644 index 00000000000..68406e0e4bc --- /dev/null +++ b/tensorflow/core/ops/compat/ops_history_v2/UncompressElement.pbtxt @@ -0,0 +1,23 @@ +op { + name: "UncompressElement" + input_arg { + name: "compressed" + type: DT_VARIANT + } + output_arg { + name: "components" + type_list_attr: "output_types" + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } +} diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc index 0122cbed087..6a633fb679d 100644 --- a/tensorflow/core/ops/dataset_ops.cc +++ b/tensorflow/core/ops/dataset_ops.cc @@ -731,42 +731,19 @@ REGISTER_OP("OneShotIterator") .SetIsStateful() .SetShapeFn(shape_inference::ScalarShape); -namespace { - -Status IteratorGetNextShapeFn(shape_inference::InferenceContext* c) { - shape_inference::ShapeHandle unused; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); - std::vector<PartialTensorShape> output_shapes; - TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes)); - if (output_shapes.size() != c->num_outputs()) { - return errors::InvalidArgument( - "`output_shapes` must be the same length as `output_types` (", - output_shapes.size(), " vs. ", c->num_outputs()); - } - for (size_t i = 0; i < output_shapes.size(); ++i) { - shape_inference::ShapeHandle output_shape_handle; - TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape( - output_shapes[i], &output_shape_handle)); - c->set_output(static_cast<int>(i), output_shape_handle); - } - return Status::OK(); -} - -} // namespace - REGISTER_OP("IteratorGetNext") .Input("iterator: resource") .Output("components: output_types") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(IteratorGetNextShapeFn); + .SetShapeFn(shape_inference::DatasetIteratorShape); REGISTER_OP("IteratorGetNextSync") .Input("iterator: resource") .Output("components: output_types") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(IteratorGetNextShapeFn); + .SetShapeFn(shape_inference::DatasetIteratorShape); // TODO(b/124308596): Instead of conservatively marking this op as stateful, // implement a mechanism to determine whether `dataset` has a side-effect @@ -778,7 +755,7 @@ REGISTER_OP("DatasetToSingleElement") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") .SetIsStateful() - .SetShapeFn(IteratorGetNextShapeFn); + .SetShapeFn(shape_inference::DatasetIteratorShape); // TODO(b/124308596): Instead of conservatively marking this op as stateful, // implement a mechanism to determine whether `dataset` has a side-effect @@ -796,7 +773,7 @@ REGISTER_OP("ReduceDataset") .Attr("output_shapes: list(shape) >= 1") .Attr("use_inter_op_parallelism: bool = true") .SetIsStateful() - .SetShapeFn(IteratorGetNextShapeFn); + .SetShapeFn(shape_inference::DatasetIteratorShape); REGISTER_OP("IteratorToStringHandle") .Input("resource_handle: resource") @@ -875,7 +852,7 @@ REGISTER_OP("OptionalGetValue") .Output("components: output_types") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(IteratorGetNextShapeFn); + .SetShapeFn(shape_inference::DatasetIteratorShape); REGISTER_OP("IteratorGetNextAsOptional") .Input("iterator: resource") @@ -992,7 +969,7 @@ REGISTER_OP("MultiDeviceIteratorGetNextFromShard") .Output("components: output_types") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(IteratorGetNextShapeFn); + .SetShapeFn(shape_inference::DatasetIteratorShape); REGISTER_OP("MultiDeviceIteratorToStringHandle") .Input("multi_device_iterator: resource") diff --git a/tensorflow/core/ops/experimental_dataset_ops.cc b/tensorflow/core/ops/experimental_dataset_ops.cc index 2c9cbe2f416..aa4bd64270a 100644 --- a/tensorflow/core/ops/experimental_dataset_ops.cc +++ b/tensorflow/core/ops/experimental_dataset_ops.cc @@ -132,6 +132,19 @@ REGISTER_OP("ExperimentalChooseFastestDataset") .Attr("output_shapes: list(shape) >= 1") .SetShapeFn(shape_inference::ScalarShape); +REGISTER_OP("CompressElement") + .Input("components: input_types") + .Output("compressed: variant") + .Attr("input_types: list(type) >= 1") + .SetShapeFn(shape_inference::ScalarShape); + +REGISTER_OP("UncompressElement") + .Input("compressed: variant") + .Output("components: output_types") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::DatasetIteratorShape); + REGISTER_OP("CSVDataset") .Input("filenames: string") .Input("compression_type: string") diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc index 418f1e20e37..e11f14b8538 100644 --- a/tensorflow/core/ops/image_ops.cc +++ b/tensorflow/core/ops/image_ops.cc @@ -756,6 +756,41 @@ REGISTER_OP("ExtractGlimpse") c->Dim(input, 3)); }); +REGISTER_OP("ExtractGlimpseV2") + .Input("input: float") + .Input("size: int32") + .Input("offsets: float") + .Output("glimpse: float") + .Attr("centered: bool = true") + .Attr("normalized: bool = true") + .Attr("uniform_noise: bool = true") + .Attr("noise: string = 'uniform'") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle input; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input)); + ShapeHandle offsets; + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &offsets)); + + DimensionHandle batch_dim; + TF_RETURN_IF_ERROR( + c->Merge(c->Dim(input, 0), c->Dim(offsets, 0), &batch_dim)); + DimensionHandle unused; + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(offsets, 1), 2, &unused)); + + bool uniform_noise = false; + TF_RETURN_IF_ERROR(c->GetAttr("uniform_noise", &uniform_noise)); + string noise; + TF_RETURN_IF_ERROR(c->GetAttr("noise", &noise)); + if (uniform_noise && (!noise.empty() && noise != "uniform")) { + return errors::InvalidArgument( + "The uniform_noise and noise should not be specified at the same " + "time"); + } + + return SetOutputToSizedImage(c, batch_dim, 1 /* size_input_idx */, + c->Dim(input, 3)); + }); + // -------------------------------------------------------------------------- REGISTER_OP("CropAndResize") diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index cbf03d7b045..972d6e27b75 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -936,7 +936,7 @@ REGISTER_OP("_MklMatMul") .Output("product: T") .Attr("transpose_a: bool = false") .Attr("transpose_b: bool = false") - .Attr("T: {bfloat16, float, double, complex64, complex128}") + .Attr("T: {bfloat16, float}") .SetShapeFn(shape_inference::MatMulShape); #endif // INTEL_MKL diff --git a/tensorflow/core/ops/mkl_nn_ops.cc b/tensorflow/core/ops/mkl_nn_ops.cc index a625fb64ed3..248cf1d0e8a 100644 --- a/tensorflow/core/ops/mkl_nn_ops.cc +++ b/tensorflow/core/ops/mkl_nn_ops.cc @@ -1369,6 +1369,48 @@ REGISTER_OP("_MklFusedBatchNormGradV3") R"doc(MKL-DNN implementation of FusedBatchNormGradV3: Do not invoke this operator directly in Python. Graph rewrite pass is expected to invoke this operator.)doc"); +REGISTER_OP("_MklFusedBatchNormEx") + .Input("x: T") + .Input("scale: U") + .Input("offset: U") + .Input("mean: U") + .Input("variance: U") + .Input("side_input: num_side_inputs * T") + .Input("mkl_x: uint8") + .Input("mkl_scale: uint8") + .Input("mkl_offset: uint8") + .Input("mkl_mean: uint8") + .Input("mkl_variance: uint8") + .Input("mkl_side_input: num_side_inputs * uint8") + .Output("y: T") + .Output("batch_mean: U") + .Output("batch_variance: U") + .Output("reserve_space_1: U") + .Output("reserve_space_2: U") + .Output("reserve_space_3: U") + .Output("mkl_y: uint8") + .Output("mkl_batch_mean: uint8") + .Output("mkl_batch_variance: uint8") + .Output("mkl_reserve_space_1: uint8") + .Output("mkl_reserve_space_2: uint8") + .Output("mkl_reserve_space_3: uint8") + .Attr("T: {bfloat16, float}") + .Attr("U: {float}") + .Attr("epsilon: float = 0.0001") + .Attr("exponential_avg_factor: float = 1.0") + .Attr(GetConvnetDataFormatAttrString()) + .Attr("num_side_inputs: int >= 0 = 0") + .Attr("activation_mode: string = \"Identity\"") + .Attr("is_training: bool = true") + .SetShapeFn(shape_inference::FusedBatchNormShape) + .Doc(R"doc( +MKL version of FusedBatchNormEx operator. Uses MKL DNN APIs to perform fused +batch normalization and relu. + +NOTE Do not invoke this operator directly in Python. Graph rewrite pass is +expected to invoke these operators. +)doc"); + } // namespace tensorflow #endif // INTEL_MKL diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index 9200547cf45..518972696f1 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -238,7 +238,11 @@ REGISTER_OP("_FusedBatchNormEx") .Output("reserve_space_1: U") .Output("reserve_space_2: U") .Output("reserve_space_3: U") +#ifdef ENABLE_MKLDNN_V1 + .Attr("T: {half, float, bfloat16}") +#else .Attr("T: {half, float}") +#endif .Attr("U: {float}") .Attr("epsilon: float = 0.0001") .Attr("exponential_avg_factor: float = 1.0") diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 1ea06a2fdac..e2f2e3d00fa 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -7451,6 +7451,23 @@ op { } } } +op { + name: "CompressElement" + input_arg { + name: "components" + type_list_attr: "input_types" + } + output_arg { + name: "compressed" + type: DT_VARIANT + } + attr { + name: "input_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } +} op { name: "ComputeAccidentalHits" input_arg { @@ -11515,7 +11532,7 @@ op { } } attr { - name: "binary_count" + name: "binary_output" type: "bool" default_value { b: false @@ -11530,7 +11547,7 @@ op { } input_arg { name: "weights" - type: DT_FLOAT + type_attr: "output_type" } output_arg { name: "output_indices" @@ -11573,7 +11590,7 @@ op { minimum: -1 } attr { - name: "binary_count" + name: "binary_output" type: "bool" } attr { @@ -11581,8 +11598,10 @@ op { type: "type" allowed_values { list { + type: DT_INT32 type: DT_INT64 type: DT_FLOAT + type: DT_DOUBLE } } } @@ -15092,6 +15111,53 @@ op { } } } +op { + name: "ExtractGlimpseV2" + input_arg { + name: "input" + type: DT_FLOAT + } + input_arg { + name: "size" + type: DT_INT32 + } + input_arg { + name: "offsets" + type: DT_FLOAT + } + output_arg { + name: "glimpse" + type: DT_FLOAT + } + attr { + name: "centered" + type: "bool" + default_value { + b: true + } + } + attr { + name: "normalized" + type: "bool" + default_value { + b: true + } + } + attr { + name: "uniform_noise" + type: "bool" + default_value { + b: true + } + } + attr { + name: "noise" + type: "string" + default_value { + s: "uniform" + } + } +} op { name: "ExtractImagePatches" input_arg { @@ -33206,7 +33272,7 @@ op { } } attr { - name: "binary_count" + name: "binary_output" type: "bool" default_value { b: false @@ -33225,7 +33291,7 @@ op { } input_arg { name: "weights" - type: DT_FLOAT + type_attr: "output_type" } output_arg { name: "output_indices" @@ -33268,7 +33334,7 @@ op { minimum: -1 } attr { - name: "binary_count" + name: "binary_output" type: "bool" } attr { @@ -33276,8 +33342,10 @@ op { type: "type" allowed_values { list { + type: DT_INT32 type: DT_INT64 type: DT_FLOAT + type: DT_DOUBLE } } } @@ -44717,7 +44785,7 @@ op { } } attr { - name: "binary_count" + name: "binary_output" type: "bool" default_value { b: false @@ -44849,7 +44917,7 @@ op { } input_arg { name: "weights" - type: DT_FLOAT + type_attr: "output_type" } output_arg { name: "output_indices" @@ -44892,7 +44960,7 @@ op { minimum: -1 } attr { - name: "binary_count" + name: "binary_output" type: "bool" } attr { @@ -44900,8 +44968,10 @@ op { type: "type" allowed_values { list { + type: DT_INT32 type: DT_INT64 type: DT_FLOAT + type: DT_DOUBLE } } } @@ -44999,6 +45069,142 @@ op { } } } +op { + name: "SparseCrossHashed" + input_arg { + name: "indices" + type: DT_INT64 + number_attr: "N" + } + input_arg { + name: "values" + type_list_attr: "sparse_types" + } + input_arg { + name: "shapes" + type: DT_INT64 + number_attr: "N" + } + input_arg { + name: "dense_inputs" + type_list_attr: "dense_types" + } + input_arg { + name: "num_buckets" + type: DT_INT64 + } + input_arg { + name: "strong_hash" + type: DT_BOOL + } + input_arg { + name: "salt" + type: DT_INT64 + } + output_arg { + name: "output_indices" + type: DT_INT64 + } + output_arg { + name: "output_values" + type: DT_INT64 + } + output_arg { + name: "output_shape" + type: DT_INT64 + } + attr { + name: "N" + type: "int" + has_minimum: true + } + attr { + name: "sparse_types" + type: "list(type)" + has_minimum: true + allowed_values { + list { + type: DT_INT64 + type: DT_STRING + } + } + } + attr { + name: "dense_types" + type: "list(type)" + has_minimum: true + allowed_values { + list { + type: DT_INT64 + type: DT_STRING + } + } + } +} +op { + name: "SparseCrossV2" + input_arg { + name: "indices" + type: DT_INT64 + number_attr: "N" + } + input_arg { + name: "values" + type_list_attr: "sparse_types" + } + input_arg { + name: "shapes" + type: DT_INT64 + number_attr: "N" + } + input_arg { + name: "dense_inputs" + type_list_attr: "dense_types" + } + input_arg { + name: "sep" + type: DT_STRING + } + output_arg { + name: "output_indices" + type: DT_INT64 + } + output_arg { + name: "output_values" + type: DT_STRING + } + output_arg { + name: "output_shape" + type: DT_INT64 + } + attr { + name: "N" + type: "int" + has_minimum: true + } + attr { + name: "sparse_types" + type: "list(type)" + has_minimum: true + allowed_values { + list { + type: DT_INT64 + type: DT_STRING + } + } + } + attr { + name: "dense_types" + type: "list(type)" + has_minimum: true + allowed_values { + list { + type: DT_INT64 + type: DT_STRING + } + } + } +} op { name: "SparseDenseCwiseAdd" input_arg { @@ -49723,6 +49929,13 @@ op { i: -1 } } + attr { + name: "is_packed" + type: "bool" + default_value { + b: false + } + } } op { name: "TPUReplicatedOutput" @@ -52520,6 +52733,29 @@ op { type: "type" } } +op { + name: "UncompressElement" + input_arg { + name: "compressed" + type: DT_VARIANT + } + output_arg { + name: "components" + type_list_attr: "output_types" + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } +} op { name: "UnicodeDecode" input_arg { diff --git a/tensorflow/core/ops/sparse_ops.cc b/tensorflow/core/ops/sparse_ops.cc index 85186c4a2d8..906cef1f5ec 100644 --- a/tensorflow/core/ops/sparse_ops.cc +++ b/tensorflow/core/ops/sparse_ops.cc @@ -272,6 +272,46 @@ REGISTER_OP("SparseCross") return Status::OK(); }); +REGISTER_OP("SparseCrossV2") + .Input("indices: N * int64") + .Input("values: sparse_types") + .Input("shapes: N * int64") + .Input("dense_inputs: dense_types") + .Input("sep: string") + .Output("output_indices: int64") + .Output("output_values: string") + .Output("output_shape: int64") + .Attr("N: int >= 0") + .Attr("sparse_types: list({int64, string}) >= 0") + .Attr("dense_types: list({int64, string}) >= 0") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->Matrix(c->UnknownDim(), 2)); + c->set_output(1, c->Vector(c->UnknownDim())); + c->set_output(2, c->Vector(2)); + return Status::OK(); + }); + +REGISTER_OP("SparseCrossHashed") + .Input("indices: N * int64") + .Input("values: sparse_types") + .Input("shapes: N * int64") + .Input("dense_inputs: dense_types") + .Input("num_buckets: int64") + .Input("strong_hash: bool") + .Input("salt: int64") + .Output("output_indices: int64") + .Output("output_values: int64") + .Output("output_shape: int64") + .Attr("N: int >= 0") + .Attr("sparse_types: list({int64, string}) >= 0") + .Attr("dense_types: list({int64, string}) >= 0") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->Matrix(c->UnknownDim(), 2)); + c->set_output(1, c->Vector(c->UnknownDim())); + c->set_output(2, c->Vector(2)); + return Status::OK(); + }); + REGISTER_OP("SparseSplit") .Input("split_dim: int64") .Input("indices: int64") diff --git a/tensorflow/core/ops/tpu_replication_ops.cc b/tensorflow/core/ops/tpu_replication_ops.cc index 3bb94044e14..a729d3c3b7b 100644 --- a/tensorflow/core/ops/tpu_replication_ops.cc +++ b/tensorflow/core/ops/tpu_replication_ops.cc @@ -44,6 +44,8 @@ REGISTER_OP("TPUReplicatedInput") .Attr("is_mirrored_variable: bool = false") // Index of the input. If is_mirrored_variable is true, this is ignored. .Attr("index: int = -1") + // All inputs are packed into one input + .Attr("is_packed: bool = false") .SetShapeFn([](InferenceContext* c) { ShapeHandle cur = c->input(c->num_inputs() - 1); for (int i = c->num_inputs() - 2; i >= 0; --i) { diff --git a/tensorflow/core/platform/BUILD b/tensorflow/core/platform/BUILD index f78b738247d..7f7ca0f06cd 100644 --- a/tensorflow/core/platform/BUILD +++ b/tensorflow/core/platform/BUILD @@ -386,6 +386,7 @@ py_test( name = "ram_file_system_test", srcs = ["ram_file_system_test.py"], python_version = "PY3", + tags = ["no_windows"], # TODO(b/156428279): reenable this test once the image is updated. deps = [ "//tensorflow:tensorflow_py", ], diff --git a/tensorflow/core/platform/cloud/BUILD b/tensorflow/core/platform/cloud/BUILD index 101d7ac5807..2440549a353 100644 --- a/tensorflow/core/platform/cloud/BUILD +++ b/tensorflow/core/platform/cloud/BUILD @@ -20,6 +20,7 @@ package_group( packages = [ "//learning/brain/tfrc/...", "//tensorflow/...", + "//third_party/gstpufs/...", ], ) diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc index e4047c78998..92210498b01 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system.cc +++ b/tensorflow/core/platform/cloud/gcs_file_system.cc @@ -158,12 +158,17 @@ string JoinGcsPath(const string& path, const string& subpath) { /// For example: /// - for 'a/b/c/d' it will append 'a', 'a/b' and 'a/b/c' /// - for 'a/b/c/' it will append 'a', 'a/b' and 'a/b/c' +/// - for 'a//b/c/' it will append 'a', 'a//b' and 'a//b/c' +/// - for '/a/b/c/' it will append '/a', '/a/b' and '/a/b/c' std::set<string> AddAllSubpaths(const std::vector<string>& paths) { std::set<string> result; result.insert(paths.begin(), paths.end()); for (const string& path : paths) { StringPiece subpath = io::Dirname(path); - while (!subpath.empty()) { + // If `path` starts with `/`, `subpath` will be `/` and then we get into an + // infinite loop. Same behavior happens if there is a `//` pattern in + // `path`, so we check for that and leave the loop quicker. + while (!(subpath.empty() || subpath == "/")) { result.emplace(string(subpath)); subpath = io::Dirname(subpath); } @@ -1349,9 +1354,19 @@ Status GcsFileSystem::GetMatchingPaths(const string& pattern, const auto& files_and_folders = AddAllSubpaths(all_files); + // To handle `/` in the object names, we need to remove it from `dir` + // and then use `StrCat` to insert it back. + const StringPiece dir_no_slash = str_util::StripSuffix(dir, "/"); + // Match all obtained paths to the input pattern. for (const auto& path : files_and_folders) { - const string& full_path = this->JoinPath(dir, path); + // Manually construct the path instead of using `JoinPath` for the + // cases where `path` starts with a `/` (which is a valid character in + // the filenames of GCS objects). `JoinPath` canonicalizes the result, + // removing duplicate slashes. We know that `dir_no_slash` does not + // end in `/`, so we are safe inserting the new `/` here as the path + // separator. + const string full_path = strings::StrCat(dir_no_slash, "/", path); if (this->Match(full_path, pattern)) { results->push_back(full_path); } diff --git a/tensorflow/core/platform/cloud/gcs_file_system_test.cc b/tensorflow/core/platform/cloud/gcs_file_system_test.cc index 802f18a31ae..14af9f979e6 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system_test.cc +++ b/tensorflow/core/platform/cloud/gcs_file_system_test.cc @@ -1969,6 +1969,56 @@ TEST(GcsFileSystemTest, GetMatchingPaths_SelfDirectoryMarker) { EXPECT_EQ(std::vector<string>({"gs://bucket/path/file3.txt"}), result); } +TEST(GcsFileSystemTest, GetMatchingPaths_SlashInObjectName) { + std::vector<HttpRequest*> requests({new FakeHttpRequest( + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "fields=items%2Fname%2CnextPageToken&prefix=path%2F\n" + "Auth Token: fake_token\n" + "Timeouts: 5 1 10\n", + "{\"items\": [ " + " { \"name\": \"path/\" }," + " { \"name\": \"path//foo.txt\" }]}")}); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); + + std::vector<string> result; + TF_EXPECT_OK(fs.GetMatchingPaths("gs://bucket/path/*", &result)); + EXPECT_EQ(std::vector<string>(), result); +} + +TEST(GcsFileSystemTest, GetMatchingPaths_SlashInObjectNameEscaped) { + std::vector<HttpRequest*> requests({new FakeHttpRequest( + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "fields=items%2Fname%2CnextPageToken&prefix=path%2F\n" + "Auth Token: fake_token\n" + "Timeouts: 5 1 10\n", + "{\"items\": [ " + " { \"name\": \"path/\" }," + " { \"name\": \"path//foo.txt\" }]}")}); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); + + std::vector<string> result; + TF_EXPECT_OK(fs.GetMatchingPaths("gs://bucket/path/\\/*", &result)); + EXPECT_EQ(std::vector<string>({"gs://bucket/path//foo.txt"}), result); +} + TEST(GcsFileSystemTest, GetMatchingPaths_FolderAndWildcard_NoMatches) { std::vector<HttpRequest*> requests({new FakeHttpRequest( "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" diff --git a/tensorflow/core/platform/default/BUILD b/tensorflow/core/platform/default/BUILD index 49318fd0811..89231b0f206 100644 --- a/tensorflow/core/platform/default/BUILD +++ b/tensorflow/core/platform/default/BUILD @@ -509,6 +509,7 @@ filegroup( filegroup( name = "mobile_srcs_no_runtime", srcs = [ + "casts.h", "context.h", "dynamic_annotations.h", "env.cc", diff --git a/tensorflow/core/profiler/convert/BUILD b/tensorflow/core/profiler/convert/BUILD index 369d26a92d9..390f94157c3 100644 --- a/tensorflow/core/profiler/convert/BUILD +++ b/tensorflow/core/profiler/convert/BUILD @@ -242,6 +242,7 @@ cc_library( "//tensorflow/core/profiler/utils:xplane_utils", "//tensorflow/core/profiler/utils:xplane_visitor", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", ], ) diff --git a/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.cc b/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.cc index 83673458d21..89b4939f5d0 100644 --- a/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.cc +++ b/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.cc @@ -752,5 +752,17 @@ std::string GetSummaryNextStep(absl::string_view input_classification, return summary_next_step; } +double HostToDeviceTransferAsPercentOfInputTime( + const InputTimeBreakdown& breakdown) { + // Thanks to the scaling trick we did in GenerateHostResult(), we can + // estimate the percentage of input-time spent on host-to-device transfer in + // the following way. + double total_input_time_us = + breakdown.demanded_file_read_us() + breakdown.advanced_file_read_us() + + breakdown.preprocessing_us() + breakdown.enqueue_us() + + breakdown.unclassified_non_enqueue_us(); + return 100.0 * SafeDivide(breakdown.enqueue_us(), total_input_time_us); +} + } // namespace profiler } // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h b/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h index 93b4df0b2c2..2191251ee88 100644 --- a/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h +++ b/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h @@ -31,6 +31,17 @@ limitations under the License. namespace tensorflow { namespace profiler { +// If the percent of input-time spent on host-to-device transfer is greater than +// kHostToDeviceTimePercentAsSignificant, we should advise the +// user to optimize this transfer. +constexpr double kHostToDeviceTimePercentAsSignificant = 10.0; + +// If the percent of input-time spent on host-to-device transfer is greater than +// kHostToDeviceTimePercentAsDominant, we should ONLY advise the +// user to optimize this transfer; we won't bother to suggest optimization for +// tf.data. +constexpr double kHostToDeviceTimePercentAsDominant = 90.0; + // Computes the summary of step time in milliseconds. StepSummary ComputeStepTimeSummaryInMs( const ::tensorflow::protobuf::RepeatedPtrField<PerCoreStepInfo>& @@ -62,6 +73,11 @@ void OutputAnalysis(double output_percent, std::string* output_classification, string GetSummaryNextStep(absl::string_view input_classification, const InputTimeBreakdown& breakdown); +// Returns the percentage of the input time that is spent on transferring the +// data from host to device. +double HostToDeviceTransferAsPercentOfInputTime( + const InputTimeBreakdown& breakdown); + void AddErrorMessages(const OpStats& op_stats, InputPipelineAnalysisResult* result); diff --git a/tensorflow/core/profiler/convert/op_stats_to_overview_page.cc b/tensorflow/core/profiler/convert/op_stats_to_overview_page.cc index bec92e0d998..62f37c50155 100644 --- a/tensorflow/core/profiler/convert/op_stats_to_overview_page.cc +++ b/tensorflow/core/profiler/convert/op_stats_to_overview_page.cc @@ -97,6 +97,9 @@ void ComputeFaqTips(OverviewPageRecommendation* re) { } void ComputeDocumentationTips(OverviewPageRecommendation* re) { + *re->add_documentation_tips() = MakeOverviewPageTipDocLink( + "https://www.tensorflow.org/guide/data_performance_analysis", + "Analyze tf.data performance with the TF Profiler"); *re->add_documentation_tips() = MakeOverviewPageTipDocLink( "https://www.tensorflow.org/guide/" "data_performance", @@ -294,6 +297,7 @@ OverviewPage ConvertOpStatsToOverviewPage(const OpStats& op_stats, bottleneck.input_classification(), bottleneck.input_statement(), "", hardware_type, TfFunctionRecommendationHtml(op_stats.tf_function_db()), overview_page.mutable_recommendation()); + SetOverviewPageErrorMessage(op_stats, &overview_page); return overview_page; } @@ -310,5 +314,18 @@ void SetRemarks(const OpStats& op_stats, OverviewPageAnalysis* analysis) { } } +void SetOverviewPageErrorMessage(const OpStats& op_stats, + OverviewPage* overview_page) { + *overview_page->mutable_errors() = op_stats.errors(); + absl::c_sort(*overview_page->mutable_errors()); + if (overview_page->errors().empty()) { + // Shows run-environment error only if there is no other existing error. + if (op_stats.run_environment().device_type() != "CPU" && + op_stats.run_environment().device_core_count() <= 0) { + *overview_page->add_errors() = std::string(kNoDeviceTraceCollected); + } + } +} + } // namespace profiler } // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/op_stats_to_overview_page.h b/tensorflow/core/profiler/convert/op_stats_to_overview_page.h index b4b3991a18d..d4d75c03454 100644 --- a/tensorflow/core/profiler/convert/op_stats_to_overview_page.h +++ b/tensorflow/core/profiler/convert/op_stats_to_overview_page.h @@ -48,6 +48,9 @@ OverviewPageAnalysis ComputeAnalysisResult(const OpStats& op_stats); OverviewPageRunEnvironment ComputeRunEnvironment( const RunEnvironment& run_environment); +void SetOverviewPageErrorMessage(const OpStats& op_stats, + OverviewPage* overview_page); + OverviewPage ConvertOpStatsToOverviewPage(const OpStats& op_stats, HardwareType hardware_type); diff --git a/tensorflow/core/profiler/convert/xplane_to_op_stats.cc b/tensorflow/core/profiler/convert/xplane_to_op_stats.cc index f008219cbd2..4d2a45747e0 100644 --- a/tensorflow/core/profiler/convert/xplane_to_op_stats.cc +++ b/tensorflow/core/profiler/convert/xplane_to_op_stats.cc @@ -18,6 +18,7 @@ limitations under the License. #include <vector> #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/convert/op_metrics_db_combiner.h" #include "tensorflow/core/profiler/convert/step_events_to_steps_db.h" @@ -109,12 +110,20 @@ void ProcessHostPlane(const XPlane* host_plane, bool use_device_step_events, } // namespace +void PropagateXSpaceErrorsToOpStats(const XSpace& space, OpStats* op_stats) { + if (space.errors().empty()) return; + absl::flat_hash_set<std::string> unique_errors; + unique_errors.insert(space.errors().begin(), space.errors().end()); + *op_stats->mutable_errors() = {unique_errors.begin(), unique_errors.end()}; +} + OpStats ConvertXSpaceToOpStats(const XSpace& space) { const XPlane* host_plane = FindPlaneWithName(space, kHostThreads); std::vector<const XPlane*> device_planes = FindPlanesWithPrefix(space, kGpuPlanePrefix); OpStats op_stats; StepEvents step_events; + PropagateXSpaceErrorsToOpStats(space, &op_stats); // Convert device planes. OpMetricsDbCombiner op_metrics_db_combiner( op_stats.mutable_device_op_metrics_db()); diff --git a/tensorflow/core/profiler/convert/xplane_to_op_stats.h b/tensorflow/core/profiler/convert/xplane_to_op_stats.h index 2d30a5d5fad..4708caa5aae 100644 --- a/tensorflow/core/profiler/convert/xplane_to_op_stats.h +++ b/tensorflow/core/profiler/convert/xplane_to_op_stats.h @@ -25,6 +25,9 @@ namespace profiler { // NOTE: call GroupTfEvents before if OpStats.step_db needs to be generated. OpStats ConvertXSpaceToOpStats(const XSpace& space); +// Propagate and dedup the errors in XSpace and add to OpStats. +void PropagateXSpaceErrorsToOpStats(const XSpace& space, OpStats* op_stats); + } // namespace profiler } // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xplane_to_op_stats_test.cc b/tensorflow/core/profiler/convert/xplane_to_op_stats_test.cc index 7b4652f6c0b..67901e83dd3 100644 --- a/tensorflow/core/profiler/convert/xplane_to_op_stats_test.cc +++ b/tensorflow/core/profiler/convert/xplane_to_op_stats_test.cc @@ -185,6 +185,18 @@ TEST(ConcertXPlaneToOpStats, TfFunctionTest) { EXPECT_EQ(not_traced_mode.self_time_ps(), 20); } +TEST(ConvertXPlaneToOpStats, PropagateAndDedupErrors) { + XSpace space; + static constexpr char kError[] = "host: error"; + *space.add_errors() = kError; + *space.add_errors() = kError; + + OpStats op_stats = ConvertXSpaceToOpStats(space); + + EXPECT_EQ(1, op_stats.errors_size()); + EXPECT_EQ(kError, op_stats.errors(/*index=*/0)); +} + } // namespace } // namespace profiler } // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xplane_to_profile_response.cc b/tensorflow/core/profiler/convert/xplane_to_profile_response.cc index e6fe74942fc..70a07171310 100644 --- a/tensorflow/core/profiler/convert/xplane_to_profile_response.cc +++ b/tensorflow/core/profiler/convert/xplane_to_profile_response.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/profiler/convert/xplane_to_profile_response.h" +#include <string> + #include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" @@ -78,14 +80,14 @@ Status ConvertProtoToJson(const Proto& proto_output, std::string* json_output) { // tensorflow::StringPiece. auto error_msg = status.message(); return errors::Internal( - strings::StrCat("Could not convert proto to JSON string: ", - StringPiece(error_msg.data(), error_msg.length()))); + "Could not convert proto to JSON string: ", + absl::string_view(error_msg.data(), error_msg.length())); } return Status::OK(); } // Returns the tool name with extension. -string ToolName(absl::string_view tool) { +std::string ToolName(absl::string_view tool) { if (tool == kTraceViewer) return "trace.json.gz"; if (tool == kMemoryProfile) return "memory_profile.json.gz"; return absl::StrCat(tool, ".pb"); diff --git a/tensorflow/core/profiler/internal/gpu/BUILD b/tensorflow/core/profiler/internal/gpu/BUILD index e6ee8514227..c6fe4d77031 100644 --- a/tensorflow/core/profiler/internal/gpu/BUILD +++ b/tensorflow/core/profiler/internal/gpu/BUILD @@ -55,7 +55,6 @@ tf_cc_test_gpu( linkstatic = tf_kernel_tests_linkstatic(), tags = tf_cuda_tests_tags() + [ "nomac", - "notap", # b/154510273 "gpu_cupti", ], deps = [ diff --git a/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc b/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc index 9119c3d5d0b..ab16693deae 100644 --- a/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc +++ b/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc @@ -16,9 +16,11 @@ limitations under the License. #include "tensorflow/core/profiler/internal/gpu/cupti_tracer.h" #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/container/node_hash_map.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/host_info.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mem.h" @@ -614,15 +616,42 @@ class CuptiDriverApiHookWithActivityApi : public CuptiDriverApiHook { // Grab timestamp for API exit. API entry timestamp saved in cbdata. uint64 end_tsc = CuptiTracer::GetTimestamp(); uint64 start_tsc = *cbdata->correlationData; + TrackContext(cbid, cbdata->context); return AddDriverApiCallbackEvent(collector_, cupti_interface_, device_id, start_tsc, end_tsc, domain, cbid, cbdata); } - Status Flush() override { return Status::OK(); } + Status SyncAndFlush() override { + if (option_.sync_devices_before_stop) { + CuptiApiTracingDisabler disabler; + absl::MutexLock lock(&mutex_); + for (auto &ctx : contexts_) { + cuCtxPushCurrent(ctx); + cuCtxSynchronize(); // Ignore error here for best effort. + CUcontext current; + cuCtxPopCurrent(¤t); + } + } + return Status::OK(); + } private: + void TrackContext(CUpti_CallbackId cbid, CUcontext ctx) { + if (!option_.sync_devices_before_stop) return; + if (ctx == NULL) return; + absl::MutexLock lock(&mutex_); + if (cbid == CUPTI_DRIVER_TRACE_CBID_cuCtxDestroy_v2 || + cbid == CUPTI_DRIVER_TRACE_CBID_cuCtxDestroy) { + contexts_.erase(ctx); + } else { + contexts_.emplace(ctx); + } + } + const CuptiTracerOptions option_; CuptiInterface *cupti_interface_; CuptiTraceCollector *collector_; + absl::Mutex mutex_; + absl::flat_hash_set<CUcontext> contexts_ TF_GUARDED_BY(mutex_); TF_DISALLOW_COPY_AND_ASSIGN(CuptiDriverApiHookWithActivityApi); }; @@ -1158,7 +1187,7 @@ class CuptiDriverApiHookWithCudaEvent : public CuptiDriverApiHook { return AddDriverApiCallbackEvent(collector_, cupti_interface_, device_id, start_tsc, end_tsc, domain, cbid, cbdata); } - Status Flush() override { + Status SyncAndFlush() override { for (auto &recorder : cuda_event_recorders_) { TF_RETURN_IF_ERROR(recorder->Stop()); } @@ -1236,6 +1265,11 @@ class CuptiDriverApiHookWithCudaEvent : public CuptiDriverApiHook { std::vector<std::unique_ptr<CudaEventRecorder>> cuda_event_recorders_; TF_DISALLOW_COPY_AND_ASSIGN(CuptiDriverApiHookWithCudaEvent); }; + +/*static*/ std::string ErrorWithHostname(absl::string_view error_message) { + return absl::StrCat(port::Hostname(), ": ", error_message); +} + } // namespace /*static*/ Status CuptiDriverApiHook::AddDriverApiCallbackEvent( @@ -1397,7 +1431,7 @@ void CuptiTracer::Disable() { } cupti_interface_->CleanUp(); Finalize().IgnoreError(); - cupti_driver_api_hook_->Flush().IgnoreError(); + cupti_driver_api_hook_->SyncAndFlush().IgnoreError(); collector_->Flush(); collector_ = nullptr; option_.reset(); @@ -1641,11 +1675,13 @@ Status CuptiTracer::ProcessActivityBuffer(CUcontext context, uint32_t stream_id, /*static*/ std::string CuptiTracer::ErrorIfAny() { if (CuptiTracer::NumGpus() == 0) { - return "No GPU detected."; + return ErrorWithHostname("No GPU detected."); } else if (CuptiTracer::GetCuptiTracerSingleton()->NeedRootAccess()) { - return "Insufficient privilege to run libcupti (you need root permission)."; + return ErrorWithHostname( + "Insufficient privilege to run libcupti (you need root permission)."); } else if (CuptiTracer::GetTimestamp() == 0) { - return "Failed to load libcupti (is it installed and accessible?)"; + return ErrorWithHostname( + "Failed to load libcupti (is it installed and accessible?)"); } return ""; } diff --git a/tensorflow/core/profiler/internal/gpu/cupti_tracer.h b/tensorflow/core/profiler/internal/gpu/cupti_tracer.h index e236afc5c41..a62c08013e8 100644 --- a/tensorflow/core/profiler/internal/gpu/cupti_tracer.h +++ b/tensorflow/core/profiler/internal/gpu/cupti_tracer.h @@ -147,6 +147,8 @@ struct CuptiTracerOptions { std::vector<CUpti_ActivityKind> activities_selected; // Whether to call cuptiFinalize. bool cupti_finalize = false; + // Whether to call cuCtxSynchronize for each device before Stop(). + bool sync_devices_before_stop = false; }; struct CuptiTracerCollectorOptions { @@ -219,7 +221,7 @@ class CuptiDriverApiHook { virtual Status OnDriverApiExit(int device_id, CUpti_CallbackDomain domain, CUpti_CallbackId cbid, const CUpti_CallbackData* callback_info) = 0; - virtual Status Flush() = 0; + virtual Status SyncAndFlush() = 0; protected: static Status AddDriverApiCallbackEvent( diff --git a/tensorflow/core/profiler/internal/gpu/device_tracer.cc b/tensorflow/core/profiler/internal/gpu/device_tracer.cc index ac6662c8432..0370f6a51f9 100644 --- a/tensorflow/core/profiler/internal/gpu/device_tracer.cc +++ b/tensorflow/core/profiler/internal/gpu/device_tracer.cc @@ -659,12 +659,16 @@ Status GpuTracer::CollectData(XSpace* space) { case State::kStartedOk: return errors::FailedPrecondition("Cannot collect trace before stopping"); case State::kStartedError: - LOG(ERROR) << "Cannot collect, xprof failed to start"; + LOG(ERROR) << "Cannot collect, profiler failed to start"; return Status::OK(); case State::kStoppedError: VLOG(1) << "No trace data collected"; return Status::OK(); case State::kStoppedOk: { + std::string cupti_error = CuptiTracer::ErrorIfAny(); + if (!cupti_error.empty()) { + space->add_errors(cupti_error); + } if (cupti_collector_) { cupti_collector_->Export(space); } diff --git a/tensorflow/core/profiler/internal/traceme_recorder.h b/tensorflow/core/profiler/internal/traceme_recorder.h index 1da7d4cebb1..5fdea5bddbd 100644 --- a/tensorflow/core/profiler/internal/traceme_recorder.h +++ b/tensorflow/core/profiler/internal/traceme_recorder.h @@ -16,6 +16,7 @@ limitations under the License. #define TENSORFLOW_CORE_PROFILER_INTERNAL_TRACEME_RECORDER_H_ #include <atomic> +#include <string> #include <vector> #include "absl/container/flat_hash_map.h" @@ -52,13 +53,13 @@ class TraceMeRecorder { // Times are in ns since the Unix epoch. struct Event { uint64 activity_id; - string name; + std::string name; uint64 start_time; // 0 = missing uint64 end_time; // 0 = missing }; struct ThreadInfo { uint32 tid; - string name; + std::string name; }; struct ThreadEvents { ThreadInfo thread; diff --git a/tensorflow/core/profiler/lib/BUILD b/tensorflow/core/profiler/lib/BUILD index 6316fd118fc..e80b9fc9766 100644 --- a/tensorflow/core/profiler/lib/BUILD +++ b/tensorflow/core/profiler/lib/BUILD @@ -94,6 +94,7 @@ cc_library( hdrs = ["traceme.h"], visibility = ["//visibility:public"], deps = [ + ":traceme_encode", "@com_google_absl//absl/strings", "//tensorflow/core:lib", "//tensorflow/core/platform", @@ -102,6 +103,16 @@ cc_library( ]), ) +cc_library( + name = "traceme_encode", + hdrs = ["traceme_encode.h"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "annotated_traceme", hdrs = ["annotated_traceme.h"], @@ -115,10 +126,17 @@ cc_library( ], ) -tf_pybind_cc_library_wrapper( - name = "scoped_annotation_headers", - visibility = ["//tensorflow/python/profiler/internal:__pkg__"], - deps = [":scoped_annotation"], +cc_library( + name = "connected_traceme", + hdrs = ["connected_traceme.h"], + visibility = ["//visibility:public"], + deps = [ + ":traceme", + ":traceme_encode", + "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + ], ) cc_library( @@ -149,6 +167,7 @@ filegroup( "profiler_session.h", "scoped_annotation.h", "traceme.h", + "traceme_encode.h", ], visibility = ["//visibility:public"], ) diff --git a/tensorflow/core/profiler/lib/connected_traceme.h b/tensorflow/core/profiler/lib/connected_traceme.h new file mode 100644 index 00000000000..5b16e2e3adf --- /dev/null +++ b/tensorflow/core/profiler/lib/connected_traceme.h @@ -0,0 +1,123 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_PROFILER_LIB_CONNECTED_TRACEME_H_ +#define TENSORFLOW_CORE_PROFILER_LIB_CONNECTED_TRACEME_H_ + +#include <string> + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "tensorflow/core/profiler/lib/traceme.h" +#include "tensorflow/core/profiler/lib/traceme_encode.h" + +namespace tensorflow { +namespace profiler { + +enum class ContextType : int { + kGeneric, + kTfExecutor, +}; + +/* + * TraceMeProducer and TraceMeConsumer are used to correlate TraceMe events on + * different threads. TraceMeProducer generates the context information to be + * passed to TraceMeConsumer, which consists of the context id and optionally + * the context type. They may be provided by the user. Then, the events of the + * same context information can be correlated during the analysis. + * + * Example Usages: + * (1) Using the user-provided context type and id. The user is responsible for + * providing the same context type and id to TraceMeProducer and + * TraceMeConsumer. + * [Producer Thread] + * // user_context_id is provided by the user. + * TraceMeProducer producer( + * [&] { return TraceMeEncode("op_dispatch", {{"op_type", "matmul"}}); }, + * ContextType::kTfExecutor, user_context_id); + * [Consumer Thread] + * // user_context_id is provided by the user. + * TraceMeConsumer consumer( + * [&] { return "op_execute"; }, user_context_id, ContextType::kTfExecutor); + * + * (2) Using the user-provided context type and generic id. The user is + * responsible for passing the TraceMeProducer's context id to + * TraceMeConsumer as well as providing the same context type to + * TraceMeProducer and TraceMeConsumer. + * [Producer Thread] + * TraceMeProducer producer( + * [&] { return TraceMeEncode("op_dispatch", {{"op_type", "matmul"}}); }, + * ContextType::kTfExecutor); + * context_id = producer.GetContextId(); + * // Pass context_id to the consumer thread. + * [Consumer Thread] + * // context_id is passed from the producer thread. + * TraceMeConsumer consumer( + * [&] { return "op_execute"; }, context_id, ContextType::kTfExecutor); + * + * (3) Using the generic context information. The user is responsible for + * passing the TraceMeProducer's context id to TraceMeConsumer. + * [Producer Thread] + * TraceMeProducer producer( + * [&] { return TraceMeEncode("op_dispatch", {{"op_type", "matmul"}}); }); + * context_id = producer.GetContextId(); + * // Pass context_id to the consumer thread. + * [Consumer Thread] + * // context_id is passed from the producer thread. + * TraceMeConsumer consumer([&] { return "op_execute"; }, context_id); + */ +class TraceMeProducer { + public: + template <typename NameT> + explicit TraceMeProducer(NameT name, + ContextType context_type = ContextType::kGeneric, + absl::optional<uint64> context_id = absl::nullopt, + int level = 2) + : trace_me_(name, level) { + trace_me_.AppendMetadata([&] { + context_id_ = + context_id.has_value() ? *context_id : TraceMe::NewActivityId(); + return TraceMeEncode( + {{"$pt", static_cast<int>(context_type)}, {"$p", context_id_}}); + }); + } + + uint64 GetContextId() const { return context_id_; } + + private: + TraceMe trace_me_; + uint64 context_id_ = 0; +}; + +class TraceMeConsumer { + public: + template <typename NameT> + TraceMeConsumer(NameT name, uint64 context_id, + ContextType context_type = ContextType::kGeneric, + int level = 2) + : trace_me_(name, level) { + trace_me_.AppendMetadata([&] { + return TraceMeEncode( + {{"$ct", static_cast<int>(context_type)}, {"$c", context_id}}); + }); + } + + private: + TraceMe trace_me_; +}; + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_LIB_CONNECTED_TRACEME_H_ diff --git a/tensorflow/core/profiler/lib/traceme.h b/tensorflow/core/profiler/lib/traceme.h index af93ac11b1e..6df196bdba7 100644 --- a/tensorflow/core/profiler/lib/traceme.h +++ b/tensorflow/core/profiler/lib/traceme.h @@ -16,12 +16,10 @@ limitations under the License. #define TENSORFLOW_CORE_PROFILER_LIB_TRACEME_H_ #include <new> +#include <string> #include <utility> -#include "absl/strings/match.h" -#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" -#include "absl/strings/strip.h" #include "tensorflow/core/platform/env_time.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" @@ -30,6 +28,7 @@ limitations under the License. #if !defined(IS_MOBILE_PLATFORM) #include "tensorflow/core/profiler/internal/traceme_recorder.h" #endif +#include "tensorflow/core/profiler/lib/traceme_encode.h" // IWYU pragma: export namespace tensorflow { namespace profiler { @@ -78,20 +77,21 @@ inline int GetTFTraceMeLevel(bool is_expensive) { // auto id = ActivityStart("step"); // ... do some work ... // ActivityEnd(id); +// The two static methods should be called within the same thread. class TraceMe { public: - // Constructor that traces a user-defined activity labeled with activity_name + // Constructor that traces a user-defined activity labeled with name // in the UI. Level defines the trace priority, used for filtering TraceMe // events. By default, traces with TraceMe level <= 2 are recorded. Levels: // - Must be a positive integer. // - Can be a value in enum TraceMeLevel. // Users are welcome to use level > 3 in their code, if they wish to filter // out their host traces based on verbosity. - explicit TraceMe(absl::string_view activity_name, int level = 1) { + explicit TraceMe(absl::string_view name, int level = 1) { DCHECK_GE(level, 1); #if !defined(IS_MOBILE_PLATFORM) if (TF_PREDICT_FALSE(TraceMeRecorder::Active(level))) { - new (&no_init_.name) string(activity_name); + new (&no_init_.name) std::string(name); start_time_ = EnvTime::NowNanos(); } #endif @@ -102,46 +102,55 @@ class TraceMe { // Note: We can't take the string by value because a) it would make the // overloads ambiguous, and b) we want lvalue strings to use the string_view // constructor so we avoid copying them when tracing is disabled. - explicit TraceMe(string &&activity_name, int level = 1) { + explicit TraceMe(std::string&& name, int level = 1) { DCHECK_GE(level, 1); #if !defined(IS_MOBILE_PLATFORM) if (TF_PREDICT_FALSE(TraceMeRecorder::Active(level))) { - new (&no_init_.name) string(std::move(activity_name)); + new (&no_init_.name) std::string(std::move(name)); start_time_ = EnvTime::NowNanos(); } #endif } // Do not allow passing strings by reference or value since the caller - // may unintentionally maintain ownership of the activity_name. - // Explicitly std::move the activity_name or wrap it in a string_view if + // may unintentionally maintain ownership of the name. + // Explicitly std::move the name or wrap it in a string_view if // you really wish to maintain ownership. - explicit TraceMe(const string &activity_name, int level = 1) = delete; + explicit TraceMe(const std::string& name, int level = 1) = delete; // This overload is necessary to make TraceMe's with string literals work. // Otherwise, the string&& and the string_view constructor would be equally // good overload candidates. - explicit TraceMe(const char *raw, int level = 1) + explicit TraceMe(const char* raw, int level = 1) : TraceMe(absl::string_view(raw), level) {} - // This overload only generates the activity name if tracing is enabled. - // Useful for avoiding things like string concatenation when tracing is - // disabled. The |name_generator| may be a lambda or functor that returns a - // type that the string() constructor can take. + // This overload only generates the name (and possibly metadata) if tracing is + // enabled. Useful for avoiding expensive operations (e.g., string + // concatenation) when tracing is disabled. + // name_generator may be a lambda or functor that returns a type that the + // string() constructor can take, e.g., the result of TraceMeEncode. // name_generator is templated, rather than a std::function to avoid // allocations std::function might make even if never called. - // Usage: profiler::TraceMe([&]{ return StrCat(prefix, ":", postfix); }); + // Example Usage: + // TraceMe op_trace_me([&]() { + // return StrCat(op_name, ":", op_type); + // } + // TraceMe trace_me_with_metadata([&value1]() { + // return TraceMeEncode("my_trace", {{"key1", value1}, {"key2", 42}}); + // }); template <typename NameGeneratorT> explicit TraceMe(NameGeneratorT name_generator, int level = 1) { DCHECK_GE(level, 1); #if !defined(IS_MOBILE_PLATFORM) if (TF_PREDICT_FALSE(TraceMeRecorder::Active(level))) { - new (&no_init_.name) string(name_generator()); + new (&no_init_.name) std::string(name_generator()); start_time_ = EnvTime::NowNanos(); } #endif } + ~TraceMe() { Stop(); } + // Stop tracing the activity. Called by the destructor, but exposed to allow // stopping tracing before the object goes out of scope. Only has an effect // the first time it is called. @@ -166,28 +175,27 @@ class TraceMe { #endif } - // Sets new_metadata in the metadata part of no_init_.name. - void SetMetadata(absl::string_view new_metadata) { + // Appends new_metadata to the TraceMe name passed to the constructor. + // metadata_generator may be a lambda or functor that returns a type that the + // string() constructor can take, e.g., the result of TraceMeEncode. + // metadata_generator is only evaluated when tracing is enabled. + // metadata_generator is templated, rather than a std::function to avoid + // allocations std::function might make even if never called. + // Example Usage: + // trace_me.AppendMetadata([&value1]() { + // return TraceMeEncode({{"key1", value1}, {"key2", 42}}); + // }); + template <typename MetadataGeneratorT> + void AppendMetadata(MetadataGeneratorT metadata_generator) { #if !defined(IS_MOBILE_PLATFORM) if (TF_PREDICT_FALSE(start_time_ != kUntracedActivity)) { if (TF_PREDICT_TRUE(TraceMeRecorder::Active())) { - absl::string_view orig = no_init_.name; - if (absl::EndsWith(orig, "#")) { - // orig does have metadata. - absl::ConsumeSuffix(&orig, "#"); - absl::ConsumePrefix(&new_metadata, "#"); - no_init_.name = absl::StrCat(orig, ",", new_metadata); - } else { - // orig does not have metadata. - absl::StrAppend(&no_init_.name, new_metadata); - } + traceme_internal::AppendMetadata(&no_init_.name, metadata_generator()); } } #endif } - ~TraceMe() { Stop(); } - // Static API, for use when scoped objects are inconvenient. // Record the start time of an activity. @@ -196,7 +204,7 @@ class TraceMe { #if !defined(IS_MOBILE_PLATFORM) if (TF_PREDICT_FALSE(TraceMeRecorder::Active(level))) { uint64 activity_id = TraceMeRecorder::NewActivityId(); - TraceMeRecorder::Record({activity_id, string(name), + TraceMeRecorder::Record({activity_id, std::string(name), /*start_time=*/EnvTime::NowNanos(), /*end_time=*/0}); return activity_id; @@ -211,7 +219,8 @@ class TraceMe { // We don't check the level again (see TraceMe::Stop()). if (TF_PREDICT_FALSE(activity_id != kUntracedActivity)) { if (TF_PREDICT_TRUE(TraceMeRecorder::Active())) { - TraceMeRecorder::Record({activity_id, /*name=*/"", /*start_time=*/0, + TraceMeRecorder::Record({activity_id, /*name=*/std::string(), + /*start_time=*/0, /*end_time=*/EnvTime::NowNanos()}); } } @@ -226,6 +235,14 @@ class TraceMe { #endif } + static uint64 NewActivityId() { +#if !defined(IS_MOBILE_PLATFORM) + return TraceMeRecorder::NewActivityId(); +#else + return 0; +#endif + } + private: // Activity ID or start time used when tracing is disabled. constexpr static uint64 kUntracedActivity = 0; @@ -239,7 +256,7 @@ class TraceMe { union NoInit { NoInit() {} ~NoInit() {} - string name; + std::string name; } no_init_; uint64 start_time_ = kUntracedActivity; diff --git a/tensorflow/core/profiler/lib/traceme_encode.h b/tensorflow/core/profiler/lib/traceme_encode.h new file mode 100644 index 00000000000..2e23c6d878b --- /dev/null +++ b/tensorflow/core/profiler/lib/traceme_encode.h @@ -0,0 +1,127 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_PROFILER_LIB_TRACEME_ENCODE_H_ +#define TENSORFLOW_CORE_PROFILER_LIB_TRACEME_ENCODE_H_ + +#include <string.h> + +#include <initializer_list> +#include <string> +#include <utility> + +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { +namespace profiler { +namespace traceme_internal { + +// Copies the contents of str to the address pointed by out. +// Returns the address after the copy. +// REQUIRED: The address range [out, out + str.size()] must have been allocated. +TF_ATTRIBUTE_ALWAYS_INLINE inline char* Append(char* out, + absl::string_view str) { + const size_t str_size = str.size(); + if (TF_PREDICT_TRUE(str_size > 0)) { + memcpy(out, str.data(), str_size); + out += str_size; + } + return out; +} + +// Appends args encoded as TraceMe metadata to name. +TF_ATTRIBUTE_ALWAYS_INLINE inline std::string AppendArgs( + std::string name, + const std::initializer_list<std::pair<absl::string_view, absl::AlphaNum>>& + args) { + if (TF_PREDICT_TRUE(args.size() > 0)) { + const auto old_size = name.size(); + auto new_size = old_size + args.size() * 2 + 1; + for (const auto& arg : args) { + new_size += arg.first.size() + arg.second.size(); + } + name.resize(new_size); + char* const begin = &name[0]; + char* out = begin + old_size; + *out++ = '#'; + for (const auto& arg : args) { + out = Append(out, arg.first); + *out++ = '='; + out = Append(out, arg.second.Piece()); + *out++ = ','; + } + *(out - 1) = '#'; + DCHECK_EQ(out, begin + new_size); + } + return name; +} + +// Appends new_metadata to the metadata part of name. +TF_ATTRIBUTE_ALWAYS_INLINE inline void AppendMetadata( + std::string* name, absl::string_view new_metadata) { + if (!TF_PREDICT_FALSE(new_metadata.empty())) { + if (!name->empty() && name->back() == '#') { // name already has metadata + name->back() = ','; + if (TF_PREDICT_TRUE(new_metadata.front() == '#')) { + new_metadata.remove_prefix(1); + } + } + name->append(new_metadata.data(), new_metadata.size()); + } +} + +} // namespace traceme_internal + +// Encodes an event name and arguments into TraceMe metadata. +// Use within a lambda to avoid expensive operations when tracing is disabled. +// Example Usage: +// TraceMe trace_me([value1]() { +// return TraceMeEncode("my_trace", {{"key1", value1}, {"key2", 42}}); +// }); +inline std::string TraceMeEncode( + std::string name, + std::initializer_list<std::pair<absl::string_view, absl::AlphaNum>> args) { + return traceme_internal::AppendArgs(std::move(name), args); +} +inline std::string TraceMeEncode( + absl::string_view name, + std::initializer_list<std::pair<absl::string_view, absl::AlphaNum>> args) { + return traceme_internal::AppendArgs(std::string(name), args); +} +inline std::string TraceMeEncode( + const char* name, + std::initializer_list<std::pair<absl::string_view, absl::AlphaNum>> args) { + return traceme_internal::AppendArgs(std::string(name), args); +} + +// Encodes arguments into TraceMe metadata. +// Use within a lambda to avoid expensive operations when tracing is disabled. +// Example Usage: +// TraceMe trace_me("my_trace"); +// ... +// trace_me.AppendMetadata([value1]() { +// return TraceMeEncode({{"key1", value1}, {"key2", 42}}); +// }); +inline std::string TraceMeEncode( + std::initializer_list<std::pair<absl::string_view, absl::AlphaNum>> args) { + return traceme_internal::AppendArgs(std::string(), args); +} + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_LIB_TRACEME_ENCODE_H_ diff --git a/tensorflow/core/profiler/protobuf/op_metrics.proto b/tensorflow/core/profiler/protobuf/op_metrics.proto index c0f34773e02..af38795b7b2 100644 --- a/tensorflow/core/profiler/protobuf/op_metrics.proto +++ b/tensorflow/core/profiler/protobuf/op_metrics.proto @@ -26,7 +26,7 @@ message LayoutAnalysis { } // Metrics for an operation (accumulated over all occurrences). -// Next ID: 19 +// Next ID: 20 message OpMetrics { // HLO module id. 0 for TF ops. uint64 hlo_module_id = 13; @@ -50,6 +50,19 @@ message OpMetrics { uint64 flops = 2; // Total bytes accessed. uint64 bytes_accessed = 5; + // Breakdown of memory accessed by operation type and memory space. + message MemoryAccessed { + enum OperationType { + UNKNOWN = 0; + READ = 1; + WRITE = 2; + } + OperationType operation_type = 1; + // Device-specific id of memory space. + uint64 memory_space = 2; + uint64 bytes_accessed = 3; + } + repeated MemoryAccessed memory_accessed_breakdown = 19; // Total dma stall time in picoseconds. uint64 dma_stall_ps = 10; // The data layout for this op. Only set for convolution ops for now. diff --git a/tensorflow/core/profiler/protobuf/overview_page.proto b/tensorflow/core/profiler/protobuf/overview_page.proto index 018aa759cc5..1590076d55f 100644 --- a/tensorflow/core/profiler/protobuf/overview_page.proto +++ b/tensorflow/core/profiler/protobuf/overview_page.proto @@ -81,6 +81,8 @@ message OverviewPageRecommendation { // A statement for input that recommends the next steps for investigating the // bottleneck. string statement = 2; + // A list of tips for tackling input bottleneck. + repeated OverviewPageTip input_tips = 11; // A statement for output that recommends the next steps for investigating the // bottleneck. string output_statement = 9; diff --git a/tensorflow/core/profiler/utils/errors.cc b/tensorflow/core/profiler/utils/errors.cc index 9c678e98a43..1851c624e5c 100644 --- a/tensorflow/core/profiler/utils/errors.cc +++ b/tensorflow/core/profiler/utils/errors.cc @@ -33,5 +33,10 @@ const absl::string_view kErrorNoStepMarker = " than the step time. For (1), you need to add step instrumentation;" " for (2), you may try to profile longer."; +const absl::string_view kNoDeviceTraceCollected = + "No device trace was collected. This might happen if your job hadn't been " + "run on the device when sampling was turned on. You could try the sampling" + " again later."; + } // namespace profiler } // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/errors.h b/tensorflow/core/profiler/utils/errors.h index b213fd05c71..2dcb60e6899 100644 --- a/tensorflow/core/profiler/utils/errors.h +++ b/tensorflow/core/profiler/utils/errors.h @@ -28,6 +28,8 @@ ABSL_CONST_INIT extern const absl::string_view kErrorIncompleteStep; // step info. ABSL_CONST_INIT extern const absl::string_view kErrorNoStepMarker; +ABSL_CONST_INIT extern const absl::string_view kNoDeviceTraceCollected; + } // namespace profiler } // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/xplane_schema.cc b/tensorflow/core/profiler/utils/xplane_schema.cc index f8ff31b078a..3705a4786fa 100644 --- a/tensorflow/core/profiler/utils/xplane_schema.cc +++ b/tensorflow/core/profiler/utils/xplane_schema.cc @@ -147,6 +147,11 @@ const StatTypeMap& GetStatTypeMap() { {"region_type", kRegionType}, {"data_type", kDataType}, {"shape", kTensorShapes}, + // XPlane semantics related. + {"$pt", kProducerType}, + {"$ct", kConsumerType}, + {"$p", kProducerId}, + {"$c", kConsumerId}, // Device trace arguments. {"device_id", kDeviceId}, {"context_id", kContextId}, @@ -158,6 +163,7 @@ const StatTypeMap& GetStatTypeMap() { {"stream", kStream}, // Stats added when processing traces. {"group_id", kGroupId}, + {"flow", kFlow}, {"step_name", kStepName}, {"level 0", kLevel0}, {"tf_op", kTfOp}, diff --git a/tensorflow/core/profiler/utils/xplane_schema.h b/tensorflow/core/profiler/utils/xplane_schema.h index 31ff90155f5..de8dc32a4f1 100644 --- a/tensorflow/core/profiler/utils/xplane_schema.h +++ b/tensorflow/core/profiler/utils/xplane_schema.h @@ -139,6 +139,11 @@ enum StatType { kRegionType, kDataType, kTensorShapes, + // XPlane semantics related. + kProducerType, + kConsumerType, + kProducerId, + kConsumerId, // Device trace arguments. kDeviceId, kContextId, @@ -150,6 +155,7 @@ enum StatType { kStream, // Stats added when processing traces. kGroupId, + kFlow, kStepName, kLevel0, kTfOp, @@ -204,6 +210,38 @@ inline bool IsInternalStat(absl::optional<int64> stat_type) { stat_type == StatType::kLevel0; } +// Support for flow events: +// This class enables encoding/decoding the flow id and direction, stored as +// XStat value. +class XFlow { + public: + enum FlowDirection { + kFlowUnspecified = 0x0, + kFlowIn = 0x1, + kFlowOut = 0x2, + kFlowInOut = 0x3, + }; + + XFlow(uint64 flow_id, FlowDirection direction) + : encoded_((flow_id << 2) | (direction & 0x3)) { + DCHECK_NE(Direction(), kFlowUnspecified); + } + + // Encoding + uint64 ToStatValue() const { return encoded_; } + + // Decoding + static XFlow FromStatValue(uint64 encoded) { return XFlow(encoded); } + + uint64 Id() const { return (encoded_ >> 2); } + FlowDirection Direction() const { return FlowDirection(encoded_ & 0x3); } + + private: + explicit XFlow(uint64 encoded) : encoded_(encoded) {} + + uint64 encoded_; +}; + } // namespace profiler } // namespace tensorflow diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index a534c0cf827..7131d1f7227 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 401 // Updated: 2020/5/14 +#define TF_GRAPH_DEF_VERSION 414 // Updated: 2020/5/27 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // diff --git a/tensorflow/core/tpu/BUILD b/tensorflow/core/tpu/BUILD index 46a8759a257..5d1b7e1101f 100644 --- a/tensorflow/core/tpu/BUILD +++ b/tensorflow/core/tpu/BUILD @@ -68,6 +68,13 @@ cc_library( deps = ["//tensorflow/core:protos_all_cc"], ) +cc_library( + name = "tpu_configuration", + srcs = ["tpu_configuration.cc"], + hdrs = ["tpu_configuration.h"], + deps = ["//tensorflow/core:framework"], +) + cc_library( name = "tpu_init_mode", srcs = ["tpu_init_mode.cc"], @@ -84,3 +91,11 @@ cc_library( "//tensorflow/c:tf_status", ], ) + +cc_library( + name = "tpu_library_loader", + srcs = ["tpu_library_loader.cc"], + hdrs = ["tpu_library_loader.h"], + visibility = ["//tensorflow:__subpackages__"], + deps = ["//tensorflow/core/platform:status"], +) diff --git a/tensorflow/core/tpu/tpu_configuration.cc b/tensorflow/core/tpu/tpu_configuration.cc new file mode 100644 index 00000000000..3788d5cc6c2 --- /dev/null +++ b/tensorflow/core/tpu/tpu_configuration.cc @@ -0,0 +1,44 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/tpu/tpu_configuration.h" + +namespace tensorflow { + +namespace { + +ResourceMgr* GetGlobalResourceMgr() { + static ResourceMgr* const rmgr = new ResourceMgr(); + return rmgr; +} + +} // namespace + +#if !defined(PLATFORM_GOOGLE) +// Used only by Google-internal tests, so deliberately left empty. +void MaybeInitializeTPUSystemForTests() {} +#endif + +ResourceMgr* GetTPUConfigResourceMgr() { + MaybeInitializeTPUSystemForTests(); + + // Put all TPU-related state in the global ResourceMgr. This includes the + // TpuPodState, compilation cache, etc. We don't use the TPU_SYSTEM + // ResourceMgr because there may be more than one TPU_SYSTEM ResourceMgr when + // DirectSession or isolate_session_state are used. + return GetGlobalResourceMgr(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/tpu/tpu_configuration.h b/tensorflow/core/tpu/tpu_configuration.h new file mode 100644 index 00000000000..6c337bd0fe7 --- /dev/null +++ b/tensorflow/core/tpu/tpu_configuration.h @@ -0,0 +1,30 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TPU_TPU_CONFIGURATION_H_ +#define TENSORFLOW_CORE_TPU_TPU_CONFIGURATION_H_ + +#include "tensorflow/core/framework/resource_mgr.h" + +namespace tensorflow { + +void MaybeInitializeTPUSystemForTests(); + +// Returns a process-wide global ResourceMgr. +ResourceMgr* GetTPUConfigResourceMgr(); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_TPU_CONFIGURATION_H_ diff --git a/tensorflow/core/tpu/tpu_library_loader.cc b/tensorflow/core/tpu/tpu_library_loader.cc new file mode 100644 index 00000000000..bfd9fe29efe --- /dev/null +++ b/tensorflow/core/tpu/tpu_library_loader.cc @@ -0,0 +1,30 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/tpu/tpu_library_loader.h" + +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { +namespace tpu { + +Status InitializeTPULibrary(void* library) { + // TODO(frankchn): dlsym the loaded library and populate a struct with the + // relevant C APIs necessary for TPUs. + return Status::OK(); +} + +} // namespace tpu +} // namespace tensorflow diff --git a/tensorflow/core/tpu/tpu_library_loader.h b/tensorflow/core/tpu/tpu_library_loader.h new file mode 100644 index 00000000000..35a7dd7c9be --- /dev/null +++ b/tensorflow/core/tpu/tpu_library_loader.h @@ -0,0 +1,29 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TPU_TPU_LIBRARY_LOADER_H_ +#define TENSORFLOW_CORE_TPU_TPU_LIBRARY_LOADER_H_ + +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { +namespace tpu { + +Status InitializeTPULibrary(void* library); + +} // namespace tpu +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_TPU_LIBRARY_LOADER_H_ diff --git a/tensorflow/core/util/BUILD b/tensorflow/core/util/BUILD index de2dce9c0c2..8e878c2464d 100644 --- a/tensorflow/core/util/BUILD +++ b/tensorflow/core/util/BUILD @@ -505,6 +505,16 @@ cc_library( ], ) +cc_library( + name = "incremental_barrier", + srcs = ["incremental_barrier.cc"], + hdrs = ["incremental_barrier.h"], + deps = [ + "//tensorflow/core:lib", + "@com_google_absl//absl/functional:bind_front", + ], +) + # Tests. tf_cc_test( @@ -632,6 +642,20 @@ tf_cc_test( ], ) +tf_cc_test( + name = "incremental_barrier_test", + srcs = ["incremental_barrier_test.cc"], + deps = [ + ":incremental_barrier", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/platform", + "@com_google_absl//absl/functional:bind_front", + "@com_google_absl//absl/time", + ], +) + # Proto libraries. tf_proto_library( name = "test_log_proto_impl", diff --git a/tensorflow/core/util/incremental_barrier.cc b/tensorflow/core/util/incremental_barrier.cc new file mode 100644 index 00000000000..cbea7f25cc5 --- /dev/null +++ b/tensorflow/core/util/incremental_barrier.cc @@ -0,0 +1,64 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/util/incremental_barrier.h" + +#include <atomic> +#include <functional> + +#include "absl/functional/bind_front.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +class InternalIncrementalBarrier { + public: + explicit InternalIncrementalBarrier(IncrementalBarrier::DoneCallback callback) + : left_(1), done_callback_(std::move(callback)) {} + + void operator()() { + DCHECK_GE(left_.load(std::memory_order_relaxed), 0); + + if (left_.fetch_sub(1, std::memory_order_acq_rel) - 1 == 0) { + IncrementalBarrier::DoneCallback done_callback = + std::move(done_callback_); + delete this; + done_callback(); + } + } + + IncrementalBarrier::BarrierCallback Inc() { + left_.fetch_add(1, std::memory_order_acq_rel); + + // std::bind_front is only available ever since C++20. + return absl::bind_front(&InternalIncrementalBarrier::operator(), this); + } + + private: + std::atomic<int> left_; + IncrementalBarrier::DoneCallback done_callback_; +}; + +IncrementalBarrier::IncrementalBarrier(DoneCallback done_callback) + : internal_barrier_( + new InternalIncrementalBarrier(std::move(done_callback))) {} + +IncrementalBarrier::~IncrementalBarrier() { (*internal_barrier_)(); } + +IncrementalBarrier::BarrierCallback IncrementalBarrier::Inc() { + return internal_barrier_->Inc(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/util/incremental_barrier.h b/tensorflow/core/util/incremental_barrier.h new file mode 100644 index 00000000000..be45e9d4d8b --- /dev/null +++ b/tensorflow/core/util/incremental_barrier.h @@ -0,0 +1,81 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_INCREMENTAL_BARRIER_H_ +#define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_INCREMENTAL_BARRIER_H_ + +#include <atomic> +#include <functional> + +namespace tensorflow { + +class InternalIncrementalBarrier; + +// BarrierClosure (see +// https://github.com/chromium/chromium/blob/master/base/barrier_closure.h) +// executes a callback after it has been invoked |num_closures| times. +// Plus, `BarrierClosure` is a continuation-passing style abstraction and self- +// deleting. + +// IncrementalBarrier is a convenience class to be used in place of a barrier +// closure, which is particularly helpful (e.g. simplify code) because callers +// don't need to calculate the |num_closures| beforehand. +// +// Example Usage: +// void MakeCalls() { +// typedef std::function<void()> Callback; +// typedef std::function<void(Callback)> OtherCallback; +// Callback done_callback = ... +// OtherCallback cb1 = ... +// OtherCallback cb2 = ... +// std::thread threads[2]; +// { +// IncrementalBarrier barrier(done_callback); +// threads[0] = std::thread(cb1(barrier.Inc()); +// threads[1] = std::thread(cb2(barrier.Inc()); +// ... at this moment, `barrier` is incremented twice, and then +// destructed.... +// } +// threads[0].join(); +// threads[1].join(); +// } +// +// `done_callback` will be called when both conditions are true: +// 1) after `barrier` is destructed. +// 2) Each `BarrierCallback` returned by `Inc` is called. +// This class is thread-safe. +class IncrementalBarrier { + public: + typedef std::function<void()> DoneCallback; + typedef std::function<void()> BarrierCallback; + explicit IncrementalBarrier(DoneCallback callback); + + ~IncrementalBarrier(); + + // Returns a BarrierCallback (std::function) that individual task call to + // signal its completeness. + // The returned BarrierCallback outlives this `IncrementalBarrier` instance. + // Furthermore, each task should eventually call the returned function, or + // else done_callback wouldn't be called. + BarrierCallback Inc(); + + private: + // self-deleting, thereby not owned by 'IncrementalBarrier'. + InternalIncrementalBarrier* internal_barrier_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_INCREMENTAL_BARRIER_H_ diff --git a/tensorflow/core/util/incremental_barrier_test.cc b/tensorflow/core/util/incremental_barrier_test.cc new file mode 100644 index 00000000000..020cb9ece32 --- /dev/null +++ b/tensorflow/core/util/incremental_barrier_test.cc @@ -0,0 +1,133 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/util/incremental_barrier.h" + +#include <atomic> + +#include "absl/functional/bind_front.h" +#include "absl/time/time.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/platform.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/threadpool.h" + +namespace tensorflow { +namespace { + +// A thread-safe counter class. +class Counter { + public: + void Increment() TF_LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + ++count_; + } + + int GetCount() TF_LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + return count_; + } + + private: + mutex mu_; + int count_ = 0; +}; + +TEST(IncrementalBarrierTest, RunInstantlyWhenZeroClosure) { + Counter counter; + EXPECT_EQ(counter.GetCount(), 0); + { + IncrementalBarrier::DoneCallback done_callback = + absl::bind_front(&Counter::Increment, &counter); + IncrementalBarrier barrier(done_callback); + EXPECT_EQ(counter.GetCount(), 0); + } + EXPECT_EQ(counter.GetCount(), 1); +} + +TEST(IncrementalBarrierTest, RunAfterNumClosuresOneNowTwoLater) { + Counter counter; + + IncrementalBarrier::BarrierCallback bc1, bc2; + { + IncrementalBarrier::DoneCallback done_callback = + absl::bind_front(&Counter::Increment, &counter); + IncrementalBarrier barrier(done_callback); + + CHECK_EQ(counter.GetCount(), 0); + + bc1 = barrier.Inc(); + bc2 = barrier.Inc(); + + IncrementalBarrier::BarrierCallback bc3 = barrier.Inc(); + bc3(); + + CHECK_EQ(counter.GetCount(), 0); + } + + CHECK_EQ(counter.GetCount(), 0); + bc1(); + CHECK_EQ(counter.GetCount(), 0); + bc2(); + CHECK_EQ(counter.GetCount(), 1); +} + +TEST(IncrementalBarrierTest, RunAfterNumClosuresConcurrency) { + const int num_closure = 100, num_thread = 2; + std::atomic<int> schedule_count{0}; + Counter counter; + + { + IncrementalBarrier::DoneCallback done_callback = + absl::bind_front(&Counter::Increment, &counter); + IncrementalBarrier barrier(done_callback); + + CHECK_EQ(counter.GetCount(), 0); + + tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), + "BarrierClosure", num_thread); + for (int i = 0; i < num_closure; ++i) { + pool.Schedule([&barrier, &schedule_count]() { + schedule_count.fetch_add(1); + IncrementalBarrier::BarrierCallback bc = barrier.Inc(); + + Env::Default()->SleepForMicroseconds(100); + bc(); + }); + } + + CHECK_EQ(counter.GetCount(), 0); + } + + CHECK_EQ(schedule_count.load(std::memory_order_relaxed), 100); + CHECK_EQ(counter.GetCount(), 1); +} + +#if defined(PLATFORM_GOOGLE) +void BM_FunctionInc(benchmark::State& state) { + IncrementalBarrier barrier([] {}); + for (auto _ : state) { + barrier.Inc()(); + } +} + +BENCHMARK(BM_FunctionInc); +#endif // PLATFORM_GOOGLE + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 04c36ed3399..33eba9a734f 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -11417,6 +11417,32 @@ func DynamicStitch(scope *Scope, indices []tf.Output, data []tf.Output) (merged return op.Output(0) } +// Uncompresses a compressed dataset element. +func UncompressElement(scope *Scope, compressed tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (components []tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} + opspec := tf.OpSpec{ + Type: "UncompressElement", + Input: []tf.Input{ + compressed, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if components, idx, err = makeOutputList(op, idx, "components"); err != nil { + scope.UpdateErr("UncompressElement", err) + return + } + return components +} + // Records the bytes size of each element of `input_dataset` in a StatsAggregator. func BytesProducedStatsDataset(scope *Scope, input_dataset tf.Output, tag tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { @@ -11909,6 +11935,108 @@ func CropAndResizeGradBoxes(scope *Scope, grads tf.Output, image tf.Output, boxe return op.Output(0) } +// ExtractGlimpseV2Attr is an optional argument to ExtractGlimpseV2. +type ExtractGlimpseV2Attr func(optionalAttr) + +// ExtractGlimpseV2Centered sets the optional centered attribute to value. +// +// value: indicates if the offset coordinates are centered relative to +// the image, in which case the (0, 0) offset is relative to the center +// of the input images. If false, the (0,0) offset corresponds to the +// upper left corner of the input images. +// If not specified, defaults to true +func ExtractGlimpseV2Centered(value bool) ExtractGlimpseV2Attr { + return func(m optionalAttr) { + m["centered"] = value + } +} + +// ExtractGlimpseV2Normalized sets the optional normalized attribute to value. +// +// value: indicates if the offset coordinates are normalized. +// If not specified, defaults to true +func ExtractGlimpseV2Normalized(value bool) ExtractGlimpseV2Attr { + return func(m optionalAttr) { + m["normalized"] = value + } +} + +// ExtractGlimpseV2UniformNoise sets the optional uniform_noise attribute to value. +// +// value: indicates if the noise should be generated using a +// uniform distribution or a Gaussian distribution. +// If not specified, defaults to true +func ExtractGlimpseV2UniformNoise(value bool) ExtractGlimpseV2Attr { + return func(m optionalAttr) { + m["uniform_noise"] = value + } +} + +// ExtractGlimpseV2Noise sets the optional noise attribute to value. +// +// value: indicates if the noise should `uniform`, `gaussian`, or +// `zero`. The default is `uniform` which means the the noise type +// will be decided by `uniform_noise`. +// If not specified, defaults to "uniform" +func ExtractGlimpseV2Noise(value string) ExtractGlimpseV2Attr { + return func(m optionalAttr) { + m["noise"] = value + } +} + +// Extracts a glimpse from the input tensor. +// +// Returns a set of windows called glimpses extracted at location +// `offsets` from the input tensor. If the windows only partially +// overlaps the inputs, the non overlapping areas will be filled with +// random noise. +// +// The result is a 4-D tensor of shape `[batch_size, glimpse_height, +// glimpse_width, channels]`. The channels and batch dimensions are the +// same as that of the input tensor. The height and width of the output +// windows are specified in the `size` parameter. +// +// The argument `normalized` and `centered` controls how the windows are built: +// +// * If the coordinates are normalized but not centered, 0.0 and 1.0 +// correspond to the minimum and maximum of each height and width +// dimension. +// * If the coordinates are both normalized and centered, they range from +// -1.0 to 1.0. The coordinates (-1.0, -1.0) correspond to the upper +// left corner, the lower right corner is located at (1.0, 1.0) and the +// center is at (0, 0). +// * If the coordinates are not normalized they are interpreted as +// numbers of pixels. +// +// Arguments: +// input: A 4-D float tensor of shape `[batch_size, height, width, channels]`. +// size: A 1-D tensor of 2 elements containing the size of the glimpses +// to extract. The glimpse height must be specified first, following +// by the glimpse width. +// offsets: A 2-D integer tensor of shape `[batch_size, 2]` containing +// the y, x locations of the center of each window. +// +// Returns A tensor representing the glimpses `[batch_size, +// glimpse_height, glimpse_width, channels]`. +func ExtractGlimpseV2(scope *Scope, input tf.Output, size tf.Output, offsets tf.Output, optional ...ExtractGlimpseV2Attr) (glimpse tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ExtractGlimpseV2", + Input: []tf.Input{ + input, size, offsets, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // ExtractGlimpseAttr is an optional argument to ExtractGlimpse. type ExtractGlimpseAttr func(optionalAttr) @@ -26103,6 +26231,173 @@ func FusedPadConv2D(scope *Scope, input tf.Output, paddings tf.Output, filter tf return op.Output(0) } +// Adjust the hue of one or more images. +// +// `images` is a tensor of at least 3 dimensions. The last dimension is +// interpreted as channels, and must be three. +// +// The input image is considered in the RGB colorspace. Conceptually, the RGB +// colors are first mapped into HSV. A delta is then applied all the hue values, +// and then remapped back to RGB colorspace. +// +// Arguments: +// images: Images to adjust. At least 3-D. +// delta: A float delta to add to the hue. +// +// Returns The hue-adjusted image or images. +func AdjustHue(scope *Scope, images tf.Output, delta tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "AdjustHue", + Input: []tf.Input{ + images, delta, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// List of the given size with empty elements. +// +// element_shape: the shape of the future elements of the list +// num_elements: the number of elements to reserve +// handle: the output list +// element_dtype: the desired type of elements in the list. +func TensorListReserve(scope *Scope, element_shape tf.Output, num_elements tf.Output, element_dtype tf.DataType) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"element_dtype": element_dtype} + opspec := tf.OpSpec{ + Type: "TensorListReserve", + Input: []tf.Input{ + element_shape, num_elements, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Clips tensor values to a specified min and max. +// +// Given a tensor `t`, this operation returns a tensor of the same type and +// shape as `t` with its values clipped to `clip_value_min` and `clip_value_max`. +// Any values less than `clip_value_min` are set to `clip_value_min`. Any values +// greater than `clip_value_max` are set to `clip_value_max`. +// +// Arguments: +// t: A `Tensor`. +// clip_value_min: A 0-D (scalar) `Tensor`, or a `Tensor` with the same shape +// as `t`. The minimum value to clip by. +// clip_value_max: A 0-D (scalar) `Tensor`, or a `Tensor` with the same shape +// as `t`. The maximum value to clip by. +// +// Returns A clipped `Tensor` with the same shape as input 't'. +func ClipByValue(scope *Scope, t tf.Output, clip_value_min tf.Output, clip_value_max tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ClipByValue", + Input: []tf.Input{ + t, clip_value_min, clip_value_max, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Conv2DBackpropFilterAttr is an optional argument to Conv2DBackpropFilter. +type Conv2DBackpropFilterAttr func(optionalAttr) + +// Conv2DBackpropFilterUseCudnnOnGpu sets the optional use_cudnn_on_gpu attribute to value. +// If not specified, defaults to true +func Conv2DBackpropFilterUseCudnnOnGpu(value bool) Conv2DBackpropFilterAttr { + return func(m optionalAttr) { + m["use_cudnn_on_gpu"] = value + } +} + +// Conv2DBackpropFilterExplicitPaddings sets the optional explicit_paddings attribute to value. +// +// value: If `padding` is `"EXPLICIT"`, the list of explicit padding amounts. For the ith +// dimension, the amount of padding inserted before and after the dimension is +// `explicit_paddings[2 * i]` and `explicit_paddings[2 * i + 1]`, respectively. If +// `padding` is not `"EXPLICIT"`, `explicit_paddings` must be empty. +// If not specified, defaults to <> +func Conv2DBackpropFilterExplicitPaddings(value []int64) Conv2DBackpropFilterAttr { + return func(m optionalAttr) { + m["explicit_paddings"] = value + } +} + +// Conv2DBackpropFilterDataFormat sets the optional data_format attribute to value. +// +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, in_channels, in_height, in_width]. +// If not specified, defaults to "NHWC" +func Conv2DBackpropFilterDataFormat(value string) Conv2DBackpropFilterAttr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// Conv2DBackpropFilterDilations sets the optional dilations attribute to value. +// +// value: 1-D tensor of length 4. The dilation factor for each dimension of +// `input`. If set to k > 1, there will be k-1 skipped cells between each filter +// element on that dimension. The dimension order is determined by the value of +// `data_format`, see above for details. Dilations in the batch and depth +// dimensions must be 1. +// If not specified, defaults to <i:1 i:1 i:1 i:1 > +func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr { + return func(m optionalAttr) { + m["dilations"] = value + } +} + +// Computes the gradients of convolution with respect to the filter. +// +// Arguments: +// input: 4-D with shape `[batch, in_height, in_width, in_channels]`. +// filter_sizes: An integer vector representing the tensor shape of `filter`, +// where `filter` is a 4-D +// `[filter_height, filter_width, in_channels, out_channels]` tensor. +// out_backprop: 4-D with shape `[batch, out_height, out_width, out_channels]`. +// Gradients w.r.t. the output of the convolution. +// strides: The stride of the sliding window for each dimension of the input +// of the convolution. Must be in the same order as the dimension specified with +// format. +// padding: The type of padding algorithm to use. +// +// Returns 4-D with shape +// `[filter_height, filter_width, in_channels, out_channels]`. Gradient w.r.t. +// the `filter` input of the convolution. +func Conv2DBackpropFilter(scope *Scope, input tf.Output, filter_sizes tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...Conv2DBackpropFilterAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Conv2DBackpropFilter", + Input: []tf.Input{ + input, filter_sizes, out_backprop, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // ConfigureDistributedTPUAttr is an optional argument to ConfigureDistributedTPU. type ConfigureDistributedTPUAttr func(optionalAttr) @@ -30243,6 +30538,21 @@ func ResourceScatterMul(scope *Scope, resource tf.Output, indices tf.Output, upd return scope.AddOperation(opspec) } +// Compresses a dataset element. +func CompressElement(scope *Scope, components []tf.Output) (compressed tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "CompressElement", + Input: []tf.Input{ + tf.OutputList(components), + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // MatMulAttr is an optional argument to MatMul. type MatMulAttr func(optionalAttr) @@ -30655,57 +30965,6 @@ func QuantizedAvgPool(scope *Scope, input tf.Output, min_input tf.Output, max_in return op.Output(0), op.Output(1), op.Output(2) } -// Clips tensor values to a specified min and max. -// -// Given a tensor `t`, this operation returns a tensor of the same type and -// shape as `t` with its values clipped to `clip_value_min` and `clip_value_max`. -// Any values less than `clip_value_min` are set to `clip_value_min`. Any values -// greater than `clip_value_max` are set to `clip_value_max`. -// -// Arguments: -// t: A `Tensor`. -// clip_value_min: A 0-D (scalar) `Tensor`, or a `Tensor` with the same shape -// as `t`. The minimum value to clip by. -// clip_value_max: A 0-D (scalar) `Tensor`, or a `Tensor` with the same shape -// as `t`. The maximum value to clip by. -// -// Returns A clipped `Tensor` with the same shape as input 't'. -func ClipByValue(scope *Scope, t tf.Output, clip_value_min tf.Output, clip_value_max tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ClipByValue", - Input: []tf.Input{ - t, clip_value_min, clip_value_max, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// List of the given size with empty elements. -// -// element_shape: the shape of the future elements of the list -// num_elements: the number of elements to reserve -// handle: the output list -// element_dtype: the desired type of elements in the list. -func TensorListReserve(scope *Scope, element_shape tf.Output, num_elements tf.Output, element_dtype tf.DataType) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"element_dtype": element_dtype} - opspec := tf.OpSpec{ - Type: "TensorListReserve", - Input: []tf.Input{ - element_shape, num_elements, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // VariableShapeAttr is an optional argument to VariableShape. type VariableShapeAttr func(optionalAttr) @@ -33006,6 +33265,14 @@ func TPUReplicatedInputIndex(value int64) TPUReplicatedInputAttr { } } +// TPUReplicatedInputIsPacked sets the optional is_packed attribute to value. +// If not specified, defaults to false +func TPUReplicatedInputIsPacked(value bool) TPUReplicatedInputAttr { + return func(m optionalAttr) { + m["is_packed"] = value + } +} + // Connects N inputs to an N-way replicated TPU computation. // // This operation holds a replicated input to a `tpu.replicate()` computation subgraph. @@ -34196,6 +34463,74 @@ func SparseReduceMax(scope *Scope, input_indices tf.Output, input_values tf.Outp return op.Output(0) } +// Generates sparse cross from a list of sparse and dense tensors. +// +// The op takes two lists, one of 2D `SparseTensor` and one of 2D `Tensor`, each +// representing features of one feature column. It outputs a 2D `SparseTensor` with +// the batchwise crosses of these features. +// +// For example, if the inputs are +// +// inputs[0]: SparseTensor with shape = [2, 2] +// [0, 0]: "a" +// [1, 0]: "b" +// [1, 1]: "c" +// +// inputs[1]: SparseTensor with shape = [2, 1] +// [0, 0]: "d" +// [1, 0]: "e" +// +// inputs[2]: Tensor [["f"], ["g"]] +// +// then the output will be +// +// shape = [2, 2] +// [0, 0]: "a_X_d_X_f" +// [1, 0]: "b_X_e_X_g" +// [1, 1]: "c_X_e_X_g" +// +// if hashed_output=true then the output will be +// +// shape = [2, 2] +// [0, 0]: FingerprintCat64( +// Fingerprint64("f"), FingerprintCat64( +// Fingerprint64("d"), Fingerprint64("a"))) +// [1, 0]: FingerprintCat64( +// Fingerprint64("g"), FingerprintCat64( +// Fingerprint64("e"), Fingerprint64("b"))) +// [1, 1]: FingerprintCat64( +// Fingerprint64("g"), FingerprintCat64( +// Fingerprint64("e"), Fingerprint64("c"))) +// +// Arguments: +// indices: 2-D. Indices of each input `SparseTensor`. +// values: 1-D. values of each `SparseTensor`. +// shapes: 1-D. Shapes of each `SparseTensor`. +// dense_inputs: 2-D. Columns represented by dense `Tensor`. +// num_buckets: It is used if hashed_output is true. +// output = hashed_value%num_buckets if num_buckets > 0 else hashed_value. +// strong_hash: boolean, if true, siphash with salt will be used instead of farmhash. +// salt: Specify the salt that will be used by the siphash function. +// +// Returns: +// output_indices: 2-D. Indices of the concatenated `SparseTensor`. +// output_values: 1-D. Non-empty values of the concatenated or hashed +// `SparseTensor`. +// output_shape: 1-D. Shape of the concatenated `SparseTensor`. +func SparseCrossHashed(scope *Scope, indices []tf.Output, values []tf.Output, shapes []tf.Output, dense_inputs []tf.Output, num_buckets tf.Output, strong_hash tf.Output, salt tf.Output) (output_indices tf.Output, output_values tf.Output, output_shape tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SparseCrossHashed", + Input: []tf.Input{ + tf.OutputList(indices), tf.OutputList(values), tf.OutputList(shapes), tf.OutputList(dense_inputs), num_buckets, strong_hash, salt, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + // QuantizedInstanceNormAttr is an optional argument to QuantizedInstanceNorm. type QuantizedInstanceNormAttr func(optionalAttr) @@ -34457,6 +34792,71 @@ func BiasAddV1(scope *Scope, value tf.Output, bias tf.Output) (output tf.Output) return op.Output(0) } +// Generates sparse cross from a list of sparse and dense tensors. +// +// The op takes two lists, one of 2D `SparseTensor` and one of 2D `Tensor`, each +// representing features of one feature column. It outputs a 2D `SparseTensor` with +// the batchwise crosses of these features. +// +// For example, if the inputs are +// +// inputs[0]: SparseTensor with shape = [2, 2] +// [0, 0]: "a" +// [1, 0]: "b" +// [1, 1]: "c" +// +// inputs[1]: SparseTensor with shape = [2, 1] +// [0, 0]: "d" +// [1, 0]: "e" +// +// inputs[2]: Tensor [["f"], ["g"]] +// +// then the output will be +// +// shape = [2, 2] +// [0, 0]: "a_X_d_X_f" +// [1, 0]: "b_X_e_X_g" +// [1, 1]: "c_X_e_X_g" +// +// if hashed_output=true then the output will be +// +// shape = [2, 2] +// [0, 0]: FingerprintCat64( +// Fingerprint64("f"), FingerprintCat64( +// Fingerprint64("d"), Fingerprint64("a"))) +// [1, 0]: FingerprintCat64( +// Fingerprint64("g"), FingerprintCat64( +// Fingerprint64("e"), Fingerprint64("b"))) +// [1, 1]: FingerprintCat64( +// Fingerprint64("g"), FingerprintCat64( +// Fingerprint64("e"), Fingerprint64("c"))) +// +// Arguments: +// indices: 2-D. Indices of each input `SparseTensor`. +// values: 1-D. values of each `SparseTensor`. +// shapes: 1-D. Shapes of each `SparseTensor`. +// dense_inputs: 2-D. Columns represented by dense `Tensor`. +// sep: string used when joining a list of string inputs, can be used as separator later. +// +// Returns: +// output_indices: 2-D. Indices of the concatenated `SparseTensor`. +// output_values: 1-D. Non-empty values of the concatenated or hashed +// `SparseTensor`. +// output_shape: 1-D. Shape of the concatenated `SparseTensor`. +func SparseCrossV2(scope *Scope, indices []tf.Output, values []tf.Output, shapes []tf.Output, dense_inputs []tf.Output, sep tf.Output) (output_indices tf.Output, output_values tf.Output, output_shape tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "SparseCrossV2", + Input: []tf.Input{ + tf.OutputList(indices), tf.OutputList(values), tf.OutputList(shapes), tf.OutputList(dense_inputs), sep, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1), op.Output(2) +} + // Pads a tensor with mirrored values. // // This operation pads a `input` with mirrored values according to the `paddings` @@ -36887,34 +37287,6 @@ func QueueDequeueUpToV2(scope *Scope, handle tf.Output, n tf.Output, component_t return components } -// Adjust the hue of one or more images. -// -// `images` is a tensor of at least 3 dimensions. The last dimension is -// interpreted as channels, and must be three. -// -// The input image is considered in the RGB colorspace. Conceptually, the RGB -// colors are first mapped into HSV. A delta is then applied all the hue values, -// and then remapped back to RGB colorspace. -// -// Arguments: -// images: Images to adjust. At least 3-D. -// delta: A float delta to add to the hue. -// -// Returns The hue-adjusted image or images. -func AdjustHue(scope *Scope, images tf.Output, delta tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "AdjustHue", - Input: []tf.Input{ - images, delta, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Says whether the targets are in the top `K` predictions. // // This outputs a `batch_size` bool array, an entry `out[i]` is `true` if the @@ -48489,94 +48861,6 @@ func RetrieveTPUEmbeddingFTRLParameters(scope *Scope, num_shards int64, shard_id return op.Output(0), op.Output(1), op.Output(2) } -// Conv2DBackpropFilterAttr is an optional argument to Conv2DBackpropFilter. -type Conv2DBackpropFilterAttr func(optionalAttr) - -// Conv2DBackpropFilterUseCudnnOnGpu sets the optional use_cudnn_on_gpu attribute to value. -// If not specified, defaults to true -func Conv2DBackpropFilterUseCudnnOnGpu(value bool) Conv2DBackpropFilterAttr { - return func(m optionalAttr) { - m["use_cudnn_on_gpu"] = value - } -} - -// Conv2DBackpropFilterExplicitPaddings sets the optional explicit_paddings attribute to value. -// -// value: If `padding` is `"EXPLICIT"`, the list of explicit padding amounts. For the ith -// dimension, the amount of padding inserted before and after the dimension is -// `explicit_paddings[2 * i]` and `explicit_paddings[2 * i + 1]`, respectively. If -// `padding` is not `"EXPLICIT"`, `explicit_paddings` must be empty. -// If not specified, defaults to <> -func Conv2DBackpropFilterExplicitPaddings(value []int64) Conv2DBackpropFilterAttr { - return func(m optionalAttr) { - m["explicit_paddings"] = value - } -} - -// Conv2DBackpropFilterDataFormat sets the optional data_format attribute to value. -// -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the data is stored in the order of: -// [batch, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, in_channels, in_height, in_width]. -// If not specified, defaults to "NHWC" -func Conv2DBackpropFilterDataFormat(value string) Conv2DBackpropFilterAttr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// Conv2DBackpropFilterDilations sets the optional dilations attribute to value. -// -// value: 1-D tensor of length 4. The dilation factor for each dimension of -// `input`. If set to k > 1, there will be k-1 skipped cells between each filter -// element on that dimension. The dimension order is determined by the value of -// `data_format`, see above for details. Dilations in the batch and depth -// dimensions must be 1. -// If not specified, defaults to <i:1 i:1 i:1 i:1 > -func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr { - return func(m optionalAttr) { - m["dilations"] = value - } -} - -// Computes the gradients of convolution with respect to the filter. -// -// Arguments: -// input: 4-D with shape `[batch, in_height, in_width, in_channels]`. -// filter_sizes: An integer vector representing the tensor shape of `filter`, -// where `filter` is a 4-D -// `[filter_height, filter_width, in_channels, out_channels]` tensor. -// out_backprop: 4-D with shape `[batch, out_height, out_width, out_channels]`. -// Gradients w.r.t. the output of the convolution. -// strides: The stride of the sliding window for each dimension of the input -// of the convolution. Must be in the same order as the dimension specified with -// format. -// padding: The type of padding algorithm to use. -// -// Returns 4-D with shape -// `[filter_height, filter_width, in_channels, out_channels]`. Gradient w.r.t. -// the `filter` input of the convolution. -func Conv2DBackpropFilter(scope *Scope, input tf.Output, filter_sizes tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...Conv2DBackpropFilterAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Conv2DBackpropFilter", - Input: []tf.Input{ - input, filter_sizes, out_backprop, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // LRNGradAttr is an optional argument to LRNGrad. type LRNGradAttr func(optionalAttr) diff --git a/tensorflow/lite/BUILD b/tensorflow/lite/BUILD index 14babee2da7..6477c0491f9 100644 --- a/tensorflow/lite/BUILD +++ b/tensorflow/lite/BUILD @@ -246,6 +246,7 @@ cc_library( ":minimal_logging", ":simple_memory_arena", ":string", + ":tflite_with_xnnpack_optional", ":type_to_tflitetype", ":util", ":version", @@ -311,6 +312,8 @@ cc_library( ], ) +# Link this library to inject XNNPACK delegate to TFLite runtime automatically +# by utilizing the weak symbols if they're supported by the platform. cc_library( name = "tflite_with_xnnpack", srcs = ["tflite_with_xnnpack.cc"], @@ -323,6 +326,35 @@ cc_library( alwayslink = 1, ) +# Enables applying XNNPACK delegate for float models in TFLite runtime. +# WARNING: This build flag is experimental and subject to change. +config_setting( + name = "tflite_with_xnnpack_enabled", + values = {"define": "tflite_with_xnnpack=true"}, +) + +cc_library( + name = "tflite_with_xnnpack_optional", + srcs = ["tflite_with_xnnpack_optional.cc"], + hdrs = [ + "core/macros.h", + "tflite_with_xnnpack_optional.h", + ], + copts = tflite_copts() + TFLITE_DEFAULT_COPTS, + defines = select({ + ":tflite_with_xnnpack_enabled": ["TFLITE_BUILD_WITH_XNNPACK_DELEGATE"], + "//conditions:default": [], + }), + deps = [ + "//tensorflow/lite/c:common", + ] + select({ + ":tflite_with_xnnpack_enabled": [ + "//tensorflow/lite/delegates/xnnpack:xnnpack_delegate", + ], + "//conditions:default": [], + }), +) + cc_test( name = "string_util_test", size = "small", @@ -344,7 +376,9 @@ cc_test( cc_test( name = "interpreter_test", size = "small", - srcs = ["interpreter_test.cc"], + srcs = [ + "interpreter_test.cc", + ], features = ["-dynamic_link_test_srcs"], # see go/dynamic_link_test_srcs tags = [ "tflite_not_portable_ios", # TODO(b/117786830) diff --git a/tensorflow/lite/build_def.bzl b/tensorflow/lite/build_def.bzl index 4af4bd4aae8..f6cdb981328 100644 --- a/tensorflow/lite/build_def.bzl +++ b/tensorflow/lite/build_def.bzl @@ -702,7 +702,6 @@ def gen_model_coverage_test(src, model_name, data, failure_type, tags, size = "m "//tensorflow/lite/python:lite", "//tensorflow/python:client_testlib", ] + flex_dep(target_op_sets), - timeout = "long", ) def if_tflite_experimental_runtime(if_eager, if_non_eager, if_none = []): diff --git a/tensorflow/lite/c/BUILD b/tensorflow/lite/c/BUILD index e1702d40d5a..1aa043b7c0c 100644 --- a/tensorflow/lite/c/BUILD +++ b/tensorflow/lite/c/BUILD @@ -22,6 +22,9 @@ package( tflite_cc_shared_object( name = "tensorflowlite_c", linkopts = select({ + "//tensorflow:ios": [ + "-Wl,-exported_symbols_list,$(location //tensorflow/lite/c:exported_symbols.lds)", + ], "//tensorflow:macos": [ "-Wl,-exported_symbols_list,$(location //tensorflow/lite/c:exported_symbols.lds)", ], diff --git a/tensorflow/lite/core/api/BUILD b/tensorflow/lite/core/api/BUILD index 6681a3ed550..419a3b2486d 100644 --- a/tensorflow/lite/core/api/BUILD +++ b/tensorflow/lite/core/api/BUILD @@ -26,6 +26,7 @@ cc_library( deps = [ "//tensorflow/lite/c:common", "//tensorflow/lite/schema:schema_fbs", + "@flatbuffers//:runtime_cc", ], ) diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.cc b/tensorflow/lite/core/api/flatbuffer_conversions.cc index 63e04899ca3..c52fc9f690b 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions.cc +++ b/tensorflow/lite/core/api/flatbuffer_conversions.cc @@ -15,10 +15,14 @@ limitations under the License. #include "tensorflow/lite/core/api/flatbuffer_conversions.h" -#include <cstdlib> +#include <cstddef> +#include <cstdint> +#include <memory> +#include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/schema/schema_generated.h" namespace tflite { diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.h b/tensorflow/lite/core/api/flatbuffer_conversions.h index d774afe8e85..2feddfaa8e6 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions.h +++ b/tensorflow/lite/core/api/flatbuffer_conversions.h @@ -19,9 +19,12 @@ limitations under the License. // flatbuffer serialization format into in-memory values that are used by the // runtime API and interpreter. +#include <cstddef> +#include <new> +#include <type_traits> + #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/core/api/error_reporter.h" -#include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/schema/schema_generated.h" namespace tflite { diff --git a/tensorflow/lite/core/api/op_resolver.cc b/tensorflow/lite/core/api/op_resolver.cc index 6424071f371..c239d9ed23e 100644 --- a/tensorflow/lite/core/api/op_resolver.cc +++ b/tensorflow/lite/core/api/op_resolver.cc @@ -15,6 +15,10 @@ limitations under the License. #include "tensorflow/lite/core/api/op_resolver.h" +#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/core/api/error_reporter.h" + namespace tflite { TfLiteStatus GetRegistrationFromOpCode( diff --git a/tensorflow/lite/core/api/tensor_utils.cc b/tensorflow/lite/core/api/tensor_utils.cc index d8d6fc46a18..3aac16b6878 100644 --- a/tensorflow/lite/core/api/tensor_utils.cc +++ b/tensorflow/lite/core/api/tensor_utils.cc @@ -17,6 +17,8 @@ limitations under the License. #include <string.h> +#include "tensorflow/lite/c/common.h" + namespace tflite { TfLiteStatus ResetVariableTensor(TfLiteTensor* tensor) { diff --git a/tensorflow/lite/core/subgraph.cc b/tensorflow/lite/core/subgraph.cc index 7f4e0e286ea..81710df128b 100644 --- a/tensorflow/lite/core/subgraph.cc +++ b/tensorflow/lite/core/subgraph.cc @@ -533,6 +533,11 @@ void Subgraph::SetCancellationFunction(void* data, check_cancelled_func_ = check_cancelled_func; } +bool Subgraph::IsCancelled() { + return (check_cancelled_func_ != nullptr) && + (*check_cancelled_func_)(cancellation_data_); +} + void Subgraph::ReserveNodes(int count) { nodes_and_registration_.reserve(count); } @@ -1316,6 +1321,8 @@ TfLiteStatus Subgraph::RemoveAllDelegates() { return kTfLiteOk; } +bool Subgraph::HasDelegates() { return !delegates_applied_.empty(); } + TfLiteStatus Subgraph::EnsureMemoryAllocations() { if (memory_planner_) { state_ = kStateUninvokable; diff --git a/tensorflow/lite/core/subgraph.h b/tensorflow/lite/core/subgraph.h index 0b0c1e31e89..d6067daaa6a 100644 --- a/tensorflow/lite/core/subgraph.h +++ b/tensorflow/lite/core/subgraph.h @@ -553,6 +553,9 @@ class Subgraph { // afterwards. TfLiteStatus RemoveAllDelegates(); + // Returns true if the subgraph has delegates applied. + bool HasDelegates(); + // Cleanups up data reserved for the given node. Does not remove the {node, // registration} pair from nodes_and_registrations_. void CleanupNode(int node_index); @@ -578,6 +581,9 @@ class Subgraph { // Ensures the memory required is planned and allocated. TfLiteStatus EnsureMemoryAllocations(); + // Returns true if cancellation function returns true. + bool IsCancelled(); + // The state of the Interpreter. enum State { // The interpreter isn't ready to be invoked. diff --git a/tensorflow/lite/delegates/BUILD b/tensorflow/lite/delegates/BUILD index df671675ec9..8a05298d01a 100644 --- a/tensorflow/lite/delegates/BUILD +++ b/tensorflow/lite/delegates/BUILD @@ -32,6 +32,16 @@ cc_library( ], ) +cc_library( + name = "interpreter_utils", + srcs = ["interpreter_utils.cc"], + hdrs = ["interpreter_utils.h"], + copts = tflite_copts(), + deps = [ + "//tensorflow/lite:framework", + ], +) + cc_test( name = "utils_test", srcs = ["utils_test.cc"], @@ -43,3 +53,25 @@ cc_test( "@com_google_googletest//:gtest_main", ], ) + +cc_test( + name = "delegate_test", + size = "small", + srcs = ["delegate_test.cc"], + features = ["-dynamic_link_test_srcs"], # see go/dynamic_link_test_srcs + tags = [ + "tflite_not_portable_ios", # TODO(b/117786830) + ], + deps = [ + ":interpreter_utils", + "//tensorflow/lite:framework", + "//tensorflow/lite:version", + "//tensorflow/lite/core/api", + "//tensorflow/lite/kernels:builtin_ops", + "//tensorflow/lite/kernels:kernel_util", + "//tensorflow/lite/kernels/internal:compatibility", + "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/testing:util", + "@com_google_googletest//:gtest", + ], +) diff --git a/tensorflow/lite/delegates/delegate_test.cc b/tensorflow/lite/delegates/delegate_test.cc new file mode 100644 index 00000000000..1efe6e44d54 --- /dev/null +++ b/tensorflow/lite/delegates/delegate_test.cc @@ -0,0 +1,1051 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <stdint.h> + +#include <memory> + +#include <gmock/gmock.h> +#include <gtest/gtest.h> +#include "tensorflow/lite/delegates/interpreter_utils.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/testing/util.h" +#include "tensorflow/lite/version.h" + +namespace tflite { +namespace { + +// Build a kernel registration for an op that copies its one input +// to an output +TfLiteRegistration AddOpRegistration() { + TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr}; + + reg.custom_name = "my_add"; + reg.builtin_code = tflite::BuiltinOperator_CUSTOM; + + reg.prepare = [](TfLiteContext* context, TfLiteNode* node) { + // Set output size to input size + const TfLiteTensor* input1 = GetInput(context, node, 0); + const TfLiteTensor* input2 = GetInput(context, node, 1); + TfLiteTensor* output = GetOutput(context, node, 0); + + TF_LITE_ENSURE_EQ(context, input1->dims->size, input2->dims->size); + for (int i = 0; i < input1->dims->size; ++i) { + TF_LITE_ENSURE_EQ(context, input1->dims->data[i], input2->dims->data[i]); + } + + TF_LITE_ENSURE_STATUS(context->ResizeTensor( + context, output, TfLiteIntArrayCopy(input1->dims))); + return kTfLiteOk; + }; + + reg.invoke = [](TfLiteContext* context, TfLiteNode* node) { + // Copy input data to output data. + const TfLiteTensor* a0 = GetInput(context, node, 0); + TF_LITE_ENSURE(context, a0); + TF_LITE_ENSURE(context, a0->data.f); + const TfLiteTensor* a1 = GetInput(context, node, 1); + TF_LITE_ENSURE(context, a1); + TF_LITE_ENSURE(context, a1->data.f); + TfLiteTensor* out = GetOutput(context, node, 0); + TF_LITE_ENSURE(context, out); + TF_LITE_ENSURE(context, out->data.f); + int num = a0->dims->data[0]; + for (int i = 0; i < num; i++) { + out->data.f[i] = a0->data.f[i] + a1->data.f[i]; + } + return kTfLiteOk; + }; + return reg; +} + +} // namespace + +// TestDelegate is a friend of Interpreter to access RemoveAllDelegates(). +class TestDelegate : public ::testing::Test { + protected: + void SetUp() override { + interpreter_.reset(new Interpreter); + interpreter_->AddTensors(5); + interpreter_->SetInputs({0, 1}); + interpreter_->SetOutputs({3, 4}); + TfLiteQuantizationParams quant; + interpreter_->SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {3}, + quant); + interpreter_->SetTensorParametersReadWrite(1, kTfLiteFloat32, "", {3}, + quant); + interpreter_->SetTensorParametersReadWrite(2, kTfLiteFloat32, "", {3}, + quant); + interpreter_->SetTensorParametersReadWrite(3, kTfLiteFloat32, "", {3}, + quant); + interpreter_->SetTensorParametersReadWrite(4, kTfLiteFloat32, "", {3}, + quant); + TfLiteRegistration reg = AddOpRegistration(); + interpreter_->AddNodeWithParameters({0, 0}, {2}, nullptr, 0, nullptr, ®); + interpreter_->AddNodeWithParameters({1, 1}, {3}, nullptr, 0, nullptr, ®); + interpreter_->AddNodeWithParameters({2, 1}, {4}, nullptr, 0, nullptr, ®); + } + + void TearDown() override { + // Interpreter relies on delegate to free the resources properly. Thus + // the life cycle of delegate must be longer than interpreter. + interpreter_.reset(); + delegate_.reset(); + } + + TfLiteBufferHandle last_allocated_handle_ = kTfLiteNullBufferHandle; + + TfLiteBufferHandle AllocateBufferHandle() { return ++last_allocated_handle_; } + + TfLiteStatus RemoveAllDelegates() { + return interpreter_->RemoveAllDelegates(); + } + + protected: + class SimpleDelegate { + public: + // Create a simple implementation of a TfLiteDelegate. We use the C++ class + // SimpleDelegate and it can produce a handle TfLiteDelegate that is + // value-copyable and compatible with TfLite. + // fail_node_prepare: To simulate failure of Delegate node's Prepare(). + // min_ops_per_subset: If >0, partitioning preview is used to choose only + // those subsets with min_ops_per_subset number of nodes. + // fail_node_invoke: To simulate failure of Delegate node's Invoke(). + explicit SimpleDelegate( + const std::vector<int>& nodes, + TfLiteDelegateFlags delegate_flags = kTfLiteDelegateFlagsNone, + bool fail_node_prepare = false, int min_ops_per_subset = 0, + bool fail_node_invoke = false) + : nodes_(nodes), + fail_delegate_node_prepare_(fail_node_prepare), + min_ops_per_subset_(min_ops_per_subset), + fail_delegate_node_invoke_(fail_node_invoke) { + delegate_.Prepare = [](TfLiteContext* context, + TfLiteDelegate* delegate) -> TfLiteStatus { + auto* simple = static_cast<SimpleDelegate*>(delegate->data_); + TfLiteIntArray* nodes_to_separate = + TfLiteIntArrayCreate(simple->nodes_.size()); + // Mark nodes that we want in TfLiteIntArray* structure. + int index = 0; + for (auto node_index : simple->nodes_) { + nodes_to_separate->data[index++] = node_index; + // make sure node is added + TfLiteNode* node; + TfLiteRegistration* reg; + context->GetNodeAndRegistration(context, node_index, &node, ®); + TFLITE_CHECK_EQ(reg->builtin_code, tflite::BuiltinOperator_CUSTOM); + TFLITE_CHECK_EQ(strcmp(reg->custom_name, "my_add"), 0); + } + // Check that all nodes are available + TfLiteIntArray* execution_plan; + TF_LITE_ENSURE_STATUS( + context->GetExecutionPlan(context, &execution_plan)); + for (int exec_index = 0; exec_index < execution_plan->size; + exec_index++) { + int node_index = execution_plan->data[exec_index]; + TfLiteNode* node; + TfLiteRegistration* reg; + context->GetNodeAndRegistration(context, node_index, &node, ®); + if (exec_index == node_index) { + // Check op details only if it wasn't delegated already. + TFLITE_CHECK_EQ(reg->builtin_code, tflite::BuiltinOperator_CUSTOM); + TFLITE_CHECK_EQ(strcmp(reg->custom_name, "my_add"), 0); + } + } + + // Get preview of delegate partitioning from the context. + TfLiteDelegateParams* params_array; + int num_partitions; + TFLITE_CHECK_EQ( + context->PreviewDelegatePartitioning( + context, nodes_to_separate, ¶ms_array, &num_partitions), + kTfLiteOk); + + if (simple->min_ops_per_subset() > 0) { + // Build a new vector of ops from subsets with atleast the minimum + // size. + std::vector<int> allowed_ops; + for (int idx = 0; idx < num_partitions; ++idx) { + const auto* nodes_in_subset = params_array[idx].nodes_to_replace; + if (nodes_in_subset->size < simple->min_ops_per_subset()) continue; + allowed_ops.insert(allowed_ops.end(), nodes_in_subset->data, + nodes_in_subset->data + nodes_in_subset->size); + } + + // Free existing nodes_to_separate & initialize a new array with + // allowed_ops. + TfLiteIntArrayFree(nodes_to_separate); + nodes_to_separate = TfLiteIntArrayCreate(allowed_ops.size()); + memcpy(nodes_to_separate->data, allowed_ops.data(), + sizeof(int) * nodes_to_separate->size); + } + + // Another call to PreviewDelegateParitioning should be okay, since + // partitioning memory is managed by context. + TFLITE_CHECK_EQ( + context->PreviewDelegatePartitioning( + context, nodes_to_separate, ¶ms_array, &num_partitions), + kTfLiteOk); + + context->ReplaceNodeSubsetsWithDelegateKernels( + context, simple->FakeFusedRegistration(), nodes_to_separate, + delegate); + TfLiteIntArrayFree(nodes_to_separate); + return kTfLiteOk; + }; + delegate_.CopyToBufferHandle = [](TfLiteContext* context, + TfLiteDelegate* delegate, + TfLiteBufferHandle buffer_handle, + TfLiteTensor* tensor) -> TfLiteStatus { + // TODO(b/156586986): Implement tests to test buffer copying logic. + return kTfLiteOk; + }; + delegate_.CopyFromBufferHandle = + [](TfLiteContext* context, TfLiteDelegate* delegate, + TfLiteBufferHandle buffer_handle, + TfLiteTensor* output) -> TfLiteStatus { + TFLITE_CHECK_GE(buffer_handle, -1); + TFLITE_CHECK_EQ(output->buffer_handle, buffer_handle); + const float floats[] = {6., 6., 6.}; + int num = output->dims->data[0]; + for (int i = 0; i < num; i++) { + output->data.f[i] = floats[i]; + } + return kTfLiteOk; + }; + + delegate_.FreeBufferHandle = + [](TfLiteContext* context, TfLiteDelegate* delegate, + TfLiteBufferHandle* handle) { *handle = kTfLiteNullBufferHandle; }; + // Store type-punned data SimpleDelegate structure. + delegate_.data_ = static_cast<void*>(this); + delegate_.flags = delegate_flags; + } + + TfLiteRegistration FakeFusedRegistration() { + TfLiteRegistration reg = {nullptr}; + reg.custom_name = "fake_fused_op"; + + reg.invoke = [](TfLiteContext* context, + TfLiteNode* node) -> TfLiteStatus { + // Copy input data to output data. + const TfLiteTensor* a0; + const TfLiteTensor* a1; + if (node->inputs->size == 2) { + a0 = GetInput(context, node, 0); + a1 = GetInput(context, node, 1); + } else { + a0 = GetInput(context, node, 0); + a1 = a0; + } + TfLiteTensor* out = GetOutput(context, node, 0); + int num = 1; + for (int i = 0; i < a0->dims->size; ++i) { + num *= a0->dims->data[i]; + } + for (int i = 0; i < num; i++) { + out->data.f[i] = a0->data.f[i] + a1->data.f[i]; + } + if (out->buffer_handle != kTfLiteNullBufferHandle) { + // Make the data stale so that CopyFromBufferHandle can be invoked + out->data_is_stale = true; + } + return kTfLiteOk; + }; + if (fail_delegate_node_invoke_) { + reg.invoke = [](TfLiteContext* context, + TfLiteNode* node) -> TfLiteStatus { + return kTfLiteError; + }; + } + + reg.prepare = [](TfLiteContext* context, TfLiteNode* node) { + // Set output size to input size + const TfLiteTensor* input1; + const TfLiteTensor* input2; + if (node->inputs->size == 2) { + input1 = GetInput(context, node, 0); + input2 = GetInput(context, node, 1); + } else { + input1 = GetInput(context, node, 0); + input2 = input1; + } + TfLiteTensor* output = GetOutput(context, node, 0); + + TF_LITE_ENSURE_STATUS(context->ResizeTensor( + context, output, TfLiteIntArrayCopy(input1->dims))); + return kTfLiteOk; + }; + if (fail_delegate_node_prepare_) { + reg.prepare = [](TfLiteContext* context, TfLiteNode* node) { + return kTfLiteError; + }; + } + + return reg; + } + + TfLiteDelegate* get_tf_lite_delegate() { return &delegate_; } + + int min_ops_per_subset() { return min_ops_per_subset_; } + + private: + std::vector<int> nodes_; + TfLiteDelegate delegate_; + bool fail_delegate_node_prepare_ = false; + int min_ops_per_subset_ = 0; + bool fail_delegate_node_invoke_ = false; + }; + + std::unique_ptr<Interpreter> interpreter_; + std::unique_ptr<SimpleDelegate> delegate_, delegate2_; +}; +namespace { + +TEST_F(TestDelegate, BasicDelegate) { + delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2})); + interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()); + + ASSERT_EQ(interpreter_->execution_plan().size(), 1); + int node = interpreter_->execution_plan()[0]; + const auto* node_and_reg = interpreter_->node_and_registration(node); + EXPECT_EQ(node_and_reg->second.custom_name, + delegate_->FakeFusedRegistration().custom_name); + + const TfLiteDelegateParams* params = static_cast<const TfLiteDelegateParams*>( + node_and_reg->first.builtin_data); + ASSERT_EQ(params->nodes_to_replace->size, 3); + EXPECT_EQ(params->nodes_to_replace->data[0], 0); + EXPECT_EQ(params->nodes_to_replace->data[1], 1); + EXPECT_EQ(params->nodes_to_replace->data[2], 2); + + ASSERT_EQ(params->input_tensors->size, 2); + EXPECT_EQ(params->input_tensors->data[0], 0); + EXPECT_EQ(params->input_tensors->data[1], 1); + + ASSERT_EQ(params->output_tensors->size, 2); + EXPECT_EQ(params->output_tensors->data[0], 3); + EXPECT_EQ(params->output_tensors->data[1], 4); +} + +TEST_F(TestDelegate, DelegateNodePrepareFailure) { + delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate( + {0, 1, 2}, kTfLiteDelegateFlagsNone, true /**fail_node_prepare**/)); + // ModifyGraphWithDelegate fails, since the Prepare() method in the node's + // TfLiteRegistration returns an error status. + ASSERT_EQ( + interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()), + kTfLiteDelegateError); + // Execution plan should remain unchanged. + ASSERT_EQ(interpreter_->execution_plan().size(), 3); + + std::vector<float> input = {1.0f, 2.0f, 3.0f}; + std::vector<float> expected_output = {2.0f, 4.0f, 6.0f}; + constexpr int kOutputTensorIndex = 3; + TfLiteTensor* tensor = interpreter_->tensor(kOutputTensorIndex); + + // Verify Invoke() behavior. + memcpy(interpreter_->typed_tensor<float>(0), input.data(), 3 * sizeof(float)); + memcpy(interpreter_->typed_tensor<float>(1), input.data(), 3 * sizeof(float)); + interpreter_->Invoke(); + for (int i = 0; i < 3; ++i) { + EXPECT_EQ(tensor->data.f[i], expected_output[i]) << i; + } +} + +TEST_F(TestDelegate, DelegateNodeInvokeFailure) { + delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate( + {0, 1, 2}, kTfLiteDelegateFlagsNone, false /**fail_node_prepare**/, + 0 /**min_ops_per_subset**/, true /**fail_node_invoke**/)); + ASSERT_EQ( + interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()), + kTfLiteOk); + // Delegation modified execution plan. + ASSERT_EQ(interpreter_->execution_plan().size(), 1); + + std::vector<float> input = {1.0f, 2.0f, 3.0f}; + std::vector<float> expected_output = {2.0f, 4.0f, 6.0f}; + constexpr int kOutputTensorIndex = 3; + + // Verify Invoke() behavior: fails first, succeeds after RemoveAllDelegates(). + memcpy(interpreter_->typed_tensor<float>(0), input.data(), 3 * sizeof(float)); + memcpy(interpreter_->typed_tensor<float>(1), input.data(), 3 * sizeof(float)); + EXPECT_EQ(interpreter_->Invoke(), kTfLiteError); + ASSERT_EQ(RemoveAllDelegates(), kTfLiteOk); + // Delegation removed, returning to original execution plan. + ASSERT_EQ(interpreter_->execution_plan().size(), 3); + + memcpy(interpreter_->typed_tensor<float>(0), input.data(), 3 * sizeof(float)); + memcpy(interpreter_->typed_tensor<float>(1), input.data(), 3 * sizeof(float)); + TfLiteTensor* tensor = interpreter_->tensor(kOutputTensorIndex); + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + for (int i = 0; i < 3; ++i) { + EXPECT_EQ(tensor->data.f[i], expected_output[i]) << i; + } +} + +TEST_F(TestDelegate, DelegateNodeInvokeFailureFallback) { + delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate( + {0, 1, 2}, kTfLiteDelegateFlagsNone, false /**fail_node_prepare**/, + 0 /**min_ops_per_subset**/, true /**fail_node_invoke**/)); + ASSERT_EQ( + interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()), + kTfLiteOk); + // Delegation modified execution plan. + ASSERT_EQ(interpreter_->execution_plan().size(), 1); + + std::vector<float> input = {1.0f, 2.0f, 3.0f}; + std::vector<float> expected_output = {2.0f, 4.0f, 6.0f}; + constexpr int kOutputTensorIndex = 3; + + memcpy(interpreter_->typed_tensor<float>(0), input.data(), 3 * sizeof(float)); + memcpy(interpreter_->typed_tensor<float>(1), input.data(), 3 * sizeof(float)); + EXPECT_EQ( + delegates::InterpreterUtils::InvokeWithCPUFallback(interpreter_.get()), + kTfLiteDelegateError); + // Delegation removed, returning to original execution plan. + ASSERT_EQ(interpreter_->execution_plan().size(), 3); + // Check outputs. + TfLiteTensor* tensor = interpreter_->tensor(kOutputTensorIndex); + for (int i = 0; i < 3; ++i) { + EXPECT_EQ(tensor->data.f[i], expected_output[i]) << i; + } +} + +TEST_F(TestDelegate, SecondDelegationPrepareFailure) { + // First delegate only supports nodes 1, 2. Gets applied successfully. + // This delegate should support dynamic tensors, otherwise the second won't be + // applied. + delegate_ = std::unique_ptr<SimpleDelegate>( + new SimpleDelegate({1, 2}, kTfLiteDelegateFlagsAllowDynamicTensors)); + // Second delegate supports node 0, but fails during the delegate-node's + // Prepare. + delegate2_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate( + {0}, kTfLiteDelegateFlagsNone, true /**fail_node_prepare**/)); + + // Initially, execution plan has 3 nodes. + ASSERT_EQ(interpreter_->execution_plan().size(), 3); + // First delegate should be applied successfully, yielding a plan with 2 + // nodes. + ASSERT_EQ( + interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()), + kTfLiteOk); + ASSERT_EQ(interpreter_->execution_plan().size(), 2); + // Second delegate won't get applied. + // As a result, previous delegate should also get undone, restoring the + // execution plan to its original state. + ASSERT_EQ( + interpreter_->ModifyGraphWithDelegate(delegate2_->get_tf_lite_delegate()), + kTfLiteDelegateError); + ASSERT_EQ(interpreter_->execution_plan().size(), 3); + + std::vector<float> input = {1.0f, 2.0f, 3.0f}; + std::vector<float> expected_output = {2.0f, 4.0f, 6.0f}; + constexpr int kOutputTensorIndex = 3; + TfLiteTensor* tensor = interpreter_->tensor(kOutputTensorIndex); + + // Verify Invoke() behavior. + memcpy(interpreter_->typed_tensor<float>(0), input.data(), 3 * sizeof(float)); + memcpy(interpreter_->typed_tensor<float>(1), input.data(), 3 * sizeof(float)); + interpreter_->Invoke(); + for (int i = 0; i < 3; ++i) { + EXPECT_EQ(tensor->data.f[i], expected_output[i]) << i; + } +} + +TEST_F(TestDelegate, SecondDelegationInvokeFailure) { + delegate_ = std::unique_ptr<SimpleDelegate>( + new SimpleDelegate({1, 2}, kTfLiteDelegateFlagsAllowDynamicTensors)); + delegate2_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate( + {0}, kTfLiteDelegateFlagsNone, false /**fail_node_prepare**/, + 0 /**min_ops_per_subset**/, true /**fail_node_invoke**/)); + ASSERT_EQ( + interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()), + kTfLiteOk); + ASSERT_EQ( + interpreter_->ModifyGraphWithDelegate(delegate2_->get_tf_lite_delegate()), + kTfLiteOk); + ASSERT_EQ(interpreter_->execution_plan().size(), 2); + + std::vector<float> input = {1.0f, 2.0f, 3.0f}; + // Outputs match the AddOp path, rather than delegate path. + std::vector<float> expected_output = {2.0f, 4.0f, 6.0f}; + constexpr int kOutputTensorIndex = 3; + + // Verify Invoke() behavior to ensure Interpreter isn't broken. + memcpy(interpreter_->typed_tensor<float>(0), input.data(), 3 * sizeof(float)); + memcpy(interpreter_->typed_tensor<float>(1), input.data(), 3 * sizeof(float)); + EXPECT_EQ(interpreter_->Invoke(), kTfLiteError); + EXPECT_EQ(RemoveAllDelegates(), kTfLiteOk); + ASSERT_EQ(interpreter_->execution_plan().size(), 3); + ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); + TfLiteTensor* tensor = interpreter_->tensor(kOutputTensorIndex); + for (int i = 0; i < 3; ++i) { + EXPECT_EQ(tensor->data.f[i], expected_output[i]) << i; + } +} + +TEST_F(TestDelegate, StaticDelegateMakesGraphImmutable) { + delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2})); + ASSERT_EQ( + interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()), + kTfLiteOk); + ASSERT_EQ(interpreter_->execution_plan().size(), 1); + + // Deliberately try to set tensor params with quantization while immutable, + // ensuring quantization is properly freed. + TfLiteQuantization quant = {}; + quant.type = kTfLiteAffineQuantization; + auto quant_params = static_cast<TfLiteAffineQuantization*>( + malloc(sizeof(TfLiteAffineQuantization))); + quant_params->scale = nullptr; + quant_params->zero_point = nullptr; + quant_params->quantized_dimension = 0; + quant.params = quant_params; + ASSERT_NE(interpreter_->SetTensorParametersReadWrite(0, kTfLiteInt8, "", {3}, + quant), + kTfLiteOk); +} + +TEST_F(TestDelegate, ComplexDelegate) { + delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({1, 2})); + interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()); + + ASSERT_EQ(interpreter_->execution_plan().size(), 2); + // 0th should be a non-delegated original op + ASSERT_EQ(interpreter_->execution_plan()[0], 0); + // 1st should be a new macro op (3) which didn't exist) + ASSERT_EQ(interpreter_->execution_plan()[1], 3); + const auto* node_and_reg = interpreter_->node_and_registration(3); + ASSERT_EQ(node_and_reg->second.custom_name, + delegate_->FakeFusedRegistration().custom_name); +} + +TEST_F(TestDelegate, SetBufferHandleToInput) { + delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2})); + TfLiteDelegate* delegate = delegate_->get_tf_lite_delegate(); + interpreter_->ModifyGraphWithDelegate(delegate); + + constexpr int kOutputTensorIndex = 0; + TfLiteTensor* tensor = interpreter_->tensor(kOutputTensorIndex); + ASSERT_EQ(tensor->delegate, nullptr); + ASSERT_EQ(tensor->buffer_handle, kTfLiteNullBufferHandle); + + TfLiteBufferHandle handle = AllocateBufferHandle(); + TfLiteStatus status = + interpreter_->SetBufferHandle(kOutputTensorIndex, handle, delegate); + ASSERT_EQ(status, kTfLiteOk); + EXPECT_EQ(tensor->delegate, delegate); + EXPECT_EQ(tensor->buffer_handle, handle); +} + +TEST_F(TestDelegate, SetBufferHandleToOutput) { + delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2})); + TfLiteDelegate* delegate = delegate_->get_tf_lite_delegate(); + interpreter_->ModifyGraphWithDelegate(delegate); + + constexpr int kOutputTensorIndex = 3; + TfLiteTensor* tensor = interpreter_->tensor(kOutputTensorIndex); + // Before setting the buffer handle, the tensor's `delegate` is already set + // because it will be written by the delegate. + ASSERT_EQ(tensor->delegate, delegate); + ASSERT_EQ(tensor->buffer_handle, kTfLiteNullBufferHandle); + + TfLiteBufferHandle handle = AllocateBufferHandle(); + TfLiteStatus status = + interpreter_->SetBufferHandle(kOutputTensorIndex, handle, delegate); + ASSERT_EQ(status, kTfLiteOk); + EXPECT_EQ(tensor->delegate, delegate); + EXPECT_EQ(tensor->buffer_handle, handle); +} + +TEST_F(TestDelegate, SetInvalidHandleToTensor) { + interpreter_->Invoke(); + delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2})); + TfLiteDelegate* delegate = delegate_->get_tf_lite_delegate(); + interpreter_->ModifyGraphWithDelegate(delegate); + + SimpleDelegate another_simple_delegate({0, 1, 2}); + + constexpr int kOutputTensorIndex = 3; + TfLiteTensor* tensor = interpreter_->tensor(kOutputTensorIndex); + // Before setting the buffer handle, the tensor's `delegate` is already set + // because it will be written by the delegate. + ASSERT_EQ(tensor->delegate, delegate); + ASSERT_EQ(tensor->buffer_handle, kTfLiteNullBufferHandle); + + TfLiteBufferHandle handle = AllocateBufferHandle(); + TfLiteStatus status = interpreter_->SetBufferHandle( + kOutputTensorIndex, handle, + another_simple_delegate.get_tf_lite_delegate()); + // Setting a buffer handle to a tensor with another delegate will fail. + ASSERT_EQ(status, kTfLiteError); + EXPECT_EQ(tensor->delegate, delegate); + EXPECT_EQ(tensor->buffer_handle, kTfLiteNullBufferHandle); +} + +// We utilize delegation in such a way as to allow node subsets with a minimum +// number of ops only. +TEST_F(TestDelegate, TestDelegationWithPartitionPreview) { + // We set kTfLiteDelegateFlagsAllowDynamicTensors to ensure the second + // delegate can be applied. + // Ops 0 and 2 are delegated but end up in the same partition (based on + // dependency analysis). However, since min_ops_per_subset = 3, no delegation + // takes place. + delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate( + {0, 2}, kTfLiteDelegateFlagsAllowDynamicTensors, + false /**fail_node_prepare**/, 3 /**min_ops_per_subset**/)); + interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()); + + // Original execution plan remains. + ASSERT_EQ(interpreter_->execution_plan().size(), 3); + ASSERT_EQ(interpreter_->execution_plan()[0], 0); + ASSERT_EQ(interpreter_->execution_plan()[1], 1); + ASSERT_EQ(interpreter_->execution_plan()[2], 2); + + // Same ops supported, but min_ops_per_subset = 2. + delegate2_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate( + {0, 2}, kTfLiteDelegateFlagsAllowDynamicTensors, + false /**fail_node_prepare**/, 2 /**min_ops_per_subset**/)); + interpreter_->ModifyGraphWithDelegate(delegate2_->get_tf_lite_delegate()); + + ASSERT_EQ(interpreter_->execution_plan().size(), 2); + ASSERT_EQ(interpreter_->execution_plan()[0], 3); + const auto* node_and_reg = interpreter_->node_and_registration(3); + ASSERT_EQ(node_and_reg->second.custom_name, + delegate2_->FakeFusedRegistration().custom_name); + ASSERT_EQ(interpreter_->execution_plan()[1], 1); +} + +TEST_F(TestDelegate, TestResizeInputWithNonDynamicDelegate) { + delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2})); + ASSERT_EQ( + interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()), + kTfLiteOk); + + // Try resizing input to same shape as before (which should be a No-op). + ASSERT_EQ(interpreter_->ResizeInputTensor(0, {3}), kTfLiteOk); + ASSERT_EQ(interpreter_->execution_plan().size(), 1); + + ASSERT_EQ(interpreter_->ResizeInputTensor(0, {1, 3}), kTfLiteOk); + ASSERT_EQ(interpreter_->ResizeInputTensor(1, {1, 3}), kTfLiteOk); + ASSERT_EQ(interpreter_->execution_plan().size(), 3); + // This should fail, since the previous application of the delegate will be + // re-done automatically, making the graph immutable again. + ASSERT_NE( + interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()), + kTfLiteOk); + // Ensure graph has been restored to its valid delegated state. + ASSERT_EQ(interpreter_->execution_plan().size(), 1); + + std::vector<float> input = {1.0f, 2.0f, 3.0f, 4.0f}; + std::vector<float> expected_output = {2.0f, 4.0f, 6.0f, 8.0f}; + constexpr int kOutputTensorIndex = 3; + TfLiteTensor* tensor = interpreter_->tensor(kOutputTensorIndex); + + // Verify Invoke() behavior. + memcpy(interpreter_->typed_tensor<float>(0), input.data(), 3 * sizeof(float)); + memcpy(interpreter_->typed_tensor<float>(1), input.data(), 3 * sizeof(float)); + interpreter_->Invoke(); + for (int i = 0; i < 3; ++i) { + EXPECT_EQ(tensor->data.f[i], expected_output[i]) << i; + } + + // Resize again, but call AllocateTensors as usual afterwards. + ASSERT_EQ(interpreter_->ResizeInputTensor(0, {1, 4}), kTfLiteOk); + ASSERT_EQ(interpreter_->ResizeInputTensor(1, {1, 4}), kTfLiteOk); + ASSERT_EQ(interpreter_->execution_plan().size(), 3); + ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); + ASSERT_EQ(interpreter_->execution_plan().size(), 1); + + memcpy(interpreter_->typed_tensor<float>(0), input.data(), 4 * sizeof(float)); + memcpy(interpreter_->typed_tensor<float>(1), input.data(), 4 * sizeof(float)); + interpreter_->Invoke(); + for (int i = 0; i < 4; ++i) { + EXPECT_EQ(tensor->data.f[i], expected_output[i]) << i; + } +} + +TEST_F(TestDelegate, TestResizeInputWithMultipleDelegates) { + // First delegate only supports node 0. + // This delegate should support dynamic tensors, otherwise the second won't be + // applied. + delegate_ = std::unique_ptr<SimpleDelegate>( + new SimpleDelegate({0}, kTfLiteDelegateFlagsAllowDynamicTensors)); + // Second delegate supports nodes 1 & 2, and makes the graph immutable. + delegate2_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({1, 2})); + ASSERT_EQ( + interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()), + kTfLiteOk); + ASSERT_EQ( + interpreter_->ModifyGraphWithDelegate(delegate2_->get_tf_lite_delegate()), + kTfLiteOk); + // Should be two delegates nodes. + ASSERT_EQ(interpreter_->execution_plan().size(), 2); + + // Try resizing input to same shape as before (which should be a No-op). + ASSERT_EQ(interpreter_->ResizeInputTensor(0, {3}), kTfLiteOk); + ASSERT_EQ(interpreter_->execution_plan().size(), 2); + + // Resizing input tensors should temporarily restore original execution plan + // of 3 nodes. + ASSERT_EQ(interpreter_->ResizeInputTensor(0, {1, 3}), kTfLiteOk); + ASSERT_EQ(interpreter_->ResizeInputTensor(1, {1, 3}), kTfLiteOk); + ASSERT_EQ(interpreter_->execution_plan().size(), 3); + // This should fail, since the previous application of the delegate will be + // re-done automatically, making the graph immutable again. + ASSERT_NE( + interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()), + kTfLiteOk); + // Ensure graph has been restored to its valid delegated state. + ASSERT_EQ(interpreter_->execution_plan().size(), 2); + + std::vector<float> input = {1.0f, 2.0f, 3.0f, 4.0f}; + std::vector<float> expected_output = {2.0f, 4.0f, 6.0f, 8.0f}; + constexpr int kOutputTensorIndex = 2; + TfLiteTensor* tensor = interpreter_->tensor(kOutputTensorIndex); + + // Verify Invoke() behavior. + memcpy(interpreter_->typed_tensor<float>(0), input.data(), 3 * sizeof(float)); + memcpy(interpreter_->typed_tensor<float>(1), input.data(), 3 * sizeof(float)); + interpreter_->Invoke(); + for (int i = 0; i < 3; ++i) { + EXPECT_EQ(tensor->data.f[i], expected_output[i]) << i; + } + + // Resize again, but call AllocateTensors as usual afterwards. + ASSERT_EQ(interpreter_->ResizeInputTensor(0, {1, 4}), kTfLiteOk); + ASSERT_EQ(interpreter_->ResizeInputTensor(1, {1, 4}), kTfLiteOk); + ASSERT_EQ(interpreter_->execution_plan().size(), 3); + ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); + ASSERT_EQ(interpreter_->execution_plan().size(), 2); + + memcpy(interpreter_->typed_tensor<float>(0), input.data(), 4 * sizeof(float)); + memcpy(interpreter_->typed_tensor<float>(1), input.data(), 4 * sizeof(float)); + interpreter_->Invoke(); + for (int i = 0; i < 4; ++i) { + EXPECT_EQ(tensor->data.f[i], expected_output[i]) << i; + } +} + +TEST_F(TestDelegate, TestFallbackWithMultipleDelegates) { + // First delegate only supports node 0. + // This delegate should support dynamic tensors, otherwise the second won't be + // applied. + delegate_ = std::unique_ptr<SimpleDelegate>( + new SimpleDelegate({0}, kTfLiteDelegateFlagsAllowDynamicTensors)); + // Second delegate supports nodes 1 & 2, and makes the graph immutable. + delegate2_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate( + {1, 2}, kTfLiteDelegateFlagsNone, false /**fail_node_prepare**/, + 0 /**min_ops_per_subset**/, true /**fail_node_invoke**/)); + // Pre-delegation execution plan should have three nodes. + ASSERT_EQ(interpreter_->execution_plan().size(), 3); + ASSERT_EQ( + interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()), + kTfLiteOk); + ASSERT_EQ( + interpreter_->ModifyGraphWithDelegate(delegate2_->get_tf_lite_delegate()), + kTfLiteOk); + // Should be two delegates nodes. + ASSERT_EQ(interpreter_->execution_plan().size(), 2); + + std::vector<float> input = {1.0f, 2.0f, 3.0f}; + std::vector<float> expected_output = {2.0f, 4.0f, 6.0f}; + constexpr int kOutputTensorIndex = 2; + TfLiteTensor* tensor = interpreter_->tensor(kOutputTensorIndex); + + memcpy(interpreter_->typed_tensor<float>(0), input.data(), 3 * sizeof(float)); + memcpy(interpreter_->typed_tensor<float>(1), input.data(), 3 * sizeof(float)); + EXPECT_EQ( + delegates::InterpreterUtils::InvokeWithCPUFallback(interpreter_.get()), + kTfLiteDelegateError); + // All delegates should be undone. + EXPECT_EQ(interpreter_->execution_plan().size(), 3); + for (int i = 0; i < 3; ++i) { + EXPECT_EQ(tensor->data.f[i], expected_output[i]) << i; + } +} + +TEST_F(TestDelegate, ReleaseNonPersistentMemoryWithDelegates) { + // First delegate only supports node 0. + // This delegate should support dynamic tensors, otherwise the second won't be + // applied. + delegate_ = std::unique_ptr<SimpleDelegate>( + new SimpleDelegate({0}, kTfLiteDelegateFlagsAllowDynamicTensors)); + // Second delegate supports nodes 1 & 2, and makes the graph immutable. + delegate2_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({1, 2})); + + // No-op. + ASSERT_EQ(interpreter_->ReleaseNonPersistentMemory(), kTfLiteOk); + + ASSERT_EQ( + interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()), + kTfLiteOk); + ASSERT_EQ( + interpreter_->ModifyGraphWithDelegate(delegate2_->get_tf_lite_delegate()), + kTfLiteOk); + // Should be two delegates nodes. + ASSERT_EQ(interpreter_->execution_plan().size(), 2); + + ASSERT_EQ(interpreter_->ReleaseNonPersistentMemory(), kTfLiteOk); + ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); + + // This should fail, since the graph is immutable. + ASSERT_NE( + interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()), + kTfLiteOk); + + std::vector<float> input = {1.0f, 2.0f, 3.0f, 4.0f}; + std::vector<float> expected_output = {2.0f, 4.0f, 6.0f, 8.0f}; + constexpr int kOutputTensorIndex = 2; + TfLiteTensor* tensor = interpreter_->tensor(kOutputTensorIndex); + + // Verify Invoke() behavior. + memcpy(interpreter_->typed_tensor<float>(0), input.data(), 3 * sizeof(float)); + memcpy(interpreter_->typed_tensor<float>(1), input.data(), 3 * sizeof(float)); + interpreter_->Invoke(); + for (int i = 0; i < 3; ++i) { + EXPECT_EQ(tensor->data.f[i], expected_output[i]) << i; + } + + ASSERT_EQ(interpreter_->ReleaseNonPersistentMemory(), kTfLiteOk); +} + +TEST_F(TestDelegate, TestCopyFromBufferInvoke) { + delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2})); + TfLiteDelegate* delegate = delegate_->get_tf_lite_delegate(); + interpreter_->ModifyGraphWithDelegate(delegate); + + constexpr int kOutputTensorIndex = 3; + TfLiteTensor* tensor = interpreter_->tensor(kOutputTensorIndex); + std::vector<float> floats = {1.0f, 2.0f, 3.0f}; + memcpy(interpreter_->typed_tensor<float>(0), floats.data(), + floats.size() * sizeof(float)); + + memcpy(interpreter_->typed_tensor<float>(1), floats.data(), + floats.size() * sizeof(float)); + + // Before setting the buffer handle, the tensor's `delegate` is already set + // because it will be written by the delegate. + ASSERT_EQ(tensor->delegate, delegate); + ASSERT_EQ(tensor->buffer_handle, kTfLiteNullBufferHandle); + + // Called Invoke without setting the buffer will not call the CopyFromBuffer + interpreter_->Invoke(); + std::vector<float> res = {2.0f, 4.0f, 6.0f}; + for (int i = 0; i < tensor->dims->data[0]; ++i) { + ASSERT_EQ(tensor->data.f[i], res[i]); + } +} + +TEST_F(TestDelegate, TestCopyFromBuffer) { + delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2})); + TfLiteDelegate* delegate = delegate_->get_tf_lite_delegate(); + interpreter_->ModifyGraphWithDelegate(delegate); + + constexpr int kOutputTensorIndex = 3; + TfLiteTensor* tensor = interpreter_->tensor(kOutputTensorIndex); + std::vector<float> floats = {1.0f, 2.0f, 3.0f}; + memcpy(interpreter_->typed_tensor<float>(0), floats.data(), + floats.size() * sizeof(float)); + + memcpy(interpreter_->typed_tensor<float>(1), floats.data(), + floats.size() * sizeof(float)); + + // Before setting the buffer handle, the tensor's `delegate` is already set + // because it will be written by the delegate. + ASSERT_EQ(tensor->delegate, delegate); + ASSERT_EQ(tensor->buffer_handle, kTfLiteNullBufferHandle); + + TfLiteBufferHandle handle = AllocateBufferHandle(); + TfLiteStatus status = + interpreter_->SetBufferHandle(kOutputTensorIndex, handle, delegate); + interpreter_->Invoke(); + ASSERT_EQ(status, kTfLiteOk); + EXPECT_EQ(tensor->delegate, delegate); + EXPECT_EQ(tensor->buffer_handle, handle); + for (int i = 0; i < tensor->dims->data[0]; ++i) { + ASSERT_EQ(tensor->data.f[i], 6.0f); + } +} + +TEST_F(TestDelegate, DelegateCustomOpResolution) { + // Build a flatbuffer model that contains the "my_add" custom op which gets + // resolved only after SimpleDelegate is applied. + flatbuffers::FlatBufferBuilder builder; + // Tensors. + const int32_t shape[1] = {3}; + flatbuffers::Offset<Tensor> tensors[3] = { + CreateTensor(builder, builder.CreateVector<int32_t>(shape, 1), + TensorType_FLOAT32, /*buffer=*/0, builder.CreateString("X")), + CreateTensor(builder, builder.CreateVector<int32_t>(shape, 1), + TensorType_FLOAT32, /*buffer=*/0, builder.CreateString("Y")), + CreateTensor(builder, builder.CreateVector<int32_t>(shape, 1), + TensorType_FLOAT32, /*buffer=*/0, builder.CreateString("Z")), + }; + // Custom op definition. + flatbuffers::Offset<OperatorCode> op_code = + CreateOperatorCodeDirect(builder, BuiltinOperator_CUSTOM, "my_add"); + const int32_t inputs[2] = {0, 1}; + const int32_t outputs[1] = {2}; + flatbuffers::Offset<Operator> op = CreateOperator( + builder, /*opcode_index=*/0, builder.CreateVector<int32_t>(inputs, 2), + builder.CreateVector<int32_t>(outputs, 1), BuiltinOptions_NONE, + /*builtin_options=*/0, + /*custom_options=*/0, tflite::CustomOptionsFormat_FLEXBUFFERS); + // Subgraph & Model. + flatbuffers::Offset<SubGraph> subgraph = + CreateSubGraph(builder, builder.CreateVector(tensors, 3), + builder.CreateVector<int32_t>(inputs, 2), + builder.CreateVector<int32_t>(outputs, 1), + builder.CreateVector(&op, 1), /*name=*/0); + flatbuffers::Offset<Buffer> buffers[1] = { + CreateBuffer(builder, builder.CreateVector({})), + }; + flatbuffers::Offset<Model> model_buffer = CreateModel( + builder, TFLITE_SCHEMA_VERSION, builder.CreateVector(&op_code, 1), + builder.CreateVector(&subgraph, 1), builder.CreateString("test_model"), + builder.CreateVector(buffers, 1)); + builder.Finish(model_buffer); + std::vector<char> buffer = + std::vector<char>(builder.GetBufferPointer(), + builder.GetBufferPointer() + builder.GetSize()); + const Model* model = GetModel(buffer.data()); + + // Build an interpreter with the model. Initialization should work fine. + std::unique_ptr<Interpreter> interpreter; + ASSERT_EQ( + InterpreterBuilder( + model, ::tflite::ops::builtin::BuiltinOpResolver())(&interpreter), + kTfLiteOk); + // AllocateTensors should fail, since my_add hasn't been resolved. + ASSERT_EQ(interpreter->AllocateTensors(), kTfLiteError); + + // Applying static delegate won't work, since the interpreter will first try + // to Prepare all original nodes. + std::unique_ptr<SimpleDelegate> static_delegate(new SimpleDelegate({0})); + ASSERT_EQ(interpreter->ModifyGraphWithDelegate( + static_delegate->get_tf_lite_delegate()), + kTfLiteError); + + // Applying delegate that supports dynamic tensors should work. + std::unique_ptr<SimpleDelegate> dynamic_delegate( + new SimpleDelegate({0}, kTfLiteDelegateFlagsAllowDynamicTensors)); + ASSERT_EQ(interpreter->ModifyGraphWithDelegate( + dynamic_delegate->get_tf_lite_delegate()), + kTfLiteOk); + // AllocateTensors will now work. + ASSERT_EQ(interpreter->AllocateTensors(), kTfLiteOk); +} + +class TestDelegateWithDynamicTensors : public ::testing::Test { + protected: + void SetUp() override { + interpreter_.reset(new Interpreter); + + interpreter_->AddTensors(2); + interpreter_->SetInputs({0}); + interpreter_->SetOutputs({1}); + TfLiteQuantizationParams quant; + interpreter_->SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {3}, + quant); + interpreter_->SetTensorParametersReadWrite(1, kTfLiteFloat32, "", {3}, + quant); + TfLiteRegistration reg = DynamicCopyOpRegistration(); + interpreter_->AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr, ®); + + delegate_.Prepare = [](TfLiteContext* context, + TfLiteDelegate* delegate) -> TfLiteStatus { + // In this test, the delegate replaces all the nodes if this function is + // called. + TfLiteIntArray* execution_plan; + TF_LITE_ENSURE_STATUS( + context->GetExecutionPlan(context, &execution_plan)); + context->ReplaceNodeSubsetsWithDelegateKernels( + context, DelegateRegistration(), execution_plan, delegate); + return kTfLiteOk; + }; + delegate_.flags = kTfLiteDelegateFlagsNone; + } + + static TfLiteRegistration DynamicCopyOpRegistration() { + TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr}; + + reg.prepare = [](TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* output = GetOutput(context, node, 0); + SetTensorToDynamic(output); + return kTfLiteOk; + }; + + reg.invoke = [](TfLiteContext* context, TfLiteNode* node) { + // Not implemented since this isn't required in testing. + return kTfLiteOk; + }; + return reg; + } + + static TfLiteRegistration DelegateRegistration() { + TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr}; + return reg; + } + + std::unique_ptr<Interpreter> interpreter_; + TfLiteDelegate delegate_; +}; + +TEST_F(TestDelegateWithDynamicTensors, DisallowDynamicTensors) { + interpreter_->ModifyGraphWithDelegate(&delegate_); + + ASSERT_EQ(interpreter_->execution_plan().size(), 1); + // The interpreter should not call delegate's `Prepare` when dynamic tensors + // exist. So the node ID isn't changed. + ASSERT_EQ(interpreter_->execution_plan()[0], 0); +} + +TEST_F(TestDelegateWithDynamicTensors, AllowDynamicTensors) { + delegate_.flags = kTfLiteDelegateFlagsAllowDynamicTensors; + interpreter_->ModifyGraphWithDelegate(&delegate_); + + ASSERT_EQ(interpreter_->execution_plan().size(), 1); + // The node should be replaced because dynamic tensors are allowed. Therefore + // only node ID in the execution plan is changed from 0 to 1. + ASSERT_EQ(interpreter_->execution_plan()[0], 1); +} + +TEST_F(TestDelegateWithDynamicTensors, ModifyGraphAfterAllocate) { + // Trigger allocation *before* delegate application. + ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); + + delegate_.flags = kTfLiteDelegateFlagsAllowDynamicTensors; + ASSERT_EQ(interpreter_->ModifyGraphWithDelegate(&delegate_), kTfLiteOk); + ASSERT_EQ(interpreter_->execution_plan().size(), 1); + ASSERT_EQ(interpreter_->execution_plan()[0], 1); + + // Allocation should still succeed. + ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/lite/delegates/gpu/BUILD b/tensorflow/lite/delegates/gpu/BUILD index 2581232bc2b..bb509610c7a 100644 --- a/tensorflow/lite/delegates/gpu/BUILD +++ b/tensorflow/lite/delegates/gpu/BUILD @@ -32,7 +32,11 @@ cc_library( linkopts = select({ "//tensorflow:android": [ "-lEGL", - "-lGLESv3", + # We don't need to link libGLESv3, because if it exists, + # it is a symlink to libGLESv2. + # See Compatibility Definition Document: + # https://source.android.com/compatibility/10/android-10-cdd#7_1_4_1_opengl_es + "-lGLESv2", ], "//conditions:default": [], }), @@ -76,6 +80,7 @@ objc_library( name = "metal_delegate", srcs = ["metal_delegate.mm"], hdrs = ["metal_delegate.h"], + module_name = "TensorFlowLiteCMetal", sdk_frameworks = ["Metal"], deps = [ "//tensorflow/lite:kernel_api", @@ -220,7 +225,11 @@ cc_library( linkopts = select({ "//tensorflow:android": [ "-lEGL", - "-lGLESv3", + # We don't need to link libGLESv3, because if it exists, + # it is a symlink to libGLESv2. + # See Compatibility Definition Document: + # https://source.android.com/compatibility/10/android-10-cdd#7_1_4_1_opengl_es + "-lGLESv2", ], "//conditions:default": [], }), diff --git a/tensorflow/lite/delegates/gpu/api.cc b/tensorflow/lite/delegates/gpu/api.cc index 6c299e4965c..cddd14b6855 100644 --- a/tensorflow/lite/delegates/gpu/api.cc +++ b/tensorflow/lite/delegates/gpu/api.cc @@ -31,6 +31,12 @@ struct ObjectTypeGetter { ObjectType operator()(OpenClTexture) const { return ObjectType::OPENCL_TEXTURE; } + ObjectType operator()(VulkanBuffer) const { + return ObjectType::VULKAN_BUFFER; + } + ObjectType operator()(VulkanTexture) const { + return ObjectType::VULKAN_TEXTURE; + } ObjectType operator()(CpuMemory) const { return ObjectType::CPU_MEMORY; } }; @@ -42,6 +48,8 @@ struct ObjectValidityChecker { } bool operator()(OpenClBuffer obj) const { return obj.memobj; } bool operator()(OpenClTexture obj) const { return obj.memobj; } + bool operator()(VulkanBuffer obj) const { return obj.memory; } + bool operator()(VulkanTexture obj) const { return obj.memory; } bool operator()(CpuMemory obj) const { return obj.data != nullptr && obj.size_bytes > 0 && (data_type == DataType::UNKNOWN || @@ -72,15 +80,19 @@ bool IsValid(const TensorObjectDef& def, const TensorObject& object) { bool IsObjectPresent(ObjectType type, const TensorObject& obj) { switch (type) { case ObjectType::CPU_MEMORY: - return absl::get_if<CpuMemory>(&obj); + return absl::holds_alternative<CpuMemory>(obj); case ObjectType::OPENGL_SSBO: - return absl::get_if<OpenGlBuffer>(&obj); + return absl::holds_alternative<OpenGlBuffer>(obj); case ObjectType::OPENGL_TEXTURE: - return absl::get_if<OpenGlTexture>(&obj); + return absl::holds_alternative<OpenGlTexture>(obj); case ObjectType::OPENCL_BUFFER: - return absl::get_if<OpenClBuffer>(&obj); + return absl::holds_alternative<OpenClBuffer>(obj); case ObjectType::OPENCL_TEXTURE: - return absl::get_if<OpenClTexture>(&obj); + return absl::holds_alternative<OpenClTexture>(obj); + case ObjectType::VULKAN_BUFFER: + return absl::holds_alternative<VulkanBuffer>(obj); + case ObjectType::VULKAN_TEXTURE: + return absl::holds_alternative<VulkanTexture>(obj); case ObjectType::UNKNOWN: return false; } diff --git a/tensorflow/lite/delegates/gpu/api.h b/tensorflow/lite/delegates/gpu/api.h index 2a531f1f81b..1dfeeebd700 100644 --- a/tensorflow/lite/delegates/gpu/api.h +++ b/tensorflow/lite/delegates/gpu/api.h @@ -71,6 +71,8 @@ enum class ObjectType { CPU_MEMORY, OPENCL_TEXTURE, OPENCL_BUFFER, + VULKAN_BUFFER, + VULKAN_TEXTURE }; struct OpenGlBuffer { @@ -104,11 +106,37 @@ struct OpenClTexture { // TODO(akulik): should it specify texture format? }; +struct VulkanBuffer { + VulkanBuffer() = default; + explicit VulkanBuffer(VkBuffer buffer_, VkDeviceSize size_, + VkDeviceMemory memory_, VkDeviceSize offset_) + : buffer(buffer_), size(size_), memory(memory_), offset(offset_) {} + + VkBuffer buffer; + VkDeviceSize size; + VkDeviceMemory memory; + VkDeviceSize offset; +}; + +struct VulkanTexture { + VulkanTexture() = default; + explicit VulkanTexture(VkDeviceMemory new_memory) : memory(new_memory) {} + + VkImage image; + VkImageView image_view; + VkFormat format; + VkExtent3D extent; + VkDeviceMemory memory; + VkDeviceSize offset; +}; + struct VulkanMemory { VulkanMemory() = default; explicit VulkanMemory(VkDeviceMemory new_memory) : memory(new_memory) {} VkDeviceMemory memory; + VkDeviceSize size; + VkDeviceSize offset; }; struct CpuMemory { @@ -195,8 +223,9 @@ bool IsValid(const TensorObjectDef& def); // @return the number of elements in a tensor object. uint32_t NumElements(const TensorObjectDef& def); -using TensorObject = absl::variant<absl::monostate, OpenGlBuffer, OpenGlTexture, - CpuMemory, OpenClBuffer, OpenClTexture>; +using TensorObject = + absl::variant<absl::monostate, OpenGlBuffer, OpenGlTexture, CpuMemory, + OpenClBuffer, OpenClTexture, VulkanBuffer, VulkanTexture>; // @return true if object is set and corresponding values are defined. bool IsValid(const TensorObjectDef& def, const TensorObject& object); diff --git a/tensorflow/lite/delegates/gpu/cl/BUILD b/tensorflow/lite/delegates/gpu/cl/BUILD index 2e686810767..c149479ae4c 100644 --- a/tensorflow/lite/delegates/gpu/cl/BUILD +++ b/tensorflow/lite/delegates/gpu/cl/BUILD @@ -38,6 +38,20 @@ cc_library( ], ) +cc_library( + name = "arguments", + srcs = ["arguments.cc"], + hdrs = ["arguments.h"], + deps = [ + ":opencl_wrapper", + ":util", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/common:util", + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "buffer", srcs = ["buffer.cc"], diff --git a/tensorflow/lite/delegates/gpu/cl/arguments.cc b/tensorflow/lite/delegates/gpu/cl/arguments.cc new file mode 100644 index 00000000000..26d9fc778b3 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/cl/arguments.cc @@ -0,0 +1,173 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/gpu/cl/arguments.h" + +#include "absl/strings/ascii.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" + +namespace tflite { +namespace gpu { +namespace cl { +namespace { +std::string GetNextWord(const std::string& code, size_t first_position) { + size_t pos = first_position; + char t = code[pos]; + while (absl::ascii_isalnum(t) || t == '_') { + pos++; + t = code[pos]; + } + return code.substr(first_position, pos - first_position); +} +} // namespace + +Arguments::Arguments(Arguments&& args) + : int_values_(std::move(args.int_values_)), + shared_int4s_data_(std::move(args.shared_int4s_data_)), + float_values_(std::move(args.float_values_)), + shared_float4s_data_(std::move(args.shared_float4s_data_)) {} +Arguments& Arguments::operator=(Arguments&& args) { + if (this != &args) { + int_values_ = std::move(args.int_values_); + shared_int4s_data_ = std::move(args.shared_int4s_data_); + float_values_ = std::move(args.float_values_); + shared_float4s_data_ = std::move(args.shared_float4s_data_); + } + return *this; +} + +void Arguments::AddFloat(const std::string& name, float value) { + float_values_[name].value = value; +} +void Arguments::AddInt(const std::string& name, int value) { + int_values_[name].value = value; +} + +absl::Status Arguments::SetInt(const std::string& name, int value) { + auto ii = int_values_.find(name); + if (ii == int_values_.end()) { + return absl::NotFoundError(absl::StrCat("No argument with name - ", name)); + } + ii->second.value = value; + if (ii->second.active) { + shared_int4s_data_[ii->second.offset] = value; + } + return absl::OkStatus(); +} + +absl::Status Arguments::SetFloat(const std::string& name, float value) { + auto fi = float_values_.find(name); + if (fi == float_values_.end()) { + return absl::NotFoundError(absl::StrCat("No argument with name - ", name)); + } + fi->second.value = value; + if (fi->second.active) { + shared_float4s_data_[fi->second.offset] = value; + } + return absl::OkStatus(); +} + +std::string Arguments::GetListOfArgs() { + std::string result; + for (int i = 0; i < shared_int4s_data_.size() / 4; ++i) { + absl::StrAppend(&result, ",\n int4 shared_int4_", i); + } + for (int i = 0; i < shared_float4s_data_.size() / 4; ++i) { + absl::StrAppend(&result, ",\n float4 shared_float4_", i); + } + return result; +} + +absl::Status Arguments::Bind(cl_kernel kernel, int offset) { + for (int i = 0; i < shared_int4s_data_.size() / 4; ++i) { + const int error_code = clSetKernelArg(kernel, offset, sizeof(int32_t) * 4, + &shared_int4s_data_[i * 4]); + if (error_code != CL_SUCCESS) { + return absl::UnknownError(absl::StrCat( + "Failed to set kernel arguments - ", CLErrorCodeToString(error_code), + "(at index - ", offset, ")")); + } + offset++; + } + for (int i = 0; i < shared_float4s_data_.size() / 4; ++i) { + const int error_code = clSetKernelArg(kernel, offset, sizeof(int32_t) * 4, + &shared_float4s_data_[i * 4]); + if (error_code != CL_SUCCESS) { + return absl::UnknownError(absl::StrCat( + "Failed to set kernel arguments - ", CLErrorCodeToString(error_code), + "(at index - ", offset, ")")); + } + offset++; + } + return absl::OkStatus(); +} + +std::string Arguments::AddActiveArgument(const std::string& arg_name) { + if (auto it = int_values_.find(arg_name); it != int_values_.end()) { + int int_index; + if (it->second.active) { + int_index = it->second.offset; + } else { + it->second.active = true; + it->second.offset = shared_int4s_data_.size(); + int_index = it->second.offset; + shared_int4s_data_.push_back(it->second.value); + } + std::string index = std::to_string(int_index / 4); + std::string postfixes[4] = {"x", "y", "z", "w"}; + return "shared_int4_" + index + "." + postfixes[int_index % 4]; + } + if (auto it = float_values_.find(arg_name); it != float_values_.end()) { + int float_index; + if (it->second.active) { + float_index = it->second.offset; + } else { + it->second.active = true; + it->second.offset = shared_float4s_data_.size(); + float_index = it->second.offset; + shared_float4s_data_.push_back(it->second.value); + } + std::string index = std::to_string(float_index / 4); + std::string postfixes[4] = {"x", "y", "z", "w"}; + return "shared_float4_" + index + "." + postfixes[float_index % 4]; + } + return arg_name; +} + +void Arguments::ResolveArgsPass(std::string* code) { + std::string result; + constexpr char kPrefix[] = "args."; + size_t position = 0; + size_t next_position = code->find(kPrefix); + while (next_position != std::string::npos) { + size_t arg_pos = next_position; + next_position += strlen(kPrefix); + std::string object_name = GetNextWord(*code, next_position); + std::string new_name = AddActiveArgument(object_name); + code->replace(arg_pos, object_name.size() + strlen(kPrefix), new_name); + position = arg_pos + new_name.size(); + next_position = code->find(kPrefix, position); + } + + int shared_int4s_aligned_size = AlignByN(shared_int4s_data_.size(), 4); + shared_int4s_data_.resize(shared_int4s_aligned_size); + int shared_float4s_aligned_size = AlignByN(shared_float4s_data_.size(), 4); + shared_float4s_data_.resize(shared_float4s_aligned_size); +} + +} // namespace cl +} // namespace gpu +} // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/cl/arguments.h b/tensorflow/lite/delegates/gpu/cl/arguments.h new file mode 100644 index 00000000000..274532d0199 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/cl/arguments.h @@ -0,0 +1,88 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_ARGUMENTS_H_ +#define TENSORFLOW_LITE_DELEGATES_GPU_CL_ARGUMENTS_H_ + +#include <map> +#include <string> +#include <vector> + +#include "tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h" +#include "tensorflow/lite/delegates/gpu/cl/util.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/common/util.h" + +namespace tflite { +namespace gpu { +namespace cl { + +class Arguments { + public: + Arguments() = default; + void AddFloat(const std::string& name, float value = 0.0f); + void AddInt(const std::string& name, int value = 0); + + absl::Status SetInt(const std::string& name, int value); + absl::Status SetFloat(const std::string& name, float value); + + std::string GetListOfArgs(); + + absl::Status Bind(cl_kernel kernel, int offset); + + void ResolveArgsPass(std::string* code); + + // Move only + Arguments(Arguments&& args); + Arguments& operator=(Arguments&& args); + Arguments(const Arguments&) = delete; + Arguments& operator=(const Arguments&) = delete; + + private: + std::string AddActiveArgument(const std::string& arg_name); + + struct IntValue { + int value; + + // many uniforms generated automatically and not used + // to reduce amount of data transferred we adding this optimization + bool active = false; + + // offset to shared uniform storage. + uint32_t offset = -1; + }; + std::map<std::string, IntValue> int_values_; + std::vector<int32_t> shared_int4s_data_; + + struct FloatValue { + float value; + + // many uniforms generated automatically and not used + // to reduce amount of data transferred we adding this optimization + bool active = false; + + // offset to shared uniform storage. + uint32_t offset = -1; + }; + std::map<std::string, FloatValue> float_values_; + std::vector<float> shared_float4s_data_; +}; + +} // namespace cl +} // namespace gpu +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_GPU_CL_ARGUMENTS_H_ diff --git a/tensorflow/lite/delegates/gpu/cl/cl_kernel.h b/tensorflow/lite/delegates/gpu/cl/cl_kernel.h index b575684d2b4..be9dc6dbf03 100644 --- a/tensorflow/lite/delegates/gpu/cl/cl_kernel.h +++ b/tensorflow/lite/delegates/gpu/cl/cl_kernel.h @@ -65,6 +65,7 @@ class CLKernel { int GetPrivateMemorySize() const { return private_memory_size_; } int GetMaxWorkGroupSize() const { return max_work_group_size_; } + int GetBindingCounter() const { return binding_counter_; } void ResetBindingCounter() { binding_counter_ = 0; } // Do not use this function diff --git a/tensorflow/lite/delegates/gpu/cl/egl_sync.cc b/tensorflow/lite/delegates/gpu/cl/egl_sync.cc index ddc373bce31..f50bc75b8be 100644 --- a/tensorflow/lite/delegates/gpu/cl/egl_sync.cc +++ b/tensorflow/lite/delegates/gpu/cl/egl_sync.cc @@ -22,8 +22,15 @@ namespace gpu { namespace cl { absl::Status EglSync::NewFence(EGLDisplay display, EglSync* sync) { + static auto* egl_create_sync_khr = + reinterpret_cast<decltype(&eglCreateSyncKHR)>( + eglGetProcAddress("eglCreateSyncKHR")); + if (egl_create_sync_khr == nullptr) { + // Needs extension: EGL_KHR_fence_sync (EGL) / GL_OES_EGL_sync (OpenGL ES). + return absl::InternalError("Not supported: eglCreateSyncKHR."); + } EGLSyncKHR egl_sync; - RETURN_IF_ERROR(TFLITE_GPU_CALL_EGL(eglCreateSyncKHR, &egl_sync, display, + RETURN_IF_ERROR(TFLITE_GPU_CALL_EGL(*egl_create_sync_khr, &egl_sync, display, EGL_SYNC_FENCE_KHR, nullptr)); if (egl_sync == EGL_NO_SYNC_KHR) { return absl::InternalError("Returned empty KHR EGL sync"); @@ -43,25 +50,46 @@ EglSync& EglSync::operator=(EglSync&& sync) { void EglSync::Invalidate() { if (sync_ != EGL_NO_SYNC_KHR) { - eglDestroySyncKHR(display_, sync_); + static auto* egl_destroy_sync_khr = + reinterpret_cast<decltype(&eglDestroySyncKHR)>( + eglGetProcAddress("eglDestroySyncKHR")); + // Needs extension: EGL_KHR_fence_sync (EGL) / GL_OES_EGL_sync (OpenGL ES). + if (egl_destroy_sync_khr) { + // Note: we're doing nothing when the function pointer is nullptr, or the + // call returns EGL_FALSE. + (*egl_destroy_sync_khr)(display_, sync_); + } sync_ = EGL_NO_SYNC_KHR; } } absl::Status EglSync::ServerWait() { + static auto* egl_wait_sync_khr = reinterpret_cast<decltype(&eglWaitSyncKHR)>( + eglGetProcAddress("eglWaitSyncKHR")); + if (egl_wait_sync_khr == nullptr) { + // Needs extension: EGL_KHR_wait_sync + return absl::InternalError("Not supported: eglWaitSyncKHR."); + } EGLint result; RETURN_IF_ERROR( - TFLITE_GPU_CALL_EGL(eglWaitSyncKHR, &result, display_, sync_, 0)); + TFLITE_GPU_CALL_EGL(*egl_wait_sync_khr, &result, display_, sync_, 0)); return result == EGL_TRUE ? absl::OkStatus() : absl::InternalError("eglWaitSync failed"); } absl::Status EglSync::ClientWait() { + static auto* egl_client_wait_sync_khr = + reinterpret_cast<decltype(&eglClientWaitSyncKHR)>( + eglGetProcAddress("eglClientWaitSyncKHR")); + if (egl_client_wait_sync_khr == nullptr) { + // Needs extension: EGL_KHR_fence_sync (EGL) / GL_OES_EGL_sync (OpenGL ES). + return absl::InternalError("Not supported: eglClientWaitSyncKHR."); + } EGLint result; // TODO(akulik): make it active wait for better performance - RETURN_IF_ERROR(TFLITE_GPU_CALL_EGL(eglClientWaitSyncKHR, &result, display_, - sync_, EGL_SYNC_FLUSH_COMMANDS_BIT_KHR, - EGL_FOREVER_KHR)); + RETURN_IF_ERROR( + TFLITE_GPU_CALL_EGL(*egl_client_wait_sync_khr, &result, display_, sync_, + EGL_SYNC_FLUSH_COMMANDS_BIT_KHR, EGL_FOREVER_KHR)); return result == EGL_CONDITION_SATISFIED_KHR ? absl::OkStatus() : absl::InternalError("eglClientWaitSync failed"); diff --git a/tensorflow/lite/delegates/gpu/cl/gl_interop.cc b/tensorflow/lite/delegates/gpu/cl/gl_interop.cc index eaeff2cda07..599e6766301 100644 --- a/tensorflow/lite/delegates/gpu/cl/gl_interop.cc +++ b/tensorflow/lite/delegates/gpu/cl/gl_interop.cc @@ -273,7 +273,7 @@ GlClBufferCopier::GlClBufferCopier(const TensorObjectDef& input_def, absl::Status GlClBufferCopier::Convert(const TensorObject& input_obj, const TensorObject& output_obj) { - if (absl::get_if<OpenGlBuffer>(&input_obj)) { + if (absl::holds_alternative<OpenGlBuffer>(input_obj)) { auto ssbo = absl::get_if<OpenGlBuffer>(&input_obj); auto cl_mem = absl::get_if<OpenClBuffer>(&output_obj); RETURN_IF_ERROR( diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/BUILD b/tensorflow/lite/delegates/gpu/cl/kernels/BUILD index ff6f06eeb68..b5510b3e8df 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/BUILD +++ b/tensorflow/lite/delegates/gpu/cl/kernels/BUILD @@ -1290,8 +1290,10 @@ cc_library( ":gpu_operation", ":util", ":work_group_picking", + "//tensorflow/lite/delegates/gpu/cl:arguments", "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:types", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/multiply_add_test.cc b/tensorflow/lite/delegates/gpu/cl/kernels/multiply_add_test.cc index 2adb6a20bc4..444a380c2e9 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/multiply_add_test.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/multiply_add_test.cc @@ -38,7 +38,7 @@ TEST_F(OpenCLOperationTest, MultiplyAddVectorMul) { src_tensor.data = {0.0f, 1.0f, 2.0f, 3.0f}; MultiplyAttributes attr; - Tensor<Linear, DataType::FLOAT32> parameters; + ::tflite::gpu::Tensor<Linear, DataType::FLOAT32> parameters; parameters.shape = Linear(2); parameters.data = {0.5f, 2.0f}; attr.param = parameters; @@ -68,7 +68,7 @@ TEST_F(OpenCLOperationTest, MultiplyAddVectorAdd) { src_tensor.data = {0.0f, 1.0f, 2.0f, 3.0f}; AddAttributes attr; - Tensor<Linear, DataType::FLOAT32> parameters; + ::tflite::gpu::Tensor<Linear, DataType::FLOAT32> parameters; parameters.shape = Linear(2); parameters.data = {0.5f, 2.0f}; attr.param = parameters; @@ -152,7 +152,7 @@ TEST_F(OpenCLOperationTest, MultiplyAddVectorMad) { src_tensor.data = {0.0f, 1.0f, 2.0f, 3.0f}; MultiplyAttributes mul_attr; - Tensor<Linear, DataType::FLOAT32> parameters; + ::tflite::gpu::Tensor<Linear, DataType::FLOAT32> parameters; parameters.shape = Linear(2); parameters.data = {0.5f, 2.0f}; mul_attr.param = parameters; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/prelu_test.cc b/tensorflow/lite/delegates/gpu/cl/kernels/prelu_test.cc index 01b603b5961..4b0006c7f32 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/prelu_test.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/prelu_test.cc @@ -37,7 +37,7 @@ TEST_F(OpenCLOperationTest, PReLUAlpha) { src_tensor.data = {0.0f, -1.0f, -2.0f, 3.0f}; PReLUAttributes attr; - Tensor<Linear, DataType::FLOAT32> parameters; + ::tflite::gpu::Tensor<Linear, DataType::FLOAT32> parameters; parameters.shape = Linear(2); parameters.data = {0.5f, -2.0f}; attr.alpha = parameters; @@ -68,7 +68,7 @@ TEST_F(OpenCLOperationTest, PReLUAlphaClip) { src_tensor.data = {0.0f, -1.0f, -2.0f, 3.0f}; PReLUAttributes attr; - Tensor<Linear, DataType::FLOAT32> parameters; + ::tflite::gpu::Tensor<Linear, DataType::FLOAT32> parameters; parameters.shape = Linear(2); parameters.data = {0.5f, -2.0f}; attr.alpha = parameters; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/transpose.cc b/tensorflow/lite/delegates/gpu/cl/kernels/transpose.cc index 66a272fa2da..fc3efe32c3b 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/transpose.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/transpose.cc @@ -17,6 +17,8 @@ limitations under the License. #include <string> +#include "absl/strings/substitute.h" +#include "tensorflow/lite/delegates/gpu/cl/arguments.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/util.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h" @@ -27,37 +29,45 @@ namespace { std::string GetTransposeCode( const OperationDef& op_def, const TransposeAttributes& attr, - const std::vector<ElementwiseOperation*>& linked_operations) { - TensorCodeGenerator src_tensor( - "src_data", - WHSBPoint{"src_size.x", "src_size.y", "src_size.z", "src_size.w"}, - op_def.src_tensors[0]); - TensorCodeGenerator dst_tensor( - "dst_data", - WHSBPoint{"dst_size.x", "dst_size.y", "dst_size.z", "dst_size.w"}, - op_def.dst_tensors[0]); + const std::vector<ElementwiseOperation*>& linked_operations, + Arguments* args) { + TensorCodeGenerator src_tensor("src_data", + WHSBPoint{"args.src_width", "args.src_height", + "args.src_slices", "args.src_batch"}, + op_def.src_tensors[0]); + TensorCodeGenerator dst_tensor("dst_data", + WHSBPoint{"args.dst_width", "args.dst_height", + "args.dst_slices", "args.dst_batch"}, + op_def.dst_tensors[0]); + + args->AddInt("src_width"); + args->AddInt("src_height"); + args->AddInt("src_slices"); + args->AddInt("src_batch"); + args->AddInt("dst_width"); + args->AddInt("dst_height"); + args->AddInt("dst_slices"); + args->AddInt("dst_batch"); + args->AddInt("dst_channels"); const std::string batch_id = op_def.IsBatchSupported() ? "B" : ""; std::string c = GetCommonDefines(op_def.precision); c += "__kernel void main_function(\n"; c += src_tensor.GetDeclaration(AccessType::READ); c += GetArgsDeclaration(linked_operations); - c += dst_tensor.GetDeclaration(AccessType::WRITE) + ",\n"; - c += " int4 src_size, \n"; - c += " int4 dst_size, \n"; - c += " int src_channels, \n"; - c += " int dst_channels \n"; - c += ") {\n"; + c += dst_tensor.GetDeclaration(AccessType::WRITE); + c += "$0) {\n"; if (op_def.IsBatchSupported()) { c += " int linear_id = get_global_id(0);\n"; - c += " int X = linear_id / dst_size.w;\n"; - c += " int B = linear_id % dst_size.w;\n"; + c += " int X = linear_id / args.dst_batch;\n"; + c += " int B = linear_id % args.dst_batch;\n"; } else { c += " int X = get_global_id(0);\n"; } c += " int Y = get_global_id(1);\n"; c += " int Z = get_global_id(2);\n"; - c += " if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) { \n"; + c += " if (X >= args.dst_width || Y >= args.dst_height || Z >= " + "args.dst_slices) { \n"; c += " return; \n"; c += " } \n"; c += " FLT temps[4];\n"; @@ -83,7 +93,7 @@ std::string GetTransposeCode( } else { c += " for (int i = 0; i < 4; ++i) {\n"; c += " int dst_channel = Z * 4 + i;\n"; - c += " if (dst_channel < dst_channels) {;\n"; + c += " if (dst_channel < args.dst_channels) {;\n"; const std::string bhwc[] = {"B", "Y", "X", "dst_channel"}; std::string src_b = op_def.IsBatchSupported() ? bhwc[remap[0]] : ""; c += " int s_y = " + bhwc[remap[1]] + ";\n"; @@ -100,24 +110,27 @@ std::string GetTransposeCode( } c += " FLT4 result = (FLT4)(temps[0], temps[1], temps[2], temps[3]);\n"; std::string x_3dcoord = - op_def.IsBatchSupported() ? "X * dst_size.w + B" : "X"; + op_def.IsBatchSupported() ? "X * args.dst_batch + B" : "X"; const LinkingContext context{"result", x_3dcoord, "Y", "Z"}; c += PostProcess(linked_operations, context); c += " " + dst_tensor.WriteWHSB("result", "X", "Y", "Z", batch_id); c += "}\n"; - return c; + args->ResolveArgsPass(&c); + return absl::Substitute(c, args->GetListOfArgs()); } } // namespace Transpose::Transpose(Transpose&& operation) : GPUOperation(std::move(operation)), attr_(operation.attr_), + args_(std::move(operation.args_)), kernel_(std::move(operation.kernel_)), work_group_size_(operation.work_group_size_) {} Transpose& Transpose::operator=(Transpose&& operation) { if (this != &operation) { attr_ = operation.attr_; + args_ = std::move(operation.args_); kernel_ = std::move(operation.kernel_); std::swap(work_group_size_, operation.work_group_size_); GPUOperation::operator=(std::move(operation)); @@ -126,21 +139,28 @@ Transpose& Transpose::operator=(Transpose&& operation) { } absl::Status Transpose::Compile(const CreationContext& creation_context) { - const auto code = GetTransposeCode(definition_, attr_, linked_operations_); + const auto code = + GetTransposeCode(definition_, attr_, linked_operations_, &args_); return creation_context.cache->GetOrCreateCLKernel( code, "main_function", *creation_context.context, *creation_context.device, &kernel_); } absl::Status Transpose::BindArguments() { + RETURN_IF_ERROR(args_.SetInt("src_width", src_[0]->Width())); + RETURN_IF_ERROR(args_.SetInt("src_height", src_[0]->Height())); + RETURN_IF_ERROR(args_.SetInt("src_slices", src_[0]->Slices())); + RETURN_IF_ERROR(args_.SetInt("src_batch", src_[0]->Batch())); + RETURN_IF_ERROR(args_.SetInt("dst_width", dst_[0]->Width())); + RETURN_IF_ERROR(args_.SetInt("dst_height", dst_[0]->Height())); + RETURN_IF_ERROR(args_.SetInt("dst_slices", dst_[0]->Slices())); + RETURN_IF_ERROR(args_.SetInt("dst_batch", dst_[0]->Batch())); + RETURN_IF_ERROR(args_.SetInt("dst_channels", dst_[0]->Channels())); kernel_.ResetBindingCounter(); RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_)); RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting())); - RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHSB())); - RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB())); - RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->Channels())); - RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->Channels())); + RETURN_IF_ERROR(args_.Bind(kernel_.kernel(), kernel_.GetBindingCounter())); return absl::OkStatus(); } diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/transpose.h b/tensorflow/lite/delegates/gpu/cl/kernels/transpose.h index 61038b1e0ca..13f06281012 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/transpose.h +++ b/tensorflow/lite/delegates/gpu/cl/kernels/transpose.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_TRANSPOSE_H_ #define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_TRANSPOSE_H_ +#include "tensorflow/lite/delegates/gpu/cl/arguments.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h" #include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/types.h" @@ -43,6 +44,7 @@ class Transpose : public GPUOperation { int3 GetGridSize() const; TransposeAttributes attr_; + Arguments args_; CLKernel kernel_; int3 work_group_size_; }; diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/winograd_test.cc b/tensorflow/lite/delegates/gpu/cl/kernels/winograd_test.cc index aff64dd48f3..1dada33ae04 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/winograd_test.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/winograd_test.cc @@ -111,7 +111,7 @@ TEST_F(OpenCLOperationTest, Winograd36To4x4) { src_tensor.data[i] = sin(i); } - Tensor<Linear, DataType::FLOAT32> biases; + ::tflite::gpu::Tensor<Linear, DataType::FLOAT32> biases; biases.shape = Linear(1); biases.data.resize(biases.shape.DimensionsProduct()); for (int i = 0; i < biases.data.size(); ++i) { diff --git a/tensorflow/lite/delegates/gpu/cl/opencl_wrapper.cc b/tensorflow/lite/delegates/gpu/cl/opencl_wrapper.cc index be551bc9973..bdaa807d83c 100644 --- a/tensorflow/lite/delegates/gpu/cl/opencl_wrapper.cc +++ b/tensorflow/lite/delegates/gpu/cl/opencl_wrapper.cc @@ -17,6 +17,8 @@ limitations under the License. #include <dlfcn.h> +#include <string> + #include "absl/strings/str_cat.h" #include "tensorflow/lite/delegates/gpu/common/status.h" @@ -24,42 +26,51 @@ namespace tflite { namespace gpu { namespace cl { +#ifdef __ANDROID__ #define LoadFunction(function) \ if (is_pixel) { \ function = reinterpret_cast<PFN_##function>(loadOpenCLPointer(#function)); \ } else { \ function = reinterpret_cast<PFN_##function>(dlsym(libopencl, #function)); \ } +#else +#define LoadFunction(function) \ + function = reinterpret_cast<PFN_##function>(dlsym(libopencl, #function)); +#endif absl::Status LoadOpenCL() { void* libopencl = dlopen("libOpenCL.so", RTLD_NOW | RTLD_LOCAL); if (libopencl) { LoadOpenCLFunctions(libopencl, false); return absl::OkStatus(); - } else { - // Pixel phone? - libopencl = dlopen("libOpenCL-pixel.so", RTLD_NOW | RTLD_LOCAL); - if (libopencl) { - typedef void (*enableOpenCL_t)(); - enableOpenCL_t enableOpenCL = - reinterpret_cast<enableOpenCL_t>(dlsym(libopencl, "enableOpenCL")); - enableOpenCL(); - LoadOpenCLFunctions(libopencl, true); - return absl::OkStatus(); - } else { - return absl::UnknownError( - absl::StrCat("OpenCL library not loaded - ", dlerror())); - } } + // record error + std::string error(dlerror()); +#ifdef __ANDROID__ + // Pixel phone? + libopencl = dlopen("libOpenCL-pixel.so", RTLD_NOW | RTLD_LOCAL); + if (libopencl) { + typedef void (*enableOpenCL_t)(); + enableOpenCL_t enableOpenCL = + reinterpret_cast<enableOpenCL_t>(dlsym(libopencl, "enableOpenCL")); + enableOpenCL(); + LoadOpenCLFunctions(libopencl, true); + return absl::OkStatus(); + } +#endif + return absl::UnknownError( + absl::StrCat("Can not open OpenCL library on this device - ", error)); } void LoadOpenCLFunctions(void* libopencl, bool is_pixel) { +#ifdef __ANDROID__ typedef void* (*loadOpenCLPointer_t)(const char* name); loadOpenCLPointer_t loadOpenCLPointer; if (is_pixel) { loadOpenCLPointer = reinterpret_cast<loadOpenCLPointer_t>( dlsym(libopencl, "loadOpenCLPointer")); } +#endif LoadFunction(clGetPlatformIDs); LoadFunction(clGetPlatformInfo); diff --git a/tensorflow/lite/delegates/gpu/cl/selectors/fully_connected_selector.cc b/tensorflow/lite/delegates/gpu/cl/selectors/fully_connected_selector.cc index 2a04a04460d..12a1d726368 100644 --- a/tensorflow/lite/delegates/gpu/cl/selectors/fully_connected_selector.cc +++ b/tensorflow/lite/delegates/gpu/cl/selectors/fully_connected_selector.cc @@ -27,6 +27,22 @@ namespace tflite { namespace gpu { namespace cl { +absl::Status SelectFullyConnectedGeneric( + const FullyConnectedAttributes& attr, + const CreationContext& creation_context, const OperationDef& op_def, + int batch_size, std::unique_ptr<GPUOperation>* ptr) { + if (op_def.IsBatchSupported()) { + ConvTexture conv; + RETURN_IF_ERROR(CreateConvTexture(creation_context, op_def, attr, &conv)); + *ptr = absl::make_unique<ConvTexture>(std::move(conv)); + } else { + FullyConnected fc; + RETURN_IF_ERROR(CreateFullyConnected(creation_context, op_def, attr, &fc)); + *ptr = absl::make_unique<FullyConnected>(std::move(fc)); + } + return absl::OkStatus(); +} + absl::Status SelectFullyConnectedAdreno(const FullyConnectedAttributes& attr, const CreationContext& creation_context, const OperationDef& op_def, @@ -38,8 +54,7 @@ absl::Status SelectFullyConnectedAdreno(const FullyConnectedAttributes& attr, *ptr = absl::make_unique<ConvTexture>(std::move(conv)); } else { FullyConnected fc; - RETURN_IF_ERROR( - CreateFullyConnected(creation_context, op_def, attr, &fc)); + RETURN_IF_ERROR(CreateFullyConnected(creation_context, op_def, attr, &fc)); *ptr = absl::make_unique<FullyConnected>(std::move(fc)); } return absl::OkStatus(); @@ -55,8 +70,7 @@ absl::Status SelectFullyConnectedPowerVR( *ptr = absl::make_unique<ConvPowerVR>(std::move(conv)); } else { FullyConnected fc; - RETURN_IF_ERROR( - CreateFullyConnected(creation_context, op_def, attr, &fc)); + RETURN_IF_ERROR(CreateFullyConnected(creation_context, op_def, attr, &fc)); *ptr = absl::make_unique<FullyConnected>(std::move(fc)); } return absl::OkStatus(); @@ -80,8 +94,7 @@ absl::Status SelectFullyConnectedMali(const FullyConnectedAttributes& attr, } } else { FullyConnected fc; - RETURN_IF_ERROR( - CreateFullyConnected(creation_context, op_def, attr, &fc)); + RETURN_IF_ERROR(CreateFullyConnected(creation_context, op_def, attr, &fc)); *ptr = absl::make_unique<FullyConnected>(std::move(fc)); } return absl::OkStatus(); @@ -102,8 +115,8 @@ absl::Status SelectFullyConnected(const FullyConnectedAttributes& attr, return SelectFullyConnectedMali(attr, creation_context, op_def, batch_size, ptr); default: - return SelectFullyConnectedAdreno(attr, creation_context, op_def, - batch_size, ptr); + return SelectFullyConnectedGeneric(attr, creation_context, op_def, + batch_size, ptr); } } diff --git a/tensorflow/lite/delegates/gpu/common/BUILD b/tensorflow/lite/delegates/gpu/common/BUILD index 94d79182a92..b7120605902 100644 --- a/tensorflow/lite/delegates/gpu/common/BUILD +++ b/tensorflow/lite/delegates/gpu/common/BUILD @@ -116,6 +116,7 @@ cc_library( ":status", ":tensor", "@com_google_absl//absl/strings", + "//tensorflow/lite/delegates:utils", "//tensorflow/lite:context", "//tensorflow/lite:kernel_api", "//tensorflow/lite:util", diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.cc b/tensorflow/lite/delegates/gpu/common/model_builder.cc index 46856a70a7c..daedc277869 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder.cc +++ b/tensorflow/lite/delegates/gpu/common/model_builder.cc @@ -45,6 +45,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/common/tensor.h" #include "tensorflow/lite/delegates/gpu/common/transformations/general_transformations.h" +#include "tensorflow/lite/delegates/utils.h" #include "tensorflow/lite/kernels/internal/reference/dequantize.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/kernel_util.h" @@ -1347,6 +1348,17 @@ class PadOperationParser : public TFLiteOperationParser { RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1, /*outputs=*/1)); RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1)); + auto pad_tensor = tflite::GetInput(context, tflite_node, 1); + if (pad_tensor->dims->size != 2) { + return absl::InvalidArgumentError(absl::StrCat( + "Invalid paddings tensor dimension: expected 2 dim, got ", + pad_tensor->dims->size, " dim")); + } + if (pad_tensor->dims->data[0] != 4 || pad_tensor->dims->data[1] != 2) { + return absl::InvalidArgumentError(absl::StrCat( + "Invalid paddings tensor shape: expected 4x2, got ", + pad_tensor->dims->data[0], "x", pad_tensor->dims->data[1])); + } return absl::OkStatus(); } @@ -1370,6 +1382,7 @@ class PadOperationParser : public TFLiteOperationParser { // 4x2 tensor with paddings. if (paddings.shape.h != 4 || paddings.shape.w != 2) { + // It shouldn't fail here since it's checked at IsSupported(). return absl::InvalidArgumentError( "Paddings tensor has unexpected shape."); } @@ -2350,7 +2363,7 @@ class TransformTensorOperationParser : public TFLiteOperationParser { private: }; -class TransformTensorV2OperationParser : public TFLiteOperationParser { +class TransformTensorBilinearV2OperationParser : public TFLiteOperationParser { public: absl::Status IsSupported(const TfLiteContext* context, const TfLiteNode* tflite_node, @@ -2368,7 +2381,7 @@ class TransformTensorV2OperationParser : public TFLiteOperationParser { RETURN_IF_ERROR(reader->AddInput(node, 1)); // bbox RETURN_IF_ERROR(reader->AddOutputs(node)); - std::string op_name = "transform_tensor_v2"; + std::string op_name = "transform_tensor_bilinear_v2"; node->operation.type = op_name; BHWC output_shape; RETURN_IF_ERROR( @@ -2731,8 +2744,8 @@ std::unique_ptr<TFLiteOperationParser> NewOperationParser( if (custom_name == "TransformTensor") { return std::make_unique<TransformTensorOperationParser>(); } - if (custom_name == "TransformTensorV2") { - return std::make_unique<TransformTensorV2OperationParser>(); + if (custom_name == "TransformTensorBilinearV2") { + return std::make_unique<TransformTensorBilinearV2OperationParser>(); } if (custom_name == "TransformLandmarks") { return std::make_unique<TransformLandmarksOperationParser>(); @@ -2762,10 +2775,13 @@ absl::Status IsSupported(const TfLiteContext* context, TfLiteNode* node, ->IsSupported(context, node, registration); } -bool IsAllAllowedTensors(TfLiteContext* context, const TfLiteIntArray* array, +bool IsAllAllowedTensors(TfLiteContext* context, + const TfLiteIntArray* tensor_indices, bool allow_quant_ops = false) { - for (int i = 0; i < array->size; ++i) { - const TfLiteTensor* t = context->tensors + array->data[i]; + for (int i = 0; i < tensor_indices->size; ++i) { + int tensor_idx = tensor_indices->data[i]; + if (tensor_idx == kTfLiteOptionalTensor) continue; + const TfLiteTensor* t = &context->tensors[tensor_idx]; bool type_supported = (t->type == kTfLiteFloat32 || t->type == kTfLiteFloat16); if (allow_quant_ops) { @@ -2809,7 +2825,8 @@ TfLiteIntArray* GetOpsToReplace(TfLiteContext* context, bool allow_quant_ops, return true; }; - GraphWithDequantPartitionHelper partition_helper(context, node_supported_fn); + delegates::FP16GraphPartitionHelper partition_helper(context, + node_supported_fn); std::set<std::string> unsupported_nodes_info; if (partition_helper.Partition(&unsupported_nodes_info) != kTfLiteOk) { return TfLiteIntArrayCreate(0); diff --git a/tensorflow/lite/delegates/gpu/common/model_builder_helper.cc b/tensorflow/lite/delegates/gpu/common/model_builder_helper.cc index 65e2b6f0d47..4973a8179cd 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder_helper.cc +++ b/tensorflow/lite/delegates/gpu/common/model_builder_helper.cc @@ -15,9 +15,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/model_builder_helper.h" -#include <set> #include <string> -#include <unordered_map> #include <fp16.h> #include "absl/strings/str_cat.h" @@ -33,157 +31,6 @@ limitations under the License. namespace tflite { namespace gpu { -TfLiteStatus GraphWithDequantPartitionHelper::Partition( - std::set<std::string>* unsupported_nodes_info) { - const auto status = GraphPartitionHelper::Partition(unsupported_nodes_info); - // Clean up those partitions that have a single dequant op. NoteThose - // removed dequant ops have to be reserved in the graph and should not be - // delegated. - RemoveSingleDequantNodePartitions(); - return status; -} - -std::vector<int> -GraphWithDequantPartitionHelper::GetNodesOfFirstNLargestPartitions(int n) { - // We first get partitions to reduce the number of nodes to be checked in - // deciding which dequant ops could actually be replaced. And then we - // remap input-tensor to dequant nodes' inputs and remove those - // to-be-reserved dequant nodes. - auto first_nps = GetFirstNLargestPartitions(n); - std::vector<int> ops_to_replace; - for (const auto p : first_nps) { - auto nodes = p->nodes_to_replace; - ops_to_replace.insert(ops_to_replace.end(), nodes->data, - nodes->data + nodes->size); - } - RemapInputTensors(ops_to_replace); - RemoveReservedDequantsFromNodes(&ops_to_replace); - return ops_to_replace; -} - -bool GraphWithDequantPartitionHelper::IsNodeSupported( - TfLiteContext* context, TfLiteNode* node, TfLiteRegistration* registration, - int node_id, std::string* unsupported_details) { - // If we need to handle dequant nodes, we have to remap input tensors of - // this node if some of them come from a dequant node before testing if - // the node is supported. - std::vector<int> orig_inputs; - if (RecordAndRemapInputTensors(registration->builtin_code, node_id, node, - &orig_inputs)) { - // We have a dequant op here. Note that we retrun an Ok status because a - // dequant node is first added as supported. Later, this dequant node - // will be removed if it has to be preserved in the graph which happens - // when its immediate downstream nodes cannot be supported. - return true; - } - const auto status = GraphPartitionHelper::IsNodeSupported( - context, node, registration, node_id, unsupported_details); - RestoreToOrigInputTensors(node, orig_inputs); - return status; -} - -bool GraphWithDequantPartitionHelper::RecordAndRemapInputTensors( - int32_t op_code, int node_id, TfLiteNode* node, - std::vector<int>* orig_inputs) { - orig_inputs->clear(); - // Record the dequant node. - if (op_code == kTfLiteBuiltinDequantize && - context_->tensors[node->inputs->data[0]].type == - TfLiteType::kTfLiteFloat16) { - dequant_nodes_[node->outputs->data[0]] = node->inputs->data[0]; - return true; - } - // For a dequantize op, there's no need to remap its input tensors. - if (dequant_nodes_.empty()) return false; - RemapInputTensors(node, orig_inputs); - return false; -} - -void GraphWithDequantPartitionHelper::RestoreToOrigInputTensors( - TfLiteNode* node, const std::vector<int>& orig_inputs) { - if (node->inputs->size != orig_inputs.size()) return; - for (int j = 0; j < node->inputs->size; ++j) { - node->inputs->data[j] = orig_inputs[j]; - } -} - -void GraphWithDequantPartitionHelper::RemapInputTensors( - const std::vector<int>& nodes) const { - for (int node_id : nodes) { - TfLiteNode* node; - TfLiteRegistration* registration; - GetNodeAndRegistration(context_, node_id, &node, ®istration) - .IgnoreError(); - RemapInputTensors(node, nullptr /* orig_inputs*/); - } -} - -void GraphWithDequantPartitionHelper::RemoveSingleDequantNodePartitions() { - auto it = partitions_.begin(); - while (it != partitions_.end()) { - auto p = *it; - if (p->nodes_to_replace->size != 1) { - ++it; - continue; - } - int node_id = p->nodes_to_replace->data[0]; - TfLiteNode* node = nullptr; - TfLiteRegistration* registration = nullptr; - GetNodeAndRegistration(context_, node_id, &node, ®istration) - .IgnoreError(); - if (registration->builtin_code != kTfLiteBuiltinDequantize || - context_->tensors[node->inputs->data[0]].type != - TfLiteType::kTfLiteFloat16) { - ++it; - continue; - } - // Note such dequant nodes have to be preserved in the graph as dequant - // ops are not actually supported in the GPU delegate. - dequant_nodes_to_save_.insert(node_id); - it = partitions_.erase(it); - } -} - -void GraphWithDequantPartitionHelper::RemoveReservedDequantsFromNodes( - std::vector<int>* nodes) { - if (dequant_nodes_to_save_.empty()) return; - auto it = nodes->begin(); - while (it != nodes->end()) { - if (dequant_nodes_to_save_.find(*it) == dequant_nodes_to_save_.end()) { - ++it; - continue; - } - it = nodes->erase(it); - } -} - -void GraphWithDequantPartitionHelper::RemapInputTensors( - TfLiteNode* node, std::vector<int>* orig_inputs) const { - TfLiteIntArray* inputs = node->inputs; - auto inputs_view = TfLiteIntArrayView(inputs); - // Prepopulate 'orig_inputs' first and clear it if there's no input from a - // dequant op. - if (orig_inputs) { - orig_inputs->clear(); - orig_inputs->reserve(inputs->size); - for (auto tid : inputs_view) { - orig_inputs->push_back(tid); - } - } - // Fix this node's inputs (i.e. prune out the preceding dequantize node) in - // order to test if it is supported. - bool is_remapped = false; - for (int j = 0; j < inputs->size; ++j) { - const int input_tid = inputs->data[j]; - const auto it = dequant_nodes_.find(input_tid); - if (it != dequant_nodes_.end()) { - inputs->data[j] = it->second; - is_remapped = true; - } - } - if (!is_remapped && orig_inputs) orig_inputs->clear(); -} - absl::Status GetNodeAndRegistration(TfLiteContext* context, int node_id, TfLiteNode** tflite_node, TfLiteRegistration** registration) { diff --git a/tensorflow/lite/delegates/gpu/common/model_builder_helper.h b/tensorflow/lite/delegates/gpu/common/model_builder_helper.h index 54ae19e890a..9caa5630037 100644 --- a/tensorflow/lite/delegates/gpu/common/model_builder_helper.h +++ b/tensorflow/lite/delegates/gpu/common/model_builder_helper.h @@ -16,17 +16,12 @@ limitations under the License. #ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_BUILDER_HELPER_H_ #define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_BUILDER_HELPER_H_ -#include <set> -#include <string> -#include <unordered_map> - #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/model.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/common/tensor.h" -#include "tensorflow/lite/delegates/utils.h" #include "tensorflow/lite/kernels/internal/reference/dequantize.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/internal/types.h" @@ -35,61 +30,6 @@ limitations under the License. namespace tflite { namespace gpu { -class GraphWithDequantPartitionHelper : public delegates::GraphPartitionHelper { - public: - GraphWithDequantPartitionHelper( - TfLiteContext* context, delegates::IsNodeSupportedFn is_node_supported_fn) - : GraphPartitionHelper(context, std::move(is_node_supported_fn)) {} - - TfLiteStatus Partition( - std::set<std::string>* unsupported_nodes_info) override; - - // Returns a list of node indices of all nodes from the first n largest - // partitions. If there are fewer paritions than n, all nodes will be - // returned. The partition is ranked according to the number of nodes. - std::vector<int> GetNodesOfFirstNLargestPartitions(int n); - - protected: - bool IsNodeSupported(TfLiteContext* context, TfLiteNode* node, - TfLiteRegistration* registration, int node_id, - std::string* unsupported_details) override; - - private: - // Record 'node' if it is a dequant op (i.e. a fp16 one here) and return true. - // When it's not a dequant op, remap its inputs to the inputs of the preceding - // dequant if there's a one and returns false. 'orig_inputs' records original - // input tensor ids of this node if any input is remapped. - bool RecordAndRemapInputTensors(int32_t op_code, int node_id, - TfLiteNode* node, - std::vector<int>* orig_inputs); - - // Restore inputs of 'node' to 'orig_inputs' only if two sizes match. - void RestoreToOrigInputTensors(TfLiteNode* node, - const std::vector<int>& orig_inputs); - - // Remap input tensors of every node in 'nodes' (i.e. node indices) if some of - // them are from dequant ops. - void RemapInputTensors(const std::vector<int>& nodes) const; - - void RemoveSingleDequantNodePartitions(); - - void RemoveReservedDequantsFromNodes(std::vector<int>* nodes); - - // Remap input tensors of a single 'node' if some of come from a dequant op. - // If 'orig_inputs' isn't nullptr, it records original input tensor ids of - // this node if any input is remapped. - void RemapInputTensors(TfLiteNode* node, std::vector<int>* orig_inputs) const; - - // A map recording dequantize nodes's input/output tensors of this selected - // graph. The key is the output tensor id, and the value is the input tensor - // id. - std::unordered_map<int, int> dequant_nodes_; - - // A set of dequant nodes as in node indices that have to be preserved in the - // graph. - std::set<int> dequant_nodes_to_save_; -}; - absl::Status GetNodeAndRegistration(TfLiteContext* context, int node_id, TfLiteNode** tflite_node, TfLiteRegistration** registration); diff --git a/tensorflow/lite/delegates/gpu/common/operations.cc b/tensorflow/lite/delegates/gpu/common/operations.cc index bdcf6f605cc..c3861ca2baa 100644 --- a/tensorflow/lite/delegates/gpu/common/operations.cc +++ b/tensorflow/lite/delegates/gpu/common/operations.cc @@ -499,6 +499,14 @@ BHWC CalculateOutputShape(const BHWC& input, const SliceAttributes& attr) { StridedSize(attr.ends.c - attr.starts.c, attr.strides.c)); } +BHWDC CalculateOutputShape(const BHWDC& input, const Slice3DAttributes& attr) { + return BHWDC(StridedSize(attr.ends.b - attr.starts.b, attr.strides.b), + StridedSize(attr.ends.h - attr.starts.h, attr.strides.h), + StridedSize(attr.ends.w - attr.starts.w, attr.strides.w), + StridedSize(attr.ends.d - attr.starts.d, attr.strides.d), + StridedSize(attr.ends.c - attr.starts.c, attr.strides.c)); +} + BHWC CalculateOutputShape(const BHWC& input, const PadAttributes& attr) { return BHWC(attr.appended.b + attr.prepended.b + input.b, attr.appended.h + attr.prepended.h + input.h, @@ -534,9 +542,10 @@ absl::Status CalculateOutputShape(const std::vector<BHWC>& input, switch (attr.axis) { case Axis::CHANNELS: for (int i = 1; i < input.size(); i++) { - if (input[i].h != new_shape.h || input[i].w != new_shape.w) { + if (input[i].h != new_shape.h || input[i].w != new_shape.w || + input[i].b != new_shape.b) { return absl::InvalidArgumentError( - "Height and Width must be the same when concatenating " + "Height, Width and Batch must be the same when concatenating " "by channels axis"); } new_shape.c += input[i].c; @@ -544,9 +553,10 @@ absl::Status CalculateOutputShape(const std::vector<BHWC>& input, break; case Axis::HEIGHT: for (int i = 1; i < input.size(); i++) { - if (input[i].w != new_shape.w || input[i].c != new_shape.c) { + if (input[i].w != new_shape.w || input[i].c != new_shape.c || + input[i].b != new_shape.b) { return absl::InvalidArgumentError( - "Channels and Width must be the same when concatenating " + "Channels, Width and Batch must be the same when concatenating " "by height axis"); } new_shape.h += input[i].h; @@ -554,14 +564,26 @@ absl::Status CalculateOutputShape(const std::vector<BHWC>& input, break; case Axis::WIDTH: for (int i = 1; i < input.size(); i++) { - if (input[i].h != new_shape.h || input[i].c != new_shape.c) { + if (input[i].h != new_shape.h || input[i].c != new_shape.c || + input[i].b != new_shape.b) { return absl::InvalidArgumentError( - "Height and Channels must be the same when concatenating " + "Height, Channels and Batch must be the same when concatenating " "by width axis"); } new_shape.w += input[i].w; } break; + case Axis::BATCH: + for (int i = 1; i < input.size(); i++) { + if (input[i].h != new_shape.h || input[i].c != new_shape.c || + input[i].w != new_shape.w) { + return absl::InvalidArgumentError( + "Width, Height and Channels must be the same when concatenating " + "by batch axis"); + } + new_shape.b += input[i].b; + } + break; default: return absl::InvalidArgumentError("Invalid axis"); break; @@ -578,9 +600,10 @@ absl::Status CalculateOutputShape(const std::vector<BHWDC>& input, case Axis::CHANNELS: for (int i = 1; i < input.size(); ++i) { if (input[i].h != new_shape.h || input[i].w != new_shape.w || - input[i].d != new_shape.d) { + input[i].d != new_shape.d || input[i].b != new_shape.b) { return absl::InvalidArgumentError( - "Height, Width and Depth must be the same when concatenating " + "Height, Width, Batch and Depth must be the same when " + "concatenating " "by channels axis"); } new_shape.c += input[i].c; @@ -589,9 +612,10 @@ absl::Status CalculateOutputShape(const std::vector<BHWDC>& input, case Axis::HEIGHT: for (int i = 1; i < input.size(); ++i) { if (input[i].w != new_shape.w || input[i].c != new_shape.c || - input[i].d != new_shape.d) { + input[i].d != new_shape.d || input[i].b != new_shape.b) { return absl::InvalidArgumentError( - "Width, Depth and Channels must be the same when concatenating " + "Width, Depth, Batch and Channels must be the same when " + "concatenating " "by height axis"); } new_shape.h += input[i].h; @@ -600,9 +624,10 @@ absl::Status CalculateOutputShape(const std::vector<BHWDC>& input, case Axis::WIDTH: for (int i = 1; i < input.size(); ++i) { if (input[i].h != new_shape.h || input[i].c != new_shape.c || - input[i].d != new_shape.d) { + input[i].d != new_shape.d || input[i].b != new_shape.b) { return absl::InvalidArgumentError( - "Height, Depth and Channels must be the same when concatenating " + "Height, Depth, Batch and Channels must be the same when " + "concatenating " "by width axis"); } new_shape.w += input[i].w; @@ -611,14 +636,27 @@ absl::Status CalculateOutputShape(const std::vector<BHWDC>& input, case Axis::DEPTH: for (int i = 1; i < input.size(); ++i) { if (input[i].w != new_shape.w || input[i].h != new_shape.h || - input[i].c != new_shape.c) { + input[i].c != new_shape.c || input[i].b != new_shape.b) { return absl::InvalidArgumentError( - "Width, Height and Channels must be the same when concatenating " + "Width, Height, Batch and Channels must be the same when " + "concatenating " "by depth axis"); } new_shape.d += input[i].d; } break; + case Axis::BATCH: + for (int i = 1; i < input.size(); ++i) { + if (input[i].w != new_shape.w || input[i].h != new_shape.h || + input[i].c != new_shape.c || input[i].d != new_shape.d) { + return absl::InvalidArgumentError( + "Width, Height, Depth and Channels must be the same when " + "concatenating " + "by batch axis"); + } + new_shape.b += input[i].b; + } + break; default: return absl::InvalidArgumentError("Invalid axis"); } @@ -704,5 +742,12 @@ BHWC CalculateOutputShape(const BHWC& input, const TransposeAttributes& attr) { input.get(attr.perm.w), input.get(attr.perm.c)); } +BHWDC CalculateOutputShape(const BHWDC& input, + const Transpose3DAttributes& attr) { + return BHWDC(input.get(attr.perm.b), input.get(attr.perm.h), + input.get(attr.perm.w), input.get(attr.perm.d), + input.get(attr.perm.c)); +} + } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/common/operations.h b/tensorflow/lite/delegates/gpu/common/operations.h index d0268eee585..9d714d9bc55 100644 --- a/tensorflow/lite/delegates/gpu/common/operations.h +++ b/tensorflow/lite/delegates/gpu/common/operations.h @@ -399,6 +399,9 @@ struct Resize3DAttributes { // If true, the centers of the 8 corner pixels of the input and output tensors // are aligned, preserving the values at the corner pixels. Defaults to false. bool align_corners = false; + // half_pixel_centers assumes pixels are of half the actual dimensions, and + // yields more accurate resizes. Only applicable to BILINEAR sampling. + bool half_pixel_centers = false; }; float CalculateResizeScale(int32_t input_size, int32_t output_size, @@ -460,6 +463,20 @@ struct SliceAttributes { // input. BHWC CalculateOutputShape(const BHWC& input, const SliceAttributes& attr); +// Simple slicing without advanced support for shrinking, reverse slicing etc. +struct Slice3DAttributes { + // Specifies start and end dimensions for slicing. + BHWDC starts; + BHWDC ends; + + // Stride should be >= 1. + BHWDC strides; +}; + +// @return shape of a tensor after Slice3D operation is applied to the given +// input. +BHWDC CalculateOutputShape(const BHWDC& input, const Slice3DAttributes& attr); + struct AddAttributes { TensorOrScalar param; }; @@ -485,6 +502,10 @@ struct ReshapeAttributes { BHWC new_shape; }; +struct Reshape3DAttributes { + BHWDC new_shape; +}; + struct TransposeAttributes { // A permutation of the dimensions of input tensor BHWC perm; @@ -494,6 +515,16 @@ struct TransposeAttributes { // the given input. BHWC CalculateOutputShape(const BHWC& input, const TransposeAttributes& attr); +struct Transpose3DAttributes { + // A permutation of the dimensions of input tensor + BHWDC perm; +}; + +// @return shape of a tensor after Transpose3D operation is applied to +// the given input. +BHWDC CalculateOutputShape(const BHWDC& input, + const Transpose3DAttributes& attr); + struct SpaceToDepthAttributes { int block_size; }; diff --git a/tensorflow/lite/delegates/gpu/common/testing/feature_parity/BUILD b/tensorflow/lite/delegates/gpu/common/testing/feature_parity/BUILD index b0c5b7526f8..b5ceff30d1e 100644 --- a/tensorflow/lite/delegates/gpu/common/testing/feature_parity/BUILD +++ b/tensorflow/lite/delegates/gpu/common/testing/feature_parity/BUILD @@ -24,10 +24,10 @@ cc_library( hdrs = ["utils.h"], deps = [ "//tensorflow/lite:framework", - "//tensorflow/lite:schema_fbs_version", "//tensorflow/lite/c:common", "//tensorflow/lite/kernels:builtin_ops", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", ], ) diff --git a/tensorflow/lite/delegates/gpu/common/testing/feature_parity/generators/BUILD b/tensorflow/lite/delegates/gpu/common/testing/feature_parity/generators/BUILD index f2a6fa10b1e..ae746cdb08d 100644 --- a/tensorflow/lite/delegates/gpu/common/testing/feature_parity/generators/BUILD +++ b/tensorflow/lite/delegates/gpu/common/testing/feature_parity/generators/BUILD @@ -11,7 +11,7 @@ cc_library( ], deps = [ ":add", - "//tensorflow/lite:framework", + "//tensorflow/lite/delegates/gpu/common/testing/feature_parity:utils", ], ) diff --git a/tensorflow/lite/delegates/gpu/common/testing/feature_parity/utils.cc b/tensorflow/lite/delegates/gpu/common/testing/feature_parity/utils.cc index 8f6e3cc64bf..bdcbf7ed62e 100644 --- a/tensorflow/lite/delegates/gpu/common/testing/feature_parity/utils.cc +++ b/tensorflow/lite/delegates/gpu/common/testing/feature_parity/utils.cc @@ -116,7 +116,7 @@ absl::optional<std::string> CoordinateToString(TfLiteIntArray* shape, return result; } -// Builds intepreter for a model, allocates tensors. +// Builds interpreter for a model, allocates tensors. absl::Status BuildInterpreter(const Model* model, std::unique_ptr<Interpreter>* interpreter) { TfLiteStatus status = diff --git a/tensorflow/lite/delegates/gpu/common/testing/feature_parity/utils.h b/tensorflow/lite/delegates/gpu/common/testing/feature_parity/utils.h index 68c4a1a0d1e..7c34978fb55 100644 --- a/tensorflow/lite/delegates/gpu/common/testing/feature_parity/utils.h +++ b/tensorflow/lite/delegates/gpu/common/testing/feature_parity/utils.h @@ -115,7 +115,7 @@ class TensorEqMatcher { return false; } - // 4. Proceed to data comparison. Iterate throught elements as they lay + // 4. Proceed to data comparison. Iterate through elements as they lay // flat. If some pair of elements don't match, deduct the coordinate // basing on the dimensions, then return. absl::Span<float> lhs_span(lhs.data.f, lhs.bytes / sizeof(float)); @@ -163,7 +163,7 @@ class TensorEqMatcher { const TfLiteTensor rhs_; }; -// Builds intepreter for a model, allocates tensors. +// Builds interpreter for a model, allocates tensors. absl::Status BuildInterpreter(const Model* model, std::unique_ptr<Interpreter>* interpreter); diff --git a/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.cc b/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.cc index 4efb98a6847..b279e49e40c 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.cc @@ -48,8 +48,9 @@ class MergeConvolutionWithAdd : public SequenceTransformation { } AddAttributes add_attr = absl::any_cast<AddAttributes>(add_node.operation.attributes); - if (!absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&add_attr.param) && - !absl::get_if<float>(&add_attr.param)) { + if (!absl::holds_alternative<Tensor<Linear, DataType::FLOAT32>>( + add_attr.param) && + !absl::holds_alternative<float>(add_attr.param)) { return {TransformStatus::DECLINED, "This fuse applicable only for broadcast or scalar addition."}; } @@ -104,8 +105,9 @@ class MergeAddWithConvolution : public SequenceTransformation { } AddAttributes add_attr = absl::any_cast<AddAttributes>(add_node.operation.attributes); - if (!absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&add_attr.param) && - !absl::get_if<float>(&add_attr.param)) { + if (!absl::holds_alternative<Tensor<Linear, DataType::FLOAT32>>( + add_attr.param) && + !absl::holds_alternative<float>(add_attr.param)) { return {TransformStatus::DECLINED, "This fuse applicable only for broadcast or scalar addition."}; } diff --git a/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.cc b/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.cc index 749382c3417..f4ace3c0d41 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.cc @@ -45,8 +45,9 @@ class MergeConvolutionWithMul : public SequenceTransformation { MultiplyAttributes mul_attr = absl::any_cast<MultiplyAttributes>(mul_node.operation.attributes); - if (!absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param) && - !absl::get_if<float>(&mul_attr.param)) { + if (!absl::holds_alternative<Tensor<Linear, DataType::FLOAT32>>( + mul_attr.param) && + !absl::holds_alternative<float>(mul_attr.param)) { return { TransformStatus::DECLINED, "This fuse applicable only for broadcast or scalar multiplication."}; @@ -108,9 +109,9 @@ class MergeMulWithConvolution : public SequenceTransformation { MultiplyAttributes mul_attr = absl::any_cast<MultiplyAttributes>(mul_node.operation.attributes); - if (!absl::get_if<Tensor<Linear, DataType::FLOAT32>>( - &mul_attr.param) && - !absl::get_if<float>(&mul_attr.param)) { + if (!absl::holds_alternative<Tensor<Linear, DataType::FLOAT32>>( + mul_attr.param) && + !absl::holds_alternative<float>(mul_attr.param)) { return { TransformStatus::DECLINED, "This fuse applicable only for broadcast or scalar multiplication."}; diff --git a/tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with.cc b/tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with.cc index 23e99bc3305..2f1621eb34b 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with.cc @@ -146,10 +146,11 @@ class MergePaddingWithAddOperation : public NodeTransformation { AddAttributes add_attr = absl::any_cast<AddAttributes>(add_node->operation.attributes); - const auto add_broadcast = - absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&add_attr.param); - const float* add_scalar = absl::get_if<float>(&add_attr.param); - if (add_broadcast || add_scalar) { + const bool is_add_broadcast = + absl::holds_alternative<Tensor<Linear, DataType::FLOAT32>>( + add_attr.param); + const bool is_add_scalar = absl::holds_alternative<float>(add_attr.param); + if (is_add_broadcast || is_add_scalar) { return {TransformStatus::SKIPPED, "Cannot remove padding when this broadcast/scalar ADD"}; } diff --git a/tensorflow/lite/delegates/gpu/common/transformations/remove_noop.cc b/tensorflow/lite/delegates/gpu/common/transformations/remove_noop.cc index e80b244b34f..b4cdd87109a 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/remove_noop.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/remove_noop.cc @@ -77,9 +77,9 @@ std::unique_ptr<SequenceTransformation> NewRemoveSingleInputAdd() { } auto& attr = absl::any_cast<const AddAttributes&>(node->operation.attributes); - return absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&attr.param) == - nullptr && - absl::get_if<float>(&attr.param) == nullptr; + return !absl::holds_alternative<Tensor<Linear, DataType::FLOAT32>>( + attr.param) && + !absl::holds_alternative<float>(attr.param); }); } diff --git a/tensorflow/lite/delegates/gpu/gl/compiled_model.fbs b/tensorflow/lite/delegates/gpu/gl/compiled_model.fbs index f25f9026629..6887b665ee4 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiled_model.fbs +++ b/tensorflow/lite/delegates/gpu/gl/compiled_model.fbs @@ -156,7 +156,7 @@ table CompiledModel { table Parameters { // indicated flow engine version that compiled this model. If engine version - // does not match compiled model, then a model need to be recompiled. + // does not match compiled model, then a model need to be recompiled. // version:uint32; // not implemented // Could potentially be used to track environment when a model was compiled diff --git a/tensorflow/lite/delegates/gpu/gl/gl_buffer.cc b/tensorflow/lite/delegates/gpu/gl/gl_buffer.cc index 1de49676219..344e494690a 100644 --- a/tensorflow/lite/delegates/gpu/gl/gl_buffer.cc +++ b/tensorflow/lite/delegates/gpu/gl/gl_buffer.cc @@ -145,6 +145,19 @@ absl::Status CreatePersistentBuffer(size_t size, return absl::OkStatus(); } +namespace gl_buffer_internal { + +BufferMapper::BufferMapper(GLenum target, size_t offset, size_t bytes, + GLbitfield access) + : target_(target), + data_(glMapBufferRange(target_, offset, bytes, access)) {} + +BufferMapper::~BufferMapper() { + TFLITE_GPU_CALL_GL(glUnmapBuffer, target_).IgnoreError(); +} + +}; // namespace gl_buffer_internal + } // namespace gl } // namespace gpu } // namespace tflite diff --git a/tensorflow/lite/delegates/gpu/gl/gl_buffer.h b/tensorflow/lite/delegates/gpu/gl/gl_buffer.h index 3225679ec5a..1877fb1f144 100644 --- a/tensorflow/lite/delegates/gpu/gl/gl_buffer.h +++ b/tensorflow/lite/delegates/gpu/gl/gl_buffer.h @@ -229,11 +229,9 @@ class BufferBinder { // RAII for mapping and unmapping a buffer. class BufferMapper { public: - BufferMapper(GLenum target, size_t offset, size_t bytes, GLbitfield access) - : target_(target), - data_(glMapBufferRange(target_, offset, bytes, access)) {} + BufferMapper(GLenum target, size_t offset, size_t bytes, GLbitfield access); - ~BufferMapper() { TFLITE_GPU_CALL_GL(glUnmapBuffer, target_).IgnoreError(); } + ~BufferMapper(); void* data() { return data_; } diff --git a/tensorflow/lite/delegates/gpu/gl/object.h b/tensorflow/lite/delegates/gpu/gl/object.h index 3463d0678b6..0c2a2326356 100644 --- a/tensorflow/lite/delegates/gpu/gl/object.h +++ b/tensorflow/lite/delegates/gpu/gl/object.h @@ -70,7 +70,7 @@ struct Object { // @return true if object is a reference. inline bool IsRef(const Object& object) { - return !absl::get_if<ObjectData>(&object.object); + return !absl::holds_alternative<ObjectData>(object.object); } inline ObjectRef GetRef(const Object& object) { diff --git a/tensorflow/lite/delegates/gpu/gl/runtime.cc b/tensorflow/lite/delegates/gpu/gl/runtime.cc index 2a48b59c8d9..b7e01a33570 100644 --- a/tensorflow/lite/delegates/gpu/gl/runtime.cc +++ b/tensorflow/lite/delegates/gpu/gl/runtime.cc @@ -483,7 +483,7 @@ absl::Status ApplyTexturesAssignment( Object* object = global_ref_to_object_ptr[global_ref]; if (usage_rec_id == kNotAssigned || object == nullptr || object->object_type != ObjectType::TEXTURE || - !absl::get_if<ObjectSizeT>(&object->size)) { + !absl::holds_alternative<ObjectSizeT>(object->size)) { // Skip objects with other data type, non-textures and textures with wrong // number of dimensions. continue; diff --git a/tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu/GpuDelegate.java b/tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu/GpuDelegate.java index 895f12f0233..78cab0d2cbf 100644 --- a/tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu/GpuDelegate.java +++ b/tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu/GpuDelegate.java @@ -17,18 +17,19 @@ package org.tensorflow.lite.gpu; import java.io.Closeable; import org.tensorflow.lite.Delegate; +import org.tensorflow.lite.annotations.UsedByReflection; /** * {@link Delegate} for GPU inference. * - * <p>Note: When calling {@code Interpreter.modifyGraphWithDelegate()}/ - * {@code Interpreter.Options.addDelegate()} and {@code Interpreter.run()}, the caller must have an - * {@code EGLContext} in the <b>current thread</b> and {@code Interpreter.run()} must be called from - * the same {@code EGLContext}. If an {@code EGLContext} does not exist, the delegate will - * internally create one, but then the developer must ensure that {@code Interpreter.run()} is - * always called from the same thread in which {@code Interpreter.modifyGraphWithDelegate()} was - * called. + * <p>Note: When calling {@code Interpreter.modifyGraphWithDelegate()}/ {@code + * Interpreter.Options.addDelegate()} and {@code Interpreter.run()}, the caller must have an {@code + * EGLContext} in the <b>current thread</b> and {@code Interpreter.run()} must be called from the + * same {@code EGLContext}. If an {@code EGLContext} does not exist, the delegate will internally + * create one, but then the developer must ensure that {@code Interpreter.run()} is always called + * from the same thread in which {@code Interpreter.modifyGraphWithDelegate()} was called. */ +@UsedByReflection("TFLiteSupport/model/GpuDelegateProxy") public class GpuDelegate implements Delegate, Closeable { private static final long INVALID_DELEGATE_HANDLE = 0; @@ -98,6 +99,7 @@ public class GpuDelegate implements Delegate, Closeable { options.inferencePreference); } + @UsedByReflection("TFLiteSupport/model/GpuDelegateProxy") public GpuDelegate() { this(new Options()); } diff --git a/tensorflow/lite/delegates/interpreter_utils.cc b/tensorflow/lite/delegates/interpreter_utils.cc new file mode 100644 index 00000000000..89955b23361 --- /dev/null +++ b/tensorflow/lite/delegates/interpreter_utils.cc @@ -0,0 +1,65 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/interpreter_utils.h" + +namespace tflite { +namespace delegates { +TfLiteStatus InterpreterUtils::InvokeWithCPUFallback(Interpreter* interpreter) { + TfLiteStatus status = interpreter->Invoke(); + if (status == kTfLiteOk || interpreter->IsCancelled() || + !interpreter->HasDelegates()) { + return status; + } + // Retry without delegation. + // TODO(b/138706191): retry only if error is due to delegation. + TF_LITE_REPORT_ERROR( + interpreter->error_reporter(), + "Invoke() failed in the presence of delegation. Retrying without."); + + // Copy input data to a buffer. + // Input data is safe since Subgraph::PrepareOpsAndTensors() passes + // preserve_inputs=true to ArenaPlanner. + std::vector<char> buf; + size_t input_size = 0; + + for (auto i : interpreter->inputs()) { + TF_LITE_ENSURE_STATUS(interpreter->EnsureTensorDataIsReadable(i)); + TfLiteTensor* t = interpreter->tensor(i); + input_size += t->bytes; + } + buf.reserve(input_size); + for (auto i : interpreter->inputs()) { + TfLiteTensor* t = interpreter->tensor(i); + buf.insert(buf.end(), t->data.raw, t->data.raw + t->bytes); + } + + TF_LITE_ENSURE_STATUS(interpreter->RemoveAllDelegates()); + + // Copy inputs from buffer. + auto bufp = buf.begin(); + for (auto i : interpreter->inputs()) { + TfLiteTensor* t = interpreter->tensor(i); + std::copy(bufp, bufp + t->bytes, t->data.raw); + bufp += t->bytes; + } + + // Invoke again. + TF_LITE_ENSURE_STATUS(interpreter->Invoke()); + return kTfLiteDelegateError; +} + +} // namespace delegates +} // namespace tflite diff --git a/tensorflow/lite/delegates/interpreter_utils.h b/tensorflow/lite/delegates/interpreter_utils.h new file mode 100644 index 00000000000..f736c2db1f4 --- /dev/null +++ b/tensorflow/lite/delegates/interpreter_utils.h @@ -0,0 +1,52 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_INTERPRETER_UTILS_H_ +#define TENSORFLOW_LITE_DELEGATES_INTERPRETER_UTILS_H_ + +#include "tensorflow/lite/interpreter.h" + +// Utility functions and classes for using delegates. + +namespace tflite { +namespace delegates { +#if !TFLITE_EXPERIMENTAL_RUNTIME_EAGER +class InterpreterUtils { + public: + /// Invokes an interpreter with automatic fallback from delegation to CPU. + /// + /// If using the delegate fails, the delegate is automatically undone and an + /// attempt made to return the interpreter to an invokable state. + /// + /// Allowing the fallback is suitable only if both of the following hold: + /// - The caller is known not to cache pointers to tensor data across Invoke() + /// calls. + /// - The model is not stateful (no variables, no LSTMs) or the state isn't + /// needed between batches. + /// + /// Returns one of the following three status codes: + /// 1. kTfLiteOk: Success. Output is valid. + /// 2. kTfLiteDelegateError: Delegate error but fallback succeeded. Output is + /// valid. + /// NOTE: This undoes all delegates previously applied to the Interpreter. + /// 3. kTfLiteError: Unexpected/runtime failure. Output is invalid. + /// WARNING: This is an experimental API and subject to change. + static TfLiteStatus InvokeWithCPUFallback(Interpreter* interpreter); +}; +#endif // !TFLITE_EXPERIMENTAL_RUNTIME_EAGER +} // namespace delegates +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_INTERPRETER_UTILS_H_ diff --git a/tensorflow/lite/delegates/nnapi/acceleration_test_list.cc b/tensorflow/lite/delegates/nnapi/acceleration_test_list.cc index cc9e049123e..46a6a720d1e 100644 --- a/tensorflow/lite/delegates/nnapi/acceleration_test_list.cc +++ b/tensorflow/lite/delegates/nnapi/acceleration_test_list.cc @@ -56,6 +56,7 @@ FloatActivationsOpTest/PRelu,29 LogisticOpTest/LogisticOpTest/Sigmoid(.+nt8)?/\d+ LogisticOpTest/LogisticOpTest/Sigmoid/\d+ TanhOpTest/TanhOpTest/Tanh(.+nt8)?/\d+,29 +FloatActivationsOpTest/Elu,30 FloatActivationsOpTest/HardSwish QuantizedActivationsOpTest/HardSwish QuantizedActivationsOpTest/HardSwishBias @@ -301,14 +302,14 @@ VariedShapeSpec/ReshapeOpTest/WithStretchDimension/1 # resize_bilinear_test // align_corners & half_pixel_centers are not implemented in NNAPI before API 30 -ResizeBilinearOpTest/ResizeBilinearOpTest.+HalfPixelCenters.*,30 +ResizeBilinearOpTest/ResizeBilinearOpTest.+HalfPixelCenters.*/0,30 // Only models with constant size tensor are accelerated ResizeBilinearOpTest/ResizeBilinearOpTest/.+/0,29 # resize_nearest_neighbor_test // align_corners & half_pixel_centers are not implemented in NNAPI before API 30 -ResizeNearestNeighborOpTest/ResizeNearestNeighborOpTest.+AlignCorners.*,30 -ResizeNearestNeighborOpTest/ResizeNearestNeighborOpTest.+HalfPixelCenters.*,30 +ResizeNearestNeighborOpTest/ResizeNearestNeighborOpTest.+AlignCorners.*/0,30 +ResizeNearestNeighborOpTest/ResizeNearestNeighborOpTest.+HalfPixelCenters.*/0,30 // Only models with constant size tensor are accelerated ResizeNearestNeighborOpTest/ResizeNearestNeighborOpTest/.+/0,29 diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc index ff6ad0dc0d9..fd6703bd46a 100644 --- a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc +++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc @@ -18,6 +18,7 @@ limitations under the License. #include <cstdarg> #include <cstddef> #include <cstdint> +#include <cstdio> #include <cstring> #include <functional> #include <initializer_list> @@ -659,8 +660,10 @@ class NNAPIOpBuilder { // Lower hardswish according to the following equation: // hard_swish[x] = x (ReLU6(x + 3)) / 6 == x * (Relu_N1_to_1(x/3) * 3 + 3) / 6 // = 0.5x * Relu_N1_to_1(x/3) + 0.5x - TfLiteStatus AddHardSwish(int lite_input_index, int lite_output_index, - bool need_int8_conversion, int lite_node_index) { + TfLiteStatus TransformHardSwishIntoSupportedOps(int lite_input_index, + int lite_output_index, + bool need_int8_conversion, + int lite_node_index) { const TfLiteTensor& tensor = context_->tensors[lite_input_index]; float input_scale = tensor.params.scale; int input_zero_point = tensor.params.zero_point; @@ -1623,7 +1626,7 @@ bool NNAPIDelegateKernel::Validate( } } break; case kTfLiteBuiltinResizeBilinear: { - ExpectMaxOpVersion(version, 2, &val_ctx); + ExpectMaxOpVersion(version, 3, &val_ctx); const auto& input = context->tensors[node->inputs->data[0]]; const auto output_dims = context->tensors[node->outputs->data[0]].dims; Expect(input.dims->size == 4, @@ -1663,7 +1666,7 @@ bool NNAPIDelegateKernel::Validate( } } break; case kTfLiteBuiltinResizeNearestNeighbor: { - ExpectMaxOpVersion(version, 2, &val_ctx); + ExpectMaxOpVersion(version, 3, &val_ctx); ExpectMinAndroidSdkVersion(android_sdk_version, kMinSdkVersionForNNAPI12, &val_ctx); ExpectIsFloatOrQuant8Operator(context, node, &val_ctx); @@ -2334,6 +2337,11 @@ bool NNAPIDelegateKernel::Validate( NNAPIValidationFailureType::kUnsupportedInputType, "NNAPI only supports floating point input.", &val_ctx); } break; + case kTfLiteBuiltinElu: { + ExpectOpVersion(version, 1, &val_ctx); + ExpectMinAndroidSdkVersion(android_sdk_version, kMinSdkVersionForNNAPI13, + &val_ctx); + } break; default: // All other operators are not mapped. AddValidationFailure(NNAPIValidationFailureType::kUnsupportedOperator, @@ -2419,6 +2427,9 @@ TfLiteStatus NNAPIDelegateKernel::Map( mapping_args.builder->AddScalarInt32Operand(builtin->activation); *nn_op_type = ANEURALNETWORKS_FULLY_CONNECTED; } break; + case kTfLiteBuiltinHardSwish: { + *nn_op_type = ANEURALNETWORKS_HARD_SWISH; + } break; case kTfLiteBuiltinSoftmax: { auto builtin = reinterpret_cast<TfLiteSoftmaxParams*>( mapping_args.node->builtin_data); @@ -3111,6 +3122,10 @@ TfLiteStatus NNAPIDelegateKernel::Map( mapping_args.builder->AddScalarBoolOperand(builtin->keep_dims); *nn_op_type = ANEURALNETWORKS_REDUCE_SUM; } break; + case kTfLiteBuiltinElu: { + mapping_args.builder->AddScalarFloat32Operand(1.0); + *nn_op_type = ANEURALNETWORKS_ELU; + } break; default: // All other operators are not mapped. return kTfLiteError; @@ -3246,6 +3261,22 @@ TfLiteStatus NNAPIDelegateKernel::Prepare(TfLiteContext* context, RETURN_TFLITE_ERROR_IF_NN_ERROR(context, set_caching_result, "configuring NNAPI caching", nnapi_errno); } + // Set compilation timeout if applicable. + if (nnapi_->android_sdk_version >= kMinSdkVersionForNNAPI13) { + if (delegate_options.max_compilation_timeout_duration_ns > 0) { + RETURN_TFLITE_ERROR_IF_NN_ERROR( + context, + nnapi_->ANeuralNetworksCompilation_setTimeout( + compilation, + delegate_options.max_compilation_timeout_duration_ns), + "setting compilation timeout", nnapi_errno); + } + RETURN_TFLITE_ERROR_IF_NN_ERROR( + context, + nnapi_->ANeuralNetworksCompilation_setPriority( + compilation, delegate_options.execution_priority), + "setting compilation priority", nnapi_errno); + } const int finish_result = nnapi_->ANeuralNetworksCompilation_finish(compilation); if (finish_result != ANEURALNETWORKS_NO_ERROR) { @@ -3312,6 +3343,27 @@ TfLiteStatus NNAPIDelegateKernel::Invoke(TfLiteContext* context, std::unique_ptr<ANeuralNetworksExecution, NNFreeExecution> execution_unique_ptr(execution, NNFreeExecution(nnapi_)); + // Set compilation timeout if applicable. + const auto delegate_options = + StatefulNnApiDelegate::GetOptions(node->delegate); + if (nnapi_->android_sdk_version >= kMinSdkVersionForNNAPI13) { + if (delegate_options.max_execution_timeout_duration_ns > 0) { + RETURN_TFLITE_ERROR_IF_NN_ERROR( + context, + nnapi_->ANeuralNetworksExecution_setTimeout( + execution, delegate_options.max_execution_timeout_duration_ns), + "setting execution timeout", nnapi_errno); + } + if (delegate_options.max_execution_loop_timeout_duration_ns > 0) { + RETURN_TFLITE_ERROR_IF_NN_ERROR( + context, + nnapi_->ANeuralNetworksExecution_setLoopTimeout( + execution, + delegate_options.max_execution_loop_timeout_duration_ns), + "setting execution loop timeout", nnapi_errno); + } + } + // Set the input tensor buffers. Note: we access tflite tensors using // absolute indices but NN api indices inputs by relative indices. int relative_input_index = 0; @@ -3588,10 +3640,14 @@ TfLiteStatus NNAPIDelegateKernel::AddOpsAndTensors(TfLiteContext* context, input_tensor_flags |= NN_TENSOR_FLAG_SCALAR_AS_TENSOR; } - // h_swish will be lowered into supported NNAPI operations. - if (reg->builtin_code == kTfLiteBuiltinHardSwish) { - builder.AddHardSwish(node->inputs->data[0], node->outputs->data[0], - need_int8_conversion, node_index); + // On SDK level less than 30, h_swish will be lowered into supported NNAPI + // operations. Since SDK level 30, h_swish is supported as a single + // operation. + if (reg->builtin_code == kTfLiteBuiltinHardSwish && + nnapi_->android_sdk_version < kMinSdkVersionForNNAPI13) { + builder.TransformHardSwishIntoSupportedOps( + node->inputs->data[0], node->outputs->data[0], need_int8_conversion, + node_index); continue; } // Map inputs to NN API tensor indices. @@ -3972,6 +4028,8 @@ TfLiteStatus NNAPIDelegateKernel::BuildGraph( using ::tflite::delegate::nnapi::NNAPIDelegateKernel; +StatefulNnApiDelegate::Data::Data(const NnApi* nnapi) : nnapi(nnapi) {} + StatefulNnApiDelegate::Data::~Data() { std::for_each(std::begin(delegate_state_cache), std::end(delegate_state_cache), @@ -4009,9 +4067,7 @@ StatefulNnApiDelegate::StatefulNnApiDelegate(Options options) StatefulNnApiDelegate::StatefulNnApiDelegate(const NnApi* nnapi, Options options) - : TfLiteDelegate(TfLiteDelegateCreate()), - delegate_data_(Data{.execution_preference = options.execution_preference, - .nnapi = nnapi}) { + : TfLiteDelegate(TfLiteDelegateCreate()), delegate_data_(nnapi) { if (options.accelerator_name) { delegate_data_.accelerator_name = options.accelerator_name; } @@ -4021,6 +4077,7 @@ StatefulNnApiDelegate::StatefulNnApiDelegate(const NnApi* nnapi, if (options.model_token) { delegate_data_.model_token = options.model_token; } + delegate_data_.execution_preference = options.execution_preference; delegate_data_.disallow_nnapi_cpu = options.disallow_nnapi_cpu; delegate_data_.max_number_delegated_partitions = options.max_number_delegated_partitions; diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate.h b/tensorflow/lite/delegates/nnapi/nnapi_delegate.h index 1bd9fb5c49f..7ef02bc5107 100644 --- a/tensorflow/lite/delegates/nnapi/nnapi_delegate.h +++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate.h @@ -22,6 +22,7 @@ limitations under the License. #include "absl/types/optional.h" #include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/nnapi/NeuralNetworksTypes.h" #include "tensorflow/lite/nnapi/nnapi_implementation.h" typedef struct ANeuralNetworksMemory ANeuralNetworksMemory; @@ -90,8 +91,32 @@ class StatefulNnApiDelegate : public TfLiteDelegate { // of number of nodes and selecting them until the limit is reached. int max_number_delegated_partitions = 3; - // allow fp32 compuation to be run in fp16 + // allow fp32 compuation to be run in fp16. bool allow_fp16 = false; + + // Specifies the relative priority for executions of the model. + // Available values are {ANEURALNETWORKS_PRIORITY_LOW, + // ANEURALNETWORKS_PRIORITY_MEDIUM, ANEURALNETWORKS_PRIORITY_HIGH, + // ANEURALNETWORKS_PRIORITY_DEFAULT}. + int execution_priority = ANEURALNETWORKS_PRIORITY_DEFAULT; + + // Specifies the maximum expected duration in nanosecond for compiling the + // model. If the device is not able to complete the compilation within the + // specified duration, the compilation may be aborted. If set to 0, the + // timeout duration is considered infinite. + uint64_t max_compilation_timeout_duration_ns = 0; + + // Specifies the maximum expected duration in nanosecond for executing the + // model. If the device is not able to complete the execution within the + // specified duration, the execution may be aborted. If set to 0, the + // timeout duration is considered infinite. + uint64_t max_execution_timeout_duration_ns = 0; + + // Specifies the maximum expected duration in nanosecond for WHILE loops in + // the execution. If a WHILE loop condition model does not output false + // within the specified duration, the execution will be aborted. If set to + // 0, the default timeout for loops will be used. + uint64_t max_execution_loop_timeout_duration_ns = 0; }; // Uses default options. @@ -156,8 +181,6 @@ class StatefulNnApiDelegate : public TfLiteDelegate { private: // Encapsulates all delegate data. struct Data { - // Preferred Power/perf trade-off. - Options::ExecutionPreference execution_preference; // Pointer to NNAPI implementation to be used by this delegate as // set when building the StatefulNnApiDelegate instance. // Will generally be the NnApiInstance() singleton but can be overridden @@ -165,6 +188,8 @@ class StatefulNnApiDelegate : public TfLiteDelegate { // The ownership of the nnapi instance is left to the caller of // the StatefulNnApiDelegate constructor. const NnApi* nnapi; + // Preferred Power/perf trade-off. + Options::ExecutionPreference execution_preference; // Selected NNAPI accelerator name. std::string accelerator_name; // The cache dir for NNAPI model. @@ -177,7 +202,7 @@ class StatefulNnApiDelegate : public TfLiteDelegate { std::vector<MemoryRegistration> tensor_memory_map; // Contains a non zero value if any NNAPI method call // operation returned a non zero result code. - int nnapi_errno; + int nnapi_errno = ANEURALNETWORKS_NO_ERROR; // Cache of kernels already built in StatefulNnApiDelegate::DoPrepare // when trying to understand if all nodes are supported by the target // accelerators. @@ -187,9 +212,21 @@ class StatefulNnApiDelegate : public TfLiteDelegate { // Maximum number of NNAPI partition to delegate. Zero or negative means // no limit. Copied from StatefulNnApiDelegate::Options int max_number_delegated_partitions; - // allow fp32 computation to be run in fp32 - bool allow_fp16 = false; + // allow fp32 computation to be run in fp16. + bool allow_fp16; + // Specifies the relative priority for executions of the model. + int execution_priority = ANEURALNETWORKS_PRIORITY_DEFAULT; + // Specifies the maximum expected duration in nanosecond for compiling the + // model. + uint64_t max_compilation_timeout_duration_ns = 0; + // Specifies the maximum expected duration in nanosecond for executing the + // model. + uint64_t max_execution_timeout_duration_ns = 0; + // Specifies the maximum expected duration in nanosecond for WHILE loops in + // the execution + uint64_t max_execution_loop_timeout_duration_ns = 0; + explicit Data(const NnApi* nnapi); ~Data(); // Caches an initialised NNAPIDelegateKernel. diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate_disabled.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate_disabled.cc index 3c23054ea25..2bc7ae58449 100644 --- a/tensorflow/lite/delegates/nnapi/nnapi_delegate_disabled.cc +++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate_disabled.cc @@ -27,7 +27,8 @@ StatefulNnApiDelegate::StatefulNnApiDelegate(Options /* options */) : StatefulNnApiDelegate() {} StatefulNnApiDelegate::StatefulNnApiDelegate() - : TfLiteDelegate(TfLiteDelegateCreate()) { + : TfLiteDelegate(TfLiteDelegateCreate()), + delegate_data_(/*nnapi=*/nullptr) { Prepare = DoPrepare; } @@ -46,6 +47,8 @@ int StatefulNnApiDelegate::GetNnApiErrno() const { return 0; } using ::tflite::delegate::nnapi::NNAPIDelegateKernel; +StatefulNnApiDelegate::Data::Data(const NnApi* nnapi) : nnapi(nnapi) {} + StatefulNnApiDelegate::Data::~Data() {} void StatefulNnApiDelegate::Data::CacheDelegateKernel( diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate_kernel.h b/tensorflow/lite/delegates/nnapi/nnapi_delegate_kernel.h index 668fdf5b5f6..af93d9650c9 100644 --- a/tensorflow/lite/delegates/nnapi/nnapi_delegate_kernel.h +++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate_kernel.h @@ -31,6 +31,7 @@ namespace nnapi { constexpr int32_t kMinSdkVersionForNNAPI = 27; constexpr int32_t kMinSdkVersionForNNAPI11 = 28; constexpr int32_t kMinSdkVersionForNNAPI12 = 29; +constexpr int32_t kMinSdkVersionForNNAPI13 = 30; // Track tensor indices to NN API tensor indices mapping. class OperandMapping { diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate_test.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate_test.cc index ea9111c4567..acfa0c77d30 100644 --- a/tensorflow/lite/delegates/nnapi/nnapi_delegate_test.cc +++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate_test.cc @@ -304,6 +304,23 @@ TEST(NNAPIDelegate, StatefulDelegateWithCompilationCaching) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.9, 0.4, 1.0, 1.3})); } +// Sanity check for the state-ful NNAPI delegate with QoS hints. +TEST(NNAPIDelegate, StatefulDelegateWithQoS) { + StatefulNnApiDelegate::Options options; + options.execution_priority = ANEURALNETWORKS_PRIORITY_HIGH; + options.max_compilation_timeout_duration_ns = UINT64_MAX; + options.max_execution_timeout_duration_ns = UINT64_MAX; + options.max_execution_loop_timeout_duration_ns = UINT64_MAX; + + FloatAddOpModel m(options, {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8}); + m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.9, 0.4, 1.0, 1.3})); +} + // Sanity check for the state-ful NNAPI delegate using TfLiteBufferHandle. TEST(NNAPIDelegate, StatefulDelegateWithBufferHandles) { // Skip the test if Android specific functions could not be found. diff --git a/tensorflow/lite/delegates/utils.cc b/tensorflow/lite/delegates/utils.cc index fba8bec39a5..f9cf9380a31 100644 --- a/tensorflow/lite/delegates/utils.cc +++ b/tensorflow/lite/delegates/utils.cc @@ -18,6 +18,7 @@ limitations under the License. #include <algorithm> #include <vector> +#include "tensorflow/lite/builtin_ops.h" #include "tensorflow/lite/context_util.h" namespace tflite { @@ -136,5 +137,167 @@ TfLiteStatus GraphPartitionHelper::PrepareSupportedNodes( return kTfLiteOk; } +TfLiteStatus FP16GraphPartitionHelper::Partition( + std::set<std::string>* unsupported_nodes_info) { + const auto status = GraphPartitionHelper::Partition(unsupported_nodes_info); + // Clean up those partitions that have a single dequant op. NoteThose + // removed dequant ops have to be reserved in the graph and should not be + // delegated. + RemoveSingleDequantNodePartitions(); + return status; +} + +std::vector<int> FP16GraphPartitionHelper::GetNodesOfFirstNLargestPartitions( + int n) { + // We first get partitions to reduce the number of nodes to be checked in + // deciding which dequant ops could actually be replaced. And then we + // remap input-tensor to dequant nodes' inputs and remove those + // to-be-reserved dequant nodes. + auto first_nps = GetFirstNLargestPartitions(n); + std::vector<int> ops_to_replace; + for (const auto p : first_nps) { + auto nodes = p->nodes_to_replace; + ops_to_replace.insert(ops_to_replace.end(), nodes->data, + nodes->data + nodes->size); + } + RemapInputTensors(ops_to_replace); + RemoveReservedDequantsFromNodes(&ops_to_replace); + return ops_to_replace; +} + +bool FP16GraphPartitionHelper::IsNodeSupported( + TfLiteContext* context, TfLiteNode* node, TfLiteRegistration* registration, + int node_id, std::string* unsupported_details) { + // If we need to handle dequant nodes, we have to remap input tensors of + // this node if some of them come from a dequant node before testing if + // the node is supported. + std::vector<int> orig_inputs; + if (RecordAndRemapInputTensors(registration->builtin_code, node_id, node, + &orig_inputs)) { + // We have a dequant op here. Note that we retrun an Ok status because a + // dequant node is first added as supported. Later, this dequant node + // will be removed if it has to be preserved in the graph which happens + // when its immediate downstream nodes cannot be supported. + return true; + } + const auto status = GraphPartitionHelper::IsNodeSupported( + context, node, registration, node_id, unsupported_details); + RestoreToOrigInputTensors(node, orig_inputs); + return status; +} + +bool FP16GraphPartitionHelper::RecordAndRemapInputTensors( + int32_t op_code, int node_id, TfLiteNode* node, + std::vector<int>* orig_inputs) { + orig_inputs->clear(); + // Record the dequant node. + if (op_code == kTfLiteBuiltinDequantize && + context_->tensors[node->inputs->data[0]].type == + TfLiteType::kTfLiteFloat16) { + dequant_nodes_[node->outputs->data[0]] = node->inputs->data[0]; + return true; + } + // For a dequantize op, there's no need to remap its input tensors. + if (dequant_nodes_.empty()) return false; + RemapInputTensors(node, orig_inputs); + return false; +} + +void FP16GraphPartitionHelper::RestoreToOrigInputTensors( + TfLiteNode* node, const std::vector<int>& orig_inputs) { + if (node->inputs->size != orig_inputs.size()) return; + for (int j = 0; j < node->inputs->size; ++j) { + node->inputs->data[j] = orig_inputs[j]; + } +} + +void FP16GraphPartitionHelper::RemapInputTensors( + const std::vector<int>& nodes) const { + for (int node_id : nodes) { + TfLiteNode* node; + TfLiteRegistration* registration; + TfLiteStatus status = context_->GetNodeAndRegistration( + context_, node_id, &node, ®istration); + if (status != kTfLiteOk) { + TF_LITE_KERNEL_LOG(context_, + "Couldn't get node and registration info for op: %d\n", + node_id); + } + RemapInputTensors(node, nullptr /* orig_inputs*/); + } +} + +void FP16GraphPartitionHelper::RemoveSingleDequantNodePartitions() { + auto it = partitions_.begin(); + while (it != partitions_.end()) { + auto p = *it; + if (p->nodes_to_replace->size != 1) { + ++it; + continue; + } + int node_id = p->nodes_to_replace->data[0]; + TfLiteNode* node = nullptr; + TfLiteRegistration* registration = nullptr; + + TfLiteStatus status = context_->GetNodeAndRegistration( + context_, node_id, &node, ®istration); + if (status != kTfLiteOk) { + TF_LITE_KERNEL_LOG(context_, + "Couldn't get node and registration info for op: %d\n", + node_id); + } + if (registration->builtin_code != kTfLiteBuiltinDequantize || + context_->tensors[node->inputs->data[0]].type != + TfLiteType::kTfLiteFloat16) { + ++it; + continue; + } + // Note such dequant nodes have to be preserved in the graph as dequant + // ops are not actually supported in the GPU delegate. + dequant_nodes_to_save_.insert(node_id); + it = partitions_.erase(it); + } +} + +void FP16GraphPartitionHelper::RemoveReservedDequantsFromNodes( + std::vector<int>* nodes) { + if (dequant_nodes_to_save_.empty()) return; + auto it = nodes->begin(); + while (it != nodes->end()) { + if (dequant_nodes_to_save_.find(*it) == dequant_nodes_to_save_.end()) { + ++it; + continue; + } + it = nodes->erase(it); + } +} + +void FP16GraphPartitionHelper::RemapInputTensors( + TfLiteNode* node, std::vector<int>* orig_inputs) const { + TfLiteIntArray* inputs = node->inputs; + auto inputs_view = TfLiteIntArrayView(inputs); + // Prepopulate 'orig_inputs' first and clear it if there's no input from a + // dequant op. + if (orig_inputs) { + orig_inputs->clear(); + orig_inputs->reserve(inputs->size); + for (auto tid : inputs_view) { + orig_inputs->push_back(tid); + } + } + // Fix this node's inputs (i.e. prune out the preceding dequantize node) in + // order to test if it is supported. + bool is_remapped = false; + for (int j = 0; j < inputs->size; ++j) { + const int input_tid = inputs->data[j]; + const auto it = dequant_nodes_.find(input_tid); + if (it != dequant_nodes_.end()) { + inputs->data[j] = it->second; + is_remapped = true; + } + } + if (!is_remapped && orig_inputs) orig_inputs->clear(); +} + } // namespace delegates } // namespace tflite diff --git a/tensorflow/lite/delegates/utils.h b/tensorflow/lite/delegates/utils.h index d6d22c4efa2..11ad9990426 100644 --- a/tensorflow/lite/delegates/utils.h +++ b/tensorflow/lite/delegates/utils.h @@ -16,10 +16,14 @@ limitations under the License. #ifndef TENSORFLOW_LITE_DELEGATES_UTILS_H_ #define TENSORFLOW_LITE_DELEGATES_UTILS_H_ +// Utility functions and classes for implementing delegates. + #include <functional> #include <limits> #include <set> #include <string> +#include <unordered_map> +#include <utility> #include <vector> #include "tensorflow/lite/c/common.h" @@ -109,6 +113,70 @@ class GraphPartitionHelper { // Contains an array of supported node indices. TfLiteIntArray* supported_nodes_ = nullptr; // owns the memory }; + +// While partitioning the graph, this claims DEQUANTIZE nodes (FP16->FP32) in +// addition to supported nodes for the delegate, when the DEQUANTIZE node's +// output is an input to the kernel that supports FP16 input. +// Noth that you have to use `GetNodesOfFirstNLargestPartitions` instead of +// superclass' `GetFirstNLargestPartitions` to do actual remapping of FP16 +// inputs. +class FP16GraphPartitionHelper : public GraphPartitionHelper { + public: + FP16GraphPartitionHelper(TfLiteContext* context, + IsNodeSupportedFn is_node_supported_fn) + : GraphPartitionHelper(context, std::move(is_node_supported_fn)) {} + + TfLiteStatus Partition( + std::set<std::string>* unsupported_nodes_info) override; + + // Returns a list of node indices of all nodes from the first n largest + // partitions. If there are fewer paritions than n, all nodes will be + // returned. The partition is ranked according to the number of nodes. + // TODO(b/156707497): Add this to superclass besides + // GetFirstNLargestPartitions (one that returns partitions instead of nodes) + std::vector<int> GetNodesOfFirstNLargestPartitions(int n); + + protected: + bool IsNodeSupported(TfLiteContext* context, TfLiteNode* node, + TfLiteRegistration* registration, int node_id, + std::string* unsupported_details) override; + + private: + // Record 'node' if it is a dequant op (i.e. a fp16 one here) and return true. + // When it's not a dequant op, remap its inputs to the inputs of the preceding + // dequant if there's a one and returns false. 'orig_inputs' records original + // input tensor ids of this node if any input is remapped. + bool RecordAndRemapInputTensors(int32_t op_code, int node_id, + TfLiteNode* node, + std::vector<int>* orig_inputs); + + // Restore inputs of 'node' to 'orig_inputs' only if two sizes match. + void RestoreToOrigInputTensors(TfLiteNode* node, + const std::vector<int>& orig_inputs); + + // Remap input tensors of every node in 'nodes' (i.e. node indices) if some of + // them are from dequant ops. + void RemapInputTensors(const std::vector<int>& nodes) const; + + void RemoveSingleDequantNodePartitions(); + + void RemoveReservedDequantsFromNodes(std::vector<int>* nodes); + + // Remap input tensors of a single 'node' if some of come from a dequant op. + // If 'orig_inputs' isn't nullptr, it records original input tensor ids of + // this node if any input is remapped. + void RemapInputTensors(TfLiteNode* node, std::vector<int>* orig_inputs) const; + + // A map recording dequantize nodes's input/output tensors of this selected + // graph. The key is the output tensor id, and the value is the input tensor + // id. + std::unordered_map<int, int> dequant_nodes_; + + // A set of dequant nodes as in node indices that have to be preserved in the + // graph. + std::set<int> dequant_nodes_to_save_; +}; + } // namespace delegates } // namespace tflite diff --git a/tensorflow/lite/delegates/xnnpack/BUILD b/tensorflow/lite/delegates/xnnpack/BUILD index e8e6c061160..1cdba72b615 100644 --- a/tensorflow/lite/delegates/xnnpack/BUILD +++ b/tensorflow/lite/delegates/xnnpack/BUILD @@ -53,6 +53,7 @@ cc_library( deps = [ "//tensorflow/lite:framework", "//tensorflow/lite:schema_fbs_version", + "//tensorflow/lite/c:common", "//tensorflow/lite/kernels:builtin_ops", "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", @@ -68,6 +69,7 @@ cc_library( deps = [ "//tensorflow/lite:framework", "//tensorflow/lite:schema_fbs_version", + "//tensorflow/lite/c:common", "//tensorflow/lite/kernels:builtin_ops", "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", @@ -91,6 +93,22 @@ cc_library( ], ) +cc_library( + name = "pad_tester", + testonly = 1, + srcs = ["pad_tester.cc"], + hdrs = ["pad_tester.h"], + deps = [ + "//tensorflow/lite:framework", + "//tensorflow/lite:schema_fbs_version", + "//tensorflow/lite/c:common", + "//tensorflow/lite/kernels:builtin_ops", + "//tensorflow/lite/schema:schema_fbs", + "@com_google_googletest//:gtest", + "@flatbuffers", + ], +) + cc_library( name = "pool_2d_tester", testonly = 1, @@ -99,6 +117,7 @@ cc_library( deps = [ "//tensorflow/lite:framework", "//tensorflow/lite:schema_fbs_version", + "//tensorflow/lite/c:common", "//tensorflow/lite/kernels:builtin_ops", "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", @@ -114,6 +133,7 @@ cc_library( deps = [ "//tensorflow/lite:framework", "//tensorflow/lite:schema_fbs_version", + "//tensorflow/lite/c:common", "//tensorflow/lite/kernels:builtin_ops", "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", @@ -129,6 +149,7 @@ cc_library( deps = [ "//tensorflow/lite:framework", "//tensorflow/lite:schema_fbs_version", + "//tensorflow/lite/c:common", "//tensorflow/lite/kernels:builtin_ops", "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", @@ -293,6 +314,21 @@ cc_test( ], ) +cc_test( + name = "pad_test", + srcs = ["pad_test.cc"], + linkopts = select({ + "//tensorflow:emscripten": EMSCRIPTEN_LINKOPTS, + "//conditions:default": [], + }), + deps = [ + ":pad_tester", + ":test_main", + ":xnnpack_delegate_test_mode", + "@com_google_googletest//:gtest", + ], +) + cc_test( name = "relu_test", srcs = ["relu_test.cc"], diff --git a/tensorflow/lite/delegates/xnnpack/README.md b/tensorflow/lite/delegates/xnnpack/README.md index e0ef6f0899c..98a08a4f647 100644 --- a/tensorflow/lite/delegates/xnnpack/README.md +++ b/tensorflow/lite/delegates/xnnpack/README.md @@ -1,15 +1,48 @@ # XNNPACK backend for TensorFlow Lite XNNPACK is a highly optimized library of floating-point neural network -inference operators for ARM, WebAssembly, and x86 platforms. This document -describes how to use the XNNPACK library as a backend for TensorFlow Lite. +inference operators for ARM, x86, and WebAssembly architectures in Android, iOS, +Windows, Linux, macOS, and Emscripten environments. This document describes how +to use the XNNPACK library as an inference engine for TensorFlow Lite. -## Enabling XNNPACK backend in TensorFlow Lite models +## Using XNNPACK engine with TensorFlow Lite interpreter XNNPACK integrates with TensorFlow Lite interpreter through the delegation -mechanism. To leverage XNNPACK library for acceleration, the users need to -create an XNNPACK delegate with the `TfLiteXNNPackDelegateCreate` function, -and call `Interpreter::ModifyGraphWithDelegate` to delegate supported parts of +mechanism. There are three methods to enable XNNPACK engine in TensorFlow Lite. + +### Enable XNNPACK via Bazel build flags (recommended) + +When building TensorFlow Lite with Bazel, add +`--define tflite_with_xnnpack=true`, and the TensorFlow Lite interpreter will +use XNNPACK engine by default. + +The exact command depends on the target platform, e.g. for Android AAR you'd use + +``` +bazel build -c opt --fat_apk_cpu=x86,x86_64,arm64-v8a,armeabi-v7a \ + --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \ + --define tflite_with_xnnpack=true \ + //tensorflow/lite/java:tensorflow-lite +``` + +### Enable XNNPACK via additional dependency + +Another way to enable XNNPACK is to build and link the +`//tensorflow/lite:tflite_with_xnnpack` target into your application alongside +the TensorFlow Lite framework. + +This method works on platforms which support POSIX-style weak symbols (Android, +iOS, Linux, Mac, but **NOT** Windows). + +### Enable XNNPACK via low-level delegate API (not recommended) + +While it is possible to use low-level delegate API to enable XNNPACK, this +method is **NOT RECOMMENDED** unless you need to use TensorFlow Lite both with +and without XNNPACK (e.g. for benchmarking). + +With low-level delegate API users create an XNNPACK delegate with the +`TfLiteXNNPackDelegateCreate` function, and then call +`Interpreter::ModifyGraphWithDelegate` to delegate supported parts of the model to the XNNPACK delegate. The users must destroy the delegate with `TfLiteXNNPackDelegateDelete` **after** releasing the TensorFlow Lite interpreter. The snippet below illustrates the typical usage: @@ -59,8 +92,6 @@ Below is the list of current operators and limitations: * Only addition with two inputs is supported. * Fused `NONE`, `RELU`, `RELU_N1_TO_1`, and `RELU6` activations are supported, but fused `TANH` and `SIGN_BIT` activations are not. -* Dynamically allocated (with `kTfLiteDynamic` allocation type) inputs and - output are not supported. ### `AVERAGE_POOL_2D` @@ -68,8 +99,6 @@ Below is the list of current operators and limitations: * 1x1 pooling is not supported. * Fused `NONE`, `RELU`, `RELU_N1_TO_1`, and `RELU6` activations are supported, but fused `TANH` and `SIGN_BIT` activations are not. -* Dynamically allocated (with `kTfLiteDynamic` allocation type) inputs and - output are not supported. ### `CONV_2D` @@ -78,8 +107,6 @@ Below is the list of current operators and limitations: * Both filter and bias must be static (use `kTfLiteMmapRo` allocation type). * Fused `NONE`, `RELU`, `RELU_N1_TO_1`, and `RELU6` activations are supported, but fused `TANH` and `SIGN_BIT` activations are not. -* Dynamically allocated (with `kTfLiteDynamic` allocation type) input and output - are not supported. ### `DEPTHWISE_CONV_2D` @@ -88,8 +115,6 @@ Below is the list of current operators and limitations: * Both filter and bias must be static (use `kTfLiteMmapRo` allocation type). * Fused `NONE`, `RELU`, `RELU_N1_TO_1`, and `RELU6` activations are supported, but fused `TANH` and `SIGN_BIT` activations are not. -* Dynamically allocated (with `kTfLiteDynamic` allocation type) input and output - are not supported. ### `FULLY_CONNECTED` @@ -98,20 +123,14 @@ Below is the list of current operators and limitations: * Both filter and bias must be static (use `kTfLiteMmapRo` allocation type). * Fused `NONE`, `RELU`, `RELU_N1_TO_1`, and `RELU6` activations are supported, but fused `TANH` and `SIGN_BIT` activations are not. -* Dynamically allocated (with `kTfLiteDynamic` allocation type) input and output - are not supported. ### `HARD_SWISH` * Inputs and outputs must be in 32-bit floating-point format. -* Dynamically allocated (with `kTfLiteDynamic` allocation type) inputs and - output are not supported. ### `LOGISTIC` * Inputs and outputs must be in 32-bit floating-point format. -* Dynamically allocated (with `kTfLiteDynamic` allocation type) inputs and - output are not supported. ### `MAX_POOL_2D` @@ -119,16 +138,19 @@ Below is the list of current operators and limitations: * 1x1 pooling is not supported. * Fused `NONE`, `RELU`, `RELU_N1_TO_1`, and `RELU6` activations are supported, but fused `TANH` and `SIGN_BIT` activations are not. -* Dynamically allocated (with `kTfLiteDynamic` allocation type) inputs and - output are not supported. ### `MUL` * Inputs and outputs must be in 32-bit floating-point format. * Fused `NONE`, `RELU`, `RELU_N1_TO_1`, and `RELU6` activations are supported, but fused `TANH` and `SIGN_BIT` activations are not. -* Dynamically allocated (with `kTfLiteDynamic` allocation type) inputs and - output are not supported. + +### `PAD` + +* The first input and the output must be in 32-bit floating-point format. +* The second input (the input with the padding specification) must be static + (use `kTfLiteMmapRo` allocation type). +* The numbers of padding elements must be non-negative. ### `PRELU` @@ -136,36 +158,28 @@ Below is the list of current operators and limitations: * Slope must be static (use `kTfLiteMmapRo` allocation type). * Slope must be either a 1D tensor, or have all its non-channel dimensions equal 1. -* Dynamically allocated (with `kTfLiteDynamic` allocation type) input and output - are not supported. ### `RELU` * Inputs and outputs must be in 32-bit floating-point format. -* Dynamically allocated (with `kTfLiteDynamic` allocation type) inputs and - output are not supported. ### `RELU6` * Inputs and outputs must be in 32-bit floating-point format. -* Dynamically allocated (with `kTfLiteDynamic` allocation type) inputs and - output are not supported. ### `RELU_N1_TO_1` * Inputs and outputs must be in 32-bit floating-point format. -* Dynamically allocated (with `kTfLiteDynamic` allocation type) inputs and - output are not supported. ### `SOFTMAX` * Inputs and outputs must be in 32-bit floating-point format. * Only `beta = 1.0` is supported. -* Dynamically allocated (with `kTfLiteDynamic` allocation type) inputs and - output are not supported. ### Other limitations +* Dynamically allocated (with `kTfLiteDynamic` allocation type) inputs and + outputs are not supported. * Resizing model inputs (via `Interpreter::ResizeInputTensor`) is supported, but cause a complete reinitialization of the delegate instance, which has considerable overhead. diff --git a/tensorflow/lite/delegates/xnnpack/binary_elementwise_tester.h b/tensorflow/lite/delegates/xnnpack/binary_elementwise_tester.h index 6d9a8b6caa9..15c99c3148d 100644 --- a/tensorflow/lite/delegates/xnnpack/binary_elementwise_tester.h +++ b/tensorflow/lite/delegates/xnnpack/binary_elementwise_tester.h @@ -17,17 +17,11 @@ limitations under the License. #define TENSORFLOW_LITE_DELEGATES_XNNPACK_BINARY_ELEMENTWISE_TESTER_H_ #include <cstdint> -#include <functional> -#include <random> #include <vector> #include <gtest/gtest.h> -#include "flatbuffers/flatbuffers.h" // from @flatbuffers -#include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/kernels/register.h" -#include "tensorflow/lite/model.h" +#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/schema/schema_generated.h" -#include "tensorflow/lite/version.h" namespace tflite { namespace xnnpack { diff --git a/tensorflow/lite/delegates/xnnpack/depthwise_conv_2d_tester.h b/tensorflow/lite/delegates/xnnpack/depthwise_conv_2d_tester.h index ec8e4cea429..16dc5920229 100644 --- a/tensorflow/lite/delegates/xnnpack/depthwise_conv_2d_tester.h +++ b/tensorflow/lite/delegates/xnnpack/depthwise_conv_2d_tester.h @@ -17,17 +17,11 @@ limitations under the License. #define TENSORFLOW_LITE_DELEGATES_XNNPACK_DEPTHWISE_CONV_2D_TESTER_H_ #include <cstdint> -#include <functional> -#include <random> #include <vector> #include <gtest/gtest.h> -#include "flatbuffers/flatbuffers.h" // from @flatbuffers -#include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/kernels/register.h" -#include "tensorflow/lite/model.h" +#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/schema/schema_generated.h" -#include "tensorflow/lite/version.h" namespace tflite { namespace xnnpack { diff --git a/tensorflow/lite/delegates/xnnpack/fully_connected_tester.h b/tensorflow/lite/delegates/xnnpack/fully_connected_tester.h index 1c8e3d5d60c..cf1d5513d46 100644 --- a/tensorflow/lite/delegates/xnnpack/fully_connected_tester.h +++ b/tensorflow/lite/delegates/xnnpack/fully_connected_tester.h @@ -17,8 +17,6 @@ limitations under the License. #define TENSORFLOW_LITE_DELEGATES_XNNPACK_FULLY_CONNECTED_TESTER_H_ #include <cstdint> -#include <functional> -#include <random> #include <vector> #include <gtest/gtest.h> diff --git a/tensorflow/lite/delegates/xnnpack/pad_test.cc b/tensorflow/lite/delegates/xnnpack/pad_test.cc new file mode 100644 index 00000000000..c93ff8ab661 --- /dev/null +++ b/tensorflow/lite/delegates/xnnpack/pad_test.cc @@ -0,0 +1,279 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <cstdint> +#include <functional> +#include <memory> +#include <random> + +#include <gtest/gtest.h> +#include "tensorflow/lite/delegates/xnnpack/pad_tester.h" +#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" + +namespace tflite { +namespace xnnpack { + +TEST(Pad, Full4D) { + std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)> + xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), + TfLiteXNNPackDelegateDelete); + + std::random_device random_device; + auto rng = std::mt19937(random_device()); + auto pad_rng = + std::bind(std::uniform_int_distribution<int32_t>(1, 3), std::ref(rng)); + auto shape_rng = + std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng)); + + PadTester() + .InputPrePaddings({pad_rng(), pad_rng(), pad_rng(), pad_rng()}) + .InputPostPaddings({pad_rng(), pad_rng(), pad_rng(), pad_rng()}) + .InputShape({shape_rng(), shape_rng(), shape_rng(), shape_rng()}) + .Test(xnnpack_delegate.get()); +} + +TEST(Pad, Batch4D) { + std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)> + xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), + TfLiteXNNPackDelegateDelete); + + std::random_device random_device; + auto rng = std::mt19937(random_device()); + auto pad_rng = + std::bind(std::uniform_int_distribution<int32_t>(1, 3), std::ref(rng)); + auto shape_rng = + std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng)); + + PadTester() + .InputPrePaddings({pad_rng(), 0, 0, 0}) + .InputPostPaddings({pad_rng(), 0, 0, 0}) + .InputShape({shape_rng(), shape_rng(), shape_rng(), shape_rng()}) + .Test(xnnpack_delegate.get()); +} + +TEST(Pad, HeightAndWidth4D) { + std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)> + xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), + TfLiteXNNPackDelegateDelete); + + std::random_device random_device; + auto rng = std::mt19937(random_device()); + auto pad_rng = + std::bind(std::uniform_int_distribution<int32_t>(1, 3), std::ref(rng)); + auto shape_rng = + std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng)); + + PadTester() + .InputPrePaddings({0, pad_rng(), pad_rng(), 0}) + .InputPostPaddings({0, pad_rng(), pad_rng(), 0}) + .InputShape({shape_rng(), shape_rng(), shape_rng(), shape_rng()}) + .Test(xnnpack_delegate.get()); +} + +TEST(Pad, Channels4D) { + std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)> + xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), + TfLiteXNNPackDelegateDelete); + + std::random_device random_device; + auto rng = std::mt19937(random_device()); + auto pad_rng = + std::bind(std::uniform_int_distribution<int32_t>(1, 3), std::ref(rng)); + auto shape_rng = + std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng)); + + PadTester() + .InputPrePaddings({0, 0, 0, pad_rng()}) + .InputPostPaddings({0, 0, 0, pad_rng()}) + .InputShape({shape_rng(), shape_rng(), shape_rng(), shape_rng()}) + .Test(xnnpack_delegate.get()); +} + +TEST(Pad, Full3D) { + std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)> + xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), + TfLiteXNNPackDelegateDelete); + + std::random_device random_device; + auto rng = std::mt19937(random_device()); + auto pad_rng = + std::bind(std::uniform_int_distribution<int32_t>(1, 3), std::ref(rng)); + auto shape_rng = + std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng)); + + PadTester() + .InputPrePaddings({pad_rng(), pad_rng(), pad_rng()}) + .InputPostPaddings({pad_rng(), pad_rng(), pad_rng()}) + .InputShape({shape_rng(), shape_rng(), shape_rng()}) + .Test(xnnpack_delegate.get()); +} + +TEST(Pad, Batch3D) { + std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)> + xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), + TfLiteXNNPackDelegateDelete); + + std::random_device random_device; + auto rng = std::mt19937(random_device()); + auto pad_rng = + std::bind(std::uniform_int_distribution<int32_t>(1, 3), std::ref(rng)); + auto shape_rng = + std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng)); + + PadTester() + .InputPrePaddings({pad_rng(), 0, 0}) + .InputPostPaddings({pad_rng(), 0, 0}) + .InputShape({shape_rng(), shape_rng(), shape_rng()}) + .Test(xnnpack_delegate.get()); +} + +TEST(Pad, Width3D) { + std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)> + xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), + TfLiteXNNPackDelegateDelete); + + std::random_device random_device; + auto rng = std::mt19937(random_device()); + auto pad_rng = + std::bind(std::uniform_int_distribution<int32_t>(1, 3), std::ref(rng)); + auto shape_rng = + std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng)); + + PadTester() + .InputPrePaddings({0, pad_rng(), 0}) + .InputPostPaddings({0, pad_rng(), 0}) + .InputShape({shape_rng(), shape_rng(), shape_rng()}) + .Test(xnnpack_delegate.get()); +} + +TEST(Pad, Channels3D) { + std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)> + xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), + TfLiteXNNPackDelegateDelete); + + std::random_device random_device; + auto rng = std::mt19937(random_device()); + auto pad_rng = + std::bind(std::uniform_int_distribution<int32_t>(1, 3), std::ref(rng)); + auto shape_rng = + std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng)); + + PadTester() + .InputPrePaddings({0, 0, pad_rng()}) + .InputPostPaddings({0, 0, pad_rng()}) + .InputShape({shape_rng(), shape_rng(), shape_rng()}) + .Test(xnnpack_delegate.get()); +} + +TEST(Pad, Full2D) { + std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)> + xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), + TfLiteXNNPackDelegateDelete); + + std::random_device random_device; + auto rng = std::mt19937(random_device()); + auto pad_rng = + std::bind(std::uniform_int_distribution<int32_t>(1, 3), std::ref(rng)); + auto shape_rng = + std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng)); + + PadTester() + .InputPrePaddings({pad_rng(), pad_rng()}) + .InputPostPaddings({pad_rng(), pad_rng()}) + .InputShape({shape_rng(), shape_rng()}) + .Test(xnnpack_delegate.get()); +} + +TEST(Pad, Batch2D) { + std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)> + xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), + TfLiteXNNPackDelegateDelete); + + std::random_device random_device; + auto rng = std::mt19937(random_device()); + auto pad_rng = + std::bind(std::uniform_int_distribution<int32_t>(1, 3), std::ref(rng)); + auto shape_rng = + std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng)); + + PadTester() + .InputPrePaddings({pad_rng(), 0}) + .InputPostPaddings({pad_rng(), 0}) + .InputShape({shape_rng(), shape_rng()}) + .Test(xnnpack_delegate.get()); +} + +TEST(Pad, Channels2D) { + std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)> + xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), + TfLiteXNNPackDelegateDelete); + + std::random_device random_device; + auto rng = std::mt19937(random_device()); + auto pad_rng = + std::bind(std::uniform_int_distribution<int32_t>(1, 3), std::ref(rng)); + auto shape_rng = + std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng)); + + PadTester() + .InputPrePaddings({0, pad_rng()}) + .InputPostPaddings({0, pad_rng()}) + .InputShape({shape_rng(), shape_rng()}) + .Test(xnnpack_delegate.get()); +} + +TEST(Pad, 1D) { + std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)> + xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), + TfLiteXNNPackDelegateDelete); + + std::random_device random_device; + auto rng = std::mt19937(random_device()); + auto pad_rng = + std::bind(std::uniform_int_distribution<int32_t>(1, 3), std::ref(rng)); + auto shape_rng = + std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng)); + + PadTester() + .InputPrePaddings({pad_rng(), pad_rng()}) + .InputPostPaddings({pad_rng(), pad_rng()}) + .InputShape({shape_rng(), shape_rng()}) + .Test(xnnpack_delegate.get()); +} + +TEST(Pad, MultiThreading) { + TfLiteXNNPackDelegateOptions delegate_options = + TfLiteXNNPackDelegateOptionsDefault(); + delegate_options.num_threads = 2; + std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)> + xnnpack_delegate(TfLiteXNNPackDelegateCreate(&delegate_options), + TfLiteXNNPackDelegateDelete); + + std::random_device random_device; + auto rng = std::mt19937(random_device()); + auto pad_rng = + std::bind(std::uniform_int_distribution<int32_t>(1, 3), std::ref(rng)); + auto shape_rng = + std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng)); + + PadTester() + .InputPrePaddings({0, 0, 0, pad_rng()}) + .InputPostPaddings({0, 0, 0, pad_rng()}) + .InputShape({shape_rng(), shape_rng(), shape_rng(), shape_rng()}) + .Test(xnnpack_delegate.get()); +} + +} // namespace xnnpack +} // namespace tflite diff --git a/tensorflow/lite/delegates/xnnpack/pad_tester.cc b/tensorflow/lite/delegates/xnnpack/pad_tester.cc new file mode 100644 index 00000000000..e364b880124 --- /dev/null +++ b/tensorflow/lite/delegates/xnnpack/pad_tester.cc @@ -0,0 +1,187 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/delegates/xnnpack/pad_tester.h" + +#include <array> +#include <cstdint> +#include <functional> +#include <numeric> +#include <random> +#include <vector> + +#include <gtest/gtest.h> +#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/model.h" +#include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/version.h" + +namespace tflite { +namespace xnnpack { + +std::vector<int32_t> PadTester::OutputShape() const { + std::vector<int32_t> output_shape; + output_shape.reserve(InputShape().size()); + for (size_t i = 0; i < InputShape().size(); i++) { + int32_t output_dim = InputShape()[i]; + if (i < InputPrePaddings().size()) { + output_dim += InputPrePaddings()[i]; + } + if (i < InputPostPaddings().size()) { + output_dim += InputPostPaddings()[i]; + } + output_shape.push_back(output_dim); + } + return output_shape; +} + +void PadTester::Test(TfLiteDelegate* delegate) const { + ASSERT_EQ(InputPrePaddings().size(), InputPostPaddings().size()); + ASSERT_LE(InputPrePaddings().size(), InputShape().size()); + + std::random_device random_device; + auto rng = std::mt19937(random_device()); + auto input_rng = + std::bind(std::uniform_real_distribution<float>(), std::ref(rng)); + + std::vector<char> buffer = CreateTfLiteModel(); + const Model* model = GetModel(buffer.data()); + + std::unique_ptr<Interpreter> delegate_interpreter; + ASSERT_EQ( + InterpreterBuilder(model, ::tflite::ops::builtin::BuiltinOpResolver())( + &delegate_interpreter), + kTfLiteOk); + std::unique_ptr<Interpreter> default_interpreter; + ASSERT_EQ( + InterpreterBuilder(model, ::tflite::ops::builtin::BuiltinOpResolver())( + &default_interpreter), + kTfLiteOk); + + ASSERT_TRUE(delegate_interpreter); + ASSERT_TRUE(default_interpreter); + + ASSERT_EQ(delegate_interpreter->inputs().size(), 1); + ASSERT_EQ(default_interpreter->inputs().size(), 1); + + ASSERT_EQ(delegate_interpreter->outputs().size(), 1); + ASSERT_EQ(default_interpreter->outputs().size(), 1); + + ASSERT_EQ(delegate_interpreter->AllocateTensors(), kTfLiteOk); + ASSERT_EQ(default_interpreter->AllocateTensors(), kTfLiteOk); + + ASSERT_EQ(delegate_interpreter->ModifyGraphWithDelegate(delegate), kTfLiteOk); + + float* default_input_data = default_interpreter->typed_tensor<float>( + default_interpreter->inputs()[0]); + std::generate(default_input_data, + default_input_data + ComputeSize(InputShape()), + std::ref(input_rng)); + + float* delegate_input_data = delegate_interpreter->typed_tensor<float>( + delegate_interpreter->inputs()[0]); + std::copy(default_input_data, default_input_data + ComputeSize(InputShape()), + delegate_input_data); + + ASSERT_EQ(default_interpreter->Invoke(), kTfLiteOk); + ASSERT_EQ(delegate_interpreter->Invoke(), kTfLiteOk); + + float* default_output_data = default_interpreter->typed_tensor<float>( + default_interpreter->outputs()[0]); + float* delegate_output_data = delegate_interpreter->typed_tensor<float>( + delegate_interpreter->outputs()[0]); + + for (size_t i = 0; i < ComputeSize(OutputShape()); i++) { + ASSERT_EQ(default_output_data[i], delegate_output_data[i]); + } +} + +std::vector<char> PadTester::CreateTfLiteModel() const { + flatbuffers::FlatBufferBuilder builder; + flatbuffers::Offset<OperatorCode> operator_code = + CreateOperatorCode(builder, BuiltinOperator_PAD); + + std::vector<int32_t> paddings(InputPrePaddings().size() + + InputPostPaddings().size()); + for (size_t i = 0; i < InputPrePaddings().size(); i++) { + paddings[i * 2] = InputPrePaddings()[i]; + paddings[i * 2 + 1] = InputPostPaddings()[i]; + } + const std::array<flatbuffers::Offset<Buffer>, 2> buffers{{ + CreateBuffer(builder, builder.CreateVector({})), + CreateBuffer(builder, + builder.CreateVector( + reinterpret_cast<const uint8_t*>(paddings.data()), + sizeof(float) * paddings.size())), + }}; + + const std::vector<int32_t> output_shape = OutputShape(); + const std::array<int32_t, 2> paddings_shape{ + {static_cast<int32_t>(InputPrePaddings().size()), 2}}; + const std::array<flatbuffers::Offset<Tensor>, 3> tensors{{ + CreateTensor(builder, + builder.CreateVector<int32_t>(InputShape().data(), + InputShape().size()), + TensorType_FLOAT32), + CreateTensor(builder, + builder.CreateVector<int32_t>(paddings_shape.data(), + paddings_shape.size()), + TensorType_INT32, /*buffer=*/1), + CreateTensor(builder, + builder.CreateVector<int32_t>(output_shape.data(), + output_shape.size()), + TensorType_FLOAT32), + }}; + + const std::array<int32_t, 2> op_inputs{{0, 1}}; + const std::array<int32_t, 1> op_outputs{{2}}; + flatbuffers::Offset<Operator> op = CreateOperator( + builder, /*opcode_index=*/0, + builder.CreateVector<int32_t>(op_inputs.data(), op_inputs.size()), + builder.CreateVector<int32_t>(op_outputs.data(), op_outputs.size())); + + const std::array<int32_t, 1> subgraph_inputs{{0}}; + const std::array<int32_t, 1> subgraph_outputs{{2}}; + flatbuffers::Offset<SubGraph> subgraph = CreateSubGraph( + builder, builder.CreateVector(tensors.data(), tensors.size()), + builder.CreateVector<int32_t>(subgraph_inputs.data(), + subgraph_inputs.size()), + builder.CreateVector<int32_t>(subgraph_outputs.data(), + subgraph_outputs.size()), + builder.CreateVector(&op, 1)); + + flatbuffers::Offset<flatbuffers::String> description = + builder.CreateString("Pad model"); + + flatbuffers::Offset<Model> model_buffer = CreateModel( + builder, TFLITE_SCHEMA_VERSION, builder.CreateVector(&operator_code, 1), + builder.CreateVector(&subgraph, 1), description, + builder.CreateVector(buffers.data(), buffers.size())); + + builder.Finish(model_buffer); + + return std::vector<char>(builder.GetBufferPointer(), + builder.GetBufferPointer() + builder.GetSize()); +} + +int32_t PadTester::ComputeSize(const std::vector<int32_t>& shape) { + return std::accumulate(shape.cbegin(), shape.cend(), 1, + std::multiplies<int32_t>()); +} + +} // namespace xnnpack +} // namespace tflite diff --git a/tensorflow/lite/delegates/xnnpack/pad_tester.h b/tensorflow/lite/delegates/xnnpack/pad_tester.h new file mode 100644 index 00000000000..a6951fdf156 --- /dev/null +++ b/tensorflow/lite/delegates/xnnpack/pad_tester.h @@ -0,0 +1,87 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_DELEGATES_XNNPACK_PAD_TESTER_H_ +#define TENSORFLOW_LITE_DELEGATES_XNNPACK_PAD_TESTER_H_ + +#include <cstdint> +#include <vector> + +#include <gtest/gtest.h> +#include "tensorflow/lite/c/common.h" + +namespace tflite { +namespace xnnpack { + +class PadTester { + public: + PadTester() = default; + PadTester(const PadTester&) = delete; + PadTester& operator=(const PadTester&) = delete; + + inline PadTester& InputShape(std::initializer_list<int32_t> shape) { + for (auto it = shape.begin(); it != shape.end(); ++it) { + EXPECT_GT(*it, 0); + } + input_shape_ = std::vector<int32_t>(shape.begin(), shape.end()); + return *this; + } + + inline const std::vector<int32_t>& InputShape() const { return input_shape_; } + + inline PadTester& InputPrePaddings(std::initializer_list<int32_t> paddings) { + for (auto it = paddings.begin(); it != paddings.end(); ++it) { + EXPECT_GE(*it, 0); + } + input_pre_paddings_ = + std::vector<int32_t>(paddings.begin(), paddings.end()); + return *this; + } + + inline const std::vector<int32_t> InputPrePaddings() const { + return input_pre_paddings_; + } + + inline PadTester& InputPostPaddings(std::initializer_list<int32_t> paddings) { + for (auto it = paddings.begin(); it != paddings.end(); ++it) { + EXPECT_GE(*it, 0); + } + input_post_paddings_ = + std::vector<int32_t>(paddings.begin(), paddings.end()); + return *this; + } + + inline const std::vector<int32_t> InputPostPaddings() const { + return input_post_paddings_; + } + + std::vector<int32_t> OutputShape() const; + + void Test(TfLiteDelegate* delegate) const; + + private: + std::vector<char> CreateTfLiteModel() const; + + static int32_t ComputeSize(const std::vector<int32_t>& shape); + + std::vector<int32_t> input_shape_; + std::vector<int32_t> input_pre_paddings_; + std::vector<int32_t> input_post_paddings_; +}; + +} // namespace xnnpack +} // namespace tflite + +#endif // TENSORFLOW_LITE_DELEGATES_XNNPACK_PAD_TESTER_H_ diff --git a/tensorflow/lite/delegates/xnnpack/pool_2d_tester.h b/tensorflow/lite/delegates/xnnpack/pool_2d_tester.h index 3125e9231f6..a84be10ad45 100644 --- a/tensorflow/lite/delegates/xnnpack/pool_2d_tester.h +++ b/tensorflow/lite/delegates/xnnpack/pool_2d_tester.h @@ -17,17 +17,11 @@ limitations under the License. #define TENSORFLOW_LITE_DELEGATES_XNNPACK_POOL_2D_TESTER_H_ #include <cstdint> -#include <functional> -#include <random> #include <vector> #include <gtest/gtest.h> -#include "flatbuffers/flatbuffers.h" // from @flatbuffers -#include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/kernels/register.h" -#include "tensorflow/lite/model.h" +#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/schema/schema_generated.h" -#include "tensorflow/lite/version.h" namespace tflite { namespace xnnpack { diff --git a/tensorflow/lite/delegates/xnnpack/softmax_tester.h b/tensorflow/lite/delegates/xnnpack/softmax_tester.h index 9f930a6f21e..674dc9a443e 100644 --- a/tensorflow/lite/delegates/xnnpack/softmax_tester.h +++ b/tensorflow/lite/delegates/xnnpack/softmax_tester.h @@ -17,17 +17,11 @@ limitations under the License. #define TENSORFLOW_LITE_DELEGATES_XNNPACK_SOFTMAX_TESTER_H_ #include <cstdint> -#include <functional> -#include <random> #include <vector> #include <gtest/gtest.h> -#include "flatbuffers/flatbuffers.h" // from @flatbuffers -#include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/kernels/register.h" -#include "tensorflow/lite/model.h" +#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/schema/schema_generated.h" -#include "tensorflow/lite/version.h" namespace tflite { namespace xnnpack { diff --git a/tensorflow/lite/delegates/xnnpack/unary_elementwise_tester.h b/tensorflow/lite/delegates/xnnpack/unary_elementwise_tester.h index 88508ccd1c1..e3c210fd6b3 100644 --- a/tensorflow/lite/delegates/xnnpack/unary_elementwise_tester.h +++ b/tensorflow/lite/delegates/xnnpack/unary_elementwise_tester.h @@ -17,17 +17,11 @@ limitations under the License. #define TENSORFLOW_LITE_DELEGATES_XNNPACK_UNARY_ELEMENTWISE_TESTER_H_ #include <cstdint> -#include <functional> -#include <random> #include <vector> #include <gtest/gtest.h> -#include "flatbuffers/flatbuffers.h" // from @flatbuffers -#include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/kernels/register.h" -#include "tensorflow/lite/model.h" +#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/schema/schema_generated.h" -#include "tensorflow/lite/version.h" namespace tflite { namespace xnnpack { diff --git a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc index 6d9b4dac8f8..2beaa16255d 100644 --- a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc +++ b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" #include <algorithm> +#include <array> #include <cstdint> #include <cstring> #include <limits> @@ -120,9 +121,22 @@ class Subgraph { return nullptr; } - for (int k = 0; k < node->inputs->size; k++) { - const int t = node->inputs->data[k]; - tensors[t] = t; + switch (registration->builtin_code) { + case kTfLiteBuiltinPad: + // Ignore the second input (static padding), because it is + // represented as parameters of the XNNPACK operator rather than + // extra input. + { + const int t = node->inputs->data[0]; + tensors[t] = t; + } + break; + default: + // All other operators: process all inputs + for (int k = 0; k < node->inputs->size; k++) { + const int t = node->inputs->data[k]; + tensors[t] = t; + } } for (int k = 0; k < node->outputs->size; k++) { const int t = node->outputs->data[k]; @@ -532,10 +546,11 @@ class Subgraph { return kTfLiteOk; } - static TfLiteStatus CheckTensorFloatType(TfLiteContext* context, - const TfLiteTensor& tensor, - int tensor_index, int node_index) { - if (tensor.type != kTfLiteFloat32) { + static TfLiteStatus CheckTensorType(TfLiteContext* context, + const TfLiteTensor& tensor, + TfLiteType expected_type, + int tensor_index, int node_index) { + if (tensor.type != expected_type) { TF_LITE_MAYBE_KERNEL_LOG( context, "unsupported type %s in tensor #%d in node #%d", TfLiteTypeGetName(tensor.type), tensor_index, node_index); @@ -544,28 +559,64 @@ class Subgraph { return kTfLiteOk; } + static TfLiteStatus CheckTensorFloatType(TfLiteContext* context, + const TfLiteTensor& tensor, + int tensor_index, int node_index) { + return CheckTensorType(context, tensor, kTfLiteFloat32, tensor_index, + node_index); + } + static TfLiteStatus CheckTensorShape(TfLiteContext* context, const TfLiteTensor& tensor, - int expected_num_dims, + int min_num_dims, int max_num_dims, int tensor_index) { - if (tensor.dims->size != expected_num_dims) { - TF_LITE_MAYBE_KERNEL_LOG( - context, - "unexpected number of shape dimensions (%d != %d) in tensor #%d", - tensor.dims->size, expected_num_dims, tensor_index); - return kTfLiteError; + if (min_num_dims == max_num_dims) { + if (tensor.dims->size != min_num_dims) { + TF_LITE_MAYBE_KERNEL_LOG( + context, + "unsupported number of shape dimensions (%d) in tensor #%d: " + "%d dimensions expected", + tensor.dims->size, tensor_index, min_num_dims); + return kTfLiteError; + } + } else { + if (tensor.dims->size < min_num_dims) { + TF_LITE_MAYBE_KERNEL_LOG( + context, + "unsupported number of shape dimensions (%d) in tensor #%d: " + "at least %d dimensions expected", + tensor.dims->size, tensor_index, min_num_dims); + return kTfLiteError; + } + if (tensor.dims->size > max_num_dims) { + TF_LITE_MAYBE_KERNEL_LOG( + context, + "unsupported number of shape dimensions (%d) in tensor #%d: " + "at most %d dimensions expected", + tensor.dims->size, tensor_index, max_num_dims); + return kTfLiteError; + } } for (int i = 0; i < tensor.dims->size; i++) { if (tensor.dims->data[i] <= 0) { TF_LITE_MAYBE_KERNEL_LOG(context, - "invalid dimension #%d (%d) in tensor #%d", i, - tensor.dims->data[i], tensor_index); + "invalid num of elements (%d) in " + "dimension #%d in tensor #%d", + tensor.dims->data[i], i, tensor_index); return kTfLiteError; } } return kTfLiteOk; } + static TfLiteStatus CheckTensorShape(TfLiteContext* context, + const TfLiteTensor& tensor, + int expected_num_dims, + int tensor_index) { + return CheckTensorShape(context, tensor, expected_num_dims, + expected_num_dims, tensor_index); + } + static TfLiteStatus CheckSlopeTensorShape(TfLiteContext* context, const TfLiteTensor& tensor, int tensor_index, int node_index) { @@ -592,6 +643,39 @@ class Subgraph { return kTfLiteOk; } + static TfLiteStatus CheckPaddingsTensorShape(TfLiteContext* context, + const TfLiteTensor& tensor, + int expected_rows, + int tensor_index, + int node_index) { + if (tensor.dims->size != 2) { + TF_LITE_MAYBE_KERNEL_LOG(context, + "unexpected number of shape dimensions (%d) in " + "padding tensor #%d in node #%d: " + "expected a 2D tensor", + tensor.dims->size, tensor_index, node_index); + return kTfLiteError; + } + if (tensor.dims->data[0] != expected_rows) { + TF_LITE_MAYBE_KERNEL_LOG(context, + "unexpected number of rows (%d) in " + "padding tensor #%d in node #%d: " + "%d rows expected", + tensor.dims->size, tensor_index, node_index, + expected_rows); + return kTfLiteError; + } + if (tensor.dims->data[1] != 2) { + TF_LITE_MAYBE_KERNEL_LOG(context, + "unexpected number of columns (%d) in " + "padding tensor #%d in node #%d: " + "2 columns expected", + tensor.dims->size, tensor_index, node_index); + return kTfLiteError; + } + return kTfLiteOk; + } + static TfLiteStatus CheckTensorNonDynamicAllocation( TfLiteContext* context, const TfLiteTensor& tensor, int tensor_index, int node_index) { @@ -693,6 +777,9 @@ class Subgraph { return VisitMulNode(subgraph, logging_context, node_index, node, context->tensors, mul_params, xnnpack_tensors); } + case kTfLiteBuiltinPad: + return VisitPadNode(subgraph, logging_context, node_index, node, + context->tensors, xnnpack_tensors); case kTfLiteBuiltinPrelu: return VisitPreluNode(subgraph, logging_context, node_index, node, context->tensors, xnnpack_tensors); @@ -1565,6 +1652,86 @@ class Subgraph { return kTfLiteOk; } + static TfLiteStatus VisitPadNode( + xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index, + TfLiteNode* node, const TfLiteTensor* tensors, + const std::vector<uint32_t>& xnnpack_tensors) { + TF_LITE_ENSURE_STATUS( + CheckNumInputsAndOutputs(logging_context, node, 2, 1, node_index)); + + const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; + TF_LITE_ENSURE_STATUS(CheckTensorFloatType( + logging_context, input_tensor, node->inputs->data[0], node_index)); + TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, input_tensor, 1, + XNN_MAX_TENSOR_DIMS, + node->inputs->data[0])); + TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( + logging_context, input_tensor, node->inputs->data[0], node_index)); + + const TfLiteTensor& paddings_tensor = tensors[node->inputs->data[1]]; + TF_LITE_ENSURE_STATUS(CheckTensorType(logging_context, paddings_tensor, + kTfLiteInt32, node->inputs->data[1], + node_index)); + TF_LITE_ENSURE_STATUS(CheckPaddingsTensorShape( + logging_context, paddings_tensor, input_tensor.dims->size, + node->inputs->data[1], node_index)); + TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation( + logging_context, paddings_tensor, node->inputs->data[1], node_index)); + + const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; + TF_LITE_ENSURE_STATUS(CheckTensorFloatType( + logging_context, output_tensor, node->outputs->data[0], node_index)); + TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, output_tensor, 1, + XNN_MAX_TENSOR_DIMS, + node->outputs->data[0])); + TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( + logging_context, output_tensor, node->outputs->data[0], node_index)); + + const int32_t* paddings_data = + reinterpret_cast<const int32_t*>(paddings_tensor.data.data); + for (int i = 0; i < paddings_tensor.dims->size; i++) { + const int32_t pre_padding = paddings_data[i * 2 + 0]; + if (pre_padding < 0) { + TF_LITE_MAYBE_KERNEL_LOG( + logging_context, + "invalid pre-padding %d for dimension #%d in node %d", pre_padding, + i, node_index); + return kTfLiteError; + } + + const int32_t post_padding = paddings_data[i * 2 + 1]; + if (post_padding < 0) { + TF_LITE_MAYBE_KERNEL_LOG( + logging_context, + "invalid post-padding %d for dimension #%d in node %d", pre_padding, + i, node_index); + return kTfLiteError; + } + } + + if (subgraph != nullptr) { + std::array<size_t, XNN_MAX_TENSOR_DIMS> pre_paddings{}; + std::array<size_t, XNN_MAX_TENSOR_DIMS> post_paddings{}; + for (int i = 0; i < paddings_tensor.dims->data[0]; i++) { + pre_paddings[i] = static_cast<size_t>(paddings_data[i * 2 + 0]); + post_paddings[i] = static_cast<size_t>(paddings_data[i * 2 + 1]); + } + + const xnn_status status = xnn_define_static_constant_pad( + subgraph, pre_paddings.data(), post_paddings.data(), + /*padding_value=*/0.0f, + /*input_id=*/xnnpack_tensors[node->inputs->data[0]], + /*output_id=*/xnnpack_tensors[node->outputs->data[0]], /*flags=*/0); + if (status != xnn_status_success) { + TF_LITE_KERNEL_LOG(logging_context, "failed to delegate PAD node #%d", + node_index); + return kTfLiteError; + } + } + + return kTfLiteOk; + } + static TfLiteStatus VisitPreluNode( xnn_subgraph_t subgraph, TfLiteContext* logging_context, int node_index, TfLiteNode* node, const TfLiteTensor* tensors, diff --git a/tensorflow/lite/examples/label_image/README.md b/tensorflow/lite/examples/label_image/README.md index 09e9e77b86a..b283e169359 100644 --- a/tensorflow/lite/examples/label_image/README.md +++ b/tensorflow/lite/examples/label_image/README.md @@ -111,27 +111,41 @@ uniform 0.0379589: 907 Windsor tie 0.00735866: 466 bulletproof vest 0.00605307: To run a model with the Hexagon Delegate, assuming we have followed the [Hexagon Delegate Guide](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/performance/hexagon_delegate.md) -and installed Hexagon libraries in `/data/local/tmp`. Run it `adb shell -"/data/local/tmp/label_image \ -m +and installed Hexagon libraries in `/data/local/tmp`. Run it wth (`-j 1`) `adb +shell \ "/data/local/tmp/label_image \ -m /data/local/tmp/mobilenet_v1_1.0_224_quant.tflite \ -i /data/local/tmp/grace_hopper.bmp \ -l /data/local/tmp/labels.txt -j 1"` then you should see something like the followings: ``` Loaded model /data/local/tmp/mobilenet_v1_1.0_224_quant.tflite resolved reporter INFO: -Initialized TensorFlow Lite runtime. INFO: Created TensorFlow Lite delegate for -Hexagon. INFO: Hexagon delegate: 31 nodes delegated out of 31 nodes. +Initialized TensorFlow Lite runtime. loaded libcdsprpc.so INFO: Created +TensorFlow Lite delegate for Hexagon. INFO: Hexagon delegate: 31 nodes delegated +out of 31 nodes with 1 partitions. -remote_handle_control available and used Applied Hexagon delegate.invoked -average time: 8.307 ms 0.729412: 653 military uniform 0.0980392: 907 Windsor tie -0.0313726: 466 bulletproof vest 0.0313726: 458 bow tie 0.0117647: 700 panpipe -``` +Applied Hexagon delegate.invoked average time: 4.231 ms 0.639216: 458 bow tie +0.329412: 653 military uniform 0.00784314: 835 suit 0.00784314: 611 jersey +0.00392157: 514 cornet ``` -Run the model with the XNNPACK delegate (`-x 1`), `adb shell +Run the model with the XNNPACK delegate (`-x 1`), `adb shell \ "/data/local/tmp/label_image \ -m /data/local/tmp/mobilenet_v1_1.0_224.tflite \ -i /data/local/tmp/grace_hopper.bmp \ -l /data/local/tmp/labels.txt -x 1"` then you should see something like the followings: `Loaded model /data/local/tmp/mobilenet_v1_1.0_224.tflite resolved reporter INFO: Initialized -TensorFlow Lite runtime. Applied XNNPACK delegate.invoked average time: 11.0237 -ms 0.90707: 653 military uniform 0.0372418: 907 Windsor tie 0.0073376: 466 -bulletproof vest 0.00592856: 458 bow tie 0.00414093: 514 cornet` +TensorFlow Lite runtime. Applied XNNPACK delegate.invoked average time: 17.33 ms +0.90707: 653 military uniform 0.0372418: 907 Windsor tie 0.0073376: 466 +bulletproof vest 0.00592857: 458 bow tie 0.00414093: 514 cornet` + +With `-h` or any other unsupported flags, `label_image` will list supported +options `sargo:/data/local/tmp $ ./label_image -h ./label_image: invalid +option -- h label_image --accelerated, -a: [0|1], use Android NNAPI or not +--old_accelerated, -d: [0|1], use old Android NNAPI delegate or not +--allow_fp16, -f: [0|1], allow running fp32 models with fp16 or not --count, -c: +loop interpreter->Invoke() for certain times --gl_backend, -g: [0|1]: use GL GPU +Delegate on Android --hexagon_delegate, -j: [0|1]: use Hexagon Delegate on +Android --input_mean, -b: input mean --input_std, -s: input standard deviation +--image, -i: image_name.bmp --labels, -l: labels for the model --tflite_model, +-m: model_name.tflite --profiling, -p: [0|1], profiling or not --num_results, +-r: number of results to show --threads, -t: number of threads --verbose, -v: +[0|1] print more information --warmup_runs, -w: number of warmup runs +--xnnpack_delegate, -x [0:1]: xnnpack delegate` See the `label_image.cc` source code for other command line options. diff --git a/tensorflow/lite/examples/label_image/label_image.cc b/tensorflow/lite/examples/label_image/label_image.cc index ec744d70381..364ac325967 100644 --- a/tensorflow/lite/examples/label_image/label_image.cc +++ b/tensorflow/lite/examples/label_image/label_image.cc @@ -362,8 +362,8 @@ void display_usage() { << "--old_accelerated, -d: [0|1], use old Android NNAPI delegate or not\n" << "--allow_fp16, -f: [0|1], allow running fp32 models with fp16 or not\n" << "--count, -c: loop interpreter->Invoke() for certain times\n" - << "--gl_backend, -g: use GL GPU Delegate on Android\n" - << "--hexagon_delegate: use Hexagon Delegate on Android\n" + << "--gl_backend, -g: [0|1]: use GL GPU Delegate on Android\n" + << "--hexagon_delegate, -j: [0|1]: use Hexagon Delegate on Android\n" << "--input_mean, -b: input mean\n" << "--input_std, -s: input standard deviation\n" << "--image, -i: image_name.bmp\n" @@ -374,7 +374,7 @@ void display_usage() { << "--threads, -t: number of threads\n" << "--verbose, -v: [0|1] print more information\n" << "--warmup_runs, -w: number of warmup runs\n" - << "--xnnpack_delegate, -x: xnnpack delegate\n" + << "--xnnpack_delegate, -x [0:1]: xnnpack delegate\n" << "\n"; } diff --git a/tensorflow/lite/experimental/delegates/coreml/BUILD b/tensorflow/lite/experimental/delegates/coreml/BUILD index c04aba65aa0..193f2e0223b 100644 --- a/tensorflow/lite/experimental/delegates/coreml/BUILD +++ b/tensorflow/lite/experimental/delegates/coreml/BUILD @@ -56,6 +56,7 @@ objc_library( "//tensorflow/lite:kernel_api", "//tensorflow/lite:minimal_logging", "//tensorflow/lite/c:common", + "//tensorflow/lite/delegates:utils", "//tensorflow/lite/experimental/delegates/coreml/builders:op_builder", ], ) diff --git a/tensorflow/lite/experimental/delegates/coreml/builders/op_builder.cc b/tensorflow/lite/experimental/delegates/coreml/builders/op_builder.cc index 2581b58f1e4..46634d6970a 100644 --- a/tensorflow/lite/experimental/delegates/coreml/builders/op_builder.cc +++ b/tensorflow/lite/experimental/delegates/coreml/builders/op_builder.cc @@ -95,6 +95,7 @@ CoreML::Specification::Model* GraphBuilder::BuildModel() { CoreML::Specification::EXACT_ARRAY_MAPPING); } else { fprintf(stderr, "Unsupported Core ML version: %d\n", coreml_version_); + delete model; return nullptr; } auto* neural_network = model->mutable_neuralnetwork(); diff --git a/tensorflow/lite/experimental/delegates/hexagon/README.md b/tensorflow/lite/experimental/delegates/hexagon/README.md index 6e627c17cd2..106ddce038b 100644 --- a/tensorflow/lite/experimental/delegates/hexagon/README.md +++ b/tensorflow/lite/experimental/delegates/hexagon/README.md @@ -86,6 +86,7 @@ are verified in `IsNodeSupportedByHexagon`: * MirrorPad * Mul (without any activation) (b/129276536) * Neg +* Pack * Pad: Only supports 0 padding (b/139277813) * Quantize (8-bit inputs & outputs only) * Relu @@ -95,6 +96,7 @@ are verified in `IsNodeSupportedByHexagon`: * Constraints: - Requested size <= 65 (b/143105433) * Resize Nearest Neighbor +* Slice * SoftMax * SpaceToDepth * Split diff --git a/tensorflow/lite/experimental/delegates/hexagon/builders/BUILD b/tensorflow/lite/experimental/delegates/hexagon/builders/BUILD index e24adc2537c..feadd096c54 100644 --- a/tensorflow/lite/experimental/delegates/hexagon/builders/BUILD +++ b/tensorflow/lite/experimental/delegates/hexagon/builders/BUILD @@ -23,6 +23,7 @@ cc_library( "mirror_pad_builder.cc", "neg_op_builder.cc", "op_builder.cc", + "pack_builder.cc", "pad_builder.cc", "pool_2d_builder.cc", "quantize_builder.cc", @@ -30,6 +31,7 @@ cc_library( "reshape_builder.cc", "resize_bilinear_builder.cc", "resize_nearest_neighbor_builder.cc", + "slice_builder.cc", "softmax_builder.cc", "space_to_depth_builder.cc", "split_builder.cc", @@ -51,6 +53,7 @@ cc_library( "mirror_pad_builder.h", "neg_op_builder.h", "op_builder.h", + "pack_builder.h", "pad_builder.h", "pool_2d_builder.h", "quantize_builder.h", @@ -58,6 +61,7 @@ cc_library( "reshape_builder.h", "resize_bilinear_builder.h", "resize_nearest_neighbor_builder.h", + "slice_builder.h", "softmax_builder.h", "space_to_depth_builder.h", "split_builder.h", @@ -78,6 +82,7 @@ cc_library( "//tensorflow/lite/kernels:kernel_util", "//tensorflow/lite/kernels:padding", "//tensorflow/lite/kernels/internal:optimized_base", + "//tensorflow/lite/kernels/internal:tensor", "@hexagon_nn//:hexagon_nn_ops", ], ) diff --git a/tensorflow/lite/experimental/delegates/hexagon/builders/matmul_builder.cc b/tensorflow/lite/experimental/delegates/hexagon/builders/matmul_builder.cc index c53e62d27a7..894f98269ce 100644 --- a/tensorflow/lite/experimental/delegates/hexagon/builders/matmul_builder.cc +++ b/tensorflow/lite/experimental/delegates/hexagon/builders/matmul_builder.cc @@ -18,7 +18,9 @@ limitations under the License. #include <limits> +#include "hexagon/hexagon_nn_ops.h" #include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/experimental/delegates/hexagon/hexagon_nn/hexagon_nn.h" #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/lite/kernels/kernel_util.h" @@ -27,9 +29,124 @@ namespace tflite { namespace delegates { namespace hexagon { namespace { +void GetDims(int* batch_size, int* height_size, int* width_size, + int* depth_size, const TfLiteIntArray* dims) { + int* dim[] = {batch_size, height_size, width_size, depth_size}; + for (int i = 0; i < 4; ++i) *(dim[i]) = 1; + for (int i = 4 - dims->size; i < 4; ++i) { + *dim[i] = dims->data[i - (4 - dims->size)]; + } +} constexpr uint8_t k8BitSignFlipConstant = 0x80; +TfLiteStatus AddFullyConnectedHelper(const TfLiteIntArray* inputs, + const TfLiteIntArray* outputs, + const OpBuilder::TensorID weights_id, + const OpBuilder::TensorID weights_min_id, + const OpBuilder::TensorID weights_max_id, + GraphBuilder* graph_builder, + TfLiteContext* context, + OpBuilder* matmul_op, + OpBuilder::TensorID* node_output) { + static int scalar_shape[] = {1, 1, 1, 1}; + // Data tensor. + int data_tensor_id = inputs->data[0]; + const auto& data_tensor = context->tensors[data_tensor_id]; + float data_min, data_max; + TF_LITE_ENSURE_STATUS(OpBuilder::ComputeMinAndMaxQuantValues( + data_tensor, &data_min, &data_max)); + auto* data_min_const = graph_builder->AddConstNodeWithData( + scalar_shape, reinterpret_cast<char*>(&data_min), sizeof(data_min)); + auto* data_max_const = graph_builder->AddConstNodeWithData( + scalar_shape, reinterpret_cast<char*>(&data_max), sizeof(data_max)); + + // Data and weight tensors in required order. + matmul_op->AddInput(graph_builder->GetHexagonTensorId(data_tensor_id)); + matmul_op->AddInput(weights_id); + matmul_op->AddInput(OpBuilder::TensorID(data_min_const->GetID(), 0)); + matmul_op->AddInput(OpBuilder::TensorID(data_max_const->GetID(), 0)); + matmul_op->AddInput(weights_min_id); + matmul_op->AddInput(weights_max_id); + + // Outputs for the MatMul node, which are in int32 format. + // Output shape should still be the same. + int output_batch_size, output_height_size, output_width_size, + output_depth_size; + GetDims(&output_batch_size, &output_height_size, &output_width_size, + &output_depth_size, context->tensors[outputs->data[0]].dims); + const auto& matmul_out = + matmul_op->AddOutput(sizeof(int32_t), 4, + {output_batch_size, output_height_size, + output_width_size, output_depth_size}); + const auto& matmul_out_min = + matmul_op->AddOutput(sizeof(float), 4, {1, 1, 1, 1}); + const auto& matmul_out_max = + matmul_op->AddOutput(sizeof(float), 4, {1, 1, 1, 1}); + + // Bias tensor. + int bias_tensor_id = inputs->data[2]; + OpBuilder::TensorID matmul_and_bias_out = matmul_out, + matmul_and_bias_out_min = matmul_out_min, + matmul_and_bias_out_max = matmul_out_max; + if (bias_tensor_id != -1) { + const auto& bias_tensor = context->tensors[bias_tensor_id]; + auto* const_bias_node = + graph_builder->AddConstNodeWithData(bias_tensor_id, bias_tensor); + float bias_min, bias_max; + graph_builder->AddTensorWithID(bias_tensor_id, const_bias_node->GetID(), 0); + OpBuilder::ComputeMinAndMaxQuantValues(bias_tensor, &bias_min, &bias_max); + auto* bias_min_const = graph_builder->AddConstNodeWithData( + scalar_shape, reinterpret_cast<char*>(&bias_min), sizeof(bias_min)); + auto* bias_max_const = graph_builder->AddConstNodeWithData( + scalar_shape, reinterpret_cast<char*>(&bias_max), sizeof(bias_max)); + + // MatMul + Bias. + auto* bias_add_op = graph_builder->AddNode(matmul_op->GetTFLiteNodeID()); + bias_add_op->SetOpType(OP_QuantizedBiasAdd_32p32to32); + bias_add_op->AddInput(matmul_out); + bias_add_op->AddInput(graph_builder->GetHexagonTensorId(bias_tensor_id)); + bias_add_op->AddInput(matmul_out_min); + bias_add_op->AddInput(matmul_out_max); + bias_add_op->AddInput(OpBuilder::TensorID(bias_min_const->GetID(), 0)); + bias_add_op->AddInput(OpBuilder::TensorID(bias_max_const->GetID(), 0)); + matmul_and_bias_out = + bias_add_op->AddOutput(sizeof(int32_t), 4, + {output_batch_size, output_height_size, + output_width_size, output_depth_size}); + matmul_and_bias_out_min = + bias_add_op->AddOutput(sizeof(float), 4, {1, 1, 1, 1}); + matmul_and_bias_out_max = + bias_add_op->AddOutput(sizeof(float), 4, {1, 1, 1, 1}); + } + + float output_min, output_max; + // Quantize 32-bit result into 8-bit format using output tensor min/max. + OpBuilder::ComputeMinAndMaxQuantValues(context->tensors[outputs->data[0]], + &output_min, &output_max); + auto* output_min_const = graph_builder->AddConstNodeWithData( + scalar_shape, reinterpret_cast<char*>(&output_min), sizeof(output_min)); + auto* output_max_const = graph_builder->AddConstNodeWithData( + scalar_shape, reinterpret_cast<char*>(&output_max), sizeof(output_max)); + auto* quantize_biasadd_op = + graph_builder->AddNode(matmul_op->GetTFLiteNodeID()); + quantize_biasadd_op->SetOpType(OP_Requantize_32to8); + quantize_biasadd_op->AddInput(matmul_and_bias_out); + quantize_biasadd_op->AddInput(matmul_and_bias_out_min); + quantize_biasadd_op->AddInput(matmul_and_bias_out_max); + quantize_biasadd_op->AddInput( + OpBuilder::TensorID(output_min_const->GetID(), 0)); + quantize_biasadd_op->AddInput( + OpBuilder::TensorID(output_max_const->GetID(), 0)); + *node_output = + quantize_biasadd_op->AddOutput(sizeof(uint8_t), 4, + {output_batch_size, output_height_size, + output_width_size, output_depth_size}); + quantize_biasadd_op->AddOutput(sizeof(float), 4, {1, 1, 1, 1}); + quantize_biasadd_op->AddOutput(sizeof(float), 4, {1, 1, 1, 1}); + return kTfLiteOk; +} + } // namespace // The TFLite 'Fully-connected' quantized op corresponds to the following @@ -38,27 +155,14 @@ constexpr uint8_t k8BitSignFlipConstant = 0x80; // MatMul out (int32), Bias (int32) => QuantizedBiasAdd => BiasAdd out (int32) // BiasAdd out (int32) => Requantize_32to8 => Output (8-bit) // TODO(b/129276536): Add activation support. -TfLiteStatus MatMulOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs, - const TfLiteIntArray* outputs, - TfLiteContext* context) { +TfLiteStatus MatMulWithConstWeightsOpBuilder::PopulateSubGraph( + const TfLiteIntArray* inputs, const TfLiteIntArray* outputs, + TfLiteContext* context) { static int quant_bound_shape[] = {1, 1, 1, 1}; - // Data tensor. - int data_tensor_id = inputs->data[0]; - const auto& data_tensor = context->tensors[data_tensor_id]; - TF_LITE_ENSURE_STATUS( - ComputeMinAndMaxQuantValues(data_tensor, &data_min_, &data_max_)); - auto* data_min_const = graph_builder_->AddConstNodeWithData( - quant_bound_shape, reinterpret_cast<char*>(&data_min_), - sizeof(data_min_)); - auto* data_max_const = graph_builder_->AddConstNodeWithData( - quant_bound_shape, reinterpret_cast<char*>(&data_max_), - sizeof(data_max_)); - // Weights vector. int weights_tensor_id = inputs->data[1]; const auto& weights_tensor = context->tensors[weights_tensor_id]; - // TODO(srjoglekar): Abstract out. if (weights_tensor.allocation_type != kTfLiteMmapRo) { context->ReportError( context, "Weights tensor doesn't have correct allocation type: %s", @@ -107,84 +211,74 @@ TfLiteStatus MatMulOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs, quant_bound_shape, reinterpret_cast<char*>(&weights_max_), sizeof(weights_max_)); - // Data and weight tensors in required order. - AddInput(graph_builder_->GetHexagonTensorId(data_tensor_id)); + return AddFullyConnectedHelper( + inputs, outputs, graph_builder_->GetHexagonTensorId(weights_tensor_id), + TensorID(weights_min_const->GetID(), 0), + TensorID(weights_max_const->GetID(), 0), graph_builder_, context, this, + &node_output_); +} + +TfLiteStatus MatMulWithConstWeightsOpBuilder::RegisterOutputs( + const TfLiteIntArray* outputs, TfLiteContext* context) { + // Should be only 1 output. + graph_builder_->AddTensorWithID(outputs->data[0], node_output_.first, + node_output_.second); + return kTfLiteOk; +} + +TfLiteStatus MatMulOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs, + const TfLiteIntArray* outputs, + TfLiteContext* context) { + static int scalar_shape[] = {1, 1, 1, 1}; + const int weights_tensor_id = inputs->data[1]; + const auto& weights_tensor = context->tensors[weights_tensor_id]; + int batch_size, height_size, width_size, depth_size; + GetDims(&batch_size, &height_size, &width_size, &depth_size, + weights_tensor.dims); + weights_shape_ = {batch_size, height_size, depth_size, width_size}; + // Permutation for transposing. + int permutation[] = {0, 1, 3, 2}; + const int permutation_shape[] = {1, 1, 1, 4}; + auto permutation_node = graph_builder_->AddConstNodeWithData( + permutation_shape, reinterpret_cast<char*>(permutation), + 4 * sizeof(permutation[0])); AddInput(graph_builder_->GetHexagonTensorId(weights_tensor_id)); - AddInput(TensorID(data_min_const->GetID(), 0)); - AddInput(TensorID(data_max_const->GetID(), 0)); + AddInput(TensorID(permutation_node->GetID(), 0)); + + ComputeMinAndMaxQuantValues(weights_tensor, &weights_min_, &weights_max_); + auto* weights_min_const = graph_builder_->AddConstNodeWithData( + scalar_shape, reinterpret_cast<char*>(&weights_min_), + sizeof(weights_min_)); + auto* weights_max_const = graph_builder_->AddConstNodeWithData( + scalar_shape, reinterpret_cast<char*>(&weights_max_), + sizeof(weights_max_)); AddInput(TensorID(weights_min_const->GetID(), 0)); AddInput(TensorID(weights_max_const->GetID(), 0)); - // Outputs for the MatMul node, which are in int32 format. - // Output shape should still be the same. - int output_batch_size, output_height_size, output_width_size, - output_depth_size; - GetDims(&output_batch_size, &output_height_size, &output_width_size, - &output_depth_size, context->tensors[outputs->data[0]].dims); - const auto& matmul_out = AddOutput(sizeof(int32_t), 4, - {output_batch_size, output_height_size, - output_width_size, output_depth_size}); - const auto& matmul_out_min = AddOutput(sizeof(float), 4, {1, 1, 1, 1}); - const auto& matmul_out_max = AddOutput(sizeof(float), 4, {1, 1, 1, 1}); + auto transposed_weights = AddOutput(sizeof(uint8_t), 4, weights_shape_); + auto transposed_weights_min = AddOutput(sizeof(float), 4, {1, 1, 1, 1}); + auto transposed_weights_max = AddOutput(sizeof(float), 4, {1, 1, 1, 1}); - // Bias tensor. - int bias_tensor_id = inputs->data[2]; - const auto& bias_tensor = context->tensors[bias_tensor_id]; - auto* const_bias_node = - graph_builder_->AddConstNodeWithData(bias_tensor_id, bias_tensor); - graph_builder_->AddTensorWithID(bias_tensor_id, const_bias_node->GetID(), 0); - ComputeMinAndMaxQuantValues(bias_tensor, &bias_min_, &bias_max_); - auto* bias_min_const = graph_builder_->AddConstNodeWithData( - quant_bound_shape, reinterpret_cast<char*>(&bias_min_), - sizeof(bias_min_)); - auto* bias_max_const = graph_builder_->AddConstNodeWithData( - quant_bound_shape, reinterpret_cast<char*>(&bias_max_), - sizeof(bias_max_)); - - // MatMul + Bias. - auto* bias_add_op = graph_builder_->AddNode(GetTFLiteNodeID()); - bias_add_op->SetOpType(OP_QuantizedBiasAdd_32p32to32); - bias_add_op->AddInput(matmul_out); - bias_add_op->AddInput(graph_builder_->GetHexagonTensorId(bias_tensor_id)); - bias_add_op->AddInput(matmul_out_min); - bias_add_op->AddInput(matmul_out_max); - bias_add_op->AddInput(TensorID(bias_min_const->GetID(), 0)); - bias_add_op->AddInput(TensorID(bias_max_const->GetID(), 0)); - const auto& bias_add_out = - bias_add_op->AddOutput(sizeof(int32_t), 4, - {output_batch_size, output_height_size, - output_width_size, output_depth_size}); - const auto& bias_add_out_min = - bias_add_op->AddOutput(sizeof(float), 4, {1, 1, 1, 1}); - const auto& bias_add_out_max = - bias_add_op->AddOutput(sizeof(float), 4, {1, 1, 1, 1}); - - // Quantize 32-bit result into 8-bit format using output tensor min/max. - ComputeMinAndMaxQuantValues(context->tensors[outputs->data[0]], &output_min_, - &output_max_); - auto* output_min_const = graph_builder_->AddConstNodeWithData( - quant_bound_shape, reinterpret_cast<char*>(&output_min_), - sizeof(output_min_)); - auto* output_max_const = graph_builder_->AddConstNodeWithData( - quant_bound_shape, reinterpret_cast<char*>(&output_max_), - sizeof(output_max_)); - auto* quantize_biasadd_op = graph_builder_->AddNode(GetTFLiteNodeID()); - quantize_biasadd_op->SetOpType(OP_Requantize_32to8); - quantize_biasadd_op->AddInput(bias_add_out); - quantize_biasadd_op->AddInput(bias_add_out_min); - quantize_biasadd_op->AddInput(bias_add_out_max); - quantize_biasadd_op->AddInput(TensorID(output_min_const->GetID(), 0)); - quantize_biasadd_op->AddInput(TensorID(output_max_const->GetID(), 0)); - node_output_ = - quantize_biasadd_op->AddOutput(sizeof(uint8_t), 4, - {output_batch_size, output_height_size, - output_width_size, output_depth_size}); - quantize_biasadd_op->AddOutput(sizeof(float), 4, {1, 1, 1, 1}); - quantize_biasadd_op->AddOutput(sizeof(float), 4, {1, 1, 1, 1}); + auto* matmul_op = graph_builder_->AddNode(GetTFLiteNodeID()); + matmul_op->SetOpType(OP_QuantizedMatMul_8x8to32); + AddFullyConnected(inputs, outputs, transposed_weights, transposed_weights_min, + transposed_weights_max, context, matmul_op); return kTfLiteOk; } +TfLiteStatus MatMulOpBuilder::AddFullyConnected(const TfLiteIntArray* inputs, + const TfLiteIntArray* outputs, + const TensorID weights_id, + const TensorID weights_min_id, + const TensorID weights_max_id, + TfLiteContext* context, + OpBuilder* matmul_op) { + return AddFullyConnectedHelper(inputs, outputs, weights_id, weights_min_id, + weights_max_id, graph_builder_, context, + matmul_op, &node_output_); +} + TfLiteStatus MatMulOpBuilder::RegisterOutputs(const TfLiteIntArray* outputs, TfLiteContext* context) { // Should be only 1 output. @@ -193,9 +287,12 @@ TfLiteStatus MatMulOpBuilder::RegisterOutputs(const TfLiteIntArray* outputs, return kTfLiteOk; } -MatMulOpBuilder::~MatMulOpBuilder() {} +OpBuilder* CreateMatMulWithConstWeightsOpBuilder(GraphBuilder* graph_builder, + int op_type) { + return new MatMulWithConstWeightsOpBuilder(graph_builder, op_type); +} -OpBuilder* CreateMatMulBuilder(GraphBuilder* graph_builder, int op_type) { +OpBuilder* CreateMatMulOpBuilder(GraphBuilder* graph_builder, int op_type) { return new MatMulOpBuilder(graph_builder, op_type); } diff --git a/tensorflow/lite/experimental/delegates/hexagon/builders/matmul_builder.h b/tensorflow/lite/experimental/delegates/hexagon/builders/matmul_builder.h index 212ea7be7a3..89f3c1273d7 100644 --- a/tensorflow/lite/experimental/delegates/hexagon/builders/matmul_builder.h +++ b/tensorflow/lite/experimental/delegates/hexagon/builders/matmul_builder.h @@ -23,6 +23,28 @@ namespace tflite { namespace delegates { namespace hexagon { +// Builder for FullyConnected op in Hexagon with weights as const. +class MatMulWithConstWeightsOpBuilder : public OpBuilder { + public: + explicit MatMulWithConstWeightsOpBuilder(GraphBuilder* graph_builder, + int op_type) + : OpBuilder(graph_builder, op_type) {} + TfLiteStatus PopulateSubGraph(const TfLiteIntArray* inputs, + const TfLiteIntArray* outputs, + TfLiteContext* context) override; + + TfLiteStatus RegisterOutputs(const TfLiteIntArray* outputs, + TfLiteContext* context) override; + + private: + TensorID node_output_; + std::vector<int> weights_shape_, bias_shape_; + std::vector<float> transposed_weights_; + float data_min_, data_max_, weights_min_, weights_max_, bias_min_, bias_max_, + output_min_, output_max_; +}; + +// Builder for FullyConnected op in Hexagon with non const weights. class MatMulOpBuilder : public OpBuilder { public: explicit MatMulOpBuilder(GraphBuilder* graph_builder, int op_type) @@ -34,9 +56,15 @@ class MatMulOpBuilder : public OpBuilder { TfLiteStatus RegisterOutputs(const TfLiteIntArray* outputs, TfLiteContext* context) override; - ~MatMulOpBuilder() override; - private: + // Adds Fully connected op related ops to the graph. + TfLiteStatus AddFullyConnected(const TfLiteIntArray* inputs, + const TfLiteIntArray* outputs, + const TensorID weights_id, + const TensorID weights_min_id, + const TensorID weights_max_id, + TfLiteContext* context, OpBuilder* matmul_op); + TensorID node_output_; std::vector<int> weights_shape_, bias_shape_; std::vector<float> transposed_weights_; diff --git a/tensorflow/lite/experimental/delegates/hexagon/builders/op_builder.cc b/tensorflow/lite/experimental/delegates/hexagon/builders/op_builder.cc index 230a292b6fe..d851f8cf824 100644 --- a/tensorflow/lite/experimental/delegates/hexagon/builders/op_builder.cc +++ b/tensorflow/lite/experimental/delegates/hexagon/builders/op_builder.cc @@ -23,7 +23,8 @@ namespace tflite { namespace delegates { namespace hexagon { -OpBuilder* GraphBuilder::CreateOpBuilderFromTfLiteOp(int op_type) { +OpBuilder* GraphBuilder::CreateOpBuilderFromTfLiteOp(int op_type, + TfLiteNode* node) { switch (op_type) { case kTfLiteBuiltinAdd: return CreateArithmeticBuilder(this, OP_QuantizedAdd_8p8to8); @@ -45,8 +46,14 @@ OpBuilder* GraphBuilder::CreateOpBuilderFromTfLiteOp(int op_type) { return CreatePadBuilder(this, OP_QuantizedPad_8); case kTfLiteBuiltinMirrorPad: return CreateMirrorPadBuilder(this, OP_MirrorPad_8); - case kTfLiteBuiltinFullyConnected: - return CreateMatMulBuilder(this, OP_QuantizedMatMul_8x8to32); + case kTfLiteBuiltinFullyConnected: { + const auto& weights_tensor = context_->tensors[node->inputs->data[1]]; + if (weights_tensor.allocation_type == kTfLiteMmapRo) + return CreateMatMulWithConstWeightsOpBuilder( + this, OP_QuantizedMatMul_8x8to32); + else + return CreateMatMulOpBuilder(this, OP_Transpose_8); + } case kTfLiteBuiltinAveragePool2d: return CreatePool2DBuilder(this, OP_QuantizedAvgPool_8); case kTfLiteBuiltinMaxPool2d: @@ -97,6 +104,10 @@ OpBuilder* GraphBuilder::CreateOpBuilderFromTfLiteOp(int op_type) { return CreateMinMaxBuilder(this, OP_QuantizedMinimum_8); case kTfLiteBuiltinMaximum: return CreateMinMaxBuilder(this, OP_QuantizedMaximum_8); + case kTfLiteBuiltinSlice: + return CreateSliceOpBuilder(this, OP_QuantizedSlice_8); + case kTfLiteBuiltinPack: + return CreatePackBuilder(this, OP_QuantizedPack_8); default: context_->ReportError(context_, "Op not supported: %d", op_type); return nullptr; @@ -267,7 +278,7 @@ OpBuilder* GraphBuilder::AddNode(int tflite_node_index) { OpBuilder* GraphBuilder::AddNodeFromTfLiteOp(int op_type, TfLiteNode* node, int tflite_node_index) { - OpBuilder* op = CreateOpBuilderFromTfLiteOp(op_type); + OpBuilder* op = CreateOpBuilderFromTfLiteOp(op_type, node); builders_.emplace_back(op); op->SetNodeId(builders_.size()); op->SetTFLiteNodeId(tflite_node_index); diff --git a/tensorflow/lite/experimental/delegates/hexagon/builders/op_builder.h b/tensorflow/lite/experimental/delegates/hexagon/builders/op_builder.h index 267fc818ca1..743323c8bd3 100644 --- a/tensorflow/lite/experimental/delegates/hexagon/builders/op_builder.h +++ b/tensorflow/lite/experimental/delegates/hexagon/builders/op_builder.h @@ -197,7 +197,7 @@ class GraphBuilder { // Same as above but takes shape of the tensor that will holds the data. OpBuilder* AddConstNodeWithData(const int shape[], char* data, int data_size); - OpBuilder* CreateOpBuilderFromTfLiteOp(int op_type); + OpBuilder* CreateOpBuilderFromTfLiteOp(int op_type, TfLiteNode* node); // Construct Input node with 'input_tensors' as output. TfLiteStatus AddInputTensors(const TfLiteIntArray* input_tensors, diff --git a/tensorflow/lite/experimental/delegates/hexagon/builders/op_factory.h b/tensorflow/lite/experimental/delegates/hexagon/builders/op_factory.h index 515d0edb929..33b56e91f0a 100644 --- a/tensorflow/lite/experimental/delegates/hexagon/builders/op_factory.h +++ b/tensorflow/lite/experimental/delegates/hexagon/builders/op_factory.h @@ -26,7 +26,8 @@ class OpBuilder; OpBuilder* CreateArgMinMaxOpBuilder(GraphBuilder* graph_builder, int op_type); OpBuilder* CreateActivationBuilder(GraphBuilder* graph_builder, int op_type); OpBuilder* CreateArithmeticBuilder(GraphBuilder* graph_builder, int op_type); -OpBuilder* CreateMatMulBuilder(GraphBuilder* graph_builder, int op_type); +OpBuilder* CreateMatMulWithConstWeightsOpBuilder(GraphBuilder* graph_builder, + int op_type); OpBuilder* CreateConcatBuilder(GraphBuilder* graph_builder, int op_type); OpBuilder* CreateConv2DBuilder(GraphBuilder* graph_builder, int op_type); OpBuilder* CreateTransposeConv2DBuilder(GraphBuilder* graph_builder, @@ -55,6 +56,9 @@ OpBuilder* CreateQuantizeBuilder(GraphBuilder* graph_builder, int op_type); OpBuilder* CreateHardSwishBuilder(GraphBuilder* graph_builder, int op_type); OpBuilder* CreateCastBuilder(GraphBuilder* graph_builder, int op_type); OpBuilder* CreateMinMaxBuilder(GraphBuilder* graph_builder, int op_type); +OpBuilder* CreateSliceOpBuilder(GraphBuilder* graph_builder, int op_type); +OpBuilder* CreatePackBuilder(GraphBuilder* graph_builder, int op_type); +OpBuilder* CreateMatMulOpBuilder(GraphBuilder* graph_builder, int op_type); } // namespace hexagon } // namespace delegates diff --git a/tensorflow/lite/experimental/delegates/hexagon/builders/pack_builder.cc b/tensorflow/lite/experimental/delegates/hexagon/builders/pack_builder.cc new file mode 100644 index 00000000000..1d99f3bbb8d --- /dev/null +++ b/tensorflow/lite/experimental/delegates/hexagon/builders/pack_builder.cc @@ -0,0 +1,134 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/experimental/delegates/hexagon/builders/pack_builder.h" + +#include <stdint.h> + +#include <limits> + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/kernels/kernel_util.h" + +namespace tflite { +namespace delegates { +namespace hexagon { +namespace { + +int GetAxis(int axis, const TfLiteIntArray* inputs, TfLiteContext* context) { + auto& input_tensor = context->tensors[inputs->data[0]]; + // Handle -ve axis. + if (axis < 0) { + axis += input_tensor.dims->size + 1; + } + // We need to adjust the axis to be as if the inputs are of rank 4, since + // we represent tensors in Hexagon of rank 4. + return (4 - input_tensor.dims->size) + axis - 1; +} + +} // namespace +TfLiteStatus PackOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs, + const TfLiteIntArray* outputs, + TfLiteContext* context) { + static int scalar_shape[] = {1, 1, 1, 1}; + auto* params = reinterpret_cast<TfLitePackParams*>(builtin_data_); + int axis = GetAxis(params->axis, inputs, context); + // Add axis + auto* axis_node = graph_builder_->AddConstNodeWithData( + scalar_shape, reinterpret_cast<char*>(&axis), sizeof(axis)); + AddInput(TensorID(axis_node->GetID(), 0)); + + // Add all input tensors. + minima_.reserve(inputs->size); + maxima_.reserve(inputs->size); + int tensor_id = -1; + float data_min, data_max; + for (int i = 0; i < inputs->size; ++i) { + tensor_id = inputs->data[i]; + auto& input_tensor = context->tensors[tensor_id]; + AddInput(graph_builder_->GetHexagonTensorId(tensor_id)); + TF_LITE_ENSURE_STATUS( + ComputeMinAndMaxQuantValues(input_tensor, &data_min, &data_max)); + minima_.push_back(data_min); + maxima_.push_back(data_max); + } + + // Minima tensors. + for (int i = 0; i < minima_.size(); ++i) { + auto* data_min_const = graph_builder_->AddConstNodeWithData( + scalar_shape, reinterpret_cast<char*>(&minima_[i]), sizeof(minima_[i])); + AddInput(TensorID(data_min_const->GetID(), 0)); + } + + // Maxima tensors. + for (int i = 0; i < maxima_.size(); ++i) { + auto* data_max_const = graph_builder_->AddConstNodeWithData( + scalar_shape, reinterpret_cast<char*>(&maxima_[i]), sizeof(maxima_[i])); + AddInput(TensorID(data_max_const->GetID(), 0)); + } + + // Hexagon outputs for this node. + int output_batch_size, output_height_size, output_width_size, + output_depth_size; + GetDims(&output_batch_size, &output_height_size, &output_width_size, + &output_depth_size, context->tensors[outputs->data[0]].dims); + + TensorID pack_out = AddOutput(sizeof(uint8_t), 4, + {output_batch_size, output_height_size, + output_width_size, output_depth_size}); + + // Output min/max for requantization. + float output_min, output_max; + TF_LITE_ENSURE_STATUS(ComputeMinAndMaxQuantValues( + context->tensors[outputs->data[0]], &output_min, &output_max)); + auto* output_min_const = graph_builder_->AddConstNodeWithData( + scalar_shape, reinterpret_cast<char*>(&output_min), sizeof(output_min)); + auto* output_max_const = graph_builder_->AddConstNodeWithData( + scalar_shape, reinterpret_cast<char*>(&output_max), sizeof(output_max)); + + const auto& pack_out_min = AddOutput(sizeof(float), 4, {1, 1, 1, 1}); + const auto& pack_out_max = AddOutput(sizeof(float), 4, {1, 1, 1, 1}); + + // Requantize output to the expected min/max. + auto* requantize_op = graph_builder_->AddNode(GetTFLiteNodeID()); + requantize_op->SetOpType(OP_Requantize_8to8); + requantize_op->AddInput(pack_out); + requantize_op->AddInput(pack_out_min); + requantize_op->AddInput(pack_out_max); + requantize_op->AddInput(TensorID(output_min_const->GetID(), 0)); + requantize_op->AddInput(TensorID(output_max_const->GetID(), 0)); + node_output_ = + requantize_op->AddOutput(sizeof(uint8_t), 4, + {output_batch_size, output_height_size, + output_width_size, output_depth_size}); + requantize_op->AddOutput(sizeof(float), 4, {1, 1, 1, 1}); + requantize_op->AddOutput(sizeof(float), 4, {1, 1, 1, 1}); + return kTfLiteOk; +} + +TfLiteStatus PackOpBuilder::RegisterOutputs(const TfLiteIntArray* outputs, + TfLiteContext* context) { + // Should be only 1 output. + graph_builder_->AddTensorWithID(outputs->data[0], node_output_.first, + node_output_.second); + return kTfLiteOk; +} + +OpBuilder* CreatePackBuilder(GraphBuilder* graph_builder, int op_type) { + return new PackOpBuilder(graph_builder, op_type); +} + +} // namespace hexagon +} // namespace delegates +} // namespace tflite diff --git a/tensorflow/lite/experimental/delegates/hexagon/builders/pack_builder.h b/tensorflow/lite/experimental/delegates/hexagon/builders/pack_builder.h new file mode 100644 index 00000000000..a372c519c01 --- /dev/null +++ b/tensorflow/lite/experimental/delegates/hexagon/builders/pack_builder.h @@ -0,0 +1,46 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_HEXAGON_BUILDERS_PACK_BUILDER_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_HEXAGON_BUILDERS_PACK_BUILDER_H_ +#include <vector> + +#include "tensorflow/lite/experimental/delegates/hexagon/builders/op_builder.h" + +namespace tflite { +namespace delegates { +namespace hexagon { + +class PackOpBuilder : public OpBuilder { + public: + explicit PackOpBuilder(GraphBuilder* graph_builder, int op_type) + : OpBuilder(graph_builder, op_type) {} + TfLiteStatus PopulateSubGraph(const TfLiteIntArray* inputs, + const TfLiteIntArray* outputs, + TfLiteContext* context) override; + + TfLiteStatus RegisterOutputs(const TfLiteIntArray* outputs, + TfLiteContext* context) override; + + private: + TensorID node_output_; + // Min/max for all inputs. + std::vector<float> minima_, maxima_; +}; + +} // namespace hexagon +} // namespace delegates +} // namespace tflite + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_HEXAGON_BUILDERS_PACK_BUILDER_H_ diff --git a/tensorflow/lite/experimental/delegates/hexagon/builders/slice_builder.cc b/tensorflow/lite/experimental/delegates/hexagon/builders/slice_builder.cc new file mode 100644 index 00000000000..cc282343f0c --- /dev/null +++ b/tensorflow/lite/experimental/delegates/hexagon/builders/slice_builder.cc @@ -0,0 +1,106 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/experimental/delegates/hexagon/builders/slice_builder.h" + +#include <vector> + +#include "tensorflow/lite/kernels/internal/tensor.h" + +namespace tflite { +namespace delegates { +namespace hexagon { +namespace { +template <typename T> +void GetBeginAndSizeVectors(int dimensions, const TfLiteTensor* begin, + const TfLiteTensor* size, std::vector<int>* begins, + std::vector<int>* sizes) { + for (int i = 0; i < dimensions; ++i) { + begins->push_back(GetTensorData<T>(begin)[i]); + sizes->push_back(GetTensorData<T>(size)[i]); + } +} +} // namespace + +TfLiteStatus SliceOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs, + const TfLiteIntArray* outputs, + TfLiteContext* context) { + static int quant_bound_shape[] = {1, 1, 1, 1}; + + // Input data tensor. + const int tensor_id = inputs->data[0]; + const auto& input_tensor = context->tensors[tensor_id]; + AddInput(graph_builder_->GetHexagonTensorId(tensor_id)); + // Start / Size + const auto& begin_tensor = context->tensors[inputs->data[1]]; + const auto& size_tensor = context->tensors[inputs->data[2]]; + std::vector<int32_t> begins, sizes; + if (begin_tensor.type == kTfLiteInt32) { + GetBeginAndSizeVectors<int32_t>(input_tensor.dims->size, &begin_tensor, + &size_tensor, &begins, &sizes); + } else if (begin_tensor.type == kTfLiteInt64) { + GetBeginAndSizeVectors<int64_t>(input_tensor.dims->size, &begin_tensor, + &size_tensor, &begins, &sizes); + } else { + return kTfLiteError; + } + const int32_t begins_shape[] = {1, 1, 1, static_cast<int32_t>(begins.size())}; + auto begins_node = graph_builder_->AddConstNodeWithData( + begins_shape, reinterpret_cast<char*>(begins.data()), + sizeof(int32_t) * begins.size()); + auto sizes_node = graph_builder_->AddConstNodeWithData( + begins_shape, reinterpret_cast<char*>(sizes.data()), + sizeof(int32_t) * begins.size()); + AddInput(TensorID(begins_node->GetID(), 0)); + AddInput(TensorID(sizes_node->GetID(), 0)); + + // Input min/max + TF_LITE_ENSURE_STATUS( + ComputeMinAndMaxQuantValues(input_tensor, &input_min_, &input_max_)); + auto* input_min_const = graph_builder_->AddConstNodeWithData( + quant_bound_shape, reinterpret_cast<char*>(&input_min_), + sizeof(input_min_)); + auto* input_max_const = graph_builder_->AddConstNodeWithData( + quant_bound_shape, reinterpret_cast<char*>(&input_max_), + sizeof(input_max_)); + AddInput(TensorID(input_min_const->GetID(), 0)); + AddInput(TensorID(input_max_const->GetID(), 0)); + + // Outputs + int output_batch_size, output_height_size, output_width_size, + output_depth_size; + GetDims(&output_batch_size, &output_height_size, &output_width_size, + &output_depth_size, context->tensors[outputs->data[0]].dims); + node_output_ = AddOutput(sizeof(uint8_t), 4, + {output_batch_size, output_height_size, + output_width_size, output_depth_size}); + AddOutput(sizeof(float), 4, {1, 1, 1, 1}); + AddOutput(sizeof(float), 4, {1, 1, 1, 1}); + return kTfLiteOk; +} + +TfLiteStatus SliceOpBuilder::RegisterOutputs(const TfLiteIntArray* outputs, + TfLiteContext* context) { + // Should be only 1 output. + graph_builder_->AddTensorWithID(outputs->data[0], node_output_.first, + node_output_.second); + return kTfLiteOk; +} + +OpBuilder* CreateSliceOpBuilder(GraphBuilder* graph_builder, int op_type) { + return new SliceOpBuilder(graph_builder, op_type); +} +} // namespace hexagon +} // namespace delegates +} // namespace tflite diff --git a/tensorflow/lite/experimental/delegates/hexagon/builders/slice_builder.h b/tensorflow/lite/experimental/delegates/hexagon/builders/slice_builder.h new file mode 100644 index 00000000000..0ee06630dba --- /dev/null +++ b/tensorflow/lite/experimental/delegates/hexagon/builders/slice_builder.h @@ -0,0 +1,45 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_HEXAGON_BUILDERS_SLICE_BUILDER_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_HEXAGON_BUILDERS_SLICE_BUILDER_H_ + +#include "tensorflow/lite/experimental/delegates/hexagon/builders/op_builder.h" + +namespace tflite { +namespace delegates { +namespace hexagon { + +class SliceOpBuilder : public OpBuilder { + public: + explicit SliceOpBuilder(GraphBuilder* graph_builder, int op_type) + : OpBuilder(graph_builder, op_type) {} + + TfLiteStatus PopulateSubGraph(const TfLiteIntArray* inputs, + const TfLiteIntArray* outputs, + TfLiteContext* context) override; + + TfLiteStatus RegisterOutputs(const TfLiteIntArray* outputs, + TfLiteContext* context) override; + + private: + TensorID node_output_; + float input_min_, input_max_; +}; + +} // namespace hexagon +} // namespace delegates +} // namespace tflite + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_HEXAGON_BUILDERS_SLICE_BUILDER_H_ diff --git a/tensorflow/lite/experimental/delegates/hexagon/builders/tests/BUILD b/tensorflow/lite/experimental/delegates/hexagon/builders/tests/BUILD index a5cdc0411ca..0627d5b202d 100644 --- a/tensorflow/lite/experimental/delegates/hexagon/builders/tests/BUILD +++ b/tensorflow/lite/experimental/delegates/hexagon/builders/tests/BUILD @@ -34,12 +34,14 @@ hexagon_op_tests( "mirror_pad_test.cc", "mul_test.cc", "neg_test.cc", + "pack_test.cc", "pad_test.cc", "pool_test.cc", "quantize_test.cc", "reduce_test.cc", "reshape_test.cc", "resize_test.cc", + "slice_test.cc", "softmax_test.cc", "space_to_depth_test.cc", "split_test.cc", diff --git a/tensorflow/lite/experimental/delegates/hexagon/builders/tests/matmul_test.cc b/tensorflow/lite/experimental/delegates/hexagon/builders/tests/matmul_test.cc index a16e22888dd..ff2c71946e7 100644 --- a/tensorflow/lite/experimental/delegates/hexagon/builders/tests/matmul_test.cc +++ b/tensorflow/lite/experimental/delegates/hexagon/builders/tests/matmul_test.cc @@ -22,7 +22,8 @@ using testing::ElementsAreArray; class FullyConnectedOpModel : public SingleOpModelWithHexagon { public: FullyConnectedOpModel(int units, int batches, const TensorData& input, - const TensorData& output) + const TensorData& output, bool optional_bias = false, + bool const_weights = true) : batches_(batches), units_(units) { int total_input_size = 1; for (size_t i = 0; i < input.shape.size(); ++i) { @@ -34,9 +35,13 @@ class FullyConnectedOpModel : public SingleOpModelWithHexagon { weights_ = AddInput({input.type, {units_, input_size_}, input.min, input.max}); - auto bias_scale = GetScale(input_) * GetScale(weights_); - TensorData bias{TensorType_INT32, {units_}, 0, 0, bias_scale}; - bias_ = AddInput(bias); + if (optional_bias) { + bias_ = AddNullInput(); + } else { + auto bias_scale = GetScale(input_) * GetScale(weights_); + TensorData bias{TensorType_INT32, {units_}, 0, 0, bias_scale}; + bias_ = AddInput(bias); + } output_ = AddOutput(output); @@ -46,15 +51,18 @@ class FullyConnectedOpModel : public SingleOpModelWithHexagon { FullyConnectedOptionsWeightsFormat_DEFAULT, /*keep_num_dims=*/false) .Union()); - - BuildInterpreter({GetShape(input_), GetShape(weights_), GetShape(bias_)}); + BuildInterpreter({GetShape(input_), GetShape(weights_)}); // Weights & bias tensors need to be constant. // We don't use AddConstInput to allow setting filter values later. - auto* weights_tensor = interpreter_->tensor(weights_); - weights_tensor->allocation_type = kTfLiteMmapRo; - auto* bias_tensor = interpreter_->tensor(bias_); - bias_tensor->allocation_type = kTfLiteMmapRo; + if (const_weights) { + auto* weights_tensor = interpreter_->tensor(weights_); + weights_tensor->allocation_type = kTfLiteMmapRo; + } + if (!optional_bias) { + auto* bias_tensor = interpreter_->tensor(bias_); + bias_tensor->allocation_type = kTfLiteMmapRo; + } } void SetBias(const std::vector<float>& data) { @@ -146,4 +154,113 @@ TEST(QuantizedFullyConnectedOpTest, TestQuantizedUint8) { ElementsAre(151, 152, 153, 185, 186, 187)); } +TEST(QuantizedFullyConnectedOpTest, TestQuantizedUint8_NoBias) { + FullyConnectedOpModel m( + /*units=*/3, /*batches*/ 2, + /*input=*/{TensorType_UINT8, {2, 10}, -63.5, 64}, + /*output=*/{TensorType_UINT8, {}, -127, 128}, /*optional_bias*/ true); + + m.SetWeights<uint8_t>({ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2 + }); + + m.SetInput<uint8_t>({ + 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0 + 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1 + }); + + m.Invoke(); + auto reference_output = m.GetDequantizedOutput<uint8_t>(); + + m.ApplyDelegateAndInvoke(); + + EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(), + ElementsAreArray(ArrayFloatNear(reference_output))); +} + +TEST(QuantizedFullyConnectedOpTest, TestQuantizedInt8_NoBias) { + FullyConnectedOpModel m(/*units=*/3, /*batches*/ 2, + /*input=*/{TensorType_INT8, {2, 10}, -63.5, 64}, + /*output=*/{TensorType_INT8, {}, -127, 128}, + /*optional_bias*/ true); + + m.SetWeights<int8_t>({ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2 + }); + + m.SetInput<int8_t>({ + 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0 + 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1 + }); + + m.Invoke(); + auto reference_output = m.GetDequantizedOutput<int8_t>(); + + m.ApplyDelegateAndInvoke(); + + EXPECT_THAT(m.GetDequantizedOutput<int8_t>(), + ElementsAreArray(ArrayFloatNear(reference_output))); +} + +TEST(QuantizedFullyConnectedOpTest, TestQuantizedInt8_NonConstWeights) { + FullyConnectedOpModel m(/*units=*/3, /*batches*/ 2, + /*input=*/{TensorType_INT8, {2, 10}, -63.5, 64}, + /*output=*/{TensorType_INT8, {}, -127, 128}, + /*optional_bias=*/false, /*const_weights=*/false); + + m.SetWeights<int8_t>({ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2 + }); + m.SetBias({1, 2, 3}); + + m.SetInput<int8_t>({ + 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0 + 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1 + }); + + m.Invoke(); + auto reference_output = m.GetDequantizedOutput<int8_t>(); + + m.ApplyDelegateAndInvoke(); + + EXPECT_THAT(m.GetDequantizedOutput<int8_t>(), + ElementsAreArray(ArrayFloatNear(reference_output))); +} + +TEST(QuantizedFullyConnectedOpTest, TestQuantizedUint8_NonConstWeights) { + FullyConnectedOpModel m( + /*units=*/3, /*batches*/ 2, + /*input=*/{TensorType_UINT8, {2, 10}, -63.5, 64}, + /*output=*/{TensorType_UINT8, {}, -127, 128}, /*optional_bias=*/false, + /*const_weights=*/false); + + m.SetWeights<uint8_t>({ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2 + }); + m.SetBias({1, 2, 3}); + + m.SetInput<uint8_t>({ + 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0 + 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1 + }); + + m.ApplyDelegateAndInvoke(); + + EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(), + ElementsAreArray(ArrayFloatNear({ + 24, 25, 26, // + 58, 59, 60, // + }))); + EXPECT_THAT(m.GetOutput<uint8_t>(), + ElementsAre(151, 152, 153, 185, 186, 187)); +} + } // namespace tflite diff --git a/tensorflow/lite/experimental/delegates/hexagon/builders/tests/pack_test.cc b/tensorflow/lite/experimental/delegates/hexagon/builders/tests/pack_test.cc new file mode 100644 index 00000000000..6f030575a01 --- /dev/null +++ b/tensorflow/lite/experimental/delegates/hexagon/builders/tests/pack_test.cc @@ -0,0 +1,125 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <gtest/gtest.h> +#include "tensorflow/lite/experimental/delegates/hexagon/builders/tests/hexagon_delegate_op_model.h" + +namespace tflite { +using testing::ElementsAreArray; + +class PackOpModel : public SingleOpModelWithHexagon { + public: + PackOpModel(const TensorData& input_template, int axis, int values_count) { + std::vector<std::vector<int>> all_input_shapes; + for (int i = 0; i < values_count; ++i) { + all_input_shapes.push_back(input_template.shape); + AddInput(input_template); + } + output_ = AddOutput({input_template.type, /*shape=*/{}, input_template.min, + input_template.max}); + SetBuiltinOp(BuiltinOperator_PACK, BuiltinOptions_PackOptions, + CreatePackOptions(builder_, values_count, axis).Union()); + BuildInterpreter(all_input_shapes); + } + + std::vector<int> GetOutputShape() { return GetTensorShape(output_); } + + template <typename integer_type> + void SetInput(int index, std::initializer_list<float> data) { + QuantizeAndPopulate<integer_type>(index, data); + } + + template <typename integer_type> + std::vector<float> GetDequantizedOutput() { + return Dequantize<integer_type>(ExtractVector<integer_type>(output_), + GetScale(output_), GetZeroPoint(output_)); + } + + private: + int output_; +}; + +template <typename InputType> +struct PackOpTest : public ::testing::Test { + using TypeToTest = InputType; + TensorType TENSOR_TYPE = + (std::is_same<InputType, int16_t>::value + ? TensorType_INT16 + : (std::is_same<InputType, uint8_t>::value ? TensorType_UINT8 + : TensorType_INT8)); +}; + +using TestTypes = testing::Types<int8_t, uint8_t>; +TYPED_TEST_CASE(PackOpTest, TestTypes); + +TYPED_TEST(PackOpTest, ThreeInputs) { + PackOpModel model({TestFixture::TENSOR_TYPE, {2}, -10, 10}, 0, 3); + model.SetInput<typename TestFixture::TypeToTest>(0, {1, 4}); + model.SetInput<typename TestFixture::TypeToTest>(1, {2, 5}); + model.SetInput<typename TestFixture::TypeToTest>(2, {3, 6}); + model.Invoke(); + auto ref_output_shape = model.GetOutputShape(); + auto ref_output = + model.GetDequantizedOutput<typename TestFixture::TypeToTest>(); + model.ApplyDelegateAndInvoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray(ref_output_shape)); + EXPECT_THAT(model.GetDequantizedOutput<typename TestFixture::TypeToTest>(), + ElementsAreArray(ArrayFloatNear(ref_output))); +} + +TYPED_TEST(PackOpTest, ThreeInputsDifferentAxis) { + PackOpModel model({TestFixture::TENSOR_TYPE, {2}, -10, 10}, 1, 3); + model.SetInput<typename TestFixture::TypeToTest>(0, {1, 4}); + model.SetInput<typename TestFixture::TypeToTest>(1, {2, 5}); + model.SetInput<typename TestFixture::TypeToTest>(2, {3, 6}); + model.Invoke(); + auto ref_output_shape = model.GetOutputShape(); + auto ref_output = + model.GetDequantizedOutput<typename TestFixture::TypeToTest>(); + model.ApplyDelegateAndInvoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray(ref_output_shape)); + EXPECT_THAT(model.GetDequantizedOutput<typename TestFixture::TypeToTest>(), + ElementsAreArray(ArrayFloatNear(ref_output))); +} + +TYPED_TEST(PackOpTest, ThreeInputsNegativeAxis) { + PackOpModel model({TestFixture::TENSOR_TYPE, {2}, -10, 10}, -1, 3); + model.SetInput<typename TestFixture::TypeToTest>(0, {1, 4}); + model.SetInput<typename TestFixture::TypeToTest>(1, {2, 5}); + model.SetInput<typename TestFixture::TypeToTest>(2, {3, 6}); + model.Invoke(); + auto ref_output_shape = model.GetOutputShape(); + auto ref_output = + model.GetDequantizedOutput<typename TestFixture::TypeToTest>(); + model.ApplyDelegateAndInvoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray(ref_output_shape)); + EXPECT_THAT(model.GetDequantizedOutput<typename TestFixture::TypeToTest>(), + ElementsAreArray(ArrayFloatNear(ref_output))); +} + +TYPED_TEST(PackOpTest, MultilDimensions) { + PackOpModel model({TestFixture::TENSOR_TYPE, {2, 3}, -10, 20}, 1, 2); + model.SetInput<typename TestFixture::TypeToTest>(0, {1, 2, 3, 4, 5, 6}); + model.SetInput<typename TestFixture::TypeToTest>(1, {7, 8, 9, 10, 11, 12}); + model.Invoke(); + auto ref_output_shape = model.GetOutputShape(); + auto ref_output = + model.GetDequantizedOutput<typename TestFixture::TypeToTest>(); + model.ApplyDelegateAndInvoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray(ref_output_shape)); + EXPECT_THAT(model.GetDequantizedOutput<typename TestFixture::TypeToTest>(), + ElementsAreArray(ArrayFloatNear(ref_output))); +} + +} // namespace tflite diff --git a/tensorflow/lite/experimental/delegates/hexagon/builders/tests/slice_test.cc b/tensorflow/lite/experimental/delegates/hexagon/builders/tests/slice_test.cc new file mode 100644 index 00000000000..d3bcfb6a6c2 --- /dev/null +++ b/tensorflow/lite/experimental/delegates/hexagon/builders/tests/slice_test.cc @@ -0,0 +1,163 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <gtest/gtest.h> +#include "tensorflow/lite/experimental/delegates/hexagon/builders/tests/hexagon_delegate_op_model.h" + +namespace tflite { +using testing::ElementsAreArray; + +template <typename index_type> +class SliceOpModel : public SingleOpModelWithHexagon { + public: + SliceOpModel(const TensorData& input, const TensorData& output, + const TensorData& begin, const TensorData& size, + std::initializer_list<index_type> begin_data, + std::initializer_list<index_type> size_data) { + input_ = AddInput(input); + begin_ = AddConstInput(begin, begin_data); + size_ = AddConstInput(size, size_data); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_SLICE, BuiltinOptions_SliceOptions, + CreateSliceOptions(builder_).Union()); + BuildInterpreter({GetShape(input_), GetShape(begin_), GetShape(size_)}); + } + + template <typename T> + void SetInput(std::initializer_list<float> data) { + QuantizeAndPopulate<T>(input_, data); + } + + template <typename T> + std::vector<float> GetDequantizedOutput() { + return Dequantize<T>(ExtractVector<T>(output_), GetScale(output_), + GetZeroPoint(output_)); + } + + std::vector<int> GetOutputShape() { return GetTensorShape(output_); } + + private: + int input_; + int begin_; + int size_; + int output_; +}; + +TEST(SliceOpTest, Input_1D_Uint8) { + SliceOpModel<int32_t> m(/*input=*/{TensorType_UINT8, {4}, -10, 10}, + /*output=*/{TensorType_UINT8, {2}, -10, 10}, + {TensorType_INT32, {1}}, {TensorType_INT32, {1}}, {1}, + {2}); + m.SetInput<uint8_t>({1, 2, 3, 4}); + m.ApplyDelegateAndInvoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(), + ElementsAreArray(ArrayFloatNear({2, 3}, 0.1))); +} + +TEST(SliceOpTest, Input_2D_Uint8) { + SliceOpModel<int32_t> m( + /*input=*/{TensorType_UINT8, {2, 3}, -10, 10}, + /*output=*/{TensorType_UINT8, {1, 2}, -10, 10}, {TensorType_INT32, {2}}, + {TensorType_INT32, {2}}, {1, 0}, {1, 2}); + m.SetInput<uint8_t>({1, 2, 3, 4, 5, 6}); + m.Invoke(); + auto reference_output = m.GetDequantizedOutput<uint8_t>(); + auto reference_output_shape = m.GetOutputShape(); + m.ApplyDelegateAndInvoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray(reference_output_shape)); + EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(), + ElementsAreArray(ArrayFloatNear(reference_output, 0.1))); +} + +TEST(SliceOpTest, SizeInt64_Uint8) { + SliceOpModel<int64_t> m(/*input=*/{TensorType_UINT8, {4, 1, 1, 1}, -10, 10}, + /*output=*/{TensorType_UINT8, {3, 1, 1, 1}, -10, 10}, + {TensorType_INT64, {4}}, {TensorType_INT64, {4}}, + {1, 0, 0, 0}, {3, 1, 1, 1}); + m.SetInput<uint8_t>({1, 2, 3, 4}); + m.Invoke(); + auto reference_output = m.GetDequantizedOutput<uint8_t>(); + auto reference_output_shape = m.GetOutputShape(); + m.ApplyDelegateAndInvoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray(reference_output_shape)); + EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(), + ElementsAreArray(ArrayFloatNear(reference_output, 0.1))); +} + +TEST(SliceOpTest, SizeMinus1) { + SliceOpModel<int64_t> m( + /*input=*/{TensorType_UINT8, {3, 2, 3, 1}, -10, 10}, + /*output=*/{TensorType_UINT8, {2, 1, 3, 1}, -10, 10}, + {TensorType_INT64, {4}}, {TensorType_INT64, {4}}, {1, 0, 0, 0}, + {2, 1, -1, 1}); + m.SetInput<uint8_t>({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6}); + m.Invoke(); + auto reference_output = m.GetDequantizedOutput<uint8_t>(); + auto reference_output_shape = m.GetOutputShape(); + m.ApplyDelegateAndInvoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray(reference_output_shape)); + EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(), + ElementsAreArray(ArrayFloatNear(reference_output, 0.1))); +} + +TEST(SliceOpTest, BeginNonZeroSizeMinus1Axis1) { + SliceOpModel<int64_t> m( + /*input=*/{TensorType_UINT8, {3, 3, 2, 1}, -10, 10}, + /*output=*/{TensorType_UINT8, {2, 2, 1, 1}, -10, 10}, + {TensorType_INT64, {4}}, {TensorType_INT64, {4}}, {1, 1, 0, 0}, + {2, -1, 1, 1}); + m.SetInput<uint8_t>({1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9}); + m.Invoke(); + auto reference_output = m.GetDequantizedOutput<uint8_t>(); + auto reference_output_shape = m.GetOutputShape(); + m.ApplyDelegateAndInvoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray(reference_output_shape)); + EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(), + ElementsAreArray(ArrayFloatNear(reference_output, 0.1))); +} + +TEST(SliceOpTest, BeginNonZeroSizeMinus1Axis2) { + SliceOpModel<int64_t> m( + /*input=*/{TensorType_UINT8, {3, 2, 3, 1}, -10, 10}, + /*output=*/{TensorType_UINT8, {2, 1, 2, 1}, -10, 10}, + {TensorType_INT64, {4}}, {TensorType_INT64, {4}}, {1, 0, 1, 0}, + {2, 1, -1, 1}); + m.SetInput<uint8_t>({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6}); + m.Invoke(); + auto reference_output = m.GetDequantizedOutput<uint8_t>(); + auto reference_output_shape = m.GetOutputShape(); + m.ApplyDelegateAndInvoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray(reference_output_shape)); + EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(), + ElementsAreArray(ArrayFloatNear(reference_output, 0.1))); +} + +TEST(SliceOpTest, BeginNonZeroSizeMinus1Axis2_Int8) { + SliceOpModel<int64_t> m( + /*input=*/{TensorType_INT8, {3, 2, 3, 1}, -10, 10}, + /*output=*/{TensorType_INT8, {2, 1, 2, 1}, -10, 10}, + {TensorType_INT64, {4}}, {TensorType_INT64, {4}}, {1, 0, 1, 0}, + {2, 1, -1, 1}); + m.SetInput<int8_t>({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6}); + m.Invoke(); + auto reference_output = m.GetDequantizedOutput<int8_t>(); + auto reference_output_shape = m.GetOutputShape(); + m.ApplyDelegateAndInvoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray(reference_output_shape)); + EXPECT_THAT(m.GetDequantizedOutput<int8_t>(), + ElementsAreArray(ArrayFloatNear(reference_output, 0.1))); +} + +} // namespace tflite diff --git a/tensorflow/lite/experimental/delegates/hexagon/utils.cc b/tensorflow/lite/experimental/delegates/hexagon/utils.cc index 8aff13549b8..80f82749e80 100644 --- a/tensorflow/lite/experimental/delegates/hexagon/utils.cc +++ b/tensorflow/lite/experimental/delegates/hexagon/utils.cc @@ -42,6 +42,8 @@ bool InputsWithCorrectTypes( const std::vector<std::vector<TfLiteType>>& per_input_possible_types) { if (node->inputs->size != per_input_possible_types.size()) return false; for (int i = 0; i < per_input_possible_types.size(); ++i) { + // Skip optional tensor. + if (node->inputs->data[i] == -1) continue; bool type_found = false; for (auto possible_type : per_input_possible_types[i]) { if (TensorTypeMatch(node->inputs->data[i], context, possible_type)) { @@ -85,11 +87,13 @@ bool CheckOpVersion(const TfLiteRegistration* registration) { case kTfLiteBuiltinMinimum: case kTfLiteBuiltinMirrorPad: case kTfLiteBuiltinMul: + case kTfLiteBuiltinPack: case kTfLiteBuiltinPad: case kTfLiteBuiltinQuantize: case kTfLiteBuiltinRelu6: case kTfLiteBuiltinResizeBilinear: case kTfLiteBuiltinResizeNearestNeighbor: + case kTfLiteBuiltinSlice: case kTfLiteBuiltinSoftmax: case kTfLiteBuiltinSpaceToDepth: case kTfLiteBuiltinSplit: @@ -116,6 +120,9 @@ bool IsNodeSupportedByHexagon(const TfLiteRegistration* registration, int tensor_id; for (int i = 0; i < node->inputs->size; ++i) { tensor_id = node->inputs->data[i]; + // Skip optional tensors. Builders should handle optional tensors + // not available. + if (tensor_id == -1) continue; const auto& tensor = context->tensors[tensor_id]; if (tensor.dims->size > 4) return false; } @@ -191,20 +198,21 @@ bool IsNodeSupportedByHexagon(const TfLiteRegistration* registration, if (!InputsWithCorrectTypes(node, context, {{kTfLiteUInt8, kTfLiteInt8}, {kTfLiteUInt8, kTfLiteInt8}, - {kTfLiteInt32}})) + {kTfLiteInt32, kTfLiteNoType}})) { return false; + } - const auto& weights_tensor = context->tensors[node->inputs->data[1]]; - const auto& bias_tensor = context->tensors[node->inputs->data[2]]; - const bool weights_and_bias_const = - weights_tensor.allocation_type == kTfLiteMmapRo && - bias_tensor.allocation_type == kTfLiteMmapRo; + bool bias_const_or_no_bias = true; + if (node->inputs->data[2] != -1) { + const auto& bias_tensor = context->tensors[node->inputs->data[2]]; + bias_const_or_no_bias = bias_tensor.allocation_type == kTfLiteMmapRo; + } const TfLiteFullyConnectedParams* matmul_params = reinterpret_cast<const TfLiteFullyConnectedParams*>( node->builtin_data); - return (weights_and_bias_const && - IsActivationReluOrNone(matmul_params->activation) && + return (bias_const_or_no_bias && + matmul_params->activation == kTfLiteActNone && matmul_params->keep_num_dims == false && matmul_params->weights_format == kTfLiteFullyConnectedWeightsFormatDefault); @@ -335,7 +343,8 @@ bool IsNodeSupportedByHexagon(const TfLiteRegistration* registration, return false; const auto& input_tensor = context->tensors[node->inputs->data[1]]; const bool is_four_dim_or_less = input_tensor.dims->size < 5; - // We need splitting axis to be constant, so Hexagon knows output shapes. + // We need splitting axis to be constant, so Hexagon knows output + // shapes. return is_four_dim_or_less && IsConstantTensor(GetInput(context, node, 0)); } @@ -378,6 +387,25 @@ bool IsNodeSupportedByHexagon(const TfLiteRegistration* registration, node, context, {{kTfLiteUInt8, kTfLiteInt8}, {kTfLiteUInt8, kTfLiteInt8}}); } + case kTfLiteBuiltinSlice: { + const auto& begins_tensor = context->tensors[node->inputs->data[1]]; + const auto& sizes_tensor = context->tensors[node->inputs->data[2]]; + if (!IsConstantTensor(&begins_tensor) || !IsConstantTensor(&sizes_tensor)) + return false; + return InputsWithCorrectTypes(node, context, + {{kTfLiteUInt8, kTfLiteInt8}, + {kTfLiteInt32, kTfLiteInt64}, + {kTfLiteInt32, kTfLiteInt64}}); + } + case kTfLiteBuiltinPack: { + // All tensors must be 8-bit. + for (int i = 0; i < node->inputs->size; ++i) { + if (!TensorTypeMatch(node->inputs->data[i], context, kTfLiteUInt8) && + !TensorTypeMatch(node->inputs->data[i], context, kTfLiteInt8)) + return false; + } + return true; + } default: return false; } diff --git a/tensorflow/lite/experimental/ios/BUILD.apple b/tensorflow/lite/experimental/ios/BUILD.apple index 5c954bc3de8..ddbfc0dec5b 100644 --- a/tensorflow/lite/experimental/ios/BUILD.apple +++ b/tensorflow/lite/experimental/ios/BUILD.apple @@ -11,17 +11,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -genrule( - name = "strip_coreml_include_hdr", - srcs = ["//tensorflow/lite/experimental/delegates/coreml:coreml_delegate.h"], - outs = ["coreml_delegate.h"], - cmd = """ - sed 's/#include \".*common.h"/#include \"common.h\"/' \ - "$(location //tensorflow/lite/experimental/delegates/coreml:coreml_delegate.h)" \ - > "$@" - """, -) - TFL_FRAMEWORK_HDRS = [ "//tensorflow/lite/delegates/gpu:metal_delegate.h", "//tensorflow/lite/c:c_api.h", @@ -57,6 +46,17 @@ ios_static_framework( ], ) +genrule( + name = "strip_coreml_include_hdr", + srcs = ["//tensorflow/lite/experimental/delegates/coreml:coreml_delegate.h"], + outs = ["coreml_delegate.h"], + cmd = """ + sed 's|#include ".*common.h"|#include "TensorFlowLiteC/common.h"|'\ + "$(location //tensorflow/lite/experimental/delegates/coreml:coreml_delegate.h)"\ + > "$@" + """, +) + # This target builds the Core ML delegate as a separate static framework, which # does not include the TensorFlow Lite runtime. As this target does not contain # TensorFlow Lite runtime, it is intended to be linked along with the @@ -78,15 +78,32 @@ ios_static_framework( ], ) +# This target builds the Metal delegate as a separate static framework, which +# does not include the TensorFlow Lite runtime. As this target does not contain +# TensorFlow Lite runtime, it is intended to be linked along with the +# TensorFlowLiteC framework above in a composable way. +# +# bazel build -c opt --config=ios_fat //tensorflow/lite/experimental/ios:TensorFlowLiteCMetal_framework +ios_static_framework( + name = "TensorFlowLiteCMetal_framework", + hdrs = [ + "//tensorflow/lite/delegates/gpu:metal_delegate.h", + ], + avoid_deps = [ + ":tensorflow_lite_c", + ], + bundle_name = "TensorFlowLiteCMetal", + minimum_os_version = TFL_MINIMUM_OS_VERSION, + deps = [ + "//tensorflow/lite/delegates/gpu:metal_delegate", + ], +) + cc_library( name = "tensorflow_lite_c", hdrs = [ "//tensorflow/lite/c:c_api.h", "//tensorflow/lite/c:common.h", - "//tensorflow/lite/delegates/gpu:metal_delegate.h", - ], - linkopts = [ - "-Wl,-weak_framework,Metal", ], tags = [ "nobuilder", @@ -94,7 +111,6 @@ cc_library( ], deps = [ "//tensorflow/lite/c:c_api", - "//tensorflow/lite/delegates/gpu:metal_delegate", ], ) diff --git a/tensorflow/lite/experimental/ios/TensorFlowLiteC.podspec.template b/tensorflow/lite/experimental/ios/TensorFlowLiteC.podspec.template index d8a5ef8f2e1..3f0517e1fe6 100644 --- a/tensorflow/lite/experimental/ios/TensorFlowLiteC.podspec.template +++ b/tensorflow/lite/experimental/ios/TensorFlowLiteC.podspec.template @@ -31,4 +31,10 @@ Pod::Spec.new do |s| coreml.dependency 'TensorFlowLiteC/Core' coreml.vendored_frameworks = 'Frameworks/TensorFlowLiteCCoreML.framework' end + + s.subspec 'Metal' do |metal| + metal.weak_framework = 'Metal' + metal.dependency 'TensorFlowLiteC/Core' + metal.vendored_frameworks = 'Frameworks/TensorFlowLiteCMetal.framework' + end end diff --git a/tensorflow/lite/experimental/support/java/BUILD b/tensorflow/lite/experimental/support/java/BUILD index 43e984a0cb8..85f5da17193 100644 --- a/tensorflow/lite/experimental/support/java/BUILD +++ b/tensorflow/lite/experimental/support/java/BUILD @@ -9,7 +9,24 @@ package( licenses = ["notice"], # Apache 2.0 ) +# TODO(b/156482505): The NOGPU target is a temporary target. Internally, people +# may already depend on "tensorflow-lite-support" so we shouldn't remove GPU +# from its dependency. We will have CLs to help users migrate. After migration +# is done, the "NOGPU" target will be removed. +android_library( + name = "tensorflow-lite-support-nogpu", + srcs = glob(["src/java/org/tensorflow/lite/support/**/*.java"]), + javacopts = JAVACOPTS, + manifest = "AndroidManifest.xml", + deps = [ + "//tensorflow/lite/java:tensorflowlite", + "@org_checkerframework_qual", + ], +) + # TODO(138904786): Split Java part and Android part to make the support library usable by pure Java. +# For new users: Please use "tensorflow-lite-support-nogpu" if possible, and +# additionally depends on "tensorflowlite_gpu" if needed. android_library( name = "tensorflow-lite-support", srcs = glob(["src/java/org/tensorflow/lite/support/**/*.java"]), @@ -17,7 +34,7 @@ android_library( manifest = "AndroidManifest.xml", deps = [ "//tensorflow/lite/java:tensorflowlite", - "//tensorflow/lite/java:tensorflowlite_gpu", + "//tensorflow/lite/java:tensorflowlite_gpu", # unuseddeps: keep "@org_checkerframework_qual", ], ) diff --git a/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/model/GpuDelegateProxy.java b/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/model/GpuDelegateProxy.java new file mode 100644 index 00000000000..9cfcf923ded --- /dev/null +++ b/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/model/GpuDelegateProxy.java @@ -0,0 +1,69 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite.support.model; + +import android.util.Log; +import java.io.Closeable; +import java.io.IOException; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.tensorflow.lite.Delegate; + +/** + * Helper class to create and call necessary methods of {@code GpuDelegate} which is not a strict + * dependency. + */ +class GpuDelegateProxy implements Delegate, Closeable { + + private static final String TAG = "GpuDelegateProxy"; + + private final Delegate proxiedDelegate; + private final Closeable proxiedCloseable; + + @Nullable + public static GpuDelegateProxy maybeNewInstance() { + try { + Class<?> clazz = Class.forName("org.tensorflow.lite.gpu.GpuDelegate"); + Object instance = clazz.getDeclaredConstructor().newInstance(); + return new GpuDelegateProxy(instance); + } catch (ReflectiveOperationException e) { + Log.e(TAG, "Failed to create the GpuDelegate dynamically.", e); + return null; + } + } + + /** Calls {@code close()} method of the delegate. */ + @Override + public void close() { + try { + proxiedCloseable.close(); + } catch (IOException e) { + // Should not trigger, because GpuDelegate#close never throws. The catch is required because + // of Closeable#close. + Log.e(TAG, "Failed to close the GpuDelegate.", e); + } + } + + /** Calls {@code getNativeHandle()} method of the delegate. */ + @Override + public long getNativeHandle() { + return proxiedDelegate.getNativeHandle(); + } + + private GpuDelegateProxy(Object instance) { + this.proxiedCloseable = (Closeable) instance; + this.proxiedDelegate = (Delegate) instance; + } +} diff --git a/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/model/Model.java b/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/model/Model.java index c7f9e83f692..8062d68d7b9 100644 --- a/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/model/Model.java +++ b/tensorflow/lite/experimental/support/java/src/java/org/tensorflow/lite/support/model/Model.java @@ -22,7 +22,7 @@ import java.util.Map; import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; import org.tensorflow.lite.Interpreter; -import org.tensorflow.lite.gpu.GpuDelegate; +import org.tensorflow.lite.Tensor; import org.tensorflow.lite.support.common.FileUtil; import org.tensorflow.lite.support.common.SupportPreconditions; @@ -91,7 +91,7 @@ public class Model { /** The memory-mapped model data. */ private final MappedByteBuffer byteModel; - private final GpuDelegate gpuDelegate; + private final GpuDelegateProxy gpuDelegateProxy; /** * Builder for {@link Model}. @@ -181,24 +181,30 @@ public class Model { * @param modelPath The original path of the model. It can be fetched later by {@link * Model#getPath()}. * @param options The options for running the model. + * @throws IllegalArgumentException if {@code options.device} is {@link Device#GPU} but + * "tensorflow-lite-gpu" is not linked to the project. */ public static Model createModel( @NonNull MappedByteBuffer byteModel, @NonNull String modelPath, @NonNull Options options) { Interpreter.Options interpreterOptions = new Interpreter.Options(); - GpuDelegate gpuDelegate = options.device.equals(Device.GPU) ? new GpuDelegate() : null; + GpuDelegateProxy gpuDelegateProxy = null; switch (options.device) { case NNAPI: interpreterOptions.setUseNNAPI(true); break; case GPU: - interpreterOptions.addDelegate(gpuDelegate); + gpuDelegateProxy = GpuDelegateProxy.maybeNewInstance(); + SupportPreconditions.checkArgument( + gpuDelegateProxy != null, + "Cannot inference with GPU. Did you add \"tensorflow-lite-gpu\" as dependency?"); + interpreterOptions.addDelegate(gpuDelegateProxy); break; case CPU: break; } interpreterOptions.setNumThreads(options.numThreads); Interpreter interpreter = new Interpreter(byteModel, interpreterOptions); - return new Model(modelPath, byteModel, interpreter, gpuDelegate); + return new Model(modelPath, byteModel, interpreter, gpuDelegateProxy); } /** Returns the memory-mapped model data. */ @@ -213,6 +219,24 @@ public class Model { return modelPath; } + /** + * Gets the Tensor associated with the provdied input index. + * + * @throws IllegalStateException if the interpreter is closed. + */ + public Tensor getInputTensor(int inputIndex) { + return interpreter.getInputTensor(inputIndex); + } + + /** + * Gets the Tensor associated with the provdied output index. + * + * @throws IllegalStateException if the interpreter is closed. + */ + public Tensor getOutputTensor(int outputIndex) { + return interpreter.getOutputTensor(outputIndex); + } + /** * Returns the output shape. Useful if output shape is only determined when graph is created. * @@ -243,8 +267,8 @@ public class Model { if (interpreter != null) { interpreter.close(); } - if (gpuDelegate != null) { - gpuDelegate.close(); + if (gpuDelegateProxy != null) { + gpuDelegateProxy.close(); } } @@ -252,10 +276,10 @@ public class Model { @NonNull String modelPath, @NonNull MappedByteBuffer byteModel, @NonNull Interpreter interpreter, - @Nullable GpuDelegate gpuDelegate) { + @Nullable GpuDelegateProxy gpuDelegateProxy) { this.modelPath = modelPath; this.byteModel = byteModel; this.interpreter = interpreter; - this.gpuDelegate = gpuDelegate; + this.gpuDelegateProxy = gpuDelegateProxy; } } diff --git a/tensorflow/lite/experimental/support/metadata/BUILD b/tensorflow/lite/experimental/support/metadata/BUILD index d6417a1bfcf..4621c8c55d2 100644 --- a/tensorflow/lite/experimental/support/metadata/BUILD +++ b/tensorflow/lite/experimental/support/metadata/BUILD @@ -62,6 +62,7 @@ py_library( deps = [ ":metadata_schema_py", ":schema_py", + "//tensorflow/lite/experimental/support/metadata/cc/python:_pywrap_metadata_version", "//tensorflow/lite/experimental/support/metadata/flatbuffers_lib:_pywrap_flatbuffers", "//tensorflow/python:platform", "@flatbuffers//:runtime_py", diff --git a/tensorflow/lite/experimental/support/metadata/cc/BUILD b/tensorflow/lite/experimental/support/metadata/cc/BUILD new file mode 100644 index 00000000000..2b288abe368 --- /dev/null +++ b/tensorflow/lite/experimental/support/metadata/cc/BUILD @@ -0,0 +1,16 @@ +package( + default_visibility = ["//tensorflow/lite/experimental/support:users"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "metadata_version", + srcs = ["metadata_version.cc"], + hdrs = ["metadata_version.h"], + deps = [ + "//tensorflow/lite/c:common", + "//tensorflow/lite/experimental/support/metadata:metadata_schema_cc", + "//tensorflow/lite/tools:logging", + "@flatbuffers", + ], +) diff --git a/tensorflow/lite/experimental/support/metadata/cc/metadata_version.cc b/tensorflow/lite/experimental/support/metadata/cc/metadata_version.cc new file mode 100644 index 00000000000..4f43c1431a7 --- /dev/null +++ b/tensorflow/lite/experimental/support/metadata/cc/metadata_version.cc @@ -0,0 +1,50 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/experimental/support/metadata/cc/metadata_version.h" + +#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "tensorflow/lite/experimental/support/metadata/metadata_schema_generated.h" +#include "tensorflow/lite/tools/logging.h" + +namespace tflite { +namespace metadata { + +TfLiteStatus GetMinimumMetadataParserVersion(const uint8_t* buffer_data, + size_t buffer_size, + std::string* min_version) { + flatbuffers::Verifier verifier = + flatbuffers::Verifier(buffer_data, buffer_size); + if (!tflite::VerifyModelMetadataBuffer(verifier)) { + TFLITE_LOG(ERROR) << "The model metadata is not a valid FlatBuffer buffer."; + return kTfLiteError; + } + + // Returns the version as the initial default one, "1.0.0", because it is the + // first version ever for metadata_schema.fbs. + // + // Later, when new fields are added to the schema, we'll update the logic of + // getting the minimum metadata parser version. To be more specific, we'll + // have a table that records the new fields and the versions of the schema + // they are added to. And the minimum metadata parser version will be the + // largest version number of all fields that has been added to a metadata + // flatbuffer. + // TODO(b/156539454): replace the hardcoded version with template + genrule. + static constexpr char kDefaultVersion[] = "1.0.0"; + *min_version = kDefaultVersion; + return kTfLiteOk; +} + +} // namespace metadata +} // namespace tflite diff --git a/tensorflow/lite/experimental/support/metadata/cc/metadata_version.h b/tensorflow/lite/experimental/support/metadata/cc/metadata_version.h new file mode 100644 index 00000000000..71e90788af4 --- /dev/null +++ b/tensorflow/lite/experimental/support/metadata/cc/metadata_version.h @@ -0,0 +1,35 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_METADATA_CC_METADATA_VERSION_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_METADATA_CC_METADATA_VERSION_H_ + +#include <string> + +#include "tensorflow/lite/c/common.h" + +namespace tflite { +namespace metadata { + +// Gets the minimum metadata parser version that can fully understand all fields +// in a given metadata flatbuffer. TFLite Metadata follows Semantic Versioning +// 2.0. Each release version has the form MAJOR.MINOR.PATCH. +TfLiteStatus GetMinimumMetadataParserVersion(const uint8_t* buffer_data, + size_t buffer_size, + std::string* min_version); + +} // namespace metadata +} // namespace tflite + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_SUPPORT_METADATA_CC_METADATA_VERSION_H_ diff --git a/tensorflow/lite/experimental/support/metadata/cc/python/BUILD b/tensorflow/lite/experimental/support/metadata/cc/python/BUILD new file mode 100644 index 00000000000..4128f0ac9d1 --- /dev/null +++ b/tensorflow/lite/experimental/support/metadata/cc/python/BUILD @@ -0,0 +1,22 @@ +load("//tensorflow:tensorflow.bzl", "pybind_extension") + +package( + default_visibility = [ + "//tensorflow/lite/experimental/support/metadata:__pkg__", + ], + licenses = ["notice"], # Apache 2.0 +) + +pybind_extension( + name = "_pywrap_metadata_version", + srcs = [ + "metadata_version.cc", + ], + features = ["-use_header_modules"], + module_name = "_pywrap_metadata_version", + deps = [ + "//tensorflow/lite/c:common", + "//tensorflow/lite/experimental/support/metadata/cc:metadata_version", + "@pybind11", + ], +) diff --git a/tensorflow/lite/experimental/support/metadata/cc/python/metadata_version.cc b/tensorflow/lite/experimental/support/metadata/cc/python/metadata_version.cc new file mode 100644 index 00000000000..7d1f9d1e122 --- /dev/null +++ b/tensorflow/lite/experimental/support/metadata/cc/python/metadata_version.cc @@ -0,0 +1,55 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/experimental/support/metadata/cc/metadata_version.h" + +#include "pybind11/pybind11.h" +#include "tensorflow/lite/c/common.h" + +namespace tflite { +namespace metadata { + +PYBIND11_MODULE(_pywrap_metadata_version, m) { + m.doc() = R"pbdoc( + _pywrap_metadata_version + A module that returns the minimum metadata parser version of a given + metadata flatbuffer. + )pbdoc"; + + // Using pybind11 type conversions to convert between Python and native + // C++ types. There are other options to provide access to native Python types + // in C++ and vice versa. See the pybind 11 instrcution [1] for more details. + // Type converstions is recommended by pybind11, though the main downside + // is that a copy of the data must be made on every Python to C++ transition: + // this is needed since the C++ and Python versions of the same type generally + // won’t have the same memory layout. + // + // [1]: https://pybind11.readthedocs.io/en/stable/advanced/cast/index.html + m.def("GetMinimumMetadataParserVersion", + [](const std::string& buffer_data) -> std::string { + std::string min_version; + if (GetMinimumMetadataParserVersion( + reinterpret_cast<const uint8_t*>(buffer_data.c_str()), + buffer_data.length(), &min_version) != kTfLiteOk) { + pybind11::value_error( + "Error occurred when getting the minimum metadata parser " + "version of the metadata flatbuffer."); + } + return min_version; + }); +} + +} // namespace metadata +} // namespace tflite diff --git a/tensorflow/lite/experimental/support/metadata/cc/test/BUILD b/tensorflow/lite/experimental/support/metadata/cc/test/BUILD new file mode 100644 index 00000000000..fd829124c73 --- /dev/null +++ b/tensorflow/lite/experimental/support/metadata/cc/test/BUILD @@ -0,0 +1,15 @@ +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +cc_test( + name = "metadata_version_test", + srcs = ["metadata_version_test.cc"], + deps = [ + "//tensorflow/lite/experimental/support/metadata:metadata_schema_cc", + "//tensorflow/lite/experimental/support/metadata/cc:metadata_version", + "@com_google_googletest//:gtest_main", + "@flatbuffers", + ], +) diff --git a/tensorflow/lite/experimental/support/metadata/cc/test/metadata_version_test.cc b/tensorflow/lite/experimental/support/metadata/cc/test/metadata_version_test.cc new file mode 100644 index 00000000000..00d9c0902c6 --- /dev/null +++ b/tensorflow/lite/experimental/support/metadata/cc/test/metadata_version_test.cc @@ -0,0 +1,65 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/experimental/support/metadata/cc/metadata_version.h" + +#include <gmock/gmock.h> +#include <gtest/gtest.h> +#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "tensorflow/lite/experimental/support/metadata/metadata_schema_generated.h" + +namespace tflite { +namespace metadata { +namespace { + +using ::testing::MatchesRegex; + +TEST(MetadataVersionTest, + GetMinimumMetadataParserVersionSucceedsWithValidMetadata) { + // Creates a dummy metadata flatbuffer for test. + flatbuffers::FlatBufferBuilder builder(1024); + auto name = builder.CreateString("Foo"); + ModelMetadataBuilder metadata_builder(builder); + metadata_builder.add_name(name); + auto metadata = metadata_builder.Finish(); + FinishModelMetadataBuffer(builder, metadata); + + // Gets the mimimum metadata parser version. + std::string min_version; + EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), + builder.GetSize(), &min_version), + kTfLiteOk); + // Validates that the version is well-formed (x.y.z). + EXPECT_THAT(min_version, MatchesRegex("[0-9]*\\.[0-9]*\\.[0-9]")); +} + +TEST(MetadataVersionTest, + GetMinimumMetadataParserVersionSucceedsWithInvalidIdentifier) { + // Creates a dummy metadata flatbuffer without identifier. + flatbuffers::FlatBufferBuilder builder(1024); + ModelMetadataBuilder metadata_builder(builder); + auto metadata = metadata_builder.Finish(); + builder.Finish(metadata); + + // Gets the mimimum metadata parser version and triggers error. + std::string min_version; + EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), + builder.GetSize(), &min_version), + kTfLiteError); + EXPECT_TRUE(min_version.empty()); +} + +} // namespace +} // namespace metadata +} // namespace tflite diff --git a/tensorflow/lite/experimental/support/metadata/java/BUILD b/tensorflow/lite/experimental/support/metadata/java/BUILD index 82b6e9866a9..c208752ae24 100644 --- a/tensorflow/lite/experimental/support/metadata/java/BUILD +++ b/tensorflow/lite/experimental/support/metadata/java/BUILD @@ -16,7 +16,6 @@ android_library( deps = [ "//tensorflow/lite/experimental/support/metadata:metadata_schema_fbs_android", "//tensorflow/lite/experimental/support/metadata:schema_fbs_android", - "//tensorflow/lite/java:tensorflowlite_java", "@org_checkerframework_qual", ], ) @@ -32,7 +31,6 @@ java_library( deps = [ "//tensorflow/lite/experimental/support/metadata:metadata_schema_java", "//tensorflow/lite/experimental/support/metadata:schema_fbs_java", - "//tensorflow/lite/java:tensorflowlite_javalib", "@org_checkerframework_qual", ], ) diff --git a/tensorflow/lite/experimental/support/metadata/java/src/java/org/tensorflow/lite/support/metadata/MetadataExtractor.java b/tensorflow/lite/experimental/support/metadata/java/src/java/org/tensorflow/lite/support/metadata/MetadataExtractor.java index 3ded50e5d95..be4d8caf577 100644 --- a/tensorflow/lite/experimental/support/metadata/java/src/java/org/tensorflow/lite/support/metadata/MetadataExtractor.java +++ b/tensorflow/lite/experimental/support/metadata/java/src/java/org/tensorflow/lite/support/metadata/MetadataExtractor.java @@ -22,8 +22,6 @@ import java.io.InputStream; import java.nio.ByteBuffer; import java.util.zip.ZipException; import org.checkerframework.checker.nullness.qual.Nullable; -import org.tensorflow.lite.DataType; -import org.tensorflow.lite.Tensor.QuantizationParams; import org.tensorflow.lite.schema.Tensor; import org.tensorflow.lite.support.metadata.schema.ModelMetadata; import org.tensorflow.lite.support.metadata.schema.TensorMetadata; @@ -111,6 +109,48 @@ public class MetadataExtractor { zipFile = createZipFile(buffer); } + /** + * Quantization parameters that corresponds to the table, {@code QuantizationParameters}, in the + * <a + * href="https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/schema/schema.fbs">TFLite + * Model schema file.</a> + * + * <p>Since per-channel quantization does not apply to input and output tensors, {@code scale} and + * {@code zero_point} are both single values instead of arrays. + * + * <p>For tensor that are not quantized, the values of scale and zero_point are both 0. + * + * <p>Given a quantized value q, the corresponding float value f should be: <br> + * f = scale * (q - zero_point) <br> + */ + public static class QuantizationParams { + /** The scale value used in quantization. */ + private final float scale; + /** The zero point value used in quantization. */ + private final int zeroPoint; + + /** + * Creates a {@link QuantizationParams} with {@code scale} and {@code zero_point}. + * + * @param scale The scale value used in quantization. + * @param zeroPoint The zero point value used in quantization. + */ + public QuantizationParams(final float scale, final int zeroPoint) { + this.scale = scale; + this.zeroPoint = zeroPoint; + } + + /** Returns the scale value. */ + public float getScale() { + return scale; + } + + /** Returns the zero point value. */ + public int getZeroPoint() { + return zeroPoint; + } + } + /** Returns {@code true} if the model has metadata. Otherwise, returns {@code false}. */ public boolean hasMetadata() { return metadataInfo != null; @@ -166,11 +206,11 @@ public class MetadataExtractor { } /** - * Gets the {@link DataType} of the input tensor with {@code inputIndex}. + * Gets the {@link TensorType} of the input tensor with {@code inputIndex}. * * @param inputIndex the index of the desired input tensor */ - public DataType getInputTensorType(int inputIndex) { + public byte getInputTensorType(int inputIndex) { return modelInfo.getInputTensorType(inputIndex); } @@ -221,11 +261,11 @@ public class MetadataExtractor { } /** - * Gets the {@link DataType} of the output tensor with {@code outputIndex}. + * Gets the {@link TensorType} of the output tensor with {@code outputIndex}. * * @param outputIndex the index of the desired output tensor */ - public DataType getOutputTensorType(int outputIndex) { + public byte getOutputTensorType(int outputIndex) { return modelInfo.getOutputTensorType(outputIndex); } diff --git a/tensorflow/lite/experimental/support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ModelInfo.java b/tensorflow/lite/experimental/support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ModelInfo.java index e2905d108d7..309a3dbe774 100644 --- a/tensorflow/lite/experimental/support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ModelInfo.java +++ b/tensorflow/lite/experimental/support/metadata/java/src/java/org/tensorflow/lite/support/metadata/ModelInfo.java @@ -21,12 +21,8 @@ import static org.tensorflow.lite.support.metadata.Preconditions.checkNotNull; import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Collections; -import java.util.HashMap; import java.util.List; -import java.util.Map; import org.checkerframework.checker.nullness.qual.Nullable; -import org.tensorflow.lite.DataType; -import org.tensorflow.lite.Tensor.QuantizationParams; import org.tensorflow.lite.schema.Buffer; import org.tensorflow.lite.schema.Metadata; import org.tensorflow.lite.schema.Model; @@ -34,6 +30,7 @@ import org.tensorflow.lite.schema.QuantizationParameters; import org.tensorflow.lite.schema.SubGraph; import org.tensorflow.lite.schema.Tensor; import org.tensorflow.lite.schema.TensorType; +import org.tensorflow.lite.support.metadata.MetadataExtractor.QuantizationParams; /** Extracts model information out of TFLite model FLatBuffer. */ final class ModelInfo { @@ -49,9 +46,6 @@ final class ModelInfo { /** Identifier of the TFLite model metadata in the Metadata array. */ static final String METADATA_FIELD_NAME = "TFLITE_METADATA"; - /** Maps from TensorType in TFlite FlatBuffer to {@link DataType} in Java. */ - private final Map<Byte, DataType> tensorTypeToDataTypeMap; - /** * Creates a {@link ModelInfo} with the model FlatBuffer, {@code buffer}. * @@ -74,7 +68,6 @@ final class ModelInfo { inputTensors = getInputTensors(model); outputTensors = getOutputTensors(model); - tensorTypeToDataTypeMap = createTensorTypeToDataTypeMap(); } /** @@ -106,13 +99,12 @@ final class ModelInfo { } /** - * Gets {@link DataType} of the input tensor with {@code inputIndex}. + * Gets the {@link TensorType} in byte of the input tensor with {@code inputIndex}. * * @param inputIndex The index of the desired intput tensor. */ - DataType getInputTensorType(int inputIndex) { - Tensor tensor = getInputTensor(inputIndex); - return getDataType(tensor.type()); + byte getInputTensorType(int inputIndex) { + return getInputTensor(inputIndex).type(); } /** Gets the metadata FlatBuffer from the model FlatBuffer. */ @@ -163,13 +155,12 @@ final class ModelInfo { } /** - * Gets {@link DataType} of the output tensor {@code outputIndex}. + * Gets the {@link TensorType} in byte of the output tensor {@code outputIndex}. * * @param outputIndex The index of the desired outtput tensor. */ - DataType getOutputTensorType(int outputIndex) { - Tensor tensor = getOutputTensor(outputIndex); - return getDataType(tensor.type()); + byte getOutputTensorType(int outputIndex) { + return getOutputTensor(outputIndex).type(); } /** @@ -233,29 +224,6 @@ final class ModelInfo { + " flatbuffer."); } - private static Map<Byte, DataType> createTensorTypeToDataTypeMap() { - Map<Byte, DataType> map = new HashMap<>(); - map.put(TensorType.FLOAT32, DataType.FLOAT32); - map.put(TensorType.INT32, DataType.INT32); - map.put(TensorType.UINT8, DataType.UINT8); - map.put(TensorType.INT64, DataType.INT64); - map.put(TensorType.STRING, DataType.STRING); - return Collections.unmodifiableMap(map); - } - - /** - * Transforms from TensorType in TFlite FlatBuffer to {@link DataType} in Java. - * - * @param tensorType The tensor type to be converted. - * @throws IllegalArgumentException if the tensor type is not supported. - */ - private DataType getDataType(byte tensorType) { - checkArgument( - tensorTypeToDataTypeMap.containsKey(tensorType), - String.format("Tensor type %d is not supported.", tensorType)); - return tensorTypeToDataTypeMap.get(tensorType); - } - /** * Gets the shape of a tensor. * diff --git a/tensorflow/lite/experimental/support/metadata/metadata.py b/tensorflow/lite/experimental/support/metadata/metadata.py index 25ca57bb4cc..b3d8d28806b 100644 --- a/tensorflow/lite/experimental/support/metadata/metadata.py +++ b/tensorflow/lite/experimental/support/metadata/metadata.py @@ -28,6 +28,7 @@ import zipfile from flatbuffers.python import flatbuffers from tensorflow.lite.experimental.support.metadata import metadata_schema_py_generated as _metadata_fb from tensorflow.lite.experimental.support.metadata import schema_py_generated as _schema_fb +from tensorflow.lite.experimental.support.metadata.cc.python import _pywrap_metadata_version from tensorflow.lite.experimental.support.metadata.flatbuffers_lib import _pywrap_flatbuffers from tensorflow.python.platform import resource_loader @@ -55,7 +56,7 @@ class MetadataPopulator(object): classifer model using Flatbuffers API. Attach the label file onto the ouput tensor (the tensor of probabilities) in the metadata. - Then, pack the metadata and lable file into the model as follows. + Then, pack the metadata and label file into the model as follows. ```python # Populating a metadata file (or a metadta buffer) and associated files to @@ -78,6 +79,9 @@ class MetadataPopulator(object): with open("updated_model.tflite", "wb") as f: f.write(updated_model_buf) ``` + + Note that existing metadata buffer (if applied) will be overridden by the new + metadata buffer. """ # As Zip API is used to concatenate associated files after tflite model file, # the populating operation is developed based on a model file. For in-memory @@ -218,12 +222,27 @@ class MetadataPopulator(object): Raises: ValueError: The metadata to be populated is empty. ValueError: The metadata does not have the expected flatbuffer identifer. + ValueError: Error occurs when getting the minimum metadata parser version. """ if not metadata_buf: raise ValueError("The metadata to be populated is empty.") _assert_metadata_buffer_identifier(metadata_buf) - self._metadata_buf = metadata_buf + + # Gets the minimum metadata parser version of the metadata_buf. + min_version = _pywrap_metadata_version.GetMinimumMetadataParserVersion( + bytes(metadata_buf)) + + # Inserts in the minimum metadata parser version into the metadata_buf. + metadata = _metadata_fb.ModelMetadataT.InitFromObj( + _metadata_fb.ModelMetadata.GetRootAsModelMetadata(metadata_buf, 0)) + metadata.minParserVersion = min_version + + b = flatbuffers.Builder(0) + b.Finish(metadata.Pack(b), self.METADATA_FILE_IDENTIFIER) + metadata_buf_with_version = b.Output() + + self._metadata_buf = metadata_buf_with_version def load_metadata_file(self, metadata_file): """Loads the metadata file to be populated. @@ -325,6 +344,9 @@ class MetadataPopulator(object): Inserts metadata_buf into the metadata field of schema.Model. If the MetadataPopulator object is created using the method, with_model_file(model_file), the model file will be updated. + + Existing metadata buffer (if applied) will be overridden by the new metadata + buffer. """ with open(self._model_file, "rb") as f: diff --git a/tensorflow/lite/experimental/support/metadata/metadata_schema.fbs b/tensorflow/lite/experimental/support/metadata/metadata_schema.fbs index b8e529ad1c5..a2812e1b6e3 100644 --- a/tensorflow/lite/experimental/support/metadata/metadata_schema.fbs +++ b/tensorflow/lite/experimental/support/metadata/metadata_schema.fbs @@ -317,12 +317,22 @@ table NormalizationOptions{ // mean and std are normalization parameters. Tensor values are normalized // on a per-channel basis, by the formula // (x - mean) / std. - // For example, a float MobileNet model will have - // mean = 127.5f and std = 127.5f. - // A quantized MobileNet model will have - // mean = 0.0f and std = 1.0f. // If there is only one value in mean or std, we'll propogate the value to // all channels. + // + // Quantized models share the same normalization parameters as their + // corresponding float models. For example, an image input tensor may have + // the normalization parameter of + // mean = 127.5f and std = 127.5f. + // The image value will be normalized from [0, 255] to [-1, 1]. + // Then, for quantized models, the image data should be further quantized + // according to the quantization parameters. In the case of uint8, the image + // data will be scaled back to [0, 255], while for int8, the image data will + // be scaled to [-128, 127]. + // + // Both the normalization parameters and quantization parameters can be + // retrieved through the metadata extractor library. + // TODO(b/156644598): add link for the metadata extractor library. // Per-channel mean of the possible values used in normalization. // diff --git a/tensorflow/lite/experimental/support/metadata/metadata_test.py b/tensorflow/lite/experimental/support/metadata/metadata_test.py index 81b3eef62f9..28395041746 100644 --- a/tensorflow/lite/experimental/support/metadata/metadata_test.py +++ b/tensorflow/lite/experimental/support/metadata/metadata_test.py @@ -43,6 +43,8 @@ class MetadataTest(test_util.TensorFlowTestCase): f.write(self._empty_model_buf) self._model_file = self._create_model_file_with_metadata_and_buf_fields() self._metadata_file = self._create_metadata_file() + self._metadata_file_with_version = self._create_metadata_file_with_version( + self._metadata_file, "1.0.0") self._file1 = self.create_tempfile("file1").full_path self._file2 = self.create_tempfile("file2").full_path self._file3 = self.create_tempfile("file3").full_path @@ -135,6 +137,25 @@ class MetadataTest(test_util.TensorFlowTestCase): b.Finish(model.Pack(b), identifier) return b.Output() + def _create_metadata_file_with_version(self, metadata_file, min_version): + # Creates a new metadata file with the specified min_version for testing + # purposes. + with open(metadata_file, "rb") as f: + metadata_buf = bytearray(f.read()) + + metadata = _metadata_fb.ModelMetadataT.InitFromObj( + _metadata_fb.ModelMetadata.GetRootAsModelMetadata(metadata_buf, 0)) + metadata.minParserVersion = min_version + + b = flatbuffers.Builder(0) + b.Finish( + metadata.Pack(b), _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) + + metadata_file_with_version = self.create_tempfile().full_path + with open(metadata_file_with_version, "wb") as f: + f.write(b.Output()) + return metadata_file_with_version + class MetadataPopulatorTest(MetadataTest): @@ -245,7 +266,7 @@ class MetadataPopulatorTest(MetadataTest): buffer_data = model.Buffers(buffer_index) metadata_buf_np = buffer_data.DataAsNumpy() metadata_buf = metadata_buf_np.tobytes() - with open(self._metadata_file, "rb") as f: + with open(self._metadata_file_with_version, "rb") as f: expected_metadata_buf = bytearray(f.read()) self.assertEqual(metadata_buf, expected_metadata_buf) @@ -293,7 +314,7 @@ class MetadataPopulatorTest(MetadataTest): buffer_data = model.Buffers(buffer_index) metadata_buf_np = buffer_data.DataAsNumpy() metadata_buf = metadata_buf_np.tobytes() - with open(self._metadata_file, "rb") as f: + with open(self._metadata_file_with_version, "rb") as f: expected_metadata_buf = bytearray(f.read()) self.assertEqual(metadata_buf, expected_metadata_buf) diff --git a/tensorflow/lite/experimental/support/metadata/testdata/golden_json.json b/tensorflow/lite/experimental/support/metadata/testdata/golden_json.json index bc3001e685a..9ff5581fbff 100644 --- a/tensorflow/lite/experimental/support/metadata/testdata/golden_json.json +++ b/tensorflow/lite/experimental/support/metadata/testdata/golden_json.json @@ -17,5 +17,6 @@ { "name": "file1" } - ] + ], + "min_parser_version": "1.0.0" } diff --git a/tensorflow/lite/experimental/swift/BUILD.apple b/tensorflow/lite/experimental/swift/BUILD.apple index e671721dd1c..b5e502b90f0 100644 --- a/tensorflow/lite/experimental/swift/BUILD.apple +++ b/tensorflow/lite/experimental/swift/BUILD.apple @@ -2,7 +2,7 @@ load("//tensorflow/lite:special_rules.bzl", "ios_visibility_whitelist", "tflite_ios_lab_runner") load("//tensorflow/lite/experimental/ios:ios.bzl", "TFL_DEFAULT_TAGS", "TFL_DISABLED_SANITIZER_TAGS", "TFL_MINIMUM_OS_VERSION") -load("@build_bazel_rules_apple//apple:ios.bzl", "ios_application", "ios_unit_test") +load("@build_bazel_rules_apple//apple:ios.bzl", "ios_application", "ios_static_framework", "ios_unit_test") load("@build_bazel_rules_swift//swift:swift.bzl", "swift_library") package( @@ -34,11 +34,25 @@ swift_library( tags = TFL_DEFAULT_TAGS, visibility = ios_visibility_whitelist(), deps = [ + "//tensorflow/lite/delegates/gpu:metal_delegate", "//tensorflow/lite/experimental/delegates/coreml:coreml_delegate", "//tensorflow/lite/experimental/ios:tensorflow_lite_c", ], ) +# bazel build -c opt --config=ios_fat //tensorflow/lite/experimental/swift:TensorFlowLite_framework +ios_static_framework( + name = "TensorFlowLite_framework", + avoid_deps = [ + "//tensorflow/lite/experimental/ios:tensorflow_lite_c", + ], + bundle_name = "TensorFlowLite", + minimum_os_version = TFL_MINIMUM_OS_VERSION, + deps = [ + ":TensorFlowLite", + ], +) + ios_unit_test( name = "Tests", size = "small", diff --git a/tensorflow/lite/experimental/swift/Sources/MetalDelegate.swift b/tensorflow/lite/experimental/swift/Sources/MetalDelegate.swift index 8fd15f303da..6cde2533f95 100644 --- a/tensorflow/lite/experimental/swift/Sources/MetalDelegate.swift +++ b/tensorflow/lite/experimental/swift/Sources/MetalDelegate.swift @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -import TensorFlowLiteC +import TensorFlowLiteCMetal /// A delegate that uses the `Metal` framework for performing TensorFlow Lite graph operations with /// GPU acceleration. diff --git a/tensorflow/lite/experimental/swift/TensorFlowLiteSwift.podspec.template b/tensorflow/lite/experimental/swift/TensorFlowLiteSwift.podspec.template index a925112f539..1e414f1959f 100644 --- a/tensorflow/lite/experimental/swift/TensorFlowLiteSwift.podspec.template +++ b/tensorflow/lite/experimental/swift/TensorFlowLiteSwift.podspec.template @@ -26,7 +26,16 @@ Pod::Spec.new do |s| s.subspec 'Core' do |core| core.dependency 'TensorFlowLiteC', "#{s.version}" core.source_files = swift_dir + 'Sources/*.swift' - core.exclude_files = swift_dir + 'Sources/CoreMLDelegate.swift' + core.exclude_files = swift_dir + 'Sources/{CoreML,Metal}Delegate.swift' + + core.test_spec 'Tests' do |ts| + ts.source_files = swift_dir + 'Tests/*.swift' + ts.exclude_files = swift_dir + 'Tests/MetalDelegateTests.swift' + ts.resources = [ + tfl_dir + 'testdata/add.bin', + tfl_dir + 'testdata/add_quantized.bin', + ] + end end s.subspec 'CoreML' do |coreml| @@ -35,11 +44,17 @@ Pod::Spec.new do |s| coreml.dependency 'TensorFlowLiteSwift/Core', "#{s.version}" end - s.test_spec 'Tests' do |ts| - ts.source_files = swift_dir + 'Tests/*.swift' - ts.resources = [ - tfl_dir + 'testdata/add.bin', - tfl_dir + 'testdata/add_quantized.bin', - ] + s.subspec 'Metal' do |metal| + metal.source_files = swift_dir + 'Sources/MetalDelegate.swift' + metal.dependency 'TensorFlowLiteC/Metal', "#{s.version}" + metal.dependency 'TensorFlowLiteSwift/Core', "#{s.version}" + + metal.test_spec 'Tests' do |ts| + ts.source_files = swift_dir + 'Tests/{Interpreter,MetalDelegate}Tests.swift' + ts.resources = [ + tfl_dir + 'testdata/add.bin', + tfl_dir + 'testdata/add_quantized.bin', + ] + end end end diff --git a/tensorflow/lite/experimental/swift/Tests/InterpreterTests.swift b/tensorflow/lite/experimental/swift/Tests/InterpreterTests.swift index 09b001cb0cb..8d0140279af 100644 --- a/tensorflow/lite/experimental/swift/Tests/InterpreterTests.swift +++ b/tensorflow/lite/experimental/swift/Tests/InterpreterTests.swift @@ -50,26 +50,6 @@ class InterpreterTests: XCTestCase { XCTAssertNil(interpreter.delegates) } - func testInitWithDelegate() throws { - let metalDelegate = MetalDelegate() - let interpreter = try Interpreter(modelPath: AddQuantizedModel.path, delegates: [metalDelegate]) - XCTAssertEqual(interpreter.delegates?.count, 1) - XCTAssertNil(interpreter.options) - } - - func testInitWithOptionsAndDelegate() throws { - var options = Interpreter.Options() - options.threadCount = 1 - let metalDelegate = MetalDelegate() - let interpreter = try Interpreter( - modelPath: AddQuantizedModel.path, - options: options, - delegates: [metalDelegate] - ) - XCTAssertNotNil(interpreter.options) - XCTAssertEqual(interpreter.delegates?.count, 1) - } - func testInputTensorCount() { XCTAssertEqual(interpreter.inputTensorCount, AddModel.inputTensorCount) } @@ -268,7 +248,7 @@ class InterpreterOptionsTests: XCTestCase { // MARK: - Constants /// Values for the `add.bin` model. -private enum AddModel { +enum AddModel { static let info = (name: "add", extension: "bin") static let inputTensorCount = 1 static let outputTensorCount = 1 @@ -301,7 +281,7 @@ private enum AddModel { } /// Values for the `add_quantized.bin` model. -private enum AddQuantizedModel { +enum AddQuantizedModel { static let info = (name: "add_quantized", extension: "bin") static let inputOutputIndex = 0 static let shape: Tensor.Shape = [2] diff --git a/tensorflow/lite/experimental/swift/Tests/MetalDelegateTests.swift b/tensorflow/lite/experimental/swift/Tests/MetalDelegateTests.swift index 6daa429e2f0..8af43842d7a 100644 --- a/tensorflow/lite/experimental/swift/Tests/MetalDelegateTests.swift +++ b/tensorflow/lite/experimental/swift/Tests/MetalDelegateTests.swift @@ -32,6 +32,26 @@ class MetalDelegateTests: XCTestCase { XCTAssertTrue(delegate.options.allowsPrecisionLoss) XCTAssertEqual(delegate.options.waitType, .active) } + + func testInitInterpreterWithDelegate() throws { + let metalDelegate = MetalDelegate() + let interpreter = try Interpreter(modelPath: AddQuantizedModel.path, delegates: [metalDelegate]) + XCTAssertEqual(interpreter.delegates?.count, 1) + XCTAssertNil(interpreter.options) + } + + func testInitInterpreterWithOptionsAndDelegate() throws { + var options = Interpreter.Options() + options.threadCount = 1 + let metalDelegate = MetalDelegate() + let interpreter = try Interpreter( + modelPath: AddQuantizedModel.path, + options: options, + delegates: [metalDelegate] + ) + XCTAssertNotNil(interpreter.options) + XCTAssertEqual(interpreter.delegates?.count, 1) + } } class MetalDelegateOptionsTests: XCTestCase { diff --git a/tensorflow/lite/g3doc/convert/1x_compatibility.md b/tensorflow/lite/g3doc/convert/1x_compatibility.md index 9f9f277a8d9..ceb99bad5e2 100644 --- a/tensorflow/lite/g3doc/convert/1x_compatibility.md +++ b/tensorflow/lite/g3doc/convert/1x_compatibility.md @@ -34,7 +34,7 @@ input_arrays = ['input_name'] # A list of the names of the model's output tensors output_arrays = ['output_name'] # Load and convert the frozen graph -converter = tf.lite.TFLiteConverter.from_frozen_graph( +converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph( graph_def_file, input_arrays, output_arrays) tflite_model = converter.convert() # Write the converted model to disk diff --git a/tensorflow/lite/g3doc/guide/build_rpi.md b/tensorflow/lite/g3doc/guide/build_rpi.md index 1e04ee77a0e..dfe3709b024 100644 --- a/tensorflow/lite/g3doc/guide/build_rpi.md +++ b/tensorflow/lite/g3doc/guide/build_rpi.md @@ -5,87 +5,98 @@ Raspberry Pi. If you just want to start using TensorFlow Lite to execute your models, the fastest option is to install the TensorFlow Lite runtime package as shown in the [Python quickstart](python.md). -Note: This page shows how to compile only the C++ static library for -TensorFlow Lite. Alternative install options include: [install just the Python -interpreter API](python.md) (for inferencing only); [install the full -TensorFlow package from pip](https://www.tensorflow.org/install/pip); -or [build the full TensorFlow package]( -https://www.tensorflow.org/install/source_rpi). - +**Note:** This page shows how to compile only the C++ static library for +TensorFlow Lite. Alternative install options include: +[install just the Python interpreter API](python.md) (for inferencing only); +[install the full TensorFlow package from pip](https://www.tensorflow.org/install/pip); +or +[build the full TensorFlow package](https://www.tensorflow.org/install/source_rpi). ## Cross-compile for Raspberry Pi -This has been tested on Ubuntu 16.04.3 64bit and TensorFlow devel docker image +Instruction has been tested on Ubuntu 16.04.3 64-bit PC (AMD64) and TensorFlow +devel docker image [tensorflow/tensorflow:nightly-devel](https://hub.docker.com/r/tensorflow/tensorflow/tags/). -To cross compile TensorFlow Lite, first install the toolchain and libs: +To cross compile TensorFlow Lite follow the steps: -```bash -sudo apt-get update -sudo apt-get install crossbuild-essential-armhf -# The following is only needed for Pi Zero build. -sudo apt-get install crossbuild-essential-armel -``` +1. Clone official Raspberry Pi cross-compilation toolchain: -If you are using Docker, you may not use `sudo`. + ```bash + git clone https://github.com/raspberrypi/tools.git rpi_tools + ``` -Now git-clone the TensorFlow repository -(`https://github.com/tensorflow/tensorflow`)—if you're using the TensorFlow -Docker image, the repo is already provided in `/tensorflow_src/`—and then run -this script at the root of the TensorFlow repository to download all the -build dependencies: +2. Clone TensorFlow repository: -```bash -./tensorflow/lite/tools/make/download_dependencies.sh -``` + ```bash + git clone https://github.com/tensorflow/tensorflow.git tensorflow_src -Note that you only need to do this once. + ``` -You should then be able to compile: + **Note:** If you're using the TensorFlow Docker image, the repo is already + provided in `/tensorflow_src/`. -To build ARMv7 binary for Raspberry Pi 2, 3 and 4: +3. Run following script at the root of the TensorFlow repository to download + all the build dependencies: -```bash -./tensorflow/lite/tools/make/build_rpi_lib.sh -``` + ```bash + cd tensorflow_src && ./tensorflow/lite/tools/make/download_dependencies.sh + ``` -This should compile a static library in: -`tensorflow/lite/tools/make/gen/rpi_armv7l/lib/libtensorflow-lite.a`. + **Note:** You only need to do this once. -To build ARMv6 binary for Raspberry Pi Zero: +4. To build ARMv7 binary for Raspberry Pi 2, 3 and 4 execute: -```bash -./tensorflow/lite/tools/make/build_rpi_lib.sh TARGET_ARCH=armv6 -``` + ```bash + PATH=../rpi_tools/arm-bcm2708/arm-rpi-4.9.3-linux-gnueabihf/bin:$PATH ./tensorflow/lite/tools/make/build_rpi_lib.sh + ``` -This should compile a static library in: -`tensorflow/lite/tools/make/gen/rpi_armv6/lib/libtensorflow-lite.a`. + **Note:** This should compile a static library in: + `tensorflow/lite/tools/make/gen/rpi_armv7l/lib/libtensorflow-lite.a`. + +5. To build ARMv6 binary for Raspberry Pi Zero execute: + + ```bash + PATH=../rpi_tools/arm-bcm2708/arm-rpi-4.9.3-linux-gnueabihf/bin:$PATH ./tensorflow/lite/tools/make/build_rpi_lib.sh TARGET_ARCH=armv6 + ``` + + **Note:** This should compile a static library in: + `tensorflow/lite/tools/make/gen/rpi_armv6/lib/libtensorflow-lite.a`. ## Compile natively on Raspberry Pi -This has been tested on Raspberry Pi 3b, Raspbian GNU/Linux 9.1 (stretch), gcc version 6.3.0 20170516 (Raspbian 6.3.0-18+rpi1). +Instruction has been tested on Raspberry Pi Zero, Raspbian GNU/Linux 10 +(buster), gcc version 8.3.0 (Raspbian 8.3.0-6+rpi1): -Log in to your Raspberry Pi and install the toolchain: +To natively compile TensorFlow Lite follow the steps: -```bash -sudo apt-get install build-essential -``` +1. Log in to your Raspberry Pi and install the toolchain: -Now git-clone the TensorFlow repository -(`https://github.com/tensorflow/tensorflow`) and run this at the root of -the repository: + ```bash + sudo apt-get install build-essential + ``` -```bash -./tensorflow/lite/tools/make/download_dependencies.sh -``` +2. Clone TensorFlow repository: -Note that you only need to do this once. + ```bash + git clone https://github.com/tensorflow/tensorflow.git tensorflow_src -You should then be able to compile: + ``` -```bash -./tensorflow/lite/tools/make/build_rpi_lib.sh -``` +3. Run following script at the root of the TensorFlow repository to download + all the build dependencies: -This should compile a static library in: -`tensorflow/lite/tools/make/gen/lib/rpi_armv7/libtensorflow-lite.a`. + ```bash + cd tensorflow_src && ./tensorflow/lite/tools/make/download_dependencies.sh + ``` + + **Note:** You only need to do this once. + +4. You should then be able to compile TensorFlow Lite with: + + ```bash + ./tensorflow/lite/tools/make/build_rpi_lib.sh + ``` + + **Note:** This should compile a static library in: + `tensorflow/lite/tools/make/gen/lib/rpi_armv6/libtensorflow-lite.a`. diff --git a/tensorflow/lite/g3doc/guide/roadmap.md b/tensorflow/lite/g3doc/guide/roadmap.md index 35ef44a7dbf..b762db12c44 100644 --- a/tensorflow/lite/g3doc/guide/roadmap.md +++ b/tensorflow/lite/g3doc/guide/roadmap.md @@ -1,4 +1,4 @@ -# TensorFlow Lite 2019 Roadmap +# TensorFlow Lite Roadmap **Updated: April 18, 2020** diff --git a/tensorflow/lite/g3doc/performance/coreml_delegate.md b/tensorflow/lite/g3doc/performance/coreml_delegate.md index c267347cf3f..c3d72b2e01f 100644 --- a/tensorflow/lite/g3doc/performance/coreml_delegate.md +++ b/tensorflow/lite/g3doc/performance/coreml_delegate.md @@ -160,7 +160,7 @@ devices using other libraries such as ### Using older Core ML version -Although iOS 13 supprots Core ML 3, the model might work better when it is +Although iOS 13 supports Core ML 3, the model might work better when it is converted with Core ML 2 model specification. The target conversion version is set to the latest version by default, but you can change this by setting `coreMLVersion` (in Swift, `coreml_version` in C API) in the delegate option to diff --git a/tensorflow/lite/g3doc/performance/gpu.md b/tensorflow/lite/g3doc/performance/gpu.md index 8762afb4c83..b5abf46f845 100644 --- a/tensorflow/lite/g3doc/performance/gpu.md +++ b/tensorflow/lite/g3doc/performance/gpu.md @@ -31,7 +31,7 @@ models. For a step-by-step tutorial, watch the [GPU Delegate for Android](https://youtu.be/Xkhgre8r5G0) video. -Note: This requires OpenGL ES 3.1 or higher. +Note: This requires OpenCL or OpenGL ES (3.1 or higher). #### Step 1. Clone the TensorFlow source code and open it in Android Studio diff --git a/tensorflow/lite/g3doc/performance/gpu_advanced.md b/tensorflow/lite/g3doc/performance/gpu_advanced.md index 9f47c2e55e8..dce3eb8db6b 100644 --- a/tensorflow/lite/g3doc/performance/gpu_advanced.md +++ b/tensorflow/lite/g3doc/performance/gpu_advanced.md @@ -1,9 +1,9 @@ # TensorFlow Lite on GPU [TensorFlow Lite](https://www.tensorflow.org/mobile/tflite/) supports several -hardware accelerators. This document describes how to use the GPU backend using -the TensorFlow Lite delegate APIs on Android (requires OpenGL ES 3.1 or higher) -and iOS (requires iOS 8 or later). +hardware accelerators. This document describes how to use the GPU backend using +the TensorFlow Lite delegate APIs on Android (requires OpenCL or OpenGL ES 3.1 +and higher) and iOS (requires iOS 8 or later). ## Benefits of GPU Acceleration @@ -35,25 +35,33 @@ power and generating less heat than the same task run on a CPU. TensorFlow Lite on GPU supports the following ops in 16-bit and 32-bit float precision: -* `ADD v1` -* `AVERAGE_POOL_2D v1` -* `CONCATENATION v1` -* `CONV_2D v1` -* `DEPTHWISE_CONV_2D v1-2` -* `FULLY_CONNECTED v1` -* `LOGISTIC v1` -* `MAX_POOL_2D v1` -* `MUL v1` -* `PAD v1` -* `PRELU v1` -* `RELU v1` -* `RELU6 v1` -* `RESHAPE v1` -* `RESIZE_BILINEAR v1` -* `SOFTMAX v1` -* `STRIDED_SLICE v1` -* `SUB v1` -* `TRANSPOSE_CONV v1` +* `ADD` +* `AVERAGE_POOL_2D` +* `CONCATENATION` +* `CONV_2D` +* `DEPTHWISE_CONV_2D v1-2` +* `EXP` +* `FULLY_CONNECTED` +* `LOGISTIC` +* `LSTM v2 (Basic LSTM only)` +* `MAX_POOL_2D` +* `MAXIMUM` +* `MINIMUM` +* `MUL` +* `PAD` +* `PRELU` +* `RELU` +* `RELU6` +* `RESHAPE` +* `RESIZE_BILINEAR v1-3` +* `SOFTMAX` +* `STRIDED_SLICE` +* `SUB` +* `TRANSPOSE_CONV` + +By default, all ops are only supported at version 1. Enabling the +[experimental quantization support](gpu_advanced.md#running-quantized-models-experimental-android-only) +allows the appropriate versions; for example, ADD v2. ## Basic Usage @@ -82,8 +90,8 @@ delegate.close(); ### Android (C/C++) For C/C++ usage of TensorFlow Lite GPU on Android, the GPU delegate can be -created with `TfLiteGpuDelegateCreate()` and destroyed with -`TfLiteGpuDelegateDelete()`. +created with `TfLiteGpuDelegateV2Create()` and destroyed with +`TfLiteGpuDelegateV2Delete()`. ```c++ // Set up interpreter. @@ -94,15 +102,7 @@ std::unique_ptr<Interpreter> interpreter; InterpreterBuilder(*model, op_resolver)(&interpreter); // NEW: Prepare GPU delegate. -const TfLiteGpuDelegateOptions options = { - .metadata = NULL, - .compile_options = { - .precision_loss_allowed = 1, // FP16 - .preferred_gl_object_type = TFLITE_GL_OBJECT_TYPE_FASTEST, - .dynamic_batch_enabled = 0, // Not fully functional yet - }, -}; -auto* delegate = TfLiteGpuDelegateCreate(&options); +auto* delegate = TfLiteGpuDelegateV2Create(/*default options=*/nullptr); if (interpreter->ModifyGraphWithDelegate(delegate) != kTfLiteOk) return false; // Run inference. @@ -111,9 +111,13 @@ if (interpreter->Invoke() != kTfLiteOk) return false; ReadFromOutputTensor(interpreter->typed_output_tensor<float>(0)); // NEW: Clean up. -TfLiteGpuDelegateDelete(delegate); +TfLiteGpuDelegateV2Delete(delegate); ``` +Take a look at `TfLiteGpuDelegateOptionsV2` to create a delegate instance with +custom options. You can initialize the default options with +`TfLiteGpuDelegateOptionsV2Default()` and then modify them as necessary. + TFLite GPU for Android C/C++ uses the [Bazel](https://bazel.io) build system. The delegate can be built, for example, using the following command: @@ -165,6 +169,43 @@ called. ## Advanced Usage +### Running quantized models (Experimental, Android only) + +The GPU delegate already supports +[float16 quantized](https://www.tensorflow.org/lite/performance/post_training_float16_quant) +models. There is experimental support on Android to run 8-bit quantized as well. +This includes all flavors of quantization, including: + +* Models trained with + [Quantization-aware training](https://www.tensorflow.org/lite/convert/quantization) +* [Post-training dynamic-range quantization](https://www.tensorflow.org/lite/performance/post_training_quant) +* [Post-training full-integer quantization](https://www.tensorflow.org/lite/performance/post_training_integer_quant) + +To optimize performance, use models that have floating-point input & output +tensors. + +This feature can be enabled using delegate options as follows: + +**C++ API** + +```c++ +// NEW: Prepare custom options with feature enabled. +TfLiteGpuDelegateOptionsV2 options = TfLiteGpuDelegateOptionsV2Default(); +options.experimental_flags |= TFLITE_GPU_EXPERIMENTAL_FLAGS_ENABLE_QUANT; + +auto* delegate = TfLiteGpuDelegateV2Create(options); +if (interpreter->ModifyGraphWithDelegate(delegate) != kTfLiteOk) return false; +``` + +**Java API** + +```java +// NEW: Prepare GPU delegate with feature turned on. +GpuDelegate delegate = new GpuDelegate(new GpuDelegate.Options().setQuantizedModelsAllowed(true)); + +Interpreter.Options options = (new Interpreter.Options()).addDelegate(delegate); +``` + ### Delegate Options for iOS `NewGpuDelegate()` accepts a `struct` of options. @@ -210,7 +251,7 @@ While it is convenient to use `nullptr`, we recommend that you explicitly set the options, to avoid any unexpected behavior if default values are changed in the future. -### Input/Output Buffers +### Input/Output Buffers (iOS only) To do computation on the GPU, data must be made available to the GPU. This often requires performing a memory copy. It is desirable not to cross the CPU/GPU @@ -229,80 +270,10 @@ To achieve best performance, TensorFlow Lite makes it possible for users to directly read from and write to the TensorFlow hardware buffer and bypass avoidable memory copies. -#### Android - -Assuming the image input is in the GPU memory, it must first be converted to an -OpenGL Shader Storage Buffer Object (SSBO). You can associate a TfLiteTensor to -a user-prepared SSBO with `Interpreter.bindGlBufferToTensor()`. Note that -`Interpreter.bindGlBufferToTensor()` must be called before -`Interpreter.modifyGraphWithDelegate()`. - -```java -// Ensure a valid EGL rendering context. -EGLContext eglContext = eglGetCurrentContext(); -if (eglContext.equals(EGL_NO_CONTEXT)) return false; - -// Create an SSBO. -int[] id = new int[1]; -glGenBuffers(id.length, id, 0); -glBindBuffer(GL_SHADER_STORAGE_BUFFER, id[0]); -glBufferData(GL_SHADER_STORAGE_BUFFER, inputSize, null, GL_STREAM_COPY); -glBindBuffer(GL_SHADER_STORAGE_BUFFER, 0); // unbind -int inputSsboId = id[0]; - -// Create interpreter. -Interpreter interpreter = new Interpreter(tfliteModel); -Tensor inputTensor = interpreter.getInputTensor(0); -GpuDelegate gpuDelegate = new GpuDelegate(); -// The buffer must be bound before the delegate is installed. -gpuDelegate.bindGlBufferToTensor(inputTensor, inputSsboId); -interpreter.modifyGraphWithDelegate(gpuDelegate); - -// Run inference; the null input argument indicates use of the bound buffer for input. -fillSsboWithCameraImageTexture(inputSsboId); -float[] outputArray = new float[outputSize]; -interpreter.runInference(null, outputArray); -``` - -A similar approach can be applied to the output tensor. In that case, -`Interpreter.Options.setAllowBufferHandleOutput(true)` should be passed on, to -disable the default copying of the network's output from GPU memory to CPU -memory. - -```java -// Ensure a valid EGL rendering context. -EGLContext eglContext = eglGetCurrentContext(); -if (eglContext.equals(EGL_NO_CONTEXT)) return false; - -// Create a SSBO. -int[] id = new int[1]; -glGenBuffers(id.length, id, 0); -glBindBuffer(GL_SHADER_STORAGE_BUFFER, id[0]); -glBufferData(GL_SHADER_STORAGE_BUFFER, outputSize, null, GL_STREAM_COPY); -glBindBuffer(GL_SHADER_STORAGE_BUFFER, 0); // unbind -int outputSsboId = id[0]; - -// Create interpreter. -Interpreter.Options options = (new Interpreter.Options()).setAllowBufferHandleOutput(true); -Interpreter interpreter = new Interpreter(tfliteModel, options); -Tensor outputTensor = interpreter.getOutputTensor(0); -GpuDelegate gpuDelegate = new GpuDelegate(); -// The buffer must be bound before the delegate is installed. -gpuDelegate.bindGlBufferToTensor(outputTensor, outputSsboId); -interpreter.modifyGraphWithDelegate(gpuDelegate); - -// Run inference; the null output argument indicates use of the bound buffer for output. -ByteBuffer input = getCameraImageByteBuffer(); -interpreter.runInference(input, null); -renderOutputSsbo(outputSsboId); -``` - -#### iOS - Assuming the image input is in GPU memory, it must first be converted to a `MTLBuffer` object for Metal. You can associate a TfLiteTensor to a -user-prepared `MTLBuffer` with `BindMetalBufferToTensor()`. Note that -`BindMetalBufferToTensor()` must be called before +user-prepared `MTLBuffer` with `TFLGpuDelegateBindMetalBufferToTensor()`. Note +that `TFLGpuDelegateBindMetalBufferToTensor()` must be called before `Interpreter::ModifyGraphWithDelegate()`. Additionally, the inference output is, by default, copied from GPU memory to CPU memory. This behavior can be turned off by calling `Interpreter::SetAllowBufferHandleOutput(true)` during @@ -312,8 +283,8 @@ initialization. // Prepare GPU delegate. auto* delegate = NewGpuDelegate(nullptr); interpreter->SetAllowBufferHandleOutput(true); // disable default gpu->cpu copy -if (!BindMetalBufferToTensor(delegate, interpreter->inputs()[0], user_provided_input_buffer)) return false; -if (!BindMetalBufferToTensor(delegate, interpreter->outputs()[0], user_provided_output_buffer)) return false; +if (!TFLGpuDelegateBindMetalBufferToTensor(delegate, interpreter->inputs()[0], user_provided_input_buffer)) return false; +if (!TFLGpuDelegateBindMetalBufferToTensor(delegate, interpreter->outputs()[0], user_provided_output_buffer)) return false; if (interpreter->ModifyGraphWithDelegate(delegate) != kTfLiteOk) return false; // Run inference. diff --git a/tensorflow/lite/g3doc/performance/hexagon_delegate.md b/tensorflow/lite/g3doc/performance/hexagon_delegate.md index 60fe9465bf4..0e947d1d5e1 100644 --- a/tensorflow/lite/g3doc/performance/hexagon_delegate.md +++ b/tensorflow/lite/g3doc/performance/hexagon_delegate.md @@ -22,15 +22,15 @@ are supported, including: **Supported models:** -The Hexagon delegate currently supports quantized models generated using -[quantization-aware training](https://github.com/tensorflow/tensorflow/tree/r1.13/tensorflow/contrib/quantize), -e.g., -[these quantized models](https://www.tensorflow.org/lite/guide/hosted_models#quantized_models) -hosted on the TensorFlow Lite repo. It does not (yet) support models with -[8-bit symmetric quantization spec](https://www.tensorflow.org/lite/performance/quantization_spec). -Sample models include -[MobileNet V1](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz), -[SSD Mobilenet](https://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip). +The Hexagon delegate supports all models that conform to our +[8-bit symmetric quantization spec](https://www.tensorflow.org/lite/performance/quantization_spec), +including those generated using +[post-training integer quantization](https://www.tensorflow.org/lite/performance/post_training_integer_quant). +UInt8 models trained with the legacy +[quantization-aware training](https://github.com/tensorflow/tensorflow/tree/r1.13/tensorflow/contrib/quantize) +path are also supported, for e.g., +[these quantized versions](https://www.tensorflow.org/lite/guide/hosted_models#quantized_models) +on our Hosted Models page. ## Hexagon Delegate Java API @@ -254,10 +254,6 @@ ro.board.platform`). ## FAQ -* Will the delegate support models created using - [post-training quantization](https://www.tensorflow.org/lite/performance/post_training_quantization)? - * This is tentatively planned for a future release, though there is no - concrete timeline. * Which ops are supported by the delegate? * See the current list of [supported ops and constraints](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/experimental/delegates/hexagon/README.md) * How can I tell that the model is using the DSP when I enable the delegate? diff --git a/tensorflow/lite/g3doc/performance/model_optimization.md b/tensorflow/lite/g3doc/performance/model_optimization.md index feb6cfecea6..c66b06f9b59 100644 --- a/tensorflow/lite/g3doc/performance/model_optimization.md +++ b/tensorflow/lite/g3doc/performance/model_optimization.md @@ -89,9 +89,9 @@ The following types of quantization are available in TensorFlow Lite: Technique | Data requirements | Size reduction | Accuracy | Supported hardware ------------------------------------------------------------------------------------------------------- | -------------------------------- | -------------- | --------------------------- | ------------------ [Post-training float16 quantization](post_training_float16_quant.ipynb) | No data | Up to 50% | Insignificant accuracy loss | CPU, GPU -[Post-training dynamic range quantization](post_training_quant.ipynb) | No data | Up to 75% | Accuracy loss | CPU -[Post-training integer quantization](post_training_integer_quant.ipynb) | Unlabelled representative sample | Up to 75% | Smaller accuracy loss | CPU, EdgeTPU, Hexagon DSP -[Quantization-aware training](http://www.tensorflow.org/model_optimization/guide/quantization/training) | Labelled training data | Up to 75% | Smallest accuracy loss | CPU, EdgeTPU, Hexagon DSP +[Post-training dynamic range quantization](post_training_quant.ipynb) | No data | Up to 75% | Accuracy loss | CPU, GPU (Android) +[Post-training integer quantization](post_training_integer_quant.ipynb) | Unlabelled representative sample | Up to 75% | Smaller accuracy loss | CPU, GPU (Android), EdgeTPU, Hexagon DSP +[Quantization-aware training](http://www.tensorflow.org/model_optimization/guide/quantization/training) | Labelled training data | Up to 75% | Smallest accuracy loss | CPU, GPU (Android), EdgeTPU, Hexagon DSP Below are the latency and accuracy results for post-training quantization and quantization-aware training on a few models. All latency numbers are measured on diff --git a/tensorflow/lite/g3doc/tutorials/model_maker_text_classification.ipynb b/tensorflow/lite/g3doc/tutorials/model_maker_text_classification.ipynb index 8261d6c9e34..e10507ccac7 100644 --- a/tensorflow/lite/g3doc/tutorials/model_maker_text_classification.ipynb +++ b/tensorflow/lite/g3doc/tutorials/model_maker_text_classification.ipynb @@ -632,7 +632,7 @@ "id": "EoWiA_zX8rxE" }, "source": [ - "# Advanced Usage\n", + "## Advanced Usage\n", "\n", "The `create` function is the critical part of this library in which parameter `model_spec` defines the specification of the model, currently `AverageWordVecModelSpec` and `BertModelSpec` is supported. The `create` function contains the following steps for `AverageWordVecModelSpec`:\n", "\n", @@ -651,7 +651,7 @@ "id": "mwtiksguDfhl" }, "source": [ - "# Adjust the model\n", + "## Adjust the model\n", "\n", "We could adjust the model infrastructure like variables `wordvec_dim`, `seq_len` in `AverageWordVecModelSpec` class.\n" ] @@ -736,7 +736,7 @@ "id": "LvQuy7RSDir3" }, "source": [ - "## Change the training hyperparameters\n", + "### Change the training hyperparameters\n", "We could also change the training hyperparameters like `epochs` and `batch_size` that could affect the model accuracy. For instance,\n", "\n", "* `epochs`: more epochs could achieve better accuracy, but may lead to overfitting.\n", @@ -788,7 +788,7 @@ "id": "Eq6B9lKMfhS6" }, "source": [ - "## Change the Model\n", + "### Change the Model\n", "\n", "We could change the model by changing the `model_spec`. The following shows how we change to BERT-base model.\n", "\n", diff --git a/tensorflow/lite/interpreter.cc b/tensorflow/lite/interpreter.cc index c8ccf671d60..167254a2a62 100644 --- a/tensorflow/lite/interpreter.cc +++ b/tensorflow/lite/interpreter.cc @@ -310,6 +310,8 @@ void Interpreter::SetCancellationFunction(void* data, } } +bool Interpreter::IsCancelled() { return primary_subgraph().IsCancelled(); } + TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate) { TfLiteStatus status = kTfLiteOk; for (auto& subgraph : subgraphs_) { @@ -340,6 +342,8 @@ TfLiteStatus Interpreter::RemoveAllDelegates() { return kTfLiteOk; } +bool Interpreter::HasDelegates() { return primary_subgraph().HasDelegates(); } + TfLiteStatus Interpreter::SetBufferHandle(int tensor_index, TfLiteBufferHandle buffer_handle, TfLiteDelegate* delegate) { diff --git a/tensorflow/lite/interpreter.h b/tensorflow/lite/interpreter.h index b93fd76c13b..0e01ce44e0c 100644 --- a/tensorflow/lite/interpreter.h +++ b/tensorflow/lite/interpreter.h @@ -42,6 +42,9 @@ namespace tflite { class InterpreterTest; class TestDelegate; +namespace delegates { +class InterpreterUtils; // Class for friend declarations. +} // namespace delegates namespace impl { @@ -322,10 +325,9 @@ class Interpreter { /// Change the dimensionality of a given tensor. Note, this is only acceptable /// for tensor indices that are inputs or variables. - /// Returns status of failure or success. - /// TODO(aselle): Consider implementing ArraySlice equivalent to make this - /// more adept at accepting data without an extra copy. Use absl::ArraySlice - /// if our partners determine that dependency is acceptable. + /// Returns status of failure or success. Note that this doesn't actually + /// resize any existing buffers. A call to AllocateTensors() is required to + /// change the tensor input buffer. TfLiteStatus ResizeInputTensor(int tensor_index, const std::vector<int>& dims); @@ -334,7 +336,8 @@ class Interpreter { // tensor indices that are inputs or variables. Only unknown dimensions can be // resized with this function. Unknown dimensions are indicated as `-1` in the // `dims_signature` attribute of a `TfLiteTensor`. Returns status of failure - // or success. + // or success. Note that this doesn't actually resize any existing buffers. + /// A call to AllocateTensors() is required to change the tensor input buffer. TfLiteStatus ResizeInputTensorStrict(int tensor_index, const std::vector<int>& dims); @@ -529,6 +532,7 @@ class Interpreter { friend class InterpreterBuilder; friend class tflite::InterpreterTest; friend class tflite::TestDelegate; + friend class tflite::delegates::InterpreterUtils; /// Set the value of an external context. static void SetExternalContext(struct TfLiteContext* context, @@ -542,6 +546,15 @@ class Interpreter { // afterwards. TfLiteStatus RemoveAllDelegates(); + // Returns true if delegates have been applied. + bool HasDelegates(); + + // Returns true if cancellation function returns true. + bool IsCancelled(); + + // Get the error reporter associated with this interpreter. + ErrorReporter* error_reporter() { return error_reporter_; } + // A pure C data structure used to communicate with the pure C plugin // interface. To avoid copying tensor metadata, this is also the definitive // structure to store tensors. diff --git a/tensorflow/lite/interpreter_builder.cc b/tensorflow/lite/interpreter_builder.cc index fb87702fd13..43d81ef0770 100644 --- a/tensorflow/lite/interpreter_builder.cc +++ b/tensorflow/lite/interpreter_builder.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/lite/core/api/flatbuffer_conversions.h" #include "tensorflow/lite/kernels/internal/compatibility.h" #include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/tflite_with_xnnpack_optional.h" #include "tensorflow/lite/util.h" #include "tensorflow/lite/version.h" @@ -108,27 +109,14 @@ TfLiteStatus ParseSparseIndexVector(const DimensionMetadata* src, const char* kEmptyTensorName = ""; -#if TFLITE_HAS_ATTRIBUTE_WEAK // Using weak symbols to create a delegate allows automatic injection of the // delegate simply by adding it as a dependency. - // For flex delegate, see also the strong override in // lite/delegates/flex/delegate.cc. TFLITE_ATTRIBUTE_WEAK Interpreter::TfLiteDelegatePtr AcquireFlexDelegate() { return Interpreter::TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {}); } -// For XNNPACK delegate, see also the strong override in -// lite/tflite_with_xnnpack.cc. -TFLITE_ATTRIBUTE_WEAK Interpreter::TfLiteDelegatePtr AcquireXNNPACKDelegate( - int num_threads) { - return Interpreter::TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {}); -} -#else -Interpreter::TfLiteDelegatePtr (*AcquireFlexDelegate)() = nullptr; -Interpreter::TfLiteDelegatePtr (*AcquireXNNPACKDelegate)(int) = nullptr; -#endif - namespace impl { InterpreterBuilder::InterpreterBuilder(const FlatBufferModel& model, @@ -541,17 +529,17 @@ TfLiteStatus InterpreterBuilder::ParseTensors( TfLiteStatus InterpreterBuilder::ApplyDelegates(Interpreter* interpreter, int num_threads) { // First, apply XNNPACK delegate if applicable. - if (AcquireXNNPACKDelegate && num_fp32_tensors_ > 0) { - if (auto xnnpack_delegate = AcquireXNNPACKDelegate(num_threads)) { - // The execution will fall back to default implementation if the XNNPACK - // delegate fails to be applied. Therefore, we ignore the return status - // here and let it fall through the rest of the code. + if (num_fp32_tensors_ > 0) { + // The execution will fall back to default implementation if the XNNPACK + // delegate fails to be applied. Therefore, we ignore the return status + // here and let it fall through the rest of the code. + if (auto xnnpack_delegate = MaybeCreateXNNPACKDelegate(num_threads)) { interpreter->ModifyGraphWithDelegate(std::move(xnnpack_delegate)); } } // Secondly, apply Flex delegate if applicable. - if (has_flex_op_ && AcquireFlexDelegate) { + if (has_flex_op_) { if (auto flex_delegate = AcquireFlexDelegate()) { return interpreter->ModifyGraphWithDelegate(std::move(flex_delegate)); } diff --git a/tensorflow/lite/interpreter_test.cc b/tensorflow/lite/interpreter_test.cc index abd92ad563d..49b8e7bd816 100644 --- a/tensorflow/lite/interpreter_test.cc +++ b/tensorflow/lite/interpreter_test.cc @@ -1304,948 +1304,6 @@ TEST_F(TestExecutionPlan, NullExecutionPlan) { ASSERT_EQ(run_order_, std::vector<int>()); } -// Build a kernel registration for an op that copies its one input -// to an output -TfLiteRegistration AddOpRegistration() { - TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr}; - - reg.custom_name = "my_add"; - reg.builtin_code = tflite::BuiltinOperator_CUSTOM; - - reg.prepare = [](TfLiteContext* context, TfLiteNode* node) { - // Set output size to input size - const TfLiteTensor* input1 = GetInput(context, node, 0); - const TfLiteTensor* input2 = GetInput(context, node, 1); - TfLiteTensor* output = GetOutput(context, node, 0); - - TF_LITE_ENSURE_EQ(context, input1->dims->size, input2->dims->size); - for (int i = 0; i < input1->dims->size; ++i) { - TF_LITE_ENSURE_EQ(context, input1->dims->data[i], input2->dims->data[i]); - } - - TF_LITE_ENSURE_STATUS(context->ResizeTensor( - context, output, TfLiteIntArrayCopy(input1->dims))); - return kTfLiteOk; - }; - - reg.invoke = [](TfLiteContext* context, TfLiteNode* node) { - // Copy input data to output data. - const TfLiteTensor* a0 = GetInput(context, node, 0); - TF_LITE_ENSURE(context, a0); - TF_LITE_ENSURE(context, a0->data.f); - const TfLiteTensor* a1 = GetInput(context, node, 1); - TF_LITE_ENSURE(context, a1); - TF_LITE_ENSURE(context, a1->data.f); - TfLiteTensor* out = GetOutput(context, node, 0); - TF_LITE_ENSURE(context, out); - TF_LITE_ENSURE(context, out->data.f); - int num = a0->dims->data[0]; - for (int i = 0; i < num; i++) { - out->data.f[i] = a0->data.f[i] + a1->data.f[i]; - } - return kTfLiteOk; - }; - return reg; -} - -} // namespace - -// TestDelegate is a friend of Interpreter to access RemoveAllDelegates(). -class TestDelegate : public ::testing::Test { - protected: - void SetUp() override { - interpreter_.reset(new Interpreter); - interpreter_->AddTensors(5); - interpreter_->SetInputs({0, 1}); - interpreter_->SetOutputs({3, 4}); - TfLiteQuantizationParams quant; - interpreter_->SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {3}, - quant); - interpreter_->SetTensorParametersReadWrite(1, kTfLiteFloat32, "", {3}, - quant); - interpreter_->SetTensorParametersReadWrite(2, kTfLiteFloat32, "", {3}, - quant); - interpreter_->SetTensorParametersReadWrite(3, kTfLiteFloat32, "", {3}, - quant); - interpreter_->SetTensorParametersReadWrite(4, kTfLiteFloat32, "", {3}, - quant); - TfLiteRegistration reg = AddOpRegistration(); - interpreter_->AddNodeWithParameters({0, 0}, {2}, nullptr, 0, nullptr, ®); - interpreter_->AddNodeWithParameters({1, 1}, {3}, nullptr, 0, nullptr, ®); - interpreter_->AddNodeWithParameters({2, 1}, {4}, nullptr, 0, nullptr, ®); - } - - void TearDown() override { - // Interpreter relies on delegate to free the resources properly. Thus - // the life cycle of delegate must be longer than interpreter. - interpreter_.reset(); - delegate_.reset(); - } - - TfLiteBufferHandle last_allocated_handle_ = kTfLiteNullBufferHandle; - - TfLiteBufferHandle AllocateBufferHandle() { return ++last_allocated_handle_; } - - TfLiteStatus RemoveAllDelegates() { - return interpreter_->RemoveAllDelegates(); - } - - protected: - class SimpleDelegate { - public: - // Create a simple implementation of a TfLiteDelegate. We use the C++ class - // SimpleDelegate and it can produce a handle TfLiteDelegate that is - // value-copyable and compatible with TfLite. - // fail_node_prepare: To simulate failure of Delegate node's Prepare(). - // min_ops_per_subset: If >0, partitioning preview is used to choose only - // those subsets with min_ops_per_subset number of nodes. - // fail_node_invoke: To simulate failure of Delegate node's Invoke(). - explicit SimpleDelegate( - const std::vector<int>& nodes, - TfLiteDelegateFlags delegate_flags = kTfLiteDelegateFlagsNone, - bool fail_node_prepare = false, int min_ops_per_subset = 0, - bool fail_node_invoke = false) - : nodes_(nodes), - fail_delegate_node_prepare_(fail_node_prepare), - min_ops_per_subset_(min_ops_per_subset), - fail_delegate_node_invoke_(fail_node_invoke) { - delegate_.Prepare = [](TfLiteContext* context, - TfLiteDelegate* delegate) -> TfLiteStatus { - auto* simple = static_cast<SimpleDelegate*>(delegate->data_); - TfLiteIntArray* nodes_to_separate = - TfLiteIntArrayCreate(simple->nodes_.size()); - // Mark nodes that we want in TfLiteIntArray* structure. - int index = 0; - for (auto node_index : simple->nodes_) { - nodes_to_separate->data[index++] = node_index; - // make sure node is added - TfLiteNode* node; - TfLiteRegistration* reg; - context->GetNodeAndRegistration(context, node_index, &node, ®); - TFLITE_CHECK_EQ(reg->builtin_code, tflite::BuiltinOperator_CUSTOM); - TFLITE_CHECK_EQ(strcmp(reg->custom_name, "my_add"), 0); - } - // Check that all nodes are available - TfLiteIntArray* execution_plan; - TF_LITE_ENSURE_STATUS( - context->GetExecutionPlan(context, &execution_plan)); - for (int exec_index = 0; exec_index < execution_plan->size; - exec_index++) { - int node_index = execution_plan->data[exec_index]; - TfLiteNode* node; - TfLiteRegistration* reg; - context->GetNodeAndRegistration(context, node_index, &node, ®); - if (exec_index == node_index) { - // Check op details only if it wasn't delegated already. - TFLITE_CHECK_EQ(reg->builtin_code, tflite::BuiltinOperator_CUSTOM); - TFLITE_CHECK_EQ(strcmp(reg->custom_name, "my_add"), 0); - } - } - - // Get preview of delegate partitioning from the context. - TfLiteDelegateParams* params_array; - int num_partitions; - TFLITE_CHECK_EQ( - context->PreviewDelegatePartitioning( - context, nodes_to_separate, ¶ms_array, &num_partitions), - kTfLiteOk); - - if (simple->min_ops_per_subset() > 0) { - // Build a new vector of ops from subsets with atleast the minimum - // size. - std::vector<int> allowed_ops; - for (int idx = 0; idx < num_partitions; ++idx) { - const auto* nodes_in_subset = params_array[idx].nodes_to_replace; - if (nodes_in_subset->size < simple->min_ops_per_subset()) continue; - allowed_ops.insert(allowed_ops.end(), nodes_in_subset->data, - nodes_in_subset->data + nodes_in_subset->size); - } - - // Free existing nodes_to_separate & initialize a new array with - // allowed_ops. - TfLiteIntArrayFree(nodes_to_separate); - nodes_to_separate = TfLiteIntArrayCreate(allowed_ops.size()); - memcpy(nodes_to_separate->data, allowed_ops.data(), - sizeof(int) * nodes_to_separate->size); - } - - // Another call to PreviewDelegateParitioning should be okay, since - // partitioning memory is managed by context. - TFLITE_CHECK_EQ( - context->PreviewDelegatePartitioning( - context, nodes_to_separate, ¶ms_array, &num_partitions), - kTfLiteOk); - - context->ReplaceNodeSubsetsWithDelegateKernels( - context, simple->FakeFusedRegistration(), nodes_to_separate, - delegate); - TfLiteIntArrayFree(nodes_to_separate); - return kTfLiteOk; - }; - delegate_.CopyToBufferHandle = [](TfLiteContext* context, - TfLiteDelegate* delegate, - TfLiteBufferHandle buffer_handle, - TfLiteTensor* tensor) -> TfLiteStatus { - // TODO(ycling): Implement tests to test buffer copying logic. - return kTfLiteOk; - }; - delegate_.CopyFromBufferHandle = - [](TfLiteContext* context, TfLiteDelegate* delegate, - TfLiteBufferHandle buffer_handle, - TfLiteTensor* output) -> TfLiteStatus { - TFLITE_CHECK_GE(buffer_handle, -1); - TFLITE_CHECK_EQ(output->buffer_handle, buffer_handle); - const float floats[] = {6., 6., 6.}; - int num = output->dims->data[0]; - for (int i = 0; i < num; i++) { - output->data.f[i] = floats[i]; - } - return kTfLiteOk; - }; - - delegate_.FreeBufferHandle = - [](TfLiteContext* context, TfLiteDelegate* delegate, - TfLiteBufferHandle* handle) { *handle = kTfLiteNullBufferHandle; }; - // Store type-punned data SimpleDelegate structure. - delegate_.data_ = static_cast<void*>(this); - delegate_.flags = delegate_flags; - } - - TfLiteRegistration FakeFusedRegistration() { - TfLiteRegistration reg = {nullptr}; - reg.custom_name = "fake_fused_op"; - - reg.invoke = [](TfLiteContext* context, - TfLiteNode* node) -> TfLiteStatus { - // Copy input data to output data. - const TfLiteTensor* a0; - const TfLiteTensor* a1; - if (node->inputs->size == 2) { - a0 = GetInput(context, node, 0); - a1 = GetInput(context, node, 1); - } else { - a0 = GetInput(context, node, 0); - a1 = a0; - } - TfLiteTensor* out = GetOutput(context, node, 0); - int num = 1; - for (int i = 0; i < a0->dims->size; ++i) { - num *= a0->dims->data[i]; - } - for (int i = 0; i < num; i++) { - out->data.f[i] = a0->data.f[i] + a1->data.f[i]; - } - // Make the data stale so that CopyFromBufferHandle can be invoked - out->data_is_stale = true; - return kTfLiteOk; - }; - if (fail_delegate_node_invoke_) { - reg.invoke = [](TfLiteContext* context, - TfLiteNode* node) -> TfLiteStatus { - return kTfLiteError; - }; - } - - reg.prepare = [](TfLiteContext* context, TfLiteNode* node) { - // Set output size to input size - const TfLiteTensor* input1; - const TfLiteTensor* input2; - if (node->inputs->size == 2) { - input1 = GetInput(context, node, 0); - input2 = GetInput(context, node, 1); - } else { - input1 = GetInput(context, node, 0); - input2 = input1; - } - TfLiteTensor* output = GetOutput(context, node, 0); - - TF_LITE_ENSURE_STATUS(context->ResizeTensor( - context, output, TfLiteIntArrayCopy(input1->dims))); - return kTfLiteOk; - }; - if (fail_delegate_node_prepare_) { - reg.prepare = [](TfLiteContext* context, TfLiteNode* node) { - return kTfLiteError; - }; - } - - return reg; - } - - TfLiteDelegate* get_tf_lite_delegate() { return &delegate_; } - - int min_ops_per_subset() { return min_ops_per_subset_; } - - private: - std::vector<int> nodes_; - TfLiteDelegate delegate_; - bool fail_delegate_node_prepare_ = false; - int min_ops_per_subset_ = 0; - bool fail_delegate_node_invoke_ = false; - }; - - std::unique_ptr<Interpreter> interpreter_; - std::unique_ptr<SimpleDelegate> delegate_, delegate2_; -}; -namespace { - -TEST_F(TestDelegate, BasicDelegate) { - delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2})); - interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()); - - ASSERT_EQ(interpreter_->execution_plan().size(), 1); - int node = interpreter_->execution_plan()[0]; - const auto* node_and_reg = interpreter_->node_and_registration(node); - EXPECT_EQ(node_and_reg->second.custom_name, - delegate_->FakeFusedRegistration().custom_name); - - const TfLiteDelegateParams* params = static_cast<const TfLiteDelegateParams*>( - node_and_reg->first.builtin_data); - ASSERT_EQ(params->nodes_to_replace->size, 3); - EXPECT_EQ(params->nodes_to_replace->data[0], 0); - EXPECT_EQ(params->nodes_to_replace->data[1], 1); - EXPECT_EQ(params->nodes_to_replace->data[2], 2); - - ASSERT_EQ(params->input_tensors->size, 2); - EXPECT_EQ(params->input_tensors->data[0], 0); - EXPECT_EQ(params->input_tensors->data[1], 1); - - ASSERT_EQ(params->output_tensors->size, 2); - EXPECT_EQ(params->output_tensors->data[0], 3); - EXPECT_EQ(params->output_tensors->data[1], 4); -} - -TEST_F(TestDelegate, DelegateNodePrepareFailure) { - delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate( - {0, 1, 2}, kTfLiteDelegateFlagsNone, true /**fail_node_prepare**/)); - // ModifyGraphWithDelegate fails, since the Prepare() method in the node's - // TfLiteRegistration returns an error status. - ASSERT_EQ( - interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()), - kTfLiteDelegateError); - // Execution plan should remain unchanged. - ASSERT_EQ(interpreter_->execution_plan().size(), 3); - - std::vector<float> input = {1.0f, 2.0f, 3.0f}; - std::vector<float> expected_output = {2.0f, 4.0f, 6.0f}; - constexpr int kOutputTensorIndex = 3; - TfLiteTensor* tensor = interpreter_->tensor(kOutputTensorIndex); - - // Verify Invoke() behavior. - memcpy(interpreter_->typed_tensor<float>(0), input.data(), 3 * sizeof(float)); - memcpy(interpreter_->typed_tensor<float>(1), input.data(), 3 * sizeof(float)); - interpreter_->Invoke(); - for (int i = 0; i < 3; ++i) { - EXPECT_EQ(tensor->data.f[i], expected_output[i]) << i; - } -} - -TEST_F(TestDelegate, DelegateNodeInvokeFailure) { - delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate( - {0, 1, 2}, kTfLiteDelegateFlagsNone, false /**fail_node_prepare**/, - 0 /**min_ops_per_subset**/, true /**fail_node_invoke**/)); - ASSERT_EQ( - interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()), - kTfLiteOk); - // Delegation modified execution plan. - ASSERT_EQ(interpreter_->execution_plan().size(), 1); - - std::vector<float> input = {1.0f, 2.0f, 3.0f}; - std::vector<float> expected_output = {2.0f, 4.0f, 6.0f}; - constexpr int kOutputTensorIndex = 3; - - // Verify Invoke() behavior: fails first, succeeds after RemoveAllDelegates(). - memcpy(interpreter_->typed_tensor<float>(0), input.data(), 3 * sizeof(float)); - memcpy(interpreter_->typed_tensor<float>(1), input.data(), 3 * sizeof(float)); - EXPECT_EQ(interpreter_->Invoke(), kTfLiteError); - ASSERT_EQ(RemoveAllDelegates(), kTfLiteOk); - // Delegation removed, returning to original execution plan. - ASSERT_EQ(interpreter_->execution_plan().size(), 3); - - memcpy(interpreter_->typed_tensor<float>(0), input.data(), 3 * sizeof(float)); - memcpy(interpreter_->typed_tensor<float>(1), input.data(), 3 * sizeof(float)); - TfLiteTensor* tensor = interpreter_->tensor(kOutputTensorIndex); - ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); - for (int i = 0; i < 3; ++i) { - EXPECT_EQ(tensor->data.f[i], expected_output[i]) << i; - } -} - -TEST_F(TestDelegate, SecondDelegationPrepareFailure) { - // First delegate only supports nodes 1, 2. Gets applied successfully. - // This delegate should support dynamic tensors, otherwise the second won't be - // applied. - delegate_ = std::unique_ptr<SimpleDelegate>( - new SimpleDelegate({1, 2}, kTfLiteDelegateFlagsAllowDynamicTensors)); - // Second delegate supports node 0, but fails during the delegate-node's - // Prepare. - delegate2_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate( - {0}, kTfLiteDelegateFlagsNone, true /**fail_node_prepare**/)); - - // Initially, execution plan has 3 nodes. - ASSERT_EQ(interpreter_->execution_plan().size(), 3); - // First delegate should be applied successfully, yielding a plan with 2 - // nodes. - ASSERT_EQ( - interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()), - kTfLiteOk); - ASSERT_EQ(interpreter_->execution_plan().size(), 2); - // Second delegate won't get applied. - // As a result, previous delegate should also get undone, restoring the - // execution plan to its original state. - ASSERT_EQ( - interpreter_->ModifyGraphWithDelegate(delegate2_->get_tf_lite_delegate()), - kTfLiteDelegateError); - ASSERT_EQ(interpreter_->execution_plan().size(), 3); - - std::vector<float> input = {1.0f, 2.0f, 3.0f}; - std::vector<float> expected_output = {2.0f, 4.0f, 6.0f}; - constexpr int kOutputTensorIndex = 3; - TfLiteTensor* tensor = interpreter_->tensor(kOutputTensorIndex); - - // Verify Invoke() behavior. - memcpy(interpreter_->typed_tensor<float>(0), input.data(), 3 * sizeof(float)); - memcpy(interpreter_->typed_tensor<float>(1), input.data(), 3 * sizeof(float)); - interpreter_->Invoke(); - for (int i = 0; i < 3; ++i) { - EXPECT_EQ(tensor->data.f[i], expected_output[i]) << i; - } -} - -TEST_F(TestDelegate, SecondDelegationInvokeFailure) { - delegate_ = std::unique_ptr<SimpleDelegate>( - new SimpleDelegate({1, 2}, kTfLiteDelegateFlagsAllowDynamicTensors)); - delegate2_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate( - {0}, kTfLiteDelegateFlagsNone, false /**fail_node_prepare**/, - 0 /**min_ops_per_subset**/, true /**fail_node_invoke**/)); - ASSERT_EQ( - interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()), - kTfLiteOk); - ASSERT_EQ( - interpreter_->ModifyGraphWithDelegate(delegate2_->get_tf_lite_delegate()), - kTfLiteOk); - ASSERT_EQ(interpreter_->execution_plan().size(), 2); - - std::vector<float> input = {1.0f, 2.0f, 3.0f}; - // Outputs match the AddOp path, rather than delegate path. - std::vector<float> expected_output = {2.0f, 4.0f, 6.0f}; - constexpr int kOutputTensorIndex = 3; - - // Verify Invoke() behavior to ensure Interpreter isn't broken. - memcpy(interpreter_->typed_tensor<float>(0), input.data(), 3 * sizeof(float)); - memcpy(interpreter_->typed_tensor<float>(1), input.data(), 3 * sizeof(float)); - EXPECT_EQ(interpreter_->Invoke(), kTfLiteError); - EXPECT_EQ(RemoveAllDelegates(), kTfLiteOk); - ASSERT_EQ(interpreter_->execution_plan().size(), 3); - ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); - TfLiteTensor* tensor = interpreter_->tensor(kOutputTensorIndex); - for (int i = 0; i < 3; ++i) { - EXPECT_EQ(tensor->data.f[i], expected_output[i]) << i; - } -} - -TEST_F(TestDelegate, StaticDelegateMakesGraphImmutable) { - delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2})); - ASSERT_EQ( - interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()), - kTfLiteOk); - ASSERT_EQ(interpreter_->execution_plan().size(), 1); - - // Deliberately try to set tensor params with quantization while immutable, - // ensuring quantization is properly freed. - TfLiteQuantization quant = {}; - quant.type = kTfLiteAffineQuantization; - auto quant_params = static_cast<TfLiteAffineQuantization*>( - malloc(sizeof(TfLiteAffineQuantization))); - quant_params->scale = nullptr; - quant_params->zero_point = nullptr; - quant_params->quantized_dimension = 0; - quant.params = quant_params; - ASSERT_NE(interpreter_->SetTensorParametersReadWrite(0, kTfLiteInt8, "", {3}, - quant), - kTfLiteOk); -} - -TEST_F(TestDelegate, ComplexDelegate) { - delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({1, 2})); - interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()); - - ASSERT_EQ(interpreter_->execution_plan().size(), 2); - // 0th should be a non-delegated original op - ASSERT_EQ(interpreter_->execution_plan()[0], 0); - // 1st should be a new macro op (3) which didn't exist) - ASSERT_EQ(interpreter_->execution_plan()[1], 3); - const auto* node_and_reg = interpreter_->node_and_registration(3); - ASSERT_EQ(node_and_reg->second.custom_name, - delegate_->FakeFusedRegistration().custom_name); -} - -TEST_F(TestDelegate, SetBufferHandleToInput) { - delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2})); - TfLiteDelegate* delegate = delegate_->get_tf_lite_delegate(); - interpreter_->ModifyGraphWithDelegate(delegate); - - constexpr int kOutputTensorIndex = 0; - TfLiteTensor* tensor = interpreter_->tensor(kOutputTensorIndex); - ASSERT_EQ(tensor->delegate, nullptr); - ASSERT_EQ(tensor->buffer_handle, kTfLiteNullBufferHandle); - - TfLiteBufferHandle handle = AllocateBufferHandle(); - TfLiteStatus status = - interpreter_->SetBufferHandle(kOutputTensorIndex, handle, delegate); - ASSERT_EQ(status, kTfLiteOk); - EXPECT_EQ(tensor->delegate, delegate); - EXPECT_EQ(tensor->buffer_handle, handle); -} - -TEST_F(TestDelegate, SetBufferHandleToOutput) { - delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2})); - TfLiteDelegate* delegate = delegate_->get_tf_lite_delegate(); - interpreter_->ModifyGraphWithDelegate(delegate); - - constexpr int kOutputTensorIndex = 3; - TfLiteTensor* tensor = interpreter_->tensor(kOutputTensorIndex); - // Before setting the buffer handle, the tensor's `delegate` is already set - // because it will be written by the delegate. - ASSERT_EQ(tensor->delegate, delegate); - ASSERT_EQ(tensor->buffer_handle, kTfLiteNullBufferHandle); - - TfLiteBufferHandle handle = AllocateBufferHandle(); - TfLiteStatus status = - interpreter_->SetBufferHandle(kOutputTensorIndex, handle, delegate); - ASSERT_EQ(status, kTfLiteOk); - EXPECT_EQ(tensor->delegate, delegate); - EXPECT_EQ(tensor->buffer_handle, handle); -} - -TEST_F(TestDelegate, SetInvalidHandleToTensor) { - interpreter_->Invoke(); - delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2})); - TfLiteDelegate* delegate = delegate_->get_tf_lite_delegate(); - interpreter_->ModifyGraphWithDelegate(delegate); - - SimpleDelegate another_simple_delegate({0, 1, 2}); - - constexpr int kOutputTensorIndex = 3; - TfLiteTensor* tensor = interpreter_->tensor(kOutputTensorIndex); - // Before setting the buffer handle, the tensor's `delegate` is already set - // because it will be written by the delegate. - ASSERT_EQ(tensor->delegate, delegate); - ASSERT_EQ(tensor->buffer_handle, kTfLiteNullBufferHandle); - - TfLiteBufferHandle handle = AllocateBufferHandle(); - TfLiteStatus status = interpreter_->SetBufferHandle( - kOutputTensorIndex, handle, - another_simple_delegate.get_tf_lite_delegate()); - // Setting a buffer handle to a tensor with another delegate will fail. - ASSERT_EQ(status, kTfLiteError); - EXPECT_EQ(tensor->delegate, delegate); - EXPECT_EQ(tensor->buffer_handle, kTfLiteNullBufferHandle); -} - -// We utilize delegation in such a way as to allow node subsets with a minimum -// number of ops only. -TEST_F(TestDelegate, TestDelegationWithPartitionPreview) { - // We set kTfLiteDelegateFlagsAllowDynamicTensors to ensure the second - // delegate can be applied. - // Ops 0 and 2 are delegated but end up in the same partition (based on - // dependency analysis). However, since min_ops_per_subset = 3, no delegation - // takes place. - delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate( - {0, 2}, kTfLiteDelegateFlagsAllowDynamicTensors, - false /**fail_node_prepare**/, 3 /**min_ops_per_subset**/)); - interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()); - - // Original execution plan remains. - ASSERT_EQ(interpreter_->execution_plan().size(), 3); - ASSERT_EQ(interpreter_->execution_plan()[0], 0); - ASSERT_EQ(interpreter_->execution_plan()[1], 1); - ASSERT_EQ(interpreter_->execution_plan()[2], 2); - - // Same ops supported, but min_ops_per_subset = 2. - delegate2_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate( - {0, 2}, kTfLiteDelegateFlagsAllowDynamicTensors, - false /**fail_node_prepare**/, 2 /**min_ops_per_subset**/)); - interpreter_->ModifyGraphWithDelegate(delegate2_->get_tf_lite_delegate()); - - ASSERT_EQ(interpreter_->execution_plan().size(), 2); - ASSERT_EQ(interpreter_->execution_plan()[0], 3); - const auto* node_and_reg = interpreter_->node_and_registration(3); - ASSERT_EQ(node_and_reg->second.custom_name, - delegate2_->FakeFusedRegistration().custom_name); - ASSERT_EQ(interpreter_->execution_plan()[1], 1); -} - -TEST_F(TestDelegate, TestResizeInputWithNonDynamicDelegate) { - delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2})); - ASSERT_EQ( - interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()), - kTfLiteOk); - - // Try resizing input to same shape as before (which should be a No-op). - ASSERT_EQ(interpreter_->ResizeInputTensor(0, {3}), kTfLiteOk); - ASSERT_EQ(interpreter_->execution_plan().size(), 1); - - ASSERT_EQ(interpreter_->ResizeInputTensor(0, {1, 3}), kTfLiteOk); - ASSERT_EQ(interpreter_->ResizeInputTensor(1, {1, 3}), kTfLiteOk); - ASSERT_EQ(interpreter_->execution_plan().size(), 3); - // This should fail, since the previous application of the delegate will be - // re-done automatically, making the graph immutable again. - ASSERT_NE( - interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()), - kTfLiteOk); - // Ensure graph has been restored to its valid delegated state. - ASSERT_EQ(interpreter_->execution_plan().size(), 1); - - std::vector<float> input = {1.0f, 2.0f, 3.0f, 4.0f}; - std::vector<float> expected_output = {2.0f, 4.0f, 6.0f, 8.0f}; - constexpr int kOutputTensorIndex = 3; - TfLiteTensor* tensor = interpreter_->tensor(kOutputTensorIndex); - - // Verify Invoke() behavior. - memcpy(interpreter_->typed_tensor<float>(0), input.data(), 3 * sizeof(float)); - memcpy(interpreter_->typed_tensor<float>(1), input.data(), 3 * sizeof(float)); - interpreter_->Invoke(); - for (int i = 0; i < 3; ++i) { - EXPECT_EQ(tensor->data.f[i], expected_output[i]) << i; - } - - // Resize again, but call AllocateTensors as usual afterwards. - ASSERT_EQ(interpreter_->ResizeInputTensor(0, {1, 4}), kTfLiteOk); - ASSERT_EQ(interpreter_->ResizeInputTensor(1, {1, 4}), kTfLiteOk); - ASSERT_EQ(interpreter_->execution_plan().size(), 3); - ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); - ASSERT_EQ(interpreter_->execution_plan().size(), 1); - - memcpy(interpreter_->typed_tensor<float>(0), input.data(), 4 * sizeof(float)); - memcpy(interpreter_->typed_tensor<float>(1), input.data(), 4 * sizeof(float)); - interpreter_->Invoke(); - for (int i = 0; i < 4; ++i) { - EXPECT_EQ(tensor->data.f[i], expected_output[i]) << i; - } -} - -TEST_F(TestDelegate, TestResizeInputWithMultipleDelegates) { - // First delegate only supports node 0. - // This delegate should support dynamic tensors, otherwise the second won't be - // applied. - delegate_ = std::unique_ptr<SimpleDelegate>( - new SimpleDelegate({0}, kTfLiteDelegateFlagsAllowDynamicTensors)); - // Second delegate supports nodes 1 & 2, and makes the graph immutable. - delegate2_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({1, 2})); - ASSERT_EQ( - interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()), - kTfLiteOk); - ASSERT_EQ( - interpreter_->ModifyGraphWithDelegate(delegate2_->get_tf_lite_delegate()), - kTfLiteOk); - // Should be two delegates nodes. - ASSERT_EQ(interpreter_->execution_plan().size(), 2); - - // Try resizing input to same shape as before (which should be a No-op). - ASSERT_EQ(interpreter_->ResizeInputTensor(0, {3}), kTfLiteOk); - ASSERT_EQ(interpreter_->execution_plan().size(), 2); - - // Resizing input tensors should temporarily restore original execution plan - // of 3 nodes. - ASSERT_EQ(interpreter_->ResizeInputTensor(0, {1, 3}), kTfLiteOk); - ASSERT_EQ(interpreter_->ResizeInputTensor(1, {1, 3}), kTfLiteOk); - ASSERT_EQ(interpreter_->execution_plan().size(), 3); - // This should fail, since the previous application of the delegate will be - // re-done automatically, making the graph immutable again. - ASSERT_NE( - interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()), - kTfLiteOk); - // Ensure graph has been restored to its valid delegated state. - ASSERT_EQ(interpreter_->execution_plan().size(), 2); - - std::vector<float> input = {1.0f, 2.0f, 3.0f, 4.0f}; - std::vector<float> expected_output = {2.0f, 4.0f, 6.0f, 8.0f}; - constexpr int kOutputTensorIndex = 2; - TfLiteTensor* tensor = interpreter_->tensor(kOutputTensorIndex); - - // Verify Invoke() behavior. - memcpy(interpreter_->typed_tensor<float>(0), input.data(), 3 * sizeof(float)); - memcpy(interpreter_->typed_tensor<float>(1), input.data(), 3 * sizeof(float)); - interpreter_->Invoke(); - for (int i = 0; i < 3; ++i) { - EXPECT_EQ(tensor->data.f[i], expected_output[i]) << i; - } - - // Resize again, but call AllocateTensors as usual afterwards. - ASSERT_EQ(interpreter_->ResizeInputTensor(0, {1, 4}), kTfLiteOk); - ASSERT_EQ(interpreter_->ResizeInputTensor(1, {1, 4}), kTfLiteOk); - ASSERT_EQ(interpreter_->execution_plan().size(), 3); - ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); - ASSERT_EQ(interpreter_->execution_plan().size(), 2); - - memcpy(interpreter_->typed_tensor<float>(0), input.data(), 4 * sizeof(float)); - memcpy(interpreter_->typed_tensor<float>(1), input.data(), 4 * sizeof(float)); - interpreter_->Invoke(); - for (int i = 0; i < 4; ++i) { - EXPECT_EQ(tensor->data.f[i], expected_output[i]) << i; - } -} - -TEST_F(TestDelegate, ReleaseNonPersistentMemoryWithDelegates) { - // First delegate only supports node 0. - // This delegate should support dynamic tensors, otherwise the second won't be - // applied. - delegate_ = std::unique_ptr<SimpleDelegate>( - new SimpleDelegate({0}, kTfLiteDelegateFlagsAllowDynamicTensors)); - // Second delegate supports nodes 1 & 2, and makes the graph immutable. - delegate2_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({1, 2})); - - // No-op. - ASSERT_EQ(interpreter_->ReleaseNonPersistentMemory(), kTfLiteOk); - - ASSERT_EQ( - interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()), - kTfLiteOk); - ASSERT_EQ( - interpreter_->ModifyGraphWithDelegate(delegate2_->get_tf_lite_delegate()), - kTfLiteOk); - // Should be two delegates nodes. - ASSERT_EQ(interpreter_->execution_plan().size(), 2); - - ASSERT_EQ(interpreter_->ReleaseNonPersistentMemory(), kTfLiteOk); - ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); - - // This should fail, since the graph is immutable. - ASSERT_NE( - interpreter_->ModifyGraphWithDelegate(delegate_->get_tf_lite_delegate()), - kTfLiteOk); - - std::vector<float> input = {1.0f, 2.0f, 3.0f, 4.0f}; - std::vector<float> expected_output = {2.0f, 4.0f, 6.0f, 8.0f}; - constexpr int kOutputTensorIndex = 2; - TfLiteTensor* tensor = interpreter_->tensor(kOutputTensorIndex); - - // Verify Invoke() behavior. - memcpy(interpreter_->typed_tensor<float>(0), input.data(), 3 * sizeof(float)); - memcpy(interpreter_->typed_tensor<float>(1), input.data(), 3 * sizeof(float)); - interpreter_->Invoke(); - for (int i = 0; i < 3; ++i) { - EXPECT_EQ(tensor->data.f[i], expected_output[i]) << i; - } - - ASSERT_EQ(interpreter_->ReleaseNonPersistentMemory(), kTfLiteOk); -} - -TEST_F(TestDelegate, TestCopyFromBufferInvoke) { - delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2})); - TfLiteDelegate* delegate = delegate_->get_tf_lite_delegate(); - interpreter_->ModifyGraphWithDelegate(delegate); - - constexpr int kOutputTensorIndex = 3; - TfLiteTensor* tensor = interpreter_->tensor(kOutputTensorIndex); - std::vector<float> floats = {1.0f, 2.0f, 3.0f}; - memcpy(interpreter_->typed_tensor<float>(0), floats.data(), - floats.size() * sizeof(float)); - - memcpy(interpreter_->typed_tensor<float>(1), floats.data(), - floats.size() * sizeof(float)); - - // Before setting the buffer handle, the tensor's `delegate` is already set - // because it will be written by the delegate. - ASSERT_EQ(tensor->delegate, delegate); - ASSERT_EQ(tensor->buffer_handle, kTfLiteNullBufferHandle); - - // Called Invoke without setting the buffer will not call the CopyFromBuffer - interpreter_->Invoke(); - std::vector<float> res = {2.0f, 4.0f, 6.0f}; - for (int i = 0; i < tensor->dims->data[0]; ++i) { - ASSERT_EQ(tensor->data.f[i], res[i]); - } -} - -TEST_F(TestDelegate, TestCopyFromBuffer) { - delegate_ = std::unique_ptr<SimpleDelegate>(new SimpleDelegate({0, 1, 2})); - TfLiteDelegate* delegate = delegate_->get_tf_lite_delegate(); - interpreter_->ModifyGraphWithDelegate(delegate); - - constexpr int kOutputTensorIndex = 3; - TfLiteTensor* tensor = interpreter_->tensor(kOutputTensorIndex); - std::vector<float> floats = {1.0f, 2.0f, 3.0f}; - memcpy(interpreter_->typed_tensor<float>(0), floats.data(), - floats.size() * sizeof(float)); - - memcpy(interpreter_->typed_tensor<float>(1), floats.data(), - floats.size() * sizeof(float)); - - // Before setting the buffer handle, the tensor's `delegate` is already set - // because it will be written by the delegate. - ASSERT_EQ(tensor->delegate, delegate); - ASSERT_EQ(tensor->buffer_handle, kTfLiteNullBufferHandle); - - TfLiteBufferHandle handle = AllocateBufferHandle(); - TfLiteStatus status = - interpreter_->SetBufferHandle(kOutputTensorIndex, handle, delegate); - interpreter_->Invoke(); - ASSERT_EQ(status, kTfLiteOk); - EXPECT_EQ(tensor->delegate, delegate); - EXPECT_EQ(tensor->buffer_handle, handle); - for (int i = 0; i < tensor->dims->data[0]; ++i) { - ASSERT_EQ(tensor->data.f[i], 6.0f); - } -} - -TEST_F(TestDelegate, DelegateCustomOpResolution) { - // Build a flatbuffer model that contains the "my_add" custom op which gets - // resolved only after SimpleDelegate is applied. - flatbuffers::FlatBufferBuilder builder; - // Tensors. - const int32_t shape[1] = {3}; - flatbuffers::Offset<Tensor> tensors[3] = { - CreateTensor(builder, builder.CreateVector<int32_t>(shape, 1), - TensorType_FLOAT32, /*buffer=*/0, builder.CreateString("X")), - CreateTensor(builder, builder.CreateVector<int32_t>(shape, 1), - TensorType_FLOAT32, /*buffer=*/0, builder.CreateString("Y")), - CreateTensor(builder, builder.CreateVector<int32_t>(shape, 1), - TensorType_FLOAT32, /*buffer=*/0, builder.CreateString("Z")), - }; - // Custom op definition. - flatbuffers::Offset<OperatorCode> op_code = - CreateOperatorCodeDirect(builder, BuiltinOperator_CUSTOM, "my_add"); - const int32_t inputs[2] = {0, 1}; - const int32_t outputs[1] = {2}; - flatbuffers::Offset<Operator> op = CreateOperator( - builder, /*opcode_index=*/0, builder.CreateVector<int32_t>(inputs, 2), - builder.CreateVector<int32_t>(outputs, 1), BuiltinOptions_NONE, - /*builtin_options=*/0, - /*custom_options=*/0, tflite::CustomOptionsFormat_FLEXBUFFERS); - // Subgraph & Model. - flatbuffers::Offset<SubGraph> subgraph = - CreateSubGraph(builder, builder.CreateVector(tensors, 3), - builder.CreateVector<int32_t>(inputs, 2), - builder.CreateVector<int32_t>(outputs, 1), - builder.CreateVector(&op, 1), /*name=*/0); - flatbuffers::Offset<Buffer> buffers[1] = { - CreateBuffer(builder, builder.CreateVector({})), - }; - flatbuffers::Offset<Model> model_buffer = CreateModel( - builder, TFLITE_SCHEMA_VERSION, builder.CreateVector(&op_code, 1), - builder.CreateVector(&subgraph, 1), builder.CreateString("test_model"), - builder.CreateVector(buffers, 1)); - builder.Finish(model_buffer); - std::vector<char> buffer = - std::vector<char>(builder.GetBufferPointer(), - builder.GetBufferPointer() + builder.GetSize()); - const Model* model = GetModel(buffer.data()); - - // Build an interpreter with the model. Initialization should work fine. - std::unique_ptr<Interpreter> interpreter; - ASSERT_EQ( - InterpreterBuilder( - model, ::tflite::ops::builtin::BuiltinOpResolver())(&interpreter), - kTfLiteOk); - // AllocateTensors should fail, since my_add hasn't been resolved. - ASSERT_EQ(interpreter->AllocateTensors(), kTfLiteError); - - // Applying static delegate won't work, since the interpreter will first try - // to Prepare all original nodes. - std::unique_ptr<SimpleDelegate> static_delegate(new SimpleDelegate({0})); - ASSERT_EQ(interpreter->ModifyGraphWithDelegate( - static_delegate->get_tf_lite_delegate()), - kTfLiteError); - - // Applying delegate that supports dynamic tensors should work. - std::unique_ptr<SimpleDelegate> dynamic_delegate( - new SimpleDelegate({0}, kTfLiteDelegateFlagsAllowDynamicTensors)); - ASSERT_EQ(interpreter->ModifyGraphWithDelegate( - dynamic_delegate->get_tf_lite_delegate()), - kTfLiteOk); - // AllocateTensors will now work. - ASSERT_EQ(interpreter->AllocateTensors(), kTfLiteOk); -} - -class TestDelegateWithDynamicTensors : public ::testing::Test { - protected: - void SetUp() override { - interpreter_.reset(new Interpreter); - - interpreter_->AddTensors(2); - interpreter_->SetInputs({0}); - interpreter_->SetOutputs({1}); - TfLiteQuantizationParams quant; - interpreter_->SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {3}, - quant); - interpreter_->SetTensorParametersReadWrite(1, kTfLiteFloat32, "", {3}, - quant); - TfLiteRegistration reg = DynamicCopyOpRegistration(); - interpreter_->AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr, ®); - - delegate_.Prepare = [](TfLiteContext* context, - TfLiteDelegate* delegate) -> TfLiteStatus { - // In this test, the delegate replaces all the nodes if this function is - // called. - TfLiteIntArray* execution_plan; - TF_LITE_ENSURE_STATUS( - context->GetExecutionPlan(context, &execution_plan)); - context->ReplaceNodeSubsetsWithDelegateKernels( - context, DelegateRegistration(), execution_plan, delegate); - return kTfLiteOk; - }; - delegate_.flags = kTfLiteDelegateFlagsNone; - } - - static TfLiteRegistration DynamicCopyOpRegistration() { - TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr}; - - reg.prepare = [](TfLiteContext* context, TfLiteNode* node) { - TfLiteTensor* output = GetOutput(context, node, 0); - SetTensorToDynamic(output); - return kTfLiteOk; - }; - - reg.invoke = [](TfLiteContext* context, TfLiteNode* node) { - // Not implemented since this isn't required in testing. - return kTfLiteOk; - }; - return reg; - } - - static TfLiteRegistration DelegateRegistration() { - TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr}; - return reg; - } - - std::unique_ptr<Interpreter> interpreter_; - TfLiteDelegate delegate_; -}; - -TEST_F(TestDelegateWithDynamicTensors, DisallowDynamicTensors) { - interpreter_->ModifyGraphWithDelegate(&delegate_); - - ASSERT_EQ(interpreter_->execution_plan().size(), 1); - // The interpreter should not call delegate's `Prepare` when dynamic tensors - // exist. So the node ID isn't changed. - ASSERT_EQ(interpreter_->execution_plan()[0], 0); -} - -TEST_F(TestDelegateWithDynamicTensors, AllowDynamicTensors) { - delegate_.flags = kTfLiteDelegateFlagsAllowDynamicTensors; - interpreter_->ModifyGraphWithDelegate(&delegate_); - - ASSERT_EQ(interpreter_->execution_plan().size(), 1); - // The node should be replaced because dynamic tensors are allowed. Therefore - // only node ID in the execution plan is changed from 0 to 1. - ASSERT_EQ(interpreter_->execution_plan()[0], 1); -} - -TEST_F(TestDelegateWithDynamicTensors, ModifyGraphAfterAllocate) { - // Trigger allocation *before* delegate application. - ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); - - delegate_.flags = kTfLiteDelegateFlagsAllowDynamicTensors; - ASSERT_EQ(interpreter_->ModifyGraphWithDelegate(&delegate_), kTfLiteOk); - ASSERT_EQ(interpreter_->execution_plan().size(), 1); - ASSERT_EQ(interpreter_->execution_plan()[0], 1); - - // Allocation should still succeed. - ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); -} - TEST(TestDelegateOwnership, ProperlyDisposed) { struct TfLiteInterpreterOwnedDelegate : public TfLiteDelegate { TfLiteInterpreterOwnedDelegate(bool* destroyed, bool* prepared) diff --git a/tensorflow/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicBenchmarker.java b/tensorflow/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicBenchmarker.java index 49cf21debc5..839984cfc5d 100644 --- a/tensorflow/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicBenchmarker.java +++ b/tensorflow/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicBenchmarker.java @@ -57,19 +57,19 @@ public abstract class OvicBenchmarker { /** Total runtime in ns. */ protected double totalRuntimeNano = 0.0; /** Total allowed runtime in ms. */ - protected double wallTimeNano = 20000 * 30 * 1.0e6; + protected double wallTimeMilli = 20000 * 30.0; /** Record whether benchmark has started (used to skip the first image). */ protected boolean benchmarkStarted = false; /** * Initializes an {@link OvicBenchmarker} * - * @param wallTimeNano: a double number specifying the total amount of time to benchmark. + * @param wallTimeMilli: a double number specifying the total amount of time to benchmark. */ - public OvicBenchmarker(double wallTimeNano) { + protected OvicBenchmarker(double wallTimeMilli) { benchmarkStarted = false; totalRuntimeNano = 0.0; - this.wallTimeNano = wallTimeNano; + this.wallTimeMilli = wallTimeMilli; } /** Return the cumulative latency of all runs so far. */ @@ -79,13 +79,13 @@ public abstract class OvicBenchmarker { /** Check whether the benchmarker should stop. */ public Boolean shouldStop() { - if (totalRuntimeNano >= wallTimeNano) { + if ((totalRuntimeNano * 1.0 / 1e6) >= wallTimeMilli) { Log.e( TAG, - "Total runtime (ms) " - + (totalRuntimeNano * 1.0e-6) - + " exceeded wall-time " - + (wallTimeNano * 1.0e-6)); + "Total runtime " + + (totalRuntimeNano * 1.0 / 1e6) + + " exceeded walltime (ms) " + + wallTimeMilli); return true; } return false; diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index 3a29fee5699..657b5d89a85 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -235,6 +235,15 @@ cc_library( visibility = ["//visibility:private"], ) +cc_library( + name = "tflite_with_ruy_and_caching_enabled", + defines = [ + "TFLITE_WITH_RUY", + "TFLITE_WITH_RUY_GEMV", + ], + visibility = ["//visibility:private"], +) + cc_library( name = "tflite_with_ruy_default", build_for_embedded = True, @@ -423,140 +432,157 @@ cc_library( ], ) +BUILTIN_KERNEL_SRCS = [ + "activations.cc", + "add.cc", + "add_n.cc", + "arg_min_max.cc", + "audio_spectrogram.cc", + "basic_rnn.cc", + "batch_matmul.cc", + "batch_to_space_nd.cc", + "bidirectional_sequence_lstm.cc", + "bidirectional_sequence_rnn.cc", + "cast.cc", + "ceil.cc", + "comparisons.cc", + "concatenation.cc", + "conv.cc", + "densify.cc", + "depth_to_space.cc", + "depthwise_conv.cc", + "dequantize.cc", + "detection_postprocess.cc", + "div.cc", + "elementwise.cc", + "embedding_lookup.cc", + "embedding_lookup_sparse.cc", + "exp.cc", + "expand_dims.cc", + "fake_quant.cc", + "fill.cc", + "floor.cc", + "floor_div.cc", + "floor_mod.cc", + "fully_connected.cc", + "gather.cc", + "gather_nd.cc", + "hashtable_lookup.cc", + "if.cc", + "l2norm.cc", + "local_response_norm.cc", + "logical.cc", + "lsh_projection.cc", + "lstm.cc", + "matrix_diag.cc", + "matrix_set_diag.cc", + "maximum_minimum.cc", + "mfcc.cc", + "mirror_pad.cc", + "mul.cc", + "neg.cc", + "non_max_suppression.cc", + "numeric_verify.cc", + "one_hot.cc", + "pack.cc", + "pad.cc", + "pooling.cc", + "pow.cc", + "quantize.cc", + "range.cc", + "rank.cc", + "reduce.cc", + "reshape.cc", + "resize_bilinear.cc", + "resize_nearest_neighbor.cc", + "reverse.cc", + "reverse_sequence.cc", + "round.cc", + "scatter_nd.cc", + "segment_sum.cc", + "select.cc", + "shape.cc", + "skip_gram.cc", + "slice.cc", + "space_to_batch_nd.cc", + "space_to_depth.cc", + "sparse_to_dense.cc", + "split.cc", + "split_v.cc", + "squared_difference.cc", + "squeeze.cc", + "strided_slice.cc", + "sub.cc", + "svdf.cc", + "tile.cc", + "topk_v2.cc", + "transpose.cc", + "transpose_conv.cc", + "unidirectional_sequence_lstm.cc", + "unidirectional_sequence_rnn.cc", + "unique.cc", + "unpack.cc", + "where.cc", + "while.cc", + "zeros_like.cc", +] + +BUILTIN_KERNEL_DEPS = [ + ":cpu_backend_context", + ":cpu_backend_gemm", + ":cpu_backend_threadpool", + ":eigen_support", + ":kernel_util", + ":lstm_eval", + ":lstm_shared", + ":op_macros", + ":padding", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "//third_party/eigen3", + "@flatbuffers", + "//tensorflow/lite:framework_lib", + "//tensorflow/lite:minimal_logging", + "//tensorflow/lite:string_util", + "//tensorflow/lite/c:common", + "//tensorflow/lite/kernels/internal:audio_utils", + "//tensorflow/lite/kernels/internal:common", + "//tensorflow/lite/kernels/internal:compatibility", + "//tensorflow/lite/kernels/internal:cpu_check", + "//tensorflow/lite/kernels/internal:kernel_utils", + "//tensorflow/lite/kernels/internal:optimized", + "//tensorflow/lite/kernels/internal:optimized_base", + "//tensorflow/lite/kernels/internal:quantization_util", + "//tensorflow/lite/kernels/internal:reference_base", + "//tensorflow/lite/kernels/internal:strided_slice_logic", + "//tensorflow/lite/kernels/internal:tensor", + "//tensorflow/lite/kernels/internal:tensor_utils", + "//tensorflow/lite/kernels/internal:types", +] + cc_library( name = "builtin_op_kernels", - srcs = [ - "activations.cc", - "add.cc", - "add_n.cc", - "arg_min_max.cc", - "audio_spectrogram.cc", - "basic_rnn.cc", - "batch_matmul.cc", - "batch_to_space_nd.cc", - "bidirectional_sequence_lstm.cc", - "bidirectional_sequence_rnn.cc", - "cast.cc", - "ceil.cc", - "comparisons.cc", - "concatenation.cc", - "conv.cc", - "densify.cc", - "depth_to_space.cc", - "depthwise_conv.cc", - "dequantize.cc", - "detection_postprocess.cc", - "div.cc", - "elementwise.cc", - "embedding_lookup.cc", - "embedding_lookup_sparse.cc", - "exp.cc", - "expand_dims.cc", - "fake_quant.cc", - "fill.cc", - "floor.cc", - "floor_div.cc", - "floor_mod.cc", - "fully_connected.cc", - "gather.cc", - "gather_nd.cc", - "hashtable_lookup.cc", - "if.cc", - "l2norm.cc", - "local_response_norm.cc", - "logical.cc", - "lsh_projection.cc", - "lstm.cc", - "matrix_diag.cc", - "matrix_set_diag.cc", - "maximum_minimum.cc", - "mfcc.cc", - "mirror_pad.cc", - "mul.cc", - "neg.cc", - "non_max_suppression.cc", - "numeric_verify.cc", - "one_hot.cc", - "pack.cc", - "pad.cc", - "pooling.cc", - "pow.cc", - "quantize.cc", - "range.cc", - "rank.cc", - "reduce.cc", - "reshape.cc", - "resize_bilinear.cc", - "resize_nearest_neighbor.cc", - "reverse.cc", - "reverse_sequence.cc", - "round.cc", - "scatter_nd.cc", - "segment_sum.cc", - "select.cc", - "shape.cc", - "skip_gram.cc", - "slice.cc", - "space_to_batch_nd.cc", - "space_to_depth.cc", - "sparse_to_dense.cc", - "split.cc", - "split_v.cc", - "squared_difference.cc", - "squeeze.cc", - "strided_slice.cc", - "sub.cc", - "svdf.cc", - "tile.cc", - "topk_v2.cc", - "transpose.cc", - "transpose_conv.cc", - "unidirectional_sequence_lstm.cc", - "unidirectional_sequence_rnn.cc", - "unique.cc", - "unpack.cc", - "where.cc", - "while.cc", - "zeros_like.cc", - ], + srcs = BUILTIN_KERNEL_SRCS, hdrs = [ "dequantize.h", ], copts = tflite_copts() + tf_opts_nortti_if_android() + EXTRA_EIGEN_COPTS, visibility = ["//visibility:private"], - deps = [ - ":cpu_backend_context", - ":cpu_backend_gemm", - ":cpu_backend_threadpool", - ":eigen_support", - ":kernel_util", - ":lstm_eval", - ":lstm_shared", - ":op_macros", - ":padding", - "//tensorflow/lite:framework_lib", - "//tensorflow/lite:minimal_logging", - "//tensorflow/lite:string_util", - "//tensorflow/lite/c:common", - "//tensorflow/lite/kernels/internal:audio_utils", - "//tensorflow/lite/kernels/internal:common", - "//tensorflow/lite/kernels/internal:compatibility", - "//tensorflow/lite/kernels/internal:cpu_check", - "//tensorflow/lite/kernels/internal:kernel_utils", - "//tensorflow/lite/kernels/internal:optimized", - "//tensorflow/lite/kernels/internal:optimized_base", - "//tensorflow/lite/kernels/internal:quantization_util", - "//tensorflow/lite/kernels/internal:reference_base", - "//tensorflow/lite/kernels/internal:strided_slice_logic", - "//tensorflow/lite/kernels/internal:tensor", - "//tensorflow/lite/kernels/internal:tensor_utils", - "//tensorflow/lite/kernels/internal:types", - "//third_party/eigen3", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", - "@farmhash_archive//:farmhash", - "@flatbuffers", + deps = BUILTIN_KERNEL_DEPS + ["@farmhash_archive//:farmhash"], +) + +# Creates a target where Ruy is unconditionally enabled along with caching +# on GEMV operations. This is useful for TF Lite deployments where custom +# copts are not allowed, e.g. b/156119344 +cc_library( + name = "builtin_op_kernels_ruy_and_caching", + srcs = BUILTIN_KERNEL_SRCS, + hdrs = [ + "dequantize.h", ], + copts = tflite_copts() + tf_opts_nortti_if_android() + EXTRA_EIGEN_COPTS, + visibility = ["//visibility:private"], + deps = BUILTIN_KERNEL_DEPS + ["@farmhash_archive//:farmhash"] + [":tflite_with_ruy_and_caching_enabled"], ) cc_library( @@ -673,6 +699,22 @@ cc_library( ], ) +# TODO(b/156664104) Remove once runtime flag available. +cc_library( + name = "builtin_ops_ruy_and_caching_enabled", + srcs = ["register.cc"], + hdrs = [ + "builtin_op_kernels.h", + "fully_connected.h", + "register.h", + ], + deps = [ + ":builtin_op_kernels_ruy_and_caching", + "//tensorflow/lite:framework_lib", + "//tensorflow/lite/c:common", + ], +) + # The builtin_ops target will resolve to optimized kernels when available. This # target uses reference kernels only, and is useful for validation and testing. # It should *not* generally be used in production. diff --git a/tensorflow/lite/kernels/internal/optimized/integer_ops/add.h b/tensorflow/lite/kernels/internal/optimized/integer_ops/add.h index 8937fe2b26e..44479d93a31 100644 --- a/tensorflow/lite/kernels/internal/optimized/integer_ops/add.h +++ b/tensorflow/lite/kernels/internal/optimized/integer_ops/add.h @@ -18,6 +18,7 @@ limitations under the License. #include "ruy/profiler/instrumentation.h" // from @ruy #include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h" +#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/lite/kernels/internal/reference/integer_ops/add.h" #include "tensorflow/lite/kernels/internal/types.h" @@ -47,6 +48,9 @@ inline void AddElementwise(int size, const ArithmeticParams& params, const int32x4_t input1_left_dup = vdupq_n_s32(input1_left_shift); const int32x4_t input2_left_dup = vdupq_n_s32(input2_left_shift); + const int16x8_t input1_offset_dup = vdupq_n_s16(params.input1_offset); + const int16x8_t input2_offset_dup = vdupq_n_s16(params.input2_offset); + for (; i <= size - 16; i += 16) { const int8x16_t input1_val_original = vld1q_s8(input1_data + i); const int8x16_t input2_val_original = vld1q_s8(input2_data + i); @@ -61,13 +65,13 @@ inline void AddElementwise(int size, const ArithmeticParams& params, const int16x8_t input2_val_s16_low = vmovl_s8(vget_low_s8(input2_val_original)); const int16x8_t input1_val_high = - vaddq_s16(input1_val_s16_high, vdupq_n_s16(params.input1_offset)); + vaddq_s16(input1_val_s16_high, input1_offset_dup); const int16x8_t input2_val_high = - vaddq_s16(input2_val_s16_high, vdupq_n_s16(params.input2_offset)); + vaddq_s16(input2_val_s16_high, input2_offset_dup); const int16x8_t input1_val_low = - vaddq_s16(input1_val_s16_low, vdupq_n_s16(params.input1_offset)); + vaddq_s16(input1_val_s16_low, input1_offset_dup); const int16x8_t input2_val_low = - vaddq_s16(input2_val_s16_low, vdupq_n_s16(params.input2_offset)); + vaddq_s16(input2_val_s16_low, input2_offset_dup); const int16x4_t input1_val_high_high = vget_high_s16(input1_val_high); const int16x4_t input1_val_high_low = vget_low_s16(input1_val_high); const int16x4_t input1_val_low_high = vget_high_s16(input1_val_low); @@ -122,7 +126,7 @@ inline void AddElementwise(int size, const ArithmeticParams& params, vdupq_n_s16(params.output_offset)); const int16x8_t s2 = vaddq_s16(vcombine_s16(s21_narrowed, s22_narrowed), vdupq_n_s16(params.output_offset)); - const int16x8_t s = vcombine_s16(vqmovn_s16(s1), vqmovn_s16(s2)); + const int8x16_t s = vcombine_s8(vqmovn_s16(s1), vqmovn_s16(s2)); const int8x16_t clamped = vmaxq_s8(output_activation_min_vector, @@ -272,101 +276,6 @@ inline void Add(const ArithmeticParams& params, AddElementwise(flat_size, params, input1_data, input2_data, output_data); } -inline void BroadcastAddFivefold(const ArithmeticParams& unswitched_params, - const RuntimeShape& unswitched_input1_shape, - const int8* unswitched_input1_data, - const RuntimeShape& unswitched_input2_shape, - const int8* unswitched_input2_data, - const RuntimeShape& output_shape, - int8* output_data) { - ruy::profiler::ScopeLabel label("BroadcastAddFivefoldInt8/8bit"); - - ArithmeticParams switched_params = unswitched_params; - switched_params.input1_offset = unswitched_params.input2_offset; - switched_params.input1_multiplier = unswitched_params.input2_multiplier; - switched_params.input1_shift = unswitched_params.input2_shift; - switched_params.input2_offset = unswitched_params.input1_offset; - switched_params.input2_multiplier = unswitched_params.input1_multiplier; - switched_params.input2_shift = unswitched_params.input1_shift; - - const bool use_unswitched = - unswitched_params.broadcast_category == - tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast; - - const ArithmeticParams& params = - use_unswitched ? unswitched_params : switched_params; - const int8* input1_data = - use_unswitched ? unswitched_input1_data : unswitched_input2_data; - const int8* input2_data = - use_unswitched ? unswitched_input2_data : unswitched_input1_data; - - // Fivefold nested loops. The second input resets its position for each - // iteration of the second loop. The first input resets its position at the - // beginning of the fourth loop. The innermost loop is an elementwise add of - // sections of the arrays. - int8* output_data_ptr = output_data; - const int8* input1_data_ptr = input1_data; - const int8* input2_data_reset = input2_data; - // In the fivefold pattern, y0, y2 and y4 are not broadcast, and so shared - // between input shapes. y3 for input 1 is always broadcast, and so the - // dimension there is 1, whereas optionally y1 might be broadcast for input 2. - // Put another way, - // input1.shape.FlatSize = y0 * y1 * y2 * y4, - // input2.shape.FlatSize = y0 * y2 * y3 * y4. - int y0 = params.broadcast_shape[0]; - int y1 = params.broadcast_shape[1]; - int y2 = params.broadcast_shape[2]; - int y3 = params.broadcast_shape[3]; - int y4 = params.broadcast_shape[4]; - if (y4 > 1) { - // General fivefold pattern, with y4 > 1 so there is a non-broadcast inner - // dimension. - for (int i0 = 0; i0 < y0; ++i0) { - const int8* input2_data_ptr = nullptr; - for (int i1 = 0; i1 < y1; ++i1) { - input2_data_ptr = input2_data_reset; - for (int i2 = 0; i2 < y2; ++i2) { - for (int i3 = 0; i3 < y3; ++i3) { - AddElementwise(y4, params, input1_data_ptr, input2_data_ptr, - output_data_ptr); - input2_data_ptr += y4; - output_data_ptr += y4; - } - // We have broadcast y4 of input1 data y3 times, and now move on. - input1_data_ptr += y4; - } - } - // We have broadcast y2*y3*y4 of input2 data y1 times, and now move on. - input2_data_reset = input2_data_ptr; - } - } else { - // Special case of y4 == 1, in which the innermost loop is a single element - // and can be combined with the next (y3) as an inner broadcast. - // - // Note that this handles the case of pure scalar broadcast when - // y0 == y1 == y2 == 1. With low overhead it handles cases such as scalar - // broadcast with batch (as y2 > 1). - // - // NOTE The process is the same as the above general case except simplified - // for y4 == 1 and the loop over y3 is contained within the - // AddScalarBroadcast function. - for (int i0 = 0; i0 < y0; ++i0) { - const int8* input2_data_ptr = nullptr; - for (int i1 = 0; i1 < y1; ++i1) { - input2_data_ptr = input2_data_reset; - for (int i2 = 0; i2 < y2; ++i2) { - AddScalarBroadcast(y3, params, *input1_data_ptr, input2_data_ptr, - output_data_ptr); - input2_data_ptr += y3; - output_data_ptr += y3; - input1_data_ptr += 1; - } - } - input2_data_reset = input2_data_ptr; - } - } -} - inline void BroadcastAddDispatch(const ArithmeticParams& params, const RuntimeShape& input1_shape, const int8* input1_data, @@ -380,8 +289,9 @@ inline void BroadcastAddDispatch(const ArithmeticParams& params, output_shape, output_data); } - BroadcastAddFivefold(params, input1_shape, input1_data, input2_shape, - input2_data, output_shape, output_data); + optimized_ops::BinaryBroadcastFiveFold( + params, input1_shape, input1_data, input2_shape, input2_data, + output_shape, output_data, AddElementwise, AddScalarBroadcast); } } // namespace optimized_integer_ops diff --git a/tensorflow/lite/kernels/internal/optimized/integer_ops/mul.h b/tensorflow/lite/kernels/internal/optimized/integer_ops/mul.h index 18aeef4c8b5..0d385ec1656 100644 --- a/tensorflow/lite/kernels/internal/optimized/integer_ops/mul.h +++ b/tensorflow/lite/kernels/internal/optimized/integer_ops/mul.h @@ -38,49 +38,81 @@ inline void MulElementwise(int size, const ArithmeticParams& params, TFLITE_DCHECK_GT(params.output_offset, -256); TFLITE_DCHECK_LT(params.output_offset, 256); #ifdef USE_NEON - const auto input1_offset_vector = vdupq_n_s16(params.input1_offset); - const auto input2_offset_vector = vdupq_n_s16(params.input2_offset); - const auto output_offset_vector = vdupq_n_s16(params.output_offset); + const int16x8_t input1_offset_vector = vdupq_n_s16(params.input1_offset); + const int16x8_t input2_offset_vector = vdupq_n_s16(params.input2_offset); + const int16x8_t output_offset_vector = vdupq_n_s16(params.output_offset); const auto output_activation_min_vector = - vdup_n_s8(params.quantized_activation_min); + vdupq_n_s8(params.quantized_activation_min); const auto output_activation_max_vector = - vdup_n_s8(params.quantized_activation_max); + vdupq_n_s8(params.quantized_activation_max); const int left_shift = std::max(0, params.output_shift); const int right_shift = std::max(0, -params.output_shift); const int32x4_t left_shift_vec = vdupq_n_s32(left_shift); - for (; i <= size - 8; i += 8) { - // We load / store 8 at a time, multiplying as two sets of 4 int32s. - const auto input1_val_original = vld1_s8(input1_data + i); - const auto input2_val_original = vld1_s8(input2_data + i); - const auto input1_val_s16 = vmovl_s8(input1_val_original); - const auto input2_val_s16 = vmovl_s8(input2_val_original); - const auto input1_val = vaddq_s16(input1_val_s16, input1_offset_vector); - const auto input2_val = vaddq_s16(input2_val_s16, input2_offset_vector); + for (; i <= size - 16; i += 16) { + // We load / store 16 at a time, multiplying as four sets of 4 int32s. + const int8x16_t input1_val_original = vld1q_s8(input1_data + i); + const int8x16_t input2_val_original = vld1q_s8(input2_data + i); - const auto input1_val_low = vget_low_s16(input1_val); - const auto input1_val_high = vget_high_s16(input1_val); - const auto input2_val_low = vget_low_s16(input2_val); - const auto input2_val_high = vget_high_s16(input2_val); + const int16x8_t input1_val_s16_high = + vmovl_s8(vget_high_s8(input1_val_original)); + const int16x8_t input1_val_s16_low = + vmovl_s8(vget_low_s8(input1_val_original)); - auto p1 = vmull_s16(input2_val_low, input1_val_low); - auto p2 = vmull_s16(input2_val_high, input1_val_high); + const int16x8_t input2_val_s16_high = + vmovl_s8(vget_high_s8(input2_val_original)); + const int16x8_t input2_val_s16_low = + vmovl_s8(vget_low_s8(input2_val_original)); + const int16x8_t input1_val_high = + vaddq_s16(input1_val_s16_high, input1_offset_vector); + const int16x8_t input2_val_high = + vaddq_s16(input2_val_s16_high, input2_offset_vector); + const int16x8_t input1_val_low = + vaddq_s16(input1_val_s16_low, input1_offset_vector); + const int16x8_t input2_val_low = + vaddq_s16(input2_val_s16_low, input2_offset_vector); + const int16x4_t input1_val_high_high = vget_high_s16(input1_val_high); + const int16x4_t input1_val_high_low = vget_low_s16(input1_val_high); + const int16x4_t input1_val_low_high = vget_high_s16(input1_val_low); + const int16x4_t input1_val_low_low = vget_low_s16(input1_val_low); + const int16x4_t input2_val_high_high = vget_high_s16(input2_val_high); + const int16x4_t input2_val_high_low = vget_low_s16(input2_val_high); + const int16x4_t input2_val_low_high = vget_high_s16(input2_val_low); + const int16x4_t input2_val_low_low = vget_low_s16(input2_val_low); + + auto p1 = vmull_s16(input2_val_high_high, input1_val_high_high); + auto p2 = vmull_s16(input2_val_high_low, input1_val_high_low); + auto p3 = vmull_s16(input2_val_low_high, input1_val_low_high); + auto p4 = vmull_s16(input2_val_low_low, input1_val_low_low); p1 = vshlq_s32(p1, left_shift_vec); p2 = vshlq_s32(p2, left_shift_vec); + p3 = vshlq_s32(p3, left_shift_vec); + p4 = vshlq_s32(p4, left_shift_vec); + p1 = vqrdmulhq_n_s32(p1, params.output_multiplier); p2 = vqrdmulhq_n_s32(p2, params.output_multiplier); + p3 = vqrdmulhq_n_s32(p3, params.output_multiplier); + p4 = vqrdmulhq_n_s32(p4, params.output_multiplier); using gemmlowp::RoundingDivideByPOT; p1 = RoundingDivideByPOT(p1, right_shift); p2 = RoundingDivideByPOT(p2, right_shift); + p3 = RoundingDivideByPOT(p3, right_shift); + p4 = RoundingDivideByPOT(p4, right_shift); const auto p1_narrowed = vqmovn_s32(p1); const auto p2_narrowed = vqmovn_s32(p2); - const auto p = - vaddq_s16(vcombine_s16(p1_narrowed, p2_narrowed), output_offset_vector); - const auto clamped = - vmax_s8(output_activation_min_vector, - vmin_s8(output_activation_max_vector, vqmovn_s16(p))); - vst1_s8(output_data + i, clamped); + const auto p3_narrowed = vqmovn_s32(p3); + const auto p4_narrowed = vqmovn_s32(p4); + + const int16x8_t p_part1 = + vaddq_s16(vcombine_s16(p2_narrowed, p1_narrowed), output_offset_vector); + const int16x8_t p_part2 = + vaddq_s16(vcombine_s16(p4_narrowed, p3_narrowed), output_offset_vector); + const int8x16_t p = vcombine_s8(vqmovn_s16(p_part2), vqmovn_s16(p_part1)); + + const auto clamped = vmaxq_s8(output_activation_min_vector, + vminq_s8(output_activation_max_vector, p)); + vst1q_s8(output_data + i, clamped); } #endif // NEON @@ -117,40 +149,63 @@ inline void MulSimpleBroadcast(int size, const ArithmeticParams& params, const auto input2_offset_vector = vdupq_n_s16(params.input2_offset); const auto output_offset_vector = vdupq_n_s16(params.output_offset); const auto output_activation_min_vector = - vdup_n_s8(params.quantized_activation_min); + vdupq_n_s8(params.quantized_activation_min); const auto output_activation_max_vector = - vdup_n_s8(params.quantized_activation_max); + vdupq_n_s8(params.quantized_activation_max); const int left_shift = std::max(0, params.output_shift); const int right_shift = std::max(0, -params.output_shift); const int32x4_t left_shift_vec = vdupq_n_s32(left_shift); - for (; i <= size - 8; i += 8) { - // We load / store 8 at a time, multiplying as two sets of 4 int32s. - const auto input2_val_original = vld1_s8(input2_data + i); - const auto input2_val_s16 = vmovl_s8(input2_val_original); - const auto input2_val = vaddq_s16(input2_val_s16, input2_offset_vector); + for (; i <= size - 16; i += 16) { + // We load / store 16 at a time, multiplying as four sets of 4 int32s. + const auto input2_val_original = vld1q_s8(input2_data + i); + const auto input2_val_s16_high = + vmovl_s8(vget_high_s8(input2_val_original)); + const auto input2_val_s16_low = vmovl_s8(vget_low_s8(input2_val_original)); - const auto input2_val_low = vget_low_s16(input2_val); - const auto input2_val_high = vget_high_s16(input2_val); + const auto input2_val_high = + vaddq_s16(input2_val_s16_high, input2_offset_vector); + const auto input2_val_low = + vaddq_s16(input2_val_s16_low, input2_offset_vector); - auto p1 = vmull_n_s16(input2_val_low, input1_val); - auto p2 = vmull_n_s16(input2_val_high, input1_val); + const auto input2_val_low_low = vget_low_s16(input2_val_low); + const auto input2_val_low_high = vget_high_s16(input2_val_low); + const auto input2_val_high_low = vget_low_s16(input2_val_high); + const auto input2_val_high_high = vget_high_s16(input2_val_high); + + auto p1 = vmull_n_s16(input2_val_high_high, input1_val); + auto p2 = vmull_n_s16(input2_val_high_low, input1_val); + auto p3 = vmull_n_s16(input2_val_low_high, input1_val); + auto p4 = vmull_n_s16(input2_val_low_low, input1_val); p1 = vshlq_s32(p1, left_shift_vec); p2 = vshlq_s32(p2, left_shift_vec); + p3 = vshlq_s32(p3, left_shift_vec); + p4 = vshlq_s32(p4, left_shift_vec); + p1 = vqrdmulhq_n_s32(p1, params.output_multiplier); p2 = vqrdmulhq_n_s32(p2, params.output_multiplier); + p3 = vqrdmulhq_n_s32(p3, params.output_multiplier); + p4 = vqrdmulhq_n_s32(p4, params.output_multiplier); using gemmlowp::RoundingDivideByPOT; p1 = RoundingDivideByPOT(p1, right_shift); p2 = RoundingDivideByPOT(p2, right_shift); + p3 = RoundingDivideByPOT(p3, right_shift); + p4 = RoundingDivideByPOT(p4, right_shift); const auto p1_narrowed = vqmovn_s32(p1); const auto p2_narrowed = vqmovn_s32(p2); - const auto p = - vaddq_s16(vcombine_s16(p1_narrowed, p2_narrowed), output_offset_vector); - const auto clamped = - vmax_s8(output_activation_min_vector, - vmin_s8(output_activation_max_vector, vqmovn_s16(p))); - vst1_s8(output_data + i, clamped); + const auto p3_narrowed = vqmovn_s32(p3); + const auto p4_narrowed = vqmovn_s32(p4); + + const int16x8_t p_part1 = + vaddq_s16(vcombine_s16(p2_narrowed, p1_narrowed), output_offset_vector); + const int16x8_t p_part2 = + vaddq_s16(vcombine_s16(p4_narrowed, p3_narrowed), output_offset_vector); + const int8x16_t p = vcombine_s8(vqmovn_s16(p_part2), vqmovn_s16(p_part1)); + + const auto clamped = vmaxq_s8(output_activation_min_vector, + vminq_s8(output_activation_max_vector, p)); + vst1q_s8(output_data + i, clamped); } #endif // NEON diff --git a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc index 4c90cd86a56..c96f298370a 100644 --- a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc +++ b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc @@ -1466,16 +1466,20 @@ void NeonMatrixBatchVectorMultiplyAccumulate( int i = 0; int32_t* scratch_ptr = scratch; for (; i <= total_size - 8; i += 8, result += 8) { - float batch_scaling_factor0 = scaling_factors[i / m_rows]; - float batch_scaling_factor1 = scaling_factors[(i + 4) / m_rows]; - if (per_channel_scale) { - batch_scaling_factor0 *= per_channel_scale[i % m_rows]; - batch_scaling_factor1 *= per_channel_scale[(i + 4) % m_rows]; - } + const float batch_scaling_factor0 = scaling_factors[i / m_rows]; + const float batch_scaling_factor1 = scaling_factors[(i + 4) / m_rows]; const int batch_input_offset0 = -input_offset[i / m_rows]; const int batch_input_offset1 = -input_offset[(i + 4) / m_rows]; - const float32x4_t scaling_factor0 = vdupq_n_f32(batch_scaling_factor0); - const float32x4_t scaling_factor1 = vdupq_n_f32(batch_scaling_factor1); + float32x4_t scaling_factor0 = vdupq_n_f32(batch_scaling_factor0); + float32x4_t scaling_factor1 = vdupq_n_f32(batch_scaling_factor1); + if (per_channel_scale) { + const float32x4_t per_channel_scale0 = + vld1q_f32(&per_channel_scale[i % m_rows]); + const float32x4_t per_channel_scale1 = + vld1q_f32(&per_channel_scale[(i + 4) % m_rows]); + scaling_factor0 = vmulq_f32(scaling_factor0, per_channel_scale0); + scaling_factor1 = vmulq_f32(scaling_factor1, per_channel_scale1); + } const int32x4_t input_offset0 = vdupq_n_s32(batch_input_offset0); const int32x4_t input_offset1 = vdupq_n_s32(batch_input_offset1); const int32x4_t row_sum0 = vld1q_s32(row_sums + (i % m_rows)); @@ -1498,7 +1502,10 @@ void NeonMatrixBatchVectorMultiplyAccumulate( scratch_ptr += i; for (; i < total_size; i++) { - const float batch_scaling_factor = scaling_factors[i / m_rows]; + float batch_scaling_factor = scaling_factors[i / m_rows]; + if (per_channel_scale) { + batch_scaling_factor *= per_channel_scale[i % m_rows]; + } const int32_t zero_point = input_offset[i / m_rows]; int32_t dotprod = *(scratch_ptr++); dotprod -= row_sums[i % m_rows] * zero_point; @@ -1514,16 +1521,6 @@ void NeonMatrixBatchVectorMultiplyAccumulate( per_channel_scale, input_offset, row_sums); } -void NeonMatrixBatchVectorMultiplyAccumulate( - const int8_t* __restrict__ matrix, const int m_rows, const int m_cols, - const int8_t* __restrict__ vectors, const float* scaling_factors, - int n_batch, float* __restrict__ result, const float* per_channel_scale, - const int32_t* input_offset) { - NeonMatrixBatchVectorMultiplyAccumulateImpl( - matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result, - per_channel_scale, input_offset, nullptr); -} - inline int64x2x2_t MulAdd(int32x4_t acc, int32x4_t lhs, int32x4_t rhs) { int64x2x2_t result; const int64x2_t lhs_low = vmovl_s32(vget_low_s32(lhs)); diff --git a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h index b978bf5f3bb..86951fcd559 100644 --- a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h +++ b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.h @@ -55,16 +55,6 @@ void MatrixBatchVectorMultiplyAccumulate(const int8_t* __restrict__ matrix, vectors, scaling_factors, n_batch, scratch, result, context); } -void MatrixBatchVectorMultiplyAccumulate( - const int8_t* __restrict__ matrix, const int m_rows, const int m_cols, - const int8_t* __restrict__ vectors, const float* scaling_factors, - int n_batch, float* __restrict__ result, const float* per_channel_scale, - const int32_t* input_offset) { - NEON_OR_PORTABLE(MatrixBatchVectorMultiplyAccumulate, matrix, m_rows, m_cols, - vectors, scaling_factors, n_batch, result, per_channel_scale, - input_offset); -} - void MatrixBatchVectorMultiplyAccumulate( const int8_t* __restrict__ matrix, const int m_rows, const int m_cols, const int8_t* __restrict__ vectors, const float* scaling_factors, diff --git a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils_impl.h b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils_impl.h index 1b043390c22..1554d07a61c 100644 --- a/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils_impl.h +++ b/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils_impl.h @@ -62,12 +62,6 @@ void NeonMatrixBatchVectorMultiplyAccumulate( const int32_t* input_offset, int32_t* scratch, int32_t* row_sums, bool* compute_row_sums, CpuBackendContext* context); -void NeonMatrixBatchVectorMultiplyAccumulate( - const int8_t* __restrict__ matrix, const int m_rows, const int m_cols, - const int8_t* __restrict__ vectors, const float* scaling_factors, - int n_batch, float* __restrict__ result, const float* per_channel_scale, - const int32_t* input_offset); - void NeonApplyLayerNorm(const int16_t* input, const int16_t* layer_norm_weights, const int32_t* bias, int32_t layer_norm_scale_a, int32_t layer_norm_scale_b, int32_t variance_limit, diff --git a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h index b18f0f4bb5a..746ed622632 100644 --- a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h @@ -201,63 +201,35 @@ MatrixMap<Scalar> MapAsMatrixWithGivenNumberOfRows(Scalar* data, // MultiplyByQuantizedMultipler. #ifdef USE_NEON inline int32x4x4_t MultiplyByQuantizedMultiplier4Rows( - int32x4x4_t input_val, int32 quantized_multiplier, int shift) { - using gemmlowp::RoundingDivideByPOT; - using gemmlowp::SaturatingRoundingDoublingHighMul; - const int left_shift = shift > 0 ? shift : 0; - const int right_shift = shift > 0 ? 0 : -shift; + int32x4x4_t input_val, int32 quantized_multiplier, int32 shift) { + const int left_shift = std::max(shift, 0); + const int right_shift = std::min(shift, 0); int32x4x4_t result; - // The vector type support for SaturatingRoundingDoublingHighMulth in gemmlowp - // is limited to NEON. -#ifdef GEMMLOWP_NEON - const int32x4_t left_shifted_one_dup = vdupq_n_s32(1 << left_shift); - result.val[0] = - RoundingDivideByPOT(SaturatingRoundingDoublingHighMul( - vmulq_s32(input_val.val[0], left_shifted_one_dup), - quantized_multiplier), - right_shift); - result.val[1] = - RoundingDivideByPOT(SaturatingRoundingDoublingHighMul( - vmulq_s32(input_val.val[1], left_shifted_one_dup), - quantized_multiplier), - right_shift); - result.val[2] = - RoundingDivideByPOT(SaturatingRoundingDoublingHighMul( - vmulq_s32(input_val.val[2], left_shifted_one_dup), - quantized_multiplier), - right_shift); - result.val[3] = - RoundingDivideByPOT(SaturatingRoundingDoublingHighMul( - vmulq_s32(input_val.val[3], left_shifted_one_dup), - quantized_multiplier), - right_shift); -#else - for (int i = 0; i < 4; ++i) { - int32_t vals[4]; - vals[0] = RoundingDivideByPOT( - SaturatingRoundingDoublingHighMul( - vgetq_lane_s32(input_val.val[i], 0) * (1 << left_shift), - quantized_multiplier), - right_shift); - vals[1] = RoundingDivideByPOT( - SaturatingRoundingDoublingHighMul( - vgetq_lane_s32(input_val.val[i], 1) * (1 << left_shift), - quantized_multiplier), - right_shift); - vals[2] = RoundingDivideByPOT( - SaturatingRoundingDoublingHighMul( - vgetq_lane_s32(input_val.val[i], 2) * (1 << left_shift), - quantized_multiplier), - right_shift); - vals[3] = RoundingDivideByPOT( - SaturatingRoundingDoublingHighMul( - vgetq_lane_s32(input_val.val[i], 3) * (1 << left_shift), - quantized_multiplier), - right_shift); - result.val[i] = vld1q_s32(reinterpret_cast<int32_t*>(&vals)); - } -#endif + int32x4_t multiplier_dup = vdupq_n_s32(quantized_multiplier); + int32x4_t left_shift_dup = vdupq_n_s32(left_shift); + int32x4_t right_shift_dup = vdupq_n_s32(right_shift); + + result.val[0] = + vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[0], left_shift_dup), + multiplier_dup), + right_shift_dup); + + result.val[1] = + vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[1], left_shift_dup), + multiplier_dup), + right_shift_dup); + + result.val[2] = + vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[2], left_shift_dup), + multiplier_dup), + right_shift_dup); + + result.val[3] = + vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[3], left_shift_dup), + multiplier_dup), + right_shift_dup); + return result; } #endif @@ -7926,16 +7898,16 @@ inline void MaximumElementwise(int size, const ArithmeticParams& params, const int8* input1_data, const int8* input2_data, int8* output_data) { ruy::profiler::ScopeLabel label("MaximumElementwiseInt8/8bit"); - int i = 0; #ifdef USE_NEON - for (; i <= size - 8; i += 8) { - const int8x8_t input1_val_original = vld1_s8(input1_data + i); - const int8x8_t input2_val_original = vld1_s8(input2_data + i); - const int8x8_t max_data = vmax_s8(input1_val_original, input2_val_original); - vst1_s8(output_data + i, max_data); + for (; i <= size - 16; i += 16) { + const int8x16_t input1_val_original = vld1q_s8(input1_data + i); + const int8x16_t input2_val_original = vld1q_s8(input2_data + i); + const int8x16_t max_data = + vmaxq_s8(input1_val_original, input2_val_original); + vst1q_s8(output_data + i, max_data); } -#endif // NEON +#endif // USE_NEON for (; i < size; ++i) { const int8 input1_val = input1_data[i]; const int8 input2_val = input2_data[i]; @@ -7950,13 +7922,14 @@ inline void MaximumScalarBroadcast(int size, const ArithmeticParams& params, int i = 0; #ifdef USE_NEON - const int8x8_t input1_val_original = vdup_n_s8(input1_data); - for (; i <= size - 8; i += 8) { - const int8x8_t input2_val_original = vld1_s8(input2_data + i); - const int8x8_t max_data = vmax_s8(input1_val_original, input2_val_original); - vst1_s8(output_data + i, max_data); + const int8x16_t input1_val_original = vdupq_n_s8(input1_data); + for (; i <= size - 16; i += 16) { + const int8x16_t input2_val_original = vld1q_s8(input2_data + i); + const int8x16_t max_data = + vmaxq_s8(input1_val_original, input2_val_original); + vst1q_s8(output_data + i, max_data); } -#endif // NEON +#endif // USE_NEON for (; i < size; ++i) { const int8 input2_val = input2_data[i]; output_data[i] = std::max(input1_data, input2_val); @@ -7967,6 +7940,7 @@ inline void MaximumScalarBroadcast(int size, const ArithmeticParams& params, inline void MinimumElementwise(int size, const ArithmeticParams& params, const int8* input1_data, const int8* input2_data, int8* output_data) { + ruy::profiler::ScopeLabel label("MinimumElementwiseInt8/8bit"); int i = 0; #ifdef USE_NEON for (; i <= size - 16; i += 16) { @@ -7987,6 +7961,7 @@ inline void MinimumElementwise(int size, const ArithmeticParams& params, inline void MinimumScalarBroadcast(int size, const ArithmeticParams& params, int8 input1_data, const int8* input2_data, int8* output_data) { + ruy::profiler::ScopeLabel label("MinimumScalarBroadcastInt8/8bit"); int i = 0; #ifdef USE_NEON @@ -8013,10 +7988,7 @@ inline void BinaryBroadcastFiveFold(const ArithmeticParams& unswitched_params, const RuntimeShape& output_shape, int8* output_data, ElementwiseF elementwise_f, - ScalarBroadcastF scalar_broadcast_f, - const std::string& label_name) { - ruy::profiler::ScopeLabel label(label_name); - + ScalarBroadcastF scalar_broadcast_f) { ArithmeticParams switched_params = unswitched_params; switched_params.input1_offset = unswitched_params.input2_offset; switched_params.input1_multiplier = unswitched_params.input2_multiplier; @@ -8118,8 +8090,7 @@ inline void BroadcastMaximumDispatch(const ArithmeticParams& params, BinaryBroadcastFiveFold(params, input1_shape, input1_data, input2_shape, input2_data, output_shape, output_data, - MaximumElementwise, MaximumScalarBroadcast, - "BroadcastMaximumFivefoldInt8/8bit"); + MaximumElementwise, MaximumScalarBroadcast); } template <typename Op> @@ -8138,8 +8109,7 @@ inline void BroadcastMinimumDispatch(const ArithmeticParams& params, BinaryBroadcastFiveFold(params, input1_shape, input1_data, input2_shape, input2_data, output_shape, output_data, - MinimumElementwise, MinimumScalarBroadcast, - "BroadcastMinimumFivefoldInt8/8bit"); + MinimumElementwise, MinimumScalarBroadcast); } } // namespace optimized_ops diff --git a/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.cc b/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.cc index 7fb69e7b4f4..80cc14c6d26 100644 --- a/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.cc +++ b/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.cc @@ -24,6 +24,7 @@ limitations under the License. #include <cstdint> +#include "tensorflow/lite/kernels/cpu_backend_context.h" #include "tensorflow/lite/kernels/internal/compatibility.h" namespace tflite { @@ -89,18 +90,24 @@ float GetFloatVectorElement(__m128 v) { } // namespace -void SseMatrixBatchVectorMultiplyAccumulate( +void SseMatrixBatchVectorMultiplyAccumulateImpl( const int8_t* __restrict__ matrix, const int m_rows, const int m_cols, const int8_t* __restrict__ vectors, const float* __restrict__ scaling_factors, int n_batch, - float* __restrict__ result) { + float* __restrict__ result, const float* per_channel_scale, + const int32_t* input_offset, const int32_t* row_sums) { for (std::intptr_t batch = 0; batch < n_batch; ++batch) { const float batch_scaling_factor = scaling_factors[batch]; + const int32_t batch_offset = input_offset ? input_offset[batch] : 0; // Compute dot-product for every column. for (std::intptr_t row = 0; row < m_rows; ++row) { // Get the address of the first element of the row. const int8_t* __restrict__ row_ptr = matrix + row * m_cols; - + const float row_scale = + per_channel_scale ? per_channel_scale[row] * batch_scaling_factor + : batch_scaling_factor; + const int32_t row_offset = + row_sums && batch_offset ? batch_offset * row_sums[row] : 0; // Initialize the dot product sum for the row to 0. __m128i dotprod_32x4 = _mm_setzero_si128(); std::intptr_t col = 0; @@ -152,8 +159,10 @@ void SseMatrixBatchVectorMultiplyAccumulate( for (; col < m_cols; ++col) { sum += row_ptr[col] * vectors[col]; } // for col - - *result += sum * batch_scaling_factor; + if (row_offset) { + sum -= row_offset; + } + *result += sum * row_scale; ++result; } // for row @@ -165,56 +174,30 @@ void SseMatrixBatchVectorMultiplyAccumulate( const int8_t* __restrict__ matrix, const int m_rows, const int m_cols, const int8_t* __restrict__ vectors, const float* __restrict__ scaling_factors, int n_batch, - float* __restrict__ result, const float* __restrict__ per_channel_scale, - const int32_t* __restrict__ input_offset) { - if (input_offset == nullptr) { - SseMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vectors, - scaling_factors, n_batch, result); - return; - } - static constexpr std::intptr_t kBlockSize = 16; - for (std::intptr_t batch = 0; batch < n_batch; ++batch) { - const float batch_scaling_factor = scaling_factors[batch]; - for (std::intptr_t row = 0; row < m_rows; ++row) { - const int8_t* __restrict__ row_ptr = matrix + row * m_cols; - float scale = batch_scaling_factor; - if (per_channel_scale != nullptr) { - scale *= per_channel_scale[row]; - } - __m128i dotprod_32x4 = _mm_setzero_si128(); - __m128i row_sum_16x8 = _mm_setzero_si128(); - std::intptr_t col = 0; - for (; col < (m_cols & ~(kBlockSize - 1)); col += kBlockSize) { - const __m128i vec_8x16 = - _mm_loadu_si128(reinterpret_cast<const __m128i*>(vectors + col)); - const __m128i row_8x16 = - _mm_loadu_si128(reinterpret_cast<const __m128i*>(row_ptr + col)); - // dotprod += vec · row - dotprod_32x4 = - _mm_add_epi32(dotprod_32x4, DotProdInt8x4x4(vec_8x16, row_8x16)); + float* __restrict__ result) { + SseMatrixBatchVectorMultiplyAccumulateImpl( + matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result, + /*per_channel_scale=*/nullptr, /*input_offset=*/nullptr, + /*row_sums=*/nullptr); +} - // Pairwise add 16x 8-bit values; equivalently, multipy-add with 1. - // Result is 8x 16-bit values. - const __m128i row_16x8 = _mm_maddubs_epi16(_mm_set1_epi8(1), row_8x16); - row_sum_16x8 = _mm_add_epi16(row_sum_16x8, row_16x8); - } // for col - // Pairwise add 8x 16-bit values; equivalently, multipy-add with 1. - // Result is 4x 32-bit values. - const __m128i row_sum_32x4 = - _mm_madd_epi16(row_sum_16x8, _mm_set1_epi16(1)); - int32_t sum = ReduceInt32x4(dotprod_32x4); - int32_t row_sum = ReduceInt32x4(row_sum_32x4); - // Postamble loop. - for (; col < m_cols; ++col) { - sum += row_ptr[col] * vectors[col]; - row_sum += row_ptr[col]; - } // for col - sum -= row_sum * input_offset[batch]; - *result += sum * scale; - ++result; - } // for row - vectors += m_cols; - } // for batch +void SseMatrixBatchVectorMultiplyAccumulate( + const int8_t* __restrict__ matrix, const int m_rows, const int m_cols, + const int8_t* __restrict__ vectors, + const float* __restrict__ scaling_factors, int n_batch, + float* __restrict__ result, const float* per_channel_scale, + const int32_t* input_offset, int32_t* scratch, int32_t* row_sums, + bool* compute_row_sums, CpuBackendContext* context) { + if ((input_offset != nullptr) && (!compute_row_sums || *compute_row_sums)) { + memset(row_sums, 0, sizeof(int32_t) * m_rows); + SseReductionSumVector(matrix, row_sums, m_rows, m_cols); + if (compute_row_sums) { + *compute_row_sums = false; + } + } + SseMatrixBatchVectorMultiplyAccumulateImpl( + matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result, + per_channel_scale, input_offset, row_sums); } namespace { @@ -347,6 +330,44 @@ void SseSparseMatrixBatchVectorMultiplyAccumulate( } // for batch } +void SseReductionSumVector(const int8_t* input_vector, int32_t* output_vector, + const int output_size, const int reduction_size) { + static constexpr std::intptr_t kBlockSize = 16; + for (std::intptr_t row = 0; row < output_size; ++row) { + const int8_t* __restrict__ row_ptr = input_vector + row * reduction_size; + __m128i row_sum_16x8 = _mm_setzero_si128(); + std::intptr_t col = 0; + for (; col < (reduction_size & ~(kBlockSize - 1)); col += kBlockSize) { + const __m128i row_8x16 = + _mm_loadu_si128(reinterpret_cast<const __m128i*>(row_ptr + col)); + const __m128i row_16x8 = _mm_maddubs_epi16(_mm_set1_epi8(1), row_8x16); + row_sum_16x8 = _mm_add_epi16(row_sum_16x8, row_16x8); + } // for col +#ifdef __SSE4_1__ + // Postamble for 8x 8-bit inputs. + if (col < (reduction_size & ~7)) { + // _mm_loadu_si64 not supported in gcc versions < 9, breaks kokoro build. + const __m128i row_16x8 = _mm_cvtepi8_epi16( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(row_ptr + col))); + // dotprod += vec · row + row_sum_16x8 = _mm_add_epi16(row_sum_16x8, row_16x8); + col += 8; + } +#endif + const __m128i row_sum_32x4 = + _mm_madd_epi16(row_sum_16x8, _mm_set1_epi16(1)); + int32_t row_sum = ReduceInt32x4(row_sum_32x4); +#if defined(__SSE4_1__) && defined(__clang__) + // SSE 4.1: Don't try to unroll and vectorize this, already done above. +#pragma clang loop unroll(disable) vectorize(disable) +#endif + for (; col < reduction_size; col++) { + row_sum += *(row_ptr + col); + } + *(output_vector + row) += row_sum; + } +} + } // namespace tensor_utils } // namespace tflite diff --git a/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h b/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h index 986e70a7823..224d811e862 100644 --- a/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h +++ b/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils.h @@ -59,10 +59,9 @@ void MatrixBatchVectorMultiplyAccumulate( int n_batch, float* __restrict__ result, const float* per_channel_scale, const int32_t* input_offset, int32_t* scratch, int32_t* row_sums, bool* compute_row_sums, CpuBackendContext* context) { - PortableMatrixBatchVectorMultiplyAccumulate( - matrix, m_rows, m_cols, vectors, scaling_factors, n_batch, result, - per_channel_scale, input_offset, scratch, row_sums, compute_row_sums, - context); + SSE_OR_PORTABLE(MatrixBatchVectorMultiplyAccumulate, matrix, m_rows, m_cols, + vectors, scaling_factors, n_batch, result, per_channel_scale, + input_offset, scratch, row_sums, compute_row_sums, context); } void MatrixBatchVectorMultiplyAccumulate( @@ -75,17 +74,6 @@ void MatrixBatchVectorMultiplyAccumulate( vectors, scaling_factors, n_batch, result); } -void MatrixBatchVectorMultiplyAccumulate( - const int8_t* __restrict__ matrix, const int m_rows, const int m_cols, - const int8_t* __restrict__ vectors, - const float* __restrict__ scaling_factors, int n_batch, - float* __restrict__ result, const float* __restrict__ per_channel_scale, - const int32_t* __restrict__ input_offset) { - SSE_OR_PORTABLE(MatrixBatchVectorMultiplyAccumulate, matrix, m_rows, m_cols, - vectors, scaling_factors, n_batch, result, per_channel_scale, - input_offset); -} - void SparseMatrixBatchVectorMultiplyAccumulate1x4( const float* __restrict__ matrix, const int32_t* __restrict__ segments, const int32_t* __restrict__ indices, int m_rows, int m_cols, @@ -315,8 +303,8 @@ void ReductionSumVector(const int32_t* input_vector, int32_t* output_vector, void ReductionSumVector(const int8_t* input_vector, int32_t* output_vector, int output_size, int reduction_size) { - NEON_OR_PORTABLE(ReductionSumVector, input_vector, output_vector, output_size, - reduction_size); + SSE_OR_PORTABLE(ReductionSumVector, input_vector, output_vector, output_size, + reduction_size); } void MeanStddevNormalization(const float* input_vector, float* output_vector, diff --git a/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils_impl.h b/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils_impl.h index 1996b1f30a9..c5ede624762 100644 --- a/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils_impl.h +++ b/tensorflow/lite/kernels/internal/optimized/sse_tensor_utils_impl.h @@ -17,6 +17,8 @@ limitations under the License. #include <cstdint> +#include "tensorflow/lite/kernels/cpu_backend_context.h" + #if defined(_MSC_VER) #define __restrict__ __restrict #endif @@ -38,8 +40,9 @@ void SseMatrixBatchVectorMultiplyAccumulate( const int8_t* __restrict__ matrix, const int m_rows, const int m_cols, const int8_t* __restrict__ vectors, const float* __restrict__ scaling_factors, int n_batch, - float* __restrict__ result, const float* __restrict__ per_channel_scale, - const int32_t* __restrict__ input_offset); + float* __restrict__ result, const float* per_channel_scale, + const int32_t* input_offset, int32_t* scratch, int32_t* row_sums, + bool* compute_row_sums, CpuBackendContext* context); // Matrix multiplication for quantized values using symmetric quantization. // Sparse version. @@ -49,6 +52,9 @@ void SseSparseMatrixBatchVectorMultiplyAccumulate( const float* __restrict__ scaling_factors, int n_batch, float* __restrict__ result); +void SseReductionSumVector(const int8_t* input_vector, int32_t* output_vector, + const int output_size, const int reduction_size); + #endif // __SSSE3__ } // namespace tensor_utils diff --git a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc index 0e66dfee191..4f6db290d4f 100644 --- a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc +++ b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc @@ -161,35 +161,6 @@ void PortableMatrixBatchVectorMultiplyAccumulate( } // for batch } -void PortableMatrixBatchVectorMultiplyAccumulate( - const int8_t* __restrict__ matrix, const int m_rows, const int m_cols, - const int8_t* __restrict__ vectors, const float* scaling_factors, - int n_batch, float* __restrict__ result, const float* per_channel_scale, - const int32_t* input_offset) { - for (int batch = 0; batch < n_batch; ++batch, vectors += m_cols) { - const float batch_scaling_factor = scaling_factors[batch]; - const float batch_offset = input_offset[batch]; - const int8_t* row_ptr = matrix; - for (int row = 0; row < m_rows; ++row) { - int32_t dotprod = 0; - float scale = batch_scaling_factor; - if (per_channel_scale) { - scale *= per_channel_scale[row]; - } -#if defined(__GNUC__) - // Prefetch the row to cache. - __builtin_prefetch(row_ptr, 0 /* prefetch for read */, - 3 /* temporal locality */); -#endif - for (int col = 0; col < m_cols; ++col, ++row_ptr) { - dotprod += (*row_ptr) * (vectors[col] - batch_offset); - } // for col - *result += dotprod * scale; - ++result; - } // for row - } // for batch -} - void PortableMatrixBatchVectorMultiplyAccumulate( const int8_t* __restrict__ matrix, const int m_rows, const int m_cols, const int8_t* __restrict__ vectors, const float* scaling_factors, diff --git a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h index f2e6c9b4f7d..0fd7a407595 100644 --- a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h +++ b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h @@ -98,16 +98,6 @@ void MatrixBatchVectorMultiplyAccumulate(const int8_t* __restrict__ matrix, scaling_factors, n_batch, result); } -void MatrixBatchVectorMultiplyAccumulate( - const int8_t* __restrict__ matrix, const int m_rows, const int m_cols, - const int8_t* __restrict__ vectors, const float* scaling_factors, - int n_batch, float* __restrict__ result, const float* per_channel_scale, - const int32_t* input_offset) { - PortableMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vectors, - scaling_factors, n_batch, result, - per_channel_scale, input_offset); -} - void SparseMatrixBatchVectorMultiplyAccumulate1x4( const float* __restrict__ matrix, const int32_t* __restrict__ segments, const int32_t* __restrict__ indices, int m_rows, int m_cols, diff --git a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils_impl.h b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils_impl.h index 6c15a6cd919..34767ccd942 100644 --- a/tensorflow/lite/kernels/internal/reference/portable_tensor_utils_impl.h +++ b/tensorflow/lite/kernels/internal/reference/portable_tensor_utils_impl.h @@ -83,12 +83,6 @@ void PortableMatrixBatchVectorMultiplyAccumulate( int n_batch, int32_t* scratch, float* __restrict__ result, CpuBackendContext* context); -void PortableMatrixBatchVectorMultiplyAccumulate( - const int8_t* __restrict__ matrix, const int m_rows, const int m_cols, - const int8_t* __restrict__ vectors, const float* scaling_factors, - int n_batch, float* __restrict__ result, const float* per_channel_scale, - const int32_t* input_offset); - void PortableSparseMatrixBatchVectorMultiplyAccumulate1x4( const float* __restrict__ matrix, const int32_t* __restrict__ segments, const int32_t* __restrict__ indices, int m_rows, int m_cols, diff --git a/tensorflow/lite/kernels/internal/tensor_utils_test.cc b/tensorflow/lite/kernels/internal/tensor_utils_test.cc index 3ad59acdb68..878cf0d2618 100644 --- a/tensorflow/lite/kernels/internal/tensor_utils_test.cc +++ b/tensorflow/lite/kernels/internal/tensor_utils_test.cc @@ -1136,11 +1136,15 @@ std::vector<float> TestPerChannelDotprodMatrixBatchVectorMultiply( bool is_per_channel = true) { MatrixVectorData data = SetupMatrixVectorData(rows, cols, batch, negative, is_per_channel); - + std::vector<int32_t> scratch(rows * batch); + std::vector<int32_t> row_sums(rows); + bool compute_row_sums = true; + CpuBackendContext context; MatrixBatchVectorMultiplyAccumulate( data.matrix.data(), rows, cols, data.vectors.data(), data.scale_factors.data(), batch, &data.results[0], - data.per_channel_scales.data(), data.input_offsets.data()); + data.per_channel_scales.data(), data.input_offsets.data(), scratch.data(), + row_sums.data(), &compute_row_sums, &context); return data.results; } diff --git a/tensorflow/lite/kernels/kernel_util.h b/tensorflow/lite/kernels/kernel_util.h index 5793b08616d..d6a2dac8583 100644 --- a/tensorflow/lite/kernels/kernel_util.h +++ b/tensorflow/lite/kernels/kernel_util.h @@ -28,7 +28,7 @@ inline int NumDimensions(const TfLiteTensor* t) { return t->dims->size; } inline int SizeOfDimension(const TfLiteTensor* t, int dim) { return t->dims->data[dim]; } -inline const TfLiteTensor* GetInput(TfLiteContext* context, +inline const TfLiteTensor* GetInput(const TfLiteContext* context, const TfLiteNode* node, int index) { return &context ->tensors[flatbuffers::EndianScalar(node->inputs->data[index])]; diff --git a/tensorflow/lite/kernels/lstm_test.cc b/tensorflow/lite/kernels/lstm_test.cc index 2bd31eae8db..62634e6bfbd 100644 --- a/tensorflow/lite/kernels/lstm_test.cc +++ b/tensorflow/lite/kernels/lstm_test.cc @@ -2050,7 +2050,7 @@ TEST_P(CifgPeepholeProjectionNoClippingLayerNormLstmTest, }}; VerifyGoldens(lstm_input_, lstm_golden_output_, &layer_norm_lstm, - /*tolerance=*/0.000902065); + /*tolerance=*/0.0009021); } class CifgPeepholeProjectionNoClippingLayerNormLstmInt8Test diff --git a/tensorflow/lite/kernels/op_macros.h b/tensorflow/lite/kernels/op_macros.h index 33d033b10b6..8c1a6b1be16 100644 --- a/tensorflow/lite/kernels/op_macros.h +++ b/tensorflow/lite/kernels/op_macros.h @@ -19,6 +19,7 @@ limitations under the License. // non-portable function. #ifdef TF_LITE_MCU_DEBUG_LOG +#include "tensorflow/lite/micro/debug_log.h" #include "tensorflow/lite/micro/micro_error_reporter.h" #define DEBUG_LOG(x) \ diff --git a/tensorflow/lite/micro/BUILD b/tensorflow/lite/micro/BUILD index 5742a383b0f..3b05aee30f4 100644 --- a/tensorflow/lite/micro/BUILD +++ b/tensorflow/lite/micro/BUILD @@ -25,20 +25,16 @@ cc_library( cc_library( name = "micro_framework", srcs = [ - "debug_log.cc", "memory_helpers.cc", "micro_allocator.cc", - "micro_error_reporter.cc", "micro_interpreter.cc", "micro_optional_debug_tools.cc", "simple_memory_allocator.cc", "test_helpers.cc", ], hdrs = [ - "debug_log.h", "memory_helpers.h", "micro_allocator.h", - "micro_error_reporter.h", "micro_interpreter.h", "micro_mutable_op_resolver.h", "micro_optional_debug_tools.h", @@ -49,15 +45,46 @@ cc_library( copts = micro_copts(), deps = [ ":micro_compatibility", - ":micro_string", ":micro_utils", "//tensorflow/lite:type_to_tflitetype", "//tensorflow/lite/c:common", "//tensorflow/lite/core/api", "//tensorflow/lite/kernels/internal:compatibility", "//tensorflow/lite/kernels/internal:tensor", + "//tensorflow/lite/micro/memory_planner", "//tensorflow/lite/micro/memory_planner:greedy_memory_planner", "//tensorflow/lite/schema:schema_fbs", + "@flatbuffers//:runtime_cc", + ], +) + +cc_library( + name = "debug_log", + srcs = [ + "debug_log.cc", + ], + hdrs = [ + "debug_log.h", + ], + build_for_embedded = True, + copts = micro_copts(), +) + +cc_library( + name = "micro_error_reporter", + srcs = [ + "micro_error_reporter.cc", + ], + hdrs = [ + "micro_error_reporter.h", + ], + build_for_embedded = True, + copts = micro_copts(), + deps = [ + ":debug_log", + ":micro_compatibility", + ":micro_string", + "//tensorflow/lite/core/api", ], ) @@ -109,7 +136,7 @@ tflite_micro_cc_test( "micro_error_reporter_test.cc", ], deps = [ - ":micro_framework", + ":micro_error_reporter", ], ) diff --git a/tensorflow/lite/micro/arc_emsdp/debug_log.cc b/tensorflow/lite/micro/arc_emsdp/debug_log.cc new file mode 100644 index 00000000000..1b4d641e5e9 --- /dev/null +++ b/tensorflow/lite/micro/arc_emsdp/debug_log.cc @@ -0,0 +1,111 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/micro/debug_log.h" + +#include <cstdint> +#include <cstdio> +#include <cstring> + +// Print to debug console by default. One can define next to extend destinations +// set: EMSDP_LOG_TO_MEMORY +// : fill .debug_log memory region (data section) with passed chars. +// EMSDP_LOG_TO_HOST +// : Use MetaWare HostLink to print output log. Requires Synopsys MetaWare +// debugger +// EMSDP_LOG_TO_UART +// : use default debug UART (out to FTDI channel 0). The same USB Port is used +// for JTAG. +#define EMSDP_LOG_TO_UART + +// Memory size for symbols dump in EMSDP_LOG_TO_MEMORY destination +#define EMSDP_LOG_TO_MEMORY_SIZE (2 * 1024) + +// EMSDP Debug UART related defines (registers and bits) +#define EMSDP_DBG_UART_BASE (0xF0004000U) +#define DW_UART_CPR_FIFO_STAT (1 << 10) +#define DW_UART_USR_TFNF (0x02) +#define DW_UART_LSR_TXD_EMPTY (0x20) + +// EMSDP UART registers map (only necessairy fields) +typedef volatile struct dw_uart_reg { + uint32_t DATA; /* data in/out and DLL */ + uint32_t RES1[4]; + uint32_t LSR; /* Line Status Register */ + uint32_t RES2[25]; + uint32_t USR; /* UART status register */ + uint32_t RES3[29]; + uint32_t CPR; /* Component parameter register */ +} DW_UART_REG; + +// For simplicity we assume U-boot has already initialized debug console during +// application loading (or on reset). Hence, we use only status and data +// registers to organize blocking loop for printing symbols. No input and no IRQ +// handling. See embarc_osp repository for full EMSDP uart driver. +// (https://github.com/foss-for-synopsys-dwc-arc-processors/embarc_osp) +void DbgUartSendStr(const char* s) { + DW_UART_REG* uart_reg_ptr = (DW_UART_REG*)(EMSDP_DBG_UART_BASE); + const char* src = s; + while (*src) { + // Check uart status to send char + bool uart_is_ready = false; + if (uart_reg_ptr->CPR & DW_UART_CPR_FIFO_STAT) + uart_is_ready = ((uart_reg_ptr->USR & DW_UART_USR_TFNF) != 0); + else + uart_is_ready = ((uart_reg_ptr->LSR & DW_UART_LSR_TXD_EMPTY) != 0); + + // Send char if uart is ready. + if (uart_is_ready) uart_reg_ptr->DATA = *src++; + } +} + +// Simple dump of symbols to a pre-allocated memory region. +// When total log exceeds memory region size, cursor is moved to its begining. +// The memory region can be viewed afterward with debugger. +// It can be viewed/read with debugger afterward. +void LogToMem(const char* s) { + static int cursor = 0; +#pragma Bss(".debug_log") + static volatile char debug_log_mem[EMSDP_LOG_TO_MEMORY_SIZE]; +#pragma Bss() + + const char* src = s; + while (*src) { + debug_log_mem[cursor] = *src++; + cursor = (cursor < EMSDP_LOG_TO_MEMORY_SIZE) ? cursor + 1 : 0; + } + debug_log_mem[cursor] = '^'; +} + +extern "C" void DebugLog(const char* s) { +#ifndef TF_LITE_STRIP_ERROR_STRINGS + +#if defined EMSDP_LOG_TO_UART + DbgUartSendStr(s); +#endif + +#if defined EMSDP_LOG_TO_MEMORY +#warning \ + "EMSDP_LOG_TO_MEMORY is defined. View .debug_log memory region for stdout" + LogToMem(s); +#endif + +#if defined EMSDP_LOG_TO_HOST +#warning "EMSDP_LOG_TO_HOST is defined. Ensure hostlib is linked." + fprintf(stderr, "%s", s); +#endif + +#endif // TF_LITE_STRIP_ERROR_STRINGS +} diff --git a/tensorflow/lite/micro/benchmarks/BUILD b/tensorflow/lite/micro/benchmarks/BUILD index 4af3267d769..73b288d2bc1 100644 --- a/tensorflow/lite/micro/benchmarks/BUILD +++ b/tensorflow/lite/micro/benchmarks/BUILD @@ -46,6 +46,7 @@ cc_binary( deps = [ ":keyword_scrambled_model_data", "//tensorflow/lite/c:common", + "//tensorflow/lite/micro:micro_error_reporter", "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro/kernels:micro_ops", "//tensorflow/lite/micro/testing:micro_benchmark", @@ -58,6 +59,7 @@ cc_binary( deps = [ "//tensorflow/lite:schema_fbs_version", "//tensorflow/lite/c:common", + "//tensorflow/lite/micro:micro_error_reporter", "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro:micro_utils", "//tensorflow/lite/micro/examples/person_detection:model_settings", diff --git a/tensorflow/lite/micro/benchmarks/keyword_scrambled_model_data.cc b/tensorflow/lite/micro/benchmarks/keyword_scrambled_model_data.cc index c1e37dfb37e..834f44ca5ab 100644 --- a/tensorflow/lite/micro/benchmarks/keyword_scrambled_model_data.cc +++ b/tensorflow/lite/micro/benchmarks/keyword_scrambled_model_data.cc @@ -15,19 +15,8 @@ limitations under the License. #include "tensorflow/lite/micro/benchmarks/keyword_scrambled_model_data.h" -// We need to keep the data array aligned on some architectures. -#ifdef __has_attribute -#define HAVE_ATTRIBUTE(x) __has_attribute(x) -#else -#define HAVE_ATTRIBUTE(x) 0 -#endif -#if HAVE_ATTRIBUTE(aligned) || (defined(__GNUC__) && !defined(__clang__)) -#define DATA_ALIGN_ATTRIBUTE __attribute__((aligned(4))) -#else -#define DATA_ALIGN_ATTRIBUTE -#endif - -const unsigned char g_keyword_scrambled_model_data[] DATA_ALIGN_ATTRIBUTE = { +// Keep model aligned to 8 bytes to guarantee aligned 64-bit accesses. +alignas(8) const unsigned char g_keyword_scrambled_model_data[] = { 0x18, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x14, 0x00, 0x10, 0x00, 0x0c, 0x00, 0x08, 0x00, 0x00, 0x00, 0x04, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0xd0, 0x6e, 0x00, 0x00, diff --git a/tensorflow/lite/micro/examples/hello_world/BUILD b/tensorflow/lite/micro/examples/hello_world/BUILD index 155aaafd98c..4488c192abb 100644 --- a/tensorflow/lite/micro/examples/hello_world/BUILD +++ b/tensorflow/lite/micro/examples/hello_world/BUILD @@ -35,6 +35,7 @@ tflite_micro_cc_test( deps = [ ":model", "//tensorflow/lite:schema_fbs_version", + "//tensorflow/lite/micro:micro_error_reporter", "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro/kernels:all_ops_resolver", "//tensorflow/lite/micro/kernels:micro_ops", @@ -54,7 +55,7 @@ cc_library( copts = micro_copts(), deps = [ "//tensorflow/lite/c:common", - "//tensorflow/lite/micro:micro_framework", + "//tensorflow/lite/micro:micro_error_reporter", ], ) @@ -86,8 +87,15 @@ cc_binary( ":model", ":output_handler", "//tensorflow/lite:schema_fbs_version", + "//tensorflow/lite/micro:micro_error_reporter", "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro/kernels:all_ops_resolver", "//tensorflow/lite/schema:schema_fbs", ], ) + +sh_test( + name = "hello_world_binary_test", + srcs = ["hello_world_binary_test.sh"], + data = [":hello_world"], +) diff --git a/tensorflow/lite/micro/examples/hello_world/README.md b/tensorflow/lite/micro/examples/hello_world/README.md index 020a7d49e88..3b633890306 100644 --- a/tensorflow/lite/micro/examples/hello_world/README.md +++ b/tensorflow/lite/micro/examples/hello_world/README.md @@ -14,6 +14,7 @@ of the device. ## Table of contents +- [Deploy to ARC EM SDP](#deploy-to-arc-em-sdp) - [Deploy to Arduino](#deploy-to-arduino) - [Deploy to ESP32](#deploy-to-esp32) - [Deploy to SparkFun Edge](#deploy-to-sparkfun-edge) @@ -21,6 +22,78 @@ of the device. - [Run the tests on a development machine](#run-the-tests-on-a-development-machine) - [Train your own model](#train-your-own-model) +## Deploy to ARC EM SDP + +The following instructions will help you to build and deploy this example to +[ARC EM SDP](https://www.synopsys.com/dw/ipdir.php?ds=arc-em-software-development-platform) +board. General information and instructions on using the board with TensorFlow +Lite Micro can be found in the common +[ARC targets description](/tensorflow/lite/micro/tools/make/targets/arc/README.md). + +### Initial Setup + +Follow the instructions on the +[ARC EM SDP Initial Setup](/tensorflow/lite/micro/tools/make/targets/arc/README.md#ARC-EM-Software-Development-Platform-ARC-EM-SDP) +to get and install all required tools for work with ARC EM SDP. + +### Generate Example Project + +The example project for ARC EM SDP platform can be generated with the following +command: + +``` +make -f tensorflow/lite/micro/tools/make/Makefile TARGET=arc_emsdp TAGS=no_arc_mli generate_hello_world_make_project +``` + +### Build and Run Example + +For more detailed information on building and running examples see the +appropriate sections of general descriptions of the +[ARC EM SDP usage with TFLM](/tensorflow/lite/micro/tools/make/targets/arc/README.md#ARC-EM-Software-Development-Platform-ARC-EM-SDP). +In the directory with generated project you can also find a +*README_ARC_EMSDP.md* file with instructions and options on building and +running. Here we only briefly mention main steps which are typically enough to +get it started. + +1. You need to + [connect the board](/tensorflow/lite/micro/tools/make/targets/arc/README.md#connect-the-board) + and open an serial connection. + +2. Go to the generated example project director + + ``` + cd tensorflow/lite/micro/tools/make/gen/arc_emsdp_arc/prj/hello_world/make + ``` + +3. Build the example using + + ``` + make app + ``` + +4. To generate artefacts for self-boot of example from the board use + + ``` + make flash + ``` + +5. To run application from the board using microSD card: + + * Copy the content of the created /bin folder into the root of microSD + card. Note that the card must be formatted as FAT32 with default cluster + size (but less than 32 Kbytes) + * Plug in the microSD card into the J11 connector. + * Push the RST button. If a red LED is lit beside RST button, push the CFG + button. + +6. If you have the MetaWare Debugger installed in your environment: + + * To run application from the console using it type `make run`. + * To stop the execution type `Ctrl+C` in the console several times. + +In both cases (step 5 and 6) you will see the application output in the serial +terminal. + ## Deploy to Arduino The following instructions will help you build and deploy this sample diff --git a/tensorflow/lite/micro/examples/hello_world/hello_world_binary_test.sh b/tensorflow/lite/micro/examples/hello_world/hello_world_binary_test.sh new file mode 100755 index 00000000000..fe7683e5c4f --- /dev/null +++ b/tensorflow/lite/micro/examples/hello_world/hello_world_binary_test.sh @@ -0,0 +1,33 @@ +#!/bin/bash +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Bash unit tests for the example binary. + +set -e + +OUTPUT_LOG_FILE=${TEST_TMPDIR}/output_log.txt + +# Needed for copybara compatibility. +SCRIPT_BASE_DIR=/org_"tensor"flow +${TEST_SRCDIR}${SCRIPT_BASE_DIR}/tensorflow/lite/micro/examples/hello_world/hello_world 2>&1 | head > ${OUTPUT_LOG_FILE} + +if ! grep -q 'x_value:.*y_value:' ${OUTPUT_LOG_FILE}; then + echo "ERROR: Expected logs not found in output '${OUTPUT_LOG_FILE}'" + exit 1 +fi + +echo +echo "SUCCESS: hello_world_binary_test PASSED" diff --git a/tensorflow/lite/micro/examples/hello_world/main_functions.cc b/tensorflow/lite/micro/examples/hello_world/main_functions.cc index 404c8542432..d1c2cafe850 100644 --- a/tensorflow/lite/micro/examples/hello_world/main_functions.cc +++ b/tensorflow/lite/micro/examples/hello_world/main_functions.cc @@ -34,8 +34,12 @@ TfLiteTensor* output = nullptr; int inference_count = 0; // Create an area of memory to use for input, output, and intermediate arrays. -// Finding the minimum value for your model may require some trial and error. -constexpr int kTensorArenaSize = 2 * 1024; +// Minimum arena size, at the time of writing. After allocating tensors +// you can retrieve this value by invoking interpreter.arena_used_bytes(). +const int kModelArenaSize = 2352; +// Extra headroom for model + alignment + future interpreter changes. +const int kExtraArenaSize = 560 + 16 + 100; +const int kTensorArenaSize = kModelArenaSize + kExtraArenaSize; uint8_t tensor_arena[kTensorArenaSize]; } // namespace diff --git a/tensorflow/lite/micro/examples/hello_world/model.cc b/tensorflow/lite/micro/examples/hello_world/model.cc index 232e4a14115..f774985fd48 100644 --- a/tensorflow/lite/micro/examples/hello_world/model.cc +++ b/tensorflow/lite/micro/examples/hello_world/model.cc @@ -24,19 +24,8 @@ limitations under the License. #include "tensorflow/lite/micro/examples/hello_world/model.h" -// We need to keep the data array aligned on some architectures. -#ifdef __has_attribute -#define HAVE_ATTRIBUTE(x) __has_attribute(x) -#else -#define HAVE_ATTRIBUTE(x) 0 -#endif -#if HAVE_ATTRIBUTE(aligned) || (defined(__GNUC__) && !defined(__clang__)) -#define DATA_ALIGN_ATTRIBUTE __attribute__((aligned(4))) -#else -#define DATA_ALIGN_ATTRIBUTE -#endif - -const unsigned char g_model[] DATA_ALIGN_ATTRIBUTE = { +// Keep model aligned to 8 bytes to guarantee aligned 64-bit accesses. +alignas(8) const unsigned char g_model[] = { 0x1c, 0x00, 0x00, 0x00, 0x54, 0x46, 0x4c, 0x33, 0x00, 0x00, 0x12, 0x00, 0x1c, 0x00, 0x04, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x10, 0x00, 0x14, 0x00, 0x00, 0x00, 0x18, 0x00, 0x12, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, diff --git a/tensorflow/lite/micro/examples/magic_wand/BUILD b/tensorflow/lite/micro/examples/magic_wand/BUILD index 7d6f3cdcecd..b0be47c1eeb 100644 --- a/tensorflow/lite/micro/examples/magic_wand/BUILD +++ b/tensorflow/lite/micro/examples/magic_wand/BUILD @@ -41,6 +41,7 @@ tflite_micro_cc_test( ":magic_wand_model_data", ":sample_feature_data", "//tensorflow/lite:schema_fbs_version", + "//tensorflow/lite/micro:micro_error_reporter", "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro/kernels:all_ops_resolver", "//tensorflow/lite/micro/kernels:micro_ops", @@ -66,7 +67,7 @@ cc_library( ], deps = [ "//tensorflow/lite/c:common", - "//tensorflow/lite/micro:micro_framework", + "//tensorflow/lite/micro:micro_error_reporter", ], ) @@ -78,6 +79,7 @@ tflite_micro_cc_test( deps = [ ":accelerometer_handler", "//tensorflow/lite/c:common", + "//tensorflow/lite/micro:micro_error_reporter", "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro/testing:micro_test", ], @@ -119,7 +121,7 @@ cc_library( ], deps = [ "//tensorflow/lite/c:common", - "//tensorflow/lite/micro:micro_framework", + "//tensorflow/lite/micro:micro_error_reporter", ], ) @@ -155,6 +157,7 @@ cc_binary( ":magic_wand_model_data", ":output_handler", "//tensorflow/lite:schema_fbs_version", + "//tensorflow/lite/micro:micro_error_reporter", "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro/kernels:micro_ops", "//tensorflow/lite/schema:schema_fbs", diff --git a/tensorflow/lite/micro/examples/magic_wand/magic_wand_model_data.cc b/tensorflow/lite/micro/examples/magic_wand/magic_wand_model_data.cc index 1b8dca8eb0a..d56571dfd6f 100644 --- a/tensorflow/lite/micro/examples/magic_wand/magic_wand_model_data.cc +++ b/tensorflow/lite/micro/examples/magic_wand/magic_wand_model_data.cc @@ -19,19 +19,8 @@ limitations under the License. #include "tensorflow/lite/micro/examples/magic_wand/magic_wand_model_data.h" -// We need to keep the data array aligned on some architectures. -#ifdef __has_attribute -#define HAVE_ATTRIBUTE(x) __has_attribute(x) -#else -#define HAVE_ATTRIBUTE(x) 0 -#endif -#if HAVE_ATTRIBUTE(aligned) || (defined(__GNUC__) && !defined(__clang__)) -#define DATA_ALIGN_ATTRIBUTE __attribute__((aligned(4))) -#else -#define DATA_ALIGN_ATTRIBUTE -#endif - -const unsigned char g_magic_wand_model_data[] DATA_ALIGN_ATTRIBUTE = { +// Keep model aligned to 8 bytes to guarantee aligned 64-bit accesses. +alignas(8) const unsigned char g_magic_wand_model_data[] = { 0x1c, 0x00, 0x00, 0x00, 0x54, 0x46, 0x4c, 0x33, 0x00, 0x00, 0x12, 0x00, 0x1c, 0x00, 0x04, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x10, 0x00, 0x14, 0x00, 0x00, 0x00, 0x18, 0x00, 0x12, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, diff --git a/tensorflow/lite/micro/examples/micro_speech/BUILD b/tensorflow/lite/micro/examples/micro_speech/BUILD index d724972fbed..b487b895f7a 100644 --- a/tensorflow/lite/micro/examples/micro_speech/BUILD +++ b/tensorflow/lite/micro/examples/micro_speech/BUILD @@ -50,6 +50,7 @@ tflite_micro_cc_test( ], deps = [ "//tensorflow/lite:schema_fbs_version", + "//tensorflow/lite/micro:micro_error_reporter", "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro/examples/micro_speech/micro_features:micro_features_test_data", "//tensorflow/lite/micro/examples/micro_speech/micro_features:model", @@ -107,7 +108,7 @@ cc_library( deps = [ ":simple_model_settings", "//tensorflow/lite/c:common", - "//tensorflow/lite/micro:micro_framework", + "//tensorflow/lite/micro:micro_error_reporter", ], ) @@ -122,6 +123,7 @@ tflite_micro_cc_test( ":simple_features_generator_test_data", ":simple_model_settings", "//tensorflow/lite/c:common", + "//tensorflow/lite/micro:micro_error_reporter", "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro/testing:micro_test", ], @@ -138,7 +140,7 @@ cc_library( deps = [ ":simple_model_settings", "//tensorflow/lite/c:common", - "//tensorflow/lite/micro:micro_framework", + "//tensorflow/lite/micro:micro_error_reporter", ], ) @@ -153,6 +155,7 @@ tflite_micro_cc_test( ":simple_features_generator_test_data", ":simple_model_settings", "//tensorflow/lite/c:common", + "//tensorflow/lite/micro:micro_error_reporter", "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro/testing:micro_test", ], @@ -168,7 +171,7 @@ cc_library( ], deps = [ "//tensorflow/lite/c:common", - "//tensorflow/lite/micro:micro_framework", + "//tensorflow/lite/micro:micro_error_reporter", "//tensorflow/lite/micro/examples/micro_speech/micro_features:micro_model_settings", ], ) @@ -184,7 +187,7 @@ cc_library( deps = [ ":audio_large_sample_test_data", "//tensorflow/lite/c:common", - "//tensorflow/lite/micro:micro_framework", + "//tensorflow/lite/micro:micro_error_reporter", "//tensorflow/lite/micro/examples/micro_speech/micro_features:micro_model_settings", ], ) @@ -197,6 +200,7 @@ tflite_micro_cc_test( deps = [ ":audio_provider", "//tensorflow/lite/c:common", + "//tensorflow/lite/micro:micro_error_reporter", "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro/examples/micro_speech/micro_features:micro_model_settings", "//tensorflow/lite/micro/testing:micro_test", @@ -212,6 +216,7 @@ tflite_micro_cc_test( ":audio_large_sample_test_data", ":audio_provider_mock", "//tensorflow/lite/c:common", + "//tensorflow/lite/micro:micro_error_reporter", "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro/examples/micro_speech/micro_features:micro_model_settings", "//tensorflow/lite/micro/testing:micro_test", @@ -229,7 +234,7 @@ cc_library( deps = [ ":audio_provider", "//tensorflow/lite/c:common", - "//tensorflow/lite/micro:micro_framework", + "//tensorflow/lite/micro:micro_error_reporter", "//tensorflow/lite/micro/examples/micro_speech/micro_features:micro_features_generator", "//tensorflow/lite/micro/examples/micro_speech/micro_features:micro_model_settings", ], @@ -244,6 +249,7 @@ tflite_micro_cc_test( ":audio_provider", ":feature_provider", "//tensorflow/lite/c:common", + "//tensorflow/lite/micro:micro_error_reporter", "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro/examples/micro_speech/micro_features:micro_model_settings", "//tensorflow/lite/micro/testing:micro_test", @@ -261,7 +267,7 @@ cc_library( deps = [ ":audio_provider_mock", "//tensorflow/lite/c:common", - "//tensorflow/lite/micro:micro_framework", + "//tensorflow/lite/micro:micro_error_reporter", "//tensorflow/lite/micro/examples/micro_speech/micro_features:micro_features_generator", "//tensorflow/lite/micro/examples/micro_speech/micro_features:micro_model_settings", ], @@ -275,6 +281,7 @@ tflite_micro_cc_test( deps = [ ":feature_provider_mock", "//tensorflow/lite/c:common", + "//tensorflow/lite/micro:micro_error_reporter", "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro/examples/micro_speech/micro_features:micro_features_test_data", "//tensorflow/lite/micro/examples/micro_speech/micro_features:micro_model_settings", @@ -292,7 +299,7 @@ cc_library( ], deps = [ "//tensorflow/lite/c:common", - "//tensorflow/lite/micro:micro_framework", + "//tensorflow/lite/micro:micro_error_reporter", "//tensorflow/lite/micro/examples/micro_speech/micro_features:micro_model_settings", ], ) @@ -308,6 +315,7 @@ tflite_micro_cc_test( deps = [ ":recognize_commands", "//tensorflow/lite/c:common", + "//tensorflow/lite/micro:micro_error_reporter", "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro/testing:micro_test", ], @@ -323,7 +331,7 @@ cc_library( ], deps = [ "//tensorflow/lite/c:common", - "//tensorflow/lite/micro:micro_framework", + "//tensorflow/lite/micro:micro_error_reporter", ], ) @@ -335,6 +343,7 @@ tflite_micro_cc_test( deps = [ ":command_responder", "//tensorflow/lite/c:common", + "//tensorflow/lite/micro:micro_error_reporter", "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro/testing:micro_test", ], @@ -353,6 +362,7 @@ cc_binary( ":feature_provider", ":recognize_commands", "//tensorflow/lite:schema_fbs_version", + "//tensorflow/lite/micro:micro_error_reporter", "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro/examples/micro_speech/micro_features:micro_model_settings", "//tensorflow/lite/micro/examples/micro_speech/micro_features:model", @@ -374,6 +384,7 @@ cc_binary( ":feature_provider", ":recognize_commands", "//tensorflow/lite:schema_fbs_version", + "//tensorflow/lite/micro:micro_error_reporter", "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro/examples/micro_speech/micro_features:micro_model_settings", "//tensorflow/lite/micro/examples/micro_speech/micro_features:model", @@ -381,3 +392,9 @@ cc_binary( "//tensorflow/lite/schema:schema_fbs", ], ) + +sh_test( + name = "micro_speech_binary_mock_test", + srcs = ["micro_speech_binary_mock_test.sh"], + data = [":micro_speech_mock"], +) diff --git a/tensorflow/lite/micro/examples/micro_speech/README.md b/tensorflow/lite/micro/examples/micro_speech/README.md index 5c20aa5fe75..e854b74e33b 100644 --- a/tensorflow/lite/micro/examples/micro_speech/README.md +++ b/tensorflow/lite/micro/examples/micro_speech/README.md @@ -16,6 +16,7 @@ kilobytes of Flash. ## Table of contents +- [Deploy to ARC EM SDP](#deploy-to-arc-em-sdp) - [Deploy to Arduino](#deploy-to-arduino) - [Deploy to ESP32](#deploy-to-esp32) - [Deploy to SparkFun Edge](#deploy-to-sparkfun-edge) @@ -25,6 +26,95 @@ kilobytes of Flash. - [Run the tests on a development machine](#run-the-tests-on-a-development-machine) - [Train your own model](#train-your-own-model) +## Deploy to ARC EM SDP + +The following instructions will help you to build and deploy this example to +[ARC EM SDP](https://www.synopsys.com/dw/ipdir.php?ds=arc-em-software-development-platform) +board. General information and instructions on using the board with TensorFlow +Lite Micro can be found in the common +[ARC targets description](/tensorflow/lite/micro/tools/make/targets/arc/README.md). + +This example is quantized with symmetric uint8 scheme. As noted in +[kernels/arc_mli/README.md](/tensorflow/lite/micro/kernels/arc_mli/README.md), +embARC MLI supports optimized kernels for int8 quantization only. Therefore, +this example will only use TFLM reference kernels. + +The ARC EM SDP board contains the rich set of extension interfaces. You can +choose any compatible microphone and modify +[audio_provider.cc](/tensorflow/lite/micro/examples/micro_speech/audio_provider.cc) +file accordingly to use input from your specific camera. By default, results of +running this example are printed to the console. If you would like to instead +implement some target-specific actions, you need to modify +[command_responder.cc](/tensorflow/lite/micro/examples/micro_speech/command_responder.cc) +accordingly. + +The reference implementations of these files are used by default on the EM SDP. + +### Initial setup + +Follow the instructions on the +[ARC EM SDP Initial Setup](/tensorflow/lite/micro/tools/make/targets/arc/README.md#ARC-EM-Software-Development-Platform-ARC-EM-SDP) +to get and install all required tools for work with ARC EM SDP. + +### Generate Example Project + +As default example doesn’t provide any output without real audio, it is +recommended to get started with example for mock data. The project for ARC EM +SDP platform can be generated with the following command: + +``` +make -f tensorflow/lite/micro/tools/make/Makefile TARGET=arc_emsdp TAGS=no_arc_mli generate_micro_speech_mock_make_project +``` + +### Build and Run Example + +For more detailed information on building and running examples see the +appropriate sections of general descriptions of the +[ARC EM SDP usage with TFLM](/tensorflow/lite/micro/tools/make/targets/arc/README.md#ARC-EM-Software-Development-Platform-ARC-EM-SDP). +In the directory with generated project you can also find a +*README_ARC_EMSDP.md* file with instructions and options on building and +running. Here we only briefly mention main steps which are typically enough to +get it started. + +1. You need to + [connect the board](/tensorflow/lite/micro/tools/make/targets/arc/README.md#connect-the-board) + and open an serial connection. + +2. Go to the generated example project director + + ``` + cd tensorflow/lite/micro/tools/make/gen/arc_emsdp_arc/prj/micro_speech_mock/make + ``` + +3. Build the example using + + ``` + make app + ``` + +4. To generate artefacts for self-boot of example from the board use + + ``` + make flash + ``` + +5. To run application from the board using microSD card: + + * Copy the content of the created /bin folder into the root of microSD + card. Note that the card must be formatted as FAT32 with default cluster + size (but less than 32 Kbytes) + * Plug in the microSD card into the J11 connector. + * Push the RST button. If a red LED is lit beside RST button, push the CFG + button. + +6. If you have the MetaWare Debugger installed in your environment: + + * To run application from the console using it type `make run`. + * To stop the execution type `Ctrl+C` in the console several times. + +In both cases (step 5 and 6) you will see the application output in the serial +terminal. + ## Deploy to Arduino The following instructions will help you build and deploy this sample diff --git a/tensorflow/lite/micro/examples/micro_speech/arc_emsdp/Makefile.inc b/tensorflow/lite/micro/examples/micro_speech/arc_emsdp/Makefile.inc new file mode 100644 index 00000000000..850263f0eb9 --- /dev/null +++ b/tensorflow/lite/micro/examples/micro_speech/arc_emsdp/Makefile.inc @@ -0,0 +1,28 @@ +ifeq ($(TARGET), arc_emsdp) + +# Patch of arc make project to adjust it specifically for micro speech example. +# In particular: +# - Extend Heap and stack size for application needs +# - Use Linker command file with better usage of fast memory +# - In case project was generated with MLI usage, reduce scratch buffers. + + MICRO_SPEECH_HDRS += \ + micro_speech_patch.txt + + MICRO_SPEECH_TEST_HDRS += \ + micro_speech_patch.txt + + MICRO_SPEECH_MOCK_HDRS += \ + micro_speech_patch.txt + +%/micro_speech_patch.txt: %/emsdp.lcf %/Makefile + @cp tensorflow/lite/micro/tools/make/targets/arc/emsdp/emsdp_v2.lcf $< + @echo emsdp.lcf > $@ + @sed -E -i 's#-Hheap=[^ ]*#\-Hheap=16K \-Hstack=16K#g' $(word 2, $^) + @sed -E -i 's#MLI_ONLY *\?= *false#MLI_ONLY \?= false\n\ + CXXFLAGS += -DSCRATCH_MEM_X_SIZE=0 -DSCRATCH_MEM_Y_SIZE=0 -DSCRATCH_MEM_Z_SIZE=0\ + CCFLAGS += -DSCRATCH_MEM_X_SIZE=0 -DSCRATCH_MEM_Y_SIZE=0 -DSCRATCH_MEM_Z_SIZE=0#'\ + $(word 2, $^) + @echo Makefile >> $@ + +endif diff --git a/tensorflow/lite/micro/examples/micro_speech/feature_provider.cc b/tensorflow/lite/micro/examples/micro_speech/feature_provider.cc index 7d917085845..fc2b1420a89 100644 --- a/tensorflow/lite/micro/examples/micro_speech/feature_provider.cc +++ b/tensorflow/lite/micro/examples/micro_speech/feature_provider.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/lite/micro/examples/micro_speech/micro_features/micro_features_generator.h" #include "tensorflow/lite/micro/examples/micro_speech/micro_features/micro_model_settings.h" -FeatureProvider::FeatureProvider(int feature_size, uint8_t* feature_data) +FeatureProvider::FeatureProvider(int feature_size, int8_t* feature_data) : feature_size_(feature_size), feature_data_(feature_data), is_first_run_(true) { @@ -77,10 +77,10 @@ TfLiteStatus FeatureProvider::PopulateFeatureData( // +-----------+ +-----------+ if (slices_to_keep > 0) { for (int dest_slice = 0; dest_slice < slices_to_keep; ++dest_slice) { - uint8_t* dest_slice_data = + int8_t* dest_slice_data = feature_data_ + (dest_slice * kFeatureSliceSize); const int src_slice = dest_slice + slices_to_drop; - const uint8_t* src_slice_data = + const int8_t* src_slice_data = feature_data_ + (src_slice * kFeatureSliceSize); for (int i = 0; i < kFeatureSliceSize; ++i) { dest_slice_data[i] = src_slice_data[i]; @@ -106,7 +106,7 @@ TfLiteStatus FeatureProvider::PopulateFeatureData( audio_samples_size, kMaxAudioSampleSize); return kTfLiteError; } - uint8_t* new_slice_data = feature_data_ + (new_slice * kFeatureSliceSize); + int8_t* new_slice_data = feature_data_ + (new_slice * kFeatureSliceSize); size_t num_samples_read; TfLiteStatus generate_status = GenerateMicroFeatures( error_reporter, audio_samples, audio_samples_size, kFeatureSliceSize, diff --git a/tensorflow/lite/micro/examples/micro_speech/feature_provider.h b/tensorflow/lite/micro/examples/micro_speech/feature_provider.h index fc634ec108d..d086e013dc3 100644 --- a/tensorflow/lite/micro/examples/micro_speech/feature_provider.h +++ b/tensorflow/lite/micro/examples/micro_speech/feature_provider.h @@ -32,7 +32,7 @@ class FeatureProvider { // remain accessible for the lifetime of the provider object, since subsequent // calls will fill it with feature data. The provider does no memory // management of this data. - FeatureProvider(int feature_size, uint8_t* feature_data); + FeatureProvider(int feature_size, int8_t* feature_data); ~FeatureProvider(); // Fills the feature data with information from audio inputs, and returns how @@ -43,7 +43,7 @@ class FeatureProvider { private: int feature_size_; - uint8_t* feature_data_; + int8_t* feature_data_; // Make sure we don't try to use cached information if this is the first call // into the provider. bool is_first_run_; diff --git a/tensorflow/lite/micro/examples/micro_speech/feature_provider_mock_test.cc b/tensorflow/lite/micro/examples/micro_speech/feature_provider_mock_test.cc index 6dcf3da9a3f..aae556bf6e0 100644 --- a/tensorflow/lite/micro/examples/micro_speech/feature_provider_mock_test.cc +++ b/tensorflow/lite/micro/examples/micro_speech/feature_provider_mock_test.cc @@ -27,7 +27,7 @@ TF_LITE_MICRO_TEST(TestFeatureProviderMockYes) { tflite::MicroErrorReporter micro_error_reporter; tflite::ErrorReporter* error_reporter = µ_error_reporter; - uint8_t feature_data[kFeatureElementCount]; + int8_t feature_data[kFeatureElementCount]; FeatureProvider feature_provider(kFeatureElementCount, feature_data); int how_many_new_slices = 0; @@ -47,7 +47,7 @@ TF_LITE_MICRO_TEST(TestFeatureProviderMockNo) { tflite::MicroErrorReporter micro_error_reporter; tflite::ErrorReporter* error_reporter = µ_error_reporter; - uint8_t feature_data[kFeatureElementCount]; + int8_t feature_data[kFeatureElementCount]; FeatureProvider feature_provider(kFeatureElementCount, feature_data); int how_many_new_slices = 0; diff --git a/tensorflow/lite/micro/examples/micro_speech/feature_provider_test.cc b/tensorflow/lite/micro/examples/micro_speech/feature_provider_test.cc index 8e0e1f47d15..5d6816a91e4 100644 --- a/tensorflow/lite/micro/examples/micro_speech/feature_provider_test.cc +++ b/tensorflow/lite/micro/examples/micro_speech/feature_provider_test.cc @@ -26,7 +26,7 @@ TF_LITE_MICRO_TEST(TestFeatureProvider) { tflite::MicroErrorReporter micro_error_reporter; tflite::ErrorReporter* error_reporter = µ_error_reporter; - uint8_t feature_data[kFeatureElementCount]; + int8_t feature_data[kFeatureElementCount]; FeatureProvider feature_provider(kFeatureElementCount, feature_data); int how_many_new_slices = 0; diff --git a/tensorflow/lite/micro/examples/micro_speech/main_functions.cc b/tensorflow/lite/micro/examples/micro_speech/main_functions.cc index d3989c07333..e5e6aa7c1f7 100644 --- a/tensorflow/lite/micro/examples/micro_speech/main_functions.cc +++ b/tensorflow/lite/micro/examples/micro_speech/main_functions.cc @@ -43,8 +43,8 @@ int32_t previous_time = 0; // determined by experimentation. constexpr int kTensorArenaSize = 10 * 1024; uint8_t tensor_arena[kTensorArenaSize]; -uint8_t feature_buffer[kFeatureElementCount]; -uint8_t* model_input_buffer = nullptr; +int8_t feature_buffer[kFeatureElementCount]; +int8_t* model_input_buffer = nullptr; } // namespace // The name of this function is important for Arduino compatibility. @@ -74,19 +74,28 @@ void setup() { // // tflite::ops::micro::AllOpsResolver resolver; // NOLINTNEXTLINE(runtime-global-variables) - static tflite::MicroOpResolver<3> micro_op_resolver(error_reporter); + static tflite::MicroOpResolver<4> micro_op_resolver(error_reporter); if (micro_op_resolver.AddBuiltin( tflite::BuiltinOperator_DEPTHWISE_CONV_2D, - tflite::ops::micro::Register_DEPTHWISE_CONV_2D()) != kTfLiteOk) { + tflite::ops::micro::Register_DEPTHWISE_CONV_2D(), + tflite::MicroOpResolverAnyVersion()) != kTfLiteOk) { return; } if (micro_op_resolver.AddBuiltin( tflite::BuiltinOperator_FULLY_CONNECTED, - tflite::ops::micro::Register_FULLY_CONNECTED()) != kTfLiteOk) { + tflite::ops::micro::Register_FULLY_CONNECTED(), + tflite::MicroOpResolverAnyVersion()) != kTfLiteOk) { return; } if (micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_SOFTMAX, - tflite::ops::micro::Register_SOFTMAX()) != + tflite::ops::micro::Register_SOFTMAX(), + tflite::MicroOpResolverAnyVersion()) != + kTfLiteOk) { + return; + } + if (micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_RESHAPE, + tflite::ops::micro::Register_RESHAPE(), + tflite::MicroOpResolverAnyVersion()) != kTfLiteOk) { return; } @@ -105,15 +114,15 @@ void setup() { // Get information about the memory area to use for the model's input. model_input = interpreter->input(0); - if ((model_input->dims->size != 4) || (model_input->dims->data[0] != 1) || - (model_input->dims->data[1] != kFeatureSliceCount) || - (model_input->dims->data[2] != kFeatureSliceSize) || - (model_input->type != kTfLiteUInt8)) { + if ((model_input->dims->size != 2) || (model_input->dims->data[0] != 1) || + (model_input->dims->data[1] != + (kFeatureSliceCount * kFeatureSliceSize)) || + (model_input->type != kTfLiteInt8)) { TF_LITE_REPORT_ERROR(error_reporter, "Bad input tensor parameters in model"); return; } - model_input_buffer = model_input->data.uint8; + model_input_buffer = model_input->data.int8; // Prepare to access the audio spectrograms from a microphone or other source // that will provide the inputs to the neural network. diff --git a/tensorflow/lite/micro/examples/micro_speech/micro_features/BUILD b/tensorflow/lite/micro/examples/micro_speech/micro_features/BUILD index 71010493102..0aa7ff14f73 100644 --- a/tensorflow/lite/micro/examples/micro_speech/micro_features/BUILD +++ b/tensorflow/lite/micro/examples/micro_speech/micro_features/BUILD @@ -59,7 +59,7 @@ cc_library( ":micro_model_settings", "//tensorflow/lite/c:common", "//tensorflow/lite/experimental/microfrontend/lib:frontend", - "//tensorflow/lite/micro:micro_framework", + "//tensorflow/lite/micro:micro_error_reporter", ], ) @@ -85,6 +85,7 @@ tflite_micro_cc_test( ":micro_features_generator_test_data", ":micro_model_settings", "//tensorflow/lite/c:common", + "//tensorflow/lite/micro:micro_error_reporter", "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro/examples/micro_speech:audio_sample_test_data", "//tensorflow/lite/micro/testing:micro_test", diff --git a/tensorflow/lite/micro/examples/micro_speech/micro_features/micro_features_generator.cc b/tensorflow/lite/micro/examples/micro_speech/micro_features/micro_features_generator.cc index 6a01124ed86..fbb6e6e4a9f 100644 --- a/tensorflow/lite/micro/examples/micro_speech/micro_features/micro_features_generator.cc +++ b/tensorflow/lite/micro/examples/micro_speech/micro_features/micro_features_generator.cc @@ -69,7 +69,7 @@ void SetMicroFeaturesNoiseEstimates(const uint32_t* estimate_presets) { TfLiteStatus GenerateMicroFeatures(tflite::ErrorReporter* error_reporter, const int16_t* input, int input_size, - int output_size, uint8_t* output, + int output_size, int8_t* output, size_t* num_samples_read) { const int16_t* frontend_input; if (g_is_first_time) { @@ -84,16 +84,30 @@ TfLiteStatus GenerateMicroFeatures(tflite::ErrorReporter* error_reporter, for (int i = 0; i < frontend_output.size; ++i) { // These scaling values are derived from those used in input_data.py in the // training pipeline. - constexpr int32_t value_scale = (10 * 255); - constexpr int32_t value_div = (256 * 26); + // The feature pipeline outputs 16-bit signed integers in roughly a 0 to 670 + // range. In training, these are then arbitrarily divided by 25.6 to get + // float values in the rough range of 0.0 to 26.0. This scaling is performed + // for historical reasons, to match up with the output of other feature + // generators. + // The process is then further complicated when we quantize the model. This + // means we have to scale the 0.0 to 26.0 real values to the -128 to 127 + // signed integer numbers. + // All this means that to get matching values from our integer feature + // output into the tensor input, we have to perform: + // input = (((feature / 25.6) / 26.0) * 256) - 128 + // To simplify this and perform it in 32-bit integer math, we rearrange to: + // input = (feature * 256) / (25.6 * 26.0) - 128 + constexpr int32_t value_scale = 256; + constexpr int32_t value_div = static_cast<int32_t>((25.6f * 26.0f) + 0.5f); int32_t value = ((frontend_output.values[i] * value_scale) + (value_div / 2)) / value_div; - if (value < 0) { - value = 0; + value -= 128; + if (value < -128) { + value = -128; } - if (value > 255) { - value = 255; + if (value > 127) { + value = 127; } output[i] = value; } diff --git a/tensorflow/lite/micro/examples/micro_speech/micro_features/micro_features_generator.h b/tensorflow/lite/micro/examples/micro_speech/micro_features/micro_features_generator.h index 7b9bc5faec8..29304239332 100644 --- a/tensorflow/lite/micro/examples/micro_speech/micro_features/micro_features_generator.h +++ b/tensorflow/lite/micro/examples/micro_speech/micro_features/micro_features_generator.h @@ -26,7 +26,7 @@ TfLiteStatus InitializeMicroFeatures(tflite::ErrorReporter* error_reporter); // feeding into a neural network. TfLiteStatus GenerateMicroFeatures(tflite::ErrorReporter* error_reporter, const int16_t* input, int input_size, - int output_size, uint8_t* output, + int output_size, int8_t* output, size_t* num_samples_read); #endif // TENSORFLOW_LITE_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_MICRO_FEATURES_GENERATOR_H_ diff --git a/tensorflow/lite/micro/examples/micro_speech/micro_features/micro_features_generator_test.cc b/tensorflow/lite/micro/examples/micro_speech/micro_features/micro_features_generator_test.cc index f88f12a5562..ee3ee03763f 100644 --- a/tensorflow/lite/micro/examples/micro_speech/micro_features/micro_features_generator_test.cc +++ b/tensorflow/lite/micro/examples/micro_speech/micro_features/micro_features_generator_test.cc @@ -48,7 +48,7 @@ TF_LITE_MICRO_TEST(TestMicroFeaturesGeneratorYes) { }; SetMicroFeaturesNoiseEstimates(yes_estimate_presets); - uint8_t yes_calculated_data[g_yes_feature_data_slice_size]; + int8_t yes_calculated_data[g_yes_feature_data_slice_size]; size_t num_samples_read; TfLiteStatus yes_status = GenerateMicroFeatures( error_reporter, g_yes_30ms_sample_data, g_yes_30ms_sample_data_size, @@ -56,11 +56,12 @@ TF_LITE_MICRO_TEST(TestMicroFeaturesGeneratorYes) { TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, yes_status); for (int i = 0; i < g_yes_feature_data_slice_size; ++i) { - TF_LITE_MICRO_EXPECT_EQ(g_yes_feature_data_slice[i], - yes_calculated_data[i]); - if (g_yes_feature_data_slice[i] != yes_calculated_data[i]) { + const int expected = g_yes_feature_data_slice[i]; + const int actual = yes_calculated_data[i]; + TF_LITE_MICRO_EXPECT_EQ(expected, actual); + if (expected != actual) { TF_LITE_REPORT_ERROR(error_reporter, "Expected value %d but found %d", - g_yes_feature_data_slice[i], yes_calculated_data[i]); + expected, actual); } } } @@ -81,7 +82,7 @@ TF_LITE_MICRO_TEST(TestMicroFeaturesGeneratorNo) { }; SetMicroFeaturesNoiseEstimates(no_estimate_presets); - uint8_t no_calculated_data[g_no_feature_data_slice_size]; + int8_t no_calculated_data[g_no_feature_data_slice_size]; size_t num_samples_read; TfLiteStatus no_status = GenerateMicroFeatures( error_reporter, g_no_30ms_sample_data, g_no_30ms_sample_data_size, @@ -89,10 +90,12 @@ TF_LITE_MICRO_TEST(TestMicroFeaturesGeneratorNo) { TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, no_status); for (int i = 0; i < g_no_feature_data_slice_size; ++i) { - TF_LITE_MICRO_EXPECT_EQ(g_no_feature_data_slice[i], no_calculated_data[i]); - if (g_no_feature_data_slice[i] != no_calculated_data[i]) { + const int expected = g_no_feature_data_slice[i]; + const int actual = no_calculated_data[i]; + TF_LITE_MICRO_EXPECT_EQ(expected, actual); + if (expected != actual) { TF_LITE_REPORT_ERROR(error_reporter, "Expected value %d but found %d", - g_no_feature_data_slice[i], no_calculated_data[i]); + expected, actual); } } } diff --git a/tensorflow/lite/micro/examples/micro_speech/micro_features/model.cc b/tensorflow/lite/micro/examples/micro_speech/micro_features/model.cc index 45198c781b2..d1e797fcf7d 100644 --- a/tensorflow/lite/micro/examples/micro_speech/micro_features/model.cc +++ b/tensorflow/lite/micro/examples/micro_speech/micro_features/model.cc @@ -33,1528 +33,1564 @@ limitations under the License. #endif const unsigned char g_model[] DATA_ALIGN_ATTRIBUTE = { - 0x1c, 0x00, 0x00, 0x00, 0x54, 0x46, 0x4c, 0x33, 0x00, 0x00, 0x12, 0x00, - 0x1c, 0x00, 0x04, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x10, 0x00, 0x14, 0x00, - 0x00, 0x00, 0x18, 0x00, 0x12, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, - 0x1c, 0x47, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, - 0x2c, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - 0xc0, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x94, 0x00, 0x00, 0x00, - 0x0f, 0x00, 0x00, 0x00, 0x54, 0x4f, 0x43, 0x4f, 0x20, 0x43, 0x6f, 0x6e, - 0x76, 0x65, 0x72, 0x74, 0x65, 0x64, 0x2e, 0x00, 0x0a, 0x00, 0x00, 0x00, - 0x60, 0x00, 0x00, 0x00, 0x54, 0x00, 0x00, 0x00, 0x48, 0x00, 0x00, 0x00, - 0x3c, 0x00, 0x00, 0x00, 0x34, 0x00, 0x00, 0x00, 0x2c, 0x00, 0x00, 0x00, - 0x20, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x0e, 0xba, 0xff, 0xff, 0x38, 0x00, 0x00, 0x00, - 0xbc, 0xb9, 0xff, 0xff, 0xc0, 0xb9, 0xff, 0xff, 0x1e, 0xba, 0xff, 0xff, - 0xe0, 0x01, 0x00, 0x00, 0xcc, 0xb9, 0xff, 0xff, 0xd0, 0xb9, 0xff, 0xff, - 0x2e, 0xba, 0xff, 0xff, 0x60, 0x03, 0x00, 0x00, 0x36, 0xba, 0xff, 0xff, - 0x7c, 0x06, 0x00, 0x00, 0x3e, 0xba, 0xff, 0xff, 0x68, 0x45, 0x00, 0x00, - 0xec, 0xb9, 0xff, 0xff, 0x05, 0x00, 0x00, 0x00, 0x31, 0x2e, 0x35, 0x2e, - 0x30, 0x00, 0x00, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x04, 0x00, 0x08, 0x00, - 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x00, - 0x13, 0x00, 0x00, 0x00, 0x6d, 0x69, 0x6e, 0x5f, 0x72, 0x75, 0x6e, 0x74, - 0x69, 0x6d, 0x65, 0x5f, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x00, - 0x10, 0xfa, 0xff, 0xff, 0x20, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, - 0x08, 0x00, 0x00, 0x00, 0x2c, 0x45, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - 0x05, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, - 0x08, 0x00, 0x00, 0x00, 0x04, 0x01, 0x00, 0x00, 0x9c, 0x44, 0x00, 0x00, - 0x8c, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0xdc, 0x01, 0x00, 0x00, - 0x68, 0x01, 0x00, 0x00, 0x3c, 0x02, 0x00, 0x00, 0x50, 0x05, 0x00, 0x00, - 0x8e, 0xbb, 0xff, 0xff, 0x00, 0x00, 0x00, 0x03, 0x10, 0x00, 0x00, 0x00, - 0x08, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x28, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x31, 0x00, 0x00, 0x00, - 0x28, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x00, - 0x52, 0x65, 0x73, 0x68, 0x61, 0x70, 0x65, 0x5f, 0x32, 0x00, 0x00, 0x00, - 0x94, 0xfa, 0xff, 0xff, 0x2c, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, + 0x20, 0x00, 0x00, 0x00, 0x54, 0x46, 0x4c, 0x33, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x12, 0x00, 0x1c, 0x00, 0x04, 0x00, 0x08, 0x00, 0x0c, 0x00, + 0x10, 0x00, 0x14, 0x00, 0x00, 0x00, 0x18, 0x00, 0x12, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x94, 0x48, 0x00, 0x00, 0x34, 0x42, 0x00, 0x00, + 0x1c, 0x42, 0x00, 0x00, 0x3c, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x08, 0x00, 0x0c, 0x00, + 0x04, 0x00, 0x08, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, + 0x0b, 0x00, 0x00, 0x00, 0x13, 0x00, 0x00, 0x00, 0x6d, 0x69, 0x6e, 0x5f, + 0x72, 0x75, 0x6e, 0x74, 0x69, 0x6d, 0x65, 0x5f, 0x76, 0x65, 0x72, 0x73, + 0x69, 0x6f, 0x6e, 0x00, 0x0c, 0x00, 0x00, 0x00, 0xd4, 0x41, 0x00, 0x00, + 0xb4, 0x41, 0x00, 0x00, 0x24, 0x03, 0x00, 0x00, 0xf4, 0x02, 0x00, 0x00, + 0xec, 0x02, 0x00, 0x00, 0xe4, 0x02, 0x00, 0x00, 0xc4, 0x02, 0x00, 0x00, + 0xbc, 0x02, 0x00, 0x00, 0x2c, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, + 0x1c, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x16, 0xbd, 0xff, 0xff, + 0x04, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, 0x31, 0x2e, 0x35, 0x2e, + 0x30, 0x00, 0x00, 0x00, 0x94, 0xba, 0xff, 0xff, 0x98, 0xba, 0xff, 0xff, + 0x32, 0xbd, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00, 0x80, 0x02, 0x00, 0x00, + 0xfa, 0xee, 0x28, 0xc4, 0xee, 0xfe, 0xcf, 0x0f, 0x1e, 0xf7, 0x1f, 0x06, + 0x0d, 0xed, 0xe9, 0x83, 0x5c, 0xc9, 0x18, 0xe3, 0xf9, 0x14, 0x28, 0x2a, + 0x09, 0xf2, 0x18, 0x34, 0x62, 0xea, 0xef, 0xd6, 0x36, 0xb7, 0x1e, 0xf7, + 0x3b, 0x22, 0x28, 0x39, 0xc2, 0x9d, 0xf1, 0x07, 0x5e, 0x0b, 0x1e, 0x2c, + 0x07, 0xdd, 0xfd, 0xc3, 0xd8, 0x4a, 0xf3, 0x28, 0xa7, 0x16, 0xd5, 0xf1, + 0xc3, 0x05, 0xfd, 0x27, 0xcc, 0xba, 0x1e, 0xcb, 0xd7, 0x3d, 0xd4, 0x29, + 0x00, 0xfd, 0x28, 0x44, 0xfb, 0xf2, 0xf3, 0xb6, 0x4f, 0xcf, 0x09, 0xf0, + 0xfa, 0x45, 0x41, 0x49, 0x05, 0xc5, 0x17, 0x5d, 0x64, 0x00, 0xf8, 0xee, + 0x48, 0x17, 0xf4, 0xe9, 0x2e, 0x4b, 0x2e, 0x3f, 0xdf, 0xee, 0xe4, 0x08, + 0x38, 0xf1, 0x16, 0x13, 0x2f, 0x2a, 0xed, 0xc2, 0xbf, 0x36, 0xf4, 0x02, + 0xcf, 0xaa, 0xd2, 0xfa, 0xac, 0x13, 0xf6, 0xe8, 0xb5, 0x68, 0x12, 0xb6, + 0xce, 0x0e, 0xdf, 0x58, 0xe4, 0x49, 0x14, 0x15, 0x03, 0xed, 0xfa, 0xd4, + 0x40, 0xa7, 0xf6, 0xca, 0xfb, 0x00, 0x4d, 0x5e, 0xe4, 0x55, 0x1d, 0x30, + 0x45, 0xe2, 0xfc, 0x01, 0x48, 0x81, 0xe9, 0xf1, 0x1e, 0xfc, 0x21, 0x32, + 0xed, 0x4b, 0xed, 0xfa, 0x2f, 0xd2, 0xfa, 0xfb, 0x4d, 0xa7, 0xed, 0xc7, + 0x92, 0xdf, 0xe6, 0xdb, 0xf8, 0x1f, 0xd9, 0xfa, 0x91, 0xf5, 0xe5, 0xc5, + 0x8c, 0x17, 0x0f, 0xb9, 0xd2, 0xc7, 0xfe, 0x68, 0xd3, 0x51, 0x2e, 0x49, + 0x1f, 0xbd, 0x01, 0xeb, 0x31, 0x17, 0xf0, 0xef, 0xff, 0xb8, 0x5d, 0x62, + 0x02, 0x0f, 0x1f, 0x78, 0x6a, 0xb0, 0xf9, 0xfe, 0x4f, 0xcc, 0xd3, 0xff, + 0x0a, 0x96, 0x1e, 0x2c, 0xed, 0xbc, 0xf4, 0x0b, 0x42, 0xc8, 0xf1, 0xea, + 0x6e, 0x58, 0xec, 0xc4, 0x99, 0xae, 0xdc, 0xd7, 0x12, 0x87, 0xd8, 0x06, + 0xa2, 0xc2, 0xe6, 0xa2, 0x81, 0x24, 0xe9, 0xac, 0xce, 0xb6, 0x15, 0x6b, + 0xba, 0x00, 0x19, 0x58, 0x29, 0xb6, 0xfe, 0x01, 0x25, 0x96, 0xd2, 0xec, + 0x0e, 0x9c, 0x60, 0x5f, 0xe9, 0xf4, 0xf5, 0x69, 0x6b, 0xb5, 0xe1, 0xf6, + 0x5e, 0xb7, 0xb1, 0xe5, 0x11, 0x9b, 0x18, 0x10, 0xe3, 0xe1, 0xe0, 0x0d, + 0x4f, 0xa5, 0xde, 0xe5, 0x6f, 0xe2, 0xfb, 0x99, 0x82, 0xa5, 0xc9, 0xb6, + 0x1f, 0x46, 0xf3, 0x04, 0xc6, 0xca, 0xd6, 0x97, 0x90, 0x1d, 0xc0, 0x95, + 0xf0, 0x19, 0x30, 0x77, 0xc2, 0x3c, 0xfa, 0x24, 0x02, 0x4d, 0x06, 0x07, + 0x15, 0x02, 0xb0, 0xe7, 0x27, 0x22, 0x67, 0x4d, 0xf1, 0xc2, 0xf4, 0x64, + 0x38, 0x40, 0xdf, 0xf6, 0x3a, 0x43, 0xb8, 0xe1, 0x0d, 0x15, 0x11, 0xfe, + 0xf5, 0xec, 0xf9, 0xe5, 0x22, 0x36, 0xe4, 0xfd, 0x6d, 0xbf, 0x0d, 0x8e, + 0xb7, 0x15, 0xbf, 0x9f, 0x16, 0xad, 0x0a, 0x02, 0x8e, 0x14, 0xda, 0x9b, + 0x8e, 0xc3, 0xa6, 0xca, 0xf5, 0x7f, 0x51, 0x56, 0xc1, 0xb3, 0xd9, 0x35, + 0xf8, 0x7f, 0x04, 0x0a, 0x03, 0x3f, 0xbe, 0xee, 0x19, 0x68, 0x78, 0x50, + 0xf9, 0xa7, 0xf7, 0x7f, 0x1d, 0x76, 0xdb, 0xe8, 0x33, 0xb9, 0xd7, 0xe7, + 0xe8, 0x69, 0x15, 0xf7, 0xf5, 0xb2, 0xfe, 0xe8, 0xf3, 0x5b, 0xe2, 0x06, + 0x6e, 0x09, 0x36, 0xb7, 0xcc, 0x38, 0xbf, 0x8a, 0x28, 0x14, 0x2e, 0x18, + 0xa7, 0x26, 0xcb, 0xb2, 0x95, 0x37, 0xac, 0xcd, 0xd7, 0x51, 0x67, 0x44, + 0xcd, 0x31, 0xde, 0x04, 0xe9, 0x6a, 0x00, 0x13, 0x0a, 0x0c, 0xdd, 0x16, + 0xe0, 0x24, 0x7e, 0x49, 0xf1, 0xb5, 0x04, 0x52, 0x01, 0x50, 0xdd, 0xf5, + 0x26, 0xc9, 0xf4, 0xf8, 0xd6, 0x31, 0x1b, 0xd0, 0xef, 0x03, 0x0a, 0xc0, + 0xd4, 0x4f, 0xe2, 0xfd, 0x72, 0xf4, 0x5a, 0xc9, 0xd7, 0x31, 0xc0, 0x8e, + 0x17, 0x5e, 0x57, 0x00, 0xb4, 0x3a, 0xc8, 0xd2, 0x92, 0x32, 0xcb, 0xd8, + 0xc3, 0xa6, 0x63, 0x26, 0xcf, 0xbc, 0xe8, 0x57, 0x9b, 0xe9, 0xf7, 0x1c, + 0xea, 0x12, 0xf1, 0xf7, 0xdb, 0xb9, 0x7f, 0x16, 0xf6, 0xe0, 0x08, 0x70, + 0xa2, 0xed, 0xcc, 0xf1, 0x1e, 0x10, 0x04, 0xf7, 0xa9, 0xb7, 0x34, 0xaa, + 0x0a, 0xdb, 0x2a, 0xa6, 0xb6, 0x10, 0xea, 0xf8, 0x5e, 0x06, 0x72, 0xdd, + 0xd0, 0xb9, 0xd6, 0xa0, 0x10, 0x9f, 0x5a, 0x17, 0xb1, 0xe7, 0xc0, 0x01, + 0x9d, 0x01, 0xe0, 0xe0, 0xaf, 0x9c, 0x46, 0xd8, 0xaf, 0xe8, 0xce, 0x02, + 0x8a, 0xbb, 0xe4, 0xf6, 0xf3, 0x36, 0x07, 0xca, 0xcb, 0x87, 0x6e, 0xcc, + 0xd6, 0x9e, 0x0a, 0x2a, 0x81, 0xd7, 0xcf, 0xc0, 0x04, 0xeb, 0x24, 0xcc, + 0xc9, 0x95, 0x33, 0x81, 0xf7, 0xad, 0x1c, 0x9c, 0xa4, 0xd6, 0xf9, 0xe6, + 0x3d, 0x84, 0x7f, 0xcc, 0xd4, 0xb0, 0xf4, 0xa2, 0xe9, 0x3c, 0x36, 0xee, + 0xd5, 0xcf, 0xcd, 0x2d, 0x28, 0xbd, 0xff, 0xff, 0xc2, 0xbf, 0xff, 0xff, + 0x04, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, + 0x31, 0x00, 0x00, 0x00, 0x28, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x48, 0xbd, 0xff, 0xff, 0x4c, 0xbd, 0xff, 0xff, 0xe6, 0xbf, 0xff, 0xff, + 0x04, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x8a, 0xfe, 0xff, 0xff, + 0xa9, 0x00, 0x00, 0x00, 0xd0, 0xff, 0xff, 0xff, 0xd0, 0x00, 0x00, 0x00, + 0x52, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x4f, 0xfb, 0xff, 0xff, + 0x4a, 0xfd, 0xff, 0xff, 0x12, 0xc0, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00, + 0x80, 0x3e, 0x00, 0x00, 0xff, 0xf9, 0xfd, 0x0a, 0x07, 0x08, 0x07, 0x03, + 0x07, 0xf2, 0xd1, 0x09, 0xf0, 0xe9, 0x28, 0x09, 0xdf, 0x05, 0xfa, 0xf0, + 0xe8, 0xe3, 0x13, 0x0e, 0x08, 0xef, 0xd3, 0xee, 0x0f, 0xe8, 0xeb, 0x14, + 0xf7, 0xed, 0xfd, 0x1f, 0xe8, 0xd5, 0xeb, 0xfc, 0x0e, 0xf4, 0xf7, 0x07, + 0x05, 0xea, 0xf6, 0x1f, 0xf8, 0xdb, 0xdc, 0x0b, 0x03, 0xdd, 0xd8, 0xf3, + 0x0f, 0x19, 0xe1, 0x09, 0xfc, 0xe4, 0x02, 0x04, 0xf1, 0x04, 0xeb, 0xf3, + 0x1e, 0x06, 0xfd, 0x11, 0xfc, 0xfa, 0xf6, 0x1f, 0x0f, 0x02, 0xf5, 0xf7, + 0xff, 0x24, 0xdf, 0xf7, 0xf8, 0xf3, 0xf6, 0xe9, 0xef, 0x03, 0xdd, 0xf2, + 0x28, 0xe1, 0xf2, 0x22, 0xf4, 0x09, 0xf7, 0xf9, 0xf0, 0xd4, 0xf9, 0xee, + 0xff, 0x14, 0xda, 0xf3, 0x11, 0xe2, 0xf6, 0x0c, 0xf2, 0xeb, 0xf8, 0xe8, + 0xe3, 0x08, 0x02, 0x17, 0xf4, 0x0b, 0x0c, 0x27, 0xe6, 0x02, 0x03, 0xf9, + 0x14, 0x18, 0xf6, 0xeb, 0x1f, 0x0c, 0xf1, 0xee, 0xfc, 0x08, 0xf0, 0xfe, + 0xfd, 0xee, 0x17, 0xfd, 0x1c, 0xef, 0xfd, 0xde, 0x04, 0x05, 0xf0, 0x31, + 0xfa, 0x0b, 0xdc, 0x0d, 0xed, 0xf5, 0xfa, 0xf4, 0x08, 0x0c, 0xd7, 0x1e, + 0x15, 0x03, 0xf5, 0x02, 0xf4, 0xfb, 0xed, 0x01, 0xfe, 0xd6, 0x1f, 0xfd, + 0xfd, 0x0e, 0xfa, 0x06, 0xf1, 0xf9, 0xe2, 0x16, 0xe9, 0xf1, 0x03, 0x0d, + 0x0d, 0xdf, 0xf9, 0x1a, 0x0e, 0xf6, 0xfc, 0x0a, 0x19, 0xe2, 0xe0, 0x09, + 0x15, 0xf0, 0xf1, 0x06, 0xf1, 0xe1, 0xef, 0x1a, 0x08, 0xe8, 0xfd, 0x12, + 0x14, 0x06, 0xf1, 0xfc, 0xea, 0xfb, 0xf7, 0xea, 0x1d, 0x09, 0xfa, 0xf6, + 0x08, 0xf2, 0xe7, 0xf8, 0xfc, 0x16, 0xf5, 0x0e, 0x08, 0xf9, 0x0a, 0x03, + 0x26, 0xd8, 0x02, 0xf5, 0xf6, 0xf6, 0xef, 0x1f, 0xe4, 0xe2, 0xfb, 0x02, + 0x1b, 0xe6, 0xde, 0x00, 0xf2, 0xed, 0xfb, 0x18, 0xe4, 0x16, 0x1a, 0x1d, + 0xf1, 0xf6, 0xea, 0x16, 0x05, 0xde, 0xfb, 0x18, 0xf5, 0xe4, 0xfe, 0xe2, + 0x1b, 0x1c, 0x0c, 0xe8, 0x02, 0xee, 0xfb, 0x07, 0x24, 0xf2, 0xe9, 0xfa, + 0x0d, 0x05, 0xf1, 0x03, 0xfe, 0xf6, 0x19, 0x06, 0xff, 0xf9, 0x04, 0xfb, + 0x15, 0xef, 0xf1, 0xf8, 0xe9, 0xe1, 0x10, 0x04, 0xfc, 0xe6, 0x1f, 0xed, + 0x0b, 0xef, 0x00, 0x1e, 0xe6, 0x16, 0xf3, 0x09, 0xfd, 0x08, 0x08, 0x06, + 0x06, 0x23, 0xdf, 0xfc, 0x08, 0xf4, 0xea, 0x0c, 0xf2, 0xe6, 0x18, 0xf5, + 0x02, 0xf9, 0x50, 0x09, 0x01, 0xda, 0x0b, 0x05, 0x12, 0x18, 0xef, 0x04, + 0x0e, 0xd9, 0xff, 0xdc, 0xf6, 0x16, 0xf9, 0xf4, 0xec, 0xff, 0xea, 0xe6, + 0xfa, 0x0a, 0xed, 0xef, 0x02, 0xf0, 0x25, 0x21, 0xf1, 0x26, 0xf5, 0xed, + 0x09, 0xea, 0xea, 0x24, 0xfa, 0x11, 0xfc, 0xdf, 0xf3, 0x0a, 0x28, 0x0c, + 0x19, 0xff, 0xf5, 0xd6, 0x0e, 0xe2, 0x2a, 0x06, 0xfa, 0x03, 0xf9, 0xe6, + 0xef, 0x23, 0xf9, 0xfa, 0xe6, 0xfe, 0xfc, 0x03, 0x06, 0x1a, 0xf9, 0x08, + 0xe0, 0xe5, 0xff, 0x05, 0x01, 0xe7, 0x12, 0x02, 0x1d, 0x05, 0x03, 0x05, + 0x0b, 0xee, 0xed, 0xfc, 0x0f, 0xf3, 0x02, 0xe0, 0x15, 0xdf, 0x02, 0xed, + 0x10, 0x26, 0xef, 0x0d, 0x06, 0xee, 0xef, 0xf6, 0xeb, 0x11, 0x09, 0xf4, + 0xf7, 0x06, 0x0f, 0x01, 0x2a, 0x0b, 0x01, 0xdd, 0xfc, 0xf4, 0xf1, 0x17, + 0x03, 0x04, 0x07, 0xfc, 0x22, 0xfc, 0xde, 0xfe, 0x0b, 0x03, 0xf3, 0xfb, + 0x0c, 0x25, 0x04, 0x19, 0x04, 0x03, 0x01, 0xfa, 0xfb, 0xf7, 0xf6, 0x0e, + 0x15, 0x0e, 0x09, 0xff, 0x06, 0xfa, 0xfb, 0x1e, 0xfb, 0x05, 0x22, 0xf9, + 0xfe, 0xf7, 0x1d, 0xed, 0xdf, 0x18, 0x09, 0xeb, 0xef, 0x04, 0x12, 0xea, + 0xdf, 0xfb, 0xda, 0xf6, 0xdf, 0x17, 0xef, 0xef, 0xe1, 0x1a, 0xd9, 0xe2, + 0xe2, 0xfc, 0x05, 0x11, 0xf6, 0xee, 0xe8, 0xf2, 0xe1, 0x08, 0x26, 0x04, + 0xed, 0x03, 0xe0, 0xfb, 0xee, 0x0c, 0xee, 0xf6, 0x04, 0x2d, 0xf2, 0xd3, + 0xf4, 0xe0, 0xf8, 0x0c, 0xfe, 0x11, 0x0b, 0xd7, 0xfd, 0x18, 0x07, 0x0d, + 0x07, 0x08, 0xf4, 0xc6, 0x0a, 0x0a, 0x1f, 0x0c, 0xf4, 0x1d, 0x02, 0x0b, + 0x09, 0x0e, 0x21, 0xff, 0x17, 0x0b, 0x0d, 0xf2, 0xed, 0xd7, 0x0a, 0xf8, + 0x03, 0x06, 0xfa, 0xe5, 0xfd, 0x03, 0x14, 0x0f, 0xe9, 0x1a, 0xf4, 0xda, + 0x01, 0xe6, 0x09, 0x06, 0x11, 0x0d, 0xfd, 0xeb, 0x16, 0x23, 0xfa, 0x00, + 0x0b, 0x17, 0xf7, 0xda, 0xd7, 0x1b, 0xfa, 0x01, 0x03, 0x05, 0xfe, 0xd6, + 0x02, 0xee, 0xee, 0x02, 0xf3, 0x06, 0xed, 0x03, 0xec, 0x01, 0xf2, 0x0f, + 0x05, 0x17, 0x0b, 0xfb, 0x0f, 0x05, 0x03, 0x13, 0xff, 0x06, 0x02, 0xf5, + 0xf4, 0x18, 0x2b, 0xf0, 0x00, 0x17, 0xfc, 0xfd, 0x05, 0x0b, 0x0e, 0x14, + 0xe1, 0x24, 0x08, 0x24, 0xe6, 0xeb, 0x21, 0x12, 0xfb, 0x12, 0xe7, 0xf4, + 0xe8, 0x0e, 0x18, 0xee, 0xf5, 0xf3, 0xd9, 0xf3, 0xdb, 0xec, 0x0c, 0x1e, + 0xcf, 0x14, 0xdb, 0xe3, 0xdc, 0x02, 0x0c, 0xfb, 0xdb, 0x1b, 0xd0, 0xfe, + 0xf9, 0xfe, 0x2a, 0xf5, 0x00, 0x0b, 0xcd, 0xe0, 0xe2, 0x0e, 0x04, 0xf8, + 0xda, 0x1c, 0xe5, 0x0f, 0xe8, 0xf4, 0xf7, 0x15, 0x06, 0xf8, 0x02, 0xf7, + 0x0f, 0xfb, 0x17, 0xf9, 0xda, 0x01, 0xda, 0xd1, 0xf6, 0x02, 0xfd, 0x16, + 0xf1, 0xe4, 0xfa, 0x07, 0xee, 0x0a, 0xf3, 0xfd, 0xf2, 0x23, 0xf0, 0xe1, + 0x0a, 0x1a, 0x12, 0x1f, 0xef, 0x27, 0x09, 0xf1, 0x0c, 0x13, 0x23, 0xfd, + 0xf5, 0x03, 0xfe, 0x09, 0xfd, 0x16, 0xf8, 0x07, 0x08, 0x25, 0x08, 0xf8, + 0xf6, 0x0a, 0xf1, 0xf5, 0x07, 0x09, 0x05, 0xcc, 0xf8, 0x08, 0x13, 0xf9, + 0x1d, 0x11, 0x0f, 0xdc, 0xee, 0xf3, 0x27, 0xf9, 0xf9, 0x22, 0xfa, 0x0d, + 0xe2, 0x13, 0xfb, 0x11, 0x03, 0x1e, 0xff, 0xfb, 0xed, 0xf1, 0x0e, 0x0b, + 0x0f, 0x00, 0x06, 0xe0, 0x15, 0xf3, 0x13, 0xfc, 0x18, 0xf9, 0xff, 0x09, + 0xfa, 0x1f, 0x12, 0xe5, 0xe2, 0x06, 0xf9, 0xf4, 0x07, 0x15, 0x0b, 0x04, + 0xdb, 0x0d, 0xeb, 0xf3, 0xe6, 0x06, 0xe5, 0xee, 0xd8, 0x22, 0xd8, 0x10, + 0xea, 0xf9, 0x1c, 0xf7, 0xd3, 0x11, 0xc3, 0xf8, 0xde, 0x05, 0x00, 0xe6, + 0x07, 0xfd, 0xd3, 0x03, 0xea, 0xe0, 0x13, 0x14, 0xcf, 0xeb, 0xcd, 0xd3, + 0xde, 0xf5, 0xf0, 0x0c, 0x0c, 0xfa, 0xeb, 0xd3, 0xfb, 0xfd, 0x08, 0xf9, + 0xf4, 0x10, 0xfa, 0xd3, 0xf4, 0x11, 0x11, 0xf8, 0xef, 0xf8, 0xf8, 0xf1, + 0xfc, 0xe1, 0xf7, 0x12, 0x04, 0xf4, 0xfb, 0xed, 0xef, 0x0c, 0xfd, 0x1c, + 0xfe, 0x0e, 0xfd, 0xe2, 0xfe, 0x0a, 0x02, 0xfe, 0xe6, 0x1f, 0xef, 0xe5, + 0xe6, 0xf8, 0x16, 0x27, 0xe8, 0x20, 0x05, 0xe3, 0xf1, 0xef, 0xee, 0xed, + 0x0d, 0x11, 0x16, 0xfb, 0xf3, 0xff, 0x14, 0x01, 0xff, 0x15, 0x10, 0x02, + 0xe5, 0x28, 0x29, 0x13, 0x13, 0x16, 0xe6, 0x00, 0xd2, 0x26, 0xfd, 0x03, + 0x04, 0x05, 0x07, 0x06, 0xf1, 0x0e, 0x05, 0x0d, 0xe2, 0x0f, 0x02, 0xe1, + 0x07, 0xf7, 0x1c, 0xfa, 0x14, 0x30, 0xf7, 0xee, 0x00, 0xfa, 0x3d, 0x06, + 0x1c, 0x04, 0x06, 0x07, 0x05, 0x1a, 0x10, 0xf6, 0xee, 0x0a, 0xeb, 0x04, + 0xeb, 0xdf, 0x1d, 0x09, 0xd5, 0xe8, 0xd6, 0xf4, 0xf0, 0x0f, 0x1d, 0xea, + 0xf2, 0xf8, 0xa6, 0x0b, 0xdc, 0x09, 0x08, 0x24, 0xee, 0x24, 0xaa, 0xe4, + 0xcb, 0x15, 0xef, 0xe7, 0xe9, 0x0c, 0xcf, 0x06, 0xe3, 0x12, 0x11, 0x00, + 0x07, 0x14, 0xd7, 0xde, 0xf6, 0x0f, 0x0b, 0x04, 0xfb, 0x0d, 0xf8, 0x0d, + 0xf6, 0x1b, 0xf1, 0x21, 0xdd, 0xfc, 0xf4, 0xe9, 0xf8, 0xe8, 0xf7, 0x06, + 0x03, 0x1e, 0xce, 0xe1, 0xea, 0xf6, 0x05, 0xf9, 0x16, 0x15, 0x04, 0xe0, + 0x14, 0xf7, 0x1e, 0x1c, 0x0a, 0x27, 0xef, 0xf3, 0x0f, 0xf3, 0xee, 0x04, + 0xf8, 0xf1, 0x07, 0xe3, 0x05, 0x0b, 0x00, 0x1c, 0x15, 0x27, 0x07, 0xf7, + 0xfa, 0x0b, 0xfa, 0xfa, 0x17, 0x13, 0xe1, 0xf5, 0xfb, 0x0c, 0x21, 0x2f, + 0xd7, 0xfb, 0xf5, 0xfd, 0xd3, 0xf4, 0x07, 0x0e, 0xfd, 0x0b, 0xfc, 0xfa, + 0xf5, 0x0e, 0x02, 0xfa, 0xfa, 0x19, 0xfd, 0xfa, 0xfc, 0x13, 0x24, 0x0c, + 0xe4, 0x31, 0xf8, 0x12, 0xf4, 0x04, 0x18, 0x29, 0x27, 0x19, 0xfc, 0x08, + 0x11, 0xe3, 0x07, 0xfe, 0x26, 0x40, 0x05, 0x02, 0x04, 0x02, 0x0f, 0xee, + 0xf4, 0x27, 0xea, 0xf4, 0xf5, 0x11, 0x26, 0x0b, 0xe7, 0x05, 0xd2, 0xf6, + 0xea, 0xfa, 0x0b, 0xf9, 0xfa, 0x16, 0xba, 0x00, 0xfb, 0x0d, 0x0b, 0xf9, + 0xe6, 0xf6, 0xc5, 0xf8, 0xf6, 0x01, 0x0f, 0xed, 0xed, 0x13, 0xcd, 0x0d, + 0xda, 0x06, 0x17, 0xee, 0x07, 0x1d, 0xb8, 0xfa, 0xe2, 0xea, 0xf2, 0xee, + 0x04, 0x00, 0xdc, 0xd0, 0xfb, 0xf5, 0xec, 0xfe, 0xf1, 0x0d, 0xf0, 0xdb, + 0xf9, 0x0d, 0x03, 0x03, 0x0e, 0x0a, 0xda, 0xd6, 0x01, 0xf2, 0x06, 0x14, + 0x1c, 0x1f, 0xe8, 0xe8, 0x0e, 0xfd, 0x0c, 0xf5, 0xf3, 0x3d, 0xf3, 0x05, + 0x10, 0xfa, 0x1b, 0x18, 0x08, 0x36, 0x09, 0xf1, 0xeb, 0xf9, 0x22, 0x01, + 0xf3, 0xf7, 0xff, 0xf0, 0x0c, 0xe9, 0x01, 0x29, 0x21, 0x15, 0x03, 0xee, + 0xe9, 0x1a, 0xf7, 0x15, 0x06, 0x25, 0xfa, 0xf0, 0xe4, 0xf1, 0x1f, 0x01, + 0xdc, 0x2d, 0xce, 0xe9, 0xea, 0x0b, 0x06, 0x2c, 0x0a, 0x30, 0xe7, 0x09, + 0xf4, 0xf0, 0x10, 0x29, 0xf9, 0x3d, 0xe7, 0xdc, 0xe4, 0xf7, 0x3b, 0x27, + 0x23, 0x3a, 0x0a, 0x06, 0x0e, 0xfd, 0x2c, 0x07, 0x2b, 0x1c, 0xfa, 0x00, + 0xf9, 0x11, 0xea, 0x14, 0xeb, 0xfc, 0x18, 0x03, 0xf1, 0x16, 0x12, 0x04, + 0xcf, 0x12, 0xdd, 0xe4, 0x0e, 0xf0, 0x09, 0xe8, 0xf3, 0xfb, 0xa8, 0xf9, + 0xee, 0xfb, 0x1e, 0x1d, 0xfd, 0x05, 0xab, 0xe5, 0xff, 0x01, 0xfe, 0x04, + 0xf9, 0x02, 0xb9, 0xdc, 0xdf, 0x05, 0xf1, 0xef, 0xf1, 0x1e, 0xc7, 0xee, + 0xf7, 0x1e, 0x00, 0x00, 0xf8, 0x10, 0xec, 0xe8, 0x04, 0x0f, 0xf6, 0xff, + 0x04, 0x09, 0xe0, 0x0a, 0x0e, 0xe4, 0xf0, 0xf1, 0x16, 0x2b, 0xd3, 0xe1, + 0x0a, 0xef, 0xf9, 0xfe, 0x0b, 0x22, 0xf5, 0x01, 0x0a, 0xf8, 0x02, 0x00, + 0x17, 0x19, 0xf3, 0x05, 0x21, 0xfa, 0xee, 0xee, 0x12, 0xf2, 0xfa, 0xf5, + 0x05, 0x12, 0xee, 0xe4, 0x28, 0xfa, 0xf1, 0x03, 0x15, 0x16, 0x18, 0xfd, + 0x0f, 0x21, 0x04, 0xf4, 0xe5, 0x0c, 0x06, 0x13, 0xde, 0x36, 0xe8, 0xfb, + 0xe7, 0xfd, 0xf6, 0x12, 0x0e, 0x1d, 0xea, 0xf8, 0xd4, 0xe8, 0x19, 0x07, + 0xe5, 0x1c, 0xf7, 0x0c, 0xef, 0x05, 0x0f, 0x09, 0xdd, 0x1a, 0xea, 0xd7, + 0xf9, 0xf9, 0x12, 0x17, 0x2e, 0x10, 0x08, 0xfe, 0x14, 0xf5, 0x1d, 0xfa, + 0x06, 0x33, 0xed, 0xfe, 0xf7, 0x11, 0xf0, 0x15, 0xe2, 0x24, 0xf6, 0x0a, + 0xe2, 0xfc, 0x23, 0x12, 0xdd, 0x11, 0xfd, 0xe5, 0x08, 0xff, 0x15, 0xf6, + 0xf1, 0x1b, 0xae, 0xfe, 0xe6, 0x15, 0x2c, 0x2d, 0x15, 0x15, 0xc5, 0xf8, + 0xea, 0xe7, 0x07, 0x04, 0xfe, 0x28, 0xa1, 0xf2, 0xe1, 0xf9, 0xf8, 0xff, + 0xf4, 0x22, 0xb4, 0xdb, 0x03, 0x20, 0xe6, 0xf3, 0x0e, 0x19, 0xe3, 0x0a, + 0xfa, 0xee, 0xf3, 0xe5, 0xd8, 0xf9, 0xf1, 0xde, 0x06, 0x05, 0xf2, 0xf5, + 0xe7, 0x16, 0xd8, 0xfe, 0x07, 0xea, 0xee, 0x0e, 0xfa, 0xff, 0xdb, 0xe7, + 0x03, 0xed, 0x01, 0xfd, 0x09, 0x1a, 0xfa, 0xe6, 0x05, 0x10, 0xe9, 0x01, + 0x1f, 0x13, 0xf7, 0xf6, 0xfb, 0x13, 0xff, 0xdb, 0xed, 0xfe, 0x0a, 0x10, + 0x09, 0x29, 0xf5, 0x04, 0xf5, 0x26, 0x0d, 0x0c, 0xf9, 0x16, 0xfa, 0x02, + 0xf4, 0x2e, 0xde, 0xf5, 0xe1, 0x1d, 0xfb, 0x02, 0x0b, 0x23, 0x07, 0xea, + 0xd9, 0x0a, 0xf3, 0x0a, 0x0f, 0x1e, 0xe7, 0xf1, 0xd7, 0x0b, 0xf6, 0xff, + 0x0d, 0x24, 0xcc, 0x0a, 0xee, 0xda, 0x14, 0x12, 0x11, 0x29, 0xf4, 0x1a, + 0xef, 0x0b, 0xfa, 0xec, 0x0c, 0x1b, 0xf4, 0xff, 0xf5, 0xef, 0x0f, 0x10, + 0xd4, 0x04, 0xf9, 0xf8, 0xec, 0xf9, 0x21, 0x05, 0xd3, 0x27, 0xf3, 0x17, + 0xff, 0xf6, 0x15, 0xf9, 0xed, 0x0a, 0xac, 0x02, 0xfd, 0xfb, 0x04, 0x29, + 0x06, 0x03, 0xb8, 0xe6, 0xd5, 0x17, 0x09, 0x1b, 0xf6, 0x1b, 0xab, 0xdc, + 0xdf, 0xfd, 0x06, 0x09, 0x09, 0x37, 0xbb, 0xed, 0x19, 0xd7, 0xe2, 0xdd, + 0x05, 0x01, 0xec, 0xfb, 0xe4, 0x0e, 0xeb, 0xf0, 0x03, 0x17, 0x04, 0xeb, + 0x09, 0xee, 0xeb, 0xe7, 0x0c, 0x16, 0xcb, 0x0e, 0x17, 0xd8, 0xe1, 0xf8, + 0x2b, 0x19, 0xde, 0xeb, 0x10, 0xf2, 0xff, 0xf8, 0xee, 0x0e, 0xe7, 0xf0, + 0x15, 0x08, 0xf8, 0xdf, 0x06, 0x0d, 0xf9, 0x14, 0xfa, 0x0b, 0x04, 0xfd, + 0x15, 0x23, 0x20, 0xff, 0xfd, 0x1d, 0x0c, 0xf1, 0xfe, 0x15, 0x0a, 0x02, + 0xed, 0xfe, 0xfb, 0x04, 0xfb, 0x1e, 0xdd, 0x05, 0xe0, 0x16, 0xf9, 0xf6, + 0xfd, 0x32, 0xdc, 0xf2, 0xd3, 0x08, 0xf4, 0xec, 0x17, 0x25, 0xe2, 0xf0, + 0xee, 0xf1, 0x0d, 0xfe, 0x13, 0x2d, 0x01, 0x11, 0xd4, 0xe4, 0x07, 0xfb, + 0x32, 0x11, 0x14, 0x07, 0xd7, 0x02, 0x10, 0xeb, 0x2b, 0x1d, 0x01, 0xfc, + 0xf3, 0xf0, 0x13, 0x1a, 0xdb, 0x20, 0x00, 0xf0, 0xf0, 0x05, 0x16, 0x03, + 0xd4, 0xe3, 0xc2, 0xf0, 0x06, 0x02, 0x1e, 0x0a, 0xec, 0x1f, 0xab, 0xea, + 0xfa, 0xe3, 0x20, 0x22, 0x03, 0x1b, 0xb3, 0x0e, 0xe3, 0xf3, 0x1d, 0x27, + 0xe3, 0x10, 0xa7, 0xda, 0xf3, 0x00, 0x0a, 0x0a, 0x04, 0xfb, 0xb2, 0x0f, + 0x0c, 0xf5, 0x07, 0xff, 0x13, 0x1e, 0xdb, 0xf6, 0xf9, 0xef, 0xe8, 0xe7, + 0xfb, 0x18, 0xeb, 0xec, 0x09, 0xda, 0xf1, 0xf0, 0x0b, 0x04, 0xe1, 0xfa, + 0x1c, 0x25, 0xee, 0x01, 0x0b, 0x29, 0xd7, 0x0c, 0x04, 0x0b, 0xef, 0xfd, + 0x1c, 0xfc, 0xf1, 0xfb, 0x0b, 0x0f, 0xdf, 0xed, 0x17, 0x38, 0x0c, 0xd7, + 0xff, 0xfd, 0x01, 0xfc, 0xfb, 0xfb, 0x18, 0x1a, 0x18, 0xe3, 0xf9, 0xf4, + 0xfa, 0x20, 0x06, 0x09, 0x11, 0x08, 0x1d, 0xf8, 0xfa, 0x1d, 0xf5, 0x1c, + 0xf5, 0xfe, 0x03, 0x07, 0xe4, 0x33, 0xc8, 0x0c, 0xe1, 0x13, 0xff, 0xe5, + 0x10, 0x2c, 0xd3, 0xf0, 0xed, 0x04, 0x07, 0x01, 0xf1, 0x16, 0xe0, 0x13, + 0xfa, 0x11, 0x07, 0xfa, 0x19, 0x16, 0x01, 0x00, 0x07, 0x26, 0x00, 0xec, + 0x1d, 0x23, 0x05, 0xf4, 0x07, 0x17, 0x2c, 0x1d, 0xee, 0xf0, 0x0c, 0x09, + 0xe3, 0x1a, 0x24, 0x0b, 0xf3, 0x1e, 0xce, 0xfe, 0xfe, 0x12, 0x21, 0x1a, + 0xf6, 0x23, 0xc3, 0x03, 0xf4, 0x10, 0x1a, 0x2a, 0xf4, 0x08, 0xbf, 0xff, + 0x04, 0xf4, 0x0b, 0x1d, 0x1a, 0xf8, 0xcc, 0x00, 0xf7, 0x13, 0xf4, 0xfd, + 0xf4, 0x19, 0xbd, 0xef, 0x0c, 0x0d, 0x02, 0xfc, 0x12, 0x13, 0xe9, 0xe7, + 0xf5, 0xfa, 0xfa, 0xf6, 0x1a, 0x2e, 0xce, 0xd4, 0x01, 0x12, 0xfd, 0xfc, + 0x26, 0x10, 0xcc, 0xe7, 0xee, 0x13, 0xee, 0xff, 0xef, 0xea, 0x00, 0x0e, + 0x1a, 0x17, 0x04, 0x0c, 0x04, 0x0c, 0xe6, 0xf3, 0xf6, 0xdb, 0xdd, 0x04, + 0xf4, 0x22, 0x11, 0x16, 0xf3, 0x07, 0xec, 0xf8, 0xf2, 0x07, 0x03, 0x02, + 0xf5, 0x0a, 0xf6, 0x02, 0x1d, 0x1b, 0x11, 0x06, 0xf8, 0x06, 0x02, 0xea, + 0xf3, 0x1d, 0xce, 0x00, 0xed, 0xf9, 0xef, 0xf6, 0xec, 0x22, 0xc7, 0xf0, + 0xed, 0xdb, 0xe0, 0x02, 0x11, 0x07, 0xe8, 0xf0, 0xd1, 0xed, 0xff, 0xfd, + 0x0c, 0x2e, 0xd4, 0xed, 0xec, 0x0e, 0xf1, 0x07, 0x01, 0x0e, 0x0e, 0xfe, + 0xda, 0x0b, 0x0a, 0x0a, 0x1f, 0x2e, 0x13, 0x07, 0x00, 0x07, 0x14, 0x21, + 0xe9, 0xfc, 0xf0, 0x1e, 0xd7, 0xea, 0x34, 0x07, 0xc6, 0x0c, 0xd4, 0xec, + 0xfd, 0x06, 0x24, 0x0a, 0xf3, 0x15, 0xaf, 0xff, 0xe9, 0xf1, 0x0d, 0x3e, + 0xe9, 0x18, 0xba, 0x13, 0xed, 0xd7, 0x0b, 0x31, 0x05, 0x0e, 0xaf, 0x13, + 0xd6, 0x0e, 0x10, 0x02, 0x02, 0x14, 0xcb, 0xd5, 0xf9, 0x0c, 0xf9, 0x0e, + 0x1f, 0x24, 0xd5, 0xeb, 0xff, 0xf1, 0xf5, 0x0c, 0x08, 0x07, 0xf4, 0xd7, + 0x06, 0x10, 0xe8, 0xef, 0xfc, 0x2f, 0xee, 0xf1, 0x18, 0xf8, 0xf4, 0x02, + 0x11, 0x21, 0xd3, 0x12, 0x14, 0xe4, 0xf4, 0x02, 0x05, 0x24, 0xca, 0xf2, + 0xf3, 0xeb, 0xe7, 0xf8, 0x16, 0x1a, 0xeb, 0x0d, 0x05, 0x16, 0xf1, 0xec, + 0x11, 0x1c, 0x09, 0x1e, 0xe0, 0xe6, 0xfa, 0x0e, 0x0d, 0x2a, 0xea, 0x2e, + 0xed, 0xf9, 0xf7, 0x16, 0x09, 0x05, 0xdd, 0xd6, 0x02, 0xeb, 0xf5, 0xf3, + 0xe4, 0x3b, 0xed, 0x04, 0xe0, 0x0e, 0xfd, 0x09, 0xfd, 0x35, 0xdc, 0x18, + 0xf3, 0x04, 0xfa, 0x05, 0x15, 0x34, 0xe5, 0xe1, 0xe4, 0xf4, 0xe0, 0xf9, + 0x08, 0x32, 0x04, 0x08, 0xf4, 0x0f, 0xff, 0x08, 0x09, 0x2f, 0x06, 0x02, + 0xfd, 0x05, 0x0c, 0x24, 0xe3, 0x1e, 0xf5, 0x0c, 0xdd, 0xf8, 0x18, 0x20, + 0xd8, 0x14, 0xef, 0xf4, 0x17, 0x08, 0x25, 0x14, 0x04, 0x06, 0xb0, 0xf5, + 0xf5, 0x09, 0x0f, 0x3e, 0xff, 0x28, 0xb3, 0xf5, 0x19, 0xd8, 0x14, 0x21, + 0xd9, 0xf7, 0xb7, 0xe5, 0xfe, 0xe7, 0x07, 0x1e, 0x04, 0x15, 0xc5, 0xf9, + 0x14, 0x20, 0xeb, 0x01, 0x01, 0x18, 0xce, 0x00, 0xe6, 0xe2, 0xf7, 0xfb, + 0xf3, 0x0d, 0xd3, 0xf3, 0x04, 0xf8, 0xf0, 0x03, 0xf1, 0x25, 0xb5, 0xef, + 0x05, 0xe0, 0x01, 0xf6, 0x04, 0x16, 0xd1, 0x01, 0x0a, 0x21, 0x01, 0x05, + 0x0e, 0x01, 0xf0, 0x0a, 0xf3, 0x00, 0x03, 0xf8, 0xfa, 0x03, 0x0b, 0xde, + 0xfe, 0xff, 0xfb, 0xea, 0x09, 0x02, 0xf5, 0xe8, 0xe7, 0x08, 0x00, 0xf5, + 0xf8, 0x0f, 0x13, 0xfa, 0xeb, 0xe8, 0xfb, 0x1f, 0x08, 0x16, 0xe6, 0xfa, + 0xe1, 0x00, 0x03, 0xdd, 0xf1, 0x26, 0xe5, 0x1d, 0xd9, 0xff, 0xf2, 0xf8, + 0xff, 0x33, 0xea, 0xe5, 0x03, 0x0c, 0x07, 0xf9, 0xf8, 0x0f, 0xe1, 0x1e, + 0xdd, 0x0f, 0x00, 0xf1, 0x06, 0x21, 0x09, 0x05, 0xf3, 0xec, 0xe6, 0x04, + 0x07, 0x32, 0xf1, 0xf9, 0xf2, 0x01, 0x18, 0x1f, 0xd2, 0xe2, 0x0a, 0xf4, + 0xca, 0xfc, 0x28, 0x16, 0xc2, 0x10, 0xf2, 0xfc, 0x08, 0xe9, 0x2a, 0x0f, + 0xfa, 0xf5, 0xa9, 0x07, 0xec, 0xe9, 0x19, 0x43, 0x0b, 0x1c, 0xa6, 0xe9, + 0xf4, 0x16, 0x0d, 0x2b, 0xfc, 0x11, 0x9a, 0xe1, 0xf1, 0x1c, 0xf5, 0x0f, + 0xe4, 0x18, 0xc0, 0xd9, 0x14, 0x26, 0xe6, 0xf8, 0x0a, 0x17, 0xec, 0xfb, + 0xe1, 0x22, 0xdf, 0xf2, 0xfe, 0x1e, 0xd4, 0xeb, 0xd7, 0x0e, 0x08, 0xf6, + 0xef, 0xfc, 0xe6, 0xd4, 0xf7, 0x0b, 0xfb, 0xf5, 0x01, 0x25, 0xd7, 0xfb, + 0x0d, 0xfe, 0xff, 0xf3, 0x1d, 0x32, 0xfe, 0xee, 0x12, 0xf2, 0x0c, 0xec, + 0x02, 0x10, 0xef, 0x01, 0xf2, 0x0b, 0xf3, 0xf7, 0xfa, 0x25, 0xfb, 0x0d, + 0x11, 0x15, 0x04, 0xfc, 0x0c, 0x21, 0x12, 0x29, 0x00, 0xfa, 0xf6, 0xf5, + 0x06, 0x22, 0xea, 0xe2, 0xee, 0x00, 0xfd, 0xf0, 0x0b, 0x1d, 0xd3, 0xe4, + 0xe4, 0x0a, 0xfc, 0xe8, 0xea, 0x2c, 0xed, 0xed, 0xef, 0xe8, 0xf2, 0x05, + 0xfd, 0x15, 0xd8, 0xda, 0xca, 0xee, 0xfa, 0x00, 0xfe, 0x0e, 0xf2, 0xf0, + 0x0e, 0xf5, 0x04, 0x03, 0x1d, 0x2b, 0xee, 0x05, 0x0f, 0x10, 0x13, 0x35, + 0xe2, 0x04, 0x10, 0xdf, 0xcf, 0xeb, 0x40, 0x26, 0xe4, 0x03, 0xf3, 0xf9, + 0xf5, 0x14, 0x24, 0x2a, 0xdf, 0xfe, 0xab, 0xe5, 0xfe, 0x1c, 0x27, 0x35, + 0xdb, 0xff, 0xac, 0x01, 0xf6, 0xfc, 0x19, 0x1a, 0x11, 0x1f, 0xa8, 0xf5, + 0x02, 0x0f, 0x1a, 0x1f, 0xf7, 0xf2, 0xa2, 0x00, 0x15, 0x22, 0xe4, 0x13, + 0x00, 0x09, 0xd9, 0xd5, 0x02, 0x19, 0xfd, 0xf8, 0xe7, 0xff, 0xfb, 0xe0, + 0xef, 0xf7, 0xee, 0xf3, 0xf3, 0x19, 0xb0, 0xdf, 0x00, 0x0f, 0x08, 0xf3, + 0x15, 0x17, 0xec, 0x0f, 0x11, 0x14, 0x02, 0x08, 0x10, 0x17, 0xe6, 0x08, + 0xf7, 0x00, 0xed, 0xf7, 0x29, 0x07, 0x10, 0x05, 0x05, 0xe7, 0xed, 0xf4, + 0xf9, 0x15, 0xf9, 0xf0, 0x08, 0x00, 0x03, 0x09, 0x21, 0x28, 0xf6, 0x0e, + 0xfb, 0xf3, 0x03, 0xf7, 0x0f, 0x0c, 0xf0, 0xf5, 0xe3, 0xd8, 0xf8, 0xf2, + 0x09, 0x1c, 0xe7, 0xfb, 0xe4, 0xf6, 0xfa, 0xf8, 0xf1, 0x42, 0xf6, 0xda, + 0xdd, 0xd7, 0xfa, 0xff, 0x2f, 0x2c, 0xda, 0x0a, 0xde, 0xec, 0xf1, 0x14, + 0xfb, 0x1d, 0xeb, 0xee, 0xf2, 0xeb, 0xf3, 0xed, 0x0e, 0x35, 0xf0, 0x06, + 0x19, 0x04, 0x2f, 0x23, 0xe2, 0x07, 0x13, 0x0f, 0xe9, 0xf0, 0x22, 0x2e, + 0xd9, 0x1a, 0xcb, 0xed, 0xfd, 0x04, 0x27, 0x1e, 0xf6, 0x07, 0x96, 0xd6, + 0xd8, 0x11, 0x18, 0x56, 0xd2, 0xfb, 0x92, 0xfc, 0x0b, 0x0a, 0x17, 0x2c, + 0xe5, 0x04, 0xa2, 0xf8, 0xe2, 0x04, 0x1a, 0x0d, 0xeb, 0x11, 0xa2, 0xe5, + 0xe5, 0xf8, 0x02, 0xf7, 0x17, 0x03, 0xca, 0xe9, 0x0c, 0x1f, 0xfe, 0xf5, + 0x18, 0x12, 0xdd, 0x08, 0x15, 0xff, 0xfc, 0xf6, 0xe1, 0x1d, 0xe2, 0xe1, + 0xfe, 0xfc, 0x03, 0xff, 0xf2, 0x23, 0xd2, 0x01, 0x13, 0xdd, 0xf3, 0xf4, + 0xf2, 0x07, 0xef, 0x03, 0x15, 0x21, 0xd8, 0xf8, 0x09, 0xf3, 0xe8, 0xea, + 0xe8, 0xf2, 0x08, 0xf0, 0x04, 0x1a, 0xf2, 0x19, 0xfb, 0x1b, 0x15, 0xfc, + 0x1d, 0x30, 0xe5, 0x1e, 0x09, 0xe8, 0xe9, 0x09, 0xf7, 0x2a, 0xe1, 0x0e, + 0x00, 0x21, 0xf3, 0xff, 0xfb, 0x01, 0xdf, 0xf2, 0xfe, 0xf4, 0xfc, 0xf0, + 0x0b, 0x0b, 0xdd, 0xe4, 0xd2, 0x14, 0xf7, 0xfe, 0x0b, 0x39, 0x01, 0xe6, + 0xe4, 0x27, 0xfa, 0xe4, 0x04, 0x2c, 0xe2, 0x04, 0xf5, 0x07, 0xf2, 0x03, + 0xf0, 0x10, 0xf5, 0xf6, 0xfc, 0x16, 0x22, 0x1b, 0xf8, 0x11, 0xe4, 0x09, + 0xf6, 0xf0, 0x41, 0x1e, 0xcf, 0x04, 0xea, 0xee, 0x0e, 0xf6, 0x1b, 0x2f, + 0xc7, 0xf1, 0xba, 0xef, 0x0f, 0x16, 0x1e, 0x39, 0x05, 0x1e, 0x90, 0xe6, + 0x0d, 0xfa, 0x22, 0x3f, 0xe3, 0x23, 0xa5, 0xe3, 0xe9, 0x0f, 0x05, 0x27, + 0x02, 0x11, 0x99, 0x05, 0xfa, 0x05, 0x03, 0x01, 0xff, 0x26, 0xd3, 0xf7, + 0xf7, 0xf9, 0x05, 0xf4, 0xef, 0x23, 0xd2, 0xdd, 0x05, 0x08, 0xfa, 0xff, + 0x03, 0x04, 0xbd, 0xd7, 0x14, 0x06, 0xef, 0x06, 0xe5, 0x05, 0xea, 0xea, + 0x02, 0xfd, 0x0d, 0x00, 0x08, 0xff, 0xe7, 0xfb, 0xfe, 0x13, 0xfe, 0xec, + 0xf9, 0x02, 0xf3, 0xff, 0xff, 0x08, 0x04, 0xed, 0x19, 0x1d, 0xfa, 0x0a, + 0x0d, 0xf2, 0x0f, 0xec, 0x25, 0x1c, 0xec, 0x0b, 0x01, 0xff, 0x01, 0xf6, + 0x08, 0x09, 0xe8, 0xe2, 0xec, 0x23, 0xe5, 0xe9, 0xf0, 0x2e, 0xbd, 0xe1, + 0xef, 0x14, 0xe9, 0xf6, 0xf5, 0x1d, 0xdc, 0xe3, 0xd7, 0xfc, 0xf9, 0xf2, + 0xfe, 0x24, 0xf2, 0x05, 0xd5, 0xed, 0xe9, 0xf9, 0xfa, 0x2d, 0xf0, 0xfe, + 0xee, 0xf2, 0xe8, 0xf7, 0x06, 0x14, 0x01, 0x10, 0x06, 0xf3, 0x0e, 0x0e, + 0xc2, 0x1d, 0xf2, 0x1c, 0xed, 0xe3, 0x53, 0x21, 0xb8, 0x0c, 0xde, 0x03, + 0x15, 0xeb, 0x46, 0x39, 0xdf, 0xf6, 0xa3, 0xee, 0xf6, 0xe0, 0x33, 0x50, + 0xdd, 0x27, 0x9f, 0x07, 0x13, 0xe2, 0x1f, 0x35, 0xed, 0x1f, 0xb7, 0x07, + 0x11, 0xed, 0x17, 0x28, 0xf4, 0x20, 0xc1, 0xec, 0xef, 0x16, 0x02, 0xfa, + 0xe0, 0x1b, 0xf7, 0xdb, 0xfd, 0x0a, 0xe7, 0xfb, 0xe7, 0x25, 0xe2, 0xe7, + 0xf8, 0xf0, 0xee, 0xe9, 0x02, 0x06, 0xc9, 0xe4, 0x14, 0xe3, 0xe2, 0xf7, + 0xf8, 0xfd, 0xdd, 0xe2, 0x08, 0x0a, 0xe4, 0x05, 0xf5, 0x16, 0xe7, 0x01, + 0x00, 0x1c, 0xe7, 0xf0, 0xf6, 0x19, 0xfe, 0x0c, 0xf2, 0x06, 0x03, 0xe8, + 0x0b, 0xfe, 0xe3, 0x19, 0x08, 0x1a, 0x10, 0xfd, 0x00, 0x21, 0xf0, 0xeb, + 0x18, 0x02, 0xf3, 0x04, 0xf0, 0x18, 0xdb, 0x05, 0x01, 0xde, 0xed, 0xe9, + 0x23, 0x15, 0xaf, 0xe6, 0xf1, 0x0a, 0xe6, 0xea, 0x01, 0x18, 0xd8, 0xfd, + 0xf1, 0xe6, 0xec, 0xf5, 0x0e, 0x1e, 0xcc, 0xfc, 0xe7, 0x00, 0xe9, 0x11, + 0x00, 0x30, 0xf9, 0x14, 0xf4, 0x19, 0xdd, 0xf7, 0xf7, 0x2f, 0xf4, 0xf2, + 0xff, 0x27, 0x15, 0x1c, 0xbc, 0x2f, 0xe9, 0x14, 0xf5, 0xe8, 0x44, 0x30, + 0xe8, 0x1d, 0xe4, 0x18, 0x11, 0x00, 0x0c, 0x2b, 0xf3, 0x29, 0x96, 0xe0, + 0x06, 0xee, 0x3e, 0x55, 0xdc, 0x13, 0x98, 0xdf, 0xf0, 0xfe, 0x17, 0x33, + 0xe8, 0x09, 0xa3, 0x07, 0xef, 0x0e, 0x1d, 0x37, 0xdd, 0xfe, 0xb5, 0x00, + 0xf7, 0xe0, 0xea, 0xfd, 0xfd, 0x19, 0xbc, 0xfd, 0x15, 0xfe, 0x01, 0xf3, + 0xd5, 0x20, 0xbf, 0xe3, 0x15, 0x0e, 0xf0, 0xf6, 0xf2, 0x14, 0xcc, 0xf0, + 0xf7, 0x04, 0xf2, 0xff, 0x0b, 0x02, 0xd2, 0xd8, 0xfa, 0xfc, 0xe5, 0x02, + 0x00, 0xfb, 0xf0, 0xdc, 0x1e, 0x10, 0x02, 0x01, 0x00, 0x18, 0xe9, 0xdb, + 0x1e, 0xf6, 0xfc, 0x03, 0xef, 0x0a, 0x00, 0x16, 0x00, 0x0f, 0xf4, 0x16, + 0xfa, 0x0b, 0xe2, 0xfa, 0xe0, 0x07, 0xfb, 0x02, 0x21, 0x0e, 0xdd, 0x0b, + 0xea, 0xf0, 0xeb, 0xfb, 0x19, 0x09, 0xd4, 0xf2, 0xef, 0x0b, 0x00, 0xeb, + 0x1a, 0x2f, 0xea, 0x06, 0x03, 0xf6, 0xf8, 0xfb, 0xfe, 0x1d, 0xea, 0xdd, + 0xed, 0xfd, 0xfb, 0xe7, 0xfe, 0x18, 0xf4, 0xfc, 0x0b, 0xf6, 0xfc, 0x0b, + 0xfb, 0x28, 0x07, 0xff, 0x07, 0x1e, 0x03, 0x21, 0xcf, 0x22, 0x05, 0xe6, + 0xea, 0xe7, 0x43, 0x2e, 0xe7, 0x14, 0xfb, 0x0a, 0x1e, 0xfe, 0x2c, 0x24, + 0xd5, 0xfd, 0x9e, 0xd1, 0xf2, 0x1c, 0x32, 0x51, 0x01, 0xf3, 0xac, 0xe1, + 0xf4, 0xe5, 0x1c, 0x37, 0xf1, 0x0f, 0xa7, 0xdb, 0x00, 0xf6, 0x0f, 0x18, + 0xe1, 0x10, 0xc9, 0xc5, 0xe8, 0xeb, 0xf2, 0xfd, 0xf6, 0x02, 0xc2, 0xff, + 0x00, 0x19, 0x03, 0x0f, 0x02, 0x22, 0xd4, 0xe7, 0x07, 0x0f, 0xe5, 0x1a, + 0x09, 0x0b, 0xdc, 0xd2, 0x00, 0x05, 0xee, 0xf8, 0xdc, 0x14, 0xd0, 0x0a, + 0x0a, 0xfa, 0xeb, 0x04, 0xf3, 0x06, 0xde, 0x05, 0xfb, 0xfd, 0xe3, 0xec, + 0xfd, 0x14, 0xd7, 0x11, 0x0e, 0xe6, 0x06, 0xec, 0xde, 0x22, 0xd7, 0x00, + 0x03, 0xf5, 0xf5, 0x0d, 0x01, 0x05, 0xea, 0x0b, 0x16, 0x04, 0xff, 0x13, + 0xf3, 0x12, 0xd2, 0xdf, 0x0b, 0xe4, 0x06, 0xf6, 0x08, 0x2d, 0xd3, 0xd6, + 0xe7, 0x0a, 0xec, 0xff, 0xfe, 0x01, 0xdf, 0xf4, 0xdf, 0x1c, 0xfe, 0xf9, + 0xf7, 0x13, 0xca, 0xff, 0x03, 0x06, 0xe9, 0xf7, 0x06, 0x08, 0xd7, 0xf3, + 0xed, 0x08, 0xe3, 0xfd, 0x0c, 0x11, 0x15, 0xfb, 0x15, 0x08, 0x28, 0x40, + 0xe7, 0x0d, 0x08, 0xec, 0xe8, 0x16, 0x67, 0x46, 0xc8, 0x16, 0xf1, 0x02, + 0x24, 0x00, 0x3a, 0x43, 0xd6, 0x12, 0xae, 0xe7, 0xf4, 0xf8, 0x3a, 0x65, + 0xe4, 0x0c, 0xb2, 0xef, 0x1f, 0xe8, 0x29, 0x59, 0xf8, 0x11, 0xc4, 0xe1, + 0xfe, 0xfa, 0x27, 0x43, 0xc9, 0x1e, 0xbb, 0xfb, 0xf3, 0x13, 0x15, 0x0d, + 0xf1, 0x13, 0xcd, 0xf0, 0x07, 0x19, 0x07, 0x00, 0xd8, 0xeb, 0xbf, 0xf0, + 0xfc, 0xf6, 0xef, 0x16, 0x01, 0x02, 0xc1, 0xdf, 0xfd, 0xe9, 0x06, 0x06, + 0xf1, 0x08, 0xd7, 0xcc, 0xfb, 0x0e, 0xfc, 0x14, 0xf2, 0x1a, 0xe2, 0x0d, + 0xeb, 0x09, 0x07, 0x10, 0xe6, 0x13, 0xeb, 0xf5, 0x15, 0x14, 0xeb, 0xfe, + 0xf9, 0x17, 0xd2, 0xe3, 0x1e, 0xf5, 0x04, 0x0a, 0xf1, 0x0e, 0xde, 0xe7, + 0x01, 0x20, 0x0c, 0xfc, 0xdc, 0xf9, 0xe5, 0xe9, 0xff, 0x1d, 0x0a, 0xfe, + 0xec, 0x25, 0xaf, 0xd2, 0x01, 0x16, 0xfc, 0x17, 0xe8, 0x1e, 0xcd, 0xd9, + 0xe2, 0xf1, 0xeb, 0x08, 0xff, 0x33, 0xe5, 0xfb, 0xeb, 0x04, 0xfe, 0xf7, + 0xfd, 0x1f, 0xee, 0xff, 0xed, 0xf8, 0xe0, 0xff, 0xfd, 0x2b, 0x0a, 0xf5, + 0x15, 0x1d, 0xf3, 0x3f, 0x16, 0xf6, 0xf2, 0xee, 0xf4, 0xef, 0xf0, 0x56, + 0x0a, 0x1a, 0xbc, 0xfc, 0x2f, 0xfb, 0xf0, 0x56, 0x1e, 0x0e, 0xc6, 0xe8, + 0x06, 0x0b, 0x11, 0x62, 0x3e, 0xf9, 0xb8, 0xc9, 0xed, 0xeb, 0x02, 0x63, + 0x2c, 0xfd, 0xc5, 0xe9, 0x00, 0x17, 0x0f, 0x37, 0xfe, 0x20, 0xcc, 0xe0, + 0xe0, 0x0e, 0xe6, 0x20, 0x0a, 0xfd, 0xdf, 0xee, 0x0b, 0x02, 0xee, 0x1f, + 0xfb, 0x06, 0xd2, 0xed, 0xfe, 0xeb, 0xfc, 0x12, 0xfd, 0x14, 0x00, 0xd8, + 0x08, 0xf6, 0xec, 0x17, 0xf9, 0x10, 0x00, 0xd9, 0x18, 0xf1, 0xee, 0x0f, + 0xf4, 0x03, 0xee, 0xeb, 0xf0, 0xef, 0xf2, 0x06, 0x04, 0x00, 0xf4, 0x0f, + 0x09, 0x06, 0xf7, 0x0b, 0xfd, 0x01, 0x03, 0x03, 0xf4, 0xf6, 0xdd, 0x14, + 0x1c, 0xef, 0xf1, 0xdd, 0xf7, 0x13, 0xd9, 0x15, 0xef, 0x02, 0xd2, 0xe7, + 0x05, 0x05, 0xe2, 0x09, 0xf2, 0x11, 0xf5, 0xba, 0xf0, 0x04, 0xe0, 0x01, + 0x06, 0x10, 0xe6, 0xef, 0xfc, 0x12, 0xf9, 0xf4, 0x1b, 0x2f, 0xe3, 0x0f, + 0xd7, 0xf6, 0x0b, 0x11, 0xf7, 0x0c, 0x00, 0x06, 0x18, 0xef, 0x06, 0x03, + 0x0a, 0x09, 0xf6, 0x1a, 0x0d, 0xed, 0xfe, 0x2c, 0x43, 0xf4, 0xe5, 0xde, + 0xf5, 0x02, 0x25, 0x5a, 0x49, 0xd4, 0xe6, 0x24, 0x1e, 0xf7, 0x0e, 0x5c, + 0x5d, 0xf0, 0xf9, 0xe4, 0x1c, 0xeb, 0x28, 0x7f, 0x5b, 0xec, 0xfa, 0xdb, + 0x0c, 0xf5, 0x20, 0x49, 0x51, 0xe1, 0xed, 0xe6, 0x0e, 0x26, 0x28, 0x33, + 0x35, 0x05, 0xe1, 0xe4, 0x1f, 0xfc, 0xf9, 0x39, 0x18, 0x04, 0xed, 0xed, + 0x01, 0xe7, 0xe6, 0x08, 0x09, 0x03, 0xe7, 0xf9, 0x0e, 0x06, 0xec, 0x08, + 0x12, 0x1a, 0xda, 0xef, 0xdf, 0xf9, 0xe2, 0x1e, 0x1c, 0x00, 0x12, 0xd7, + 0x01, 0xf7, 0x21, 0x17, 0x13, 0x19, 0xde, 0xe0, 0xec, 0x16, 0x01, 0x1b, + 0x06, 0x0c, 0xf0, 0xe8, 0x18, 0x03, 0x06, 0x0e, 0x09, 0xfa, 0x03, 0xf3, + 0xdd, 0x01, 0xfb, 0x0a, 0x2a, 0xf4, 0xf6, 0xda, 0xe9, 0xfe, 0xe9, 0x12, + 0x19, 0xe9, 0x05, 0xdf, 0x00, 0xeb, 0xf2, 0x10, 0x0c, 0xe1, 0xcd, 0xcb, + 0xf2, 0x1f, 0xd9, 0x0c, 0xfa, 0xfb, 0xe8, 0xde, 0x00, 0xfc, 0xe5, 0x00, + 0x11, 0x02, 0xe6, 0x17, 0x14, 0x00, 0xf2, 0xfd, 0x00, 0xe1, 0x10, 0x24, + 0x12, 0xec, 0xed, 0x1e, 0x09, 0x18, 0x03, 0x0c, 0x04, 0xf4, 0x15, 0x0f, + 0x10, 0x18, 0xd6, 0x29, 0x10, 0x04, 0x1c, 0xef, 0x0f, 0x0c, 0xc7, 0x04, + 0xfe, 0xeb, 0xff, 0xf5, 0xe3, 0x15, 0xfe, 0xcb, 0x10, 0xff, 0x12, 0xfb, + 0xe4, 0xeb, 0xf9, 0x00, 0x02, 0xf1, 0x14, 0x13, 0x01, 0x02, 0xf9, 0x01, + 0x06, 0x0c, 0xf5, 0x0a, 0x1e, 0x01, 0x19, 0x0e, 0x05, 0xf5, 0x0a, 0xff, + 0xff, 0xf2, 0xfb, 0xdb, 0xf8, 0x06, 0x17, 0xf2, 0xf7, 0x0d, 0x0e, 0xf4, + 0xfa, 0xf7, 0x14, 0xdb, 0xe0, 0xfd, 0x08, 0x16, 0xf7, 0x16, 0xfc, 0x09, + 0x27, 0x07, 0x09, 0xfb, 0x0a, 0xfc, 0x0c, 0xe4, 0xdb, 0xee, 0xff, 0x10, + 0xf3, 0x09, 0xfa, 0xf4, 0x23, 0xf3, 0xf4, 0x19, 0xff, 0xfa, 0xff, 0x19, + 0x0f, 0x11, 0xed, 0xec, 0xf8, 0x0f, 0x10, 0xf3, 0xff, 0x0b, 0xf7, 0x06, + 0x0b, 0x0e, 0x07, 0xe4, 0x18, 0x0a, 0x08, 0x0e, 0x02, 0x0a, 0x05, 0x19, + 0x02, 0xf3, 0xfe, 0xfe, 0x0b, 0x0f, 0xfc, 0xfa, 0x05, 0xf9, 0xe2, 0xf9, + 0x1b, 0xf7, 0x0f, 0x07, 0xfc, 0x12, 0xfe, 0x01, 0xfd, 0xf0, 0x04, 0xf4, + 0xfd, 0x07, 0xf2, 0x04, 0x04, 0x07, 0xef, 0x0c, 0xed, 0x0e, 0xf6, 0xef, + 0x08, 0x07, 0x04, 0xe9, 0xf3, 0x20, 0xda, 0x15, 0xf8, 0xff, 0xec, 0xe0, + 0xf6, 0xff, 0xe9, 0x08, 0x01, 0x10, 0xf0, 0xfc, 0xe9, 0x08, 0xe8, 0xf5, + 0xf8, 0xe5, 0x17, 0xe6, 0x03, 0xfc, 0x09, 0xf5, 0xdd, 0xf2, 0xff, 0x05, + 0xf6, 0xf8, 0xf5, 0x07, 0xfc, 0xf1, 0x04, 0xf3, 0x13, 0xe1, 0x0f, 0xf2, + 0x0a, 0xf9, 0xfd, 0x1c, 0xe0, 0x11, 0x1b, 0xe6, 0xef, 0x05, 0x05, 0x0c, + 0x23, 0x10, 0x09, 0xfe, 0xf7, 0x1a, 0xf1, 0xfc, 0x11, 0x1d, 0xff, 0x03, + 0x03, 0xe6, 0x07, 0x11, 0x0c, 0x0d, 0x16, 0x05, 0x05, 0x25, 0xf3, 0x10, + 0x10, 0x06, 0x09, 0xe8, 0x1a, 0xf0, 0xee, 0x09, 0xff, 0x24, 0xf7, 0xfb, + 0xe6, 0x06, 0xfa, 0x08, 0x03, 0x00, 0xf2, 0x04, 0xf0, 0xeb, 0x14, 0x1c, + 0x03, 0x21, 0x14, 0x1d, 0xfe, 0x03, 0xf6, 0x02, 0x09, 0xff, 0x00, 0x13, + 0xef, 0x10, 0x1e, 0x0b, 0x1d, 0x1c, 0xf1, 0xf6, 0xe7, 0xfd, 0x14, 0x01, + 0xff, 0x13, 0xf7, 0xfc, 0x00, 0x21, 0xe3, 0xeb, 0x07, 0x0e, 0x09, 0xf1, + 0xf8, 0xfd, 0x03, 0xee, 0x19, 0xfd, 0xff, 0xfb, 0xff, 0xea, 0xfb, 0x07, + 0xf0, 0x0a, 0x04, 0x04, 0x0b, 0x12, 0xfe, 0x0b, 0xe0, 0xff, 0xf6, 0xe5, + 0xfc, 0x11, 0xed, 0xfd, 0x15, 0x03, 0xdd, 0xdb, 0x04, 0xfe, 0xff, 0x0e, + 0xff, 0xfa, 0xfb, 0xe5, 0xef, 0xf6, 0xfe, 0x22, 0x0f, 0xe8, 0xfe, 0xf4, + 0xfd, 0xd9, 0x03, 0x0a, 0xdf, 0xcf, 0xf1, 0x14, 0x05, 0xfd, 0xfb, 0xf3, + 0xfb, 0xfb, 0x0f, 0xf8, 0x05, 0x09, 0x03, 0xf7, 0x05, 0x05, 0x13, 0xfb, + 0xeb, 0x23, 0xe7, 0x18, 0xfb, 0x00, 0xfe, 0xdd, 0xe9, 0xea, 0xd3, 0xe8, + 0x1a, 0xef, 0x01, 0xf1, 0x09, 0x1d, 0xd8, 0xfc, 0xda, 0x19, 0x03, 0xec, + 0xe5, 0xf3, 0xed, 0x0a, 0xf4, 0x13, 0x0b, 0xf7, 0x0c, 0x00, 0xf9, 0xea, + 0xe3, 0xfe, 0xff, 0x0d, 0x0a, 0x1b, 0xd7, 0x17, 0xeb, 0xe9, 0x00, 0x0e, + 0xee, 0x24, 0xef, 0x09, 0x07, 0xf0, 0xf5, 0x07, 0xf5, 0xf5, 0x10, 0x17, + 0x06, 0xf7, 0xfc, 0x02, 0xfb, 0xf9, 0xe7, 0x0a, 0x26, 0xf3, 0x01, 0x01, + 0x09, 0x0b, 0x02, 0x27, 0xf8, 0xee, 0xfd, 0x1c, 0xf8, 0xf2, 0x0f, 0xfc, + 0x0d, 0xe0, 0xea, 0x02, 0x0b, 0x00, 0xe0, 0x08, 0xfe, 0x10, 0x04, 0xfe, + 0xeb, 0x13, 0x01, 0x0c, 0x0e, 0xed, 0x09, 0x01, 0x0c, 0xe3, 0x10, 0xdf, + 0xd1, 0x14, 0xf3, 0xef, 0x09, 0xf0, 0xee, 0xe5, 0x11, 0xf4, 0xf6, 0x00, + 0xe8, 0x20, 0x0a, 0xfc, 0xea, 0xf7, 0x02, 0x16, 0xe7, 0xf3, 0x0d, 0xe4, + 0x04, 0xe6, 0xef, 0xf8, 0x0f, 0x23, 0x02, 0xe0, 0x01, 0x01, 0x01, 0x05, + 0xf5, 0x0d, 0xf5, 0xf5, 0xe1, 0xff, 0x04, 0x00, 0xf4, 0x0d, 0xee, 0xf1, + 0xef, 0xf7, 0x0b, 0xff, 0x1b, 0xec, 0x05, 0xe7, 0xf3, 0x13, 0x12, 0xf2, + 0xf3, 0xfc, 0xea, 0x06, 0xfe, 0x13, 0x12, 0xdb, 0x11, 0xe2, 0xfc, 0x0d, + 0x1c, 0xe8, 0x1d, 0xfc, 0xf2, 0xe2, 0x13, 0x1d, 0xda, 0xf6, 0x1c, 0x18, + 0x1e, 0xf4, 0xfa, 0x03, 0xdc, 0x0f, 0xff, 0xff, 0x18, 0x0b, 0xed, 0xf1, + 0xf8, 0x02, 0xf4, 0x10, 0xf9, 0xeb, 0x0b, 0x0e, 0x0f, 0x01, 0x02, 0x1b, + 0x06, 0x10, 0x00, 0xe7, 0x23, 0x0d, 0xf6, 0x11, 0x08, 0xf5, 0x0f, 0x05, + 0x13, 0xf7, 0x01, 0x01, 0x0c, 0xf6, 0xf9, 0xf0, 0x29, 0x01, 0xe9, 0x11, + 0x02, 0xfa, 0xeb, 0x16, 0x0e, 0x10, 0x09, 0x0e, 0x1c, 0x0a, 0xe3, 0xd3, + 0x01, 0xe3, 0x00, 0x06, 0xe2, 0xe9, 0x19, 0xef, 0x12, 0xf3, 0xfc, 0x02, + 0x0b, 0x0c, 0x0d, 0xed, 0xfd, 0xf6, 0xf9, 0xe9, 0xf2, 0x28, 0xfe, 0x03, + 0xec, 0x03, 0x00, 0xf8, 0xde, 0x0d, 0x25, 0x07, 0x1a, 0xe7, 0xfd, 0x29, + 0xd8, 0xf7, 0xfb, 0xde, 0x0c, 0x08, 0x06, 0x22, 0xee, 0x1d, 0x05, 0x07, + 0xf0, 0xfb, 0xfe, 0x07, 0xf1, 0x04, 0xe9, 0x01, 0xfc, 0xf1, 0x00, 0xeb, + 0xe3, 0x08, 0xec, 0xfe, 0x04, 0xeb, 0xfc, 0x01, 0xf6, 0x0e, 0xdf, 0xf8, + 0x12, 0xe3, 0x16, 0xdc, 0x21, 0x0a, 0xe6, 0x06, 0xe5, 0x10, 0x07, 0xf7, + 0x1e, 0xde, 0xe3, 0x07, 0x16, 0xed, 0x23, 0xf2, 0x12, 0x0d, 0xe9, 0xf9, + 0xe8, 0xfe, 0x0e, 0x02, 0x18, 0x0a, 0xea, 0xec, 0xfb, 0xfe, 0x0c, 0x1b, + 0x19, 0x20, 0xfa, 0x07, 0xe5, 0x0c, 0x04, 0x27, 0xdb, 0xe6, 0xfe, 0x0d, + 0x0a, 0x0a, 0xfe, 0x39, 0xdd, 0xde, 0x05, 0xec, 0x09, 0x05, 0x0a, 0x2c, + 0xf4, 0x02, 0x1f, 0xd3, 0x24, 0xee, 0x0f, 0x3c, 0xf5, 0xfd, 0xf8, 0xf8, + 0x12, 0xf5, 0xf3, 0x19, 0xf9, 0xda, 0xf6, 0x0a, 0x0a, 0xf4, 0x09, 0x0f, + 0xfc, 0x00, 0x01, 0x01, 0xf3, 0xf8, 0x05, 0xf3, 0x0c, 0x19, 0x0e, 0xfd, + 0xfa, 0xe1, 0xfc, 0x0c, 0x03, 0xfb, 0x1b, 0x06, 0xcc, 0xe4, 0x08, 0xf9, + 0x10, 0xe9, 0x06, 0x00, 0x17, 0xe8, 0x0d, 0x12, 0xca, 0xf5, 0x23, 0xe4, + 0x21, 0xf6, 0x19, 0x33, 0xdd, 0xfa, 0x0c, 0x01, 0x14, 0x07, 0x00, 0x34, + 0xda, 0x05, 0x07, 0x01, 0x07, 0xe4, 0x06, 0x24, 0x02, 0xff, 0xf0, 0x09, + 0xfc, 0xf4, 0x03, 0x06, 0xee, 0x08, 0xe2, 0x1d, 0xfa, 0x0c, 0xfc, 0x02, + 0x03, 0xe5, 0xf0, 0xe2, 0x0a, 0x18, 0x12, 0x0c, 0x1e, 0x20, 0xed, 0x20, + 0xe4, 0x01, 0x2a, 0x09, 0x0d, 0x0e, 0xd0, 0xf4, 0xdd, 0xfd, 0x2b, 0xf2, + 0x08, 0x0c, 0xf8, 0xf7, 0xfc, 0xf9, 0x15, 0xef, 0x19, 0x1c, 0x01, 0xff, + 0xe2, 0x01, 0xf3, 0x30, 0x0e, 0xfb, 0x15, 0xe8, 0x1c, 0x00, 0xfa, 0x16, + 0xef, 0xea, 0xfb, 0x05, 0xf0, 0x0e, 0x02, 0x13, 0xf4, 0x01, 0x03, 0xe5, + 0x29, 0x07, 0x09, 0x24, 0xf9, 0xe3, 0xf8, 0xde, 0x2d, 0xf4, 0xf5, 0x40, + 0xed, 0xdf, 0x07, 0xef, 0x0f, 0x0a, 0x0b, 0x32, 0x0d, 0xe8, 0x00, 0xe6, + 0xf6, 0xfc, 0xfd, 0x19, 0x11, 0x09, 0xf3, 0x03, 0xea, 0xf1, 0xfb, 0x02, + 0xfd, 0x06, 0xff, 0xfe, 0x09, 0xec, 0x06, 0x0c, 0x15, 0xf9, 0x06, 0xd7, + 0xe3, 0xf7, 0xed, 0x01, 0x03, 0xfd, 0x14, 0x01, 0x0e, 0xe0, 0x37, 0x0d, + 0xd2, 0x18, 0x2f, 0xea, 0x12, 0x0d, 0x05, 0x3a, 0xd5, 0x07, 0x1e, 0xf2, + 0x21, 0x11, 0xf9, 0x36, 0xd3, 0xf5, 0x12, 0xf6, 0xfb, 0xf6, 0x06, 0x0f, + 0xde, 0xf9, 0x06, 0x09, 0xdf, 0xff, 0x0b, 0xf3, 0xf5, 0x01, 0xf1, 0xea, + 0xf2, 0x02, 0x12, 0xfc, 0x0e, 0xee, 0xf8, 0xeb, 0x00, 0xef, 0x21, 0x0f, + 0x09, 0xef, 0xeb, 0x1e, 0xef, 0xf2, 0x26, 0xf9, 0x17, 0xf1, 0xf1, 0xf0, + 0x0c, 0x10, 0x1d, 0xff, 0x1d, 0x06, 0x03, 0xf6, 0xfb, 0x14, 0x1b, 0x03, + 0x22, 0xfd, 0xec, 0x03, 0xfa, 0xf8, 0x01, 0x2b, 0x1e, 0x1b, 0x09, 0x09, + 0x07, 0xff, 0xf0, 0x20, 0xee, 0x14, 0xfb, 0xf6, 0xf8, 0x11, 0xd9, 0x29, + 0xf4, 0xfa, 0x07, 0xef, 0x20, 0xf9, 0xf2, 0x30, 0xee, 0xf0, 0xf3, 0xd6, + 0x0d, 0xfe, 0x03, 0x36, 0xf5, 0xd7, 0x01, 0xe6, 0x04, 0xf0, 0x05, 0x1f, + 0x0f, 0xdd, 0xff, 0xf8, 0x1f, 0xf2, 0x04, 0x37, 0xfa, 0x00, 0xfd, 0xf8, + 0x10, 0xe1, 0xfb, 0x0d, 0xed, 0xf6, 0xe2, 0xfe, 0x08, 0xfe, 0x07, 0x08, + 0x08, 0x11, 0x0a, 0xf0, 0xf8, 0xf5, 0x04, 0xea, 0x08, 0x12, 0x06, 0x0d, + 0x0f, 0x10, 0x40, 0x28, 0xc0, 0xfb, 0x3f, 0x08, 0x1d, 0x09, 0x1b, 0x3d, + 0xee, 0xf4, 0x29, 0x13, 0x20, 0xfc, 0x11, 0x4c, 0xdb, 0x02, 0x15, 0x05, + 0xec, 0xeb, 0x0a, 0x22, 0xe7, 0x00, 0x02, 0x01, 0xd4, 0xea, 0x0a, 0xf3, + 0xe3, 0xf8, 0xf5, 0xfa, 0x01, 0x0d, 0x19, 0x06, 0x24, 0x13, 0x02, 0xf5, + 0xf1, 0xf1, 0x1b, 0x0f, 0x19, 0x04, 0xe3, 0xf9, 0xe7, 0x02, 0x29, 0xfc, + 0x29, 0xec, 0xe9, 0x04, 0xdc, 0x22, 0x1d, 0xfd, 0x1f, 0x01, 0xec, 0xe8, + 0xf5, 0x14, 0x1b, 0x19, 0x06, 0x0e, 0x02, 0x0d, 0xf9, 0x06, 0xfc, 0x15, + 0x07, 0xfa, 0x0c, 0xe1, 0x18, 0x1a, 0xe8, 0x1b, 0xe9, 0xef, 0x0a, 0x18, + 0xfc, 0x05, 0xf9, 0x14, 0xdc, 0x04, 0x01, 0xff, 0x07, 0xfd, 0xf0, 0x2c, + 0xf2, 0xec, 0x0e, 0xe7, 0x1a, 0x05, 0xe8, 0x35, 0x13, 0x09, 0xf9, 0x07, + 0xfe, 0xfa, 0x0d, 0x40, 0x0c, 0xea, 0xf4, 0x04, 0x01, 0x11, 0xfc, 0x23, + 0xeb, 0xf4, 0xe9, 0x04, 0xeb, 0xe7, 0x07, 0x09, 0xfb, 0xf1, 0xf6, 0xfd, + 0x02, 0xfa, 0x02, 0xff, 0x00, 0xff, 0xf1, 0xf1, 0x1a, 0xe9, 0x10, 0xe3, + 0x0b, 0x0c, 0x08, 0x04, 0x1b, 0x0a, 0x2b, 0x10, 0xe1, 0x01, 0x1f, 0x06, + 0x04, 0xec, 0x19, 0x49, 0xee, 0xf8, 0x22, 0x0c, 0x20, 0x02, 0x07, 0x31, + 0xe7, 0xff, 0x0f, 0xf0, 0xfd, 0xea, 0x13, 0x26, 0xce, 0xfa, 0xff, 0xee, + 0xe9, 0xfe, 0x15, 0x08, 0x04, 0x05, 0x0d, 0xfa, 0xdd, 0xf8, 0x07, 0x0b, + 0x33, 0xef, 0xec, 0xf9, 0xd9, 0xe6, 0x1d, 0x10, 0x41, 0xf6, 0xdf, 0x11, + 0xe3, 0x14, 0x1d, 0xfb, 0x2b, 0x15, 0xdc, 0x09, 0xf6, 0x05, 0x16, 0x00, + 0x1c, 0x27, 0xe4, 0xfc, 0xf7, 0x16, 0x08, 0x08, 0x2f, 0xdd, 0xf8, 0xfa, + 0xe9, 0x0e, 0x0b, 0x0b, 0x02, 0x12, 0x02, 0xfd, 0x19, 0x03, 0xeb, 0x11, + 0xf4, 0x09, 0x09, 0x15, 0x12, 0x0d, 0xef, 0x1c, 0xe4, 0xfe, 0x17, 0x0c, + 0x09, 0x04, 0xea, 0x2f, 0xf2, 0x1e, 0x02, 0xfb, 0xfe, 0xe3, 0x00, 0x2e, + 0x04, 0xf9, 0x0c, 0x05, 0x27, 0x0c, 0x07, 0x2d, 0xf7, 0x0b, 0xfb, 0xf9, + 0x1c, 0xdf, 0x11, 0x36, 0x05, 0xf2, 0x02, 0xf8, 0x0b, 0x07, 0x05, 0xfb, + 0xfc, 0x0e, 0x13, 0xfa, 0xfb, 0x09, 0xf5, 0xfd, 0x06, 0x15, 0xf9, 0x03, + 0x18, 0xfd, 0x1a, 0x0a, 0x03, 0xe2, 0xfb, 0x00, 0x1e, 0xfe, 0x4f, 0x27, + 0xe1, 0xf7, 0x31, 0xf0, 0x1b, 0xec, 0x07, 0x5f, 0xe2, 0xf8, 0x40, 0x05, + 0x17, 0x24, 0x0c, 0x3c, 0xf3, 0x10, 0x13, 0xf8, 0x0b, 0xf3, 0xf9, 0x36, + 0xe1, 0xf3, 0xf4, 0xe8, 0xef, 0xf8, 0xfc, 0xeb, 0xe3, 0xfb, 0xf0, 0xee, + 0xdb, 0x06, 0x0c, 0x11, 0x1e, 0x10, 0xe2, 0xe9, 0xeb, 0x0d, 0x34, 0x0f, + 0x43, 0xd9, 0xef, 0x08, 0xec, 0x05, 0x1d, 0x02, 0x33, 0xef, 0xf4, 0xf7, + 0xe6, 0xf9, 0x22, 0x07, 0x04, 0x06, 0xe9, 0x02, 0xf0, 0xfc, 0x24, 0x20, + 0x24, 0x17, 0xe6, 0x0f, 0x05, 0xf6, 0xfc, 0x1f, 0xf2, 0x01, 0x0d, 0xe7, + 0xff, 0x1d, 0xf0, 0xfa, 0xd0, 0x00, 0xff, 0x0e, 0x23, 0xf9, 0xf3, 0x11, + 0xde, 0x0d, 0x05, 0x04, 0x0b, 0x0b, 0xfb, 0x26, 0x0d, 0x0d, 0xff, 0xe8, + 0x16, 0xe8, 0x0b, 0x3c, 0x18, 0xe4, 0x04, 0xff, 0xfa, 0xf3, 0xff, 0x40, + 0xee, 0x06, 0xfc, 0x0d, 0x00, 0xf7, 0x13, 0x3f, 0xf7, 0x13, 0x06, 0x08, + 0xf9, 0x13, 0xf2, 0x19, 0xfd, 0xf9, 0xf3, 0xe6, 0xfc, 0x07, 0xf6, 0xfd, + 0x0a, 0x22, 0x00, 0x01, 0x19, 0xff, 0xe7, 0xff, 0x08, 0xfd, 0x03, 0xfd, + 0x1f, 0xe7, 0x28, 0x08, 0xde, 0xf3, 0x43, 0xf6, 0x0c, 0xfe, 0x1e, 0x52, + 0xf2, 0x04, 0x17, 0xf2, 0x08, 0x0d, 0x04, 0x38, 0xde, 0x0c, 0x10, 0xef, + 0xdf, 0x0f, 0x01, 0x24, 0xde, 0xe1, 0x0d, 0xfd, 0xd4, 0xf6, 0x12, 0x0e, + 0xed, 0x01, 0xf0, 0xf3, 0xfd, 0xff, 0x18, 0xf3, 0x36, 0xda, 0xf6, 0xef, + 0xe8, 0xef, 0x37, 0x27, 0x4e, 0xf8, 0xf4, 0xff, 0xe5, 0xf3, 0x32, 0x0b, + 0x36, 0x08, 0xe9, 0xf6, 0xe2, 0x13, 0x21, 0xfe, 0x12, 0xed, 0xdd, 0xfb, + 0xf8, 0x05, 0x0f, 0x03, 0x1c, 0x04, 0xfc, 0xf2, 0x23, 0x0e, 0x03, 0xfc, + 0xf9, 0x18, 0xf7, 0x01, 0x1b, 0x03, 0xf5, 0xfd, 0xde, 0xf3, 0x19, 0xfc, + 0x11, 0x02, 0xe7, 0x13, 0xde, 0xd8, 0xf2, 0x05, 0x28, 0x02, 0x02, 0x27, + 0x07, 0x08, 0xff, 0x07, 0x27, 0x0e, 0x19, 0x40, 0xfb, 0x02, 0x0c, 0xf6, + 0x0d, 0x07, 0x0f, 0x47, 0xf8, 0x05, 0x0e, 0xfd, 0x03, 0x1e, 0x07, 0x32, + 0xe7, 0xf6, 0x24, 0x01, 0x01, 0x02, 0x0a, 0xff, 0xf6, 0x26, 0x15, 0xf0, + 0x04, 0x13, 0x03, 0xfa, 0xfe, 0xf6, 0xf1, 0x09, 0x2a, 0xe6, 0xea, 0xf6, + 0x17, 0x13, 0xeb, 0xff, 0x15, 0xeb, 0x23, 0x06, 0xc8, 0xf6, 0x33, 0xeb, + 0xf4, 0xe7, 0x12, 0x2a, 0xe3, 0xe6, 0x32, 0xfa, 0x16, 0x15, 0x17, 0x40, + 0xf1, 0x08, 0x1a, 0xf3, 0xf6, 0x0c, 0x0c, 0x11, 0xd0, 0x22, 0x02, 0xee, + 0xea, 0xf4, 0xf8, 0xf9, 0x13, 0x10, 0x17, 0xf5, 0xf1, 0x0a, 0x0e, 0xfd, + 0x32, 0xda, 0xf1, 0xe2, 0xdb, 0xf2, 0x34, 0x1f, 0x53, 0xfc, 0xe4, 0xf2, + 0xf6, 0xf2, 0x1d, 0x04, 0x4a, 0xec, 0xee, 0x06, 0xdf, 0x01, 0x1a, 0x04, + 0x27, 0xfc, 0xe6, 0xfd, 0xd9, 0xfd, 0x0e, 0x00, 0x0c, 0x16, 0xf3, 0x03, + 0xf7, 0xfc, 0x0e, 0x0f, 0x09, 0x06, 0x06, 0x04, 0x08, 0x02, 0xed, 0xf5, + 0xe4, 0xe6, 0x07, 0x06, 0x03, 0x18, 0xea, 0x13, 0xe2, 0xfa, 0x10, 0xf2, + 0x02, 0xec, 0x03, 0x3c, 0xf6, 0xf6, 0x0a, 0x10, 0x09, 0xf8, 0x15, 0x24, + 0xfd, 0x0d, 0x09, 0x01, 0x00, 0xff, 0x00, 0x1a, 0xf0, 0xee, 0x08, 0x03, + 0x1d, 0x05, 0x16, 0x46, 0xe6, 0xf8, 0x08, 0x00, 0x09, 0x09, 0xff, 0x01, + 0xfc, 0x20, 0xfc, 0xec, 0x05, 0x1b, 0x03, 0xf1, 0x12, 0xe4, 0xfa, 0x24, + 0x1c, 0xf5, 0xf2, 0x05, 0x11, 0xe7, 0xfa, 0x02, 0x20, 0xea, 0x31, 0x10, + 0xcf, 0xd8, 0x33, 0xee, 0xff, 0x09, 0x20, 0x3f, 0xe2, 0x0a, 0x29, 0xee, + 0x3a, 0xf2, 0x1e, 0x39, 0x02, 0x1e, 0xfe, 0xf2, 0xef, 0xe2, 0x0d, 0x0f, + 0xf1, 0x19, 0x02, 0xe7, 0xec, 0xff, 0xfe, 0xe4, 0xfe, 0xfb, 0x02, 0xf6, + 0xf1, 0xf4, 0x07, 0x1a, 0x2a, 0xf9, 0x06, 0xf9, 0xda, 0xf4, 0x22, 0x02, + 0x4f, 0x0a, 0xf3, 0xfc, 0xf3, 0xf6, 0x25, 0x0a, 0x28, 0x01, 0xf7, 0x09, + 0xe6, 0x05, 0x28, 0xf7, 0x1e, 0xf2, 0xee, 0x13, 0xee, 0x05, 0x0f, 0x0a, + 0x09, 0xe8, 0xe8, 0x0e, 0x05, 0x12, 0x0f, 0x15, 0x02, 0xec, 0xf8, 0x02, + 0xf7, 0x05, 0xf8, 0xff, 0xdc, 0x00, 0x01, 0x00, 0x12, 0x17, 0xec, 0x19, + 0xfa, 0x09, 0xfa, 0xf3, 0x1d, 0x0b, 0x07, 0x25, 0xea, 0x0c, 0xf5, 0xfa, + 0x04, 0xf7, 0xfe, 0x33, 0xfe, 0x14, 0xef, 0x04, 0xf0, 0x00, 0x00, 0x3a, + 0xea, 0xfa, 0x10, 0x01, 0xe4, 0x00, 0xff, 0x23, 0xe9, 0x26, 0x15, 0x10, + 0x04, 0x14, 0x0d, 0x08, 0xf8, 0xfd, 0x10, 0xfb, 0x00, 0x21, 0x06, 0xfa, + 0x0f, 0x08, 0xf1, 0x09, 0x28, 0xf0, 0xd8, 0x0d, 0x08, 0x09, 0x02, 0xfb, + 0x12, 0x03, 0x0e, 0xfb, 0xce, 0xf0, 0x39, 0xe5, 0x09, 0xf6, 0x1f, 0x35, + 0xdd, 0x1c, 0x25, 0xef, 0x17, 0x0c, 0xf6, 0x3e, 0xf0, 0x21, 0x08, 0xff, + 0xd7, 0xfc, 0xfd, 0x1f, 0xe5, 0x18, 0x12, 0xe9, 0xf5, 0xe9, 0x12, 0xf6, + 0x02, 0x13, 0xf4, 0x0a, 0xfd, 0x03, 0x09, 0x08, 0x2f, 0x07, 0xee, 0xfd, + 0xd7, 0x00, 0x2b, 0x29, 0x3b, 0xdb, 0xde, 0xf1, 0xe1, 0xf7, 0x47, 0x12, + 0x35, 0x0c, 0xe4, 0x09, 0xef, 0x17, 0x2b, 0xea, 0x2d, 0xf8, 0xe8, 0x18, + 0xef, 0x03, 0x11, 0x0a, 0x10, 0xff, 0xe8, 0x07, 0x0c, 0x07, 0x03, 0x18, + 0x05, 0x08, 0xf8, 0xf8, 0x06, 0x18, 0xe9, 0xf9, 0xe0, 0x0f, 0x0d, 0x18, + 0x04, 0x01, 0xf0, 0x1c, 0xf6, 0x14, 0xfd, 0x12, 0x0c, 0x0c, 0x02, 0x34, + 0xf6, 0xe6, 0xfd, 0xf9, 0xf9, 0xfd, 0x00, 0x2a, 0xfc, 0xf9, 0xff, 0x0a, + 0xfe, 0x1b, 0xf5, 0x34, 0xdc, 0xf9, 0x15, 0x13, 0xe7, 0x1b, 0xf7, 0x25, + 0xfd, 0x09, 0x08, 0x0a, 0xf0, 0x17, 0x0f, 0x04, 0xf4, 0xe9, 0x06, 0x07, + 0xf5, 0x02, 0xfc, 0xf5, 0x09, 0xee, 0xf1, 0x07, 0x38, 0x03, 0x05, 0x0f, + 0x16, 0x0f, 0xed, 0xff, 0x21, 0xf8, 0x34, 0x07, 0xd1, 0xf9, 0x27, 0x00, + 0x0c, 0x21, 0x18, 0x42, 0xe6, 0x02, 0x1a, 0xf1, 0x2f, 0xf1, 0x0e, 0x3b, + 0xee, 0xf8, 0x08, 0xea, 0xfe, 0xf9, 0x03, 0x18, 0xf5, 0xf8, 0x0d, 0xeb, + 0x01, 0x10, 0x09, 0x02, 0x15, 0xfb, 0xf1, 0x0b, 0xf2, 0x06, 0x08, 0x09, + 0x2f, 0x19, 0x02, 0xfe, 0xe4, 0x06, 0x1f, 0x17, 0x49, 0xf2, 0xe2, 0x02, + 0xef, 0x04, 0x26, 0x16, 0x3f, 0x08, 0xf1, 0x0a, 0xfd, 0xf9, 0x28, 0x01, + 0x15, 0x0b, 0xf9, 0x10, 0xdc, 0x02, 0x20, 0xf7, 0x16, 0xe6, 0x09, 0x03, + 0xf1, 0xf5, 0x12, 0x1c, 0xfb, 0x2a, 0x08, 0xfa, 0x0a, 0x16, 0xf6, 0x15, + 0xf0, 0x06, 0x11, 0xfd, 0x0e, 0xf9, 0xf6, 0x12, 0xed, 0xf3, 0xfd, 0x1f, + 0x0b, 0xfa, 0x08, 0x30, 0xf8, 0xff, 0x0b, 0xeb, 0x10, 0xff, 0x07, 0x22, + 0x0d, 0x07, 0x09, 0x03, 0xf6, 0xf8, 0xfc, 0x26, 0xf8, 0xee, 0x11, 0x02, + 0x03, 0x0a, 0xef, 0x38, 0xfe, 0x13, 0x1b, 0x09, 0xfe, 0x06, 0x05, 0xf3, + 0x04, 0xdf, 0xfc, 0x00, 0xe7, 0x15, 0xec, 0xf1, 0xf8, 0xfc, 0xed, 0x05, + 0x0e, 0xf3, 0x15, 0x09, 0x01, 0x0d, 0xfd, 0x00, 0x24, 0xe2, 0x31, 0x13, + 0xd5, 0x1b, 0x2b, 0xe8, 0x03, 0x08, 0x1d, 0x33, 0xdc, 0xfd, 0x24, 0xe4, + 0x20, 0xfa, 0x07, 0x33, 0x01, 0x12, 0x06, 0xf5, 0xef, 0xf7, 0xfa, 0x13, + 0x01, 0xec, 0xee, 0xe0, 0xfd, 0x0d, 0xff, 0x09, 0xf6, 0x00, 0xed, 0x07, + 0xea, 0x0e, 0xff, 0x0e, 0x26, 0xfc, 0xf0, 0xe7, 0xe7, 0xfe, 0x30, 0xff, + 0x24, 0x04, 0x06, 0xf4, 0xf5, 0xf8, 0x23, 0x0e, 0x3d, 0xf2, 0xfd, 0x04, + 0xe8, 0xfb, 0x23, 0xfe, 0x33, 0xe1, 0x01, 0xfd, 0xdc, 0xfb, 0x0e, 0xfa, + 0x22, 0xfb, 0x11, 0xfa, 0xff, 0x08, 0x21, 0x30, 0x13, 0x03, 0xf2, 0x03, + 0xf8, 0x0f, 0xec, 0x0d, 0xef, 0x0f, 0x10, 0x10, 0x0f, 0xf6, 0xf9, 0x1e, + 0xf7, 0xe5, 0x08, 0xfa, 0x09, 0xff, 0x00, 0x15, 0x02, 0x00, 0x08, 0xfe, + 0xfb, 0x0e, 0x15, 0x28, 0xfa, 0xfb, 0x13, 0x06, 0xfb, 0x05, 0xf6, 0x11, + 0xf6, 0x0b, 0x06, 0x15, 0xe1, 0x00, 0xe9, 0x0f, 0xe1, 0x1d, 0x18, 0xfd, + 0x0b, 0x0f, 0xff, 0xf2, 0xf5, 0xfd, 0x14, 0xff, 0xf4, 0xfe, 0xe2, 0xf8, + 0x14, 0x0b, 0xeb, 0x07, 0x35, 0xe2, 0xeb, 0x0b, 0x04, 0x22, 0xfe, 0x0e, + 0x1d, 0xf2, 0x24, 0x11, 0xcc, 0xec, 0x25, 0xf7, 0xff, 0xf9, 0x06, 0x29, + 0xe4, 0x07, 0x1c, 0xdb, 0xf8, 0x1d, 0xfa, 0x44, 0xf2, 0x01, 0x0f, 0xe6, + 0x11, 0x03, 0xee, 0x17, 0x06, 0xe0, 0x0c, 0xd8, 0xe9, 0xfd, 0x11, 0xfe, + 0x07, 0xdd, 0xea, 0xff, 0xde, 0xdd, 0x0a, 0x09, 0x30, 0xf2, 0x01, 0xe4, + 0xe0, 0xeb, 0x2d, 0x12, 0x2d, 0xeb, 0xfc, 0xf0, 0xe8, 0xf9, 0x1f, 0x08, + 0x3f, 0xeb, 0x0e, 0x13, 0xf9, 0x0c, 0x1c, 0x02, 0x25, 0xec, 0xf6, 0x05, + 0xf3, 0xf4, 0x18, 0x08, 0x12, 0xe9, 0xfb, 0xfd, 0xf9, 0x08, 0x13, 0x1c, + 0x08, 0xec, 0xfe, 0x02, 0xf1, 0x19, 0xf3, 0x1d, 0xf1, 0x07, 0x11, 0x12, + 0xfa, 0xf2, 0xf6, 0x0d, 0xff, 0x17, 0x0a, 0xfb, 0x1f, 0xf8, 0x11, 0x24, + 0xf6, 0xfc, 0xfe, 0x07, 0xed, 0x05, 0x1c, 0x21, 0xfe, 0xfe, 0x16, 0x0d, + 0x08, 0x0f, 0x09, 0x33, 0xf4, 0x1f, 0x14, 0x0c, 0xfe, 0xf5, 0xeb, 0x2a, + 0xee, 0xf3, 0x12, 0x19, 0xec, 0x01, 0x06, 0xf7, 0x05, 0x22, 0x0b, 0xeb, + 0xeb, 0x06, 0xe1, 0xf5, 0x0d, 0xee, 0xfb, 0x0a, 0x31, 0xff, 0xe3, 0xea, + 0x18, 0x09, 0xe3, 0x07, 0x1a, 0xf8, 0x15, 0xfc, 0xcc, 0xf2, 0x2a, 0xe5, + 0x01, 0xea, 0x10, 0x1f, 0xd9, 0x02, 0x13, 0xf6, 0x16, 0x01, 0x0e, 0x3c, + 0x02, 0x17, 0x04, 0xf1, 0xf7, 0x02, 0x07, 0x0c, 0x02, 0x1f, 0xf4, 0xe6, + 0xf0, 0xe9, 0x05, 0xf4, 0xfd, 0xe4, 0xf7, 0xe9, 0xfc, 0xef, 0x06, 0x02, + 0x26, 0xf1, 0xf1, 0xeb, 0xe9, 0xe6, 0x30, 0x1c, 0x38, 0x0f, 0x03, 0xf1, + 0x10, 0x04, 0x30, 0x19, 0x1f, 0xfb, 0xfc, 0x05, 0xe2, 0xfe, 0x18, 0xf2, + 0x1c, 0xf2, 0xf5, 0x0e, 0xf2, 0x05, 0x1d, 0x28, 0x12, 0xf0, 0xf0, 0x0f, + 0x0a, 0x03, 0x1a, 0x1a, 0xf3, 0x08, 0x13, 0xef, 0xf5, 0x1c, 0x06, 0x00, + 0xee, 0x12, 0x1d, 0x03, 0x18, 0x06, 0x0a, 0x0e, 0xf0, 0xeb, 0xfa, 0x0d, + 0x08, 0xff, 0x06, 0x24, 0x0f, 0x03, 0x0a, 0x0f, 0x0e, 0xff, 0x08, 0x33, + 0xfc, 0x00, 0x0e, 0xfb, 0xfb, 0x05, 0x07, 0x19, 0xe8, 0xe7, 0x12, 0x11, + 0x15, 0xf7, 0x0c, 0x1a, 0xf6, 0x28, 0x08, 0xeb, 0xf2, 0x25, 0xee, 0x01, + 0x03, 0xec, 0xed, 0xfa, 0xf0, 0xf2, 0xef, 0xf1, 0x02, 0x23, 0xef, 0x01, + 0x41, 0xfa, 0xf4, 0xf4, 0x15, 0xf5, 0xf5, 0xf9, 0x28, 0xde, 0x20, 0xf6, + 0xc7, 0xde, 0x21, 0xe4, 0xfe, 0xec, 0x0d, 0x2c, 0xee, 0x24, 0x10, 0xf0, + 0x1d, 0x12, 0x0e, 0x2b, 0x06, 0xf8, 0xfd, 0x01, 0x08, 0xef, 0xfd, 0x0f, + 0xeb, 0xed, 0xe1, 0xdf, 0xf1, 0xe5, 0x16, 0xe3, 0x08, 0xfc, 0xf6, 0xf6, + 0xd8, 0xf0, 0x23, 0xfc, 0x2b, 0xf5, 0xff, 0xe7, 0xf4, 0xe9, 0x29, 0x09, + 0x2b, 0x0c, 0xff, 0x08, 0x0b, 0xed, 0x29, 0x14, 0x3c, 0xf5, 0xeb, 0x18, + 0xf6, 0x10, 0x22, 0xf9, 0x17, 0x23, 0x02, 0x0c, 0xf6, 0xfa, 0x2f, 0xfe, + 0x1e, 0xeb, 0xfd, 0x03, 0xf0, 0x07, 0x1c, 0x09, 0xfa, 0xe1, 0x0d, 0x0f, + 0x18, 0x03, 0xfe, 0xf0, 0xec, 0x0b, 0x10, 0x02, 0x14, 0x06, 0xef, 0xf7, + 0xea, 0x0b, 0x05, 0xfe, 0x1f, 0x06, 0x0e, 0x07, 0x00, 0xe1, 0x01, 0x01, + 0x07, 0x05, 0x09, 0xf7, 0xef, 0x15, 0xf7, 0x12, 0x05, 0x03, 0x04, 0x1d, + 0x04, 0x10, 0x12, 0x06, 0x05, 0x00, 0x08, 0x18, 0xd6, 0xf2, 0xfa, 0x07, + 0xf8, 0x12, 0x07, 0xfd, 0xdd, 0x00, 0x04, 0xfb, 0xf8, 0x09, 0xf3, 0x09, + 0xfb, 0xf0, 0xe8, 0x09, 0x27, 0xf5, 0xf8, 0x06, 0x01, 0x02, 0x0e, 0xf6, + 0x1f, 0xfa, 0x29, 0xf8, 0xd6, 0x01, 0x22, 0xf8, 0x1d, 0xe3, 0x1a, 0x39, + 0x0a, 0x0d, 0x19, 0xf5, 0x12, 0xfb, 0x1d, 0x2a, 0x03, 0xf6, 0x0c, 0xf2, + 0xfd, 0xec, 0x18, 0x13, 0xfe, 0x1a, 0xe8, 0xdd, 0x01, 0xf8, 0x30, 0x01, + 0xf8, 0xfe, 0xe4, 0xe7, 0xff, 0xeb, 0x23, 0xfa, 0x2c, 0xf0, 0xfc, 0xe7, + 0x0a, 0xf8, 0x18, 0x10, 0x23, 0x01, 0xfa, 0xe8, 0xf1, 0xfa, 0x1d, 0x0e, + 0x17, 0xe7, 0xe4, 0xf5, 0xf9, 0x0c, 0x17, 0x0c, 0x13, 0xe8, 0xe1, 0x17, + 0x19, 0x05, 0x0b, 0x0f, 0x23, 0xed, 0xff, 0xfe, 0xe0, 0x14, 0x16, 0x00, + 0x0d, 0x1c, 0x0b, 0xf5, 0xfb, 0x18, 0xee, 0xff, 0xff, 0xf3, 0x18, 0x0c, + 0x05, 0xfa, 0xf6, 0xfe, 0xfe, 0xf8, 0xf8, 0x09, 0xef, 0xf8, 0x0e, 0xf0, + 0x00, 0xf8, 0x0c, 0xf8, 0xf6, 0x07, 0x16, 0x11, 0xf8, 0xea, 0xff, 0xff, + 0x01, 0x20, 0x07, 0x08, 0xfd, 0x1c, 0xfc, 0x06, 0xed, 0x0d, 0x08, 0x15, + 0xf0, 0x25, 0x01, 0x1b, 0x00, 0x02, 0xfe, 0x01, 0x05, 0x01, 0xfd, 0xf1, + 0xe5, 0x0c, 0xe4, 0xe1, 0xf0, 0xfa, 0xee, 0x0e, 0x35, 0xee, 0x15, 0xef, + 0x0a, 0xf9, 0x01, 0xf5, 0x1f, 0x05, 0x1f, 0x0d, 0xe1, 0xf4, 0xff, 0xf5, + 0x23, 0x02, 0x18, 0x30, 0xfc, 0xf0, 0x0d, 0x04, 0x0d, 0x06, 0x29, 0x1d, + 0xf9, 0x08, 0x06, 0xe5, 0x13, 0xfd, 0x0d, 0x26, 0xef, 0x09, 0xdc, 0xf2, + 0x05, 0xdf, 0x0c, 0xf6, 0xf3, 0xd9, 0xf8, 0x08, 0xef, 0xeb, 0x0f, 0xf9, + 0x3a, 0x03, 0xff, 0xe0, 0xf7, 0xf0, 0x15, 0x12, 0x41, 0x0b, 0xf1, 0x04, + 0x04, 0xe2, 0x0e, 0x0b, 0x2c, 0x03, 0xea, 0x02, 0xfb, 0xe7, 0x08, 0xe9, + 0x22, 0xf3, 0xf2, 0x1c, 0xfa, 0xf3, 0x11, 0x04, 0x1f, 0xf5, 0x02, 0x0f, + 0x1a, 0x1f, 0x24, 0x0b, 0x06, 0x1f, 0xf3, 0x06, 0x00, 0x02, 0xe8, 0xf6, + 0xf4, 0xe8, 0x07, 0x2e, 0xfb, 0xf8, 0x10, 0x09, 0xf0, 0x0e, 0xff, 0xfe, + 0x1c, 0x14, 0x17, 0x06, 0xe2, 0xf1, 0xfa, 0x01, 0x11, 0x13, 0x12, 0x29, + 0xf1, 0x0f, 0x1f, 0xfa, 0xfd, 0xfd, 0x02, 0x07, 0x0e, 0xfb, 0x0e, 0x04, + 0x01, 0x01, 0xed, 0xfe, 0xde, 0xfd, 0x08, 0xef, 0xf6, 0x0a, 0xff, 0x0f, + 0xe7, 0xf2, 0x0f, 0x02, 0xea, 0x10, 0xf9, 0xec, 0xfd, 0x09, 0xea, 0x1f, + 0x46, 0xdd, 0xe2, 0xf7, 0x08, 0xf5, 0xf7, 0xe9, 0x33, 0xfb, 0x2f, 0xf6, + 0xb5, 0x1d, 0x15, 0xeb, 0x11, 0xf7, 0x2a, 0x2e, 0x08, 0x1d, 0xf4, 0xfb, + 0x15, 0xfa, 0x22, 0x34, 0xff, 0x06, 0xf6, 0xfd, 0xfa, 0xf9, 0x03, 0xf5, + 0xf4, 0xf4, 0xd5, 0xea, 0x01, 0x08, 0x22, 0xf1, 0xf2, 0x06, 0xd1, 0xe5, + 0x0c, 0xef, 0x12, 0x03, 0x08, 0x02, 0xf7, 0x05, 0x1b, 0x07, 0x39, 0x34, + 0x21, 0xe2, 0xe3, 0x0b, 0x0c, 0xf6, 0x29, 0xf7, 0x24, 0x0a, 0xfc, 0xff, + 0x1a, 0xfd, 0x05, 0xff, 0xff, 0x0e, 0x0a, 0x1a, 0x09, 0xfb, 0x15, 0x04, + 0x03, 0xf7, 0xfe, 0x00, 0xfc, 0xfb, 0x11, 0xfa, 0x1d, 0x0e, 0x06, 0xed, + 0xfc, 0x23, 0xd8, 0xf2, 0x04, 0xe5, 0x0f, 0x16, 0x29, 0xfe, 0xf5, 0xec, + 0xe2, 0x0e, 0xeb, 0x09, 0x1d, 0x11, 0x05, 0x11, 0xe4, 0x29, 0x12, 0x02, + 0x12, 0x19, 0x0e, 0x1a, 0xee, 0xf9, 0x05, 0x09, 0xf5, 0xfd, 0x05, 0x04, + 0xe4, 0xf1, 0x17, 0x01, 0xf2, 0xfe, 0x0b, 0xf4, 0x0d, 0x04, 0x06, 0xfe, + 0xff, 0xec, 0xe9, 0x00, 0xff, 0x03, 0x03, 0xfd, 0xf1, 0x15, 0xfc, 0xf3, + 0xff, 0xfe, 0x09, 0xee, 0x3c, 0x01, 0xec, 0x02, 0xf0, 0xf6, 0x20, 0xeb, + 0x16, 0x07, 0x32, 0xf3, 0xce, 0xf0, 0x02, 0xd4, 0x11, 0xe6, 0x28, 0x0e, + 0xe3, 0x21, 0xee, 0xce, 0x1e, 0xd9, 0x23, 0x26, 0x06, 0xfa, 0xf9, 0xf1, + 0x01, 0xe6, 0x0b, 0x07, 0xdc, 0x21, 0xbc, 0xe3, 0xef, 0xf8, 0x12, 0xfc, + 0xe6, 0xfe, 0xf5, 0xd4, 0x15, 0x0a, 0x00, 0x13, 0xfc, 0xec, 0xf3, 0xd6, + 0x1a, 0xe3, 0x21, 0x36, 0x2a, 0x03, 0xe9, 0xe3, 0xff, 0x00, 0x13, 0x1c, + 0x0e, 0x20, 0xe5, 0xf5, 0x24, 0x0b, 0x20, 0x14, 0x13, 0xf8, 0x04, 0x1b, + 0x2f, 0x0a, 0x15, 0x00, 0xf4, 0x1a, 0x11, 0x0d, 0x03, 0x18, 0x0f, 0x18, + 0x04, 0x1f, 0xfb, 0xf2, 0x1f, 0x15, 0x03, 0xfb, 0x0b, 0x17, 0xfb, 0x0b, + 0x1b, 0x1f, 0xf4, 0x07, 0xf9, 0xf9, 0xf8, 0xf4, 0x14, 0x0f, 0xf6, 0xfe, + 0xdd, 0x0b, 0xff, 0x01, 0x18, 0x04, 0x1b, 0x0a, 0xed, 0xe7, 0xf9, 0x16, + 0x02, 0x01, 0x00, 0xf7, 0xf1, 0x07, 0xf0, 0x06, 0xf8, 0x0b, 0x02, 0xf3, + 0xff, 0x20, 0xfd, 0x01, 0x04, 0xf5, 0xd9, 0xf4, 0xf4, 0xf2, 0xe8, 0xff, + 0x04, 0x00, 0xf0, 0xe2, 0xfe, 0xed, 0x1b, 0xef, 0x20, 0xfa, 0xfb, 0xf4, + 0x02, 0x18, 0x07, 0xfb, 0xef, 0xe4, 0x08, 0x0d, 0xe1, 0x0e, 0x25, 0xc6, + 0xfd, 0x0c, 0x1c, 0x0b, 0xf0, 0x01, 0x1c, 0xd4, 0x11, 0xf5, 0x1b, 0x09, + 0xfb, 0xda, 0x13, 0xe3, 0xf9, 0x10, 0x14, 0xf0, 0xf0, 0xfd, 0x1f, 0xcf, + 0xf4, 0xe4, 0xfb, 0x0e, 0x0a, 0x11, 0xed, 0xdc, 0xfc, 0xe6, 0xf7, 0xfc, + 0x13, 0xe1, 0x0b, 0xe4, 0x04, 0x11, 0xee, 0x21, 0x14, 0xe1, 0x07, 0xe4, + 0xfb, 0x08, 0x03, 0x2b, 0x27, 0xf6, 0x0d, 0x02, 0x1b, 0x09, 0x09, 0xf8, + 0x14, 0x19, 0x0f, 0x0b, 0x01, 0x10, 0x09, 0x12, 0x03, 0xf5, 0x18, 0xf3, + 0xfb, 0xf5, 0x02, 0x0e, 0x0d, 0x00, 0x07, 0xfc, 0x18, 0x25, 0x0b, 0xf0, + 0xf9, 0xe6, 0x08, 0x01, 0x24, 0x14, 0xfa, 0xed, 0xe5, 0x1f, 0x09, 0xfe, + 0x08, 0xee, 0x1a, 0x1a, 0x05, 0x00, 0xff, 0x0c, 0xfe, 0xf9, 0x11, 0x11, + 0xea, 0xfe, 0x08, 0xf9, 0xf0, 0xe4, 0x01, 0x0d, 0xf1, 0x00, 0x0b, 0xea, + 0x19, 0xea, 0xf3, 0xf8, 0x08, 0x12, 0x1c, 0x1f, 0xfb, 0xef, 0xf0, 0xf2, + 0x14, 0xe1, 0x03, 0xfa, 0xf9, 0xda, 0xe9, 0xfc, 0xf3, 0xff, 0x12, 0x04, + 0xf7, 0xfc, 0x17, 0x0f, 0xfc, 0x29, 0x03, 0xe5, 0xf2, 0xee, 0x1e, 0xfa, + 0x04, 0xed, 0x25, 0xf4, 0xe1, 0x15, 0x10, 0x1e, 0xef, 0x1c, 0x04, 0xde, + 0xe5, 0x08, 0x21, 0xfd, 0xfd, 0xea, 0x03, 0xca, 0xda, 0x26, 0x00, 0x0a, + 0xfd, 0x05, 0xf0, 0xd4, 0xe1, 0x1a, 0xe4, 0xf5, 0x07, 0xe7, 0xfa, 0xdf, + 0xd4, 0x03, 0xf0, 0x10, 0x15, 0x0c, 0xf4, 0xed, 0xe3, 0xfb, 0x0f, 0x1e, + 0x16, 0x09, 0x00, 0xec, 0xea, 0x13, 0x16, 0x0b, 0x01, 0xfb, 0xff, 0x00, + 0xfb, 0x07, 0x13, 0x08, 0xf4, 0xe4, 0x12, 0x00, 0xfb, 0xfa, 0xfc, 0x08, + 0xeb, 0x19, 0x02, 0x1c, 0xe8, 0x26, 0xf3, 0x10, 0x09, 0x0f, 0x19, 0x02, + 0xfb, 0xec, 0xf7, 0xe2, 0xfb, 0xfa, 0x11, 0xf3, 0x0b, 0x08, 0xff, 0xd9, + 0xf8, 0x12, 0x18, 0x06, 0x07, 0x22, 0xff, 0x19, 0xf5, 0x0b, 0x0a, 0x13, + 0xf2, 0xfa, 0x02, 0x21, 0xeb, 0x11, 0x17, 0x17, 0xec, 0xe1, 0x0e, 0xf7, + 0xe8, 0xd8, 0x0e, 0x01, 0xf1, 0xed, 0xed, 0xf0, 0x09, 0xf7, 0xe7, 0xfd, + 0xf0, 0xf9, 0xdb, 0xee, 0xdc, 0xfb, 0xf8, 0x0a, 0xf5, 0x0b, 0xd4, 0xd7, + 0x08, 0x06, 0x18, 0x06, 0x0c, 0x13, 0xfd, 0x09, 0x13, 0x26, 0x12, 0xf4, + 0xef, 0x00, 0xf5, 0x28, 0x18, 0xfe, 0x04, 0x0e, 0x21, 0x1a, 0x0a, 0x1e, + 0x09, 0xf0, 0x0d, 0x0f, 0xec, 0xf3, 0x17, 0x22, 0x00, 0xec, 0x0e, 0x01, + 0xe9, 0x08, 0x09, 0xf2, 0xf2, 0x08, 0xf0, 0x0b, 0xd9, 0x09, 0x14, 0xf5, + 0xf6, 0x04, 0x19, 0xf4, 0x11, 0xe9, 0xf2, 0x0d, 0x20, 0x17, 0x0a, 0x05, + 0x0c, 0x04, 0x01, 0xfd, 0xf4, 0xfb, 0x1b, 0x0c, 0xf2, 0x0b, 0xff, 0xfe, + 0x01, 0xd8, 0xfa, 0x0e, 0xf5, 0x14, 0xf9, 0x01, 0x04, 0xf8, 0xfa, 0x02, + 0xe8, 0xf9, 0xf9, 0xea, 0xf1, 0x07, 0xff, 0x1e, 0x01, 0x0b, 0xf7, 0x0a, + 0xf7, 0x0c, 0xfd, 0xec, 0xf3, 0x05, 0xf8, 0xda, 0x0b, 0x15, 0xf6, 0xee, + 0xf9, 0x10, 0xfa, 0xfe, 0x08, 0xf0, 0xe6, 0xec, 0x05, 0xff, 0x15, 0x19, + 0x1f, 0x11, 0xfc, 0x09, 0x08, 0x01, 0x06, 0xfe, 0x04, 0x08, 0xfb, 0xfb, + 0x08, 0xf4, 0xf6, 0x28, 0x10, 0xf9, 0x28, 0x0b, 0xf8, 0x0d, 0x01, 0x00, + 0xff, 0x02, 0x05, 0x08, 0xea, 0xe9, 0xf4, 0xf6, 0x01, 0xea, 0xdf, 0x1f, + 0xfe, 0x0a, 0xf9, 0xf7, 0x0c, 0x1b, 0x06, 0xed, 0xf6, 0xf2, 0x03, 0x03, + 0xfd, 0x04, 0xf5, 0x10, 0x0a, 0x0b, 0xf4, 0xf8, 0xf1, 0xe7, 0x05, 0xfe, + 0xe7, 0x0b, 0xf1, 0xec, 0xf4, 0xec, 0x06, 0xee, 0xde, 0x05, 0x1b, 0xfe, + 0x13, 0xf3, 0xd9, 0xea, 0x04, 0x10, 0x05, 0xed, 0x15, 0x02, 0x0b, 0x10, + 0xfa, 0x02, 0x05, 0x0b, 0x02, 0x07, 0xfc, 0xf5, 0x15, 0x14, 0x05, 0xf7, + 0x0c, 0xfe, 0xf6, 0xf4, 0xfa, 0x06, 0xfc, 0x13, 0xdc, 0xe4, 0x09, 0xfa, + 0x02, 0x23, 0xec, 0x06, 0x11, 0x13, 0xf8, 0xfa, 0x27, 0x28, 0x0b, 0x23, + 0xec, 0xf1, 0x09, 0x17, 0x0f, 0x13, 0xff, 0xf2, 0xfc, 0x0a, 0xf5, 0x0d, + 0x03, 0x26, 0x01, 0x0f, 0xfe, 0xf1, 0xfb, 0xe6, 0xf0, 0x02, 0xf2, 0xff, + 0x02, 0x11, 0xff, 0xfd, 0x1c, 0x02, 0x0b, 0xf6, 0x14, 0x0c, 0x0b, 0x21, + 0x28, 0xf0, 0x11, 0x05, 0x06, 0xed, 0xf9, 0x0a, 0xf2, 0xef, 0xf8, 0xf1, + 0xfe, 0x0d, 0xf9, 0xf7, 0xea, 0x00, 0x08, 0xdb, 0x02, 0x0f, 0xfe, 0x04, + 0xef, 0x20, 0x16, 0x01, 0xe8, 0xed, 0xe4, 0x22, 0xf6, 0x19, 0x00, 0x04, + 0x01, 0x13, 0xeb, 0x0d, 0xec, 0x01, 0x08, 0x05, 0x0c, 0x0e, 0xfe, 0x02, + 0x12, 0xf7, 0x27, 0xf9, 0xfd, 0x18, 0xfe, 0x24, 0xf7, 0x13, 0xed, 0x1e, + 0x09, 0xff, 0xd8, 0xf4, 0x12, 0xf8, 0x04, 0x0c, 0x1c, 0x11, 0xfd, 0x17, + 0x1d, 0x01, 0x13, 0xee, 0x11, 0xf3, 0xf8, 0x06, 0xf6, 0x16, 0xfe, 0x15, + 0x16, 0xdc, 0x1f, 0x00, 0x25, 0xee, 0xff, 0xf7, 0xf6, 0x02, 0xdd, 0x15, + 0xf1, 0x14, 0x08, 0xe8, 0xe5, 0x21, 0xea, 0xf0, 0x1a, 0x07, 0xea, 0x08, + 0xea, 0xe4, 0x1e, 0x00, 0x13, 0x17, 0xec, 0x11, 0xd6, 0x11, 0x18, 0x17, + 0x04, 0x15, 0x03, 0x3a, 0xd6, 0x02, 0x07, 0x04, 0xe6, 0xe5, 0xfe, 0x0e, + 0xff, 0xed, 0xfc, 0xfb, 0xff, 0x1c, 0x06, 0x0a, 0xfb, 0xf9, 0xea, 0x1a, + 0x21, 0xf5, 0x04, 0x06, 0x0a, 0xe3, 0x16, 0xea, 0x04, 0xe2, 0xf9, 0xf9, + 0xe6, 0xfb, 0x0f, 0xfc, 0x06, 0xfb, 0x10, 0x07, 0x07, 0x13, 0x07, 0xfc, + 0x16, 0xef, 0x07, 0xdc, 0x12, 0x1f, 0x08, 0xf4, 0xe9, 0x14, 0x06, 0xf7, + 0xf1, 0x0c, 0x01, 0x0c, 0xe6, 0x04, 0xf3, 0xf2, 0xe5, 0xf3, 0xef, 0x1d, + 0xf6, 0x20, 0x07, 0xfe, 0xf4, 0x05, 0xee, 0x10, 0xfd, 0x0e, 0x0b, 0x02, + 0x0d, 0xd8, 0x07, 0xfb, 0x26, 0x0a, 0x1c, 0x21, 0x06, 0x1f, 0xf4, 0x06, + 0x37, 0x18, 0xfa, 0x16, 0x1e, 0x24, 0xfb, 0xf0, 0x12, 0xf9, 0x02, 0x09, + 0x17, 0x16, 0xf3, 0xf9, 0x17, 0xf2, 0x02, 0x0a, 0x2d, 0xe7, 0xe3, 0x25, + 0xf0, 0xf9, 0x0f, 0xdd, 0x15, 0xe6, 0x04, 0xfc, 0xf1, 0x17, 0x0a, 0xea, + 0x24, 0x07, 0xf1, 0x11, 0x13, 0x29, 0xf4, 0xc5, 0xfb, 0x07, 0xef, 0x13, + 0x0b, 0xe1, 0xf1, 0xeb, 0xf8, 0x1b, 0x09, 0x08, 0x1f, 0x15, 0xf2, 0x05, + 0x02, 0xdd, 0x09, 0x0f, 0x16, 0x10, 0x01, 0x30, 0xf2, 0xe0, 0x27, 0xfe, + 0xf1, 0x0e, 0x0e, 0x07, 0xe6, 0x07, 0x0b, 0x18, 0xfe, 0x0f, 0x01, 0x07, + 0xf4, 0x07, 0x10, 0xe7, 0xfb, 0xf3, 0xf7, 0x0b, 0xf9, 0x15, 0x18, 0x25, + 0x0c, 0x14, 0x02, 0x08, 0x0a, 0x0f, 0x10, 0xec, 0xee, 0x1a, 0x03, 0x14, + 0x0f, 0xfa, 0x25, 0xff, 0x18, 0x0d, 0x0b, 0xea, 0x1f, 0x28, 0x10, 0x0c, + 0xe7, 0xee, 0xf7, 0xfa, 0x03, 0x15, 0x0c, 0x1d, 0x01, 0x00, 0x12, 0xee, + 0x01, 0xf1, 0xf8, 0x0b, 0xf3, 0xfd, 0x04, 0xf8, 0x02, 0x1e, 0x0e, 0xf3, + 0x02, 0x10, 0xfd, 0x07, 0x0b, 0x09, 0x03, 0x10, 0x3e, 0x08, 0x0e, 0x0c, + 0xf4, 0xe7, 0xfd, 0x1c, 0x27, 0x1a, 0xed, 0xe1, 0x08, 0xdc, 0xd9, 0xf1, + 0x1e, 0x07, 0x12, 0xf1, 0x10, 0xfb, 0xc8, 0x08, 0x0f, 0x03, 0x1d, 0xdc, + 0x23, 0x04, 0xf9, 0x0a, 0xff, 0x08, 0x0e, 0xc9, 0x39, 0x0a, 0x01, 0x07, + 0xec, 0xe0, 0x05, 0xe8, 0x14, 0xd8, 0xe1, 0xfa, 0xd6, 0xf8, 0xed, 0xdb, + 0xff, 0x1d, 0xf5, 0x17, 0x0f, 0x1c, 0xdc, 0xed, 0xff, 0xff, 0x04, 0x13, + 0xf5, 0xe7, 0xd2, 0x12, 0xdb, 0xe1, 0x13, 0x11, 0x23, 0x0e, 0xf9, 0x31, + 0xdc, 0xef, 0x07, 0x0a, 0x20, 0xf2, 0xf9, 0x13, 0xff, 0x1c, 0x2a, 0xdf, + 0xdb, 0xe7, 0x11, 0xf2, 0xfd, 0xfb, 0x28, 0x00, 0x15, 0x03, 0x02, 0x20, + 0x07, 0xf7, 0x19, 0x13, 0x13, 0xf6, 0x09, 0xfe, 0xfd, 0x20, 0x14, 0xf5, + 0xf5, 0xfc, 0x14, 0x0e, 0x17, 0xfe, 0x15, 0x04, 0xf9, 0xf6, 0x1d, 0xf6, + 0x1b, 0xe4, 0xee, 0xfd, 0x00, 0xe9, 0xee, 0xce, 0x0f, 0x20, 0x05, 0x02, + 0x0d, 0x06, 0x05, 0xf8, 0xef, 0xdf, 0x16, 0x17, 0xe6, 0xf1, 0x10, 0xf3, + 0x06, 0x04, 0xdb, 0xfb, 0xe7, 0xf8, 0x02, 0x11, 0xff, 0x0d, 0x0a, 0xfa, + 0x27, 0x0a, 0xfc, 0xe8, 0x11, 0x17, 0xf0, 0x0d, 0x0d, 0xee, 0xdf, 0xdd, + 0xf1, 0x15, 0xd6, 0xf7, 0x00, 0xef, 0x2e, 0xe6, 0x24, 0xfd, 0xd5, 0x04, + 0xf0, 0x08, 0x08, 0xed, 0x22, 0x07, 0xe1, 0x09, 0xd0, 0x0b, 0x18, 0xe6, + 0x3f, 0x0a, 0xe5, 0xe2, 0xf9, 0x08, 0x02, 0xd6, 0x13, 0x15, 0xbd, 0x00, + 0x0e, 0xf8, 0xe2, 0xca, 0xec, 0x0e, 0xe6, 0xef, 0x15, 0x11, 0xcb, 0xdf, + 0xf9, 0x03, 0x22, 0x10, 0xfb, 0xf9, 0xe5, 0x08, 0xe1, 0x11, 0x10, 0xfc, + 0xfa, 0x00, 0xf8, 0x30, 0xe5, 0x08, 0x14, 0xe8, 0x12, 0xe2, 0x04, 0x19, + 0x0b, 0xfa, 0x33, 0xf3, 0xec, 0xfe, 0xf8, 0x25, 0xf8, 0x21, 0x28, 0xef, + 0x00, 0xde, 0xff, 0x2b, 0x03, 0xfc, 0x10, 0x0c, 0xcf, 0xfd, 0x19, 0x0a, + 0x0c, 0xf2, 0xf7, 0x0c, 0xfd, 0x02, 0x1c, 0xdf, 0x26, 0x0d, 0xf0, 0x0b, + 0xce, 0x15, 0xfb, 0xec, 0x27, 0xf6, 0xf9, 0xe5, 0xe2, 0xfb, 0xfd, 0xd8, + 0x28, 0xec, 0xe9, 0xf2, 0xca, 0x09, 0x02, 0x06, 0x0c, 0xfa, 0x05, 0x01, + 0xd5, 0x0a, 0x02, 0xfb, 0x04, 0x17, 0xdd, 0xfe, 0xeb, 0xf1, 0x09, 0x10, + 0x12, 0xff, 0x00, 0xe0, 0x26, 0xf7, 0xed, 0xf4, 0x00, 0xf2, 0xfa, 0x07, + 0x02, 0xf5, 0x06, 0xe8, 0x03, 0xfd, 0xdc, 0xf2, 0xc2, 0xff, 0x0b, 0xd6, + 0x25, 0x04, 0xe9, 0xf0, 0xd9, 0x08, 0x09, 0xc5, 0x23, 0x12, 0xf6, 0x13, + 0x11, 0xf3, 0x18, 0xf0, 0x34, 0xfe, 0xfe, 0xed, 0xea, 0x02, 0x17, 0xdc, + 0x1b, 0x1b, 0xea, 0xfe, 0xea, 0xfe, 0xf2, 0xc4, 0xfd, 0x04, 0xe9, 0x0d, + 0x0d, 0x09, 0xca, 0xd4, 0xe1, 0x04, 0x1e, 0xff, 0x0f, 0xef, 0xd6, 0x0f, + 0xd5, 0xf8, 0x26, 0xd6, 0x33, 0xe8, 0xf5, 0x3b, 0xf1, 0xe8, 0x39, 0xe8, + 0x08, 0xe5, 0x01, 0x02, 0x04, 0xf6, 0x19, 0x0a, 0xd0, 0xeb, 0x0b, 0x15, + 0xf7, 0x0e, 0x23, 0xf6, 0xf4, 0xd8, 0xf4, 0x17, 0x23, 0x25, 0x14, 0x01, + 0xd7, 0xfd, 0xf9, 0x1f, 0x1b, 0x11, 0x0a, 0x18, 0xf5, 0xf5, 0x0f, 0xe0, + 0x2e, 0x01, 0xe5, 0xdb, 0xe2, 0xf2, 0x14, 0xfa, 0x2a, 0x00, 0xe2, 0xea, + 0xfd, 0x0e, 0xfc, 0xc1, 0x35, 0x08, 0xf6, 0xf9, 0xec, 0x00, 0x06, 0x00, + 0x0b, 0xf6, 0x01, 0xfe, 0xea, 0x0b, 0x08, 0x05, 0xe4, 0xea, 0xd7, 0xfd, + 0xee, 0xf3, 0x0c, 0x0c, 0x0d, 0x02, 0xfd, 0xee, 0x17, 0x10, 0x13, 0xfd, + 0x07, 0x03, 0xf8, 0x0c, 0xd4, 0xed, 0xfe, 0x07, 0xf4, 0xee, 0xf4, 0x03, + 0xc2, 0x18, 0x2c, 0xd1, 0x33, 0xd8, 0xdb, 0xfa, 0xed, 0x10, 0x1c, 0xe3, + 0x37, 0x0a, 0xea, 0xfe, 0xf6, 0xef, 0x20, 0xed, 0x32, 0xf7, 0xf5, 0xf3, + 0xca, 0xfd, 0x0a, 0xcf, 0x0d, 0x10, 0xde, 0x07, 0x18, 0x10, 0xf0, 0xd6, + 0x0c, 0x04, 0xeb, 0x1a, 0xf9, 0x08, 0xc4, 0xcb, 0xe4, 0x0b, 0x19, 0xfc, + 0x29, 0xf6, 0xec, 0x07, 0xf3, 0xed, 0x2b, 0xe9, 0xfa, 0x02, 0xec, 0x2b, + 0xf0, 0xf2, 0x2d, 0xe8, 0xed, 0x00, 0x12, 0x13, 0xed, 0x1a, 0x3d, 0xf0, + 0x05, 0x04, 0xfc, 0x13, 0x10, 0x01, 0x40, 0xf2, 0x06, 0x02, 0xf9, 0x22, + 0x24, 0xff, 0x18, 0x00, 0xeb, 0xe8, 0x14, 0xf9, 0x25, 0xe0, 0xff, 0x03, + 0xe5, 0xfd, 0x08, 0xea, 0x2e, 0x0b, 0x05, 0xe7, 0xde, 0xe4, 0xf5, 0xea, + 0x3a, 0xf4, 0xf4, 0xe7, 0xed, 0xec, 0xf8, 0xee, 0x30, 0x0a, 0xdb, 0x05, + 0xf7, 0x16, 0xff, 0xf7, 0xfa, 0x1f, 0xef, 0xe4, 0xce, 0xf8, 0x13, 0x04, + 0xf9, 0x01, 0xe1, 0x03, 0xf9, 0xf9, 0x08, 0x04, 0xfa, 0xe4, 0xe7, 0xf7, + 0x28, 0xfd, 0xfd, 0x00, 0xfc, 0xfb, 0xef, 0x0a, 0xec, 0x0c, 0x0a, 0xd2, + 0x05, 0xfb, 0xcd, 0xfb, 0x9d, 0xea, 0x1c, 0xe5, 0x25, 0xe8, 0xea, 0x0b, + 0xf0, 0xf3, 0x0d, 0xab, 0x49, 0x0e, 0xeb, 0x00, 0xe2, 0x03, 0x29, 0xe0, + 0x3d, 0x06, 0xf7, 0xf8, 0xcf, 0x0c, 0x1a, 0xd6, 0x1f, 0xef, 0xfd, 0xff, + 0xef, 0x0c, 0xdb, 0xe0, 0x20, 0x06, 0xdf, 0x1a, 0xe7, 0xfc, 0xb2, 0xd1, + 0xdf, 0x13, 0x07, 0x1f, 0x0c, 0xf7, 0xde, 0x0a, 0xdb, 0xdf, 0x1a, 0xf5, + 0x29, 0x0d, 0xeb, 0x2c, 0xcf, 0x0e, 0x26, 0xfe, 0xef, 0x04, 0xf5, 0x14, + 0x09, 0x13, 0x34, 0xff, 0xfe, 0x0e, 0x06, 0x0e, 0x10, 0xf9, 0x2a, 0x0b, + 0xe6, 0xfe, 0xf1, 0x1a, 0x36, 0x29, 0x29, 0x05, 0x05, 0xd8, 0x14, 0x12, + 0x26, 0x0b, 0x18, 0xff, 0xd7, 0xdf, 0x0f, 0xed, 0x31, 0xf7, 0xfc, 0xec, + 0x0b, 0xef, 0x0c, 0xd2, 0x30, 0xf9, 0x04, 0xfe, 0xef, 0xe4, 0xfb, 0xd1, + 0x32, 0xe5, 0xee, 0xf0, 0x0c, 0xe6, 0x13, 0xed, 0x1e, 0x0b, 0xe4, 0xe0, + 0xfa, 0xf4, 0x14, 0xf4, 0x18, 0xf7, 0xd9, 0xf6, 0xed, 0xea, 0xfc, 0x06, + 0xfc, 0xf5, 0xed, 0xeb, 0x05, 0x03, 0x1b, 0x0b, 0xff, 0x0b, 0xef, 0x01, + 0xf1, 0x16, 0x05, 0x00, 0xee, 0x0a, 0xdb, 0x10, 0xb4, 0x14, 0x0f, 0xe1, + 0x1c, 0xfd, 0xf0, 0xf8, 0xc3, 0x11, 0x17, 0xba, 0x47, 0x15, 0xe6, 0x01, + 0xea, 0xf1, 0x0c, 0x08, 0x4a, 0x15, 0xf0, 0xf7, 0xea, 0x00, 0xf5, 0xd4, + 0xf1, 0xff, 0xe0, 0x0c, 0xf4, 0x17, 0xd8, 0xea, 0x03, 0xff, 0xd5, 0x18, + 0xfb, 0x07, 0xc7, 0xc9, 0xdd, 0xf3, 0x15, 0x0d, 0x22, 0xea, 0xdb, 0x0a, + 0xd6, 0x09, 0x1d, 0xe5, 0x2d, 0x04, 0xfc, 0x35, 0xc6, 0x0e, 0x33, 0xf1, + 0xd7, 0xea, 0x01, 0x1b, 0x0e, 0x01, 0x2a, 0xff, 0xef, 0xf1, 0xf7, 0x0f, + 0xff, 0x00, 0x3b, 0xe8, 0x0a, 0xff, 0xf4, 0x0d, 0x1f, 0x04, 0x17, 0xf7, + 0xdf, 0xec, 0x12, 0x26, 0x36, 0x07, 0x0c, 0x06, 0xe7, 0xd6, 0x13, 0xe3, + 0x30, 0x09, 0x00, 0xf5, 0xe0, 0xf3, 0x11, 0xe2, 0x38, 0x0d, 0xf6, 0x05, + 0xec, 0x05, 0x00, 0xe5, 0x24, 0xef, 0xfe, 0xf8, 0x00, 0xd8, 0x18, 0xf1, + 0x26, 0x0b, 0xf2, 0xfc, 0xe0, 0xe4, 0x06, 0x0b, 0x1a, 0x05, 0xc6, 0xf6, + 0xe8, 0xde, 0xfe, 0x0c, 0x03, 0x09, 0xfe, 0xe2, 0x18, 0x1b, 0xfb, 0xf7, + 0x06, 0xf1, 0xfe, 0xf6, 0xef, 0x1b, 0x07, 0x0d, 0x01, 0x0a, 0xed, 0xf0, + 0xad, 0x1a, 0x17, 0xd6, 0x37, 0xfd, 0xd8, 0xec, 0xca, 0xf1, 0x15, 0xc4, + 0x33, 0xf1, 0xed, 0xf0, 0xe9, 0x15, 0x0d, 0xf2, 0x36, 0xde, 0xfd, 0x0e, + 0xfb, 0x10, 0x0f, 0xf6, 0xf9, 0x0c, 0xea, 0xf0, 0xe5, 0x0b, 0xee, 0xc1, + 0x10, 0xf4, 0xe8, 0x1f, 0xee, 0x00, 0xd0, 0xe4, 0xe7, 0x13, 0x07, 0x27, + 0x12, 0xea, 0xea, 0x0f, 0xea, 0xf4, 0x14, 0xee, 0xfe, 0x09, 0xfb, 0x31, + 0xdb, 0x1b, 0x1c, 0xe7, 0xef, 0xf5, 0xf7, 0x1a, 0x06, 0x01, 0x2c, 0xed, + 0xfb, 0x04, 0xfa, 0x07, 0x19, 0xec, 0x2b, 0x0d, 0xfc, 0xd8, 0xfc, 0x0f, + 0x1f, 0xfc, 0x2d, 0xf3, 0xc9, 0xda, 0x0a, 0xfe, 0x29, 0x00, 0xfa, 0x09, + 0xe8, 0xf6, 0x21, 0xf3, 0x4a, 0x1a, 0xf8, 0x00, 0xe7, 0xf0, 0x21, 0x01, + 0x22, 0xf3, 0x00, 0xe9, 0x06, 0xe3, 0x15, 0xd7, 0x3d, 0x0c, 0x07, 0xf1, + 0xf3, 0xec, 0x17, 0xdf, 0x29, 0x1b, 0xfd, 0xfe, 0xeb, 0xed, 0x17, 0xf6, + 0x23, 0x0a, 0xea, 0xee, 0xf9, 0xf3, 0x0f, 0x0c, 0xf8, 0xf5, 0xed, 0xe8, + 0x1c, 0x14, 0x07, 0x17, 0x0b, 0x0d, 0xed, 0xf7, 0xed, 0x10, 0x07, 0xd5, + 0xf2, 0x09, 0xd6, 0xf7, 0xb5, 0xf6, 0x19, 0xc9, 0x25, 0x15, 0xe8, 0xf5, + 0xc4, 0xf9, 0x2a, 0xb0, 0x39, 0x0e, 0x02, 0x11, 0xf0, 0xf7, 0x1d, 0xeb, + 0x39, 0x10, 0x02, 0x15, 0xe0, 0x08, 0x01, 0xee, 0x1c, 0x1e, 0x08, 0x04, + 0xf2, 0x02, 0xe8, 0xda, 0xfa, 0xfb, 0xe0, 0xfe, 0x05, 0x02, 0xd3, 0xca, + 0xf4, 0xec, 0x10, 0x16, 0x05, 0x0d, 0xd7, 0x09, 0xdc, 0xf6, 0x1e, 0xf8, + 0x10, 0xed, 0xf7, 0x27, 0xf5, 0x08, 0x28, 0xee, 0xec, 0xe0, 0xf8, 0x17, + 0xfb, 0x23, 0x2e, 0xf1, 0xfa, 0xf5, 0xfc, 0x1a, 0x10, 0xf7, 0x32, 0xfb, + 0xfb, 0xe8, 0xf1, 0x03, 0x24, 0xeb, 0x25, 0xf9, 0xca, 0xf1, 0xfe, 0x01, + 0x2e, 0x07, 0x18, 0x03, 0xe5, 0xea, 0x10, 0xfa, 0x3b, 0x07, 0x0f, 0x11, + 0x04, 0xf7, 0x1d, 0xf1, 0x24, 0xd9, 0x08, 0xef, 0x02, 0xdd, 0x07, 0xc8, + 0x2c, 0x0d, 0x06, 0xec, 0x17, 0xda, 0x21, 0xdf, 0x34, 0xd9, 0xfb, 0xf2, + 0xf4, 0xec, 0x0e, 0x0a, 0x0f, 0x0f, 0xdb, 0xf0, 0xfb, 0xe6, 0x0f, 0x00, + 0x04, 0xf9, 0x01, 0x05, 0x05, 0xfe, 0x08, 0xf3, 0x0e, 0xf2, 0xfb, 0x01, + 0xfd, 0x18, 0x1d, 0xf6, 0xee, 0x06, 0xcf, 0xfc, 0xae, 0x27, 0x21, 0xd2, + 0x33, 0x03, 0xe0, 0xe0, 0xc9, 0xfb, 0x3a, 0xbd, 0x4d, 0x04, 0xe8, 0xf5, + 0xe6, 0xeb, 0x19, 0xf2, 0x4b, 0x1d, 0xfc, 0xf7, 0xd9, 0xff, 0xfe, 0xea, + 0x0f, 0x04, 0x0e, 0x00, 0xed, 0x19, 0xe9, 0xe9, 0xff, 0x11, 0xef, 0x14, + 0x01, 0x17, 0xbc, 0xb5, 0xef, 0x0c, 0x22, 0x27, 0x0f, 0x01, 0xd4, 0x03, + 0xce, 0x01, 0x25, 0xff, 0xf9, 0xf0, 0x0a, 0x1c, 0xe5, 0x0f, 0x1c, 0xee, + 0xf4, 0xf1, 0xf4, 0x0c, 0x00, 0x08, 0x1c, 0xf4, 0xd5, 0xf1, 0xfc, 0x1f, + 0x11, 0x00, 0x18, 0x03, 0xf7, 0xe4, 0xff, 0x07, 0x09, 0x1a, 0x18, 0xff, + 0xea, 0xec, 0xfd, 0x13, 0x2b, 0xf8, 0x0c, 0xfa, 0xdf, 0xf6, 0x11, 0xda, + 0x2a, 0xdc, 0xfc, 0xff, 0xff, 0xec, 0x12, 0xe1, 0x37, 0xfd, 0xeb, 0xfe, + 0xea, 0xd1, 0x12, 0xfa, 0x28, 0x1a, 0x0d, 0xf0, 0xf7, 0xe0, 0x0c, 0xeb, + 0x35, 0x14, 0xeb, 0x00, 0xeb, 0xe7, 0x1b, 0xfc, 0x09, 0x00, 0xf2, 0x04, + 0xf9, 0xe5, 0x1a, 0x0e, 0x08, 0x12, 0xf8, 0xfe, 0x09, 0x0f, 0x0d, 0xea, + 0x03, 0xe1, 0xfe, 0xf2, 0xec, 0x0d, 0x02, 0xdb, 0x04, 0x1d, 0xd4, 0x01, + 0xca, 0x13, 0x29, 0xca, 0x28, 0x04, 0xe2, 0xf1, 0xdb, 0x0b, 0x2c, 0xcd, + 0x44, 0x00, 0xe7, 0xf4, 0xd0, 0x12, 0x15, 0xff, 0x42, 0x11, 0x05, 0xfd, + 0xd9, 0x11, 0x1c, 0xf4, 0x15, 0xec, 0xf2, 0x24, 0xd6, 0x1d, 0xec, 0xda, + 0xf5, 0xec, 0xe5, 0x22, 0xf2, 0x0b, 0xbd, 0xd0, 0xeb, 0x05, 0x07, 0x1b, + 0x01, 0xed, 0xf5, 0x02, 0xcf, 0x08, 0x15, 0xfd, 0x1c, 0xe5, 0x04, 0x19, + 0xc7, 0x25, 0x22, 0xf3, 0xde, 0xfb, 0xfb, 0x20, 0xf6, 0xeb, 0x25, 0xfe, + 0xf5, 0x08, 0xf5, 0x17, 0x0e, 0x04, 0x1c, 0xf9, 0xee, 0xec, 0xe1, 0x06, + 0x12, 0xff, 0x2a, 0x13, 0xed, 0xfe, 0x05, 0x18, 0x25, 0x20, 0x09, 0x13, + 0xea, 0xd7, 0x05, 0x06, 0x33, 0x25, 0xff, 0x0a, 0xf0, 0xea, 0x17, 0xe1, + 0x30, 0xfa, 0x0d, 0x0a, 0x04, 0x00, 0x0e, 0xe9, 0x16, 0x20, 0x0d, 0x02, + 0xe8, 0xed, 0x07, 0xe8, 0x3c, 0xf1, 0xd9, 0xfa, 0xe1, 0xed, 0x18, 0xfc, + 0xf0, 0x09, 0xe3, 0x05, 0xfe, 0xd1, 0x0b, 0x0e, 0xf5, 0x25, 0xfd, 0xfb, + 0x30, 0x1e, 0x08, 0xfc, 0x0c, 0x21, 0xea, 0xfc, 0xe5, 0x1e, 0x16, 0xf5, + 0xf4, 0xfc, 0xf0, 0xea, 0xc4, 0x21, 0x27, 0xe9, 0x2b, 0xdb, 0xdb, 0xec, + 0xe5, 0xfe, 0x37, 0xe2, 0x46, 0x25, 0xfa, 0xec, 0xe4, 0xf3, 0x19, 0xf2, + 0x4c, 0x06, 0x00, 0xfb, 0xeb, 0x10, 0x10, 0xf7, 0x2a, 0xf8, 0xe9, 0x18, + 0xee, 0x21, 0xe8, 0xd5, 0xf4, 0x0a, 0xed, 0x24, 0xfe, 0xf9, 0xb2, 0xbc, + 0xf3, 0x1d, 0x00, 0x2f, 0x07, 0x08, 0xe1, 0xf1, 0xed, 0x27, 0x27, 0xfe, + 0x22, 0xfd, 0x02, 0x20, 0xd8, 0x05, 0x25, 0xec, 0xf1, 0xff, 0x0a, 0x0f, + 0xe6, 0xfe, 0x46, 0xfd, 0xe1, 0xca, 0xf7, 0x22, 0x03, 0x08, 0x21, 0xf5, + 0x0f, 0xf7, 0xfb, 0x0c, 0xfb, 0x14, 0x2d, 0x03, 0xe5, 0xe4, 0x09, 0x0b, + 0x1a, 0xe6, 0x01, 0x28, 0xe9, 0xd6, 0x0b, 0xf7, 0x2c, 0xfb, 0x11, 0xee, + 0x0b, 0xed, 0x17, 0xf0, 0x3c, 0xf5, 0x08, 0xfa, 0xf8, 0xcd, 0x17, 0xfa, + 0x39, 0xea, 0x11, 0xf5, 0xed, 0xee, 0x0a, 0xec, 0x41, 0xd6, 0xe7, 0xf9, + 0xfa, 0xc8, 0x15, 0xf7, 0x08, 0x0e, 0xe3, 0x08, 0xe8, 0xec, 0xfd, 0xfe, + 0xf1, 0x00, 0xe9, 0xf4, 0x09, 0x26, 0x02, 0x16, 0xf0, 0x01, 0xef, 0x01, + 0xff, 0x03, 0x22, 0xdb, 0xfc, 0xf5, 0xde, 0xe5, 0xc4, 0x01, 0x28, 0xd4, + 0x38, 0x08, 0xd0, 0xec, 0xd5, 0x04, 0x2f, 0xce, 0x4e, 0xeb, 0xf9, 0xe7, + 0xdf, 0xf0, 0x1b, 0xf5, 0x42, 0xf1, 0xf6, 0x09, 0xd5, 0x0a, 0x0d, 0x08, + 0x04, 0x05, 0xe2, 0x0e, 0xd7, 0x19, 0xdb, 0xda, 0xe1, 0x25, 0xde, 0x15, + 0x0e, 0x14, 0xbd, 0xb0, 0xe3, 0xe5, 0x24, 0x1e, 0xf8, 0x0d, 0xd8, 0xf7, + 0xf2, 0xff, 0x18, 0xf5, 0x07, 0xf0, 0x02, 0x25, 0xd5, 0x1e, 0x2e, 0xdf, + 0xe7, 0x05, 0xef, 0x11, 0xe8, 0xe7, 0x47, 0xf4, 0xe1, 0xde, 0x09, 0x36, + 0x1a, 0x11, 0x11, 0xf5, 0x12, 0xe5, 0xe7, 0x18, 0x01, 0x17, 0x2a, 0x03, + 0x05, 0xea, 0x09, 0x0b, 0x12, 0x04, 0x17, 0xf0, 0xee, 0xd7, 0x11, 0xed, + 0x3c, 0x17, 0x16, 0xff, 0x02, 0xdc, 0x21, 0xf3, 0x2e, 0xe5, 0x13, 0xef, + 0xec, 0xe2, 0x10, 0xd0, 0x2e, 0xee, 0xff, 0x01, 0xe0, 0xe5, 0x0b, 0xda, + 0x1f, 0xf8, 0xf6, 0xfb, 0x07, 0xdb, 0x05, 0xf6, 0x0c, 0xf3, 0xf0, 0x10, + 0xf9, 0xf5, 0xf2, 0x0d, 0x10, 0xf7, 0xf6, 0xff, 0x2b, 0x0d, 0x06, 0x1e, + 0xf3, 0x0c, 0xe9, 0x01, 0xf2, 0x23, 0xfe, 0xe9, 0xdd, 0x12, 0xdd, 0xf7, + 0xbb, 0x22, 0x1b, 0xd4, 0x38, 0x29, 0xd4, 0xcf, 0xf5, 0xf9, 0x27, 0xdd, + 0x47, 0x00, 0xf2, 0xe5, 0x09, 0xfc, 0x0e, 0xf9, 0x34, 0x0a, 0x02, 0xfd, + 0xec, 0x25, 0x1d, 0x03, 0x15, 0x09, 0xf1, 0x1b, 0xd0, 0x17, 0xda, 0xda, + 0xe7, 0x07, 0xe3, 0x15, 0xf1, 0x02, 0xb9, 0xce, 0xe6, 0x0c, 0x10, 0x31, + 0xfe, 0xf7, 0xd9, 0xfa, 0xed, 0xed, 0x33, 0xf4, 0x19, 0xe7, 0xfe, 0x3f, + 0xe5, 0x06, 0x2e, 0xe6, 0xf2, 0xdc, 0xf5, 0x18, 0xe6, 0x01, 0x2f, 0xee, + 0xe7, 0xe4, 0xfe, 0x2c, 0x03, 0xf7, 0x20, 0x05, 0x07, 0xe2, 0x06, 0x1e, + 0x05, 0xed, 0x2f, 0x03, 0xea, 0xf8, 0x0e, 0x0c, 0x1f, 0xff, 0x20, 0xf4, + 0xe8, 0xe1, 0x1c, 0xec, 0x22, 0x1e, 0x05, 0xfd, 0xf5, 0xca, 0x30, 0xe9, + 0x30, 0xe4, 0x14, 0xff, 0xf2, 0xdc, 0x17, 0xf8, 0x26, 0xe1, 0x0b, 0x01, + 0x11, 0xc2, 0x02, 0xf1, 0x36, 0x10, 0x02, 0x05, 0xed, 0xf1, 0x15, 0xfa, + 0x17, 0xf8, 0xf7, 0xf1, 0xe8, 0xd3, 0xfd, 0x08, 0xfb, 0x27, 0xf5, 0xf5, + 0x13, 0x06, 0x0b, 0xf0, 0x01, 0xf9, 0xd7, 0x0e, 0xec, 0x12, 0xfe, 0xfd, + 0xee, 0x25, 0xd8, 0xf1, 0xb2, 0x09, 0x1c, 0xbf, 0x34, 0xea, 0xc8, 0xea, + 0xdb, 0x0e, 0x24, 0xde, 0x47, 0xfe, 0xdc, 0xe0, 0xf3, 0x06, 0x20, 0xfe, + 0x2b, 0xf6, 0x18, 0x14, 0xcd, 0x19, 0x16, 0xfe, 0x1a, 0x15, 0xf8, 0x11, + 0xf4, 0x22, 0xd7, 0xcc, 0xdd, 0x15, 0xdc, 0x14, 0xf9, 0x02, 0xbb, 0xca, + 0xe3, 0xf3, 0x0d, 0x1e, 0x2a, 0x0c, 0xe4, 0x05, 0xe0, 0x18, 0x2a, 0x07, + 0x20, 0xed, 0xf6, 0x17, 0xcf, 0xf4, 0x2a, 0xd6, 0xfb, 0xce, 0x03, 0x37, + 0xe2, 0xfd, 0x1d, 0xfb, 0xe5, 0xe0, 0x05, 0x29, 0xef, 0x16, 0x23, 0xf7, + 0x01, 0xf4, 0x0c, 0x14, 0xff, 0xee, 0x31, 0xf9, 0x12, 0xf9, 0x14, 0xf6, + 0x0c, 0xf6, 0x0b, 0x0f, 0xd8, 0xdc, 0xfe, 0x0f, 0x37, 0xfa, 0x01, 0x09, + 0x04, 0xd1, 0x0b, 0x0c, 0x29, 0xf3, 0x0a, 0xf9, 0xed, 0xc2, 0x18, 0xf4, + 0x25, 0x18, 0x0f, 0x08, 0xf7, 0xed, 0x1f, 0xf7, 0x4f, 0x0e, 0xf0, 0xe4, + 0x00, 0xeb, 0xfa, 0x1a, 0x0c, 0x03, 0xe9, 0xfc, 0xf0, 0xcc, 0x06, 0x05, + 0xf2, 0x12, 0x04, 0xe2, 0x16, 0x0a, 0x0a, 0xf3, 0x0b, 0xf3, 0xdc, 0xfd, + 0x10, 0xfc, 0x0e, 0xe2, 0xe0, 0xfe, 0xf0, 0xff, 0xb1, 0x06, 0x1b, 0xe4, + 0x30, 0x13, 0xc6, 0xc3, 0xfa, 0x0c, 0x1e, 0xd9, 0x57, 0x11, 0xe1, 0xd6, + 0xfa, 0xee, 0x1d, 0xf7, 0x37, 0xea, 0xf0, 0x05, 0xef, 0x24, 0x1e, 0xf1, + 0x10, 0xe8, 0xeb, 0x19, 0xd1, 0x18, 0xf5, 0xc8, 0xf8, 0xec, 0xf5, 0x1f, + 0xf2, 0xff, 0xb3, 0xd2, 0xe6, 0x0e, 0x06, 0x2e, 0x07, 0x17, 0xe0, 0xf5, + 0x02, 0xf9, 0x20, 0x07, 0x16, 0x08, 0xe8, 0x1d, 0xd3, 0x08, 0x34, 0xda, + 0xf2, 0xce, 0xfb, 0x1f, 0xe1, 0x00, 0x2d, 0xdb, 0xdf, 0xcc, 0x05, 0xfb, + 0xf7, 0x00, 0x33, 0xf9, 0x0b, 0x01, 0x13, 0x28, 0xf8, 0x07, 0x24, 0xf8, + 0x0f, 0x03, 0x0d, 0xe9, 0x06, 0xfe, 0x18, 0xf9, 0xed, 0xf5, 0x0c, 0xe0, + 0x2c, 0x0e, 0xf9, 0x06, 0xfb, 0xce, 0x27, 0xe8, 0x29, 0x19, 0xf9, 0x01, + 0x0e, 0xc8, 0x25, 0xed, 0x30, 0xeb, 0x01, 0xfe, 0x10, 0xdc, 0x1e, 0x00, + 0x1e, 0x10, 0xf9, 0x00, 0xfc, 0xc8, 0x0e, 0x04, 0x13, 0x04, 0xf0, 0x02, + 0xfe, 0xd8, 0x0f, 0x1b, 0xf7, 0xe1, 0xf8, 0xde, 0x12, 0xe2, 0xef, 0x0a, + 0x02, 0xe0, 0xdd, 0xf1, 0x0e, 0x2a, 0x25, 0x15, 0xeb, 0x02, 0xf4, 0xf0, + 0xbf, 0xfc, 0x27, 0xdc, 0x42, 0x0f, 0xe9, 0xbf, 0xe8, 0x20, 0x33, 0xc9, + 0x3f, 0x10, 0xec, 0xf3, 0x03, 0x02, 0x2c, 0x04, 0x38, 0x06, 0x0a, 0xf9, + 0xe5, 0x1c, 0x3f, 0x0f, 0x0c, 0x25, 0xe2, 0x06, 0xe6, 0x03, 0xf4, 0xd7, + 0xfe, 0xf6, 0xe7, 0x2f, 0xfa, 0x03, 0xb6, 0xcb, 0xf1, 0x11, 0x0a, 0x2c, + 0xfc, 0x1e, 0xe0, 0xff, 0xc2, 0xdd, 0x1d, 0xf3, 0x10, 0xfa, 0x07, 0x1e, + 0xf6, 0x20, 0x07, 0xe6, 0xf1, 0x0a, 0xe8, 0x27, 0xf1, 0xf5, 0x24, 0xed, + 0xfd, 0xee, 0x13, 0x15, 0xe9, 0xe2, 0x22, 0xe5, 0xf9, 0xdd, 0x1d, 0x32, + 0x04, 0xfa, 0x25, 0x00, 0xee, 0xfd, 0x0b, 0x0e, 0x23, 0xfa, 0x0f, 0x01, + 0xf8, 0xf0, 0x15, 0xe4, 0x21, 0xf7, 0x10, 0xf9, 0xe7, 0xc3, 0x19, 0xe1, + 0x34, 0xff, 0xed, 0xf4, 0xef, 0xd7, 0x21, 0x01, 0x31, 0xee, 0xf7, 0xf2, + 0xf3, 0xe5, 0x0a, 0xee, 0x2e, 0x1e, 0xf2, 0x0c, 0x07, 0xc2, 0x08, 0x0a, + 0x14, 0x14, 0x00, 0xfc, 0xf9, 0xd6, 0xfb, 0xf8, 0xe5, 0xf1, 0xfa, 0xe0, + 0x15, 0x21, 0xef, 0x06, 0xf9, 0x00, 0xf5, 0xf4, 0x0b, 0x0b, 0x18, 0x02, + 0xf5, 0x04, 0xdb, 0xfd, 0xcc, 0x32, 0x1d, 0xc9, 0x3b, 0x12, 0xd9, 0xaf, + 0xcf, 0x0f, 0x26, 0xde, 0x35, 0xe4, 0xdb, 0xd3, 0x22, 0x11, 0x2e, 0xfb, + 0x36, 0xfa, 0xfd, 0x02, 0xeb, 0x0f, 0x37, 0x0b, 0x14, 0x1d, 0xdd, 0x18, + 0xe0, 0x10, 0xe0, 0xdf, 0x14, 0xf9, 0xf0, 0x19, 0xf7, 0xfb, 0xc4, 0xe5, + 0xe7, 0x11, 0x01, 0x31, 0x1a, 0xf7, 0xd8, 0xf1, 0xe9, 0xf3, 0x21, 0xf9, + 0xfe, 0xe4, 0xe9, 0x02, 0xd0, 0x06, 0x14, 0xd7, 0xfc, 0xec, 0x06, 0x10, + 0xfc, 0xf0, 0x1c, 0xe7, 0xec, 0xe3, 0x03, 0x21, 0xe4, 0x04, 0x12, 0xf0, + 0xf3, 0xed, 0x16, 0x36, 0x02, 0xfd, 0x13, 0x11, 0xdf, 0xeb, 0x19, 0x07, + 0x10, 0x0c, 0xf9, 0x08, 0xf8, 0xf4, 0x1d, 0xfd, 0x1d, 0x16, 0xf4, 0x0a, + 0x08, 0xec, 0x0c, 0x09, 0x3d, 0xe0, 0x0b, 0xee, 0x10, 0xd1, 0x1e, 0x15, + 0x43, 0xeb, 0xfa, 0xf3, 0x05, 0xc7, 0xf2, 0xd9, 0x25, 0x20, 0xee, 0xe9, + 0xfd, 0xce, 0x16, 0x0c, 0x27, 0x06, 0x0a, 0x06, 0xf9, 0xd6, 0x0b, 0x05, + 0xe8, 0x02, 0xe8, 0xd2, 0x10, 0x01, 0xf2, 0x15, 0x09, 0x04, 0xd3, 0xe2, + 0xfe, 0xf0, 0x32, 0x1b, 0xd9, 0xf5, 0xea, 0xcc, 0xcb, 0x10, 0x1c, 0xf1, + 0x3b, 0x02, 0xd4, 0xbf, 0xca, 0xfe, 0x12, 0xdb, 0x3b, 0xf8, 0xd5, 0xe7, + 0x13, 0x10, 0x1a, 0xf4, 0x38, 0x09, 0x08, 0xee, 0xf4, 0xf4, 0x3c, 0xf7, + 0x15, 0x04, 0xe4, 0xfa, 0xf4, 0x04, 0xee, 0xf4, 0x07, 0xf8, 0xe9, 0x3b, + 0xe2, 0x1f, 0xd5, 0xed, 0xe6, 0xfd, 0x18, 0x49, 0x21, 0x06, 0xd8, 0xde, + 0xfa, 0xf0, 0x1b, 0xfe, 0xde, 0x08, 0xf7, 0x14, 0xc7, 0x0f, 0x1d, 0xcf, + 0x00, 0xea, 0xff, 0x1b, 0xd5, 0x08, 0x0d, 0xd9, 0xf1, 0xf4, 0x16, 0x23, + 0xd8, 0x0c, 0x29, 0xdc, 0xf1, 0xf2, 0x21, 0x49, 0xfc, 0xe2, 0x08, 0x01, + 0xf0, 0xf8, 0x17, 0xf9, 0x0f, 0xf5, 0xfa, 0x1a, 0xef, 0xec, 0x09, 0xeb, + 0x1a, 0x0c, 0x17, 0x09, 0x11, 0xe9, 0x1a, 0xf7, 0x29, 0xf9, 0xfd, 0x07, + 0x01, 0xdd, 0x0a, 0xec, 0x22, 0x15, 0x03, 0xfd, 0xe2, 0xd2, 0x15, 0xec, + 0x4d, 0xd7, 0xfc, 0xf6, 0x0b, 0xcc, 0x0e, 0x04, 0x03, 0xf7, 0xfb, 0xfb, + 0x0d, 0xeb, 0x19, 0x07, 0xf4, 0xf4, 0xe5, 0xde, 0x22, 0x07, 0xea, 0xf7, + 0xeb, 0x23, 0xc8, 0xee, 0x03, 0x04, 0x0f, 0x19, 0xc3, 0xf8, 0x06, 0xd0, + 0xf7, 0xfe, 0x0e, 0xe7, 0x0a, 0x02, 0xb0, 0xb8, 0x00, 0xfb, 0x18, 0x0f, + 0x22, 0xf7, 0xe9, 0xdc, 0x09, 0x15, 0x23, 0x0d, 0x22, 0x13, 0xe2, 0xed, + 0xeb, 0x18, 0x20, 0x0b, 0x12, 0xfc, 0x02, 0xf1, 0xdb, 0x0e, 0xe1, 0x04, + 0xdb, 0x0f, 0xf3, 0x1a, 0x06, 0xef, 0xdb, 0xdc, 0xdd, 0xfb, 0x00, 0x2a, + 0x20, 0xfd, 0xc1, 0xe3, 0xef, 0x01, 0x14, 0xf2, 0x14, 0x00, 0x0f, 0x28, + 0xd9, 0xff, 0xf4, 0xdc, 0x09, 0xfa, 0x1c, 0x08, 0xd1, 0x03, 0x0a, 0xf4, + 0xe4, 0xdb, 0x20, 0x30, 0xea, 0x06, 0x11, 0xe2, 0x26, 0xf7, 0x16, 0x22, + 0xf9, 0x07, 0x02, 0xf5, 0xf6, 0xfb, 0x1d, 0x0c, 0x16, 0x0a, 0x07, 0xf9, + 0x11, 0xde, 0x20, 0x08, 0x19, 0x04, 0x0a, 0x0b, 0x0c, 0xf7, 0xf4, 0xfc, + 0x41, 0xf1, 0xf8, 0x16, 0x09, 0xdc, 0x0e, 0x1a, 0x2b, 0x1f, 0xe7, 0xfe, + 0x01, 0xe0, 0xfd, 0xe2, 0x34, 0xec, 0xf3, 0xf5, 0x03, 0xec, 0x0b, 0xfb, + 0x04, 0xf6, 0xdd, 0xfd, 0x06, 0x14, 0x0d, 0xfa, 0xfc, 0xf1, 0x0a, 0xca, + 0x01, 0xec, 0x0e, 0x0e, 0xec, 0xd7, 0xee, 0xd4, 0xf2, 0xfe, 0x16, 0xfa, + 0xbd, 0x0d, 0xef, 0xcb, 0xc4, 0xee, 0xed, 0x13, 0x10, 0x19, 0xf8, 0xb1, + 0xf1, 0xe3, 0x00, 0xf3, 0x0c, 0xf6, 0xde, 0xc6, 0x15, 0x27, 0x14, 0x29, + 0x15, 0xf6, 0xf4, 0xf5, 0xe7, 0x00, 0x0b, 0x2f, 0x0c, 0xef, 0x03, 0x0f, + 0xfd, 0x08, 0xf3, 0xf9, 0xf9, 0x05, 0x0d, 0x34, 0x15, 0x1b, 0xc8, 0xd1, + 0xf2, 0x1b, 0x0a, 0x22, 0x12, 0x11, 0xe9, 0xf4, 0xe1, 0x2a, 0x20, 0x03, + 0xf2, 0xf8, 0x14, 0x0b, 0xd0, 0xf4, 0x0e, 0xbf, 0xc6, 0xd8, 0x04, 0x05, + 0xf8, 0xf4, 0x04, 0xc9, 0xea, 0xfd, 0xf7, 0xfa, 0xe3, 0x1b, 0x11, 0xde, + 0x0c, 0x11, 0x25, 0x29, 0xe5, 0x02, 0xef, 0xef, 0x02, 0xfa, 0x1a, 0x21, + 0x19, 0x09, 0x08, 0x05, 0x04, 0xe5, 0xfa, 0xed, 0x2d, 0x26, 0xfa, 0x17, + 0xf6, 0xe8, 0x12, 0x12, 0x31, 0xfc, 0x0d, 0x00, 0xf7, 0xeb, 0x19, 0xf1, + 0x2a, 0x06, 0x14, 0xec, 0x08, 0xd3, 0x21, 0x07, 0x32, 0xe3, 0x02, 0x0b, + 0xfb, 0xd8, 0x27, 0x07, 0x05, 0xe6, 0xf5, 0xf5, 0x0a, 0xf7, 0x2c, 0x2a, + 0xd8, 0x1b, 0xda, 0xf7, 0xea, 0xf6, 0xf9, 0x0e, 0xf8, 0x0c, 0x05, 0xc7, + 0xd6, 0x06, 0x12, 0xe3, 0xe1, 0xe1, 0xd8, 0xdb, 0xc6, 0xf8, 0xe6, 0xfa, + 0x0c, 0x07, 0xf8, 0xe7, 0xe1, 0x0f, 0x00, 0xf3, 0x03, 0xf0, 0xde, 0xcc, + 0xf5, 0xfc, 0xef, 0x1e, 0x16, 0x13, 0xfb, 0xf4, 0x03, 0xe9, 0xfc, 0xfa, + 0x15, 0xe8, 0x15, 0x09, 0xf1, 0x0d, 0xdb, 0x0a, 0xe8, 0x09, 0xf5, 0x1a, + 0x04, 0xf8, 0xd8, 0xd4, 0x04, 0xee, 0x25, 0x29, 0x09, 0xfe, 0xf3, 0xf5, + 0xd4, 0x0a, 0x15, 0x19, 0xf5, 0x12, 0xfe, 0x04, 0xe7, 0x01, 0xeb, 0xde, + 0xbe, 0xfe, 0x09, 0x12, 0xdf, 0x13, 0xe0, 0xef, 0xc7, 0xff, 0x03, 0x08, + 0xfe, 0xf2, 0x19, 0xe0, 0xe4, 0x0c, 0x22, 0x1e, 0x05, 0xf7, 0x16, 0xf2, + 0xf9, 0x06, 0x17, 0xf6, 0x0c, 0x1e, 0x23, 0x08, 0xfe, 0xdc, 0xfd, 0x17, + 0x11, 0xdf, 0xf5, 0x0f, 0x01, 0x03, 0x08, 0xee, 0x1b, 0x02, 0x0b, 0x1b, + 0x0c, 0x16, 0x1a, 0x00, 0x0f, 0x26, 0x14, 0xf8, 0xf4, 0xf3, 0x19, 0x16, + 0x22, 0x0a, 0xd0, 0xf9, 0xf1, 0x05, 0x2b, 0x1e, 0x1e, 0xef, 0xf5, 0x06, + 0x05, 0xe7, 0x3f, 0x2a, 0x06, 0xf0, 0x15, 0x14, 0x13, 0x20, 0x1b, 0xde, + 0x10, 0x05, 0x33, 0xf8, 0x08, 0x04, 0x17, 0x0d, 0x0f, 0xf6, 0x01, 0xed, + 0x28, 0x25, 0x1c, 0x13, 0xfb, 0xea, 0xfb, 0xf3, 0x1c, 0xf9, 0x1f, 0xf0, + 0xfb, 0x17, 0xf8, 0xff, 0x10, 0xf7, 0x0b, 0x24, 0x04, 0x00, 0x0d, 0x0c, + 0xf7, 0x0a, 0x16, 0x13, 0xf8, 0x05, 0x0a, 0xf1, 0xf5, 0xee, 0xf8, 0x14, + 0x0e, 0xed, 0xfe, 0x1b, 0xfe, 0x17, 0x13, 0x10, 0x12, 0x21, 0x1c, 0xfa, + 0xe5, 0x0b, 0x08, 0x0c, 0x10, 0x1b, 0x03, 0xef, 0x0d, 0x05, 0x0a, 0xf0, + 0x04, 0x11, 0x15, 0x00, 0xfd, 0xef, 0x02, 0x18, 0xf4, 0x09, 0xfa, 0xf6, + 0x02, 0xf7, 0xfd, 0x13, 0xef, 0x13, 0xf7, 0xf9, 0x17, 0x0f, 0xfa, 0xf8, + 0x15, 0xff, 0x04, 0xef, 0xf0, 0x15, 0xfa, 0xfe, 0xf0, 0xf4, 0xed, 0x06, + 0x1c, 0x02, 0xfb, 0xf7, 0x05, 0xfb, 0x0c, 0xef, 0xf4, 0xf0, 0xf6, 0xec, + 0x17, 0xf3, 0xf5, 0xef, 0x02, 0xfd, 0xe5, 0x21, 0x0c, 0xf1, 0x1e, 0x08, + 0xf1, 0x0b, 0xf7, 0x09, 0x1d, 0xf2, 0xf9, 0xf2, 0xfb, 0x0e, 0xed, 0xf8, + 0xfa, 0xdd, 0xf0, 0xfd, 0xdb, 0x1a, 0xf4, 0xef, 0x0c, 0x06, 0x0f, 0xdf, + 0xe2, 0x06, 0x06, 0xee, 0xfa, 0x0d, 0x17, 0xfc, 0xf9, 0x15, 0x1a, 0xe4, + 0xfb, 0x0c, 0x1a, 0xfc, 0x1b, 0x04, 0x07, 0x20, 0xff, 0x09, 0x0f, 0xf2, + 0x26, 0x19, 0x1f, 0x0d, 0x02, 0x16, 0x03, 0x03, 0xfd, 0x05, 0x01, 0x1b, + 0x0a, 0x11, 0xfa, 0x21, 0x13, 0xfb, 0x0c, 0x05, 0xf3, 0xdd, 0xe4, 0xdc, + 0x22, 0x1b, 0x15, 0x14, 0x0e, 0xe8, 0x00, 0xf7, 0xf8, 0xf4, 0x0b, 0x0b, + 0xfd, 0x21, 0xe3, 0x0f, 0xe1, 0x22, 0x01, 0x21, 0x0b, 0x1f, 0x09, 0x10, + 0xe2, 0x18, 0x11, 0x0e, 0xed, 0x01, 0x14, 0x12, 0xfd, 0x11, 0xf6, 0xe9, + 0x20, 0xe1, 0xf5, 0x1b, 0x27, 0x22, 0xfa, 0xf7, 0xfe, 0x13, 0xf6, 0xdc, + 0x06, 0x0d, 0xf4, 0x05, 0x20, 0x0d, 0x0b, 0xe4, 0x15, 0x28, 0x0c, 0x00, + 0xf5, 0x07, 0x0c, 0x0a, 0x06, 0x0e, 0xf3, 0xfb, 0xfe, 0x04, 0x08, 0xf4, + 0xef, 0x03, 0xe4, 0xeb, 0x06, 0xee, 0xed, 0xdb, 0xeb, 0x1d, 0xf4, 0xfa, + 0x0c, 0xfc, 0xfe, 0x11, 0xf7, 0xf8, 0xf5, 0xef, 0xe7, 0xfc, 0x1b, 0xdc, + 0x17, 0xfd, 0xfe, 0x00, 0xea, 0xf4, 0xf1, 0xf7, 0x0f, 0x21, 0x04, 0xfd, + 0x0d, 0x0c, 0x0a, 0x14, 0xfd, 0x19, 0x09, 0x01, 0xfd, 0xe2, 0x0c, 0x0c, + 0xe0, 0x25, 0xfb, 0xff, 0x0d, 0x18, 0xf6, 0x0b, 0x19, 0x12, 0x10, 0x09, + 0x0b, 0x06, 0x12, 0x1c, 0x10, 0x03, 0x13, 0x0a, 0x05, 0x0f, 0x09, 0x01, + 0x21, 0xe4, 0x01, 0x26, 0xf9, 0xf4, 0x05, 0x19, 0x00, 0xff, 0x0b, 0xff, + 0x16, 0x09, 0xe7, 0xee, 0xed, 0xf5, 0x0f, 0x2f, 0xee, 0x19, 0x03, 0x0a, + 0x10, 0xee, 0xf7, 0x2e, 0xf4, 0x08, 0xf7, 0xee, 0x07, 0x00, 0xfc, 0x0e, + 0xf0, 0x12, 0x08, 0x05, 0xed, 0x11, 0xfc, 0xfb, 0xf7, 0x25, 0xf1, 0x05, + 0x0c, 0xf9, 0xfa, 0x03, 0x0c, 0x16, 0x04, 0x25, 0xf8, 0xe7, 0xfc, 0x11, + 0x0d, 0x19, 0xd8, 0xfa, 0x0b, 0x06, 0xfd, 0xef, 0x13, 0xf6, 0xff, 0x0e, + 0xf9, 0x04, 0xf1, 0xdc, 0xfb, 0xe1, 0xf6, 0x0b, 0x15, 0x07, 0xf7, 0x02, + 0x0e, 0xf1, 0xfd, 0xe3, 0xeb, 0x07, 0xf1, 0xef, 0x03, 0xfe, 0xf8, 0x07, + 0x10, 0xf7, 0x00, 0xf9, 0xf2, 0x0e, 0xf9, 0xf2, 0x1d, 0xf5, 0xd8, 0xff, + 0xe6, 0x18, 0x2a, 0x1b, 0x03, 0x16, 0xfe, 0xf4, 0xf5, 0xfd, 0x04, 0x01, + 0xfe, 0xfe, 0x07, 0xfc, 0x0e, 0xfa, 0x15, 0xeb, 0x02, 0x15, 0xea, 0xfd, + 0x04, 0xe5, 0xfe, 0xed, 0xfe, 0x1a, 0x09, 0x2a, 0x1b, 0xdf, 0xfb, 0xf8, + 0xf1, 0x04, 0x1a, 0x34, 0x07, 0xf9, 0x0d, 0xf5, 0xef, 0xec, 0x10, 0x1a, + 0x0b, 0x0f, 0x13, 0xfe, 0x10, 0x22, 0x1e, 0x02, 0xe6, 0xf7, 0x11, 0xfa, + 0x11, 0xfc, 0x1b, 0x21, 0x12, 0xf4, 0x18, 0x16, 0x29, 0xe4, 0x0c, 0x2e, + 0x12, 0x07, 0x20, 0xf6, 0x1d, 0xf4, 0x12, 0x33, 0xf4, 0xee, 0xfe, 0x05, + 0x06, 0xfb, 0x13, 0x0c, 0x0e, 0xf0, 0x00, 0xf8, 0xee, 0xf3, 0x17, 0x00, + 0xf7, 0xfb, 0xfc, 0x0f, 0xf4, 0xd5, 0x0a, 0xed, 0xeb, 0xf5, 0xe9, 0xef, + 0xd8, 0xf0, 0xf8, 0xe2, 0x19, 0xf7, 0xf8, 0x0a, 0x0b, 0x09, 0xfa, 0xe7, + 0x0f, 0xfc, 0xe8, 0x02, 0x00, 0x1a, 0xfe, 0xfd, 0x1b, 0xe6, 0xef, 0x0f, + 0xe3, 0x10, 0xf1, 0xe2, 0x0b, 0x0e, 0x06, 0x29, 0x00, 0x01, 0xf3, 0x00, + 0x11, 0x04, 0xf2, 0xf7, 0xea, 0xf8, 0xe0, 0x09, 0x0e, 0x13, 0xf4, 0x00, + 0x09, 0xfa, 0xf5, 0x0c, 0xff, 0x18, 0x08, 0x0d, 0xfa, 0xde, 0xfa, 0x03, + 0xf2, 0xf3, 0x1b, 0xeb, 0x06, 0xea, 0xfb, 0xff, 0x0d, 0xf5, 0x10, 0x17, + 0xf8, 0xe8, 0xf1, 0xf1, 0xf5, 0x00, 0x03, 0x0a, 0x09, 0x0a, 0xf3, 0xfb, + 0x33, 0x26, 0xe7, 0x17, 0xe3, 0xfa, 0x1f, 0x24, 0xfc, 0x07, 0x02, 0xe2, + 0xeb, 0x08, 0x2c, 0xf8, 0x02, 0x1f, 0x04, 0xeb, 0x0b, 0x04, 0x17, 0xf7, + 0xff, 0x1c, 0xed, 0x00, 0x3f, 0xd5, 0x17, 0x1d, 0xfe, 0x03, 0xf1, 0x1c, + 0x17, 0xec, 0x0e, 0x54, 0xee, 0xf5, 0x25, 0xfa, 0x08, 0xee, 0x13, 0x32, + 0x0e, 0xd8, 0x09, 0x0f, 0xee, 0xe5, 0x06, 0x10, 0xf4, 0xfb, 0xe4, 0xfb, + 0x09, 0xde, 0x13, 0xff, 0x02, 0xf9, 0xec, 0x0a, 0x00, 0xe9, 0xfd, 0xdc, + 0x06, 0x04, 0xdb, 0x06, 0x01, 0xf8, 0x09, 0xe2, 0x0c, 0x14, 0xda, 0xfe, + 0x20, 0xe3, 0x09, 0xda, 0x14, 0x12, 0xe1, 0x05, 0xff, 0xf3, 0x00, 0x08, + 0xfb, 0xf1, 0xfd, 0xf3, 0x04, 0xfa, 0x08, 0xff, 0x01, 0x1d, 0x0b, 0xfd, + 0x0a, 0xf4, 0xfb, 0xfc, 0xf9, 0x19, 0xed, 0xfc, 0xf2, 0x06, 0xe7, 0x02, + 0xf6, 0x0c, 0xfc, 0xfb, 0x01, 0x0c, 0xeb, 0x1b, 0xff, 0xff, 0x08, 0x1d, + 0xf7, 0xe8, 0xfc, 0xf4, 0x0c, 0xfa, 0xf1, 0xee, 0xed, 0xdd, 0xfc, 0x06, + 0x05, 0xdc, 0x1a, 0xfc, 0xf9, 0x07, 0xdf, 0x1b, 0x14, 0x0c, 0xfc, 0x01, + 0x16, 0xe1, 0xed, 0x09, 0x34, 0xee, 0xe4, 0x1c, 0x1b, 0xfc, 0x3b, 0x03, + 0x15, 0xf2, 0xeb, 0x14, 0x00, 0xdd, 0x24, 0x04, 0xf1, 0xed, 0xfd, 0xe6, + 0x32, 0xf9, 0x24, 0x04, 0x0e, 0x22, 0x03, 0x14, 0x2f, 0xf5, 0x1a, 0x37, + 0xf4, 0x18, 0x03, 0x0f, 0x4b, 0xe6, 0x0d, 0x5c, 0xf7, 0x1f, 0x1c, 0xe6, + 0x23, 0x0c, 0x15, 0x4e, 0xe0, 0x05, 0x1c, 0xec, 0xff, 0x04, 0x13, 0x15, + 0xee, 0x07, 0xec, 0x0c, 0xdd, 0xf8, 0x0e, 0x03, 0x0c, 0x1f, 0xe8, 0x0e, + 0xf5, 0xec, 0xfc, 0xe2, 0xe8, 0xfb, 0xf6, 0x00, 0xe5, 0xea, 0xf3, 0xd3, + 0xf5, 0xfd, 0xd2, 0xfd, 0x1b, 0xed, 0x09, 0xd1, 0x23, 0xfa, 0xd4, 0xf7, + 0xe9, 0xf0, 0x0a, 0xd6, 0x14, 0x03, 0xe6, 0x10, 0xf4, 0x18, 0xfe, 0xe1, + 0x0b, 0x25, 0xf5, 0xfc, 0xe9, 0xf2, 0xe9, 0xf4, 0x0d, 0xf5, 0x00, 0xf9, + 0x17, 0x02, 0xfd, 0x03, 0x04, 0xf8, 0xf5, 0x14, 0xe3, 0xd3, 0xeb, 0xe7, + 0x09, 0xf3, 0x14, 0x17, 0xee, 0xe6, 0xf6, 0xff, 0x11, 0x26, 0xf4, 0xf7, + 0x02, 0xfa, 0x05, 0x08, 0x16, 0xff, 0x0d, 0xf7, 0xf1, 0xf7, 0xe6, 0xfb, + 0x04, 0x04, 0x07, 0x02, 0x04, 0x09, 0xf5, 0xfc, 0x5f, 0xd6, 0xe7, 0x2a, + 0x23, 0xf4, 0x1b, 0x06, 0x01, 0xea, 0xe7, 0x05, 0x25, 0xe3, 0x25, 0x07, + 0xea, 0xfb, 0xfb, 0x09, 0x25, 0xde, 0x37, 0x04, 0x07, 0xe5, 0xff, 0x14, + 0x2f, 0x0a, 0x30, 0x23, 0x04, 0xf0, 0x23, 0xfe, 0x1c, 0xd2, 0x2b, 0x55, + 0x01, 0xe5, 0x26, 0xfe, 0x14, 0xed, 0x24, 0x46, 0xe6, 0xee, 0x0f, 0xfd, + 0xed, 0xef, 0x0e, 0x1e, 0x05, 0x0a, 0x12, 0xff, 0xe4, 0xf5, 0x0c, 0xed, + 0xfd, 0xea, 0x0d, 0x13, 0x1a, 0xe5, 0xfc, 0xc2, 0xef, 0x0a, 0xe2, 0x0f, + 0xfe, 0xff, 0x0c, 0xf0, 0xff, 0xdf, 0xea, 0x00, 0xf6, 0xe1, 0x04, 0xd8, + 0x26, 0x20, 0xdc, 0xf4, 0x19, 0x06, 0xe8, 0xd2, 0x10, 0x04, 0xf1, 0x02, + 0x0c, 0x06, 0xf0, 0xf0, 0x04, 0x1f, 0xf4, 0xf5, 0xed, 0xf1, 0xfa, 0xf1, + 0x04, 0x02, 0xf8, 0xfb, 0x04, 0xf1, 0xe5, 0xe4, 0x0a, 0xf0, 0xfe, 0xef, + 0x1c, 0xe3, 0xeb, 0xf3, 0x00, 0x17, 0x01, 0x13, 0x19, 0xda, 0xf8, 0x06, + 0xde, 0x11, 0xea, 0xf7, 0xf4, 0xef, 0x03, 0x04, 0x0b, 0xe8, 0x08, 0x0e, + 0xe2, 0xee, 0xde, 0x06, 0x0e, 0x29, 0xfb, 0xfa, 0x00, 0x02, 0xec, 0x1b, + 0x52, 0xff, 0xde, 0x3a, 0x2f, 0x13, 0x30, 0xe9, 0xff, 0xf6, 0xe7, 0x15, + 0x1d, 0xd9, 0x3c, 0x0f, 0xe6, 0x14, 0xee, 0x13, 0x1f, 0xe7, 0x33, 0x08, + 0xfc, 0x06, 0x0c, 0x08, 0x19, 0xd9, 0x2b, 0x1f, 0x07, 0x10, 0x24, 0x16, + 0x29, 0xfc, 0x31, 0x4d, 0xf0, 0xd9, 0x3f, 0xf2, 0x20, 0xe2, 0x25, 0x49, + 0xe5, 0xec, 0x0a, 0xf5, 0xf2, 0xd9, 0x22, 0x1f, 0xed, 0x22, 0x02, 0x0a, + 0x16, 0x08, 0xf7, 0xfb, 0x0e, 0xfb, 0xfb, 0x1d, 0xf3, 0x1c, 0xf6, 0xe1, + 0xcf, 0x19, 0xf4, 0x0f, 0xee, 0xf9, 0x04, 0xd1, 0xf9, 0xe2, 0xda, 0xf1, + 0x24, 0xf5, 0x07, 0xdf, 0x1d, 0xf9, 0xdb, 0x18, 0x0b, 0xea, 0x08, 0xca, + 0xf2, 0xfa, 0xec, 0x04, 0x0e, 0x17, 0xed, 0xf1, 0x06, 0x15, 0xfc, 0xfd, + 0x08, 0xfa, 0xe3, 0xe4, 0x0a, 0xfc, 0xee, 0x08, 0xf5, 0x09, 0xef, 0xee, + 0x06, 0xef, 0xe1, 0x19, 0x07, 0xe8, 0xe6, 0xdf, 0xea, 0x0d, 0xf1, 0x16, + 0xee, 0xed, 0xf8, 0x09, 0xfa, 0xfb, 0x0c, 0xf8, 0xeb, 0xda, 0x00, 0xfc, + 0x04, 0xfe, 0xf5, 0xff, 0xf6, 0xe1, 0x0c, 0x0a, 0x13, 0x0d, 0xf6, 0xf5, + 0x15, 0x07, 0xca, 0xec, 0x50, 0x0e, 0xd0, 0x26, 0x4c, 0xf8, 0x23, 0xeb, + 0xff, 0x08, 0xe3, 0x11, 0x2c, 0xf9, 0x2a, 0xf1, 0xe9, 0x0b, 0xe9, 0x0f, + 0x15, 0xec, 0x33, 0x11, 0x0c, 0x0d, 0x01, 0x01, 0x32, 0xe3, 0x41, 0x27, + 0x11, 0x02, 0x2e, 0x07, 0x09, 0xe3, 0x22, 0x4d, 0xf1, 0x05, 0x27, 0x03, + 0x25, 0xf5, 0x2c, 0x3b, 0xf4, 0x00, 0x16, 0x0b, 0xec, 0xfe, 0x17, 0x0d, + 0xff, 0xe7, 0xfe, 0x24, 0x06, 0xee, 0xf0, 0xe9, 0xfa, 0x1c, 0xf2, 0x19, + 0x08, 0xfa, 0xff, 0xd2, 0x01, 0x02, 0xea, 0x05, 0xf2, 0xf4, 0x0b, 0xd2, + 0xf9, 0x0d, 0xcd, 0x0d, 0x12, 0xf2, 0x0e, 0xe1, 0x1f, 0x00, 0xe7, 0x14, + 0x04, 0xff, 0x09, 0xdb, 0xfc, 0xd9, 0x06, 0xf9, 0xeb, 0x01, 0xef, 0xfa, + 0xfb, 0xf5, 0xfc, 0xfb, 0x14, 0xe2, 0xf9, 0xf5, 0x02, 0xfd, 0xfc, 0x01, + 0xf7, 0xf3, 0x00, 0xec, 0xe7, 0xf2, 0x00, 0xf1, 0x11, 0xec, 0xf0, 0xe9, + 0x11, 0x0a, 0x07, 0x04, 0x01, 0xee, 0xfb, 0xf2, 0x14, 0x01, 0x12, 0xf0, + 0xf2, 0xf1, 0xf0, 0xfb, 0x08, 0x03, 0xf8, 0x01, 0xe8, 0xf9, 0x17, 0x26, + 0x0f, 0xea, 0xf7, 0xf8, 0x1e, 0xfe, 0xf2, 0xf8, 0x3f, 0x00, 0xd4, 0x1c, + 0x53, 0xfe, 0x1e, 0x0f, 0xef, 0xdd, 0xed, 0x10, 0x19, 0xe7, 0x34, 0x0e, + 0xde, 0xdf, 0xfa, 0x0e, 0x29, 0xe3, 0x16, 0x09, 0x06, 0x12, 0xeb, 0xf9, + 0x32, 0xe0, 0x1a, 0x1d, 0xf3, 0xed, 0x10, 0x07, 0x31, 0xf2, 0x12, 0x52, + 0xeb, 0xf7, 0x1e, 0xf7, 0x1a, 0xdc, 0x3e, 0x33, 0xe3, 0xfb, 0x1f, 0x0b, + 0x08, 0xfe, 0x13, 0x1a, 0xf4, 0xf8, 0xfe, 0x08, 0xfc, 0xe9, 0xfe, 0xeb, + 0xe6, 0xf6, 0x02, 0x18, 0x02, 0xe8, 0xfb, 0xf3, 0x01, 0x08, 0xd7, 0x13, + 0x04, 0xe6, 0x02, 0xe6, 0xd7, 0x01, 0xd4, 0xf0, 0x0e, 0x05, 0x18, 0xe5, + 0x08, 0xe5, 0xd2, 0x16, 0x12, 0xfe, 0x0e, 0xd3, 0xfc, 0x1f, 0xe9, 0xf8, + 0x11, 0x06, 0xf3, 0xd5, 0xf8, 0xff, 0xf0, 0x04, 0x0a, 0xd9, 0xf8, 0xfd, + 0xf5, 0x12, 0xff, 0x06, 0x1b, 0xe6, 0xfe, 0xfe, 0xde, 0xee, 0xf6, 0x18, + 0xf1, 0xf8, 0x06, 0xf3, 0x02, 0xea, 0x04, 0x14, 0xfc, 0xee, 0xe6, 0x09, + 0xf9, 0xee, 0xe3, 0xe7, 0xfc, 0xd9, 0xef, 0xfc, 0x0a, 0x0c, 0x03, 0xf6, + 0xe2, 0x11, 0x0f, 0x19, 0x18, 0x10, 0xef, 0xe5, 0x22, 0xf5, 0xe5, 0xe9, + 0x4b, 0xf7, 0xdb, 0x0c, 0x4f, 0xde, 0x22, 0x16, 0x09, 0x16, 0xd1, 0xf8, + 0x19, 0xe0, 0x24, 0xfe, 0xb8, 0xfb, 0xe5, 0x12, 0x1c, 0xe3, 0x22, 0x09, + 0x05, 0x29, 0xf7, 0x10, 0x31, 0xe1, 0x33, 0x3f, 0xfd, 0xed, 0x04, 0x03, + 0x2e, 0xed, 0x30, 0x36, 0xee, 0x16, 0x2f, 0xf5, 0x1b, 0xdc, 0x3a, 0x56, + 0xe5, 0xef, 0x26, 0xff, 0x03, 0xd7, 0x31, 0x16, 0xef, 0xf1, 0x08, 0x13, + 0x01, 0x02, 0x03, 0xf1, 0xf2, 0x08, 0xff, 0x05, 0x12, 0xf2, 0xee, 0xda, + 0xed, 0xec, 0xea, 0xf7, 0x0c, 0xf1, 0x09, 0xe6, 0xe6, 0x00, 0xcc, 0x10, + 0x0d, 0x0d, 0x20, 0xf4, 0x18, 0x23, 0xec, 0xf9, 0x00, 0xe4, 0x07, 0xd4, + 0xfb, 0x16, 0xd2, 0x01, 0xe6, 0x01, 0x06, 0xf0, 0xfe, 0x03, 0xf3, 0x09, + 0x01, 0x0d, 0x05, 0xf7, 0xd4, 0x02, 0xfb, 0xfb, 0x08, 0xf0, 0x1f, 0xf3, + 0xfe, 0xeb, 0x02, 0x0e, 0x1b, 0x0f, 0x04, 0xf5, 0xf0, 0x1f, 0x14, 0xf7, + 0x06, 0xdc, 0xf9, 0xe9, 0x01, 0xff, 0x08, 0xf2, 0x06, 0xff, 0xff, 0xf3, + 0x05, 0x1a, 0xfc, 0xfa, 0xeb, 0xfb, 0xfa, 0x12, 0x20, 0xf6, 0xe0, 0xe8, + 0x1c, 0xfa, 0xd6, 0x0d, 0x2c, 0x04, 0xe1, 0x09, 0x3b, 0xd3, 0x2a, 0xee, + 0xf7, 0xed, 0xf1, 0xf7, 0x0d, 0xf0, 0x32, 0x0f, 0xc9, 0x0e, 0x00, 0x10, + 0x24, 0xfb, 0x31, 0xf0, 0xf4, 0xdd, 0xf5, 0x04, 0x25, 0xc7, 0x27, 0x25, + 0x16, 0x11, 0x2e, 0x09, 0x30, 0xd1, 0x2c, 0x34, 0xe6, 0xf0, 0x21, 0xf5, + 0x21, 0xc8, 0x40, 0x39, 0xde, 0xf0, 0x12, 0xf3, 0x10, 0xe8, 0x1f, 0x18, + 0xfa, 0xea, 0x07, 0x11, 0xdf, 0xed, 0xfa, 0xf0, 0x07, 0xef, 0xf3, 0x05, + 0x10, 0xe5, 0xf3, 0xe9, 0xe9, 0xe8, 0xd6, 0x01, 0xf9, 0x05, 0x0b, 0xee, + 0xf9, 0x12, 0xe3, 0x05, 0xfd, 0xe6, 0x16, 0xe2, 0x1b, 0x12, 0xc5, 0x00, + 0xfd, 0x02, 0x04, 0xd2, 0xff, 0xec, 0xf6, 0xfd, 0x00, 0xe4, 0xf7, 0xf3, + 0xeb, 0xfa, 0xf8, 0x0d, 0x03, 0xfa, 0xfe, 0xe4, 0xdb, 0xe3, 0x06, 0xff, + 0xf4, 0xf2, 0x1b, 0xf1, 0xf7, 0x02, 0x01, 0x04, 0x13, 0xe5, 0x0c, 0x05, + 0xf7, 0x0a, 0x03, 0x03, 0x0b, 0x03, 0xee, 0xf7, 0x21, 0x20, 0xff, 0xf3, + 0x09, 0xe5, 0xff, 0xec, 0x17, 0x00, 0x06, 0x14, 0xeb, 0xf2, 0x18, 0x16, + 0x1f, 0xec, 0xee, 0xe1, 0x1e, 0x03, 0xfa, 0xfe, 0x28, 0x03, 0xc9, 0x0c, + 0x3f, 0xd8, 0x30, 0x16, 0x03, 0xf8, 0xe9, 0xfb, 0x28, 0xe1, 0x36, 0x0a, + 0xdf, 0xe5, 0xeb, 0x08, 0x1c, 0xcd, 0x29, 0xf2, 0xfc, 0x0a, 0xed, 0x01, + 0x29, 0xf1, 0x20, 0x13, 0x04, 0xec, 0x17, 0x0a, 0x35, 0xc3, 0x1a, 0x46, + 0xe0, 0xd7, 0x3c, 0x09, 0x28, 0xd1, 0x22, 0x20, 0xd5, 0xfa, 0x28, 0xfa, + 0xff, 0xea, 0x1d, 0x23, 0xe0, 0x07, 0x07, 0x0f, 0xf1, 0xf1, 0x08, 0xf0, + 0xf8, 0xff, 0x05, 0x1b, 0x05, 0xfa, 0xf0, 0xfb, 0xe3, 0xe4, 0xcc, 0x1a, + 0xf9, 0x09, 0x06, 0xee, 0xf4, 0x03, 0xd0, 0x14, 0xf4, 0xff, 0x1d, 0xe8, + 0x11, 0xf4, 0xd1, 0xf4, 0x04, 0x0b, 0xfb, 0xdc, 0x0a, 0x0c, 0xeb, 0xed, + 0x06, 0xf3, 0x04, 0xdd, 0xdf, 0xf9, 0xea, 0xfc, 0xf5, 0xf2, 0xfb, 0xea, + 0xe3, 0x03, 0xee, 0x0e, 0xff, 0xdb, 0x1e, 0x04, 0xf7, 0x1a, 0x04, 0x0c, + 0x0d, 0xda, 0x04, 0xe9, 0xff, 0x04, 0x00, 0x0c, 0xf9, 0xe4, 0xfb, 0xf6, + 0x14, 0xde, 0x1b, 0x00, 0x0b, 0xfe, 0x06, 0xf8, 0x0f, 0xdc, 0x01, 0xef, + 0xef, 0x0d, 0xf8, 0xf1, 0x0f, 0xf9, 0xf9, 0xdf, 0x0d, 0xe4, 0xd9, 0xf9, + 0x2b, 0xee, 0xe8, 0x09, 0x40, 0xf9, 0x2f, 0x0a, 0xfa, 0xe8, 0xe9, 0x01, + 0x0e, 0xe7, 0x23, 0x0a, 0xd0, 0x19, 0xd3, 0x0e, 0x04, 0xda, 0x2b, 0x0f, + 0xe7, 0xe6, 0xf3, 0xfb, 0x2c, 0xd3, 0x36, 0x19, 0x0e, 0xfe, 0x03, 0x1a, + 0x2e, 0xd0, 0x23, 0x32, 0xf1, 0xe1, 0x2a, 0x09, 0x1b, 0xf6, 0x29, 0x3e, + 0xce, 0x15, 0x0a, 0xe8, 0xec, 0xdf, 0x44, 0x28, 0xd9, 0xfd, 0xfa, 0x09, + 0xff, 0xe7, 0x08, 0xec, 0xf4, 0xef, 0x01, 0x19, 0x11, 0xf3, 0xeb, 0xeb, + 0xed, 0x1a, 0xdd, 0x15, 0x0f, 0x07, 0xfe, 0xeb, 0xff, 0xd6, 0xd5, 0x04, + 0xf5, 0x07, 0x10, 0xe6, 0x0c, 0xe4, 0xda, 0x0c, 0x08, 0xee, 0x06, 0xd8, + 0xf8, 0xf1, 0xe0, 0x01, 0x08, 0xfe, 0xf9, 0xf3, 0xdf, 0x03, 0xe6, 0xf4, + 0x0a, 0xff, 0xf2, 0xe0, 0xd9, 0xeb, 0x01, 0x10, 0x02, 0xfc, 0x0d, 0x14, + 0xea, 0xf8, 0x03, 0x18, 0xf3, 0x09, 0xfc, 0x0c, 0x0b, 0x1f, 0xf5, 0x05, + 0xf7, 0xf9, 0x00, 0xfd, 0x04, 0xfc, 0x16, 0x07, 0x00, 0xdf, 0xf9, 0xfa, + 0x0c, 0xfb, 0xf4, 0xf7, 0xf0, 0xeb, 0x07, 0x17, 0x20, 0xfb, 0xf0, 0xec, + 0x04, 0x00, 0xf8, 0xf2, 0x2d, 0xf9, 0xd9, 0x0b, 0x55, 0xec, 0x33, 0x26, + 0xf8, 0x0a, 0xf2, 0x0b, 0x25, 0xdf, 0x29, 0x05, 0xd1, 0x14, 0xe2, 0xf2, + 0x12, 0xdd, 0x28, 0xfc, 0xec, 0x08, 0xfd, 0x02, 0x3a, 0xe6, 0x29, 0x25, + 0x0d, 0x10, 0x09, 0x0a, 0x32, 0xf5, 0x17, 0x2d, 0xea, 0xfb, 0x35, 0xfc, + 0x28, 0xd0, 0x29, 0x2f, 0xcb, 0x06, 0x0f, 0x04, 0xf2, 0xf3, 0x34, 0x1c, + 0xf4, 0x08, 0x05, 0xfc, 0xfd, 0xed, 0x0f, 0xf8, 0xe9, 0xf0, 0x09, 0x16, + 0xfe, 0x02, 0xff, 0xd4, 0xea, 0x0a, 0xeb, 0x0c, 0xf8, 0xf4, 0x09, 0xf4, + 0xf2, 0x07, 0xd9, 0x0b, 0xfd, 0xe4, 0x1a, 0xef, 0x14, 0x08, 0xd8, 0xfc, + 0xf5, 0xe1, 0x03, 0xcf, 0xf1, 0x11, 0xdb, 0x15, 0x07, 0x10, 0xf8, 0xfc, + 0xe2, 0xf1, 0xf5, 0xde, 0xff, 0xe7, 0x01, 0xea, 0xee, 0xe9, 0x02, 0x0a, + 0x18, 0xec, 0xfe, 0xf9, 0x09, 0xf3, 0x0e, 0x02, 0xf1, 0xfc, 0xf9, 0x16, + 0x05, 0x07, 0x09, 0x0d, 0x0e, 0xf7, 0x04, 0xed, 0x04, 0xdb, 0x04, 0x04, + 0xf6, 0xdc, 0xee, 0xec, 0xf5, 0xfe, 0xf4, 0x02, 0xe4, 0x0b, 0xe0, 0x17, + 0x0a, 0xe0, 0xf7, 0xdc, 0x11, 0xd6, 0xfe, 0xfa, 0x35, 0xde, 0xe6, 0x06, + 0x44, 0xf9, 0x35, 0x0a, 0xfb, 0xff, 0xec, 0xfb, 0x16, 0xd9, 0x23, 0x0f, + 0xd4, 0xef, 0xdf, 0x06, 0x0b, 0xd9, 0x25, 0xff, 0xf8, 0xeb, 0xf4, 0x0a, + 0x20, 0xe5, 0x22, 0x1c, 0xeb, 0xf4, 0x0d, 0x0c, 0x19, 0xe1, 0x1e, 0x31, + 0xe9, 0xfb, 0x20, 0xf0, 0x23, 0xfe, 0x35, 0x28, 0xb4, 0x06, 0x28, 0xe7, + 0xfb, 0xe9, 0x2a, 0x1a, 0xef, 0x15, 0x0c, 0xed, 0xf1, 0x04, 0x0e, 0x0a, + 0xff, 0x16, 0x01, 0x04, 0x17, 0xea, 0xec, 0xdc, 0xf4, 0xf7, 0x04, 0x16, + 0x1f, 0x0a, 0x11, 0xef, 0x12, 0xdf, 0xd9, 0x0c, 0xf5, 0x10, 0x02, 0xf3, + 0x10, 0x03, 0xd3, 0xf5, 0x0b, 0x02, 0x00, 0xcb, 0xf6, 0x23, 0xf6, 0xf1, + 0x1f, 0xf9, 0xfc, 0xf0, 0xf6, 0xfe, 0xfa, 0xf8, 0xf9, 0xf4, 0xfb, 0x0a, + 0xd6, 0x29, 0x09, 0x02, 0x00, 0xfc, 0xfc, 0xee, 0xf5, 0x05, 0xfb, 0x1e, + 0xf1, 0xf1, 0xf3, 0x02, 0xec, 0x1c, 0x0c, 0x0e, 0x0b, 0x04, 0xf6, 0xe7, + 0x14, 0x08, 0x27, 0x01, 0xfe, 0xe5, 0xe7, 0x01, 0x1b, 0xf0, 0xf6, 0xff, + 0xf4, 0xe7, 0xee, 0x18, 0x0d, 0x08, 0xf8, 0xd6, 0x07, 0xf4, 0x08, 0xff, + 0x1d, 0x13, 0xe7, 0x0b, 0x42, 0xef, 0x28, 0x00, 0xf9, 0xf0, 0xf3, 0x00, + 0x15, 0xfd, 0x1a, 0x22, 0xc1, 0xf5, 0xe0, 0xf8, 0x09, 0xe6, 0x0e, 0x05, + 0xf9, 0xf6, 0x01, 0x01, 0x13, 0xdc, 0x1f, 0x0d, 0xfb, 0x04, 0x08, 0x0b, + 0x15, 0xdb, 0x28, 0x34, 0xed, 0x0b, 0x3a, 0xed, 0x16, 0xe3, 0x39, 0x32, + 0xc4, 0x0b, 0x20, 0xe7, 0xf7, 0x02, 0x35, 0x24, 0xfc, 0xe8, 0x1c, 0xf8, + 0xf1, 0xfa, 0x0c, 0x1d, 0xf2, 0x05, 0xff, 0x12, 0x0f, 0x01, 0xec, 0xea, + 0xf0, 0x03, 0xe7, 0x15, 0xfd, 0x05, 0x08, 0xe0, 0x1b, 0xf8, 0xe1, 0x1e, + 0xed, 0xdc, 0x11, 0xeb, 0xfd, 0x1a, 0xeb, 0x09, 0xf9, 0xf3, 0x00, 0xe8, + 0xe6, 0x08, 0xf7, 0xde, 0x1e, 0x00, 0x00, 0x00, 0xe4, 0x09, 0xf2, 0xf8, + 0xe7, 0xf2, 0x0d, 0xfa, 0xe2, 0x0f, 0x04, 0x08, 0xf2, 0x13, 0xf8, 0xf9, + 0xf1, 0xff, 0x03, 0x11, 0x12, 0xe9, 0xf4, 0x13, 0x07, 0x0c, 0x13, 0x2b, + 0xf7, 0xdd, 0xf9, 0xe9, 0xfa, 0xdb, 0x1d, 0xf6, 0xf6, 0xf9, 0xe4, 0xf6, + 0x0d, 0xeb, 0x0d, 0x08, 0xe7, 0xe7, 0xf2, 0x03, 0x1d, 0xd9, 0xd8, 0xe4, + 0xf7, 0xea, 0xdc, 0xdc, 0x26, 0x02, 0xee, 0xfa, 0x38, 0xfc, 0x1a, 0xef, + 0xda, 0xf1, 0xdf, 0x0b, 0x1a, 0xe0, 0x16, 0x16, 0xdc, 0x04, 0xfa, 0xf7, + 0xee, 0x02, 0x25, 0x02, 0xf5, 0xfb, 0x08, 0xf6, 0x11, 0xf5, 0x12, 0x08, + 0xf4, 0xe3, 0x1b, 0xf5, 0x3a, 0xdc, 0x20, 0x2e, 0xe0, 0xf5, 0x30, 0xe4, + 0x09, 0xf8, 0x3c, 0x45, 0xd3, 0x08, 0x23, 0xd8, 0x09, 0xe4, 0x35, 0x30, + 0xe4, 0xfe, 0x07, 0xf6, 0x05, 0x01, 0x05, 0xff, 0xf6, 0x0d, 0x02, 0xfd, + 0x03, 0x05, 0x0d, 0x00, 0xf5, 0xd6, 0xcf, 0x19, 0x06, 0xee, 0x0d, 0xf2, + 0x01, 0x18, 0xef, 0x12, 0x04, 0x02, 0x21, 0xd9, 0x02, 0x0d, 0xeb, 0xe9, + 0x13, 0x08, 0x15, 0xf0, 0xee, 0x03, 0xec, 0x06, 0x17, 0xed, 0x00, 0x1a, + 0xee, 0xf2, 0xfc, 0x09, 0xec, 0xf8, 0xf8, 0x18, 0xf4, 0x13, 0x04, 0xf6, + 0x02, 0xf0, 0xfc, 0xfe, 0xe3, 0x01, 0x0a, 0x1c, 0x1b, 0xec, 0x0e, 0x01, + 0xfb, 0x08, 0x11, 0xf5, 0x00, 0x14, 0xe6, 0x12, 0x07, 0xf4, 0x15, 0x07, + 0xfc, 0xfb, 0xf5, 0xf1, 0x01, 0x21, 0x01, 0xe9, 0xe8, 0xef, 0xdb, 0xdf, + 0x1f, 0x0a, 0xdd, 0xd1, 0x16, 0x04, 0xfd, 0xe1, 0x24, 0xf0, 0xec, 0xf4, + 0x38, 0xe1, 0x16, 0xfd, 0xe0, 0xec, 0xe7, 0x0c, 0x2a, 0x04, 0x0c, 0x17, + 0xdc, 0xe8, 0xf2, 0x03, 0xec, 0xfd, 0x19, 0xfe, 0xf3, 0xf0, 0xf3, 0xfb, + 0x18, 0xdf, 0x1c, 0x00, 0x09, 0xf4, 0x18, 0x0b, 0x1f, 0xf6, 0x34, 0x22, + 0xf4, 0x22, 0x45, 0xeb, 0x23, 0xcf, 0x32, 0x34, 0xf2, 0xf9, 0x29, 0xd4, + 0xf7, 0x0b, 0x38, 0x2a, 0x09, 0xe6, 0x05, 0x01, 0x0b, 0xfe, 0x17, 0xfb, + 0x00, 0xeb, 0x08, 0xfd, 0x0c, 0x02, 0x1d, 0xea, 0xfa, 0x0b, 0xeb, 0x09, + 0xfe, 0xfe, 0x10, 0xe0, 0xf6, 0x06, 0xf0, 0x15, 0xf3, 0x09, 0x11, 0xe4, + 0xf9, 0x07, 0xe1, 0xed, 0x17, 0x05, 0x0c, 0xe1, 0xdb, 0xf2, 0xf8, 0xea, + 0x22, 0xe9, 0x02, 0x00, 0xfd, 0xe7, 0xf2, 0xf8, 0xf9, 0xfc, 0xfa, 0xe8, + 0xe8, 0xeb, 0xe9, 0x0d, 0x04, 0xf8, 0xf8, 0xf7, 0xf8, 0x0d, 0x03, 0x0c, + 0x13, 0xf2, 0x0f, 0xf9, 0xe6, 0xfd, 0x0f, 0x19, 0x08, 0xf7, 0xfa, 0x01, + 0xf3, 0x12, 0x1e, 0x05, 0x0a, 0x09, 0xfd, 0x0b, 0x07, 0x08, 0x02, 0xfc, + 0xd6, 0xe8, 0x14, 0x01, 0x13, 0x19, 0xef, 0xda, 0x0e, 0x0a, 0x07, 0xef, + 0x34, 0xe0, 0x05, 0x1e, 0x4e, 0xe9, 0x19, 0xff, 0xe1, 0x04, 0xfb, 0x0e, + 0x11, 0x05, 0x1f, 0x15, 0xd4, 0xec, 0xf9, 0xe7, 0xf9, 0xfc, 0x25, 0xff, + 0x06, 0xf2, 0x01, 0xf6, 0x2a, 0x17, 0x24, 0x11, 0xf3, 0x1a, 0x1f, 0xfb, + 0x32, 0xeb, 0x33, 0x2f, 0x00, 0x08, 0x2c, 0xf0, 0x26, 0xf4, 0x25, 0x36, + 0xd9, 0xf1, 0x1a, 0xd5, 0xec, 0xf9, 0x32, 0x27, 0xfc, 0xf4, 0xf0, 0xe3, + 0xfa, 0x0c, 0x16, 0x17, 0xfa, 0xf9, 0xe5, 0x1f, 0x1f, 0xfa, 0xff, 0xfd, + 0x0d, 0x02, 0xe9, 0x0e, 0xf0, 0x12, 0x09, 0xda, 0x02, 0xea, 0xe5, 0x0a, + 0xff, 0x03, 0x13, 0xf0, 0x0a, 0xf9, 0xe9, 0xff, 0x10, 0xfc, 0x1a, 0xf3, + 0xf7, 0x0f, 0xf4, 0xfa, 0xf4, 0x05, 0x10, 0x0a, 0xdd, 0x09, 0xf7, 0xf0, + 0xe5, 0x07, 0x07, 0xfa, 0x02, 0xd7, 0xf8, 0xf7, 0x01, 0xfb, 0x0e, 0xf8, + 0x07, 0x0f, 0xfe, 0x03, 0x12, 0x05, 0x09, 0x13, 0xf8, 0xdc, 0xfd, 0x27, + 0x0f, 0xec, 0xf7, 0x07, 0x00, 0xfc, 0x12, 0xf8, 0xfb, 0xea, 0xe4, 0xe9, + 0xe9, 0xe0, 0xff, 0xdc, 0xd6, 0xeb, 0xf2, 0xf7, 0x0d, 0x1b, 0xe9, 0xc4, + 0x06, 0x00, 0xfd, 0x04, 0x46, 0xf9, 0xe9, 0x13, 0x2d, 0x0c, 0x1f, 0xf8, + 0xd3, 0x0c, 0x14, 0x11, 0x05, 0xe5, 0x27, 0x08, 0xc5, 0xef, 0xdf, 0xdd, + 0x04, 0xf8, 0x11, 0x10, 0xf0, 0xe7, 0xfb, 0x03, 0x3c, 0xe7, 0x14, 0x0c, + 0xf4, 0xf6, 0x1b, 0x0a, 0x23, 0xf2, 0x2d, 0x1a, 0x08, 0xff, 0x32, 0xe7, + 0x1a, 0x05, 0x2b, 0x34, 0xf1, 0x0a, 0x00, 0xe8, 0x02, 0xdf, 0x2c, 0x2a, + 0x03, 0xe6, 0xfc, 0xef, 0xfc, 0xe4, 0x03, 0x01, 0x03, 0xee, 0xe9, 0x15, + 0x05, 0x03, 0x13, 0x11, 0x0e, 0xee, 0xf5, 0x22, 0x1b, 0x0e, 0xfd, 0xf3, + 0x0a, 0x02, 0xdd, 0x20, 0xeb, 0x06, 0xf8, 0xe2, 0x06, 0x0e, 0xde, 0x0d, + 0xf9, 0x16, 0x1c, 0x0c, 0xe0, 0xf0, 0xec, 0x0c, 0x0f, 0xf2, 0x27, 0x1d, + 0xde, 0xe6, 0xf0, 0xf9, 0xf0, 0x02, 0x0a, 0x07, 0x06, 0xf9, 0x0f, 0xfa, + 0xf0, 0xee, 0xf1, 0xf7, 0xff, 0x02, 0x0b, 0x0d, 0x1b, 0xee, 0xf6, 0x05, + 0xff, 0x1c, 0x17, 0x04, 0x05, 0x17, 0x00, 0xff, 0x0d, 0xf3, 0x23, 0x10, + 0xfd, 0x05, 0xfb, 0xea, 0x03, 0x10, 0x07, 0xd7, 0xf7, 0xff, 0xf3, 0xf1, + 0x17, 0xed, 0xd3, 0xcb, 0x14, 0x1c, 0xf5, 0x03, 0x47, 0xf6, 0xf7, 0xf2, + 0x3e, 0xf2, 0x22, 0xf4, 0xed, 0xfc, 0xee, 0x0b, 0xf4, 0xf1, 0x25, 0x10, + 0xd0, 0xf6, 0x00, 0xef, 0x10, 0xfc, 0x15, 0xe5, 0xdb, 0xf3, 0xea, 0x10, + 0x22, 0xf2, 0x2b, 0x11, 0xf9, 0x0a, 0xfc, 0xf5, 0x53, 0x16, 0x25, 0x43, + 0xe0, 0x0e, 0x13, 0xfc, 0x2d, 0xe2, 0x55, 0x65, 0xf4, 0x08, 0x01, 0xdf, + 0x0a, 0x00, 0x49, 0x1c, 0xfe, 0xdf, 0xef, 0xf2, 0xf9, 0xf6, 0xfd, 0xff, + 0xf3, 0x02, 0xf6, 0x14, 0x0b, 0xe8, 0x09, 0xfc, 0xfc, 0xe2, 0xe5, 0x11, + 0x03, 0x09, 0xfb, 0x06, 0x10, 0x1a, 0xf3, 0x0d, 0xfa, 0x0a, 0xd5, 0xf5, + 0x1a, 0x11, 0xf2, 0xfc, 0x1f, 0xfe, 0x0e, 0xe4, 0xef, 0xd7, 0xee, 0x06, + 0x1e, 0x04, 0x12, 0x28, 0xf7, 0x0e, 0x06, 0xf8, 0xee, 0xf0, 0x1a, 0x01, + 0xf7, 0xfd, 0x03, 0x11, 0x19, 0x10, 0x04, 0xfb, 0xd7, 0xfa, 0x16, 0x06, + 0x07, 0x23, 0xfa, 0x14, 0x11, 0xf1, 0x12, 0x10, 0x04, 0xe1, 0xee, 0xf7, + 0x21, 0x0e, 0x0a, 0x0a, 0xf8, 0x07, 0x0a, 0xee, 0x03, 0x1f, 0xfa, 0xc4, + 0xec, 0x12, 0x01, 0x1e, 0xfd, 0xf1, 0xe8, 0xcc, 0xf4, 0x17, 0xff, 0xdd, + 0x45, 0x10, 0xee, 0xfa, 0x3d, 0xe7, 0x27, 0xdd, 0xd7, 0xf9, 0xf4, 0xf6, + 0x06, 0xf8, 0x1e, 0x13, 0xe7, 0xe2, 0xf1, 0xe3, 0xf3, 0xf7, 0x18, 0x12, + 0xe4, 0x0a, 0xdb, 0xff, 0xff, 0xfe, 0x20, 0x09, 0x00, 0xf7, 0x23, 0xf6, + 0x2d, 0x14, 0x26, 0x28, 0xe5, 0xff, 0x0f, 0xe3, 0x1d, 0xe8, 0x56, 0x43, + 0xe7, 0xfb, 0xf9, 0xe6, 0xe9, 0xe2, 0x19, 0x19, 0x08, 0xfa, 0xf3, 0xe5, + 0x23, 0x07, 0x0f, 0xf8, 0xf8, 0xf3, 0xfc, 0x11, 0x2a, 0x05, 0xf4, 0xf1, + 0xfa, 0xfb, 0xf1, 0x1e, 0x13, 0x0f, 0xf9, 0xf5, 0xfa, 0x09, 0xf9, 0x03, + 0xf0, 0xf0, 0xe7, 0xec, 0xf1, 0x0c, 0xe6, 0xee, 0xf6, 0x20, 0x0f, 0xe9, + 0x00, 0xf4, 0xfe, 0xf0, 0x13, 0x0a, 0x17, 0x13, 0xee, 0x13, 0xfb, 0xff, + 0xf8, 0xfd, 0xf4, 0xe2, 0xe8, 0x06, 0xfc, 0x14, 0x03, 0x17, 0x00, 0x03, + 0xe6, 0xfd, 0xf2, 0x12, 0x12, 0x20, 0xeb, 0x10, 0x02, 0xf7, 0x13, 0x0d, + 0x11, 0xfd, 0xde, 0xf5, 0x07, 0xf3, 0x04, 0xff, 0x06, 0x05, 0xfb, 0xea, + 0xf0, 0x0a, 0x00, 0xb5, 0xe8, 0x1a, 0x03, 0xfe, 0x0d, 0x1a, 0xe7, 0xc0, + 0xd6, 0xdc, 0xf6, 0xf8, 0x39, 0xf5, 0xd5, 0xf8, 0x22, 0xfa, 0x22, 0x05, + 0xd0, 0xf4, 0x2d, 0xfc, 0x00, 0x0a, 0x1b, 0xfc, 0xe6, 0x09, 0x14, 0xfa, + 0x00, 0x1d, 0x1a, 0xfd, 0xf3, 0x18, 0xfc, 0xeb, 0x15, 0xf5, 0x0e, 0x0a, + 0xf3, 0xf1, 0x1b, 0x05, 0x14, 0x03, 0x2d, 0x27, 0xfb, 0x18, 0x22, 0xef, + 0xf6, 0x06, 0x28, 0x2b, 0xde, 0xec, 0xef, 0xe8, 0xd3, 0xfe, 0x17, 0x12, + 0x01, 0x13, 0x05, 0xf7, 0x00, 0xde, 0xf3, 0xe5, 0x03, 0xfb, 0x07, 0x0b, + 0xfd, 0xdc, 0xdf, 0x03, 0x0c, 0x00, 0xfa, 0x06, 0x0e, 0x02, 0x05, 0xfa, + 0xfd, 0xed, 0x09, 0x0c, 0xfd, 0xfb, 0x0c, 0xf0, 0xe4, 0x04, 0xd6, 0xf3, + 0x09, 0x0a, 0xf9, 0xf8, 0xe2, 0xef, 0xdf, 0xf0, 0xf8, 0x03, 0x0f, 0x20, + 0xf4, 0xe3, 0xf8, 0x02, 0xe2, 0xe5, 0x25, 0x0f, 0xeb, 0xf8, 0xe9, 0xfd, + 0x04, 0x0c, 0x0c, 0xfe, 0x01, 0x08, 0xfc, 0xfc, 0x1b, 0x01, 0xe5, 0x13, + 0xf9, 0xe8, 0x07, 0x20, 0xfe, 0x06, 0xec, 0xfe, 0x09, 0xef, 0x14, 0x04, + 0x0b, 0xf5, 0xe7, 0xff, 0x0a, 0x02, 0x09, 0xe9, 0xc4, 0x16, 0x0d, 0xe7, + 0x15, 0x14, 0xf1, 0xd0, 0xec, 0xe7, 0xf0, 0xf0, 0x33, 0x05, 0xda, 0xf2, + 0x0b, 0x08, 0x38, 0x01, 0x07, 0xfd, 0xd8, 0x06, 0xd9, 0xf0, 0x16, 0x1f, + 0xff, 0xf7, 0xe0, 0xd8, 0xf3, 0xf7, 0x12, 0x08, 0x0e, 0x05, 0xf6, 0x03, + 0xef, 0x1b, 0x12, 0xf4, 0xe8, 0x0f, 0x02, 0xfd, 0xf2, 0x16, 0x26, 0x22, + 0xe0, 0x07, 0xf7, 0xe6, 0xeb, 0x16, 0x22, 0x1a, 0x0b, 0x01, 0xf5, 0xea, + 0xd2, 0x22, 0x0f, 0x13, 0x15, 0x08, 0xf0, 0xfb, 0xed, 0x11, 0xf3, 0xe9, + 0xff, 0xde, 0x0a, 0x18, 0x0f, 0x02, 0xfb, 0xf9, 0xfb, 0xe8, 0x12, 0x18, + 0x01, 0xf4, 0xf6, 0xf8, 0xf0, 0x1f, 0x24, 0x15, 0xf5, 0x00, 0x1c, 0xf9, + 0x01, 0x0a, 0x11, 0xd5, 0x01, 0x12, 0x02, 0xec, 0xfd, 0x07, 0xf2, 0xea, + 0xf9, 0xff, 0xf7, 0xfb, 0x15, 0xec, 0xe5, 0x01, 0xeb, 0x05, 0xf9, 0x10, + 0xfe, 0x28, 0xe5, 0x0a, 0xeb, 0x1b, 0x0e, 0xf9, 0xde, 0x02, 0x15, 0x0a, + 0xff, 0xfe, 0x11, 0x24, 0x03, 0xf8, 0x00, 0x08, 0xfd, 0x0e, 0xeb, 0xf3, + 0xf6, 0xf7, 0x14, 0x0e, 0xfc, 0xf5, 0xde, 0xf5, 0x9e, 0xfe, 0xff, 0xff, + 0x04, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0xab, 0x01, 0x00, 0x00, + 0xfa, 0xfd, 0xff, 0xff, 0xa2, 0xff, 0xff, 0xff, 0xba, 0x00, 0x00, 0x00, + 0x24, 0xfc, 0xff, 0xff, 0x0f, 0x00, 0x00, 0x00, 0x54, 0x4f, 0x43, 0x4f, + 0x20, 0x43, 0x6f, 0x6e, 0x76, 0x65, 0x72, 0x74, 0x65, 0x64, 0x2e, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x24, 0xfb, 0xff, 0xff, + 0x68, 0x01, 0x00, 0x00, 0x5c, 0x01, 0x00, 0x00, 0x50, 0x01, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xf4, 0x00, 0x00, 0x00, + 0x90, 0x00, 0x00, 0x00, 0x48, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0xce, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x09, 0x03, 0x00, 0x00, 0x00, + 0x1c, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x1a, 0xff, 0xff, 0xff, 0x00, 0x00, 0x80, 0x3f, 0x01, 0x00, 0x00, 0x00, + 0x09, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x0e, 0x00, 0x18, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x10, 0x00, + 0x07, 0x00, 0x14, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, + 0x01, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xc4, 0xfc, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x07, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, + 0x16, 0x00, 0x00, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x07, 0x00, 0x10, 0x00, + 0x0e, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x38, 0x00, 0x00, 0x00, + 0x2c, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, + 0x14, 0x00, 0x00, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x10, 0x00, 0x07, 0x00, + 0x0e, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x02, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, + 0x1a, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x10, 0x00, 0x07, 0x00, 0x14, 0x00, + 0x0e, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x11, 0x02, 0x00, 0x00, 0x00, + 0x38, 0x00, 0x00, 0x00, 0x2c, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x06, 0x00, 0x08, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, + 0x31, 0x00, 0x00, 0x00, 0x28, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x09, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x0a, 0x00, 0x00, 0x00, 0x34, 0x04, 0x00, 0x00, 0xcc, 0x03, 0x00, 0x00, + 0x4c, 0x03, 0x00, 0x00, 0xdc, 0x02, 0x00, 0x00, 0x60, 0x02, 0x00, 0x00, + 0x20, 0x02, 0x00, 0x00, 0xb0, 0x01, 0x00, 0x00, 0x44, 0x01, 0x00, 0x00, + 0x70, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x02, 0xfc, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x09, 0x44, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, + 0x28, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xf4, 0xfb, 0xff, 0xff, + 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x80, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x80, 0x3b, 0x0e, 0x00, 0x00, 0x00, 0x6c, 0x61, 0x62, 0x65, + 0x6c, 0x73, 0x5f, 0x73, 0x6f, 0x66, 0x74, 0x6d, 0x61, 0x78, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x0e, 0x00, 0x1a, 0x00, 0x08, 0x00, 0x07, 0x00, 0x0c, 0x00, + 0x10, 0x00, 0x14, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09, + 0xb4, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x94, 0x00, 0x00, 0x00, + 0x18, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0c, 0x00, + 0x12, 0x00, 0x00, 0x00, 0x50, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x11, 0x1e, 0x23, 0x3a, 0x9e, 0xa1, 0x15, 0x39, + 0x23, 0x69, 0x45, 0x3a, 0x09, 0xe4, 0xe4, 0x39, 0x65, 0xd7, 0x13, 0x3a, + 0xe0, 0xb2, 0xfd, 0x39, 0x1b, 0xc1, 0x53, 0x3a, 0xc2, 0x50, 0x2d, 0x3a, + 0x12, 0x00, 0x00, 0x00, 0x66, 0x69, 0x72, 0x73, 0x74, 0x5f, 0x77, 0x65, + 0x69, 0x67, 0x68, 0x74, 0x73, 0x2f, 0x72, 0x65, 0x61, 0x64, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x3a, 0xfd, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x09, 0x54, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x28, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x2c, 0xfd, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - 0xc6, 0xd0, 0xd0, 0x3d, 0x01, 0x00, 0x00, 0x00, 0xf5, 0xff, 0xcf, 0x41, - 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xbc, 0xff, 0xff, - 0x00, 0x00, 0x00, 0x03, 0x10, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, - 0x1c, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x19, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, - 0x08, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x52, 0x65, 0x6c, 0x75, - 0x00, 0x00, 0x00, 0x00, 0x04, 0xfb, 0xff, 0xff, 0x2c, 0x00, 0x00, 0x00, - 0x20, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x09, 0xf5, 0x83, 0x3d, 0x01, 0x00, 0x00, 0x00, - 0x14, 0x71, 0x83, 0x41, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x72, 0xbc, 0xff, 0xff, 0x00, 0x00, 0x00, 0x02, 0x10, 0x00, 0x00, 0x00, - 0x06, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, - 0x43, 0x6f, 0x6e, 0x76, 0x32, 0x44, 0x5f, 0x62, 0x69, 0x61, 0x73, 0x00, - 0x64, 0xbc, 0xff, 0xff, 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x2d, 0x95, 0x98, 0x38, - 0x20, 0x00, 0x00, 0x00, 0x27, 0xff, 0xff, 0xff, 0x97, 0xff, 0xff, 0xff, - 0x58, 0x00, 0x00, 0x00, 0x66, 0xff, 0xff, 0xff, 0x13, 0xff, 0xff, 0xff, - 0x72, 0xfe, 0xff, 0xff, 0x5d, 0xfb, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, - 0xea, 0xbc, 0xff, 0xff, 0x00, 0x00, 0x00, 0x03, 0x10, 0x00, 0x00, 0x00, - 0x05, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x0e, 0x00, 0x00, 0x00, 0x6c, 0x61, 0x62, 0x65, 0x6c, 0x73, 0x5f, 0x73, - 0x6f, 0x66, 0x74, 0x6d, 0x61, 0x78, 0x00, 0x00, 0xec, 0xfb, 0xff, 0xff, - 0x2c, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x3b, - 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x7f, 0x3f, 0x01, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x5a, 0xbd, 0xff, 0xff, 0x00, 0x00, 0x00, 0x03, - 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, - 0x1c, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, 0x61, 0x64, 0x64, 0x5f, - 0x31, 0x00, 0x00, 0x00, 0x54, 0xfc, 0xff, 0xff, 0x2c, 0x00, 0x00, 0x00, - 0x20, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x8f, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x9c, 0xd2, 0xb5, 0x3d, 0x01, 0x00, 0x00, 0x00, - 0x48, 0x18, 0x1f, 0x41, 0x01, 0x00, 0x00, 0x00, 0x4a, 0x21, 0x4b, 0xc1, - 0xc2, 0xbd, 0xff, 0xff, 0x00, 0x00, 0x00, 0x03, 0x10, 0x00, 0x00, 0x00, - 0x03, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x44, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x00, - 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x25, 0x00, 0x00, 0x00, - 0x77, 0x65, 0x69, 0x67, 0x68, 0x74, 0x73, 0x5f, 0x71, 0x75, 0x61, 0x6e, - 0x74, 0x2f, 0x46, 0x61, 0x6b, 0x65, 0x51, 0x75, 0x61, 0x6e, 0x74, 0x57, - 0x69, 0x74, 0x68, 0x4d, 0x69, 0x6e, 0x4d, 0x61, 0x78, 0x56, 0x61, 0x72, - 0x73, 0x00, 0x00, 0x00, 0xe4, 0xfc, 0xff, 0xff, 0x30, 0x00, 0x00, 0x00, + 0xb5, 0xfa, 0xfa, 0x39, 0x1f, 0x00, 0x00, 0x00, 0x66, 0x69, 0x6e, 0x61, + 0x6c, 0x5f, 0x66, 0x63, 0x5f, 0x77, 0x65, 0x69, 0x67, 0x68, 0x74, 0x73, + 0x2f, 0x72, 0x65, 0x61, 0x64, 0x2f, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70, + 0x6f, 0x73, 0x65, 0x00, 0x02, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0xa0, 0x0f, 0x00, 0x00, 0xa2, 0xfd, 0xff, 0xff, 0x00, 0x00, 0x00, 0x09, + 0x58, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x44, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x74, 0xfe, 0xff, 0xff, 0x30, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x84, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x8a, 0x0f, 0x3b, 0x3a, - 0x01, 0x00, 0x00, 0x00, 0xfc, 0x0b, 0xb4, 0x3d, 0x01, 0x00, 0x00, 0x00, - 0xd9, 0x26, 0xbf, 0xbd, 0x80, 0x02, 0x00, 0x00, 0x60, 0x38, 0xab, 0xcb, - 0xfa, 0x7e, 0xa2, 0x55, 0x6e, 0x87, 0xa5, 0x9b, 0xb4, 0x66, 0x5c, 0x6f, - 0xae, 0xdb, 0xcd, 0xb6, 0xc2, 0x60, 0xa9, 0x7d, 0xd4, 0xac, 0xa6, 0x90, - 0x87, 0x6b, 0x50, 0x95, 0xde, 0xcd, 0xaa, 0xa1, 0x9c, 0x65, 0xb5, 0x6d, - 0xb0, 0xa5, 0xa5, 0x7f, 0x73, 0x95, 0x63, 0x81, 0x7a, 0xc6, 0xaf, 0x82, - 0x69, 0x89, 0xc3, 0x3c, 0x47, 0x73, 0x89, 0x4f, 0x33, 0xbc, 0x85, 0x5d, - 0x69, 0x11, 0x5b, 0xb9, 0xf1, 0x95, 0x8f, 0x5c, 0x7c, 0x59, 0x6c, 0xa0, - 0xa5, 0x7c, 0x5a, 0x7c, 0xb5, 0xa9, 0x7e, 0xa1, 0xb8, 0x65, 0xb3, 0x86, - 0xc1, 0x9f, 0x5c, 0x86, 0x7f, 0x74, 0x52, 0xa8, 0xc9, 0xc5, 0x71, 0x96, - 0x7a, 0x65, 0xc7, 0x69, 0x94, 0xa7, 0x65, 0x68, 0x69, 0x8d, 0x6d, 0x9e, - 0x59, 0xd4, 0x75, 0x7a, 0x4f, 0x70, 0xca, 0x48, 0x25, 0x8a, 0x69, 0x4d, - 0x2a, 0xa6, 0x76, 0x69, 0x6a, 0x02, 0x3b, 0xa2, 0xea, 0xc2, 0x73, 0x6b, - 0x86, 0x4d, 0x3a, 0xa2, 0xa2, 0x88, 0x4e, 0x6c, 0xb3, 0x83, 0x39, 0x93, - 0xa6, 0x85, 0xb8, 0x7a, 0xa8, 0x7d, 0x2e, 0x7b, 0x7f, 0x69, 0x56, 0xb5, - 0xbb, 0xae, 0x23, 0x78, 0x67, 0x5c, 0xd2, 0x82, 0x7d, 0x96, 0x46, 0x74, - 0x70, 0x72, 0x6a, 0x90, 0x43, 0xce, 0x44, 0x75, 0x4a, 0x58, 0xc7, 0x5c, - 0x34, 0x84, 0x46, 0x4b, 0x41, 0x6c, 0x62, 0x83, 0x7e, 0x01, 0x9b, 0x9b, - 0xeb, 0xf7, 0x58, 0x6f, 0x8a, 0x43, 0xb3, 0x9f, 0x9c, 0x9e, 0x55, 0xa8, - 0xaa, 0x84, 0x8f, 0x8f, 0xb0, 0x9e, 0xc8, 0x81, 0xb6, 0x80, 0xa0, 0x81, - 0x86, 0x73, 0x5d, 0xdc, 0xb9, 0xae, 0xa2, 0x6c, 0x46, 0x67, 0xfa, 0x79, - 0x89, 0xaf, 0xa0, 0x74, 0x76, 0x85, 0x72, 0xb1, 0x2a, 0xbb, 0xa0, 0x6d, - 0x4f, 0x50, 0xc9, 0x5d, 0x2f, 0xaa, 0x9c, 0x63, 0x3f, 0x59, 0x63, 0x90, - 0x73, 0x1e, 0xb3, 0x94, 0xcd, 0xff, 0x3c, 0x63, 0x9b, 0x59, 0xc5, 0xa2, - 0x9f, 0x9a, 0x53, 0xab, 0xb0, 0x74, 0xb2, 0x6f, 0x8a, 0xa7, 0xd5, 0x8d, - 0xb8, 0x7e, 0x9e, 0x78, 0x84, 0x61, 0x66, 0xe7, 0xa7, 0x9f, 0xb7, 0x45, - 0x24, 0x61, 0xfd, 0x69, 0x87, 0xb8, 0xb2, 0x7a, 0x7c, 0x58, 0x64, 0xa3, - 0x07, 0xa9, 0xaf, 0x69, 0x49, 0x2f, 0xc2, 0x46, 0x3b, 0xaf, 0x9a, 0x70, - 0x6b, 0x25, 0x5f, 0x9d, 0x82, 0x33, 0xa1, 0x54, 0xae, 0xff, 0x31, 0x5d, - 0xaf, 0x51, 0xb2, 0x82, 0x9c, 0xa9, 0x5b, 0x8c, 0xab, 0x75, 0xb3, 0x32, - 0x42, 0xbd, 0xcd, 0x77, 0xb6, 0x67, 0x9a, 0x5f, 0x6c, 0x71, 0x6e, 0xc2, - 0xac, 0x97, 0x9f, 0x4b, 0x21, 0x6a, 0xfc, 0x77, 0x83, 0xa1, 0xa3, 0x6a, - 0x7a, 0x6d, 0x5e, 0x87, 0x02, 0xa6, 0x8f, 0x7f, 0x5c, 0x2e, 0xc1, 0x51, - 0x4a, 0xa7, 0x96, 0x79, 0x83, 0x2e, 0x5a, 0x84, 0x82, 0x5c, 0x61, 0x3a, - 0x4a, 0xff, 0x2a, 0x51, 0xa4, 0x6b, 0x82, 0x5e, 0x67, 0xb3, 0x71, 0x80, - 0xad, 0x62, 0x59, 0x40, 0x26, 0xd7, 0xcf, 0x68, 0xab, 0x7c, 0x6a, 0x69, - 0x5b, 0x7c, 0x84, 0xbc, 0x95, 0x68, 0x77, 0x63, 0x3f, 0x85, 0xed, 0x7b, - 0x71, 0xa0, 0x76, 0x90, 0x8c, 0x6c, 0x61, 0x81, 0x16, 0x74, 0x72, 0x94, - 0x74, 0x37, 0xb5, 0x3d, 0x55, 0x96, 0x86, 0xad, 0x87, 0x39, 0x59, 0x88, - 0x5b, 0x65, 0x60, 0x33, 0x33, 0xe6, 0x2b, 0x4a, 0xb6, 0x82, 0x50, 0x56, - 0x51, 0x97, 0x71, 0x83, 0xa6, 0x60, 0x57, 0x51, 0x58, 0xe4, 0xd0, 0x87, - 0xa1, 0x78, 0x4c, 0x67, 0x72, 0x74, 0x86, 0xc6, 0x60, 0x47, 0x50, 0x96, - 0x67, 0x96, 0xdd, 0x7d, 0x63, 0x85, 0x5e, 0x98, 0xa2, 0x64, 0x5f, 0x8a, - 0x3b, 0x40, 0x54, 0xcb, 0xa0, 0x61, 0xa7, 0x44, 0x5f, 0x6d, 0x57, 0xb3, - 0xb9, 0x2e, 0x61, 0x8e, 0x54, 0x78, 0x85, 0x58, 0x43, 0xb0, 0x27, 0x5d, - 0x8a, 0x7c, 0x8a, 0x58, 0x40, 0x83, 0x82, 0x9b, 0x6c, 0x60, 0x6b, 0x72, - 0x7f, 0xde, 0xc9, 0x7d, 0x6f, 0x5f, 0x90, 0x7e, 0x7e, 0x7e, 0x8b, 0xe5, - 0x51, 0x37, 0x7a, 0xa9, 0xa2, 0xc5, 0xd3, 0x81, 0x32, 0x4b, 0x80, 0xa9, - 0xc5, 0x76, 0x56, 0x99, 0x33, 0x19, 0x72, 0xe6, 0xdb, 0x90, 0xa8, 0x50, - 0x65, 0x44, 0x77, 0xdb, 0xc7, 0x48, 0x65, 0x8d, 0x3d, 0x7f, 0xa2, 0x7c, - 0x53, 0x55, 0x26, 0x49, 0x5d, 0x7d, 0xa2, 0x6d, 0x3b, 0x5b, 0x87, 0x64, - 0x3a, 0x5b, 0x8d, 0x93, 0x7a, 0xb4, 0xca, 0x6d, 0x16, 0x5a, 0x99, 0x82, - 0x8d, 0x6a, 0x92, 0xa0, 0x39, 0x2c, 0x95, 0xc8, 0xb8, 0xf5, 0xc8, 0x66, - 0x2a, 0x45, 0x84, 0x9c, 0xc7, 0x8e, 0x61, 0x7b, 0x43, 0x28, 0x86, 0xff, - 0xd2, 0xc8, 0x9c, 0x46, 0x65, 0x33, 0x82, 0xd8, 0xcb, 0x73, 0x63, 0x80, - 0xda, 0xc0, 0xff, 0xff, 0x00, 0x00, 0x00, 0x03, 0x10, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x54, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xa0, 0x0f, 0x00, 0x00, - 0x31, 0x00, 0x00, 0x00, 0x77, 0x65, 0x69, 0x67, 0x68, 0x74, 0x73, 0x5f, - 0x71, 0x75, 0x61, 0x6e, 0x74, 0x5f, 0x31, 0x2f, 0x46, 0x61, 0x6b, 0x65, - 0x51, 0x75, 0x61, 0x6e, 0x74, 0x57, 0x69, 0x74, 0x68, 0x4d, 0x69, 0x6e, - 0x4d, 0x61, 0x78, 0x56, 0x61, 0x72, 0x73, 0x2f, 0x74, 0x72, 0x61, 0x6e, - 0x73, 0x70, 0x6f, 0x73, 0x65, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x14, 0x00, - 0x04, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x10, 0x00, 0x0c, 0x00, 0x00, 0x00, - 0x2c, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x7e, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x87, 0xff, 0xdb, 0x39, - 0x01, 0x00, 0x00, 0x00, 0xd8, 0xb2, 0x5d, 0x3d, 0x01, 0x00, 0x00, 0x00, - 0x37, 0xdc, 0x56, 0xbd, 0x80, 0x3e, 0x00, 0x00, 0x67, 0x6d, 0x74, 0x77, - 0x35, 0x66, 0x87, 0x95, 0x8e, 0x82, 0x5e, 0x70, 0x6e, 0xa7, 0x60, 0x64, - 0x86, 0x5e, 0x93, 0x7a, 0x76, 0x74, 0x71, 0x8c, 0x61, 0x71, 0x60, 0x8b, - 0x83, 0x48, 0x8b, 0x5f, 0x95, 0x99, 0x5b, 0x59, 0x49, 0x44, 0x79, 0x62, - 0x8e, 0x77, 0x71, 0x89, 0x64, 0x46, 0x8f, 0x8e, 0x80, 0x73, 0x71, 0x81, - 0x85, 0x4a, 0x73, 0x57, 0x66, 0x58, 0x75, 0x93, 0x99, 0x58, 0x8a, 0x7b, - 0x87, 0x81, 0xa1, 0x46, 0x79, 0x6c, 0x83, 0x7a, 0x92, 0x74, 0x6f, 0x6b, - 0x79, 0x77, 0x97, 0x8a, 0x95, 0x75, 0xa2, 0x49, 0x80, 0x4e, 0x7f, 0x6d, - 0xaa, 0xac, 0x6c, 0x5d, 0x57, 0x82, 0x97, 0x77, 0x6f, 0x75, 0x95, 0x73, - 0x7e, 0x51, 0x9f, 0x5b, 0x54, 0x92, 0x60, 0x72, 0x80, 0x6a, 0x92, 0x83, - 0x9b, 0x85, 0x7b, 0x4d, 0x55, 0x4d, 0xb2, 0x7d, 0x65, 0x95, 0x76, 0x42, - 0x61, 0x49, 0xa2, 0x73, 0x9f, 0x7d, 0x7c, 0x54, 0x51, 0x76, 0xa1, 0x7f, - 0x86, 0x69, 0x98, 0x59, 0x6d, 0x84, 0x9f, 0x7b, 0x86, 0x79, 0x88, 0x55, - 0x9c, 0x72, 0x95, 0x8a, 0x91, 0x7a, 0x77, 0x95, 0x7b, 0x87, 0x87, 0x85, - 0x95, 0x72, 0x77, 0x59, 0x7c, 0x80, 0x90, 0x8f, 0x8a, 0x62, 0x76, 0x9f, - 0x64, 0x84, 0x71, 0x7e, 0x7c, 0x66, 0x8e, 0x94, 0x6e, 0xaa, 0x77, 0x5c, - 0x6b, 0x63, 0x68, 0x82, 0x89, 0x46, 0x61, 0x74, 0x8e, 0x85, 0x6b, 0x57, - 0x74, 0x50, 0x87, 0x66, 0x87, 0x98, 0x59, 0x7d, 0xa2, 0x59, 0x75, 0x64, - 0x72, 0x8c, 0x6a, 0x92, 0x8c, 0x56, 0x88, 0x7a, 0x6e, 0x77, 0x9c, 0x82, - 0x7e, 0x5a, 0x91, 0x80, 0x9c, 0x9e, 0x60, 0x8b, 0x6d, 0x76, 0x8d, 0x68, - 0x6c, 0x70, 0x6f, 0x8b, 0x61, 0x6e, 0x86, 0x78, 0x81, 0x81, 0x77, 0x79, - 0x76, 0x69, 0x7d, 0x7b, 0x96, 0x8b, 0x95, 0x91, 0xa2, 0x7b, 0x86, 0x8d, - 0x8b, 0x89, 0x86, 0x5a, 0x5c, 0x4d, 0x96, 0x80, 0x81, 0x55, 0x80, 0x80, - 0x7a, 0x76, 0x99, 0x98, 0x61, 0x95, 0x5a, 0x78, 0x5a, 0x6c, 0x89, 0x81, - 0x98, 0x77, 0x62, 0x77, 0x93, 0x4d, 0x9f, 0x77, 0x72, 0x87, 0x95, 0x71, - 0x65, 0x72, 0xac, 0x8c, 0xa2, 0x89, 0x90, 0x7b, 0x67, 0x60, 0x8a, 0xb3, - 0x72, 0x8f, 0x5c, 0x82, 0x74, 0x76, 0x7c, 0x85, 0x78, 0x6b, 0x97, 0x6d, - 0x86, 0x82, 0x76, 0x84, 0x89, 0x89, 0x7f, 0x6a, 0x7a, 0x7f, 0x6c, 0x77, - 0x80, 0x35, 0x7d, 0x66, 0x96, 0x7e, 0x88, 0x55, 0x6b, 0x55, 0x7c, 0xa7, - 0x7f, 0x9f, 0x64, 0x8b, 0xa0, 0x81, 0x80, 0x97, 0xaf, 0x7a, 0x7d, 0x61, - 0x7a, 0x77, 0x6f, 0x8c, 0x5e, 0x69, 0x6b, 0x94, 0x70, 0x6a, 0x66, 0x5d, - 0x78, 0x6e, 0x76, 0x64, 0xa0, 0x73, 0x8f, 0xa2, 0x9d, 0x50, 0x8e, 0x52, - 0x51, 0x85, 0x78, 0x83, 0x8f, 0x94, 0x83, 0x7c, 0x9c, 0x64, 0x59, 0x7d, - 0x66, 0x6a, 0x73, 0x80, 0x6a, 0x9b, 0x92, 0x7e, 0x7a, 0x78, 0x7d, 0xa0, - 0x8a, 0x9b, 0x61, 0x9e, 0x6c, 0x64, 0x6c, 0x8e, 0x86, 0x75, 0x8a, 0x95, - 0x8e, 0x89, 0x87, 0x8a, 0x5d, 0x8b, 0x82, 0x7c, 0x60, 0x63, 0x85, 0x85, - 0x63, 0x96, 0xa3, 0x7f, 0x93, 0x78, 0x8c, 0x86, 0x7b, 0x78, 0x8e, 0x71, - 0x72, 0x8b, 0x8a, 0x5e, 0x8d, 0x75, 0x78, 0xa3, 0x84, 0x67, 0xa7, 0x54, - 0x6c, 0x80, 0x8e, 0xa8, 0x83, 0x51, 0x6e, 0x9f, 0x8b, 0x86, 0x75, 0x95, - 0x7f, 0x7a, 0x80, 0x81, 0x8d, 0x9c, 0x83, 0x8a, 0x7b, 0x8a, 0x74, 0x6f, - 0x8d, 0x96, 0x5b, 0x9c, 0x8d, 0x7b, 0x83, 0x79, 0x7f, 0x65, 0x7e, 0x87, - 0x7c, 0x5d, 0x71, 0x97, 0x77, 0x44, 0x9a, 0x7f, 0xaa, 0x56, 0x75, 0x5f, - 0x7c, 0x51, 0x8c, 0x90, 0x84, 0x9a, 0x49, 0x5d, 0x86, 0x52, 0x94, 0x95, - 0x5b, 0x86, 0x66, 0x7d, 0x51, 0x4f, 0x7a, 0x91, 0x6d, 0x6e, 0x72, 0x70, - 0x83, 0x4f, 0x9b, 0x9a, 0x8a, 0x77, 0x6a, 0xa1, 0x71, 0x60, 0x61, 0x98, - 0x67, 0x4e, 0x7a, 0x8a, 0x53, 0x6b, 0x99, 0xa0, 0x91, 0x46, 0x8a, 0x8b, - 0x47, 0x78, 0xa9, 0x7b, 0x71, 0x6c, 0x81, 0x68, 0x53, 0x73, 0xaf, 0x70, - 0x62, 0x6d, 0x69, 0x97, 0x70, 0x83, 0x5f, 0x7f, 0x81, 0x87, 0x65, 0x93, - 0x67, 0x87, 0x70, 0x82, 0x79, 0x9e, 0x80, 0x77, 0x6c, 0x80, 0x92, 0x81, - 0x8d, 0x8c, 0x89, 0x8b, 0x4e, 0x91, 0x77, 0x84, 0x99, 0x8c, 0x71, 0x88, - 0x57, 0x7a, 0x9a, 0x8c, 0x82, 0x9b, 0x97, 0x72, 0x69, 0xac, 0x7c, 0x62, - 0x85, 0x7d, 0x76, 0x7f, 0x59, 0x85, 0x68, 0x63, 0x94, 0x8b, 0x7b, 0x92, - 0x7b, 0x6f, 0x77, 0x98, 0x66, 0x78, 0x74, 0x99, 0x85, 0x8c, 0x94, 0x89, - 0x6c, 0x77, 0x89, 0x80, 0x79, 0x8a, 0xa6, 0x95, 0xa9, 0x86, 0x6f, 0x95, - 0x90, 0x69, 0x98, 0x85, 0xa0, 0x7f, 0x56, 0xab, 0x6f, 0x5a, 0x94, 0x8b, - 0x5a, 0x72, 0x61, 0x83, 0x54, 0x70, 0x8d, 0x8d, 0x9c, 0x5e, 0x36, 0x9b, - 0x84, 0x32, 0x6e, 0x84, 0x79, 0x72, 0x64, 0x95, 0x83, 0x58, 0x67, 0x6c, - 0x9e, 0x8d, 0x6e, 0x9e, 0x4f, 0x78, 0x71, 0x85, 0x75, 0x60, 0x4d, 0x7d, - 0x64, 0x89, 0x8e, 0x89, 0x6e, 0x92, 0x53, 0x7c, 0x86, 0x8f, 0xa9, 0xb0, - 0x8e, 0x5e, 0x76, 0x96, 0x65, 0x7c, 0x8a, 0x89, 0x75, 0x8f, 0x65, 0x94, - 0x6c, 0x6c, 0x8d, 0x6d, 0x66, 0x6a, 0x62, 0x98, 0x53, 0x8f, 0x67, 0x76, - 0x80, 0x89, 0x66, 0x60, 0x55, 0x81, 0x85, 0x61, 0x75, 0x78, 0x80, 0x92, - 0x6f, 0x79, 0x66, 0x64, 0x99, 0xa7, 0x88, 0xa1, 0x86, 0x6b, 0x94, 0x88, - 0x77, 0x83, 0x8f, 0x61, 0x72, 0x7c, 0x6f, 0x8f, 0x61, 0x56, 0x8a, 0x7b, - 0x66, 0x8b, 0x98, 0x9d, 0x82, 0x65, 0x77, 0x98, 0x55, 0x83, 0x7a, 0x8c, - 0x74, 0x79, 0x6e, 0x85, 0x82, 0x9a, 0x7d, 0x8d, 0x76, 0x72, 0x64, 0x81, - 0x9a, 0x8d, 0x9f, 0x7b, 0x7c, 0x7b, 0x7b, 0x84, 0x90, 0x6b, 0xa4, 0x84, - 0x98, 0x6f, 0x81, 0xb8, 0x6f, 0x6c, 0x87, 0x6d, 0x8c, 0x72, 0x53, 0x85, - 0x59, 0x4d, 0x9c, 0x94, 0x7d, 0x6f, 0x4f, 0x82, 0x5d, 0x71, 0x6e, 0x78, - 0x61, 0x61, 0x34, 0x71, 0x6a, 0x5a, 0x73, 0xa3, 0x89, 0x65, 0x4d, 0x80, - 0x5c, 0x51, 0x81, 0x8e, 0x6c, 0x53, 0x4a, 0x95, 0x3b, 0x72, 0xa7, 0x86, - 0x7f, 0x75, 0x61, 0xa3, 0x85, 0x6c, 0x99, 0x88, 0x7c, 0x64, 0x7a, 0x8d, - 0x81, 0x7b, 0x6a, 0x7b, 0x8f, 0x74, 0x6d, 0xae, 0x42, 0x67, 0x88, 0xa1, - 0x90, 0x4d, 0x7c, 0x7b, 0x62, 0x55, 0x9a, 0x80, 0x4d, 0x76, 0x5c, 0x88, - 0x60, 0x86, 0x6f, 0x65, 0x67, 0x77, 0x8a, 0x97, 0x99, 0x7c, 0x89, 0x78, - 0x92, 0xa7, 0x6a, 0x7f, 0x8e, 0x88, 0x9d, 0xa1, 0x7b, 0xb0, 0x69, 0x8c, - 0x7e, 0x51, 0x76, 0x84, 0x7d, 0x91, 0x7a, 0x88, 0x7b, 0x88, 0x92, 0x79, - 0x6d, 0x82, 0x6c, 0x8a, 0x99, 0x62, 0x82, 0x9d, 0x99, 0x97, 0x78, 0x6a, - 0x6e, 0x83, 0x64, 0x7d, 0x8c, 0x78, 0x7c, 0x7a, 0x7d, 0x7b, 0x77, 0x84, - 0x76, 0x57, 0x63, 0x85, 0x97, 0x94, 0x80, 0x92, 0x88, 0x73, 0x91, 0x91, - 0x8f, 0x6d, 0x99, 0x86, 0x91, 0x7f, 0x8b, 0x87, 0x98, 0x62, 0x84, 0x70, - 0x97, 0x7b, 0x2e, 0x9b, 0x6e, 0x2a, 0xa4, 0x9c, 0x79, 0x88, 0x54, 0x81, - 0x4f, 0x41, 0xa0, 0x85, 0xaf, 0x9a, 0x47, 0x5a, 0x7d, 0x62, 0x7a, 0x84, - 0x81, 0x6e, 0x41, 0xb4, 0x60, 0x47, 0x8f, 0x98, 0x6c, 0x3c, 0x3b, 0x73, - 0x59, 0x55, 0x7c, 0xb0, 0x6e, 0x5f, 0x61, 0x97, 0x73, 0x59, 0x9f, 0x92, - 0x89, 0x5c, 0x70, 0x96, 0x5c, 0x7c, 0x7c, 0x64, 0x7e, 0x54, 0x5c, 0x94, - 0x56, 0x73, 0x8d, 0x95, 0x59, 0x83, 0x6c, 0x99, 0x6e, 0x5e, 0x7a, 0x99, - 0x83, 0x93, 0x88, 0x76, 0x5a, 0x5a, 0xa5, 0x95, 0x5d, 0x63, 0x8f, 0x6e, - 0x74, 0x65, 0x85, 0x86, 0x98, 0x83, 0x7b, 0x8a, 0x5c, 0x5e, 0x7f, 0x88, - 0x78, 0x68, 0x8f, 0x9f, 0x94, 0x8d, 0x74, 0x7b, 0x6a, 0x91, 0x7a, 0x9a, - 0x70, 0x67, 0xb2, 0x92, 0x75, 0x4e, 0x74, 0xa3, 0x68, 0x74, 0x91, 0x80, - 0x55, 0x8e, 0x88, 0x73, 0x70, 0x81, 0xa1, 0xb8, 0x96, 0x48, 0x67, 0xb2, - 0x76, 0xa1, 0x98, 0xa9, 0x61, 0x6c, 0x5f, 0x98, 0x84, 0x92, 0xa9, 0x83, - 0x9e, 0x74, 0x7b, 0xa2, 0x6f, 0x72, 0x95, 0xa3, 0xb9, 0x80, 0x81, 0x7b, - 0x65, 0x6b, 0x96, 0x8b, 0xae, 0x79, 0x2b, 0x86, 0x5c, 0x2c, 0x8b, 0xa3, - 0x84, 0x74, 0x53, 0x7c, 0x54, 0x4a, 0x65, 0x89, 0xa6, 0x89, 0x47, 0x77, - 0x50, 0x6d, 0x8b, 0x94, 0x8a, 0x61, 0x32, 0x7c, 0x6f, 0x47, 0x78, 0xa2, - 0x9f, 0x42, 0x42, 0x71, 0x78, 0x76, 0x9e, 0x88, 0x70, 0x70, 0x56, 0x8a, - 0x83, 0x95, 0xa7, 0x9d, 0x9d, 0x88, 0x9a, 0x92, 0x48, 0x63, 0xaf, 0x91, - 0x6c, 0x75, 0x5d, 0x5e, 0x83, 0x86, 0xaa, 0x6f, 0x79, 0x84, 0x67, 0x79, - 0x63, 0x69, 0x8e, 0x81, 0x6a, 0x96, 0x8d, 0x86, 0x7b, 0x9f, 0xaa, 0x8e, - 0x63, 0x89, 0x9a, 0x7a, 0x5e, 0x7c, 0x87, 0x83, 0x81, 0x64, 0x7e, 0x59, - 0x6d, 0x5c, 0xa4, 0x72, 0x78, 0x85, 0x9b, 0x79, 0x85, 0x7d, 0x9c, 0x7d, - 0x9c, 0x5c, 0x66, 0x75, 0x66, 0x72, 0xb4, 0x7c, 0x83, 0x9e, 0x90, 0xae, - 0x69, 0x71, 0xb0, 0x84, 0x86, 0x50, 0x66, 0xab, 0x75, 0x96, 0xa8, 0x6c, - 0x87, 0x7b, 0x7e, 0x7c, 0x60, 0x55, 0x96, 0xb0, 0x6a, 0x79, 0x42, 0x9c, - 0x97, 0xa8, 0xb2, 0x9a, 0xa0, 0x84, 0x68, 0x90, 0x90, 0x98, 0x67, 0x9c, - 0xa3, 0x81, 0x71, 0xaa, 0x93, 0x6a, 0x84, 0x8c, 0x77, 0x79, 0x4d, 0x82, - 0x45, 0x1e, 0x7b, 0x94, 0x86, 0x86, 0x26, 0x82, 0x41, 0x6f, 0x8b, 0x86, - 0xa4, 0x80, 0x38, 0x71, 0x5e, 0x5b, 0x9a, 0x73, 0x86, 0x60, 0x5a, 0x9d, - 0x7b, 0x53, 0x89, 0xa0, 0x99, 0x76, 0x57, 0x81, 0x76, 0x5a, 0x9e, 0x85, - 0x5a, 0x7b, 0x56, 0x74, 0x71, 0x6a, 0x9c, 0x68, 0x7e, 0x76, 0x7d, 0x7f, - 0x52, 0x71, 0x85, 0xa2, 0x96, 0x63, 0x73, 0x7c, 0x7a, 0x97, 0x9f, 0x7c, - 0x77, 0x77, 0x59, 0x6b, 0x62, 0x77, 0xbc, 0x6b, 0x7c, 0x79, 0x75, 0x90, - 0x67, 0x82, 0x92, 0x9c, 0x81, 0x92, 0x84, 0x7a, 0x72, 0x5b, 0x86, 0x82, - 0x87, 0x73, 0x87, 0x7c, 0x57, 0x76, 0xa6, 0x7d, 0x7d, 0x94, 0x6a, 0x67, - 0x76, 0x89, 0x9a, 0x6d, 0x7d, 0xa4, 0x6d, 0x7e, 0x74, 0x7e, 0x8f, 0xad, - 0x99, 0x55, 0x5c, 0x82, 0x75, 0x9e, 0xae, 0x76, 0x6b, 0x93, 0x5d, 0x92, - 0x6e, 0x54, 0x88, 0x8f, 0x6a, 0x72, 0x64, 0x93, 0x6e, 0x63, 0x8c, 0xa7, - 0xa6, 0x7a, 0x57, 0x9f, 0x94, 0x91, 0xbd, 0xa4, 0x92, 0x7a, 0x68, 0x9d, - 0x7d, 0x6b, 0x6b, 0xbc, 0xad, 0x7a, 0x73, 0x92, 0x7b, 0x6d, 0x91, 0x6a, - 0x66, 0x8d, 0x34, 0x9b, 0x75, 0x3b, 0x93, 0x78, 0x88, 0x58, 0x1a, 0x7f, - 0x52, 0x61, 0xa3, 0xb1, 0x9c, 0x60, 0x1d, 0x90, 0x7b, 0x37, 0x9f, 0x84, - 0xa3, 0x6c, 0x2e, 0xac, 0x73, 0x62, 0x92, 0x9a, 0x94, 0x6b, 0x5c, 0x82, - 0x5f, 0x4c, 0x9a, 0x8c, 0x76, 0x69, 0x77, 0x5f, 0x5d, 0x91, 0x80, 0x9a, - 0x60, 0x4c, 0x7b, 0x57, 0x67, 0x6b, 0x92, 0x93, 0x64, 0x91, 0x55, 0x75, - 0x41, 0x82, 0x78, 0x68, 0xa2, 0x55, 0x6a, 0x69, 0x59, 0x70, 0x8a, 0x7b, - 0x70, 0x6e, 0x63, 0x83, 0x7f, 0xa4, 0x80, 0x85, 0x86, 0x93, 0x7e, 0x6f, - 0x7b, 0x94, 0xa4, 0xa7, 0x97, 0x7a, 0x87, 0x64, 0x4a, 0x97, 0x94, 0x6a, - 0x96, 0x73, 0x5e, 0x79, 0x6a, 0x99, 0x86, 0xa0, 0x93, 0xac, 0x79, 0x76, - 0x7f, 0x7b, 0xa7, 0x75, 0x8a, 0x71, 0x53, 0x87, 0x93, 0x7f, 0x9e, 0x7b, - 0x81, 0x70, 0x68, 0x8b, 0x8c, 0x9c, 0xaf, 0xa7, 0x6a, 0x9b, 0x49, 0x6d, - 0x67, 0x80, 0x8b, 0x86, 0x9f, 0x80, 0x74, 0x7a, 0x96, 0x74, 0xc8, 0x9d, - 0xa4, 0x74, 0x71, 0x6c, 0x75, 0x6a, 0x9a, 0x95, 0x97, 0x8c, 0x6e, 0x8a, - 0x85, 0x62, 0x5f, 0x7e, 0x9e, 0x6b, 0x48, 0x93, 0x44, 0x37, 0x83, 0xa2, - 0x97, 0x72, 0x25, 0x79, 0x32, 0x39, 0x68, 0x8f, 0x93, 0x61, 0x2b, 0x96, - 0x94, 0x43, 0x82, 0x6e, 0x8f, 0x6d, 0x53, 0x9b, 0x65, 0x50, 0x70, 0x9d, - 0x7d, 0x53, 0x3b, 0x86, 0x77, 0x6c, 0xa6, 0x90, 0x6b, 0x3e, 0x7b, 0x7a, - 0x50, 0x81, 0xb4, 0x76, 0xa5, 0x74, 0x8b, 0x73, 0x79, 0x69, 0xa8, 0x9a, - 0x82, 0x4a, 0x5e, 0x6c, 0x8d, 0x66, 0xa3, 0x80, 0x8d, 0x74, 0x5b, 0x7c, - 0x77, 0xaa, 0x82, 0x69, 0x5e, 0x7d, 0x7f, 0x63, 0xa3, 0x8c, 0xb3, 0x9a, - 0x81, 0x8f, 0x7b, 0x77, 0x60, 0x89, 0x6a, 0x82, 0x5a, 0x7a, 0x71, 0x61, - 0x93, 0x73, 0x8b, 0xb0, 0xa2, 0x92, 0x7c, 0x84, 0x8b, 0x72, 0x91, 0x8d, - 0x91, 0x80, 0x6c, 0x75, 0x7a, 0xb3, 0x95, 0x5e, 0xa5, 0x5d, 0x54, 0x8b, - 0x63, 0x91, 0xa7, 0x68, 0x96, 0x4c, 0x5a, 0x86, 0x76, 0x82, 0xb6, 0xa0, - 0x68, 0x6b, 0x53, 0x76, 0x60, 0x65, 0x90, 0xaf, 0x82, 0x66, 0x80, 0x7b, - 0x84, 0xa0, 0xb0, 0xb8, 0x81, 0x6e, 0x81, 0x8a, 0x74, 0x6e, 0x97, 0xa8, - 0x89, 0x7b, 0x7b, 0x6e, 0x63, 0x74, 0x5a, 0x7b, 0x7e, 0x84, 0x40, 0x95, - 0x73, 0x3c, 0x7c, 0x72, 0x9b, 0x92, 0x27, 0x87, 0x69, 0x5b, 0x99, 0x8a, - 0xa8, 0x65, 0x36, 0x8f, 0x86, 0x3e, 0xa1, 0x79, 0x9f, 0x4d, 0x41, 0xc5, - 0x8c, 0x6a, 0x7e, 0x7f, 0x68, 0x49, 0x5c, 0x91, 0x50, 0x6a, 0x8c, 0x81, - 0x75, 0x4c, 0x6a, 0x74, 0x8a, 0x87, 0xa0, 0x93, 0x7e, 0x6d, 0x52, 0x79, - 0x86, 0x6a, 0x68, 0x6c, 0x83, 0x67, 0x79, 0x73, 0x6f, 0x72, 0x97, 0x84, - 0x8b, 0x78, 0x64, 0x69, 0x8f, 0x92, 0x86, 0x61, 0x5d, 0x85, 0x70, 0x64, - 0x7d, 0xa3, 0x92, 0xa0, 0x72, 0x71, 0x5d, 0x63, 0x7c, 0x70, 0xaf, 0x6f, - 0x93, 0x6a, 0x7e, 0x7f, 0x64, 0xab, 0x85, 0x73, 0x8f, 0x8a, 0x7e, 0x5f, - 0x7a, 0x6f, 0xaa, 0x71, 0x97, 0x7d, 0x60, 0x7c, 0x48, 0x69, 0xa9, 0xaa, - 0x98, 0x7c, 0x61, 0x85, 0x66, 0x97, 0xa2, 0x73, 0x74, 0x65, 0x52, 0x67, - 0x79, 0x8a, 0x79, 0x71, 0x85, 0x6e, 0x6d, 0x67, 0x5e, 0x7f, 0xb9, 0x93, - 0x96, 0x53, 0x69, 0x6e, 0x7f, 0x8f, 0xab, 0x93, 0xa9, 0x70, 0x6e, 0x71, - 0x7e, 0x87, 0x98, 0x7a, 0xae, 0x90, 0x64, 0x88, 0x8a, 0x4f, 0x6d, 0x9e, - 0xac, 0x7e, 0x31, 0x92, 0x50, 0x26, 0x95, 0xb2, 0x90, 0x99, 0x0c, 0x84, - 0x40, 0x4f, 0x8f, 0x76, 0xa4, 0x46, 0x4c, 0x9d, 0x8b, 0x57, 0x81, 0x79, - 0x7b, 0x47, 0x4d, 0x9c, 0x5f, 0x3b, 0x6f, 0x90, 0x7a, 0x3f, 0x66, 0x9d, - 0x6c, 0x45, 0x8b, 0x71, 0x79, 0x62, 0x72, 0x78, 0x93, 0x95, 0x7e, 0x86, - 0x7a, 0x6b, 0x77, 0x74, 0x6b, 0x86, 0xa4, 0x7e, 0x84, 0x48, 0x78, 0x75, - 0x6e, 0x8b, 0x8e, 0x56, 0x69, 0x7b, 0x59, 0x68, 0x5d, 0x77, 0x69, 0x66, - 0x67, 0x9f, 0x75, 0x7b, 0x76, 0x64, 0xc1, 0x78, 0x7d, 0x74, 0x82, 0x73, - 0x73, 0x90, 0xb8, 0x82, 0x7e, 0x70, 0x7b, 0x7a, 0x64, 0xa1, 0x7e, 0x85, - 0x83, 0x81, 0x60, 0x7b, 0x91, 0x82, 0x6f, 0x95, 0xa0, 0x86, 0x6d, 0x88, - 0x75, 0x8d, 0x94, 0x90, 0x76, 0x6d, 0x6e, 0x79, 0x64, 0x74, 0xa8, 0xb1, - 0x92, 0x6e, 0x61, 0x79, 0x74, 0x91, 0x95, 0x74, 0x65, 0x74, 0x5e, 0x7f, - 0x8b, 0x60, 0x9b, 0x9f, 0x74, 0x77, 0x4c, 0x66, 0x7c, 0x80, 0x97, 0x98, - 0x9d, 0x86, 0x55, 0x8a, 0x8a, 0x79, 0x8c, 0x82, 0xb0, 0x7d, 0x63, 0x8c, - 0x5d, 0x5b, 0x82, 0x58, 0x84, 0x56, 0x51, 0x92, 0x75, 0x24, 0x97, 0x92, - 0x75, 0x6e, 0x19, 0x8e, 0x47, 0x3e, 0x7b, 0x7b, 0x87, 0x6b, 0x3f, 0xa9, - 0x59, 0x40, 0x86, 0x74, 0x69, 0x4a, 0x2d, 0xad, 0x91, 0x62, 0xb2, 0xa9, - 0x74, 0x6c, 0x47, 0x94, 0x51, 0x75, 0xb2, 0x6f, 0x75, 0x4b, 0x60, 0xa2, - 0x8e, 0x6a, 0xa4, 0x79, 0x6f, 0x57, 0x80, 0x8c, 0x6c, 0x8e, 0x9e, 0x74, - 0x70, 0x5f, 0x66, 0x80, 0x80, 0x89, 0xb5, 0x8a, 0x7a, 0x96, 0x87, 0x7a, - 0x7b, 0x85, 0x90, 0x79, 0x59, 0x6d, 0x77, 0x8c, 0x8f, 0x82, 0xb3, 0x9c, - 0x6a, 0x6a, 0x6b, 0x70, 0x77, 0x89, 0x96, 0x86, 0x94, 0x72, 0x7e, 0x72, - 0xa9, 0x93, 0x8d, 0x7a, 0x6d, 0x8f, 0x66, 0x72, 0x9a, 0x91, 0x9e, 0x98, - 0xa0, 0x8b, 0x50, 0x76, 0x5c, 0x74, 0xbc, 0x9a, 0x98, 0x73, 0x80, 0x7d, - 0x73, 0x7c, 0xc0, 0x8b, 0x86, 0x7a, 0x66, 0x86, 0x83, 0x72, 0x8f, 0x96, - 0x98, 0x56, 0x45, 0x7b, 0x77, 0x92, 0xac, 0x8a, 0xae, 0x43, 0x33, 0x73, - 0x78, 0x83, 0x98, 0x84, 0x86, 0x78, 0x54, 0x7e, 0x70, 0x5f, 0xa6, 0xa1, - 0x94, 0x81, 0x73, 0x8d, 0x83, 0x5b, 0x88, 0x71, 0xb2, 0x91, 0x50, 0x99, - 0x6b, 0x47, 0x72, 0x92, 0x87, 0x6d, 0x07, 0x99, 0x57, 0x3d, 0x8d, 0x83, - 0x9d, 0x49, 0x40, 0x9d, 0x5c, 0x57, 0x95, 0x73, 0x6e, 0x4b, 0x49, 0xab, - 0x97, 0x58, 0x8b, 0x7a, 0x7a, 0x48, 0x47, 0x8b, 0x7e, 0x5d, 0xa9, 0x6d, - 0x8a, 0x3f, 0x60, 0x82, 0x86, 0x98, 0xa9, 0x7c, 0x74, 0x59, 0x9b, 0x80, - 0x4e, 0x75, 0x9c, 0x5e, 0x75, 0x8c, 0x67, 0x7e, 0x78, 0x75, 0x87, 0x6c, - 0x79, 0x73, 0x63, 0x77, 0x6e, 0x7a, 0x8d, 0x73, 0x4e, 0x72, 0x4a, 0x7c, - 0x8f, 0x79, 0x70, 0x7a, 0x70, 0x73, 0x7b, 0x7a, 0x62, 0xa1, 0x7b, 0x63, - 0x9a, 0x89, 0x76, 0x64, 0x84, 0x7d, 0x9c, 0x94, 0xb0, 0x7f, 0x6c, 0x7b, - 0x8d, 0x89, 0x89, 0x7b, 0x9d, 0x99, 0x64, 0x8b, 0x5c, 0x88, 0xa6, 0x8e, - 0x81, 0x86, 0x7e, 0x85, 0x73, 0x72, 0xad, 0x5d, 0x5f, 0x7e, 0x63, 0x74, - 0x64, 0xa1, 0x9c, 0x83, 0x7c, 0x83, 0x7b, 0x7b, 0x71, 0xa0, 0x9e, 0xaf, - 0x89, 0x79, 0x4c, 0x7c, 0x8c, 0x78, 0x91, 0x87, 0x8a, 0x87, 0x5e, 0x85, - 0x7b, 0x61, 0x9c, 0x88, 0xa5, 0x8d, 0x7c, 0x9c, 0x6b, 0x47, 0x95, 0x85, - 0x81, 0x80, 0x59, 0xb2, 0x4f, 0x3d, 0xae, 0x8c, 0x8d, 0x71, 0x11, 0x95, - 0x31, 0x65, 0x9d, 0xa0, 0x8e, 0x64, 0x42, 0xb9, 0x6a, 0x5c, 0x91, 0x82, - 0x91, 0x50, 0x33, 0xb2, 0x7a, 0x54, 0xac, 0x88, 0x92, 0x61, 0x4e, 0xad, - 0x65, 0x5c, 0x91, 0xb0, 0x72, 0x65, 0x4a, 0x79, 0x68, 0x77, 0x75, 0x5f, - 0x79, 0x6d, 0x6f, 0x7c, 0x4d, 0x71, 0xb8, 0x78, 0x8a, 0x87, 0x6e, 0x72, - 0x7d, 0x79, 0x87, 0x80, 0x5a, 0x78, 0x77, 0x78, 0x80, 0x8f, 0x8c, 0x56, - 0x7a, 0x8b, 0x62, 0x82, 0x5a, 0x96, 0x82, 0x68, 0x71, 0x5d, 0x75, 0x65, - 0x93, 0xb5, 0x71, 0x82, 0x82, 0x8a, 0x4b, 0x7c, 0x62, 0x6f, 0xc1, 0x86, - 0x9d, 0x90, 0x63, 0x71, 0x86, 0x9e, 0x9f, 0x77, 0x90, 0x97, 0x68, 0x81, - 0x5a, 0x8c, 0xab, 0x5e, 0x81, 0x76, 0x83, 0x79, 0x8f, 0xa1, 0x89, 0x79, - 0x81, 0x8a, 0x7e, 0x6c, 0x65, 0x79, 0xc7, 0x89, 0x92, 0x68, 0x78, 0x70, - 0x65, 0x96, 0x9e, 0x82, 0x7d, 0x5f, 0x7b, 0x77, 0x72, 0x84, 0x7e, 0x92, - 0x97, 0x7b, 0x6e, 0x67, 0x81, 0xa1, 0x9a, 0xab, 0x8d, 0x78, 0x61, 0x78, - 0x52, 0x66, 0xaa, 0x77, 0x75, 0xa3, 0x5e, 0xa0, 0x51, 0x40, 0x68, 0xb0, - 0x9a, 0x93, 0x11, 0x82, 0x69, 0x48, 0x9c, 0x77, 0x8d, 0x62, 0x36, 0xac, - 0x6c, 0x4c, 0xa3, 0xab, 0x8f, 0x32, 0x4f, 0xa9, 0x80, 0x68, 0xab, 0x7a, - 0x90, 0x61, 0x5c, 0xa5, 0x84, 0x4c, 0x8c, 0x7a, 0x95, 0x54, 0x72, 0xa0, - 0x66, 0x85, 0xb3, 0x91, 0x69, 0x64, 0x68, 0x56, 0x66, 0x8d, 0xa0, 0x9f, - 0x7a, 0x88, 0x5d, 0x7d, 0x48, 0x80, 0x7f, 0x7c, 0x7c, 0x99, 0x65, 0x81, - 0x73, 0x8b, 0x8c, 0x61, 0x44, 0x60, 0x53, 0x8e, 0x64, 0x80, 0x9c, 0x74, - 0x5d, 0x70, 0x8f, 0x5a, 0x68, 0x7a, 0x82, 0xa1, 0x75, 0x7b, 0x83, 0x60, - 0x75, 0x5e, 0xa2, 0x94, 0x6a, 0x88, 0x78, 0x71, 0x95, 0x70, 0x8b, 0x86, - 0x7e, 0x94, 0x5f, 0x65, 0x5f, 0xb1, 0x97, 0x99, 0x94, 0x84, 0x88, 0x7d, - 0x50, 0x8c, 0xaa, 0x81, 0x7b, 0x7c, 0x77, 0x65, 0x5e, 0x91, 0x9c, 0x89, - 0x8c, 0x85, 0x75, 0x62, 0x7b, 0x78, 0xc3, 0x7a, 0x62, 0x8c, 0x66, 0x6f, - 0x79, 0x7a, 0x9c, 0x6d, 0x7c, 0x6b, 0x5c, 0x7d, 0x6d, 0x54, 0x93, 0x87, - 0x7a, 0x7a, 0x50, 0x85, 0x60, 0x56, 0x5e, 0x6b, 0x90, 0x7c, 0x52, 0xa5, - 0x54, 0x42, 0x7b, 0x75, 0x83, 0x8c, 0x2c, 0xa6, 0x6f, 0x62, 0x78, 0x78, - 0x86, 0x36, 0x4b, 0xaa, 0x86, 0x54, 0x92, 0x8d, 0x7f, 0x53, 0x37, 0xbe, - 0x86, 0x7a, 0x90, 0x7e, 0x8e, 0x50, 0x58, 0xa6, 0x82, 0x58, 0x73, 0x74, - 0x66, 0x5c, 0x6a, 0x7f, 0xa2, 0x69, 0xbd, 0xa9, 0x74, 0x76, 0x75, 0x6f, - 0x45, 0x6c, 0xa5, 0x79, 0x82, 0x67, 0x56, 0x7c, 0x7f, 0x81, 0x67, 0x6d, - 0x81, 0x87, 0x71, 0x69, 0x69, 0x81, 0x85, 0x84, 0x5a, 0x8c, 0x5f, 0x73, - 0x80, 0x9c, 0x9e, 0x90, 0x77, 0xa0, 0x9c, 0x6c, 0x73, 0x8a, 0x84, 0x72, - 0x87, 0xa1, 0x67, 0x64, 0x5d, 0x9b, 0x9d, 0x9b, 0x97, 0x83, 0x5f, 0x61, - 0x77, 0x91, 0xa0, 0x8f, 0x8a, 0x6c, 0x45, 0x5f, 0x6d, 0xa6, 0x9b, 0x76, - 0x86, 0x93, 0x91, 0x7d, 0x54, 0x61, 0xa4, 0x6a, 0x5b, 0x69, 0x5f, 0x6d, - 0x83, 0xaf, 0xa0, 0x78, 0x9d, 0x62, 0x65, 0x69, 0x5f, 0x78, 0xbf, 0x91, - 0x7b, 0x7b, 0x52, 0x5d, 0x70, 0x78, 0xa9, 0x87, 0x93, 0x74, 0x61, 0x74, - 0x8c, 0x61, 0x97, 0x86, 0x9b, 0x7c, 0x7d, 0x75, 0x4b, 0x64, 0xa7, 0x81, - 0x8a, 0x9c, 0x29, 0xa2, 0x5f, 0x38, 0x6a, 0xb0, 0x82, 0x53, 0x1a, 0xa7, - 0x38, 0x47, 0x97, 0x90, 0x8d, 0x41, 0x25, 0xa7, 0x65, 0x63, 0x8b, 0x79, - 0x8f, 0x3e, 0x21, 0xd0, 0x5e, 0x5d, 0x9d, 0x68, 0x75, 0x3e, 0x68, 0xb6, - 0x6a, 0x50, 0x9a, 0x71, 0x81, 0x45, 0x6d, 0x9a, 0x7f, 0x86, 0x9c, 0x63, - 0x7d, 0x74, 0x69, 0x7d, 0x5a, 0x6a, 0x8d, 0x72, 0x6b, 0x69, 0x4c, 0x6f, - 0x7c, 0x8e, 0xa6, 0x83, 0x70, 0x65, 0x5f, 0x78, 0x69, 0x67, 0x7f, 0x8d, - 0x58, 0x76, 0x4a, 0x85, 0x80, 0x89, 0x9f, 0x91, 0x52, 0x62, 0x72, 0x60, - 0x7b, 0x5c, 0x77, 0x6f, 0x9d, 0xa4, 0x98, 0x70, 0x6f, 0xad, 0x94, 0x9f, - 0x7b, 0x89, 0x74, 0x7e, 0x5d, 0x8d, 0xab, 0x98, 0x8f, 0x90, 0x82, 0x84, - 0x60, 0x7c, 0xb7, 0x8e, 0x79, 0x83, 0x56, 0x86, 0x87, 0x79, 0x95, 0x75, - 0x78, 0x71, 0x58, 0x73, 0x87, 0x5d, 0xc6, 0x9f, 0x75, 0x61, 0x4f, 0x71, - 0x91, 0x88, 0xb3, 0x8c, 0x7d, 0x7c, 0x6a, 0x75, 0x6d, 0x66, 0x8e, 0x94, - 0x96, 0x74, 0x59, 0x6f, 0x6d, 0x65, 0xb0, 0x8e, 0x7b, 0x89, 0x7a, 0x6a, - 0x7d, 0x57, 0x82, 0x7a, 0x61, 0x9f, 0x50, 0xab, 0x57, 0x46, 0x86, 0x8d, - 0xa3, 0x96, 0x18, 0xab, 0x51, 0x6e, 0xb3, 0x7e, 0x90, 0x6d, 0x6d, 0xc0, - 0x54, 0x35, 0x96, 0x84, 0x8e, 0x49, 0x28, 0xe4, 0x81, 0x5f, 0x9b, 0x87, - 0x8c, 0x33, 0x56, 0xb4, 0x61, 0x5e, 0x8b, 0x81, 0x99, 0x61, 0x6b, 0x96, - 0x75, 0x82, 0x9e, 0x7c, 0x90, 0x63, 0x64, 0x6b, 0x55, 0x6e, 0xb6, 0x7f, - 0x5f, 0x55, 0x65, 0x60, 0x35, 0x8a, 0x85, 0x91, 0x4d, 0x62, 0x90, 0x90, - 0x57, 0x5a, 0x9f, 0x7b, 0x4c, 0x86, 0x73, 0x83, 0x4a, 0x6d, 0xb0, 0x67, - 0x65, 0x89, 0x54, 0x68, 0x89, 0x7b, 0x72, 0x4f, 0x7a, 0x93, 0x61, 0x7e, - 0x79, 0x89, 0x8f, 0x9c, 0x7b, 0x70, 0x48, 0x67, 0x82, 0x75, 0xaa, 0x92, - 0x9a, 0x8f, 0x79, 0x8c, 0x64, 0x94, 0x98, 0x83, 0x7c, 0x8f, 0x5c, 0x77, - 0x70, 0x90, 0x91, 0x88, 0x7d, 0x51, 0x5d, 0x5d, 0x8b, 0x9f, 0xbc, 0x78, - 0x9e, 0x73, 0x67, 0x6d, 0x82, 0x8d, 0xc9, 0x86, 0x96, 0x6a, 0x5d, 0x79, - 0x7e, 0x6b, 0xb2, 0x79, 0x88, 0x85, 0x65, 0x73, 0x75, 0x6b, 0x9e, 0x7f, - 0x8e, 0x94, 0x8e, 0x7d, 0x74, 0x61, 0x97, 0x56, 0x97, 0x6b, 0x30, 0xb6, - 0x5f, 0x5a, 0xaa, 0xa5, 0x85, 0x5d, 0x01, 0xbc, 0x79, 0x63, 0x6e, 0x82, - 0x72, 0x26, 0x4f, 0xc8, 0x98, 0x56, 0x85, 0x9a, 0x81, 0x1f, 0x48, 0xcf, - 0x84, 0x74, 0x75, 0x87, 0xae, 0x43, 0x6f, 0xdf, 0x6a, 0x4e, 0x97, 0x5d, - 0x8f, 0x37, 0x55, 0x89, 0x7d, 0x82, 0xb1, 0x89, 0x6d, 0x52, 0x65, 0x8b, - 0x71, 0x87, 0x8d, 0x6a, 0x99, 0x5d, 0x65, 0x78, 0x67, 0x8d, 0x7b, 0x51, - 0x60, 0x8a, 0x59, 0x72, 0x78, 0x93, 0x88, 0x75, 0x46, 0x60, 0x6e, 0x79, - 0x7b, 0x9d, 0x9c, 0x8c, 0x5c, 0x7c, 0x69, 0x71, 0x60, 0x6f, 0xb0, 0x7d, - 0x4c, 0x5e, 0x88, 0x77, 0x74, 0x6a, 0x6f, 0x9a, 0xa2, 0x83, 0x48, 0x5a, - 0x6e, 0xa2, 0x8b, 0x7a, 0x65, 0x5b, 0x4b, 0x80, 0x5b, 0x8f, 0xaf, 0x8e, - 0x93, 0x4a, 0x59, 0x6e, 0x5e, 0x89, 0x91, 0x87, 0x73, 0x6a, 0x47, 0x6c, - 0x6c, 0x81, 0xad, 0x5a, 0x76, 0x51, 0x51, 0x6c, 0x80, 0x92, 0x9d, 0xae, - 0x90, 0x71, 0x6c, 0x7a, 0x7c, 0x84, 0xa7, 0x7d, 0x82, 0x7c, 0x80, 0x59, - 0x7d, 0x86, 0xa9, 0x94, 0x8e, 0x7b, 0x7c, 0x67, 0x67, 0x66, 0x8f, 0x49, - 0x5d, 0xa4, 0x4a, 0xbc, 0x5a, 0x34, 0xa7, 0xaa, 0x9e, 0x86, 0x17, 0xc0, - 0x53, 0x67, 0x76, 0xae, 0x8d, 0x37, 0x4a, 0xd6, 0x76, 0x69, 0x95, 0x7a, - 0x8a, 0x0e, 0x3f, 0xe8, 0x60, 0x4d, 0x9e, 0x90, 0xad, 0x44, 0x46, 0xc5, - 0x4c, 0x6e, 0x72, 0x8c, 0x89, 0x49, 0x51, 0xa0, 0x60, 0x84, 0x84, 0x9d, - 0xa4, 0x5a, 0x84, 0x8d, 0x69, 0x6a, 0x97, 0x78, 0x72, 0x66, 0x72, 0x9b, - 0x74, 0x7a, 0x95, 0x7c, 0x7a, 0x6e, 0x74, 0x7f, 0x65, 0x94, 0x77, 0x7e, - 0x85, 0x6d, 0x65, 0x7b, 0x63, 0x7b, 0x87, 0x49, 0x80, 0x74, 0x74, 0x85, - 0x6e, 0x78, 0xad, 0x66, 0x8a, 0x65, 0x54, 0x7c, 0x4e, 0x62, 0x97, 0x7f, - 0x82, 0x6c, 0x58, 0x79, 0x91, 0x94, 0xb3, 0x7a, 0x88, 0x82, 0x60, 0x7f, - 0x8c, 0xa7, 0x7b, 0x93, 0x77, 0x49, 0x6f, 0x6f, 0x5a, 0x8d, 0x93, 0x8b, - 0x87, 0x59, 0x7d, 0x5e, 0x83, 0x7e, 0x8c, 0x7a, 0x91, 0x4e, 0x6f, 0x89, - 0x8a, 0x87, 0x8b, 0x85, 0x8e, 0x43, 0x63, 0x8d, 0x90, 0x6c, 0xa5, 0x73, - 0x8a, 0x78, 0x5f, 0x73, 0x88, 0x57, 0x9e, 0x8f, 0x7f, 0x91, 0x70, 0x77, - 0x8a, 0x76, 0xa2, 0x77, 0x53, 0x86, 0x51, 0xd8, 0xa9, 0x5b, 0x9b, 0x96, - 0x7c, 0x71, 0x01, 0xd4, 0x56, 0x4a, 0x95, 0xab, 0x91, 0x54, 0x45, 0xe5, - 0x74, 0x4f, 0x87, 0x6a, 0xa2, 0x3e, 0x47, 0xff, 0x91, 0x4d, 0x94, 0x97, - 0x6d, 0x74, 0x77, 0xe0, 0x5d, 0x4e, 0x5f, 0x73, 0x70, 0x3a, 0x68, 0xb2, - 0x78, 0x61, 0x8c, 0x77, 0xa8, 0x57, 0x8c, 0x99, 0x23, 0x5a, 0x84, 0x78, - 0x9b, 0x7f, 0x5e, 0xa0, 0x49, 0x84, 0x83, 0x94, 0x99, 0x4d, 0x8d, 0x9a, - 0x86, 0x90, 0x9b, 0x51, 0x75, 0x73, 0x78, 0x89, 0x59, 0x64, 0x78, 0x91, - 0x72, 0x9c, 0x72, 0x7e, 0x65, 0x6a, 0x80, 0xaa, 0x94, 0x65, 0x6d, 0x87, - 0x73, 0x93, 0x97, 0x7d, 0x99, 0x63, 0x75, 0x89, 0x67, 0xa1, 0x90, 0x7f, - 0x88, 0x65, 0x6d, 0x8f, 0x7d, 0x62, 0x91, 0xa7, 0x8b, 0x73, 0x51, 0x88, - 0x66, 0x66, 0x99, 0xa7, 0x7c, 0x54, 0x82, 0x67, 0x64, 0x8a, 0x95, 0x7c, - 0x8a, 0x5d, 0x5e, 0x68, 0x4b, 0x75, 0x92, 0x7a, 0x9f, 0x66, 0x71, 0x8d, - 0x76, 0x72, 0x8e, 0x77, 0x76, 0x8c, 0x5b, 0x88, 0x9a, 0x92, 0x7c, 0x74, - 0x95, 0xaa, 0x71, 0x77, 0x97, 0x93, 0x9e, 0x62, 0x96, 0x6a, 0x49, 0xd8, - 0x81, 0x99, 0xae, 0x87, 0x6c, 0x76, 0x3e, 0xd9, 0x6e, 0x95, 0xa3, 0x86, - 0x60, 0x6c, 0x5c, 0xbe, 0x98, 0x8a, 0x99, 0x7c, 0x47, 0x45, 0x69, 0xeb, - 0x9d, 0x7d, 0xbb, 0x90, 0x66, 0x69, 0x70, 0xc6, 0x7b, 0x59, 0x9e, 0x87, - 0x58, 0x76, 0x7c, 0xae, 0x72, 0x7d, 0x9f, 0x92, 0x82, 0x58, 0x51, 0x7a, - 0x5d, 0x77, 0xa8, 0x7c, 0x56, 0x68, 0x88, 0x8a, 0x7e, 0x8a, 0x98, 0x68, - 0x64, 0x79, 0x6e, 0x7a, 0x60, 0x96, 0x98, 0x60, 0x60, 0x71, 0x60, 0x8e, - 0x7c, 0x8c, 0x92, 0x92, 0x77, 0x80, 0x90, 0x91, 0x81, 0x82, 0x9c, 0x80, - 0x61, 0x7f, 0x5a, 0x8e, 0x88, 0x7c, 0x8e, 0x79, 0x69, 0x8e, 0x4e, 0x7e, - 0x84, 0x9e, 0x67, 0x72, 0x5c, 0x78, 0x7b, 0x8c, 0x65, 0x7d, 0x8e, 0xa4, - 0x5e, 0x7a, 0x5c, 0x97, 0x6a, 0x81, 0xab, 0x85, 0x4d, 0x73, 0x83, 0x96, - 0x8b, 0x7d, 0xa6, 0x69, 0x74, 0x86, 0x73, 0x79, 0x52, 0x8c, 0xa0, 0x86, - 0x64, 0x7b, 0x84, 0x77, 0x87, 0x93, 0x7d, 0x6d, 0x98, 0x6d, 0x88, 0x5f, - 0x7c, 0x84, 0x92, 0x82, 0x81, 0x76, 0x85, 0x77, 0x98, 0x85, 0x88, 0x68, - 0x7d, 0x71, 0x3c, 0xf1, 0x83, 0x86, 0xa2, 0xb3, 0x6e, 0x77, 0x53, 0xe8, - 0xa8, 0xc7, 0xb3, 0x83, 0x93, 0x83, 0x63, 0xe8, 0x94, 0xb3, 0x86, 0x6e, - 0x75, 0x5d, 0x54, 0xf0, 0x89, 0xa7, 0x94, 0xb1, 0x7e, 0x91, 0x9a, 0xb8, - 0x91, 0x7e, 0x99, 0x50, 0x71, 0x82, 0x8a, 0x91, 0x7a, 0x8a, 0x8b, 0x80, - 0x64, 0x6a, 0x5f, 0xbe, 0x5d, 0x96, 0xb1, 0x82, 0x45, 0x71, 0x8b, 0x95, - 0x7c, 0x9b, 0x89, 0x6d, 0x5b, 0x73, 0x81, 0x90, 0x76, 0xab, 0xa6, 0x88, - 0x62, 0x7d, 0x75, 0x99, 0x7a, 0x8b, 0x6e, 0x9b, 0x83, 0x89, 0x99, 0x93, - 0x81, 0x9e, 0x8a, 0x76, 0x75, 0x7d, 0x6c, 0x93, 0x68, 0x7a, 0x8d, 0x78, - 0x88, 0x93, 0x66, 0xa5, 0x6c, 0xae, 0xb1, 0x83, 0x72, 0x8f, 0x6b, 0x7b, - 0x79, 0x9b, 0x98, 0x7c, 0x82, 0x84, 0x7d, 0x7d, 0x71, 0x7c, 0xb0, 0x81, - 0x74, 0x89, 0x72, 0x89, 0x98, 0xa0, 0x7d, 0x62, 0x2f, 0x50, 0x7d, 0x8b, - 0x4c, 0x83, 0x87, 0x89, 0x57, 0x9e, 0x92, 0x8c, 0x81, 0x7e, 0xb9, 0x95, - 0x7f, 0x76, 0x8e, 0x90, 0x9d, 0x68, 0x78, 0x95, 0x7d, 0xab, 0x84, 0x8a, - 0x64, 0x9f, 0x80, 0x94, 0x8d, 0x89, 0x76, 0x8e, 0x6f, 0x8b, 0x75, 0x7d, - 0x89, 0x74, 0x67, 0x8a, 0x7d, 0x63, 0x79, 0x6d, 0x79, 0x8a, 0x78, 0x7f, - 0x7a, 0x9b, 0x70, 0x70, 0x84, 0x86, 0x80, 0x95, 0x5a, 0x77, 0x80, 0x91, - 0x9c, 0x92, 0x76, 0x81, 0x69, 0x89, 0x78, 0xa5, 0x7a, 0x8d, 0x86, 0x64, - 0x8f, 0x8d, 0x7d, 0xa1, 0x8c, 0x7b, 0x77, 0x7e, 0x80, 0x93, 0x86, 0x68, - 0x90, 0x9c, 0x71, 0x8c, 0x68, 0x52, 0x85, 0x88, 0x89, 0x92, 0x64, 0x8f, - 0x74, 0x64, 0x7c, 0x88, 0x8d, 0x97, 0x77, 0x97, 0x91, 0xac, 0x74, 0x7f, - 0x60, 0x7e, 0x6e, 0x70, 0x86, 0x83, 0x7f, 0x81, 0x6f, 0x94, 0x62, 0xa4, - 0x86, 0x7d, 0x90, 0x7c, 0x89, 0x63, 0x7b, 0x89, 0x75, 0xa1, 0x67, 0x69, - 0xa6, 0x76, 0x69, 0x9c, 0x71, 0x79, 0x76, 0x7a, 0x8e, 0x78, 0x94, 0x75, - 0x5a, 0x76, 0x6b, 0x91, 0x84, 0x75, 0x72, 0x93, 0x79, 0x7e, 0x75, 0x9a, - 0x6f, 0x7a, 0x7b, 0x80, 0x5f, 0x90, 0x74, 0x7d, 0x9b, 0x76, 0x70, 0x89, - 0x8f, 0x5f, 0x7f, 0x9c, 0x93, 0x6d, 0x81, 0x7f, 0x8d, 0x7d, 0x74, 0x5d, - 0x75, 0x88, 0x7b, 0x91, 0x75, 0x6b, 0x7f, 0x8c, 0x71, 0x74, 0x87, 0x88, - 0x83, 0x75, 0x77, 0x96, 0x7f, 0x67, 0x7d, 0x95, 0x81, 0x5c, 0x71, 0x5c, - 0x6e, 0x75, 0x86, 0x92, 0x5d, 0x7a, 0x77, 0x9f, 0x6e, 0x79, 0x68, 0x60, - 0x94, 0x88, 0x88, 0x88, 0x79, 0x7e, 0x8a, 0x6d, 0x84, 0xa7, 0x5b, 0x8e, - 0x67, 0x9c, 0x7e, 0x75, 0x82, 0x96, 0x7c, 0x7b, 0x72, 0x85, 0x8c, 0xa3, - 0x96, 0x5b, 0x93, 0x67, 0x7e, 0x9f, 0x71, 0x82, 0x79, 0x8c, 0x93, 0x9d, - 0x6b, 0x90, 0x8a, 0x8a, 0x55, 0x82, 0x94, 0x74, 0x7d, 0xaa, 0x81, 0x78, - 0x8a, 0x8d, 0x83, 0x7b, 0x97, 0x92, 0x68, 0x64, 0x8c, 0x5d, 0x78, 0x9b, - 0x73, 0x95, 0x78, 0x77, 0x6f, 0x61, 0x7c, 0x9d, 0x85, 0x6e, 0x84, 0x4c, - 0x87, 0x57, 0x93, 0x68, 0x8e, 0x77, 0x78, 0x72, 0x87, 0x91, 0x5f, 0x7e, - 0xa6, 0x75, 0x66, 0x86, 0x7a, 0x7d, 0x70, 0x6f, 0x87, 0x8b, 0x74, 0x85, - 0x7d, 0x8b, 0x7f, 0x70, 0x7e, 0x82, 0x84, 0x75, 0x89, 0xa6, 0x7b, 0x7a, - 0xa5, 0x69, 0x73, 0x74, 0x82, 0x65, 0x8f, 0x98, 0x7b, 0x77, 0x84, 0x92, - 0x73, 0x8a, 0xa1, 0x93, 0x80, 0x81, 0x72, 0x8a, 0x6b, 0x75, 0x8f, 0x98, - 0x73, 0x74, 0x6f, 0x70, 0x51, 0x6a, 0x84, 0x9e, 0x78, 0x9b, 0x8c, 0x81, - 0x7e, 0x75, 0x80, 0x88, 0x73, 0x4e, 0x71, 0x74, 0x8c, 0x74, 0x6a, 0x84, - 0x7f, 0x6b, 0x78, 0xab, 0x77, 0xa2, 0x98, 0x93, 0x77, 0x75, 0x72, 0x5c, - 0x60, 0x74, 0x84, 0x67, 0x83, 0x7d, 0x7f, 0x7c, 0x5c, 0x72, 0x70, 0x7f, - 0x6c, 0x84, 0x90, 0xab, 0x97, 0x7f, 0x6b, 0x82, 0x7f, 0x78, 0x73, 0x7d, - 0x8f, 0x8e, 0x8a, 0x8f, 0x8d, 0xa3, 0x74, 0x6e, 0x5e, 0x8c, 0x94, 0x86, - 0x57, 0xb0, 0x79, 0xa8, 0x7b, 0x8d, 0x83, 0x77, 0x89, 0xb6, 0x60, 0x9d, - 0x77, 0x59, 0x72, 0x4d, 0x6f, 0x94, 0x71, 0x75, 0x61, 0x96, 0x86, 0x5d, - 0x84, 0x68, 0x86, 0x82, 0x8d, 0x70, 0x9a, 0x86, 0x73, 0x64, 0x74, 0x7d, - 0x80, 0x5a, 0x64, 0x81, 0xa1, 0x71, 0x77, 0x65, 0xa3, 0x76, 0xa3, 0x9d, - 0x73, 0x7b, 0x8f, 0x7b, 0x79, 0x7d, 0x6c, 0x85, 0x8e, 0x75, 0x65, 0x6a, - 0x87, 0x70, 0x68, 0x8e, 0x76, 0x5d, 0x66, 0x7c, 0x83, 0x83, 0x7e, 0x89, - 0x59, 0x8c, 0x75, 0x59, 0x87, 0x7e, 0x7f, 0x90, 0x6b, 0x7b, 0x7e, 0x6d, - 0x6e, 0x86, 0x69, 0x92, 0x83, 0x8f, 0x8a, 0x60, 0x78, 0x75, 0x61, 0x91, - 0x73, 0x66, 0x86, 0x86, 0x9f, 0x6f, 0x7b, 0x9a, 0x7c, 0x54, 0x75, 0x8e, - 0x7e, 0x72, 0x8e, 0x98, 0x94, 0x5f, 0x71, 0x7c, 0x95, 0x9f, 0x8e, 0x83, - 0x96, 0x4b, 0x8d, 0x84, 0x81, 0x7d, 0x70, 0x84, 0x70, 0x53, 0x8d, 0x84, - 0x5a, 0x91, 0x88, 0x9a, 0x8f, 0x69, 0x8b, 0x52, 0x85, 0x89, 0x6e, 0x99, - 0x79, 0x89, 0x9a, 0x82, 0x6e, 0x8b, 0x65, 0x62, 0x80, 0xa8, 0x8f, 0x8a, - 0x71, 0x61, 0x7e, 0x7d, 0x7e, 0xaa, 0x7f, 0xa0, 0x5e, 0x67, 0x90, 0x86, - 0x6d, 0xac, 0x74, 0x50, 0x61, 0x91, 0x7d, 0x69, 0x8b, 0x7f, 0x81, 0x7a, - 0x93, 0x8c, 0x72, 0x64, 0x98, 0x88, 0x91, 0x83, 0x69, 0x6d, 0x78, 0x7a, - 0x68, 0x7c, 0x76, 0x81, 0xa7, 0x88, 0x8f, 0x79, 0x7d, 0x6c, 0x8a, 0x60, - 0x88, 0x6d, 0x79, 0x9d, 0x80, 0x82, 0x66, 0x7d, 0x7e, 0x96, 0x78, 0x70, - 0x9b, 0x70, 0x7e, 0x90, 0x77, 0x94, 0x7b, 0x89, 0x78, 0x84, 0x74, 0x6d, - 0x7d, 0xa7, 0x75, 0x97, 0x85, 0x83, 0x86, 0x65, 0x75, 0x9a, 0x7c, 0x68, - 0x87, 0x82, 0x75, 0x68, 0x4c, 0x8a, 0x68, 0x93, 0x7d, 0x88, 0x84, 0x72, - 0x58, 0x81, 0x5d, 0x83, 0x89, 0x63, 0x83, 0x7d, 0x8e, 0x75, 0x8c, 0x88, - 0x7f, 0x57, 0x8c, 0x8f, 0xa6, 0x71, 0x8a, 0x95, 0x88, 0x51, 0x74, 0x8a, - 0x8a, 0x98, 0x72, 0x80, 0x8a, 0x52, 0x90, 0x66, 0x54, 0x8e, 0x7f, 0x94, - 0x81, 0x49, 0x84, 0x70, 0x5c, 0x93, 0x89, 0x6d, 0x82, 0x7f, 0x70, 0x5d, - 0x87, 0x8a, 0x71, 0x70, 0x6f, 0xa1, 0x90, 0x9f, 0x74, 0x7c, 0x8c, 0x8b, - 0x72, 0xbf, 0x89, 0x90, 0x5c, 0x8c, 0x75, 0x72, 0x6f, 0xb2, 0x84, 0x6d, - 0x61, 0x80, 0x7d, 0x7a, 0x66, 0xaa, 0x75, 0x71, 0x89, 0x6d, 0x69, 0x72, - 0x73, 0x98, 0x8c, 0x78, 0x5a, 0x8e, 0x8c, 0x81, 0x55, 0x81, 0x96, 0x67, - 0x6f, 0x71, 0x74, 0x7d, 0x8e, 0x66, 0x9a, 0x67, 0xaa, 0x81, 0x90, 0x79, - 0x89, 0x59, 0x86, 0x66, 0x8f, 0x7d, 0x7e, 0xa2, 0xa4, 0x99, 0x68, 0x7a, - 0x8c, 0x73, 0x85, 0x77, 0x8b, 0x74, 0x75, 0x66, 0xaa, 0x98, 0x59, 0x8b, - 0x91, 0x6c, 0x76, 0x73, 0x87, 0xa4, 0x82, 0x82, 0x63, 0x70, 0x7e, 0x73, - 0x96, 0x97, 0x6f, 0x86, 0x81, 0x6f, 0x83, 0x82, 0x7b, 0x82, 0xa3, 0xa7, - 0x95, 0x77, 0x84, 0x65, 0x9b, 0x94, 0x6e, 0xb0, 0x75, 0x66, 0x78, 0x82, - 0x9c, 0x7a, 0x5f, 0xab, 0x99, 0x2f, 0x7f, 0x68, 0xa4, 0x69, 0x8f, 0x9a, - 0x91, 0x56, 0x6e, 0x75, 0x63, 0x9b, 0x9e, 0x97, 0x95, 0x68, 0x80, 0x6a, - 0x40, 0x95, 0x53, 0x72, 0x6f, 0x6b, 0x91, 0x78, 0x7f, 0x93, 0x70, 0x8d, - 0x62, 0x83, 0x7e, 0x64, 0x5b, 0xaa, 0x70, 0x6c, 0x7e, 0x9c, 0x88, 0x76, - 0x60, 0x70, 0x66, 0x69, 0x84, 0x97, 0x9d, 0x63, 0x5e, 0x9a, 0x7e, 0x52, - 0x58, 0xb8, 0x95, 0x7c, 0x4d, 0x96, 0x8f, 0x70, 0x71, 0xbf, 0x83, 0x83, - 0x9e, 0x70, 0x6f, 0x57, 0x70, 0x9a, 0x8d, 0x6e, 0x98, 0x5a, 0x69, 0x6f, - 0x90, 0x71, 0x8a, 0x5d, 0x8e, 0x6e, 0x69, 0x7a, 0x90, 0x86, 0x89, 0x88, - 0xb6, 0x77, 0x84, 0x79, 0x76, 0x86, 0x86, 0x7c, 0xbf, 0x6d, 0x5c, 0x90, - 0xa1, 0x93, 0x72, 0x63, 0x9a, 0x82, 0x7b, 0x61, 0x91, 0x76, 0x82, 0x96, - 0xb9, 0x80, 0x77, 0x7f, 0xa0, 0x73, 0x61, 0x80, 0x83, 0xc1, 0x92, 0x67, - 0x7c, 0x81, 0x90, 0x67, 0x8b, 0xbe, 0x81, 0x91, 0x6c, 0x7e, 0x8d, 0x6c, - 0x62, 0x83, 0x7e, 0x72, 0x64, 0x8a, 0x83, 0x82, 0xaa, 0x8c, 0x74, 0xab, - 0x79, 0x85, 0x91, 0x79, 0x90, 0x68, 0x5c, 0x9a, 0x7c, 0x36, 0x80, 0x6e, - 0x93, 0x76, 0x5e, 0xa0, 0xa5, 0x63, 0x73, 0x7e, 0x8d, 0x94, 0x63, 0x99, - 0x8f, 0x6a, 0x7f, 0x57, 0x57, 0x6f, 0x6d, 0x86, 0x8e, 0x6b, 0x8d, 0x53, - 0x94, 0xba, 0x84, 0x6f, 0x5a, 0x7b, 0x8c, 0x5f, 0x73, 0x93, 0x8b, 0x87, - 0x6f, 0x9e, 0x8a, 0x87, 0x62, 0x97, 0x86, 0x7c, 0x69, 0xab, 0xa1, 0x95, - 0x42, 0x8c, 0x8b, 0x66, 0x68, 0x99, 0xa8, 0x74, 0x80, 0xa5, 0x7d, 0x82, - 0x55, 0xb3, 0x6f, 0x81, 0xa8, 0x9a, 0x80, 0x67, 0x62, 0x7f, 0x78, 0x93, - 0x90, 0x83, 0x83, 0x7b, 0x77, 0x73, 0x8c, 0x56, 0xa7, 0x85, 0x7b, 0x71, - 0x8f, 0x5d, 0x92, 0x69, 0xbe, 0x5e, 0x7f, 0x7f, 0x8e, 0x71, 0x84, 0x75, - 0x95, 0x69, 0x88, 0x6b, 0x96, 0x85, 0x78, 0x39, 0xc2, 0x86, 0x7c, 0x99, - 0xa1, 0x94, 0x6b, 0x86, 0xb5, 0x5e, 0x7e, 0x6e, 0x81, 0x95, 0x6a, 0x88, - 0x7b, 0x92, 0x8f, 0x68, 0x97, 0x77, 0x84, 0x73, 0x68, 0x96, 0x5a, 0x92, - 0x66, 0x74, 0x74, 0x6c, 0x7d, 0x81, 0x6c, 0x93, 0x7f, 0x72, 0x86, 0x74, - 0xbf, 0x8f, 0x53, 0xa4, 0x89, 0x76, 0xa0, 0x87, 0x97, 0x6a, 0x6b, 0xb1, - 0x91, 0x50, 0x74, 0x68, 0xa3, 0x60, 0x8d, 0xbc, 0xc1, 0x3e, 0x62, 0x59, - 0x71, 0x72, 0x6d, 0x80, 0x9f, 0x52, 0x82, 0x6b, 0x5d, 0x7f, 0x74, 0x7e, - 0x74, 0x84, 0x8a, 0x59, 0x5c, 0x85, 0x6d, 0x9c, 0x75, 0x9a, 0x88, 0x89, - 0x81, 0x9f, 0x81, 0x88, 0x6a, 0x94, 0x84, 0x5f, 0x6b, 0x9b, 0x83, 0x4f, - 0x7e, 0xca, 0x99, 0x6d, 0x45, 0x7f, 0x87, 0x71, 0x69, 0xad, 0x95, 0x53, - 0x6e, 0x9b, 0x90, 0x73, 0x5d, 0xb0, 0x8d, 0x67, 0x83, 0x82, 0xa3, 0x70, - 0x70, 0x92, 0x82, 0x9a, 0x8a, 0x69, 0x6a, 0x6e, 0x7f, 0x89, 0xa4, 0x76, - 0x97, 0x62, 0x94, 0x80, 0x87, 0x55, 0x80, 0x76, 0xb3, 0x7e, 0x7e, 0x71, - 0x94, 0x88, 0x8e, 0x74, 0xb6, 0x4d, 0x7b, 0x73, 0x90, 0x86, 0x7c, 0x66, - 0xb5, 0x80, 0x7f, 0x84, 0x87, 0x82, 0x67, 0x83, 0x97, 0x91, 0x8a, 0x78, - 0x8b, 0x83, 0x5d, 0x84, 0x82, 0x9f, 0x8c, 0x91, 0x84, 0x8b, 0x6a, 0x68, - 0x86, 0x82, 0x73, 0x77, 0x7b, 0x83, 0x6a, 0x84, 0x92, 0x93, 0x90, 0x8b, - 0x4c, 0x94, 0x98, 0x76, 0xb8, 0x7b, 0xa0, 0xa2, 0x7d, 0x3e, 0x95, 0x88, - 0xa3, 0x6f, 0x5e, 0xc8, 0x9a, 0x52, 0x81, 0x86, 0xa3, 0x79, 0x88, 0xc3, - 0xbd, 0x54, 0x6c, 0x5e, 0x83, 0x8a, 0x98, 0x88, 0x92, 0x66, 0x73, 0x5b, - 0x6c, 0x7f, 0x6e, 0x97, 0x8d, 0x58, 0x89, 0x6e, 0x65, 0x7a, 0x7d, 0x7c, - 0x7e, 0x89, 0x94, 0x89, 0x55, 0xb8, 0x8f, 0x82, 0x6c, 0x9c, 0x96, 0x5e, - 0x6f, 0xb2, 0x70, 0x76, 0x95, 0xc8, 0x86, 0x78, 0x49, 0xac, 0x7e, 0x6c, - 0x68, 0xb6, 0xaf, 0x89, 0x68, 0xa5, 0x72, 0x85, 0x69, 0x9c, 0x94, 0x84, - 0xa4, 0x97, 0x91, 0x61, 0x7a, 0xa3, 0x8f, 0x8e, 0x93, 0x80, 0x8d, 0x76, - 0x74, 0x84, 0x9b, 0x79, 0x97, 0x4e, 0x67, 0x87, 0x9b, 0x69, 0x85, 0x7d, - 0xb2, 0x68, 0x76, 0x63, 0xa2, 0x86, 0x97, 0x7f, 0xb5, 0x63, 0x79, 0x76, - 0x8a, 0x7c, 0x7c, 0x91, 0xb1, 0x42, 0x7d, 0x7a, 0x8c, 0x8e, 0x72, 0xab, - 0xb8, 0x76, 0xab, 0x81, 0x98, 0x85, 0x56, 0x98, 0x84, 0x9f, 0x70, 0x86, - 0x76, 0x88, 0x70, 0x8d, 0x71, 0x7b, 0x7a, 0x8d, 0x76, 0x75, 0x62, 0x80, - 0x81, 0x94, 0x82, 0x6e, 0x57, 0x8d, 0xaf, 0x84, 0xbf, 0x85, 0x82, 0xa7, - 0x80, 0x89, 0x95, 0x81, 0x91, 0x49, 0x72, 0xa1, 0xa7, 0x3f, 0x72, 0x8b, - 0x99, 0x72, 0x86, 0xb2, 0xc3, 0x61, 0x55, 0x77, 0x86, 0x77, 0x83, 0xa7, - 0x95, 0x5a, 0x68, 0x68, 0x6a, 0x63, 0x6a, 0x77, 0x93, 0x7c, 0x88, 0x62, - 0x79, 0x84, 0x8b, 0x82, 0x58, 0x8f, 0x9c, 0x56, 0x77, 0xb1, 0x65, 0x8c, - 0x76, 0x91, 0x83, 0x5b, 0x62, 0x91, 0x87, 0x68, 0x71, 0xb0, 0x87, 0x64, - 0x62, 0x91, 0x94, 0x58, 0x7f, 0xac, 0xa3, 0x84, 0x75, 0xaa, 0xa3, 0x4d, - 0x7a, 0xc2, 0x84, 0x8a, 0x6d, 0xa2, 0x76, 0x74, 0x8c, 0x9e, 0x7c, 0x71, - 0x86, 0x70, 0x6d, 0x79, 0x9a, 0x74, 0xb0, 0x8d, 0xa5, 0x7e, 0x6b, 0x63, - 0x96, 0x74, 0x99, 0x76, 0xd0, 0x62, 0x85, 0x9d, 0x8f, 0x6d, 0x83, 0x88, - 0xb0, 0x62, 0x9b, 0x87, 0x91, 0x82, 0x7a, 0x90, 0x9c, 0x61, 0x6d, 0x97, - 0x84, 0x7c, 0x74, 0x8e, 0x8b, 0x75, 0x9a, 0x7e, 0x7c, 0x7d, 0x96, 0x81, - 0x94, 0x69, 0x83, 0x6f, 0x8e, 0x7c, 0x7b, 0x7a, 0x73, 0x98, 0x74, 0x9e, - 0x72, 0x8c, 0x5f, 0x7d, 0x99, 0x79, 0x5b, 0x73, 0x65, 0x78, 0xa5, 0x7d, - 0xa2, 0x98, 0x91, 0x91, 0x87, 0x7b, 0x8c, 0x82, 0xb8, 0x6b, 0x82, 0xba, - 0xa5, 0x3f, 0x83, 0x7a, 0x9b, 0x73, 0x93, 0xa1, 0xbe, 0x55, 0x6b, 0x75, - 0x94, 0x7d, 0x9c, 0xa1, 0x82, 0x50, 0x75, 0x5a, 0x88, 0x6e, 0x72, 0x7f, - 0x99, 0x64, 0x72, 0x49, 0x69, 0x79, 0x6d, 0x94, 0x73, 0x79, 0x80, 0x6f, - 0x72, 0xbc, 0x9d, 0x71, 0x7a, 0x9d, 0x8a, 0x55, 0x74, 0xaa, 0xa1, 0x85, - 0x7e, 0xc4, 0xa0, 0x7e, 0x50, 0x99, 0x68, 0x8c, 0x8a, 0xb0, 0x99, 0x6c, - 0x6d, 0xaf, 0x7b, 0x7b, 0x79, 0xba, 0x8a, 0x7a, 0x9d, 0x8b, 0x67, 0x87, - 0x76, 0xa9, 0x7f, 0x7e, 0x8b, 0x7b, 0x87, 0x84, 0x82, 0x74, 0xa3, 0x91, - 0x9a, 0x6a, 0x93, 0x7e, 0x87, 0x5b, 0x95, 0x89, 0xbb, 0x5d, 0x74, 0x6c, - 0x88, 0x7e, 0x81, 0x7e, 0xb6, 0x6b, 0x91, 0x92, 0x83, 0x78, 0x79, 0x95, - 0x90, 0x5e, 0x68, 0x8f, 0xa8, 0x92, 0x66, 0x8e, 0x6b, 0x8c, 0x86, 0x80, - 0x7e, 0x7e, 0x70, 0x84, 0x7d, 0x71, 0x67, 0x94, 0x71, 0x69, 0x84, 0x8f, - 0x6c, 0x72, 0x85, 0x83, 0x69, 0x76, 0x57, 0x62, 0x83, 0x96, 0x83, 0x77, - 0x64, 0x5f, 0xae, 0x7c, 0xa7, 0x88, 0x91, 0x8c, 0x9e, 0x7f, 0xa8, 0x8a, - 0x93, 0x6f, 0x58, 0xae, 0xb4, 0x4b, 0x7f, 0x64, 0x9f, 0x5a, 0x9e, 0xb6, - 0xa6, 0x6b, 0x79, 0x84, 0x6b, 0x7c, 0x8b, 0x94, 0x85, 0x60, 0x6b, 0x55, - 0x79, 0x68, 0x77, 0x75, 0x85, 0x5c, 0x91, 0x5e, 0x5a, 0x71, 0x68, 0x7b, - 0x73, 0x91, 0x6c, 0x6e, 0x71, 0x8b, 0x76, 0x86, 0x99, 0xb8, 0x91, 0x68, - 0x51, 0xa7, 0x6f, 0x7a, 0x8a, 0xc3, 0x8e, 0x65, 0x64, 0x9e, 0x80, 0x78, - 0x6c, 0xc5, 0xa2, 0x75, 0x71, 0xa5, 0x96, 0x4f, 0x70, 0xa4, 0x7a, 0x7c, - 0x8c, 0x80, 0x89, 0x97, 0x9a, 0x9a, 0x85, 0x89, 0x92, 0x8f, 0x81, 0x6f, - 0x82, 0x6a, 0xb8, 0x74, 0x8f, 0x51, 0x7b, 0x8b, 0x8c, 0x55, 0x7e, 0x8c, - 0xb2, 0x41, 0x85, 0x77, 0x9c, 0x73, 0x75, 0x8d, 0x9f, 0x64, 0x92, 0x77, - 0xa0, 0x87, 0x5f, 0x71, 0x85, 0x68, 0x8a, 0x78, 0x91, 0x78, 0x75, 0x7a, - 0x81, 0x67, 0x96, 0x64, 0x96, 0x85, 0x7a, 0x7e, 0x83, 0x74, 0x82, 0x8f, - 0x98, 0x75, 0x77, 0x84, 0x7e, 0x88, 0x94, 0x7d, 0x79, 0x8c, 0x47, 0x79, - 0x96, 0x7f, 0x8e, 0x90, 0x50, 0x7f, 0xa3, 0x77, 0xa8, 0x7f, 0x65, 0x9f, - 0xb9, 0x4c, 0xa7, 0x7f, 0xaa, 0x6e, 0xa2, 0xb0, 0xb8, 0x51, 0x6b, 0x74, - 0xaa, 0x63, 0x6c, 0xa3, 0xb6, 0x5e, 0x74, 0x6a, 0x75, 0x69, 0x87, 0x7f, - 0x9d, 0x71, 0x73, 0x72, 0x70, 0x57, 0x5a, 0x7e, 0x8b, 0x64, 0x9a, 0x4d, - 0x97, 0x81, 0x7b, 0x75, 0x6e, 0x92, 0x5f, 0x67, 0x7e, 0xaa, 0x90, 0x7a, - 0x92, 0xae, 0x92, 0x68, 0x79, 0x9d, 0x4f, 0x6c, 0x79, 0xb4, 0x9c, 0x58, - 0x86, 0x8e, 0x62, 0x72, 0x71, 0xc1, 0xac, 0x7d, 0x7a, 0x94, 0x8f, 0x7b, - 0x88, 0xa8, 0x8d, 0x82, 0x75, 0x9b, 0x5f, 0x83, 0x82, 0xb3, 0x7a, 0x93, - 0x94, 0x76, 0x70, 0x7e, 0x72, 0x7e, 0x8f, 0x8c, 0xa7, 0x53, 0x72, 0x77, - 0x7a, 0x64, 0xa8, 0x83, 0xc5, 0x56, 0x71, 0x7b, 0x96, 0x73, 0x7c, 0x73, - 0x93, 0x49, 0x83, 0x99, 0xa2, 0x83, 0x74, 0x79, 0xa4, 0x61, 0x8e, 0x84, - 0x7a, 0x7d, 0x56, 0x98, 0x97, 0x6d, 0x87, 0x8c, 0x7a, 0x77, 0x6a, 0x67, - 0x8a, 0x6f, 0xa2, 0x82, 0x8d, 0x85, 0x6d, 0x8f, 0x7e, 0x74, 0x72, 0x74, - 0x91, 0x75, 0x58, 0x7f, 0x9e, 0x7c, 0x9c, 0x75, 0x61, 0x6f, 0x85, 0x7b, - 0xbe, 0x84, 0x85, 0x9b, 0x8c, 0x3b, 0x9a, 0x90, 0xab, 0x77, 0x8e, 0xa2, - 0xbd, 0x55, 0x96, 0x70, 0xa8, 0x78, 0x98, 0x9c, 0xc3, 0x67, 0x6e, 0x81, - 0x70, 0x75, 0x96, 0x9c, 0x8a, 0x5b, 0x73, 0x54, 0x69, 0x6c, 0x5d, 0x82, - 0x99, 0x5b, 0x8c, 0x6d, 0x87, 0x80, 0x67, 0x86, 0x88, 0x7c, 0x70, 0x6b, - 0x75, 0xab, 0x8e, 0x79, 0x90, 0x91, 0xaf, 0x67, 0x5c, 0xa1, 0x5c, 0x6f, - 0x75, 0xa1, 0x95, 0x5f, 0x82, 0x8f, 0x78, 0x5d, 0x7c, 0xb8, 0x8a, 0x8a, - 0x6a, 0x98, 0x6e, 0x51, 0x6b, 0xaa, 0x7d, 0x7c, 0x80, 0x94, 0x79, 0x6d, - 0xaa, 0x8a, 0x7e, 0x77, 0xa4, 0x78, 0xa5, 0x6d, 0x7c, 0x75, 0xa8, 0x6f, - 0xa6, 0x51, 0x8e, 0x80, 0x96, 0x5b, 0x9d, 0x7b, 0xb8, 0x4e, 0x6c, 0x87, - 0x95, 0x7c, 0x78, 0x71, 0xb0, 0x5a, 0x99, 0xa0, 0x90, 0x87, 0x65, 0x8b, - 0x98, 0x68, 0x92, 0x76, 0x82, 0x77, 0x6a, 0x8a, 0x91, 0x84, 0x87, 0x8b, - 0x87, 0x84, 0x7a, 0x81, 0x77, 0x55, 0x8e, 0x86, 0x7a, 0x74, 0x65, 0x88, - 0x62, 0x51, 0xa1, 0x91, 0x88, 0x76, 0x5f, 0x89, 0x9f, 0x86, 0x66, 0x67, - 0x64, 0x75, 0x9e, 0x74, 0xc1, 0x80, 0x58, 0xa9, 0x8f, 0x5e, 0x94, 0x88, - 0xaf, 0x6f, 0x6c, 0xa4, 0xa1, 0x4d, 0x68, 0x66, 0xc2, 0x6e, 0x89, 0x9b, - 0xa3, 0x5a, 0x63, 0x5b, 0x9c, 0x7a, 0x93, 0x76, 0x9d, 0x6d, 0x71, 0x5d, - 0x80, 0x66, 0x79, 0x80, 0x7c, 0x65, 0x74, 0x64, 0x88, 0x90, 0x79, 0x89, - 0x72, 0x88, 0x67, 0x75, 0x6a, 0x96, 0x56, 0x67, 0x88, 0xa1, 0x8c, 0x6c, - 0x55, 0xb2, 0x8a, 0x71, 0x88, 0xdc, 0x7a, 0x72, 0x94, 0x9d, 0x7c, 0x76, - 0x6a, 0xaa, 0xa8, 0x7f, 0x80, 0xa0, 0x6b, 0x6f, 0x84, 0xe0, 0x68, 0x93, - 0xa6, 0x99, 0x69, 0x68, 0x93, 0xa0, 0x93, 0x6b, 0x87, 0x8b, 0x80, 0x90, - 0x90, 0x89, 0x8f, 0x7f, 0xaf, 0x6f, 0x82, 0x6d, 0x94, 0x70, 0x97, 0x8f, - 0xb0, 0x40, 0x9b, 0x67, 0x78, 0x86, 0x90, 0x8b, 0xa7, 0x51, 0x7f, 0x79, - 0x90, 0x71, 0x6d, 0x80, 0x95, 0x63, 0x7d, 0x87, 0xa0, 0x7e, 0x7b, 0x85, - 0x8e, 0x6d, 0xa1, 0x76, 0x70, 0x7b, 0x66, 0x87, 0x90, 0x7a, 0x86, 0x88, - 0x89, 0x87, 0x6a, 0x91, 0x78, 0x74, 0x76, 0x8d, 0x7e, 0x86, 0x63, 0x90, - 0x98, 0x7d, 0x4a, 0x85, 0x4f, 0x9d, 0xa2, 0x7c, 0xb4, 0x88, 0x78, 0xb5, - 0x8f, 0x3f, 0xa7, 0x7d, 0xa4, 0x7c, 0x60, 0x9c, 0xa8, 0x41, 0x6b, 0x7f, - 0xa2, 0x7f, 0x68, 0xaa, 0xb4, 0x73, 0x56, 0x62, 0x87, 0x72, 0xa5, 0x7c, - 0x97, 0x69, 0x58, 0x6b, 0x89, 0x57, 0x51, 0x80, 0x92, 0x7a, 0x7c, 0x4c, - 0x7c, 0x7b, 0x69, 0x5f, 0x90, 0x77, 0x78, 0x67, 0x7a, 0xad, 0x79, 0x5c, - 0x9c, 0xbf, 0xa6, 0x64, 0x53, 0xb3, 0x5e, 0x59, 0x86, 0xb9, 0x94, 0x65, - 0x70, 0x9d, 0x7a, 0x80, 0x7c, 0xae, 0x9c, 0x7b, 0x66, 0xae, 0x83, 0x5f, - 0x81, 0xc5, 0x8b, 0x7e, 0x9b, 0x89, 0x84, 0x7f, 0x7c, 0xa5, 0x5c, 0x89, - 0x8a, 0x75, 0x99, 0x6d, 0x8e, 0x90, 0x9f, 0x81, 0x81, 0x6b, 0x87, 0x76, - 0x92, 0x6f, 0xab, 0x95, 0x95, 0x4c, 0x97, 0x72, 0x80, 0x87, 0x83, 0x87, - 0xa3, 0x59, 0xad, 0x74, 0x93, 0x7f, 0x77, 0x78, 0x8d, 0x66, 0x9b, 0x7a, - 0x7d, 0x95, 0x64, 0x7f, 0x6d, 0x5c, 0x8e, 0x94, 0x92, 0x82, 0x60, 0x8d, - 0x75, 0x55, 0x8c, 0x8b, 0x8f, 0x86, 0x7d, 0x7c, 0x74, 0x57, 0x78, 0x9d, - 0x71, 0x65, 0x66, 0x7f, 0xaa, 0x92, 0x66, 0x81, 0x5a, 0x71, 0xa6, 0x78, - 0x9d, 0x8a, 0x5a, 0x8a, 0x91, 0x59, 0xb7, 0x5c, 0xc3, 0x73, 0x89, 0x9d, - 0xa7, 0x62, 0x77, 0x72, 0x9f, 0x92, 0x6a, 0x9f, 0xaa, 0x71, 0x6b, 0x5e, - 0x7d, 0x73, 0x8d, 0x89, 0xba, 0x61, 0x73, 0x6e, 0x71, 0x8a, 0x79, 0x7c, - 0x94, 0x76, 0x76, 0x65, 0x81, 0x6f, 0x4e, 0x75, 0x6e, 0x8b, 0x7d, 0x50, - 0x56, 0xb8, 0x72, 0x67, 0x93, 0xc6, 0x88, 0x6f, 0x57, 0xb7, 0x80, 0x4c, - 0x97, 0xc4, 0xb6, 0x71, 0x72, 0x9e, 0x6f, 0x72, 0x8d, 0xa5, 0x8f, 0x89, - 0x74, 0xae, 0x78, 0x70, 0x6e, 0xbb, 0x8f, 0x73, 0x74, 0x8b, 0x5e, 0x86, - 0x8b, 0x8a, 0x72, 0x71, 0x84, 0x84, 0x77, 0xa3, 0xa6, 0x73, 0xa4, 0x7e, - 0xab, 0x5d, 0x75, 0x96, 0x94, 0x5f, 0x8b, 0x74, 0x9c, 0x63, 0x8d, 0x81, - 0x80, 0x6a, 0x91, 0x88, 0x93, 0x53, 0x80, 0x75, 0x79, 0x8d, 0x78, 0x74, - 0x7c, 0x73, 0xb2, 0x89, 0x8e, 0xab, 0x75, 0x6c, 0x7a, 0x79, 0x99, 0x77, - 0x7d, 0x89, 0x5a, 0x81, 0x7c, 0x75, 0x6a, 0x7e, 0x8c, 0x83, 0x78, 0x8e, - 0x62, 0x76, 0x77, 0x6b, 0x79, 0x66, 0x6e, 0x82, 0xa1, 0x8d, 0x52, 0x79, - 0x70, 0x7d, 0xa9, 0x6a, 0x95, 0x7f, 0x59, 0x94, 0x8f, 0x73, 0xb7, 0x85, - 0xb3, 0x80, 0x77, 0x9f, 0xb8, 0x4d, 0x82, 0x7c, 0xa0, 0xa4, 0x7b, 0x8c, - 0xa9, 0x78, 0x62, 0x6b, 0x8a, 0x93, 0x80, 0x68, 0x9b, 0x6d, 0x6b, 0x7b, - 0x84, 0x8f, 0x86, 0x70, 0x70, 0x73, 0x84, 0x4f, 0x7c, 0x75, 0x64, 0x8d, - 0x6e, 0x81, 0x7c, 0x72, 0x81, 0xb0, 0x74, 0x65, 0xa7, 0xae, 0x80, 0x70, - 0x5e, 0xa4, 0x58, 0x54, 0x8e, 0xa7, 0x96, 0x65, 0x66, 0x8b, 0x6c, 0x5d, - 0x6b, 0xbe, 0x94, 0x79, 0x80, 0xa1, 0x91, 0x78, 0x6d, 0xc2, 0x82, 0x85, - 0x81, 0x7d, 0x88, 0x79, 0x93, 0x96, 0x7f, 0x7e, 0x7d, 0x92, 0x75, 0xa2, - 0x9f, 0x7b, 0x92, 0x77, 0x8a, 0x7c, 0x80, 0x8b, 0x9b, 0x64, 0xa5, 0x74, - 0xa1, 0x74, 0x7f, 0x7e, 0x85, 0x78, 0x9c, 0x86, 0x9f, 0x62, 0x8f, 0x7f, - 0x8a, 0x90, 0x6d, 0x7d, 0x93, 0x61, 0x9d, 0x81, 0x9b, 0x99, 0x69, 0x87, - 0x74, 0x7d, 0x8e, 0x8e, 0x7b, 0x7c, 0x6a, 0x71, 0x7d, 0x7f, 0x74, 0x74, - 0x7b, 0x65, 0x6e, 0x91, 0x7c, 0x6e, 0x80, 0x8c, 0x8a, 0x6c, 0x6b, 0x76, - 0xad, 0x94, 0x64, 0x81, 0x69, 0x7b, 0xac, 0x76, 0x9f, 0x71, 0x85, 0x85, - 0x8b, 0x66, 0xb5, 0x87, 0xb3, 0x63, 0x8b, 0x95, 0x8e, 0x50, 0x91, 0x77, - 0xa1, 0x99, 0x64, 0x81, 0xb3, 0x63, 0x6e, 0x7a, 0x7f, 0x73, 0x7a, 0x7b, - 0x93, 0x6d, 0x75, 0x75, 0x7c, 0x7b, 0x59, 0x7c, 0x7c, 0x68, 0x67, 0x78, - 0x79, 0x75, 0x53, 0x86, 0x84, 0x84, 0x91, 0x71, 0x85, 0xb1, 0x84, 0x64, - 0x88, 0xc0, 0x94, 0x5f, 0x6f, 0x9b, 0x69, 0x67, 0x97, 0x94, 0x88, 0x6a, - 0x7e, 0x94, 0x9e, 0x7f, 0x81, 0x9c, 0xa7, 0x7f, 0x7a, 0xa2, 0x63, 0x69, - 0x82, 0xc2, 0x5e, 0x8d, 0x7c, 0x89, 0x63, 0x93, 0x84, 0xb8, 0x76, 0x89, - 0x96, 0x87, 0x79, 0x88, 0xa6, 0x8e, 0x9b, 0x93, 0x9c, 0x5d, 0x92, 0x92, - 0x82, 0x5e, 0x85, 0x88, 0xad, 0x73, 0xa4, 0x6f, 0x74, 0x8e, 0x77, 0x89, - 0x9b, 0x6e, 0x82, 0x76, 0x93, 0xae, 0x82, 0x87, 0x76, 0x6f, 0x80, 0x76, - 0x95, 0x8e, 0x5e, 0x85, 0x7b, 0x68, 0x7f, 0x7c, 0x82, 0x94, 0x80, 0x91, - 0x77, 0x71, 0x7c, 0x94, 0x80, 0x62, 0x65, 0x7c, 0x5e, 0x70, 0x76, 0x75, - 0x7b, 0x60, 0x5f, 0x69, 0xb3, 0x6e, 0x95, 0x9d, 0x5a, 0x5b, 0x9e, 0x6e, - 0xa6, 0x80, 0x5d, 0xa5, 0x83, 0x5b, 0xa4, 0x80, 0xb3, 0x79, 0x83, 0xb6, - 0xa3, 0x73, 0x84, 0x67, 0x8d, 0x8f, 0x9d, 0x78, 0xb8, 0x8a, 0x7b, 0x6c, - 0x85, 0x87, 0x6d, 0x75, 0xae, 0x75, 0x53, 0x71, 0x6b, 0x87, 0x67, 0x7b, - 0x7f, 0x86, 0x58, 0x73, 0x7d, 0x87, 0x5d, 0x7f, 0x7d, 0x63, 0x92, 0x65, - 0x7a, 0x9c, 0x6f, 0x87, 0x81, 0xa9, 0x91, 0x54, 0x66, 0x8e, 0x58, 0x6d, - 0x92, 0xc2, 0xa9, 0x7b, 0x6e, 0x96, 0x7c, 0x60, 0x7e, 0xa8, 0x85, 0x94, - 0x90, 0x8b, 0x77, 0x79, 0x77, 0xa7, 0x8f, 0x83, 0x80, 0x99, 0x8c, 0x80, - 0x93, 0x9c, 0x73, 0x9e, 0x75, 0x90, 0x67, 0x74, 0x99, 0x98, 0x7e, 0x76, - 0x9f, 0x82, 0x90, 0x95, 0x9d, 0x5f, 0x95, 0x98, 0x8c, 0x5f, 0x77, 0x83, - 0x7b, 0x72, 0x85, 0x7c, 0x97, 0x74, 0x81, 0x80, 0x8d, 0x89, 0x7d, 0x69, - 0x95, 0x85, 0x83, 0x5e, 0x95, 0x74, 0x54, 0x7f, 0x6c, 0x67, 0x9b, 0x83, - 0x88, 0x8e, 0x6f, 0x96, 0x81, 0x7f, 0x6e, 0x87, 0x8f, 0x6f, 0x61, 0x87, - 0x63, 0x66, 0x72, 0x77, 0x75, 0x6d, 0x59, 0x7d, 0xaa, 0x85, 0x62, 0x83, - 0x97, 0x94, 0x96, 0x89, 0x9d, 0x90, 0x7d, 0x91, 0x78, 0x57, 0xa0, 0x7f, - 0xa2, 0x62, 0x63, 0x99, 0x77, 0x71, 0x7f, 0x61, 0x99, 0x89, 0x6f, 0xa2, - 0xae, 0x92, 0x88, 0x51, 0x87, 0x7a, 0x6f, 0x89, 0xa8, 0x89, 0x64, 0x81, - 0x84, 0x79, 0x5b, 0x73, 0x82, 0x6e, 0x7e, 0x5d, 0x8f, 0x82, 0x51, 0x69, - 0x8e, 0x76, 0x8b, 0x58, 0x89, 0xb2, 0x52, 0x72, 0x7f, 0xae, 0x96, 0x5a, - 0x80, 0xa1, 0x74, 0x62, 0x8d, 0xbe, 0x87, 0x6c, 0x6d, 0xad, 0x83, 0x5a, - 0x6c, 0xa5, 0x7f, 0x7c, 0x7a, 0xa1, 0x75, 0x6d, 0x85, 0xbe, 0x91, 0x8e, - 0x96, 0x8c, 0x87, 0x74, 0x8b, 0x82, 0x96, 0x8f, 0x8f, 0x93, 0x8f, 0x8c, - 0x9a, 0x78, 0x73, 0x6e, 0x91, 0x8d, 0x7e, 0x81, 0x81, 0x52, 0x90, 0x85, - 0x77, 0x66, 0x7e, 0x75, 0x8a, 0x67, 0x72, 0x76, 0x82, 0x7b, 0x6e, 0x67, - 0x96, 0x7b, 0x75, 0x76, 0x8d, 0x76, 0x7f, 0x79, 0x84, 0x7b, 0x57, 0x81, - 0x76, 0x80, 0x67, 0x8c, 0x7c, 0x80, 0x67, 0x85, 0x79, 0x5b, 0x97, 0x74, - 0x91, 0x75, 0x82, 0x75, 0x6b, 0x94, 0x7e, 0x85, 0x8e, 0x77, 0x5d, 0x78, - 0xb5, 0x8b, 0x73, 0x7f, 0x62, 0x8f, 0xb1, 0x7d, 0xa2, 0x85, 0x6b, 0x92, - 0x75, 0x75, 0xb8, 0x7d, 0xb3, 0x67, 0x5f, 0xa6, 0x9b, 0x85, 0x9a, 0x67, - 0xbe, 0x8d, 0x92, 0x88, 0xa5, 0x7c, 0xaa, 0x5a, 0x71, 0x7b, 0x70, 0x77, - 0xa0, 0xa4, 0x5e, 0x55, 0x6b, 0x8e, 0x53, 0x89, 0x8a, 0x5a, 0x7c, 0x54, - 0x7c, 0x8b, 0x53, 0x77, 0x67, 0x77, 0x67, 0x5d, 0x91, 0xac, 0x78, 0x81, - 0x8e, 0xb5, 0x6d, 0x58, 0x78, 0xa6, 0x7c, 0x85, 0x87, 0xb3, 0x76, 0x5d, - 0x7c, 0x87, 0x57, 0x68, 0x82, 0x8f, 0x89, 0x76, 0x86, 0x9f, 0x6c, 0x68, - 0x7c, 0x87, 0x79, 0x9f, 0x86, 0x9e, 0x83, 0x70, 0x8d, 0xb2, 0x84, 0x71, - 0x71, 0x91, 0x9f, 0x8e, 0x83, 0x84, 0x87, 0x80, 0x94, 0x80, 0x7d, 0x8d, - 0x7c, 0x56, 0x5f, 0x80, 0x7d, 0x84, 0x61, 0x6e, 0x69, 0x80, 0x8b, 0x67, - 0xa4, 0x8b, 0x98, 0x7a, 0x8a, 0x6c, 0x77, 0x66, 0x7d, 0x6e, 0x84, 0x78, - 0x82, 0x7d, 0x61, 0x88, 0x6e, 0x53, 0x92, 0x75, 0x88, 0x77, 0x82, 0x9f, - 0x9e, 0x6f, 0x9c, 0x76, 0x91, 0x78, 0x69, 0x7f, 0x71, 0x6c, 0x6f, 0x7d, - 0x83, 0x6e, 0x3c, 0x84, 0x90, 0x8b, 0x71, 0x69, 0x75, 0x81, 0xc8, 0x84, - 0xa7, 0x8a, 0x8a, 0x90, 0x96, 0x86, 0x9e, 0x68, 0x99, 0x84, 0x8c, 0xa0, - 0x8a, 0x71, 0x7d, 0x41, 0xa1, 0x98, 0x77, 0x91, 0xaa, 0x86, 0x96, 0x5e, - 0x86, 0x76, 0xa7, 0x83, 0xac, 0x86, 0x66, 0x46, 0x6a, 0x81, 0x64, 0x77, - 0x67, 0x53, 0x80, 0x59, 0x73, 0x71, 0x63, 0x71, 0x76, 0x86, 0x62, 0x4f, - 0x83, 0xa4, 0x5d, 0x66, 0x93, 0x87, 0x87, 0x5b, 0x7f, 0x9d, 0x61, 0x9d, - 0x94, 0xa4, 0x84, 0x75, 0x67, 0xb3, 0x7b, 0x6d, 0x64, 0x98, 0x62, 0x77, - 0x7d, 0x98, 0x8e, 0x75, 0x7d, 0xa6, 0xa4, 0x8c, 0x83, 0x8b, 0x7a, 0x97, - 0x6c, 0x7f, 0x66, 0x7f, 0x8f, 0x98, 0x72, 0x6e, 0x75, 0x65, 0x80, 0x8d, - 0x88, 0x7d, 0x8c, 0x8d, 0x67, 0x68, 0xab, 0x8c, 0x8b, 0x76, 0x87, 0x69, - 0x88, 0x6c, 0x83, 0x6e, 0x88, 0x64, 0xa8, 0x67, 0xa5, 0x5b, 0x65, 0x60, - 0x6b, 0x62, 0x76, 0x78, 0x8c, 0x5b, 0x61, 0x6f, 0x66, 0x65, 0x92, 0x67, - 0x84, 0x7b, 0x80, 0x86, 0x7b, 0x6c, 0x86, 0x7a, 0x72, 0x7b, 0x4d, 0x94, - 0x80, 0x67, 0x8e, 0x8d, 0x7f, 0x79, 0x65, 0x78, 0xa3, 0x71, 0x80, 0x74, - 0xa7, 0xa8, 0x97, 0x78, 0x91, 0x77, 0x98, 0x86, 0x82, 0x64, 0xa5, 0x6e, - 0x7a, 0x5d, 0x6f, 0xad, 0x9b, 0x7a, 0x91, 0x4b, 0xa1, 0x75, 0x95, 0x76, - 0xac, 0x9d, 0xa3, 0x65, 0x65, 0x6a, 0x81, 0x8b, 0x9f, 0x67, 0x6b, 0x6a, - 0x60, 0x5b, 0x77, 0x96, 0x73, 0x78, 0x5a, 0x77, 0x5f, 0x68, 0x70, 0x72, - 0x78, 0x65, 0x81, 0x20, 0x86, 0x99, 0x80, 0x7a, 0xa5, 0xb1, 0x69, 0x45, - 0x7d, 0xa6, 0x7d, 0x85, 0xaa, 0xa9, 0x65, 0x60, 0x75, 0x9b, 0x61, 0x92, - 0x91, 0x8f, 0x8a, 0x81, 0x88, 0x9c, 0x81, 0x7d, 0x7b, 0x8f, 0x7e, 0x9e, - 0x82, 0x94, 0x95, 0x80, 0x73, 0xae, 0x7b, 0x7a, 0x79, 0x8c, 0x8b, 0x65, - 0x71, 0x75, 0x8d, 0x7a, 0x90, 0x83, 0x7b, 0x77, 0x71, 0x4f, 0x70, 0x95, - 0x87, 0x69, 0x97, 0x8e, 0x70, 0x92, 0x6e, 0x91, 0x9d, 0x72, 0x75, 0x82, - 0xad, 0x81, 0x78, 0x8d, 0x6f, 0x65, 0x88, 0x86, 0x8c, 0x8e, 0x59, 0x8b, - 0x67, 0x69, 0x8b, 0x78, 0x7f, 0x59, 0x73, 0x87, 0x6f, 0x86, 0x66, 0x7c, - 0x96, 0x68, 0x59, 0x78, 0x67, 0x92, 0x7b, 0x76, 0x80, 0x6e, 0x4a, 0x7b, - 0x99, 0x67, 0x72, 0x9c, 0x7a, 0x80, 0x76, 0x5f, 0x8e, 0x4f, 0x71, 0x77, - 0xab, 0x78, 0x99, 0x50, 0x83, 0x65, 0x78, 0x8c, 0xbb, 0x8d, 0x4e, 0x54, - 0x81, 0x6f, 0x7f, 0x91, 0xb9, 0x79, 0x9c, 0x65, 0x5a, 0x5a, 0x73, 0x8c, - 0x9a, 0xac, 0x99, 0x44, 0x7d, 0x4f, 0x78, 0x5a, 0x7d, 0x79, 0x57, 0x44, - 0x6f, 0x6a, 0x75, 0x7f, 0x5f, 0x6f, 0x72, 0x62, 0x7f, 0x89, 0x57, 0x91, - 0x8d, 0x83, 0x7e, 0x63, 0x8c, 0x95, 0x48, 0x78, 0xa9, 0x88, 0x84, 0x5b, - 0x8c, 0xa5, 0x65, 0x71, 0x88, 0x82, 0x7e, 0xa4, 0x8d, 0x7d, 0x7d, 0x8d, - 0x91, 0x7c, 0x73, 0x7d, 0x99, 0x89, 0x6d, 0xa1, 0x98, 0x84, 0x8b, 0x6b, - 0x89, 0x86, 0x84, 0x7e, 0x86, 0x87, 0x78, 0x8c, 0x96, 0x92, 0x5a, 0xa0, - 0x64, 0x73, 0x91, 0x88, 0x8f, 0x6b, 0x96, 0x5c, 0x99, 0x62, 0x78, 0x6c, - 0x87, 0x4d, 0x5d, 0x69, 0x7b, 0x81, 0x4a, 0x61, 0x71, 0x69, 0x7d, 0x91, - 0x67, 0x92, 0x68, 0x6f, 0x50, 0x5e, 0x61, 0x7e, 0x81, 0x70, 0x5f, 0x7b, - 0x6b, 0x55, 0x71, 0x6c, 0x70, 0x53, 0x3f, 0x80, 0x6e, 0x57, 0x96, 0x84, - 0x75, 0x51, 0x60, 0x9a, 0x7f, 0xa5, 0x80, 0x94, 0x95, 0x74, 0x7c, 0x83, - 0xa0, 0x93, 0x5d, 0x92, 0x83, 0x66, 0x67, 0x8a, 0x8b, 0x9b, 0x81, 0x69, - 0x73, 0x91, 0x6b, 0x79, 0x93, 0x88, 0x64, 0x68, 0x81, 0x8c, 0x6f, 0x81, - 0x6f, 0x80, 0x68, 0x5f, 0x9c, 0x95, 0x76, 0x93, 0x87, 0x68, 0x83, 0x94, - 0x8b, 0x85, 0x72, 0x7f, 0x64, 0x8c, 0x6a, 0x95, 0x8d, 0x80, 0x69, 0x6b, - 0x98, 0x86, 0x75, 0x92, 0x7a, 0x7f, 0x5b, 0x7f, 0x9b, 0x57, 0x99, 0x8d, - 0x8a, 0x7b, 0x58, 0x73, 0x88, 0x6d, 0x8a, 0x8c, 0x8e, 0x82, 0x85, 0xaa, - 0x72, 0xa6, 0x7f, 0x7a, 0x83, 0x59, 0x6d, 0x6e, 0x79, 0x83, 0x88, 0x84, - 0x74, 0x85, 0x74, 0x78, 0x80, 0x7c, 0x97, 0x86, 0x94, 0x65, 0x7e, 0x80, - 0x6f, 0x97, 0x70, 0x74, 0x92, 0x76, 0x71, 0x91, 0x85, 0x72, 0x6e, 0x84, - 0x78, 0x7e, 0x88, 0x79, 0x7f, 0x80, 0x83, 0x7a, 0x85, 0x75, 0x82, 0x81, - 0x82, 0x7b, 0x7a, 0xa0, 0x76, 0x7f, 0x75, 0xa7, 0x67, 0x8e, 0x81, 0x98, - 0xa5, 0x86, 0x77, 0x78, 0x7f, 0x97, 0x90, 0x86, 0x80, 0x6b, 0x89, 0x66, - 0x9b, 0x5c, 0x8b, 0x74, 0xac, 0x89, 0x89, 0x92, 0x92, 0xa8, 0x61, 0x85, - 0x8c, 0x86, 0x88, 0x91, 0x92, 0x66, 0x63, 0x6c, 0x7a, 0x80, 0x7d, 0x90, - 0x6f, 0x7f, 0x92, 0x94, 0x8e, 0x7a, 0x86, 0x98, 0xa1, 0x59, 0x71, 0x8c, - 0x63, 0xa3, 0x60, 0x7d, 0x88, 0x6a, 0x83, 0x6e, 0x7a, 0x94, 0x7b, 0x81, - 0x7d, 0x83, 0x77, 0x7e, 0x63, 0xab, 0x75, 0x7b, 0x71, 0x8f, 0x76, 0x6e, - 0x78, 0x7b, 0x79, 0x86, 0x69, 0x67, 0x67, 0x70, 0x6c, 0x7a, 0x6c, 0x84, - 0x74, 0xa2, 0x74, 0x77, 0x8a, 0x58, 0x7d, 0xa0, 0x65, 0x7b, 0x79, 0x71, - 0x7c, 0x3c, 0x85, 0x96, 0x59, 0x76, 0x6a, 0x94, 0xa5, 0x5b, 0x70, 0x99, - 0x7f, 0x9a, 0x69, 0x7c, 0x6f, 0x79, 0x72, 0x8b, 0x83, 0x6e, 0x73, 0x7f, - 0x6f, 0x6d, 0x7e, 0xa3, 0x72, 0x87, 0x83, 0x8c, 0x8c, 0x70, 0x77, 0x75, - 0xa4, 0x5a, 0x89, 0x7d, 0xa0, 0x97, 0x67, 0x80, 0x78, 0x7e, 0x86, 0x6a, - 0x7b, 0x9c, 0x77, 0x67, 0x7b, 0x74, 0x7f, 0xa5, 0x90, 0x94, 0x92, 0x4d, - 0x7a, 0x79, 0x9f, 0x87, 0x64, 0x6e, 0x6d, 0x59, 0x83, 0x54, 0x79, 0x82, - 0x6c, 0x74, 0x82, 0x98, 0x77, 0x90, 0x85, 0xa4, 0x88, 0x81, 0x71, 0x85, - 0x90, 0x8e, 0x88, 0x68, 0x51, 0x6d, 0x71, 0x7b, 0x80, 0xbc, 0xa5, 0x57, - 0x8f, 0x9f, 0x95, 0x89, 0xb1, 0x96, 0x69, 0x65, 0x61, 0x73, 0x6f, 0x6c, - 0x5b, 0x95, 0x99, 0x7f, 0x76, 0x9d, 0x7c, 0x7d, 0x8d, 0xb1, 0x8f, 0x6a, - 0x76, 0x95, 0x74, 0x7a, 0x7b, 0xae, 0x77, 0x76, 0x6d, 0x99, 0x7d, 0x80, - 0x6e, 0x89, 0x7f, 0x74, 0x6f, 0x72, 0x89, 0x8b, 0x86, 0x7b, 0x7c, 0x72, - 0x6b, 0x4f, 0x71, 0x94, 0x80, 0x96, 0x83, 0x7e, 0x75, 0x74, 0x68, 0x83, - 0x95, 0x8c, 0x85, 0x7a, 0x82, 0x74, 0x85, 0x83, 0x8c, 0x7e, 0x7a, 0xa0, - 0x8e, 0x67, 0x6b, 0x82, 0x9b, 0x66, 0x6c, 0x8a, 0x88, 0x7e, 0x74, 0x9e, - 0x88, 0x82, 0x73, 0x73, 0x79, 0x7c, 0x72, 0x6b, 0x74, 0x8b, 0xa4, 0xa4, - 0xa3, 0x73, 0x73, 0x88, 0x8d, 0x94, 0x84, 0x9a, 0x9e, 0x93, 0x6c, 0x86, - 0x7a, 0x7a, 0x7e, 0xaa, 0x66, 0x8f, 0x99, 0xa4, 0x70, 0x4c, 0x6f, 0x66, - 0x8a, 0xaa, 0x69, 0x80, 0x6a, 0x5e, 0x71, 0x8f, 0x8b, 0x84, 0x75, 0x9d, - 0x5c, 0x60, 0x61, 0x4a, 0x6f, 0x91, 0x78, 0x6e, 0x8c, 0x62, 0x88, 0x75, - 0x64, 0x7c, 0x7d, 0x92, 0x9b, 0x96, 0x62, 0x72, 0x6c, 0x6f, 0x87, 0x5d, - 0xa0, 0xa7, 0x7c, 0x58, 0x6e, 0x8c, 0x82, 0x84, 0x7f, 0x8b, 0x54, 0x77, - 0x5b, 0x9a, 0x6a, 0x78, 0x5d, 0xb9, 0x8e, 0x7d, 0x6e, 0xa1, 0x66, 0x7c, - 0x87, 0xd2, 0x7a, 0x6c, 0x82, 0xa1, 0x83, 0x59, 0x64, 0x9e, 0x65, 0x6d, - 0x77, 0x80, 0x7c, 0x9a, 0x50, 0x9f, 0x8b, 0x7a, 0x73, 0x80, 0x92, 0x6d, - 0x97, 0x7f, 0x74, 0x6a, 0x5f, 0x44, 0x7d, 0x99, 0x95, 0x91, 0x8f, 0x6a, - 0x63, 0x56, 0x89, 0x96, 0xba, 0xa6, 0x71, 0x98, 0x9d, 0x3a, 0x8f, 0x77, - 0x6d, 0x76, 0x68, 0xb4, 0x8d, 0x79, 0x7a, 0x83, 0x7f, 0x96, 0x75, 0x94, - 0x9e, 0x51, 0x83, 0x5b, 0x66, 0x73, 0xa1, 0xbc, 0x8c, 0x70, 0x88, 0x80, - 0x92, 0x60, 0x7d, 0xa9, 0x97, 0x74, 0x7d, 0x98, 0x7b, 0x78, 0x85, 0xa7, - 0x8f, 0x8c, 0x91, 0x9d, 0x6a, 0x80, 0x6c, 0x8e, 0x8e, 0x91, 0x76, 0x8b, - 0x79, 0x59, 0x7d, 0x9c, 0x69, 0x83, 0x8c, 0x95, 0x8e, 0x75, 0x9d, 0x83, - 0x92, 0x99, 0x8a, 0x59, 0x61, 0x54, 0x63, 0x86, 0x83, 0x86, 0x98, 0x83, - 0x73, 0x74, 0x91, 0x52, 0x60, 0x8a, 0x7c, 0x57, 0xbc, 0x9d, 0x86, 0x6b, - 0x63, 0xa2, 0x78, 0x80, 0x75, 0xb1, 0x74, 0x76, 0x69, 0x8b, 0x7e, 0x76, - 0x7b, 0xb3, 0x77, 0x5b, 0x6c, 0x8b, 0x83, 0x80, 0x7f, 0xd1, 0x7c, 0x58, - 0x6f, 0x98, 0x71, 0x57, 0x60, 0xd0, 0x84, 0x62, 0x74, 0xa6, 0x8f, 0x7b, - 0x70, 0xaa, 0x81, 0x6b, 0x7f, 0x89, 0x6a, 0x74, 0x5a, 0x8c, 0x9c, 0x77, - 0x5d, 0x84, 0x63, 0x94, 0x8e, 0x91, 0x83, 0x4a, 0x49, 0x74, 0x6b, 0x70, - 0xc0, 0xa0, 0x6a, 0x90, 0x8e, 0x5a, 0x70, 0x96, 0xab, 0x72, 0x7e, 0xba, - 0xa7, 0x46, 0x86, 0x5d, 0x90, 0x76, 0x95, 0x8d, 0xa5, 0x40, 0x82, 0x8a, - 0x7d, 0x5e, 0x73, 0x94, 0x9d, 0x58, 0x8c, 0x8b, 0x69, 0x6c, 0x9a, 0x90, - 0xaa, 0x6f, 0x85, 0x8d, 0x64, 0x58, 0x7b, 0x97, 0xa9, 0x79, 0xa5, 0xa2, - 0x5f, 0x57, 0x9a, 0xb4, 0x89, 0x70, 0x84, 0x73, 0x46, 0x6c, 0x6e, 0x87, - 0x70, 0x94, 0x8a, 0x8a, 0x69, 0x7b, 0x6c, 0x68, 0x8e, 0xa2, 0x90, 0x84, - 0x78, 0x45, 0x63, 0x78, 0x7f, 0x90, 0x9f, 0x90, 0x68, 0x43, 0x92, 0x77, - 0x78, 0x77, 0x82, 0x7d, 0x8f, 0x6a, 0x7a, 0x70, 0x76, 0x75, 0x87, 0x63, - 0xbc, 0x8e, 0x6a, 0x71, 0x51, 0x51, 0x75, 0x6b, 0x8a, 0xb4, 0x6a, 0x5b, - 0x99, 0x84, 0x76, 0x84, 0x74, 0xaf, 0x86, 0x6a, 0x53, 0x97, 0x6e, 0x8e, - 0x61, 0xc4, 0x7e, 0x5d, 0x4d, 0x96, 0x73, 0x73, 0x53, 0xc0, 0x8f, 0x68, - 0x58, 0xae, 0x81, 0x83, 0x62, 0x98, 0x7b, 0x89, 0x54, 0x86, 0x78, 0x67, - 0x70, 0x9b, 0x63, 0x5f, 0x2d, 0x77, 0x84, 0x79, 0x6b, 0xa4, 0x7b, 0x65, - 0x45, 0x65, 0x56, 0x86, 0xbb, 0x8a, 0x8e, 0x92, 0x86, 0x48, 0x7c, 0x6d, - 0xb4, 0x7d, 0x56, 0xa4, 0x86, 0x52, 0x8b, 0x6a, 0x8d, 0x5b, 0x9d, 0xa2, - 0xbf, 0x36, 0x7c, 0x99, 0x9d, 0x65, 0x75, 0xa4, 0x9f, 0x6a, 0x7c, 0x6b, - 0x6f, 0x55, 0x70, 0x7f, 0xc2, 0x38, 0x6e, 0xa4, 0x74, 0x4c, 0x75, 0xbb, - 0xa4, 0x75, 0x8e, 0x8f, 0x56, 0x65, 0x57, 0x92, 0x73, 0x7f, 0x7d, 0x86, - 0x65, 0x76, 0x92, 0x84, 0x70, 0xa8, 0x91, 0x5b, 0x69, 0x74, 0x8e, 0x82, - 0x78, 0x8a, 0xaa, 0x71, 0x70, 0x50, 0x85, 0x82, 0x7d, 0x94, 0xa0, 0x76, - 0x6d, 0x55, 0x86, 0x79, 0x71, 0x7f, 0x9b, 0x71, 0x8a, 0x42, 0x87, 0x64, - 0x57, 0x88, 0xa0, 0x77, 0xa8, 0x91, 0x72, 0x65, 0x7e, 0x6b, 0x7e, 0x81, - 0x8d, 0x97, 0x7e, 0x6a, 0x92, 0x88, 0x84, 0x7a, 0x61, 0xa9, 0x86, 0x59, - 0x6c, 0x87, 0x61, 0x72, 0x4f, 0xc8, 0x99, 0x6c, 0x66, 0xa3, 0x80, 0x8b, - 0x5c, 0xc0, 0x69, 0x7a, 0x6c, 0xb8, 0x8e, 0x91, 0x51, 0x9f, 0x8c, 0x85, - 0x75, 0x96, 0x8c, 0x84, 0x6b, 0xa6, 0x71, 0x62, 0x42, 0x60, 0x74, 0x72, - 0x92, 0x91, 0x70, 0x5b, 0x3d, 0x71, 0x5e, 0x91, 0xa3, 0xa5, 0x6a, 0x7c, - 0x60, 0x58, 0x82, 0x80, 0xa3, 0x73, 0x8f, 0xa0, 0xb2, 0x4b, 0x94, 0x5e, - 0x9f, 0x75, 0x4d, 0x83, 0xbc, 0x42, 0x5e, 0x80, 0x8f, 0x59, 0x53, 0xac, - 0xb2, 0x45, 0x68, 0x7d, 0x9a, 0x65, 0x8a, 0xaa, 0xa0, 0x4e, 0x77, 0x72, - 0x4d, 0x62, 0x6e, 0x98, 0x8c, 0x73, 0x92, 0x5a, 0x49, 0x55, 0x7b, 0x98, - 0x8d, 0x84, 0x80, 0x8e, 0x2e, 0x56, 0x78, 0x73, 0x7b, 0x8f, 0x9a, 0x69, - 0x73, 0x68, 0x7a, 0x88, 0x78, 0xa5, 0xb1, 0x5c, 0x8f, 0x55, 0x71, 0x99, - 0x7a, 0xa9, 0xb0, 0x75, 0x69, 0x44, 0x5f, 0x66, 0x81, 0x7d, 0x9e, 0x4f, - 0x66, 0x7f, 0x87, 0x7d, 0x5d, 0x7c, 0x95, 0x62, 0xa5, 0x86, 0x90, 0x6f, - 0x60, 0xa5, 0x6e, 0x70, 0x80, 0x96, 0x6f, 0x55, 0x77, 0x87, 0x99, 0x7b, - 0x21, 0xaa, 0x7f, 0x60, 0x63, 0xae, 0x47, 0x79, 0x44, 0xb5, 0x83, 0x6e, - 0x6d, 0x93, 0x76, 0x54, 0x4b, 0xad, 0x91, 0x6b, 0x6a, 0x9c, 0x8c, 0x83, - 0x62, 0x8a, 0x88, 0x71, 0x73, 0xa0, 0x75, 0x95, 0x54, 0x80, 0x92, 0x65, - 0x45, 0x80, 0x63, 0x9a, 0x93, 0x9b, 0x78, 0x4e, 0x4d, 0x5f, 0x69, 0x9e, - 0xbd, 0xa5, 0x75, 0x6b, 0x6e, 0x6a, 0x82, 0x97, 0xab, 0x60, 0x76, 0xb3, - 0xc1, 0x39, 0x82, 0x5b, 0x71, 0x31, 0x7b, 0x9c, 0xb5, 0x4f, 0x75, 0x79, - 0x6c, 0x5d, 0x80, 0xa6, 0x9c, 0x53, 0x6f, 0x85, 0x84, 0x5e, 0x7d, 0xb5, - 0x95, 0x5f, 0x7c, 0x98, 0x72, 0x7c, 0x67, 0x99, 0xbb, 0x6c, 0x73, 0x66, - 0x59, 0x5c, 0x6c, 0x9a, 0x9b, 0x72, 0x9b, 0x5f, 0x4b, 0x51, 0x63, 0x84, - 0x74, 0xa0, 0xb3, 0x6e, 0x63, 0xa0, 0x84, 0x90, 0x71, 0x91, 0xba, 0x64, - 0x6d, 0x72, 0x78, 0x83, 0x6f, 0x8e, 0xbd, 0x64, 0x69, 0x60, 0x95, 0x67, - 0x70, 0x93, 0x78, 0x4d, 0x91, 0x3f, 0x7b, 0x6d, 0x69, 0x87, 0x7d, 0x8a, - 0xa3, 0x95, 0x9d, 0x66, 0x6d, 0x8b, 0x7a, 0x75, 0x94, 0x7b, 0x89, 0x52, - 0x66, 0x65, 0x79, 0x84, 0x49, 0x9c, 0x60, 0x66, 0x3e, 0xab, 0x4a, 0x86, - 0x54, 0xcd, 0x7c, 0x83, 0x7c, 0xac, 0x8b, 0x53, 0x67, 0xbb, 0x7c, 0x6d, - 0x72, 0xb3, 0x83, 0x85, 0x4f, 0x97, 0x86, 0x60, 0x7d, 0x93, 0x70, 0x8b, - 0x64, 0x78, 0x82, 0x73, 0x54, 0x87, 0x6c, 0xaa, 0x6f, 0x97, 0x8d, 0x51, - 0x2d, 0x50, 0x75, 0xa9, 0xc2, 0x94, 0x8d, 0x6f, 0x6d, 0x71, 0x7b, 0x87, - 0x93, 0x67, 0x7d, 0xa5, 0xa2, 0x4f, 0x99, 0x83, 0x95, 0x49, 0x70, 0x9c, - 0xcf, 0x37, 0x84, 0x86, 0x94, 0x5c, 0x95, 0xa1, 0xb6, 0x73, 0x80, 0x8d, - 0x89, 0x62, 0x6f, 0xb4, 0xa1, 0x5b, 0x64, 0x91, 0x41, 0x4f, 0x53, 0xa6, - 0xae, 0x75, 0x84, 0x82, 0x58, 0x8e, 0x63, 0x95, 0xa3, 0x8d, 0x8b, 0x76, - 0x5d, 0x78, 0x80, 0x82, 0x6e, 0x9d, 0xb8, 0x7d, 0x64, 0x8a, 0x7e, 0x80, - 0x72, 0x99, 0xcf, 0x76, 0x66, 0x77, 0x7c, 0x81, 0x71, 0x6f, 0xa1, 0x6c, - 0x6b, 0x70, 0x80, 0x7c, 0x6d, 0x83, 0x8e, 0x74, 0x7a, 0x58, 0x69, 0x53, - 0x58, 0x7d, 0x7f, 0x84, 0x96, 0x9c, 0x75, 0x6e, 0x62, 0x7c, 0x88, 0x7e, - 0x7f, 0x98, 0x93, 0x61, 0x98, 0x98, 0x80, 0x83, 0x2e, 0x7d, 0x64, 0x69, - 0x50, 0xa5, 0x38, 0x96, 0x2e, 0xc5, 0x66, 0x56, 0x64, 0xaa, 0x63, 0x64, - 0x6d, 0xb3, 0x8a, 0x6c, 0x59, 0xb6, 0x69, 0x7a, 0x54, 0x91, 0x58, 0x96, - 0x6b, 0x9f, 0x6d, 0x88, 0x4a, 0x82, 0x94, 0x67, 0x38, 0x93, 0x60, 0x87, - 0x8c, 0x93, 0x8c, 0x52, 0x31, 0x43, 0x66, 0xa9, 0xb3, 0x7a, 0x88, 0x64, - 0x60, 0x5b, 0x80, 0x84, 0xb7, 0x5a, 0x7a, 0x9d, 0x92, 0x50, 0x89, 0x80, - 0x72, 0x51, 0x7f, 0x85, 0xae, 0x47, 0x76, 0x9a, 0x7a, 0x74, 0x6d, 0x93, - 0xbd, 0x42, 0x72, 0x6d, 0x58, 0x5e, 0x6e, 0xa4, 0xb5, 0x4e, 0x76, 0x8f, - 0x75, 0x9b, 0x5d, 0x92, 0xad, 0x77, 0x7f, 0x73, 0x62, 0x7d, 0x65, 0xaf, - 0x98, 0x87, 0x80, 0x7c, 0x61, 0x81, 0x45, 0xa0, 0x84, 0x99, 0xbb, 0x72, - 0x86, 0x8f, 0x70, 0x97, 0x6a, 0x8a, 0xd3, 0x70, 0x7c, 0x91, 0x77, 0x82, - 0x70, 0x8c, 0xd5, 0x6c, 0x7f, 0x51, 0x5f, 0x69, 0x72, 0x89, 0x9a, 0x68, - 0x79, 0x70, 0x8b, 0x80, 0x52, 0x98, 0x86, 0x7a, 0xa0, 0x7b, 0x61, 0x6e, - 0x66, 0x6f, 0x77, 0x78, 0x64, 0xac, 0x7e, 0x73, 0x5d, 0x71, 0x6f, 0x80, - 0x2e, 0xa9, 0x90, 0x5c, 0x56, 0xa1, 0x32, 0x88, 0x55, 0xb9, 0x67, 0x6f, - 0x5c, 0xa5, 0x87, 0x61, 0x6b, 0xbd, 0x77, 0x7c, 0x62, 0xae, 0x7c, 0x7a, - 0x66, 0xac, 0x7a, 0x62, 0x5c, 0x9a, 0x58, 0x89, 0x5a, 0x74, 0x72, 0x66, - 0x5c, 0x8e, 0x51, 0x8e, 0x99, 0x92, 0xa0, 0x49, 0x31, 0x55, 0x68, 0x99, - 0xba, 0x82, 0xa2, 0x7a, 0x5e, 0x6f, 0x84, 0x98, 0x96, 0x52, 0x73, 0x99, - 0xb4, 0x5e, 0x7c, 0x59, 0x7d, 0x4a, 0x7e, 0xa0, 0xbe, 0x63, 0x67, 0x8e, - 0x7f, 0x71, 0x80, 0xaf, 0x93, 0x4e, 0x78, 0x7e, 0x6d, 0x52, 0x66, 0xb3, - 0x94, 0x56, 0x84, 0x8f, 0x50, 0x6d, 0x65, 0xa8, 0xb3, 0x4b, 0x91, 0x7f, - 0x4c, 0x8d, 0x69, 0x79, 0x95, 0x8f, 0x8f, 0x7c, 0x66, 0x98, 0x75, 0x9b, - 0x73, 0x9b, 0xac, 0x79, 0x6e, 0x84, 0x69, 0x9e, 0x80, 0xa0, 0xb0, 0x6c, - 0x46, 0x8b, 0x3f, 0x7a, 0x79, 0x79, 0xb3, 0x62, 0x6b, 0x60, 0x67, 0x81, - 0x4a, 0x7e, 0xa7, 0x8c, 0x74, 0x7f, 0x67, 0x4c, 0x4b, 0x8c, 0x8e, 0x67, - 0x78, 0x9d, 0x94, 0x79, 0x75, 0x7c, 0x86, 0x7b, 0x67, 0x9f, 0xa4, 0x61, - 0x5b, 0x6e, 0x85, 0x70, 0x20, 0xa5, 0x66, 0x5e, 0x55, 0xad, 0x3e, 0x7c, - 0x2d, 0xb4, 0x78, 0x6f, 0x4c, 0xc6, 0x7e, 0x6d, 0x54, 0xb4, 0x71, 0x78, - 0x54, 0xc3, 0x66, 0x6e, 0x4a, 0xa0, 0x7b, 0x85, 0x66, 0x94, 0x75, 0x8d, - 0x34, 0x88, 0x71, 0x4e, 0x49, 0x8a, 0x3b, 0x9c, 0x88, 0x76, 0x7f, 0x6a, - 0x37, 0x64, 0x66, 0xb6, 0xa3, 0x82, 0x76, 0x82, 0x6d, 0x65, 0x6f, 0x8c, - 0x99, 0x5e, 0x77, 0xa1, 0x99, 0x51, 0xa1, 0x67, 0x6f, 0x4c, 0x7f, 0x9e, - 0xad, 0x40, 0x65, 0x82, 0x76, 0x66, 0x72, 0xb5, 0xb2, 0x5b, 0x71, 0x8a, - 0x76, 0x74, 0x52, 0xa0, 0x91, 0x37, 0x86, 0x72, 0x6c, 0x75, 0x62, 0xa5, - 0xb6, 0x57, 0x75, 0x90, 0x3e, 0x7f, 0x49, 0x9f, 0x8e, 0x92, 0x81, 0x87, - 0x69, 0x9e, 0x6b, 0x86, 0x8d, 0xb1, 0x9e, 0x65, 0x6f, 0x93, 0x70, 0x79, - 0x7b, 0x87, 0xbe, 0x59, 0x69, 0x7a, 0x56, 0x7a, 0x81, 0x7d, 0xb8, 0x67, - 0x67, 0x7f, 0x54, 0x8f, 0x71, 0x85, 0xa0, 0x74, 0x89, 0x5d, 0x67, 0x52, - 0x65, 0x96, 0x89, 0x84, 0x81, 0x83, 0x82, 0x9a, 0x85, 0x73, 0x78, 0x62, - 0x87, 0x98, 0x75, 0x6a, 0x73, 0x95, 0x86, 0x71, 0x11, 0x9a, 0x91, 0x66, - 0x6e, 0xa4, 0x35, 0x89, 0x47, 0xbb, 0x5e, 0x46, 0x3a, 0xa8, 0x70, 0x4a, - 0x65, 0xb9, 0x70, 0x96, 0x66, 0xcf, 0x80, 0x79, 0x60, 0xa4, 0x79, 0x70, - 0x68, 0x92, 0x7f, 0x89, 0x6b, 0x87, 0x77, 0x67, 0x5b, 0x74, 0x3f, 0x9e, - 0x94, 0x9b, 0xa1, 0x61, 0x4b, 0x66, 0x70, 0xad, 0xb7, 0x67, 0x70, 0x6c, - 0x3f, 0x5b, 0x94, 0x88, 0xb3, 0x4f, 0x97, 0x97, 0x8c, 0x55, 0xb8, 0x78, - 0x60, 0x25, 0x51, 0x91, 0xcd, 0x44, 0x6f, 0x85, 0x5c, 0x65, 0x67, 0xa5, - 0x9e, 0x5f, 0x6d, 0x85, 0x6d, 0x56, 0x80, 0xae, 0x79, 0x63, 0x4f, 0x7d, - 0x5f, 0x6b, 0x6e, 0xa7, 0x8e, 0x76, 0x8f, 0x90, 0x6e, 0x8c, 0x88, 0x92, - 0x81, 0x81, 0x96, 0x7d, 0x48, 0x6b, 0x3f, 0xa1, 0x8c, 0xa2, 0x9f, 0x7f, - 0x77, 0x97, 0x73, 0x9c, 0x67, 0x95, 0xae, 0x77, 0x7f, 0x7a, 0x52, 0x7e, - 0x91, 0x77, 0xa8, 0x54, 0x6a, 0x74, 0x52, 0x8a, 0x67, 0x8e, 0x90, 0x8d, - 0x8b, 0x52, 0x72, 0x5a, 0x73, 0x8f, 0x94, 0x87, 0x7c, 0x88, 0x89, 0x76, - 0x77, 0x88, 0x5c, 0x77, 0x8f, 0x94, 0xac, 0x58, 0x70, 0x79, 0x75, 0x8a, - 0x20, 0x9c, 0x91, 0x55, 0x55, 0xa4, 0x5b, 0x84, 0x30, 0xc6, 0x8a, 0x51, - 0x31, 0xc3, 0x72, 0x6b, 0x65, 0xb9, 0x79, 0x7d, 0x62, 0xad, 0x88, 0x75, - 0x37, 0xb0, 0x76, 0x8a, 0x7d, 0x85, 0x7f, 0xb4, 0x46, 0x9c, 0x83, 0x7b, - 0x79, 0x78, 0x56, 0xac, 0x8d, 0xa2, 0xa9, 0x54, 0x44, 0x5a, 0x63, 0xb2, - 0xa8, 0x72, 0xa4, 0x6b, 0x5d, 0x4d, 0x8e, 0x95, 0x9e, 0x4a, 0x98, 0x8c, - 0xb0, 0x5c, 0xa5, 0x75, 0x83, 0x3b, 0x46, 0x92, 0xa7, 0x3b, 0x6a, 0x75, - 0x59, 0x57, 0x52, 0xa1, 0xab, 0x54, 0x68, 0x7c, 0x94, 0x6e, 0x5b, 0x9a, - 0xa3, 0x5d, 0x73, 0x74, 0x5a, 0x63, 0x56, 0x9e, 0xc1, 0x71, 0x82, 0x79, - 0x49, 0x92, 0x63, 0xa6, 0x99, 0x7d, 0x71, 0x81, 0x5e, 0x90, 0x5c, 0x8b, - 0x7e, 0xb4, 0xa0, 0x8c, 0x67, 0x93, 0x4e, 0x72, 0x65, 0x83, 0xb5, 0x77, - 0x83, 0x92, 0x43, 0x67, 0x8c, 0x81, 0xb1, 0x75, 0x6a, 0x61, 0x66, 0x6f, - 0x5d, 0x7f, 0x8d, 0x7b, 0x6b, 0x68, 0x6f, 0x85, 0x6e, 0x87, 0x97, 0x89, - 0x9b, 0x81, 0x7e, 0x7e, 0x9d, 0x83, 0x6b, 0x6a, 0xa5, 0x92, 0x7e, 0x70, - 0x60, 0x8f, 0x6f, 0x8b, 0x15, 0xa6, 0x66, 0x4e, 0x61, 0xbc, 0x38, 0x67, - 0x46, 0xab, 0x84, 0x5e, 0x3a, 0xac, 0x74, 0x58, 0x76, 0xc4, 0x7a, 0x76, - 0x67, 0xc0, 0x76, 0x6f, 0x52, 0xa6, 0xa2, 0x97, 0x76, 0xa6, 0x7f, 0x99, - 0x5d, 0xa5, 0x5f, 0x60, 0x58, 0x88, 0x3f, 0x9e, 0x7d, 0x81, 0x71, 0x63, - 0x42, 0x55, 0x3e, 0xbd, 0xa9, 0x7a, 0xa5, 0x67, 0x62, 0x7a, 0x80, 0x9e, - 0xc3, 0x54, 0x7f, 0x9f, 0x93, 0x73, 0xbd, 0x79, 0x74, 0x2e, 0x54, 0x9e, - 0xaa, 0x76, 0x68, 0x80, 0x78, 0x64, 0x57, 0x93, 0xa4, 0x56, 0x75, 0x72, - 0x81, 0x7f, 0x48, 0xad, 0x89, 0x67, 0x60, 0x7e, 0x7a, 0x83, 0x6e, 0x95, - 0xb0, 0x57, 0x89, 0x91, 0x4d, 0x86, 0x78, 0x7b, 0x74, 0x8c, 0x8f, 0x8d, - 0x67, 0xa4, 0x64, 0x8d, 0x77, 0x9a, 0xa1, 0x88, 0x6e, 0x94, 0x33, 0x95, - 0x81, 0x76, 0xc6, 0x7d, 0x7d, 0x85, 0x5a, 0x6e, 0x8e, 0x69, 0x9e, 0x71, - 0x82, 0x81, 0x59, 0x5b, 0x71, 0x9a, 0x91, 0x8e, 0x80, 0x69, 0x71, 0x73, - 0x6e, 0x9a, 0x95, 0x94, 0x7b, 0x80, 0x82, 0x7e, 0x76, 0x84, 0x70, 0x72, - 0x9c, 0xa0, 0x77, 0x66, 0x55, 0xa1, 0x8c, 0x73, 0x35, 0xa0, 0x68, 0x4d, - 0x3b, 0xaa, 0x44, 0x6f, 0x3c, 0xc0, 0x96, 0x78, 0x33, 0xbd, 0x64, 0x5b, - 0x75, 0xd2, 0x83, 0x87, 0x59, 0xbd, 0x80, 0x80, 0x6e, 0x8e, 0x65, 0x7a, - 0x87, 0xb6, 0x8d, 0x94, 0x39, 0x95, 0x8b, 0x5d, 0x66, 0x71, 0x4e, 0x9f, - 0x96, 0x8a, 0x98, 0x47, 0x41, 0x6c, 0x4c, 0xac, 0x95, 0x81, 0x90, 0x75, - 0x59, 0x4c, 0xa2, 0x93, 0x99, 0x58, 0x7b, 0xaf, 0xa3, 0x52, 0xb0, 0x6c, - 0x5f, 0x47, 0x6e, 0x8e, 0xae, 0x3d, 0x81, 0x6d, 0x78, 0x52, 0x4f, 0x81, - 0x80, 0x68, 0x4b, 0x81, 0x74, 0x71, 0x67, 0xa7, 0x9a, 0x55, 0x84, 0x72, - 0x64, 0x6b, 0x6e, 0x9d, 0xab, 0x76, 0x79, 0x85, 0x40, 0x84, 0x80, 0x85, - 0x70, 0x91, 0x9a, 0x81, 0x5b, 0x89, 0x6b, 0x8a, 0x92, 0x8c, 0xa4, 0x7b, - 0x75, 0x89, 0x54, 0x76, 0x69, 0x69, 0xb3, 0x6c, 0x47, 0x7d, 0x4c, 0x7f, - 0x81, 0x86, 0x8f, 0x63, 0x71, 0x6a, 0x63, 0x67, 0x7c, 0x8f, 0xa0, 0x68, - 0x86, 0x58, 0x5b, 0x87, 0x6a, 0x82, 0x89, 0x78, 0x9d, 0x8d, 0xaa, 0x82, - 0x6e, 0xa4, 0x6f, 0x6d, 0x70, 0x9f, 0x7f, 0x77, 0x41, 0xa5, 0x86, 0x61, - 0x2d, 0x99, 0xa9, 0x5f, 0x5a, 0xb3, 0x51, 0x70, 0x5a, 0xce, 0x77, 0x68, - 0x2c, 0xb8, 0x90, 0x44, 0x58, 0xb9, 0x74, 0x8e, 0x70, 0xb3, 0x9a, 0x75, - 0x6d, 0xc0, 0x9e, 0x8e, 0x8d, 0xa8, 0x7b, 0xa8, 0x4a, 0x89, 0x6e, 0x7f, - 0x5d, 0x6e, 0x46, 0x91, 0x6d, 0x81, 0x89, 0x3e, 0x35, 0x69, 0x44, 0xaf, - 0x99, 0x8d, 0x94, 0x54, 0x60, 0x5b, 0xaf, 0x97, 0x92, 0x4e, 0x80, 0xae, - 0x9e, 0x62, 0xa3, 0x77, 0x6e, 0x5d, 0x71, 0xa0, 0xa6, 0x59, 0x84, 0x5d, - 0x65, 0x4a, 0x69, 0xa1, 0xa1, 0x40, 0x75, 0x65, 0x6b, 0x68, 0x60, 0xb3, - 0x92, 0x27, 0x70, 0x67, 0x9b, 0x5e, 0x50, 0xaf, 0xae, 0x64, 0x7a, 0x6e, - 0x61, 0x94, 0x3b, 0x8f, 0x86, 0x7f, 0x98, 0x88, 0x7a, 0x7f, 0x61, 0x7b, - 0x64, 0x96, 0x96, 0x79, 0x5c, 0x96, 0x52, 0x92, 0x76, 0x7e, 0xc4, 0x60, - 0x6d, 0x7b, 0x41, 0x8c, 0x7b, 0x8e, 0x9a, 0x66, 0x79, 0x95, 0x67, 0x6a, - 0x7a, 0x9b, 0xa9, 0x85, 0x6d, 0x66, 0x55, 0x65, 0x76, 0x8b, 0x90, 0x86, - 0x88, 0x8b, 0x8f, 0x7e, 0x83, 0x7c, 0x75, 0x5f, 0x78, 0x96, 0x76, 0x47, - 0x54, 0x9c, 0x8d, 0x7d, 0x24, 0x9f, 0x79, 0x5c, 0x55, 0xb2, 0x3b, 0x67, - 0x4e, 0xd2, 0x90, 0x79, 0x3c, 0xc3, 0x8b, 0x4a, 0x7c, 0xd7, 0x70, 0x75, - 0x5b, 0xaf, 0xa8, 0x6b, 0x59, 0xc1, 0x6d, 0x5f, 0x5d, 0x96, 0x87, 0x9a, - 0x5d, 0x7f, 0x8e, 0x6d, 0x5c, 0x75, 0x3f, 0xb6, 0x8e, 0x81, 0x7b, 0x31, - 0x47, 0x67, 0x56, 0xb6, 0x90, 0x71, 0x89, 0x63, 0x61, 0x75, 0x8d, 0x8b, - 0x97, 0x62, 0x62, 0x85, 0x9c, 0x64, 0xb7, 0x61, 0x71, 0x3f, 0x6c, 0x8b, - 0xaa, 0x43, 0x82, 0x70, 0x52, 0x52, 0x80, 0xaa, 0x9e, 0x5d, 0x90, 0x69, - 0x8a, 0x77, 0x6d, 0x9f, 0x9e, 0x5f, 0x84, 0x61, 0x87, 0x70, 0x43, 0xab, - 0x97, 0x6e, 0x84, 0x6c, 0x5d, 0x82, 0x64, 0x85, 0x83, 0x7e, 0x82, 0x7c, - 0x7b, 0x91, 0x55, 0x7e, 0x77, 0x88, 0xba, 0x71, 0x6d, 0x7b, 0x71, 0x8a, - 0x7f, 0x84, 0xb5, 0x63, 0x4a, 0x9a, 0x3c, 0x70, 0x7a, 0x99, 0xa3, 0x50, - 0x84, 0x82, 0x56, 0x4c, 0x74, 0x8e, 0xa3, 0x77, 0x8f, 0x4e, 0x5f, 0x6d, - 0x97, 0x89, 0xa0, 0x6b, 0x7c, 0x8c, 0x85, 0x82, 0x8e, 0xa1, 0x89, 0x5b, - 0x7f, 0x8b, 0x8f, 0x5e, 0x74, 0x96, 0x8a, 0x7d, 0x15, 0x7b, 0x8f, 0x88, - 0x5f, 0xa7, 0x63, 0x5b, 0x39, 0xbd, 0x96, 0x56, 0x4c, 0xb4, 0x7b, 0x53, - 0x5a, 0xaf, 0x79, 0x7b, 0x5c, 0xa6, 0xaa, 0x74, 0x5f, 0xa0, 0x76, 0x9e, - 0x71, 0x9a, 0x60, 0xa4, 0x33, 0x87, 0x66, 0x66, 0x64, 0x7d, 0x6d, 0xac, - 0x9e, 0x8c, 0x78, 0x4f, 0x3d, 0x7b, 0x53, 0xb1, 0x97, 0x8a, 0x96, 0x6e, - 0x60, 0x4b, 0xa9, 0x9e, 0x93, 0x6e, 0x93, 0xb7, 0xae, 0x46, 0xb9, 0x60, - 0x72, 0x46, 0x80, 0x95, 0xb5, 0x57, 0x82, 0x53, 0x6e, 0x4e, 0x5b, 0xa2, - 0x9a, 0x3d, 0x8b, 0x6c, 0x84, 0x65, 0x69, 0xa1, 0x8c, 0x60, 0x83, 0x74, - 0x73, 0x53, 0x5d, 0x7e, 0x7f, 0x79, 0x6e, 0x81, 0x89, 0x8f, 0x51, 0x81, - 0x99, 0x97, 0x81, 0x8a, 0x87, 0x83, 0x43, 0x90, 0x89, 0x94, 0x93, 0x7a, - 0x66, 0x80, 0x82, 0x82, 0x79, 0x85, 0xb0, 0x6b, 0x87, 0x7b, 0x53, 0x89, - 0x79, 0x9d, 0xab, 0x6e, 0x82, 0x84, 0x50, 0x8f, 0x7e, 0x74, 0x90, 0x74, - 0x6e, 0x65, 0x84, 0x70, 0x82, 0x7a, 0x9e, 0x6d, 0x8f, 0x62, 0xb2, 0x84, - 0x78, 0x7e, 0x72, 0x5a, 0x7a, 0x85, 0x8c, 0x4b, 0x70, 0x99, 0x87, 0x78, - 0x26, 0x95, 0xb9, 0x77, 0x4d, 0xb6, 0x51, 0x6a, 0x41, 0xbf, 0x76, 0x68, - 0x56, 0xb6, 0x80, 0x53, 0x83, 0xaf, 0x87, 0x79, 0x79, 0xb4, 0x89, 0x7d, - 0x47, 0x9d, 0xa0, 0x86, 0x89, 0xc3, 0x6d, 0x99, 0x41, 0x89, 0x9a, 0x59, - 0x54, 0x83, 0x79, 0x9d, 0x7b, 0x73, 0x88, 0x4a, 0x42, 0x64, 0x7a, 0x9f, - 0x7b, 0x6e, 0x71, 0x7b, 0x6a, 0x61, 0xae, 0xa3, 0xa0, 0x68, 0x95, 0x9d, - 0x94, 0x49, 0x8b, 0x70, 0x8a, 0x5f, 0x49, 0xbb, 0xa7, 0x4a, 0xa1, 0x59, - 0x59, 0x59, 0x6d, 0xa0, 0x9f, 0x50, 0xa0, 0x7b, 0x75, 0x49, 0x5a, 0x8c, - 0x84, 0x68, 0x78, 0x57, 0x7a, 0x6e, 0x6b, 0x87, 0x9c, 0x7b, 0x84, 0x83, - 0x79, 0x7d, 0x5a, 0x77, 0x77, 0x6f, 0x6f, 0x7c, 0x8f, 0x83, 0x40, 0x62, - 0x6a, 0x87, 0xab, 0x74, 0x86, 0x96, 0x7a, 0x7d, 0x7b, 0x81, 0x9a, 0x65, - 0x60, 0x82, 0x61, 0x73, 0x71, 0x77, 0xa7, 0x79, 0x87, 0x8c, 0x4e, 0x72, - 0x8d, 0x89, 0x94, 0x6d, 0x75, 0x6d, 0x6e, 0x82, 0x7a, 0x8d, 0xa9, 0x77, - 0x77, 0x7c, 0x74, 0xa7, 0xb7, 0x67, 0x75, 0x67, 0x7e, 0x9f, 0x73, 0x60, - 0x6c, 0x95, 0x7f, 0x62, 0x31, 0x70, 0x85, 0x7a, 0x5f, 0xc0, 0x69, 0x66, - 0x71, 0xb0, 0x81, 0x5d, 0x48, 0xc9, 0x86, 0x39, 0x93, 0xa4, 0x8e, 0x7c, - 0x5e, 0xbb, 0x98, 0x5c, 0x74, 0x9c, 0x89, 0x6d, 0x74, 0xbd, 0x8e, 0x6e, - 0x5f, 0x9a, 0x6d, 0x70, 0x57, 0x9c, 0x58, 0xb7, 0x8e, 0x94, 0xa0, 0x3f, - 0x39, 0x75, 0x6f, 0xb4, 0xa2, 0x94, 0xa9, 0x70, 0x61, 0x8a, 0x70, 0x92, - 0xa7, 0x7f, 0x7f, 0x8d, 0x7a, 0x73, 0xa1, 0x5f, 0x8a, 0x4a, 0x65, 0xaa, - 0x92, 0x6e, 0x98, 0x51, 0x81, 0x47, 0x57, 0xb8, 0x89, 0x50, 0x8a, 0x6d, - 0x8b, 0x50, 0x8a, 0x86, 0x9b, 0x7d, 0x5b, 0x4a, 0x68, 0x74, 0x53, 0x9b, - 0x94, 0x74, 0x7c, 0x6f, 0x62, 0x86, 0x5b, 0x8f, 0x82, 0x96, 0x6e, 0x7c, - 0x80, 0x8f, 0x47, 0x5b, 0x70, 0x95, 0x97, 0x77, 0x8d, 0x8e, 0x69, 0x62, - 0x78, 0x8f, 0xbf, 0x5e, 0x76, 0xae, 0x4d, 0x84, 0x73, 0x76, 0xab, 0x6f, - 0x7f, 0x8c, 0x4b, 0x7d, 0x96, 0x7d, 0xb3, 0x55, 0x78, 0x8d, 0x76, 0x73, - 0x8d, 0x8e, 0x98, 0x6a, 0x91, 0x86, 0x6d, 0x8c, 0x7d, 0x93, 0x97, 0x56, - 0x79, 0x8f, 0xa3, 0x7f, 0x7e, 0x82, 0xa0, 0x63, 0x3d, 0x6b, 0x88, 0x5e, - 0x61, 0xc0, 0x45, 0x5f, 0x66, 0xb0, 0x6c, 0x6d, 0x29, 0xd5, 0x95, 0x3b, - 0x77, 0xaa, 0x62, 0x70, 0x63, 0xce, 0x8c, 0x6e, 0x56, 0xaa, 0x77, 0x6e, - 0x90, 0xcc, 0x6d, 0x7e, 0x41, 0x9f, 0x88, 0x4f, 0x5d, 0xb4, 0x4c, 0x9b, - 0x80, 0x97, 0x98, 0x59, 0x4c, 0x71, 0x53, 0xb4, 0x90, 0x97, 0x93, 0x90, - 0x46, 0x63, 0xa6, 0x87, 0x9d, 0x56, 0x7f, 0xab, 0x8e, 0x68, 0xc6, 0x5d, - 0x6e, 0x58, 0x4b, 0x85, 0xa1, 0x70, 0x8a, 0x60, 0x84, 0x44, 0x68, 0x8e, - 0x9b, 0x3a, 0x8c, 0x57, 0x91, 0x4c, 0x6b, 0x9c, 0xa7, 0x64, 0x82, 0x5f, - 0x68, 0x6d, 0x4d, 0xa1, 0x6c, 0x91, 0x6c, 0x6b, 0x64, 0x97, 0x86, 0x81, - 0x8d, 0x8e, 0x80, 0x72, 0x88, 0x96, 0x5d, 0x6e, 0x7c, 0x67, 0x97, 0x69, - 0x95, 0x93, 0x61, 0x8b, 0x9b, 0x7d, 0xc8, 0x6f, 0x85, 0x80, 0x67, 0x68, - 0x90, 0x6b, 0xcc, 0x7c, 0xa3, 0xa0, 0x58, 0x81, 0x7a, 0x8d, 0x9f, 0x65, - 0x81, 0x82, 0x78, 0x6b, 0x85, 0x7b, 0x9b, 0x69, 0x86, 0x6c, 0x83, 0x6c, - 0x8e, 0x59, 0xab, 0x56, 0x7c, 0x7f, 0x7b, 0x84, 0x71, 0x63, 0x7d, 0x73, - 0x60, 0x8b, 0x7a, 0x7b, 0x5e, 0xbb, 0x4b, 0x40, 0x30, 0xcc, 0x80, 0x65, - 0x6c, 0xb7, 0x80, 0x35, 0x7d, 0xa3, 0x5c, 0x6c, 0x49, 0xa6, 0x9b, 0x7b, - 0x53, 0xba, 0x62, 0x76, 0x78, 0xa0, 0x72, 0x80, 0x78, 0x93, 0x87, 0x62, - 0x64, 0x84, 0x6f, 0xa1, 0x70, 0x90, 0x9a, 0x6b, 0x42, 0x55, 0x6d, 0xc5, - 0xa6, 0x8a, 0x79, 0x64, 0x4c, 0x72, 0x7b, 0xa9, 0xa3, 0x70, 0x84, 0x8f, - 0x63, 0x7a, 0x9c, 0x4e, 0x5a, 0x76, 0x91, 0x67, 0xaf, 0x76, 0xbf, 0x46, - 0x62, 0x3f, 0x7d, 0xa7, 0x8d, 0x62, 0x90, 0x5b, 0x9a, 0x44, 0x51, 0x80, - 0xa6, 0x7e, 0x8d, 0x6a, 0x73, 0x65, 0x72, 0x82, 0x99, 0xb4, 0x6a, 0x75, - 0x85, 0x90, 0x47, 0x62, 0x9e, 0x95, 0x94, 0x78, 0x89, 0x74, 0x5d, 0xa3, - 0x7f, 0x9d, 0x7d, 0x63, 0x96, 0x86, 0x8d, 0xa2, 0x95, 0xab, 0xae, 0x5d, - 0x93, 0x8d, 0x3d, 0x76, 0x9e, 0x9c, 0xc4, 0x71, 0x7d, 0xa3, 0x75, 0x7e, - 0x6d, 0x9d, 0xa3, 0x7f, 0x94, 0x89, 0x47, 0x71, 0x8b, 0x95, 0xb1, 0x72, - 0x90, 0x53, 0x7e, 0x8f, 0x8c, 0x90, 0xa1, 0x4d, 0x59, 0x62, 0x73, 0xa0, - 0x69, 0x88, 0x86, 0x71, 0x60, 0x3b, 0x81, 0x57, 0x7d, 0x86, 0x58, 0x63, - 0x7d, 0x98, 0x74, 0x67, 0x5d, 0xb0, 0x67, 0x45, 0x9b, 0xa9, 0x94, 0x68, - 0x43, 0x8b, 0x85, 0x56, 0x63, 0x96, 0x87, 0x78, 0x88, 0xbf, 0x92, 0x8d, - 0x60, 0xa8, 0x7e, 0x7e, 0x78, 0x80, 0x66, 0x92, 0x6e, 0x97, 0xab, 0x7f, - 0x4f, 0x65, 0x59, 0xb0, 0x9b, 0x6b, 0x9f, 0x70, 0x6f, 0x5c, 0xac, 0x95, - 0xa3, 0x54, 0x8e, 0xa9, 0x9e, 0x8c, 0xa5, 0x66, 0x5f, 0x5b, 0x6c, 0x83, - 0x90, 0x73, 0x85, 0x64, 0x61, 0x51, 0x4a, 0x63, 0xa1, 0x96, 0x7e, 0x4e, - 0x87, 0x60, 0x68, 0xb5, 0x9a, 0x8d, 0x75, 0x4e, 0x8a, 0x7a, 0x5f, 0x9f, - 0x74, 0x80, 0x69, 0x6d, 0x73, 0x92, 0x79, 0x7e, 0x85, 0x68, 0x83, 0x9d, - 0xb6, 0x9d, 0x6e, 0x8f, 0x78, 0x91, 0xaf, 0x8f, 0xa0, 0x9d, 0x73, 0x55, - 0x91, 0x8f, 0xb2, 0x76, 0x97, 0xab, 0x63, 0x63, 0x68, 0x7b, 0xab, 0x5c, - 0x77, 0xae, 0x4c, 0x72, 0x6e, 0x93, 0xb8, 0x51, 0x79, 0x84, 0x7d, 0x6b, - 0x7f, 0x8a, 0xba, 0x68, 0x7a, 0x43, 0x9a, 0x8d, 0x77, 0x8a, 0x6d, 0x56, - 0x79, 0x41, 0x7a, 0x4b, 0x81, 0x7a, 0x5c, 0x68, 0x58, 0x36, 0x6f, 0x6f, - 0x9f, 0xa6, 0x5f, 0x60, 0x4e, 0x67, 0x70, 0x4c, 0x69, 0x69, 0x94, 0x63, - 0x6d, 0x7b, 0x88, 0x9e, 0x6d, 0x98, 0x69, 0x68, 0x88, 0x80, 0x80, 0x7a, - 0x8e, 0x78, 0x5e, 0x8d, 0x7e, 0x91, 0x76, 0x64, 0x7e, 0x7f, 0x4e, 0xc9, - 0x79, 0x8f, 0x9c, 0x82, 0x3d, 0x62, 0x63, 0xc3, 0xb8, 0x7b, 0x72, 0x7b, - 0x50, 0x56, 0x95, 0x72, 0x8f, 0x6b, 0x90, 0x9d, 0x76, 0xa4, 0xa5, 0x79, - 0x54, 0x4f, 0x59, 0x85, 0xc5, 0x92, 0x97, 0x4d, 0x6f, 0x69, 0x77, 0x7f, - 0x71, 0x7c, 0x87, 0x59, 0x98, 0x61, 0x80, 0x81, 0x88, 0x6b, 0x6d, 0x7f, - 0x7f, 0x77, 0x60, 0xa2, 0x96, 0x73, 0x69, 0x86, 0x83, 0x8d, 0x60, 0x66, - 0x88, 0x8c, 0x93, 0x67, 0x98, 0x82, 0x7e, 0x91, 0x99, 0x59, 0x8e, 0x6e, - 0x90, 0xa1, 0x62, 0x8a, 0x98, 0x7b, 0xc8, 0x67, 0x85, 0x8d, 0x6c, 0xa1, - 0xa1, 0x92, 0xd0, 0x49, 0x85, 0x76, 0x89, 0x75, 0x88, 0x83, 0xa3, 0x77, - 0x85, 0x68, 0x82, 0x83, 0x7f, 0x79, 0xae, 0x85, 0x76, 0x84, 0x80, 0x9a, - 0x9d, 0x7b, 0x83, 0x90, 0x79, 0x88, 0x79, 0x9a, 0x93, 0x6c, 0x69, 0x79, - 0x5f, 0x90, 0x81, 0x7b, 0x87, 0x9d, 0x86, 0x82, 0x7a, 0x77, 0x71, 0x85, - 0x8b, 0x99, 0x8f, 0x7b, 0x58, 0x98, 0x84, 0x6e, 0x9a, 0xa1, 0x7a, 0x8c, - 0x77, 0xa8, 0x86, 0x93, 0x7b, 0x90, 0x79, 0x8a, 0x85, 0x8f, 0x84, 0x97, - 0x73, 0x83, 0x7b, 0x76, 0x8e, 0xa1, 0x89, 0x8a, 0x83, 0x9c, 0x65, 0x68, - 0x7b, 0x89, 0x92, 0x84, 0x6d, 0x90, 0x61, 0x78, 0x98, 0x8c, 0x8d, 0x87, - 0xa0, 0x99, 0x79, 0x7b, 0x69, 0xa4, 0x7a, 0x8d, 0x73, 0x71, 0x70, 0x80, - 0x82, 0x77, 0x81, 0x67, 0x75, 0x97, 0x71, 0x73, 0x85, 0x6d, 0x8e, 0x86, - 0x6e, 0x80, 0x86, 0x9e, 0x6f, 0x70, 0x67, 0x59, 0x65, 0x89, 0x67, 0x8b, - 0x7d, 0x68, 0x69, 0x7a, 0x5b, 0x7e, 0x87, 0xa1, 0x92, 0x7b, 0x64, 0x7e, - 0x76, 0x72, 0x71, 0xab, 0x7c, 0x83, 0x6f, 0xa1, 0x86, 0x76, 0x71, 0x6f, - 0x91, 0x77, 0x6c, 0x71, 0x92, 0x78, 0x70, 0x7f, 0x6e, 0x65, 0x77, 0x93, - 0x7e, 0x6c, 0x85, 0x9d, 0x78, 0x8b, 0x7c, 0x5f, 0x94, 0x86, 0x7c, 0x7f, - 0x83, 0x6e, 0x72, 0x9e, 0x6e, 0x6b, 0x8d, 0x91, 0x97, 0x8b, 0x7b, 0x72, - 0x86, 0x75, 0x7f, 0x96, 0x7d, 0x81, 0xa1, 0x55, 0xa6, 0x88, 0x96, 0x87, - 0x93, 0x68, 0x89, 0x72, 0x6f, 0x9c, 0x75, 0x7c, 0x79, 0x6c, 0x74, 0x84, - 0x7d, 0xa4, 0x86, 0x84, 0x84, 0x8d, 0x63, 0x7a, 0x63, 0xbc, 0x7e, 0x93, - 0x80, 0x8d, 0x71, 0x7a, 0x5f, 0x8c, 0x74, 0x96, 0x7e, 0x9b, 0x9d, 0x8d, - 0x5b, 0xa4, 0x71, 0x5e, 0x83, 0x78, 0x86, 0x7f, 0x70, 0x99, 0x87, 0x85, - 0x8e, 0x81, 0x93, 0x80, 0x89, 0xa0, 0x7a, 0x77, 0x8e, 0x73, 0x5f, 0x80, - 0x6d, 0x87, 0x5b, 0x7a, 0x85, 0x7c, 0x85, 0x63, 0x61, 0x9d, 0x6f, 0x68, - 0x77, 0x86, 0x61, 0x6d, 0x84, 0x98, 0x7c, 0x78, 0x69, 0x84, 0x91, 0x6d, - 0x81, 0xa1, 0x6c, 0x62, 0x95, 0x6d, 0x86, 0x8b, 0x95, 0x8f, 0x5e, 0x86, - 0x73, 0xa1, 0x83, 0x58, 0x5f, 0x8e, 0x76, 0x79, 0x9e, 0x92, 0x7c, 0x7b, - 0x81, 0x8b, 0x83, 0x7b, 0x78, 0x75, 0x70, 0x83, 0x70, 0x5a, 0x6a, 0x59, - 0xa3, 0x82, 0x7a, 0x91, 0x8b, 0x6e, 0x82, 0x8e, 0x70, 0x73, 0x91, 0x76, - 0xa5, 0x7f, 0x70, 0x81, 0x6f, 0x85, 0x94, 0xa6, 0x8c, 0x50, 0x76, 0x6e, - 0x64, 0x95, 0xa0, 0x64, 0x6c, 0x68, 0x8e, 0x8b, 0xa1, 0x7d, 0xa0, 0x7f, - 0x76, 0x8b, 0x7b, 0x93, 0x7b, 0x6e, 0x7e, 0x64, 0x8a, 0xa7, 0x78, 0x64, - 0x93, 0x67, 0x7d, 0x68, 0x5c, 0xa0, 0x76, 0x98, 0xaf, 0x80, 0x55, 0x96, - 0x97, 0x9c, 0x78, 0x75, 0x87, 0x85, 0x77, 0x77, 0x62, 0x93, 0x76, 0x68, - 0xa0, 0x80, 0x81, 0x7f, 0x9a, 0x68, 0x74, 0x69, 0x94, 0x77, 0x77, 0x72, - 0x90, 0x9a, 0x6f, 0x95, 0x89, 0x6b, 0x6b, 0x94, 0x7e, 0x9c, 0x6f, 0x67, - 0x8f, 0x82, 0x80, 0x92, 0x76, 0x80, 0x65, 0x9b, 0x6a, 0x7c, 0x75, 0x5a, - 0x87, 0xa1, 0x69, 0x7a, 0x79, 0x9e, 0x9a, 0x58, 0x81, 0x92, 0x72, 0x67, - 0x90, 0x80, 0x82, 0x61, 0x9f, 0x9e, 0x6a, 0x8d, 0x8d, 0x8a, 0x73, 0x81, - 0x68, 0x7f, 0x5b, 0x59, 0x98, 0x89, 0x71, 0x72, 0x58, 0x7b, 0x94, 0x5d, - 0xa9, 0x8b, 0x72, 0x7b, 0x65, 0x73, 0x5b, 0x8b, 0x7d, 0x86, 0x6e, 0x8c, - 0x66, 0x6f, 0x6b, 0x8b, 0x71, 0x80, 0x7f, 0x70, 0x70, 0x88, 0x70, 0x7e, - 0x84, 0x89, 0x7f, 0x81, 0x87, 0x77, 0x71, 0x88, 0x7f, 0x8f, 0x5e, 0x80, - 0x5d, 0xa1, 0x89, 0x77, 0x93, 0x8e, 0x55, 0x64, 0x88, 0x9a, 0x8b, 0x80, - 0x77, 0x6f, 0x91, 0x83, 0x6b, 0x9b, 0x85, 0x5c, 0x57, 0x7e, 0xa9, 0x63, - 0x83, 0xaa, 0x7c, 0xa1, 0x91, 0x5f, 0x68, 0x76, 0x7a, 0x97, 0x96, 0x84, - 0xca, 0x8d, 0x8c, 0x8b, 0x71, 0x81, 0x88, 0x92, 0xaa, 0x74, 0x49, 0x7a, - 0x90, 0x93, 0x7a, 0x61, 0x8c, 0x66, 0x71, 0xa0, 0xab, 0x7d, 0x86, 0x6c, - 0x9f, 0x77, 0x67, 0x6a, 0x89, 0x89, 0x88, 0x70, 0xad, 0x88, 0x69, 0x84, - 0x70, 0x8f, 0x79, 0x7c, 0x66, 0xa6, 0x71, 0x8d, 0x77, 0x99, 0x69, 0x76, - 0x79, 0x7d, 0x9c, 0x6f, 0x64, 0x8b, 0x70, 0x82, 0x69, 0xa4, 0x65, 0x6e, - 0x7f, 0x9e, 0x7e, 0x84, 0x8c, 0x9c, 0x6c, 0x5b, 0x6e, 0xa7, 0x6d, 0x7a, - 0x92, 0x78, 0x9a, 0x6f, 0x81, 0x91, 0x71, 0x7d, 0x6b, 0x99, 0x6b, 0x92, - 0x5e, 0x7e, 0x64, 0x95, 0x78, 0x90, 0x6f, 0x68, 0x8a, 0x85, 0x6f, 0x88, - 0x64, 0x66, 0x7f, 0x78, 0x7c, 0x95, 0x66, 0x6c, 0x76, 0x6a, 0x9b, 0x8f, - 0x9d, 0x78, 0x86, 0x95, 0x73, 0x66, 0x6d, 0x71, 0x8b, 0x7f, 0x6f, 0x70, - 0x64, 0x94, 0xa0, 0x83, 0x6b, 0x6d, 0x85, 0x89, 0x68, 0x92, 0x8e, 0x51, - 0x81, 0x85, 0x86, 0x6e, 0x83, 0x85, 0x8a, 0x5e, 0x68, 0xbf, 0xc4, 0xa5, - 0x8b, 0x67, 0x86, 0x59, 0x85, 0x9e, 0x96, 0x67, 0x82, 0x7c, 0x6c, 0x80, - 0x84, 0xae, 0x9d, 0x80, 0xc2, 0x58, 0x5d, 0x95, 0x85, 0x8b, 0x7f, 0x5d, - 0xc7, 0x75, 0x75, 0x87, 0xa2, 0x8c, 0x62, 0x71, 0x9c, 0x61, 0x7f, 0x9c, - 0xca, 0x8d, 0x89, 0x6e, 0x7c, 0x71, 0x81, 0x99, 0x95, 0xa4, 0x76, 0x6f, - 0x64, 0x7b, 0x6c, 0x72, 0x8b, 0x83, 0x70, 0x70, 0x8b, 0xa4, 0x69, 0x76, - 0x6e, 0x8d, 0x7a, 0x80, 0x8f, 0x9e, 0x73, 0x4b, 0x75, 0x78, 0x77, 0x7b, - 0x8e, 0x92, 0x88, 0x49, 0x54, 0x9f, 0x7a, 0x7f, 0x68, 0x9f, 0x7f, 0x57, - 0x6b, 0xad, 0x85, 0x6f, 0x81, 0xa1, 0x96, 0x6f, 0x73, 0x8d, 0x5e, 0x65, - 0x7a, 0x8c, 0x7c, 0x6a, 0x7e, 0x7a, 0x6a, 0x97, 0x59, 0x86, 0x62, 0x77, - 0x70, 0x7a, 0x68, 0x62, 0x68, 0x86, 0x7e, 0x76, 0x9a, 0x7f, 0x6c, 0x7e, - 0x8a, 0x76, 0x65, 0x8f, 0x7d, 0x65, 0x76, 0xa4, 0x95, 0x62, 0x78, 0x97, - 0x7a, 0x6e, 0x7a, 0x7a, 0x7e, 0x91, 0x8c, 0x8a, 0x91, 0x82, 0x89, 0x6d, - 0x87, 0x90, 0x69, 0x71, 0x96, 0xa6, 0x7c, 0x7c, 0xa8, 0xa8, 0x62, 0x77, - 0x76, 0x99, 0xdd, 0x76, 0x8a, 0x5c, 0x86, 0x6a, 0x69, 0x9c, 0xa5, 0x7d, - 0x78, 0x6a, 0x88, 0x77, 0x77, 0xae, 0x8a, 0x99, 0xcb, 0x85, 0x59, 0x84, - 0x7b, 0x97, 0x8a, 0x82, 0xc5, 0x65, 0x8c, 0x93, 0xc3, 0x8c, 0x87, 0x64, - 0x91, 0x41, 0x70, 0xa8, 0xd1, 0x8b, 0x82, 0x71, 0x9c, 0x71, 0x4e, 0x86, - 0x98, 0x86, 0x7f, 0x7e, 0x69, 0x99, 0x79, 0x78, 0x77, 0xb3, 0x6b, 0x80, - 0x84, 0x8b, 0x56, 0x73, 0x84, 0x95, 0x82, 0x94, 0x5b, 0x92, 0x83, 0x46, - 0x66, 0x89, 0x6d, 0x61, 0x99, 0xa6, 0x99, 0x3f, 0x6c, 0xab, 0x5d, 0x5f, - 0x6c, 0x8e, 0x6b, 0x4a, 0x72, 0xb6, 0x6c, 0x75, 0x78, 0xa6, 0x6f, 0x5b, - 0x56, 0x8b, 0x57, 0x74, 0x8f, 0xab, 0x53, 0x56, 0x5d, 0x63, 0x63, 0x8b, - 0x65, 0x78, 0x71, 0x67, 0x7a, 0x62, 0x8d, 0x78, 0x99, 0x76, 0x94, 0x7a, - 0xa3, 0x70, 0x55, 0x87, 0x7e, 0x7c, 0x57, 0x57, 0x6e, 0x79, 0x94, 0x8f, - 0x86, 0x80, 0x90, 0x7d, 0x7d, 0x7f, 0x7f, 0x68, 0x41, 0x86, 0x8c, 0x6f, - 0x8a, 0x7f, 0x87, 0x8a, 0x7e, 0x7f, 0x5d, 0x71, 0x91, 0x81, 0x93, 0x71, - 0x91, 0xc6, 0x70, 0x4a, 0x74, 0xa8, 0xf3, 0x72, 0xa7, 0x80, 0x7e, 0x41, - 0x84, 0xa3, 0xb6, 0x94, 0xba, 0x84, 0x70, 0x74, 0x71, 0xac, 0x9f, 0x9d, - 0xe4, 0x67, 0x6a, 0x87, 0x92, 0x8e, 0x92, 0x82, 0xdb, 0x5e, 0x9b, 0x90, - 0xd5, 0x87, 0x8d, 0x7c, 0x9c, 0x3c, 0x6c, 0xab, 0xc2, 0x86, 0x83, 0x79, - 0x6c, 0x61, 0x51, 0xa9, 0x99, 0x79, 0x72, 0x80, 0x6f, 0x85, 0x57, 0x6c, - 0x81, 0x86, 0x6e, 0x88, 0x87, 0x8d, 0x8e, 0x81, 0x67, 0x88, 0x62, 0x99, - 0x87, 0xab, 0x8f, 0x57, 0x60, 0x77, 0x64, 0x81, 0x96, 0xa3, 0x81, 0x3d, - 0x4e, 0xb9, 0x57, 0x6e, 0x99, 0xad, 0x6a, 0x3e, 0x74, 0x96, 0x7e, 0x79, - 0x65, 0xa4, 0x7c, 0x6a, 0x53, 0x87, 0x56, 0x6f, 0x5e, 0x97, 0x85, 0x42, - 0x56, 0x6b, 0x67, 0x78, 0x7d, 0xa6, 0x7c, 0x7c, 0x7d, 0x78, 0x7b, 0x84, - 0x99, 0x7b, 0x89, 0x71, 0x76, 0x8b, 0x76, 0x73, 0x7d, 0x83, 0x56, 0x4f, - 0x86, 0x72, 0x83, 0x88, 0x6a, 0x93, 0x69, 0x90, 0x6c, 0x73, 0x6f, 0x63, - 0x55, 0x88, 0x6b, 0x88, 0x7c, 0x86, 0x87, 0x7b, 0x6c, 0x7e, 0x60, 0x57, - 0xa8, 0x81, 0xa3, 0x72, 0xba, 0xbf, 0x66, 0x65, 0x70, 0xb9, 0xe4, 0x78, - 0x99, 0x67, 0x8c, 0x72, 0x88, 0x96, 0xb5, 0x72, 0x8a, 0x66, 0x81, 0x39, - 0x85, 0x93, 0xa0, 0x9c, 0xdf, 0x74, 0x8a, 0x6d, 0x93, 0xa1, 0x8c, 0x7a, - 0xb5, 0x4b, 0x89, 0xae, 0xba, 0x9c, 0x96, 0x9a, 0xb4, 0x33, 0x5a, 0xb1, - 0xcd, 0x88, 0x84, 0x63, 0x8c, 0x5e, 0x71, 0x6d, 0xa7, 0x8a, 0x62, 0x85, - 0x77, 0x75, 0x62, 0x79, 0x96, 0x73, 0x4f, 0x7d, 0x93, 0x8a, 0x88, 0x7e, - 0x59, 0x6c, 0x7f, 0x87, 0x6f, 0x91, 0x88, 0x59, 0x6d, 0x83, 0x70, 0x7c, - 0x7f, 0x8d, 0x7f, 0x26, 0x41, 0xcf, 0x6b, 0x6e, 0x75, 0xa3, 0x90, 0x5e, - 0x3a, 0x94, 0x61, 0x9a, 0x6f, 0x9f, 0x69, 0x7d, 0x55, 0x8c, 0x60, 0x7c, - 0x93, 0x85, 0x85, 0x4b, 0x54, 0x71, 0x60, 0x8a, 0x6d, 0x8c, 0x9c, 0x7e, - 0x5b, 0x79, 0x74, 0x7b, 0x7b, 0x9d, 0x5b, 0x65, 0x81, 0x82, 0x66, 0x89, - 0x82, 0x72, 0x77, 0x78, 0x75, 0x76, 0x6b, 0x74, 0x89, 0x73, 0x6c, 0x6b, - 0x77, 0x7e, 0x67, 0x84, 0x41, 0x90, 0x58, 0x87, 0x98, 0x60, 0x96, 0x81, - 0x6b, 0x74, 0x7d, 0x56, 0x72, 0x71, 0x9a, 0x7d, 0xc5, 0xd0, 0x88, 0x6e, - 0x4d, 0xbe, 0xef, 0x8a, 0xa7, 0x92, 0x82, 0x67, 0x7f, 0x91, 0xc5, 0x7d, - 0xad, 0x77, 0x6b, 0x4e, 0x8e, 0x99, 0x9b, 0x8e, 0xc7, 0x7f, 0x8a, 0x8e, - 0x8f, 0x87, 0x9c, 0x75, 0xb0, 0x53, 0x75, 0x97, 0xc7, 0x98, 0xa4, 0xa4, - 0x80, 0x41, 0x79, 0xc3, 0xdb, 0x86, 0x9d, 0x75, 0x7f, 0x67, 0x7a, 0x96, - 0xc3, 0x83, 0x54, 0x8e, 0x6f, 0xa8, 0x7c, 0x65, 0x78, 0x7e, 0x59, 0xa3, - 0x8a, 0x97, 0x8b, 0x82, 0x5e, 0x66, 0x82, 0x9b, 0x9e, 0x9f, 0x70, 0x49, - 0x55, 0x88, 0x8a, 0x7e, 0x90, 0xa7, 0x6b, 0x3b, 0x28, 0xc0, 0x63, 0x7e, - 0x60, 0x90, 0x7c, 0x3f, 0x54, 0x9c, 0x7d, 0x8a, 0x6a, 0xa9, 0x6f, 0x61, - 0x76, 0x86, 0x64, 0x88, 0x72, 0xa5, 0x6b, 0x4d, 0x56, 0x6c, 0x52, 0xa1, - 0x84, 0x69, 0x69, 0x5b, 0x71, 0x84, 0x76, 0x9b, 0x92, 0x70, 0x86, 0x8b, - 0x71, 0x68, 0x56, 0x92, 0x76, 0x8f, 0x8f, 0x72, 0x5a, 0x77, 0x6f, 0x92, - 0x72, 0x72, 0x5e, 0x7a, 0x70, 0x73, 0x60, 0x7d, 0x5a, 0x93, 0x7f, 0x6b, - 0x89, 0x6b, 0xa1, 0x85, 0x5c, 0x8d, 0x76, 0x7c, 0x6f, 0x73, 0x96, 0x6d, - 0xbb, 0xad, 0x53, 0x53, 0x5f, 0x9a, 0xe2, 0x8d, 0xa7, 0x6d, 0x8a, 0x5b, - 0x85, 0x9c, 0xb4, 0x7b, 0xb3, 0x52, 0x75, 0x7f, 0x7a, 0x8c, 0x91, 0x7e, - 0xca, 0x5f, 0x64, 0x71, 0x85, 0x9a, 0x91, 0x72, 0xbd, 0x6e, 0x9b, 0x81, - 0x8f, 0xa8, 0xac, 0x7d, 0xb4, 0x5f, 0x45, 0xc5, 0xc8, 0x7a, 0x93, 0x8e, - 0x7b, 0x41, 0x69, 0x94, 0x8b, 0x76, 0x59, 0x81, 0x73, 0x92, 0x8e, 0x63, - 0x8e, 0x74, 0x33, 0xa5, 0x9c, 0xa2, 0x88, 0x48, 0x5d, 0x8c, 0x7d, 0xa6, - 0x68, 0x9a, 0x6f, 0x58, 0x6c, 0x8f, 0x77, 0x65, 0x97, 0x9d, 0x7a, 0x37, - 0x59, 0xab, 0x6e, 0x8f, 0x7a, 0xae, 0x65, 0x3e, 0x46, 0xa9, 0x82, 0x82, - 0x9c, 0x9d, 0x62, 0x79, 0x66, 0x7f, 0x5e, 0x88, 0x9e, 0x8f, 0x84, 0x71, - 0x5d, 0x6d, 0x70, 0xa0, 0x69, 0x92, 0x7f, 0x70, 0x66, 0x6f, 0x75, 0x8c, - 0x96, 0x7a, 0x85, 0x6a, 0x5a, 0x7c, 0x72, 0x8a, 0x8d, 0x7b, 0x8b, 0x5c, - 0x76, 0x69, 0x70, 0x7f, 0x74, 0xa1, 0x71, 0x91, 0x5a, 0x8c, 0x6e, 0x83, - 0x52, 0x78, 0x71, 0x6d, 0xa9, 0x63, 0x9d, 0x81, 0x52, 0x9e, 0x5d, 0x60, - 0x76, 0x93, 0x97, 0x67, 0xce, 0xc1, 0x75, 0x5e, 0x5f, 0x8c, 0xea, 0x76, - 0xad, 0x7a, 0x7d, 0x62, 0x85, 0x92, 0xd0, 0x6a, 0xbc, 0x53, 0x55, 0x5c, - 0x6d, 0x89, 0x9e, 0x71, 0xd2, 0x8b, 0x64, 0x61, 0x85, 0x9a, 0x77, 0x75, - 0xb9, 0x67, 0x8a, 0xac, 0x90, 0x8a, 0xb4, 0x91, 0xbb, 0x58, 0x94, 0xaf, - 0xb2, 0x76, 0xa2, 0x71, 0x95, 0x5e, 0x73, 0xa5, 0x92, 0x8c, 0x52, 0x96, - 0x53, 0x95, 0x84, 0x91, 0x93, 0x7a, 0x40, 0x88, 0xab, 0xa5, 0x63, 0x70, - 0x66, 0x88, 0x7e, 0x92, 0x89, 0x84, 0x78, 0x57, 0x3d, 0x8d, 0x84, 0x77, - 0x9b, 0x87, 0x5e, 0x4e, 0x42, 0xa0, 0x76, 0x8a, 0x77, 0x90, 0x83, 0x4c, - 0x42, 0x9b, 0x75, 0x7a, 0x88, 0x94, 0x98, 0x69, 0x4c, 0xa2, 0x6b, 0x7b, - 0x6e, 0x9b, 0x5d, 0x5f, 0x53, 0x6a, 0x63, 0x95, 0x69, 0x8a, 0x61, 0x75, - 0x6c, 0x7a, 0x58, 0x89, 0x84, 0x8f, 0x6b, 0x5a, 0x71, 0x6f, 0x59, 0x89, - 0x7d, 0x87, 0x5f, 0x77, 0x4b, 0x61, 0x77, 0x92, 0x67, 0x8e, 0x5c, 0x6f, - 0x5b, 0x77, 0x76, 0x6b, 0x44, 0x9d, 0x9f, 0x7f, 0x8b, 0x94, 0x9e, 0x7c, - 0x62, 0x94, 0x60, 0x55, 0x77, 0x8f, 0xa6, 0x62, 0xb5, 0xb2, 0x3c, 0x61, - 0x5c, 0x99, 0xeb, 0x5b, 0x90, 0x6c, 0x7f, 0x5f, 0x75, 0xa6, 0xcf, 0x77, - 0x98, 0x5d, 0x75, 0x69, 0x7f, 0x8a, 0xa7, 0x73, 0xc8, 0x74, 0x70, 0x82, - 0x76, 0x8f, 0xa2, 0x7a, 0xa4, 0x7a, 0x66, 0x81, 0x9b, 0x8f, 0x9e, 0x8b, - 0xa1, 0x51, 0x7b, 0xba, 0xc8, 0x90, 0xab, 0x92, 0x72, 0x57, 0x5b, 0xa3, - 0xb0, 0x7f, 0x4c, 0x7d, 0x5f, 0x8e, 0x6c, 0x7d, 0x71, 0x7e, 0x4e, 0x87, - 0xb7, 0x97, 0x7a, 0x4c, 0x5f, 0x72, 0x78, 0x84, 0x82, 0x7e, 0x63, 0x65, - 0x68, 0x78, 0x73, 0x85, 0x90, 0x99, 0x80, 0x57, 0x42, 0x8b, 0x8a, 0x77, - 0x71, 0x97, 0x6d, 0x44, 0x41, 0x8f, 0x78, 0x7d, 0x95, 0x81, 0x95, 0x5f, - 0x64, 0x87, 0x66, 0x80, 0x89, 0x9a, 0x61, 0x4d, 0x68, 0x7b, 0x72, 0x73, - 0x85, 0x92, 0x77, 0x7d, 0x73, 0x77, 0x54, 0x7a, 0x77, 0x7d, 0x7d, 0x7a, - 0x6e, 0x8e, 0x4f, 0x7d, 0x80, 0x9a, 0x79, 0x8b, 0x7b, 0x68, 0x6e, 0x86, - 0x7f, 0x93, 0x7a, 0x76, 0x72, 0x85, 0x6a, 0x7b, 0x57, 0x84, 0x96, 0x9a, - 0x8f, 0x91, 0x9b, 0x72, 0x73, 0x91, 0x53, 0x66, 0x76, 0x80, 0xae, 0x63, - 0xbf, 0x99, 0x5e, 0x77, 0x73, 0x9c, 0xd8, 0x74, 0xa7, 0x79, 0x52, 0x64, - 0x82, 0x95, 0xc7, 0x4f, 0xa8, 0x4f, 0x6d, 0x42, 0x7c, 0x89, 0xab, 0x83, - 0xc0, 0x82, 0x6a, 0x5f, 0x83, 0x92, 0xa8, 0x76, 0xc1, 0x77, 0x6e, 0x7b, - 0xa3, 0x9b, 0xaf, 0x87, 0xab, 0x60, 0x8d, 0xc2, 0xd2, 0x83, 0xb2, 0x78, - 0x8d, 0x39, 0x57, 0x9c, 0x90, 0x8e, 0x6e, 0x6a, 0x74, 0x79, 0x81, 0x6d, - 0x6f, 0x8e, 0x77, 0x92, 0x93, 0x7d, 0x5f, 0x68, 0x6a, 0x6c, 0x80, 0x8f, - 0x99, 0x84, 0x4f, 0x64, 0x5c, 0x93, 0x7c, 0x91, 0x98, 0x82, 0x62, 0x3f, - 0x41, 0x9f, 0x5d, 0x89, 0x98, 0x89, 0x73, 0x50, 0x32, 0xa8, 0xa0, 0x7a, - 0xa0, 0x95, 0x78, 0x69, 0x74, 0x7c, 0x89, 0x7b, 0x80, 0x65, 0x56, 0x6b, - 0x69, 0x78, 0x62, 0x87, 0xaf, 0x94, 0x7a, 0x64, 0x53, 0x86, 0x45, 0x99, - 0x88, 0x79, 0x4d, 0x74, 0x59, 0x91, 0x5f, 0x7b, 0x88, 0x90, 0x80, 0x86, - 0x7d, 0x7b, 0x64, 0xa3, 0x7f, 0x74, 0x89, 0x80, 0x7d, 0x7c, 0x7a, 0x87, - 0x5f, 0x8a, 0x5a, 0x72, 0x79, 0x74, 0x8c, 0x7c, 0x86, 0x91, 0x6e, 0x5d, - 0x61, 0x8e, 0xa2, 0x68, 0xd4, 0x92, 0x67, 0x66, 0x62, 0xa1, 0xf3, 0x63, - 0x91, 0x81, 0x74, 0x5f, 0x88, 0x98, 0xbb, 0x5a, 0x9b, 0x54, 0x6a, 0x5c, - 0x75, 0x88, 0xad, 0x7c, 0xb4, 0x7c, 0x69, 0x74, 0x84, 0x76, 0x9d, 0x9a, - 0xb0, 0x91, 0x5d, 0xa3, 0xa4, 0x7f, 0xbb, 0x80, 0xa4, 0x5d, 0x83, 0xaf, - 0xb7, 0x66, 0xb0, 0x7f, 0x89, 0x4b, 0x72, 0x9e, 0x99, 0x7c, 0x66, 0x71, - 0x6a, 0x6f, 0x6d, 0x67, 0x8d, 0x6d, 0x46, 0xa5, 0x9b, 0x84, 0x7a, 0x61, - 0x64, 0x5c, 0x88, 0x89, 0x95, 0x8c, 0x70, 0x4b, 0x6c, 0x85, 0x83, 0x8b, - 0x98, 0x87, 0x6a, 0x44, 0x4d, 0x9d, 0x78, 0x71, 0x78, 0x7e, 0x91, 0x5b, - 0x3f, 0x9f, 0x80, 0x62, 0xa7, 0x95, 0x5d, 0x74, 0x65, 0x9c, 0x6d, 0x7a, - 0x98, 0x79, 0x80, 0x61, 0x49, 0x82, 0x65, 0x92, 0x80, 0x96, 0x7c, 0x72, - 0x4f, 0x76, 0x5e, 0x8d, 0x97, 0xa5, 0x72, 0x57, 0x79, 0x87, 0x67, 0x87, - 0x80, 0x84, 0x7c, 0x6f, 0x66, 0x6b, 0x70, 0x9b, 0x64, 0x90, 0x59, 0x96, - 0x7a, 0x6f, 0x75, 0x89, 0x4e, 0x8a, 0x62, 0x6e, 0x9c, 0x8c, 0x9a, 0x78, - 0x8e, 0x91, 0x3d, 0x50, 0x72, 0x92, 0x9f, 0x63, 0xda, 0x92, 0x72, 0x60, - 0x59, 0xa6, 0xd0, 0x56, 0xc1, 0x6b, 0x5e, 0x76, 0x6e, 0x81, 0xbb, 0x4b, - 0xbb, 0x59, 0x68, 0x4f, 0x77, 0x87, 0xa1, 0x73, 0xbf, 0x65, 0x56, 0x67, - 0x77, 0x84, 0x8a, 0x7e, 0xb8, 0x85, 0x66, 0xa6, 0x99, 0xa0, 0xa5, 0x73, - 0x8d, 0x4a, 0x7d, 0xab, 0xb0, 0x6a, 0x94, 0x84, 0x87, 0x4c, 0x74, 0xa3, - 0xb3, 0xa9, 0x62, 0x7a, 0x71, 0x7f, 0x53, 0x79, 0x7a, 0x7c, 0x5e, 0x8f, - 0xa0, 0x90, 0x5c, 0x76, 0x6c, 0x92, 0x70, 0x9c, 0xb3, 0x8b, 0x7e, 0x57, - 0x5b, 0x9d, 0x96, 0x85, 0x70, 0x93, 0x8b, 0x67, 0x4c, 0x9c, 0x6a, 0x83, - 0x84, 0x90, 0x8e, 0x60, 0x56, 0xb3, 0x87, 0x7d, 0x86, 0x88, 0x79, 0x5b, - 0x58, 0x94, 0x92, 0x8e, 0x90, 0x76, 0x58, 0x51, 0x52, 0x63, 0x57, 0x88, - 0x9b, 0x7a, 0x85, 0x6c, 0x8b, 0x87, 0x5f, 0x8b, 0x90, 0x92, 0x81, 0x64, - 0x52, 0x8b, 0x77, 0x94, 0x96, 0x98, 0x69, 0x5b, 0x79, 0x87, 0x61, 0x96, - 0x7b, 0x9a, 0x61, 0x74, 0x7e, 0x8b, 0x82, 0x92, 0x4f, 0x87, 0x7f, 0x80, - 0x74, 0x97, 0x98, 0x7a, 0x79, 0x97, 0x65, 0x67, 0x66, 0xb1, 0xb1, 0x49, - 0xd6, 0x97, 0x58, 0x47, 0x62, 0x94, 0xd5, 0x82, 0xa0, 0x60, 0x3f, 0x67, - 0x6c, 0x9d, 0xb6, 0x58, 0xb1, 0x6e, 0x58, 0x4e, 0x7c, 0x83, 0x8b, 0x83, - 0xd5, 0x62, 0x8d, 0x84, 0x84, 0x8c, 0xa9, 0x6e, 0xac, 0x7f, 0x6d, 0x88, - 0xab, 0x8b, 0xb1, 0x77, 0x9b, 0x46, 0x76, 0xa7, 0xb8, 0x7b, 0xc5, 0x6e, - 0x73, 0x62, 0x68, 0x95, 0xab, 0x7c, 0x6f, 0x74, 0x56, 0x71, 0x61, 0x83, - 0x8a, 0x73, 0x54, 0x94, 0x86, 0x91, 0x60, 0x69, 0x65, 0x6b, 0x76, 0x85, - 0xae, 0x87, 0x8f, 0x55, 0x41, 0x98, 0x68, 0x87, 0x5e, 0x7a, 0x80, 0x38, - 0x50, 0xaf, 0x93, 0x79, 0x57, 0x96, 0x7b, 0x53, 0x4e, 0xc0, 0xa0, 0x85, - 0x87, 0x95, 0x86, 0x70, 0x4c, 0x9f, 0x77, 0x7d, 0x8b, 0x7a, 0x7b, 0x6d, - 0x57, 0x74, 0x81, 0x7d, 0xa2, 0x79, 0x64, 0x6c, 0x55, 0x70, 0x3c, 0x88, - 0x8a, 0x7a, 0x58, 0x72, 0x71, 0x7d, 0x6a, 0x8d, 0x78, 0x7e, 0x95, 0x8b, - 0x84, 0x7e, 0x73, 0x7c, 0x7e, 0x67, 0x89, 0x8b, 0x6d, 0x68, 0x66, 0x73, - 0x5a, 0x93, 0x82, 0x85, 0x97, 0x6b, 0x9a, 0x72, 0x51, 0xa2, 0x4f, 0x67, - 0x67, 0x7e, 0xbb, 0x37, 0xe3, 0x9c, 0x57, 0x5b, 0x6f, 0xa0, 0xdc, 0x5c, - 0xa6, 0x7c, 0x71, 0x77, 0x72, 0x88, 0xd0, 0x4d, 0x93, 0x58, 0x74, 0x6d, - 0x8f, 0x77, 0xa3, 0x76, 0xb7, 0x76, 0x6d, 0x6d, 0x6f, 0x7b, 0xaa, 0x6d, - 0xaa, 0x6a, 0x72, 0x98, 0x8d, 0x98, 0xb0, 0x52, 0x76, 0x5d, 0x61, 0xb7, - 0xac, 0x90, 0xa5, 0x75, 0x7e, 0x3d, 0x5b, 0x9a, 0xbf, 0x81, 0x83, 0x7b, - 0x5c, 0x77, 0x74, 0x82, 0x8d, 0x7e, 0x4f, 0x9f, 0x8f, 0x97, 0x7c, 0x75, - 0x5b, 0x73, 0x97, 0x73, 0x85, 0x7f, 0x70, 0x5a, 0x53, 0x81, 0x81, 0x89, - 0x73, 0x8d, 0x8a, 0x5c, 0x5f, 0x84, 0x86, 0x6f, 0x76, 0x78, 0x82, 0x6d, - 0x4f, 0xbb, 0x91, 0x61, 0x7e, 0x97, 0x6c, 0x67, 0x62, 0x83, 0x61, 0x7d, - 0x89, 0x76, 0x7b, 0x67, 0x56, 0x74, 0x49, 0x7b, 0x6b, 0x8b, 0x89, 0x74, - 0x5b, 0x7f, 0x78, 0x7b, 0x80, 0x7e, 0x63, 0x71, 0x5e, 0x91, 0x81, 0x92, - 0x7b, 0x90, 0x9c, 0x7a, 0x73, 0x85, 0x79, 0x9b, 0x66, 0x93, 0x60, 0x87, - 0x79, 0x69, 0x73, 0x8b, 0x53, 0x8c, 0x8d, 0x68, 0x93, 0xa0, 0x91, 0x65, - 0x57, 0x8d, 0x71, 0x65, 0x6c, 0x7e, 0xb3, 0x4f, 0xc7, 0xaa, 0x5a, 0x77, - 0x6e, 0x85, 0xe4, 0x6c, 0xa3, 0x89, 0x69, 0x54, 0x6d, 0x99, 0xb9, 0x77, - 0xa0, 0x80, 0x85, 0x71, 0x70, 0x78, 0x99, 0x66, 0xaf, 0x8a, 0x59, 0x64, - 0x54, 0x62, 0xbf, 0x5c, 0xbd, 0x77, 0x7f, 0xab, 0x95, 0x85, 0xaa, 0x6e, - 0xaa, 0x5a, 0x7b, 0x9f, 0xc3, 0x65, 0x93, 0x64, 0x7c, 0x2d, 0x4e, 0x8f, - 0xb2, 0x5f, 0x4e, 0x61, 0x64, 0x73, 0x56, 0x75, 0x79, 0x90, 0x5c, 0x81, - 0x8a, 0x8c, 0x70, 0x64, 0x74, 0x86, 0x86, 0x82, 0xab, 0x7e, 0x62, 0x4f, - 0x51, 0x89, 0x7b, 0x88, 0x73, 0x97, 0x77, 0x75, 0x5c, 0x9e, 0x97, 0x70, - 0x5a, 0x98, 0x7a, 0x54, 0x47, 0x99, 0xab, 0x5d, 0x91, 0xa0, 0x64, 0x51, - 0x57, 0x88, 0x88, 0x85, 0x81, 0x83, 0xa1, 0x89, 0x6a, 0x88, 0x69, 0x81, - 0x92, 0x63, 0x6a, 0x71, 0x72, 0x6a, 0x75, 0x8e, 0x90, 0x9d, 0x69, 0x60, - 0x73, 0x95, 0x79, 0x7b, 0x79, 0x7f, 0x77, 0x6e, 0x69, 0x63, 0x60, 0xa0, - 0x84, 0x91, 0x80, 0x96, 0x92, 0x70, 0x69, 0x7c, 0x3f, 0x90, 0x5c, 0x79, - 0x82, 0x63, 0x8d, 0x63, 0x56, 0x8a, 0x8e, 0x7a, 0x5c, 0x8d, 0xb8, 0x4e, - 0xb6, 0x84, 0x57, 0x79, 0x59, 0x79, 0xe8, 0x7e, 0xa8, 0x71, 0x61, 0x62, - 0x89, 0x71, 0xb7, 0x83, 0x7b, 0x53, 0x86, 0x88, 0x74, 0x71, 0xb1, 0x61, - 0xae, 0x7e, 0x8f, 0x69, 0x6b, 0x69, 0xb2, 0x6d, 0xb1, 0x7f, 0x5c, 0x9f, - 0xaa, 0x8c, 0xbd, 0x74, 0xaa, 0x5b, 0x7f, 0xa5, 0xb0, 0x6e, 0xc1, 0x5c, - 0x94, 0x34, 0x5b, 0xa6, 0xbc, 0x49, 0x75, 0x5b, 0x6e, 0x74, 0x7a, 0x92, - 0x92, 0x79, 0x78, 0x8a, 0x9e, 0x97, 0x7c, 0x5f, 0x76, 0x86, 0x59, 0x81, - 0x83, 0x7a, 0x65, 0x5b, 0x42, 0x95, 0x84, 0x99, 0x81, 0x8d, 0x6a, 0x5e, - 0x59, 0xb7, 0x96, 0x8a, 0x77, 0x86, 0x7a, 0x67, 0x3b, 0xa8, 0xae, 0x7a, - 0xa0, 0x97, 0x6c, 0x73, 0x5b, 0x9b, 0x77, 0x84, 0x7a, 0x77, 0x75, 0x6f, - 0x7d, 0x7a, 0x71, 0x86, 0x6c, 0x6f, 0x7d, 0x71, 0x68, 0x60, 0x64, 0x86, - 0x90, 0x75, 0x6a, 0x61, 0x60, 0x87, 0x68, 0x99, 0x87, 0x7e, 0x92, 0x87, - 0x87, 0x5f, 0x60, 0x91, 0x68, 0x8c, 0x7b, 0x67, 0x79, 0x5d, 0x67, 0x77, - 0x47, 0x72, 0x76, 0x88, 0x82, 0xa2, 0x7a, 0x5d, 0x64, 0x87, 0x75, 0x78, - 0x5e, 0x6f, 0xa4, 0x52, 0xc2, 0x9d, 0x81, 0x89, 0x55, 0x86, 0xc9, 0x6f, - 0x95, 0x71, 0x9d, 0x87, 0x95, 0x74, 0xac, 0x7f, 0x95, 0x6c, 0x68, 0x66, - 0x8a, 0x5f, 0x96, 0x69, 0x95, 0x79, 0x7f, 0x71, 0x86, 0x7e, 0x98, 0x71, - 0xac, 0x8f, 0x75, 0xa5, 0xac, 0x7a, 0xca, 0x63, 0xa0, 0x63, 0x69, 0xbf, - 0xae, 0x62, 0xc9, 0x46, 0x74, 0x2c, 0x66, 0x96, 0xb7, 0x70, 0x7c, 0x6b, - 0x7b, 0x90, 0x72, 0x74, 0x8d, 0x5f, 0x63, 0x93, 0x97, 0x78, 0x79, 0x64, - 0x67, 0x84, 0x64, 0x82, 0x90, 0x83, 0x91, 0x5f, 0x72, 0x93, 0x91, 0xae, - 0x6d, 0x99, 0x5b, 0x69, 0x54, 0x9f, 0x97, 0x80, 0x80, 0xa4, 0x91, 0x66, - 0x65, 0xa4, 0xa7, 0x7b, 0x97, 0x87, 0x72, 0x68, 0x6a, 0x96, 0x7b, 0x79, - 0x69, 0x83, 0x6f, 0x85, 0x6b, 0x92, 0x7f, 0x71, 0x84, 0x87, 0x6a, 0x7b, - 0x63, 0x72, 0x5f, 0x87, 0x98, 0x7b, 0x96, 0x71, 0x62, 0x90, 0x71, 0xa3, - 0x8c, 0x77, 0x90, 0x6f, 0x83, 0x76, 0x65, 0x87, 0x72, 0x8a, 0x64, 0x87, - 0x75, 0x75, 0x6d, 0x84, 0x54, 0x89, 0x88, 0xa0, 0x87, 0x73, 0x7f, 0x6f, - 0x5f, 0x90, 0x5e, 0x94, 0x5d, 0x61, 0xa6, 0x56, 0xb3, 0x91, 0x95, 0x75, - 0x4d, 0x74, 0xd9, 0x87, 0x92, 0x74, 0x7f, 0x79, 0x97, 0x6e, 0x90, 0x54, - 0x84, 0x5d, 0x5f, 0x75, 0x8b, 0x84, 0xa6, 0x75, 0xb4, 0x77, 0x78, 0x85, - 0x90, 0x76, 0xbd, 0x78, 0xd1, 0xa0, 0x5d, 0x96, 0xa9, 0x7c, 0xc1, 0x61, - 0xc2, 0x71, 0x8b, 0xa5, 0xa5, 0x5b, 0xc8, 0x50, 0x7b, 0x4b, 0x93, 0x99, - 0xae, 0x72, 0x67, 0x54, 0x81, 0x89, 0x96, 0x81, 0x6e, 0x68, 0x55, 0x7f, - 0x93, 0x8c, 0x5e, 0x65, 0x6c, 0x84, 0x7f, 0x8f, 0x9e, 0x7b, 0x73, 0x7f, - 0x51, 0x63, 0x8a, 0x8b, 0x6b, 0x9b, 0x9d, 0x57, 0x68, 0x89, 0x98, 0x70, - 0x73, 0xa3, 0x7f, 0x69, 0x44, 0x89, 0xae, 0x68, 0x89, 0x80, 0x7e, 0x6d, - 0x70, 0x95, 0x85, 0x65, 0x91, 0x7f, 0x66, 0x74, 0x96, 0x72, 0x60, 0x7a, - 0x87, 0x85, 0x79, 0x54, 0x53, 0x6c, 0x88, 0x87, 0xa9, 0x90, 0x75, 0x8b, - 0x69, 0x98, 0x7d, 0x95, 0x85, 0x7a, 0x8b, 0x82, 0x87, 0x6f, 0x86, 0x7f, - 0x74, 0xab, 0x93, 0x6c, 0x8a, 0x78, 0x68, 0x81, 0x62, 0x88, 0x78, 0x91, - 0x8b, 0x55, 0xa7, 0x58, 0x64, 0x88, 0x71, 0x93, 0x7d, 0x69, 0xbc, 0x58, - 0xbe, 0x9a, 0x6f, 0x74, 0x6f, 0x7f, 0xeb, 0x9e, 0xb7, 0x60, 0x63, 0x98, - 0x82, 0x77, 0x94, 0x63, 0x80, 0x6f, 0x7d, 0x8f, 0x8b, 0x85, 0xa5, 0x62, - 0xad, 0x86, 0x5f, 0x76, 0x88, 0x74, 0xa5, 0x66, 0xa5, 0x94, 0x88, 0x9b, - 0x87, 0x9e, 0xa8, 0x5a, 0xc9, 0x81, 0x92, 0xcd, 0xb5, 0x67, 0xb9, 0x63, - 0x86, 0x65, 0x8d, 0xad, 0x98, 0x7c, 0x8a, 0x40, 0x67, 0x65, 0x60, 0x71, - 0x8e, 0x84, 0x73, 0x64, 0x98, 0x80, 0x73, 0x81, 0x48, 0x75, 0x71, 0x9e, - 0x73, 0x89, 0x89, 0x68, 0x73, 0xa6, 0x84, 0x8a, 0x7e, 0x9f, 0x78, 0x83, - 0x60, 0x77, 0xa1, 0x87, 0x76, 0xab, 0x74, 0x57, 0x6d, 0x99, 0xa5, 0x5e, - 0x9d, 0x91, 0x6d, 0x6a, 0x76, 0x9c, 0x7b, 0x66, 0x96, 0x84, 0x85, 0x6e, - 0x6c, 0x75, 0x86, 0x6a, 0x71, 0x67, 0x8a, 0x66, 0x66, 0x68, 0x73, 0x90, - 0x92, 0x68, 0x8f, 0x71, 0x82, 0x7e, 0x71, 0xad, 0x9f, 0x84, 0x9e, 0x7d, - 0x77, 0x6b, 0x67, 0x8f, 0x73, 0x9a, 0x91, 0x74, 0x8a, 0x74, 0x5a, 0x87, - 0x37, 0x80, 0x8c, 0x8f, 0x7f, 0x75, 0xa8, 0x49, 0x63, 0x9b, 0x67, 0x68, - 0x4f, 0x87, 0xbf, 0x59, 0x9c, 0xbe, 0x93, 0x7e, 0x6f, 0x8a, 0xea, 0x77, - 0x83, 0x7a, 0x75, 0x8e, 0x7d, 0x50, 0x95, 0x60, 0x74, 0x60, 0x6f, 0x97, - 0x72, 0x5c, 0xa3, 0x6d, 0xb9, 0x86, 0x7b, 0x89, 0x9a, 0x76, 0xc7, 0x56, - 0xba, 0x86, 0x8d, 0x93, 0xa9, 0x98, 0xbb, 0x6a, 0x97, 0x74, 0x68, 0x84, - 0xc3, 0x65, 0xb6, 0x68, 0x89, 0x58, 0x87, 0xa1, 0xac, 0x60, 0x65, 0x68, - 0x7d, 0x98, 0x67, 0x8f, 0x8e, 0x84, 0x50, 0x75, 0x83, 0x91, 0x8a, 0x90, - 0x66, 0x74, 0x96, 0x89, 0x81, 0x7a, 0x7a, 0x64, 0x7f, 0x73, 0x8f, 0x95, - 0x8c, 0x89, 0x96, 0x76, 0x7a, 0x6c, 0x89, 0x91, 0x6d, 0x84, 0x68, 0x8d, - 0x47, 0x94, 0x9a, 0x67, 0x8f, 0x89, 0x8e, 0x79, 0x73, 0xa8, 0x7f, 0x6c, - 0x80, 0x64, 0x75, 0x81, 0x96, 0x9c, 0x68, 0x65, 0x76, 0x68, 0x74, 0x72, - 0x68, 0x76, 0x62, 0x6d, 0x6e, 0x6a, 0x84, 0x65, 0x8a, 0x73, 0x76, 0x91, - 0x78, 0x7c, 0x7a, 0x88, 0x6a, 0x87, 0x60, 0x99, 0x88, 0x75, 0x7b, 0x71, - 0x81, 0x7b, 0x76, 0x7d, 0x58, 0x75, 0x65, 0xa3, 0x95, 0x7e, 0x96, 0x3e, - 0x4c, 0x97, 0x86, 0x7a, 0x62, 0x92, 0xd1, 0x72, 0x8e, 0xaa, 0x85, 0x8e, - 0x59, 0x5f, 0xec, 0x77, 0x96, 0x66, 0x91, 0x9a, 0x89, 0x6c, 0xa2, 0x69, - 0x7d, 0x6e, 0x76, 0x63, 0x82, 0x72, 0x9c, 0x72, 0xa3, 0x75, 0x85, 0x7b, - 0x6d, 0x96, 0xc2, 0x69, 0xa7, 0x6a, 0x6b, 0x83, 0xa2, 0x7d, 0xce, 0x5c, - 0x94, 0x61, 0x7d, 0xae, 0xc3, 0x6d, 0x9f, 0x3c, 0x52, 0x4d, 0x8e, 0x92, - 0xae, 0x6e, 0x70, 0x5a, 0x76, 0x84, 0x7f, 0x72, 0x92, 0x72, 0x76, 0x5e, - 0x73, 0x8e, 0x82, 0x6d, 0x72, 0x81, 0x79, 0x94, 0x81, 0x88, 0x8b, 0x81, - 0x72, 0x72, 0x69, 0x84, 0x59, 0x6e, 0x74, 0x7d, 0x66, 0x74, 0x8d, 0x7b, - 0x7d, 0x7e, 0x7a, 0x83, 0x4d, 0x7e, 0x6a, 0x5a, 0x87, 0x66, 0x84, 0xa5, - 0x50, 0x5d, 0x6a, 0x8e, 0x87, 0x74, 0x88, 0x7c, 0x7d, 0x6c, 0x93, 0x98, - 0x8c, 0x76, 0x7f, 0xa3, 0x6e, 0x5d, 0x7d, 0x9f, 0x7c, 0x7a, 0x98, 0x88, - 0x74, 0x73, 0x50, 0x8c, 0x78, 0x8b, 0x71, 0x77, 0x9d, 0x56, 0x71, 0x85, - 0x6b, 0x8a, 0x93, 0x82, 0x8c, 0x79, 0x68, 0x8b, 0x57, 0x7b, 0x7c, 0x8a, - 0x6c, 0x87, 0x98, 0x54, 0x63, 0x7e, 0x78, 0x6b, 0x63, 0x77, 0xc1, 0x52, - 0xcd, 0xab, 0x75, 0x8e, 0x64, 0x68, 0xce, 0x68, 0x88, 0x6d, 0x67, 0x6d, - 0x68, 0x76, 0xa7, 0x78, 0x83, 0x67, 0x65, 0x5b, 0x8f, 0x63, 0x90, 0x5b, - 0xa1, 0x6f, 0x6a, 0x88, 0x70, 0x5c, 0x78, 0x49, 0xbc, 0x85, 0x8d, 0x8e, - 0xa3, 0x90, 0x97, 0x84, 0xa2, 0x46, 0x7a, 0x8e, 0x9e, 0xb1, 0xaa, 0x53, - 0x7d, 0x6b, 0x72, 0x86, 0x8c, 0x67, 0x6b, 0x48, 0x6f, 0x9c, 0x51, 0x94, - 0x6d, 0x66, 0x8e, 0x90, 0x79, 0x81, 0x66, 0x9f, 0x82, 0x9f, 0x98, 0x97, - 0x7c, 0x86, 0x7f, 0x57, 0x57, 0x83, 0x97, 0x8f, 0x73, 0x6f, 0x75, 0x6c, - 0x56, 0x8f, 0x7f, 0x73, 0x71, 0x84, 0x7d, 0x5f, 0x69, 0x69, 0x8e, 0x67, - 0x8a, 0x7f, 0x8c, 0x5a, 0x7a, 0x67, 0x82, 0x5a, 0x7a, 0x68, 0x73, 0x58, - 0x84, 0x83, 0x8d, 0x6d, 0x83, 0x72, 0x80, 0x7a, 0x8e, 0x7a, 0x68, 0x88, - 0x65, 0x74, 0x78, 0x73, 0x83, 0x97, 0x7b, 0x84, 0x77, 0x6d, 0x95, 0x99, - 0x76, 0x69, 0x5f, 0x9b, 0x7c, 0x75, 0x91, 0x80, 0x7b, 0x73, 0x6f, 0x9f, - 0x00, 0x00, 0x0e, 0x00, 0x18, 0x00, 0x08, 0x00, 0x07, 0x00, 0x0c, 0x00, - 0x10, 0x00, 0x14, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, - 0x10, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, - 0x28, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, - 0x0b, 0x00, 0x00, 0x00, 0x4d, 0x61, 0x74, 0x4d, 0x75, 0x6c, 0x5f, 0x62, - 0x69, 0x61, 0x73, 0x00, 0x0c, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - 0xaa, 0xcc, 0xe2, 0x37, 0x10, 0x00, 0x00, 0x00, 0xd6, 0x01, 0x00, 0x00, - 0xfd, 0xfd, 0xff, 0xff, 0x53, 0xfe, 0xff, 0xff, 0x74, 0x01, 0x00, 0x00, - 0x03, 0x00, 0x00, 0x00, 0xb4, 0x00, 0x00, 0x00, 0x5c, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0xc0, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x09, - 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, - 0x24, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x06, 0x00, 0x08, 0x00, 0x04, 0x00, - 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x3f, 0x14, 0x00, 0x1c, 0x00, - 0x08, 0x00, 0x0c, 0x00, 0x10, 0x00, 0x07, 0x00, 0x14, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x18, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, - 0x01, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, - 0x28, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x04, 0x00, 0x04, 0x00, 0x00, 0x00, 0x14, 0x00, 0x18, 0x00, - 0x00, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x07, 0x00, 0x10, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x14, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, - 0x10, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x30, 0x00, 0x00, 0x00, - 0x1c, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, - 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x10, 0x00, - 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x0c, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, - 0x03, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, - 0x04, 0x00, 0x00, 0x00, 0xfa, 0xff, 0xff, 0xff, 0x00, 0x19, 0x06, 0x00, - 0x06, 0x00, 0x05, 0x00, 0x06, 0x00, 0x00, 0x00, 0x00, 0x09, 0x06, 0x00, - 0x08, 0x00, 0x07, 0x00, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04}; -const int g_model_len = 18288; + 0x01, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0xf2, 0xdd, 0xbb, 0x3d, + 0x01, 0x00, 0x00, 0x00, 0x32, 0xa3, 0x25, 0x41, 0x01, 0x00, 0x00, 0x00, + 0xf6, 0xa0, 0x50, 0xc1, 0x05, 0x00, 0x00, 0x00, 0x61, 0x64, 0x64, 0x5f, + 0x31, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x0e, 0xfe, 0xff, 0xff, 0x00, 0x00, 0x00, 0x02, + 0x2c, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x04, 0x00, 0x04, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x0f, 0x00, 0x00, 0x00, 0x52, 0x65, 0x73, 0x68, 0x61, 0x70, 0x65, 0x5f, + 0x32, 0x2f, 0x73, 0x68, 0x61, 0x70, 0x65, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x4a, 0xfe, 0xff, 0xff, 0x00, 0x00, 0x00, 0x09, + 0x5c, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, 0x44, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x1c, 0xff, 0xff, 0xff, 0x30, 0x00, 0x00, 0x00, + 0x24, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x80, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x50, 0x50, 0xd0, 0x3d, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x80, 0xcf, 0x41, 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x00, 0x52, 0x65, 0x73, 0x68, + 0x61, 0x70, 0x65, 0x5f, 0x32, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x31, 0x00, 0x00, 0x00, 0x28, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0xc2, 0xfe, 0xff, 0xff, 0x00, 0x00, 0x00, 0x09, + 0x58, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x94, 0xff, 0xff, 0xff, 0x2c, 0x00, 0x00, 0x00, + 0x20, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x80, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0x01, 0x00, 0x00, 0x00, 0x50, 0x50, 0xd0, 0x3d, 0x01, 0x00, 0x00, 0x00, + 0x00, 0x80, 0xcf, 0x41, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x09, 0x00, 0x00, 0x00, 0x52, 0x65, 0x73, 0x68, 0x61, 0x70, 0x65, 0x5f, + 0x31, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0xa8, 0x07, 0x00, 0x00, 0x2e, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x09, + 0x60, 0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x00, 0x4c, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x14, 0x00, 0x04, 0x00, 0x08, 0x00, + 0x0c, 0x00, 0x10, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x2c, 0x00, 0x00, 0x00, + 0x20, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x80, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0x01, 0x00, 0x00, 0x00, 0x3a, 0x6a, 0xac, 0x3d, 0x01, 0x00, 0x00, 0x00, + 0xd0, 0xbd, 0xab, 0x41, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x52, 0x65, 0x6c, 0x75, 0x00, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x19, 0x00, 0x00, 0x00, + 0x14, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0xaa, 0xff, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x02, 0x44, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x2c, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x9c, 0xff, 0xff, 0xff, + 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x96, 0x08, 0x29, 0x38, 0x0b, 0x00, 0x00, 0x00, + 0x4d, 0x61, 0x74, 0x4d, 0x75, 0x6c, 0x5f, 0x62, 0x69, 0x61, 0x73, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, + 0x18, 0x00, 0x08, 0x00, 0x07, 0x00, 0x0c, 0x00, 0x10, 0x00, 0x14, 0x00, + 0x0e, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xa0, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x88, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x4c, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, + 0x9a, 0xbb, 0x84, 0x38, 0x83, 0x84, 0x73, 0x37, 0x5b, 0xa3, 0xa0, 0x38, + 0x16, 0x41, 0x3a, 0x38, 0xc7, 0x9a, 0x70, 0x38, 0xed, 0x70, 0x4e, 0x38, + 0x54, 0x4f, 0xac, 0x38, 0xfd, 0x07, 0x8d, 0x38, 0x0b, 0x00, 0x00, 0x00, + 0x43, 0x6f, 0x6e, 0x76, 0x32, 0x44, 0x5f, 0x62, 0x69, 0x61, 0x73, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x4c, 0x00, 0x00, 0x00, 0x30, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xe6, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x19, + 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x06, 0x00, 0x06, 0x00, 0x05, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x00, 0x16, 0x0a, 0x00, 0x0e, 0x00, 0x07, 0x00, + 0x00, 0x00, 0x08, 0x00, 0x0a, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09, + 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x0c, 0x00, 0x07, 0x00, + 0x00, 0x00, 0x08, 0x00, 0x0a, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, + 0x03, 0x00, 0x00, 0x00}; +const int g_model_len = 18712; diff --git a/tensorflow/lite/micro/examples/micro_speech/micro_features/no_feature_data_slice.cc b/tensorflow/lite/micro/examples/micro_speech/micro_features/no_feature_data_slice.cc index b523a8185d4..80f2b62546b 100644 --- a/tensorflow/lite/micro/examples/micro_speech/micro_features/no_feature_data_slice.cc +++ b/tensorflow/lite/micro/examples/micro_speech/micro_features/no_feature_data_slice.cc @@ -17,8 +17,8 @@ limitations under the License. #include "tensorflow/lite/micro/examples/micro_speech/micro_features/no_feature_data_slice.h" -const uint8_t g_no_feature_data_slice[g_no_feature_data_slice_size] = { - 216, 195, 223, 211, 238, 223, 243, 215, 226, 204, 232, 211, 232, 213, - 240, 218, 235, 214, 238, 205, 207, 173, 149, 201, 215, 200, 230, 213, - 208, 195, 175, 151, 195, 175, 182, 163, 235, 217, 218, 190, +const int8_t g_no_feature_data_slice[g_no_feature_data_slice_size] = { + 89, 68, 96, 83, 111, 96, 115, 87, 99, 76, 105, 84, 105, 86, + 113, 91, 108, 87, 110, 78, 80, 46, 22, 74, 88, 72, 103, 86, + 80, 68, 48, 24, 68, 48, 55, 36, 108, 90, 90, 63, }; diff --git a/tensorflow/lite/micro/examples/micro_speech/micro_features/no_feature_data_slice.h b/tensorflow/lite/micro/examples/micro_speech/micro_features/no_feature_data_slice.h index 234e7efc388..7c27379f6de 100644 --- a/tensorflow/lite/micro/examples/micro_speech/micro_features/no_feature_data_slice.h +++ b/tensorflow/lite/micro/examples/micro_speech/micro_features/no_feature_data_slice.h @@ -24,6 +24,6 @@ limitations under the License. #include <cstdint> constexpr int g_no_feature_data_slice_size = 40; -extern const uint8_t g_no_feature_data_slice[]; +extern const int8_t g_no_feature_data_slice[]; #endif // TENSORFLOW_LITE_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_NO_FEATURE_DATA_SLICE_H_ diff --git a/tensorflow/lite/micro/examples/micro_speech/micro_features/no_micro_features_data.cc b/tensorflow/lite/micro/examples/micro_speech/micro_features/no_micro_features_data.cc index d7a923364a7..2fa4556a273 100644 --- a/tensorflow/lite/micro/examples/micro_speech/micro_features/no_micro_features_data.cc +++ b/tensorflow/lite/micro/examples/micro_speech/micro_features/no_micro_features_data.cc @@ -15,151 +15,174 @@ limitations under the License. #include "tensorflow/lite/micro/examples/micro_speech/micro_features/no_micro_features_data.h" -/* File automatically created by - * tensorflow/examples/speech_commands/wav_to_features.py \ - * --sample_rate=16000 \ - * --clip_duration_ms=1000 \ - * --window_size_ms=30 \ - * --window_stride_ms=20 \ - * --feature_bin_count=40 \ - * --quantize=1 \ - * --preprocess="micro" \ - * --input_wav="speech_commands_test_set_v0.02/no/f9643d42_nohash_4.wav" \ - * --output_c_file="/tmp/no_micro_features_data.cc" \ - */ +// Golden test values for the expected spectrogram from a "no" sample file +// speech_commands_test_set_v0.02/no/f9643d42_nohash_4.wav. const int g_no_micro_f9643d42_nohash_4_width = 40; const int g_no_micro_f9643d42_nohash_4_height = 49; -const unsigned char g_no_micro_f9643d42_nohash_4_data[] = { - 230, 205, 191, 203, 202, 181, 180, 194, 205, 187, 183, 197, 203, 198, 196, - 186, 202, 159, 151, 126, 110, 138, 141, 142, 137, 148, 133, 120, 110, 126, - 117, 110, 117, 116, 137, 134, 95, 116, 123, 110, 184, 144, 183, 189, 197, - 172, 188, 164, 194, 179, 175, 174, 182, 173, 184, 174, 200, 145, 154, 148, - 147, 135, 143, 122, 127, 138, 116, 99, 122, 105, 110, 125, 127, 133, 131, - 123, 116, 119, 127, 114, 193, 176, 185, 170, 175, 146, 166, 167, 185, 185, - 185, 183, 195, 185, 176, 178, 197, 155, 137, 144, 164, 132, 153, 132, 138, - 137, 134, 95, 120, 116, 131, 122, 99, 120, 120, 110, 116, 110, 126, 127, - 128, 159, 187, 119, 178, 187, 197, 167, 199, 184, 180, 165, 194, 176, 144, - 134, 187, 136, 142, 134, 145, 132, 145, 105, 119, 123, 125, 116, 125, 102, - 129, 138, 130, 99, 99, 90, 120, 123, 134, 95, 194, 172, 187, 123, 191, - 179, 195, 182, 201, 137, 167, 142, 185, 161, 187, 146, 167, 152, 154, 107, - 152, 112, 134, 144, 117, 116, 105, 85, 105, 105, 99, 90, 123, 112, 112, - 68, 107, 105, 117, 99, 116, 143, 139, 90, 154, 142, 188, 172, 178, 135, - 175, 149, 177, 110, 173, 160, 169, 162, 173, 119, 132, 110, 85, 85, 117, - 129, 117, 112, 117, 51, 112, 95, 139, 102, 105, 90, 128, 119, 112, 99, - 170, 168, 195, 152, 174, 173, 180, 0, 157, 130, 169, 149, 149, 123, 170, - 130, 170, 133, 159, 102, 134, 90, 85, 105, 126, 119, 130, 90, 78, 68, - 127, 120, 95, 51, 122, 110, 112, 78, 116, 95, 180, 135, 179, 146, 179, - 162, 197, 153, 172, 135, 154, 0, 149, 95, 145, 114, 166, 0, 114, 110, - 145, 107, 114, 90, 136, 68, 95, 95, 95, 85, 116, 99, 116, 0, 95, - 68, 102, 51, 102, 78, 185, 157, 138, 158, 180, 117, 173, 142, 145, 117, - 169, 130, 159, 99, 138, 123, 169, 90, 78, 0, 123, 85, 107, 51, 114, - 102, 95, 0, 116, 85, 119, 95, 95, 68, 85, 51, 116, 68, 102, 78, - 167, 105, 164, 163, 178, 126, 164, 154, 154, 51, 177, 120, 156, 85, 134, - 139, 168, 90, 161, 102, 114, 116, 122, 95, 112, 102, 107, 51, 114, 85, - 119, 78, 114, 90, 102, 51, 102, 51, 114, 99, 177, 68, 152, 102, 184, - 166, 179, 129, 177, 129, 180, 110, 158, 105, 139, 0, 145, 85, 148, 102, - 117, 102, 116, 0, 78, 68, 90, 51, 107, 85, 78, 0, 51, 0, 51, - 0, 95, 51, 107, 68, 180, 117, 90, 0, 138, 0, 187, 146, 119, 140, - 164, 90, 136, 0, 131, 51, 159, 99, 141, 138, 116, 51, 90, 51, 90, - 68, 105, 0, 85, 78, 112, 51, 122, 95, 128, 68, 85, 0, 112, 68, - 147, 126, 178, 146, 171, 130, 190, 147, 188, 123, 170, 78, 132, 0, 130, - 125, 159, 95, 102, 0, 110, 0, 95, 85, 120, 68, 78, 51, 99, 51, - 105, 0, 112, 102, 105, 68, 90, 51, 90, 0, 127, 95, 166, 175, 187, - 133, 135, 0, 171, 139, 132, 128, 140, 51, 126, 107, 161, 0, 95, 51, - 119, 0, 114, 0, 95, 110, 116, 51, 112, 0, 90, 0, 116, 51, 68, - 0, 105, 68, 105, 0, 164, 78, 173, 0, 194, 166, 145, 114, 116, 51, - 107, 122, 151, 0, 156, 102, 148, 51, 122, 95, 129, 0, 85, 0, 127, - 78, 90, 0, 78, 0, 95, 0, 110, 0, 68, 119, 120, 68, 68, 0, - 122, 99, 147, 127, 200, 167, 85, 114, 161, 85, 161, 125, 143, 99, 156, - 85, 147, 68, 99, 0, 107, 102, 132, 51, 112, 68, 95, 78, 99, 0, - 68, 0, 51, 0, 90, 78, 128, 51, 95, 0, 166, 136, 174, 138, 189, - 144, 130, 129, 138, 134, 132, 120, 134, 0, 51, 78, 147, 51, 51, 0, - 51, 0, 78, 0, 68, 68, 95, 78, 90, 0, 0, 0, 68, 0, 90, - 68, 110, 0, 95, 51, 165, 151, 157, 0, 0, 0, 112, 0, 112, 95, - 149, 107, 119, 68, 126, 68, 138, 0, 78, 0, 78, 0, 99, 51, 112, - 0, 102, 0, 78, 51, 85, 0, 0, 0, 78, 0, 95, 0, 95, 78, - 105, 0, 152, 0, 0, 51, 132, 105, 159, 0, 129, 102, 114, 0, 138, - 51, 123, 0, 129, 78, 119, 51, 51, 51, 105, 0, 78, 85, 95, 0, - 85, 0, 0, 0, 85, 0, 78, 0, 0, 0, 172, 142, 141, 0, 137, - 0, 148, 128, 157, 120, 146, 120, 120, 0, 95, 78, 141, 68, 68, 0, - 68, 0, 90, 0, 85, 0, 107, 0, 78, 0, 85, 51, 102, 0, 68, - 78, 68, 0, 51, 0, 125, 0, 141, 51, 102, 138, 175, 51, 120, 51, - 173, 85, 116, 141, 164, 68, 150, 123, 133, 51, 114, 0, 117, 68, 150, - 51, 116, 68, 78, 0, 68, 0, 68, 0, 85, 0, 78, 0, 51, 78, - 155, 90, 161, 0, 132, 99, 123, 78, 107, 0, 134, 90, 95, 0, 78, - 0, 162, 143, 85, 0, 107, 78, 125, 90, 90, 51, 51, 0, 85, 0, - 0, 0, 132, 102, 102, 154, 128, 0, 99, 68, 162, 102, 151, 0, 99, - 51, 147, 141, 156, 0, 112, 120, 158, 127, 145, 139, 187, 171, 135, 138, - 146, 0, 95, 68, 127, 0, 85, 0, 105, 0, 0, 0, 187, 170, 162, - 188, 165, 51, 51, 78, 243, 215, 225, 196, 205, 181, 205, 168, 176, 134, - 157, 110, 126, 114, 133, 139, 193, 163, 159, 116, 160, 126, 122, 127, 171, - 99, 114, 68, 123, 85, 90, 0, 157, 146, 166, 179, 136, 0, 116, 90, - 242, 219, 240, 204, 216, 164, 188, 171, 176, 164, 154, 158, 190, 157, 190, - 141, 182, 177, 169, 128, 172, 145, 105, 129, 157, 90, 78, 51, 119, 68, - 137, 68, 116, 78, 141, 132, 151, 122, 156, 140, 234, 206, 229, 201, 216, - 174, 191, 144, 162, 85, 122, 157, 194, 167, 204, 149, 180, 166, 166, 139, - 122, 133, 156, 126, 145, 85, 128, 0, 99, 51, 145, 0, 126, 51, 166, - 162, 166, 162, 177, 157, 228, 198, 221, 197, 214, 177, 173, 166, 173, 139, - 185, 191, 202, 163, 205, 172, 206, 189, 135, 68, 166, 134, 149, 134, 135, - 90, 127, 107, 175, 90, 136, 117, 135, 140, 172, 167, 166, 149, 177, 152, - 221, 191, 215, 194, 211, 0, 156, 147, 182, 178, 208, 163, 190, 157, 208, - 200, 195, 164, 179, 154, 181, 150, 143, 99, 132, 137, 185, 143, 163, 85, - 51, 107, 132, 134, 164, 127, 167, 159, 175, 141, 216, 195, 223, 211, 238, - 223, 243, 215, 226, 204, 232, 211, 232, 213, 240, 218, 235, 214, 238, 205, - 207, 173, 149, 201, 215, 200, 230, 213, 208, 195, 175, 151, 195, 175, 182, - 163, 235, 217, 218, 190, 211, 191, 215, 191, 217, 220, 241, 215, 229, 206, - 236, 210, 227, 216, 236, 188, 183, 149, 202, 189, 208, 172, 191, 201, 220, - 193, 221, 207, 216, 208, 201, 131, 170, 187, 229, 197, 211, 194, 226, 201, - 205, 184, 206, 177, 221, 210, 226, 184, 204, 197, 218, 198, 212, 209, 213, - 141, 172, 110, 175, 167, 180, 156, 213, 188, 192, 179, 213, 205, 204, 174, - 200, 147, 162, 181, 203, 167, 198, 187, 210, 164, 196, 169, 189, 168, 224, - 198, 213, 204, 198, 195, 230, 211, 221, 197, 208, 0, 0, 0, 85, 90, - 167, 130, 175, 173, 203, 164, 193, 144, 170, 145, 185, 148, 154, 139, 198, - 159, 180, 171, 216, 174, 178, 161, 166, 136, 216, 184, 215, 197, 199, 190, - 228, 195, 208, 51, 117, 0, 0, 0, 0, 0, 140, 51, 135, 154, 188, - 155, 168, 0, 90, 0, 156, 85, 110, 0, 174, 90, 172, 154, 179, 99, - 142, 166, 179, 157, 177, 95, 192, 142, 204, 198, 217, 147, 173, 0, 112, - 0, 0, 0, 0, 0, 0, 0, 110, 0, 107, 0, 160, 0, 148, 95, - 172, 0, 0, 0, 116, 0, 122, 114, 170, 0, 0, 0, 0, 0, 179, - 110, 196, 85, 205, 183, 169, 0, 99, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 141, 0, 112, 0, 0, 0, 134, 0, 0, 0, 0, - 0, 0, 0, 139, 0, 0, 0, 0, 112, 186, 78, 163, 0, 169, 128, - 174, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 95, - 0, 105, 0, 0, 0, 105, 0, 0, 0, 0, 0, 0, 0, 95, 0, - 0, 0, 0, 0, 0, 0, 119, 0, 164, 78, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 90, 0, 0, 68, - 117, 0, 0, 0, 0, 0, 0, 0, 148, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 116, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 51, - 0, 0, 0, 99, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 99, 0, 0, 0, 0, 0, 0, 0, 0, 0, 78, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, +const signed char g_no_micro_f9643d42_nohash_4_data[] = { + 103, 78, 64, 76, 75, 54, 53, 67, 77, 60, 56, 70, + 76, 71, 68, 58, 74, 32, 23, -2, -18, 11, 13, 15, + 9, 20, 5, -7, -18, -2, -10, -18, -10, -12, 9, 7, + -33, -12, -4, -18, 57, 17, 55, 62, 70, 45, 61, 37, + 67, 52, 48, 47, 55, 46, 57, 47, 73, 17, 27, 20, + 19, 8, 15, -6, -1, 10, -12, -29, -6, -23, -18, -3, + -1, 5, 3, -4, -12, -8, -1, -14, 65, 48, 58, 43, + 48, 19, 39, 39, 57, 57, 58, 55, 67, 58, 49, 50, + 70, 27, 9, 16, 37, 4, 25, 4, 11, 9, 7, -33, + -7, -12, 3, -6, -29, -7, -7, -18, -12, -18, -2, -1, + 0, 31, 60, -8, 51, 59, 70, 40, 71, 57, 52, 38, + 66, 48, 17, 6, 59, 8, 15, 7, 18, 4, 18, -23, + -8, -4, -3, -12, -3, -26, 1, 10, 2, -29, -29, -37, + -7, -4, 6, -33, 67, 44, 59, -4, 64, 51, 68, 55, + 74, 9, 40, 15, 57, 33, 60, 18, 40, 25, 27, -20, + 25, -16, 6, 17, -10, -12, -23, -43, -23, -23, -29, -37, + -4, -16, -16, -60, -20, -23, -10, -29, -12, 15, 12, -37, + 27, 15, 61, 44, 50, 8, 48, 22, 49, -18, 46, 33, + 42, 34, 46, -8, 4, -18, -43, -43, -10, 1, -10, -16, + -10, -77, -16, -33, 11, -26, -23, -37, 0, -8, -16, -29, + 42, 40, 68, 24, 47, 46, 53, -128, 30, 2, 42, 21, + 21, -4, 43, 2, 43, 5, 32, -26, 7, -37, -43, -23, + -2, -8, 2, -37, -50, -60, -1, -7, -33, -77, -6, -18, + -16, -50, -12, -33, 53, 8, 52, 18, 51, 35, 69, 26, + 44, 8, 27, -128, 21, -33, 17, -14, 38, -128, -14, -18, + 17, -20, -14, -37, 8, -60, -33, -33, -33, -43, -12, -29, + -12, -128, -33, -60, -26, -77, -26, -50, 57, 29, 11, 30, + 53, -10, 45, 15, 18, -10, 42, 2, 31, -29, 10, -4, + 42, -37, -50, -128, -4, -43, -20, -77, -14, -26, -33, -128, + -12, -43, -8, -33, -33, -60, -43, -77, -12, -60, -26, -50, + 40, -23, 36, 35, 50, -2, 37, 27, 26, -77, 49, -7, + 28, -43, 6, 11, 41, -37, 33, -26, -14, -12, -6, -33, + -16, -26, -20, -77, -14, -43, -8, -50, -14, -37, -26, -77, + -26, -77, -14, -29, 50, -60, 25, -26, 57, 38, 51, 1, + 50, 1, 53, -18, 30, -23, 11, -128, 18, -43, 20, -26, + -10, -26, -12, -128, -50, -60, -37, -77, -20, -43, -50, -128, + -77, -128, -77, -128, -33, -77, -20, -60, 53, -10, -37, -128, + 10, -128, 60, 18, -8, 13, 37, -37, 8, -128, 3, -77, + 32, -29, 14, 10, -12, -77, -37, -77, -37, -60, -23, -128, + -43, -50, -16, -77, -6, -33, 0, -60, -43, -128, -16, -60, + 20, -2, 51, 19, 43, 2, 63, 20, 60, -4, 42, -50, + 4, -128, 2, -3, 32, -33, -26, -128, -18, -128, -33, -43, + -7, -60, -50, -77, -29, -77, -23, -128, -16, -26, -23, -60, + -37, -77, -37, -128, -1, -33, 39, 48, 60, 5, 8, -128, + 44, 11, 4, 0, 13, -77, -2, -20, 33, -128, -33, -77, + -8, -128, -14, -128, -33, -18, -12, -77, -16, -128, -37, -128, + -12, -77, -60, -128, -23, -60, -23, -128, 36, -50, 46, -128, + 66, 39, 18, -14, -12, -77, -20, -6, 24, -128, 28, -26, + 21, -77, -6, -33, 1, -128, -43, -128, -1, -50, -37, -128, + -50, -128, -33, -128, -18, -128, -60, -8, -7, -60, -60, -128, + -6, -29, 20, -1, 73, 40, -43, -14, 33, -43, 33, -3, + 15, -29, 29, -43, 20, -60, -29, -128, -20, -26, 4, -77, + -16, -60, -33, -50, -29, -128, -60, -128, -77, -128, -37, -50, + 0, -77, -33, -128, 39, 8, 47, 10, 62, 16, 2, 1, + 10, 7, 4, -7, 6, -128, -77, -50, 19, -77, -77, -128, + -77, -128, -50, -128, -60, -60, -33, -50, -37, -128, -128, -128, + -60, -128, -37, -60, -18, -128, -33, -77, 37, 23, 29, -128, + -128, -128, -16, -128, -16, -33, 21, -20, -8, -60, -2, -60, + 11, -128, -50, -128, -50, -128, -29, -77, -16, -128, -26, -128, + -50, -77, -43, -128, -128, -128, -50, -128, -33, -128, -33, -50, + -23, -128, 24, -128, -128, -77, 4, -23, 32, -128, 1, -26, + -14, -128, 10, -77, -4, -128, 1, -50, -8, -77, -77, -77, + -23, -128, -50, -43, -33, -128, -43, -128, -128, -128, -43, -128, + -50, -128, -128, -128, 44, 15, 14, -128, 9, -128, 21, 0, + 29, -7, 18, -7, -7, -128, -33, -50, 14, -60, -60, -128, + -60, -128, -37, -128, -43, -128, -20, -128, -50, -128, -43, -77, + -26, -128, -60, -50, -60, -128, -77, -128, -3, -128, 14, -77, + -26, 11, 47, -77, -7, -77, 45, -43, -12, 14, 37, -60, + 22, -4, 5, -77, -14, -128, -10, -60, 22, -77, -12, -60, + -50, -128, -60, -128, -60, -128, -43, -128, -50, -128, -77, -50, + 27, -37, 33, -128, 4, -29, -4, -50, -20, -128, 6, -37, + -33, -128, -50, -128, 34, 15, -43, -128, -20, -50, -3, -37, + -37, -77, -77, -128, -43, -128, -128, -128, 4, -26, -26, 27, + 0, -128, -29, -60, 35, -26, 23, -128, -29, -77, 19, 14, + 28, -128, -16, -7, 31, -1, 17, 11, 60, 44, 8, 11, + 18, -128, -33, -60, -1, -128, -43, -128, -23, -128, -128, -128, + 59, 43, 35, 61, 37, -77, -77, -50, 116, 88, 98, 69, + 78, 53, 78, 40, 48, 7, 29, -18, -2, -14, 5, 12, + 65, 35, 31, -12, 33, -2, -6, -1, 44, -29, -14, -60, + -4, -43, -37, -128, 29, 18, 38, 51, 8, -128, -12, -37, + 115, 91, 113, 77, 89, 36, 60, 44, 49, 36, 27, 31, + 63, 30, 62, 14, 55, 49, 42, 0, 45, 17, -23, 1, + 30, -37, -50, -77, -8, -60, 9, -60, -12, -50, 13, 4, + 23, -6, 28, 13, 107, 78, 101, 73, 89, 46, 63, 17, + 34, -43, -6, 30, 67, 40, 77, 21, 53, 39, 38, 12, + -6, 5, 28, -2, 18, -43, 0, -128, -29, -77, 18, -128, + -2, -77, 39, 35, 38, 35, 50, 29, 100, 70, 94, 69, + 86, 50, 45, 38, 45, 12, 58, 64, 74, 36, 77, 45, + 78, 62, 8, -60, 38, 6, 21, 7, 8, -37, -1, -20, + 48, -37, 8, -10, 8, 13, 45, 39, 38, 22, 49, 25, + 94, 63, 87, 66, 84, -128, 29, 20, 55, 51, 80, 36, + 62, 30, 81, 72, 68, 37, 51, 27, 54, 22, 16, -29, + 4, 9, 57, 15, 35, -43, -77, -20, 4, 6, 37, -1, + 40, 31, 47, 14, 89, 68, 96, 83, 111, 96, 115, 87, + 99, 76, 105, 84, 105, 86, 113, 91, 108, 87, 110, 78, + 80, 46, 22, 74, 88, 72, 103, 86, 80, 68, 48, 24, + 68, 48, 55, 36, 108, 90, 90, 63, 83, 63, 87, 64, + 90, 92, 113, 88, 102, 79, 109, 83, 100, 89, 109, 60, + 56, 21, 75, 62, 81, 45, 63, 73, 93, 65, 94, 80, + 89, 81, 73, 3, 43, 60, 102, 70, 84, 67, 99, 74, + 78, 57, 79, 50, 93, 82, 98, 56, 77, 70, 91, 71, + 85, 82, 86, 13, 45, -18, 48, 40, 53, 28, 85, 60, + 65, 52, 86, 78, 76, 46, 73, 19, 35, 54, 75, 40, + 71, 60, 82, 37, 69, 42, 62, 40, 96, 70, 85, 77, + 70, 68, 103, 84, 94, 69, 81, -128, -128, -128, -43, -37, + 40, 2, 48, 45, 76, 37, 65, 16, 43, 18, 58, 20, + 27, 12, 71, 31, 53, 44, 88, 47, 50, 33, 39, 8, + 89, 57, 88, 69, 72, 63, 100, 68, 81, -77, -10, -128, + -128, -128, -128, -128, 13, -77, 8, 27, 60, 28, 41, -128, + -37, -128, 28, -43, -18, -128, 47, -37, 45, 27, 51, -29, + 15, 39, 52, 30, 49, -33, 65, 15, 76, 71, 90, 19, + 46, -128, -16, -128, -128, -128, -128, -128, -128, -128, -18, -128, + -20, -128, 32, -128, 21, -33, 45, -128, -128, -128, -12, -128, + -6, -14, 43, -128, -128, -128, -128, -128, 52, -18, 69, -43, + 78, 55, 42, -128, -29, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, 14, -128, -16, -128, -128, -128, 7, -128, + -128, -128, -128, -128, -128, -128, 12, -128, -128, -128, -128, -16, + 59, -50, 35, -128, 42, 0, 47, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -33, -128, -23, -128, + -128, -128, -23, -128, -128, -128, -128, -128, -128, -128, -33, -128, + -128, -128, -128, -128, -128, -128, -8, -128, 36, -50, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -37, -128, -128, -60, -10, -128, -128, -128, -128, -128, + -128, -128, 21, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -12, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -77, -128, -128, -128, -29, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -29, -128, -128, -128, -128, -128, -128, -128, -128, -128, -50, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, }; diff --git a/tensorflow/lite/micro/examples/micro_speech/micro_features/no_micro_features_data.h b/tensorflow/lite/micro/examples/micro_speech/micro_features/no_micro_features_data.h index dc4d45b237e..8c1b6d5b57b 100644 --- a/tensorflow/lite/micro/examples/micro_speech/micro_features/no_micro_features_data.h +++ b/tensorflow/lite/micro/examples/micro_speech/micro_features/no_micro_features_data.h @@ -18,6 +18,6 @@ limitations under the License. extern const int g_no_micro_f9643d42_nohash_4_width; extern const int g_no_micro_f9643d42_nohash_4_height; -extern const unsigned char g_no_micro_f9643d42_nohash_4_data[]; +extern const signed char g_no_micro_f9643d42_nohash_4_data[]; #endif // TENSORFLOW_LITE_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_NO_MICRO_FEATURES_DATA_H_ diff --git a/tensorflow/lite/micro/examples/micro_speech/micro_features/yes_feature_data_slice.cc b/tensorflow/lite/micro/examples/micro_speech/micro_features/yes_feature_data_slice.cc index 7597b043d9b..7f077b5ffef 100644 --- a/tensorflow/lite/micro/examples/micro_speech/micro_features/yes_feature_data_slice.cc +++ b/tensorflow/lite/micro/examples/micro_speech/micro_features/yes_feature_data_slice.cc @@ -17,8 +17,8 @@ limitations under the License. #include "tensorflow/lite/micro/examples/micro_speech/micro_features/yes_feature_data_slice.h" -const uint8_t g_yes_feature_data_slice[g_yes_feature_data_slice_size] = { - 214, 215, 236, 202, 235, 203, 225, 191, 203, 188, 199, 194, 212, 127, - 51, 0, 174, 188, 219, 196, 228, 221, 240, 207, 235, 220, 241, 219, - 237, 207, 212, 142, 95, 0, 139, 78, 162, 177, 197, 183, +const int8_t g_yes_feature_data_slice[g_yes_feature_data_slice_size] = { + 86, 88, 108, 75, 108, 76, 98, 64, 75, 61, 71, 66, 85, -1, + -77, -128, 46, 61, 92, 69, 100, 93, 113, 80, 108, 93, 113, 91, + 110, 80, 85, 15, -33, -128, 12, -50, 34, 50, 70, 55, }; diff --git a/tensorflow/lite/micro/examples/micro_speech/micro_features/yes_feature_data_slice.h b/tensorflow/lite/micro/examples/micro_speech/micro_features/yes_feature_data_slice.h index 1515449b2c2..2427ee70063 100644 --- a/tensorflow/lite/micro/examples/micro_speech/micro_features/yes_feature_data_slice.h +++ b/tensorflow/lite/micro/examples/micro_speech/micro_features/yes_feature_data_slice.h @@ -24,6 +24,6 @@ limitations under the License. #include <cstdint> constexpr int g_yes_feature_data_slice_size = 40; -extern const uint8_t g_yes_feature_data_slice[]; +extern const int8_t g_yes_feature_data_slice[]; #endif // TENSORFLOW_LITE_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_YES_FEATURE_DATA_SLICE_H_ diff --git a/tensorflow/lite/micro/examples/micro_speech/micro_features/yes_micro_features_data.cc b/tensorflow/lite/micro/examples/micro_speech/micro_features/yes_micro_features_data.cc index 9c1fb8be0bb..6d9137af2da 100644 --- a/tensorflow/lite/micro/examples/micro_speech/micro_features/yes_micro_features_data.cc +++ b/tensorflow/lite/micro/examples/micro_speech/micro_features/yes_micro_features_data.cc @@ -15,151 +15,174 @@ limitations under the License. #include "tensorflow/lite/micro/examples/micro_speech/micro_features/yes_micro_features_data.h" -/* File automatically created by - * tensorflow/examples/speech_commands/wav_to_features.py \ - * --sample_rate=16000 \ - * --clip_duration_ms=1000 \ - * --window_size_ms=30 \ - * --window_stride_ms=20 \ - * --feature_bin_count=40 \ - * --quantize=1 \ - * --preprocess="micro" \ - * --input_wav="speech_commands_test_set_v0.02/yes/f2e59fea_nohash_1.wav" \ - * --output_c_file="yes_micro_features_data.cc" \ - */ +// Golden test values for the expected spectrogram from a "yes" sample file +// speech_commands_test_set_v0.02/yes/f2e59fea_nohash_1.wav. const int g_yes_micro_f2e59fea_nohash_1_width = 40; const int g_yes_micro_f2e59fea_nohash_1_height = 49; -const unsigned char g_yes_micro_f2e59fea_nohash_1_data[] = { - 244, 226, 245, 223, 234, 213, 228, 208, 194, 110, 95, 116, 102, 0, 137, - 161, 183, 173, 137, 116, 133, 157, 151, 156, 128, 110, 128, 0, 68, 78, - 78, 90, 68, 68, 78, 102, 95, 78, 95, 78, 210, 188, 209, 183, 204, - 188, 201, 191, 166, 119, 90, 107, 110, 107, 175, 157, 179, 168, 182, 145, - 152, 164, 171, 165, 136, 143, 122, 68, 0, 78, 90, 90, 110, 90, 102, - 99, 90, 68, 78, 68, 223, 186, 179, 123, 182, 110, 196, 171, 159, 110, - 102, 95, 90, 99, 160, 134, 125, 136, 153, 152, 164, 134, 164, 151, 141, - 136, 99, 90, 90, 90, 78, 78, 102, 119, 102, 90, 110, 90, 68, 51, - 177, 175, 211, 172, 183, 0, 95, 68, 129, 102, 68, 85, 114, 105, 110, - 85, 102, 95, 140, 51, 85, 51, 95, 90, 143, 116, 90, 78, 78, 51, - 107, 85, 68, 0, 68, 51, 90, 51, 68, 0, 164, 117, 193, 120, 156, - 0, 138, 51, 90, 0, 51, 0, 51, 85, 0, 0, 51, 0, 0, 0, - 0, 0, 114, 0, 85, 78, 90, 51, 0, 0, 51, 85, 99, 85, 107, - 68, 90, 85, 78, 0, 51, 0, 110, 0, 68, 0, 0, 0, 51, 0, - 51, 0, 0, 0, 68, 90, 107, 0, 68, 0, 0, 0, 68, 0, 51, - 68, 0, 78, 68, 0, 51, 0, 78, 68, 90, 68, 78, 51, 51, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 90, 0, 0, 0, 0, - 0, 51, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 51, 68, - 0, 0, 78, 0, 78, 0, 78, 0, 51, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 51, 0, 51, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 51, 0, 51, - 0, 51, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 51, - 0, 0, 0, 0, 51, 78, 0, 0, 51, 51, 0, 0, 0, 78, 0, - 213, 170, 192, 180, 196, 188, 173, 131, 173, 116, 137, 105, 159, 127, 0, - 0, 0, 0, 127, 164, 165, 161, 170, 164, 185, 197, 195, 167, 134, 138, - 159, 134, 136, 105, 51, 0, 99, 0, 51, 0, 228, 215, 229, 218, 237, - 215, 228, 210, 237, 222, 239, 211, 208, 211, 234, 218, 220, 209, 225, 219, - 235, 222, 245, 225, 245, 224, 243, 223, 241, 218, 237, 224, 234, 213, 221, - 193, 197, 164, 157, 128, 227, 188, 232, 196, 220, 220, 240, 219, 234, 213, - 234, 211, 231, 218, 233, 213, 239, 215, 228, 207, 229, 206, 224, 208, 226, - 207, 232, 210, 225, 208, 230, 199, 227, 206, 210, 205, 218, 174, 178, 141, - 235, 208, 220, 206, 225, 203, 233, 203, 225, 167, 205, 199, 208, 190, 221, - 204, 223, 207, 225, 188, 225, 197, 215, 188, 199, 183, 225, 195, 224, 200, - 216, 178, 208, 188, 215, 202, 214, 183, 176, 140, 198, 150, 211, 194, 203, - 120, 175, 188, 204, 189, 219, 192, 223, 202, 216, 186, 203, 185, 210, 182, - 214, 183, 204, 170, 204, 125, 184, 187, 206, 185, 198, 182, 210, 161, 202, - 198, 218, 173, 145, 120, 188, 183, 205, 168, 200, 170, 210, 177, 187, 190, - 209, 193, 193, 166, 210, 162, 175, 119, 174, 147, 182, 161, 181, 134, 176, - 143, 187, 165, 186, 149, 185, 141, 192, 181, 202, 123, 170, 143, 144, 78, - 149, 0, 208, 182, 170, 78, 170, 0, 117, 51, 156, 99, 195, 170, 200, - 130, 152, 68, 175, 141, 173, 134, 194, 132, 189, 164, 198, 134, 173, 117, - 171, 149, 183, 181, 185, 99, 153, 117, 125, 0, 166, 0, 173, 117, 144, - 0, 117, 102, 188, 120, 193, 166, 197, 68, 163, 119, 169, 99, 134, 0, - 162, 0, 164, 68, 171, 116, 126, 0, 120, 68, 68, 0, 105, 0, 159, - 95, 150, 51, 90, 85, 0, 0, 131, 0, 105, 0, 145, 51, 170, 51, - 120, 0, 107, 0, 145, 85, 160, 0, 85, 0, 0, 51, 149, 0, 78, - 0, 0, 0, 0, 0, 0, 0, 90, 0, 112, 0, 78, 102, 122, 0, - 0, 0, 0, 0, 105, 0, 0, 0, 0, 0, 0, 0, 0, 0, 112, - 0, 164, 120, 143, 0, 0, 0, 0, 0, 51, 0, 90, 0, 78, 0, - 0, 0, 0, 0, 110, 0, 139, 0, 112, 51, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 102, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 107, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 78, 0, 51, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 51, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 127, 110, 133, 0, 167, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 132, 0, 190, - 194, 202, 0, 197, 187, 161, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 214, 213, 223, 203, 218, 189, 200, 122, 78, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 191, 210, 231, 197, 226, 217, 238, 216, 236, 207, - 199, 0, 0, 0, 0, 0, 107, 122, 155, 160, 214, 215, 236, 202, 235, - 203, 225, 191, 203, 188, 199, 194, 212, 127, 51, 0, 174, 188, 219, 196, - 228, 221, 240, 207, 235, 220, 241, 219, 237, 207, 212, 142, 95, 0, 139, - 78, 162, 177, 197, 183, 211, 199, 235, 208, 238, 215, 227, 207, 211, 201, - 224, 213, 226, 192, 213, 170, 223, 205, 234, 221, 245, 225, 242, 220, 245, - 221, 239, 221, 238, 213, 226, 180, 159, 112, 176, 159, 208, 202, 213, 191, - 205, 191, 225, 197, 238, 219, 224, 201, 227, 200, 221, 201, 225, 203, 212, - 195, 229, 210, 228, 210, 239, 216, 226, 212, 233, 205, 225, 200, 229, 207, - 222, 151, 147, 119, 179, 185, 230, 218, 223, 192, 202, 136, 205, 177, 223, - 204, 228, 215, 232, 209, 221, 189, 221, 205, 209, 200, 226, 209, 229, 205, - 235, 192, 209, 198, 228, 190, 206, 185, 207, 187, 214, 175, 177, 184, 220, - 195, 214, 207, 230, 184, 205, 159, 208, 184, 189, 169, 224, 213, 219, 199, - 229, 203, 216, 205, 222, 204, 224, 206, 231, 208, 231, 176, 197, 184, 216, - 193, 211, 139, 212, 195, 231, 164, 166, 195, 217, 182, 208, 190, 217, 179, - 205, 68, 182, 119, 195, 168, 182, 136, 204, 179, 193, 158, 182, 140, 188, - 154, 197, 169, 190, 99, 184, 0, 125, 0, 131, 0, 99, 68, 179, 85, - 190, 184, 213, 203, 223, 202, 212, 190, 209, 138, 178, 0, 159, 51, 128, - 51, 105, 0, 139, 51, 179, 125, 185, 114, 171, 128, 175, 132, 181, 174, - 155, 0, 0, 0, 90, 0, 125, 0, 176, 188, 227, 217, 244, 215, 234, - 221, 239, 192, 224, 210, 0, 0, 134, 0, 51, 0, 105, 0, 105, 0, - 143, 90, 192, 119, 175, 147, 141, 51, 184, 110, 85, 0, 0, 0, 0, - 0, 0, 0, 151, 139, 201, 203, 232, 203, 226, 208, 236, 206, 230, 212, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 169, 0, 119, - 0, 78, 0, 0, 0, 0, 0, 0, 0, 0, 0, 68, 0, 0, 133, - 200, 180, 220, 197, 228, 201, 221, 184, 213, 193, 110, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 78, 0, 164, 0, 0, 0, 0, 0, 107, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 150, 164, 202, 182, 224, - 197, 211, 179, 212, 193, 134, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 85, 0, 150, 0, 85, 0, 95, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 102, 90, 193, 160, 203, 164, 200, 178, 205, 174, - 116, 0, 0, 0, 0, 0, 0, 0, 0, 0, 120, 114, 123, 0, 114, - 0, 145, 68, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 102, 68, 199, 170, 195, 180, 208, 176, 200, 164, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 110, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 142, 102, 172, 110, 186, - 167, 185, 147, 189, 154, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 177, 0, 158, 136, 197, 155, 189, 166, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 85, 0, 155, 90, 175, 117, 175, 138, 202, 165, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 51, 0, 139, - 0, 120, 68, 51, 123, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 119, 0, 78, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, +const signed char g_yes_micro_f2e59fea_nohash_1_data[] = { + 116, 98, 118, 95, 106, 85, 101, 81, 67, -18, -33, -12, + -26, -128, 9, 34, 56, 45, 9, -12, 5, 30, 23, 28, + 0, -18, 0, -128, -60, -50, -50, -37, -60, -60, -50, -26, + -33, -50, -33, -50, 83, 61, 81, 55, 76, 61, 73, 64, + 38, -8, -37, -20, -18, -20, 48, 29, 52, 41, 55, 18, + 25, 37, 44, 37, 8, 15, -6, -60, -128, -50, -37, -37, + -18, -37, -26, -29, -37, -60, -50, -60, 95, 59, 52, -4, + 54, -18, 68, 43, 31, -18, -26, -33, -37, -29, 33, 7, + -3, 8, 26, 24, 36, 6, 36, 23, 14, 8, -29, -37, + -37, -37, -50, -50, -26, -8, -26, -37, -18, -37, -60, -77, + 50, 48, 83, 44, 56, -128, -33, -60, 1, -26, -60, -43, + -14, -23, -18, -43, -26, -33, 13, -77, -43, -77, -33, -37, + 16, -12, -37, -50, -50, -77, -20, -43, -60, -128, -60, -77, + -37, -77, -60, -128, 37, -10, 65, -7, 28, -128, 10, -77, + -37, -128, -77, -128, -77, -43, -128, -128, -77, -128, -128, -128, + -128, -128, -14, -128, -43, -50, -37, -77, -128, -128, -77, -43, + -29, -43, -20, -60, -37, -43, -50, -128, -77, -128, -18, -128, + -60, -128, -128, -128, -77, -128, -77, -128, -128, -128, -60, -37, + -20, -128, -60, -128, -128, -128, -60, -128, -77, -60, -128, -50, + -60, -128, -77, -128, -50, -60, -37, -60, -50, -77, -77, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -37, -128, + -128, -128, -128, -128, -77, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -77, -60, -128, -128, -50, -128, -50, -128, + -50, -128, -77, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -77, -128, -77, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -77, -128, -77, -128, -77, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -77, -128, -128, -128, + -128, -77, -50, -128, -128, -77, -77, -128, -128, -128, -50, -128, + 85, 43, 65, 53, 69, 60, 45, 3, 46, -12, 9, -23, + 32, -1, -128, -128, -128, -128, -1, 37, 38, 33, 43, 36, + 58, 70, 68, 39, 6, 10, 32, 6, 8, -23, -77, -128, + -29, -128, -77, -128, 101, 87, 102, 91, 110, 88, 101, 83, + 110, 95, 111, 83, 81, 84, 106, 90, 93, 82, 98, 91, + 108, 95, 118, 97, 118, 97, 116, 96, 113, 90, 110, 96, + 107, 85, 94, 66, 69, 36, 29, 0, 100, 60, 105, 68, + 92, 93, 113, 92, 107, 85, 107, 83, 104, 91, 105, 85, + 112, 88, 101, 80, 101, 79, 96, 80, 98, 80, 105, 83, + 98, 81, 103, 71, 100, 79, 83, 78, 91, 47, 50, 13, + 108, 81, 93, 78, 98, 76, 105, 76, 98, 40, 77, 72, + 81, 62, 93, 77, 96, 80, 98, 61, 97, 69, 88, 61, + 71, 56, 98, 68, 97, 72, 89, 51, 81, 61, 88, 75, + 86, 56, 48, 13, 71, 22, 84, 66, 76, -7, 48, 61, + 77, 62, 91, 65, 95, 74, 88, 59, 75, 58, 83, 55, + 87, 55, 76, 43, 76, -3, 56, 60, 79, 57, 71, 54, + 82, 33, 74, 71, 91, 45, 18, -7, 61, 56, 77, 41, + 73, 42, 82, 49, 59, 63, 82, 65, 66, 38, 83, 34, + 48, -8, 46, 20, 54, 33, 54, 6, 48, 16, 60, 37, + 58, 22, 58, 14, 65, 53, 75, -4, 42, 16, 16, -50, + 22, -128, 80, 54, 43, -50, 42, -128, -10, -77, 28, -29, + 68, 43, 73, 2, 25, -60, 47, 14, 45, 7, 66, 4, + 62, 37, 71, 7, 46, -10, 44, 22, 55, 53, 57, -29, + 26, -10, -3, -128, 38, -128, 46, -10, 16, -128, -10, -26, + 60, -7, 65, 38, 70, -60, 35, -8, 42, -29, 6, -128, + 34, -128, 36, -60, 44, -12, -2, -128, -7, -60, -60, -128, + -23, -128, 31, -33, 22, -77, -37, -43, -128, -128, 3, -128, + -23, -128, 17, -77, 43, -77, -7, -128, -20, -128, 17, -43, + 32, -128, -43, -128, -128, -77, 21, -128, -50, -128, -128, -128, + -128, -128, -128, -128, -37, -128, -16, -128, -50, -26, -6, -128, + -128, -128, -128, -128, -23, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -16, -128, 36, -7, 16, -128, -128, -128, -128, -128, + -77, -128, -37, -128, -50, -128, -128, -128, -128, -128, -18, -128, + 11, -128, -16, -77, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -26, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -20, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -50, -128, -77, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -77, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -1, -18, 5, -128, + 40, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, 4, -128, 63, 66, 75, -128, + 70, 60, 34, -128, -128, -128, -128, -128, -128, -128, -128, -128, + 87, 86, 95, 76, 91, 62, 72, -6, -50, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, 64, 83, 104, 70, + 98, 90, 111, 89, 109, 80, 71, -128, -128, -128, -128, -128, + -20, -6, 27, 33, 86, 88, 108, 75, 108, 76, 98, 64, + 75, 61, 71, 66, 85, -1, -77, -128, 46, 61, 92, 69, + 100, 93, 113, 80, 108, 93, 113, 91, 110, 80, 85, 15, + -33, -128, 12, -50, 34, 50, 70, 55, 84, 72, 108, 81, + 111, 88, 100, 80, 84, 73, 97, 86, 99, 65, 85, 43, + 96, 78, 107, 94, 118, 98, 115, 92, 118, 94, 111, 93, + 111, 86, 99, 52, 32, -16, 48, 31, 81, 74, 85, 64, + 78, 64, 98, 70, 110, 92, 96, 73, 100, 72, 94, 73, + 98, 76, 85, 67, 101, 83, 101, 83, 112, 89, 98, 85, + 105, 78, 98, 72, 102, 80, 95, 23, 19, -8, 52, 57, + 103, 91, 95, 65, 74, 8, 77, 49, 96, 76, 100, 87, + 105, 81, 94, 62, 94, 78, 81, 72, 99, 82, 101, 78, + 108, 65, 82, 70, 100, 63, 79, 58, 80, 59, 87, 48, + 50, 57, 93, 67, 86, 80, 103, 56, 77, 31, 81, 57, + 62, 41, 96, 85, 91, 71, 101, 76, 89, 78, 95, 76, + 96, 79, 103, 81, 103, 48, 70, 57, 88, 66, 84, 11, + 85, 67, 104, 37, 38, 67, 90, 54, 81, 62, 90, 52, + 78, -60, 54, -8, 68, 40, 55, 8, 77, 52, 66, 31, + 55, 13, 60, 26, 69, 42, 63, -29, 57, -128, -3, -128, + 3, -128, -29, -60, 52, -43, 63, 56, 86, 75, 95, 75, + 85, 63, 82, 10, 50, -128, 31, -77, 0, -77, -23, -128, + 12, -77, 51, -3, 58, -14, 44, 0, 48, 4, 53, 47, + 28, -128, -128, -128, -37, -128, -3, -128, 49, 61, 100, 90, + 117, 88, 107, 94, 112, 64, 96, 83, -128, -128, 7, -128, + -77, -128, -23, -128, -23, -128, 16, -37, 65, -8, 48, 20, + 14, -77, 57, -18, -43, -128, -128, -128, -128, -128, -128, -128, + 24, 12, 74, 76, 105, 76, 99, 80, 108, 79, 103, 85, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + 42, -128, -8, -128, -50, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -60, -128, -128, 5, 73, 53, 93, 70, 101, 73, + 94, 57, 86, 66, -18, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -50, -128, 36, -128, -128, -128, -128, -128, -20, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, 23, 37, + 75, 54, 97, 70, 83, 52, 85, 65, 7, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -43, -128, 23, -128, -43, -128, + -33, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -26, -37, 65, 33, 76, 37, 73, 50, 77, 47, + -12, -128, -128, -128, -128, -128, -128, -128, -128, -128, -7, -14, + -4, -128, -14, -128, 18, -60, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -26, -60, 71, 42, 68, 53, + 81, 49, 73, 36, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -18, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, 15, -26, + 44, -18, 59, 39, 57, 20, 62, 26, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, 49, -128, 30, 8, 69, 27, 62, 38, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -43, -128, 28, -37, 48, -10, + 48, 11, 74, 37, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -77, -128, 11, -128, -7, -60, -77, -4, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -8, -128, -50, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, + -128, -128, -128, -128, }; diff --git a/tensorflow/lite/micro/examples/micro_speech/micro_features/yes_micro_features_data.h b/tensorflow/lite/micro/examples/micro_speech/micro_features/yes_micro_features_data.h index 07eccc35f4e..cd1ad10888e 100644 --- a/tensorflow/lite/micro/examples/micro_speech/micro_features/yes_micro_features_data.h +++ b/tensorflow/lite/micro/examples/micro_speech/micro_features/yes_micro_features_data.h @@ -18,6 +18,6 @@ limitations under the License. extern const int g_yes_micro_f2e59fea_nohash_1_width; extern const int g_yes_micro_f2e59fea_nohash_1_height; -extern const unsigned char g_yes_micro_f2e59fea_nohash_1_data[]; +extern const signed char g_yes_micro_f2e59fea_nohash_1_data[]; #endif // TENSORFLOW_LITE_MICRO_EXAMPLES_MICRO_SPEECH_MICRO_FEATURES_YES_MICRO_FEATURES_DATA_H_ diff --git a/tensorflow/lite/micro/examples/micro_speech/micro_speech_binary_mock_test.sh b/tensorflow/lite/micro/examples/micro_speech/micro_speech_binary_mock_test.sh new file mode 100755 index 00000000000..f18b7fa2dff --- /dev/null +++ b/tensorflow/lite/micro/examples/micro_speech/micro_speech_binary_mock_test.sh @@ -0,0 +1,33 @@ +#!/bin/bash +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Bash unit tests for the example binary. + +set -e + +OUTPUT_LOG_FILE=${TEST_TMPDIR}/output_log.txt + +# Needed for copybara compatibility. +SCRIPT_BASE_DIR=/org_"tensor"flow +${TEST_SRCDIR}${SCRIPT_BASE_DIR}/tensorflow/lite/micro/examples/micro_speech/micro_speech_mock 2>&1 | head > ${OUTPUT_LOG_FILE} + +if ! grep -q 'Heard ' ${OUTPUT_LOG_FILE}; then + echo "ERROR: Expected logs not found in output '${OUTPUT_LOG_FILE}'" + exit 1 +fi + +echo +echo "SUCCESS: micro_speech_binary_mock_test PASSED" diff --git a/tensorflow/lite/micro/examples/micro_speech/micro_speech_test.cc b/tensorflow/lite/micro/examples/micro_speech/micro_speech_test.cc index ca090ec9524..a6e011b1224 100644 --- a/tensorflow/lite/micro/examples/micro_speech/micro_speech_test.cc +++ b/tensorflow/lite/micro/examples/micro_speech/micro_speech_test.cc @@ -48,14 +48,19 @@ TF_LITE_MICRO_TEST(TestInvoke) { // needed by this graph. // // tflite::ops::micro::AllOpsResolver resolver; - tflite::MicroOpResolver<3> micro_op_resolver; - micro_op_resolver.AddBuiltin( - tflite::BuiltinOperator_DEPTHWISE_CONV_2D, - tflite::ops::micro::Register_DEPTHWISE_CONV_2D()); + tflite::MicroOpResolver<4> micro_op_resolver; + micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_DEPTHWISE_CONV_2D, + tflite::ops::micro::Register_DEPTHWISE_CONV_2D(), + tflite::MicroOpResolverAnyVersion()); micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_FULLY_CONNECTED, - tflite::ops::micro::Register_FULLY_CONNECTED()); + tflite::ops::micro::Register_FULLY_CONNECTED(), + tflite::MicroOpResolverAnyVersion()); micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_SOFTMAX, - tflite::ops::micro::Register_SOFTMAX()); + tflite::ops::micro::Register_SOFTMAX(), + tflite::MicroOpResolverAnyVersion()); + micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_RESHAPE, + tflite::ops::micro::Register_RESHAPE(), + tflite::MicroOpResolverAnyVersion()); // Create an area of memory to use for input, output, and intermediate arrays. const int tensor_arena_size = 10 * 1024; @@ -71,18 +76,16 @@ TF_LITE_MICRO_TEST(TestInvoke) { // Make sure the input has the properties we expect. TF_LITE_MICRO_EXPECT_NE(nullptr, input); - TF_LITE_MICRO_EXPECT_EQ(4, input->dims->size); + TF_LITE_MICRO_EXPECT_EQ(2, input->dims->size); TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[0]); - TF_LITE_MICRO_EXPECT_EQ(49, input->dims->data[1]); - TF_LITE_MICRO_EXPECT_EQ(40, input->dims->data[2]); - TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[3]); - TF_LITE_MICRO_EXPECT_EQ(kTfLiteUInt8, input->type); + TF_LITE_MICRO_EXPECT_EQ(1960, input->dims->data[1]); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteInt8, input->type); // Copy a spectrogram created from a .wav audio file of someone saying "Yes", // into the memory area used for the input. - const uint8_t* yes_features_data = g_yes_micro_f2e59fea_nohash_1_data; + const int8_t* yes_features_data = g_yes_micro_f2e59fea_nohash_1_data; for (int i = 0; i < input->bytes; ++i) { - input->data.uint8[i] = yes_features_data[i]; + input->data.int8[i] = yes_features_data[i]; } // Run the model on this input and make sure it succeeds. @@ -98,7 +101,7 @@ TF_LITE_MICRO_TEST(TestInvoke) { TF_LITE_MICRO_EXPECT_EQ(2, output->dims->size); TF_LITE_MICRO_EXPECT_EQ(1, output->dims->data[0]); TF_LITE_MICRO_EXPECT_EQ(4, output->dims->data[1]); - TF_LITE_MICRO_EXPECT_EQ(kTfLiteUInt8, output->type); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteInt8, output->type); // There are four possible classes in the output, each with a score. const int kSilenceIndex = 0; @@ -107,18 +110,18 @@ TF_LITE_MICRO_TEST(TestInvoke) { const int kNoIndex = 3; // Make sure that the expected "Yes" score is higher than the other classes. - uint8_t silence_score = output->data.uint8[kSilenceIndex]; - uint8_t unknown_score = output->data.uint8[kUnknownIndex]; - uint8_t yes_score = output->data.uint8[kYesIndex]; - uint8_t no_score = output->data.uint8[kNoIndex]; + uint8_t silence_score = output->data.uint8[kSilenceIndex] + 128; + uint8_t unknown_score = output->data.uint8[kUnknownIndex] + 128; + uint8_t yes_score = output->data.int8[kYesIndex] + 128; + uint8_t no_score = output->data.int8[kNoIndex] + 128; TF_LITE_MICRO_EXPECT_GT(yes_score, silence_score); TF_LITE_MICRO_EXPECT_GT(yes_score, unknown_score); TF_LITE_MICRO_EXPECT_GT(yes_score, no_score); // Now test with a different input, from a recording of "No". - const uint8_t* no_features_data = g_no_micro_f9643d42_nohash_4_data; + const int8_t* no_features_data = g_no_micro_f9643d42_nohash_4_data; for (int i = 0; i < input->bytes; ++i) { - input->data.uint8[i] = no_features_data[i]; + input->data.int8[i] = no_features_data[i]; } // Run the model on this "No" input. @@ -134,13 +137,13 @@ TF_LITE_MICRO_TEST(TestInvoke) { TF_LITE_MICRO_EXPECT_EQ(2, output->dims->size); TF_LITE_MICRO_EXPECT_EQ(1, output->dims->data[0]); TF_LITE_MICRO_EXPECT_EQ(4, output->dims->data[1]); - TF_LITE_MICRO_EXPECT_EQ(kTfLiteUInt8, output->type); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteInt8, output->type); // Make sure that the expected "No" score is higher than the other classes. - silence_score = output->data.uint8[kSilenceIndex]; - unknown_score = output->data.uint8[kUnknownIndex]; - yes_score = output->data.uint8[kYesIndex]; - no_score = output->data.uint8[kNoIndex]; + silence_score = output->data.int8[kSilenceIndex] + 128; + unknown_score = output->data.int8[kUnknownIndex] + 128; + yes_score = output->data.int8[kYesIndex] + 128; + no_score = output->data.int8[kNoIndex] + 128; TF_LITE_MICRO_EXPECT_GT(no_score, silence_score); TF_LITE_MICRO_EXPECT_GT(no_score, unknown_score); TF_LITE_MICRO_EXPECT_GT(no_score, yes_score); diff --git a/tensorflow/lite/micro/examples/micro_speech/recognize_commands.cc b/tensorflow/lite/micro/examples/micro_speech/recognize_commands.cc index 96f35984051..47bd10074d3 100644 --- a/tensorflow/lite/micro/examples/micro_speech/recognize_commands.cc +++ b/tensorflow/lite/micro/examples/micro_speech/recognize_commands.cc @@ -47,10 +47,10 @@ TfLiteStatus RecognizeCommands::ProcessLatestResults( return kTfLiteError; } - if (latest_results->type != kTfLiteUInt8) { + if (latest_results->type != kTfLiteInt8) { TF_LITE_REPORT_ERROR( error_reporter_, - "The results for recognition should be uint8 elements, but are %d", + "The results for recognition should be int8 elements, but are %d", latest_results->type); return kTfLiteError; } @@ -66,7 +66,7 @@ TfLiteStatus RecognizeCommands::ProcessLatestResults( } // Add the latest results to the head of the queue. - previous_results_.push_back({current_time_ms, latest_results->data.uint8}); + previous_results_.push_back({current_time_ms, latest_results->data.int8}); // Prune any earlier results that are too old for the averaging window. const int64_t time_limit = current_time_ms - average_window_duration_ms_; @@ -93,12 +93,12 @@ TfLiteStatus RecognizeCommands::ProcessLatestResults( for (int offset = 0; offset < previous_results_.size(); ++offset) { PreviousResultsQueue::Result previous_result = previous_results_.from_front(offset); - const uint8_t* scores = previous_result.scores_; + const int8_t* scores = previous_result.scores; for (int i = 0; i < kCategoryCount; ++i) { if (offset == 0) { - average_scores[i] = scores[i]; + average_scores[i] = scores[i] + 128; } else { - average_scores[i] += scores[i]; + average_scores[i] += scores[i] + 128; } } } diff --git a/tensorflow/lite/micro/examples/micro_speech/recognize_commands.h b/tensorflow/lite/micro/examples/micro_speech/recognize_commands.h index 059d567fb20..67bdb31bed9 100644 --- a/tensorflow/lite/micro/examples/micro_speech/recognize_commands.h +++ b/tensorflow/lite/micro/examples/micro_speech/recognize_commands.h @@ -36,14 +36,14 @@ class PreviousResultsQueue { // Data structure that holds an inference result, and the time when it // was recorded. struct Result { - Result() : time_(0), scores_() {} - Result(int32_t time, uint8_t* scores) : time_(time) { + Result() : time_(0), scores() {} + Result(int32_t time, int8_t* input_scores) : time_(time) { for (int i = 0; i < kCategoryCount; ++i) { - scores_[i] = scores[i]; + scores[i] = input_scores[i]; } } int32_t time_; - uint8_t scores_[kCategoryCount]; + int8_t scores[kCategoryCount]; }; int size() { return size_; } diff --git a/tensorflow/lite/micro/examples/micro_speech/recognize_commands_test.cc b/tensorflow/lite/micro/examples/micro_speech/recognize_commands_test.cc index 70911a81776..dcff73cf7ee 100644 --- a/tensorflow/lite/micro/examples/micro_speech/recognize_commands_test.cc +++ b/tensorflow/lite/micro/examples/micro_speech/recognize_commands_test.cc @@ -27,13 +27,13 @@ TF_LITE_MICRO_TEST(PreviousResultsQueueBasic) { PreviousResultsQueue queue(error_reporter); TF_LITE_MICRO_EXPECT_EQ(0, queue.size()); - uint8_t scores_a[4] = {0, 0, 0, 1}; + int8_t scores_a[4] = {0, 0, 0, 1}; queue.push_back({0, scores_a}); TF_LITE_MICRO_EXPECT_EQ(1, queue.size()); TF_LITE_MICRO_EXPECT_EQ(0, queue.front().time_); TF_LITE_MICRO_EXPECT_EQ(0, queue.back().time_); - uint8_t scores_b[4] = {0, 0, 1, 0}; + int8_t scores_b[4] = {0, 0, 1, 0}; queue.push_back({1, scores_b}); TF_LITE_MICRO_EXPECT_EQ(2, queue.size()); TF_LITE_MICRO_EXPECT_EQ(0, queue.front().time_); @@ -45,7 +45,7 @@ TF_LITE_MICRO_TEST(PreviousResultsQueueBasic) { TF_LITE_MICRO_EXPECT_EQ(1, queue.front().time_); TF_LITE_MICRO_EXPECT_EQ(1, queue.back().time_); - uint8_t scores_c[4] = {0, 1, 0, 0}; + int8_t scores_c[4] = {0, 1, 0, 0}; queue.push_back({2, scores_c}); TF_LITE_MICRO_EXPECT_EQ(2, queue.size()); TF_LITE_MICRO_EXPECT_EQ(1, queue.front().time_); @@ -60,7 +60,7 @@ TF_LITE_MICRO_TEST(PreviousResultsQueuePushPop) { TF_LITE_MICRO_EXPECT_EQ(0, queue.size()); for (int i = 0; i < 123; ++i) { - uint8_t scores[4] = {0, 0, 0, 1}; + int8_t scores[4] = {0, 0, 0, 1}; queue.push_back({i, scores}); TF_LITE_MICRO_EXPECT_EQ(1, queue.size()); TF_LITE_MICRO_EXPECT_EQ(i, queue.front().time_); @@ -78,11 +78,11 @@ TF_LITE_MICRO_TEST(RecognizeCommandsTestBasic) { RecognizeCommands recognize_commands(error_reporter); - std::initializer_list<uint8_t> result_data = {255, 0, 0, 0}; + std::initializer_list<int8_t> result_data = {127, -128, -128, -128}; auto result_dims = {2, 1, 4}; TfLiteTensor results = tflite::testing::CreateQuantizedTensor( result_data, tflite::testing::IntArrayFromInitializer(result_dims), - "input_tensor", 0.0f, 128.0f); + "input_tensor", -128.0f, 127.0f); const char* found_command; uint8_t score; @@ -98,11 +98,11 @@ TF_LITE_MICRO_TEST(RecognizeCommandsTestFindCommands) { RecognizeCommands recognize_commands(error_reporter, 1000, 51); - std::initializer_list<uint8_t> yes_data = {0, 0, 255, 0}; + std::initializer_list<int8_t> yes_data = {-128, -128, 127, -128}; auto yes_dims = {2, 1, 4}; TfLiteTensor yes_results = tflite::testing::CreateQuantizedTensor( yes_data, tflite::testing::IntArrayFromInitializer(yes_dims), - "input_tensor", 0.0f, 128.0f); + "input_tensor", -128.0f, 127.0f); bool has_found_new_command = false; const char* new_command; @@ -126,11 +126,11 @@ TF_LITE_MICRO_TEST(RecognizeCommandsTestFindCommands) { TF_LITE_MICRO_EXPECT_EQ(0, tflite::testing::TestStrcmp("yes", new_command)); } - std::initializer_list<uint8_t> no_data = {0, 0, 0, 255}; + std::initializer_list<int8_t> no_data = {-128, -128, -128, 127}; auto no_dims = {2, 1, 4}; TfLiteTensor no_results = tflite::testing::CreateQuantizedTensor( no_data, tflite::testing::IntArrayFromInitializer(no_dims), - "input_tensor", 0.0f, 128.0f); + "input_tensor", -128.0f, 127.0f); has_found_new_command = false; new_command = ""; uint8_t score; @@ -161,11 +161,11 @@ TF_LITE_MICRO_TEST(RecognizeCommandsTestBadInputLength) { RecognizeCommands recognize_commands(error_reporter, 1000, 51); - std::initializer_list<uint8_t> bad_data = {0, 0, 255}; + std::initializer_list<int8_t> bad_data = {-128, -128, 127}; auto bad_dims = {2, 1, 3}; TfLiteTensor bad_results = tflite::testing::CreateQuantizedTensor( bad_data, tflite::testing::IntArrayFromInitializer(bad_dims), - "input_tensor", 0.0f, 128.0f); + "input_tensor", -128.0f, 127.0f); const char* found_command; uint8_t score; @@ -181,11 +181,11 @@ TF_LITE_MICRO_TEST(RecognizeCommandsTestBadInputTimes) { RecognizeCommands recognize_commands(error_reporter, 1000, 51); - std::initializer_list<uint8_t> result_data = {0, 0, 255, 0}; + std::initializer_list<int8_t> result_data = {-128, -128, 127, -128}; auto result_dims = {2, 1, 4}; TfLiteTensor results = tflite::testing::CreateQuantizedTensor( result_data, tflite::testing::IntArrayFromInitializer(result_dims), - "input_tensor", 0.0f, 128.0f); + "input_tensor", -128.0f, 127.0f); const char* found_command; uint8_t score; @@ -204,11 +204,11 @@ TF_LITE_MICRO_TEST(RecognizeCommandsTestTooFewInputs) { RecognizeCommands recognize_commands(error_reporter, 1000, 51); - std::initializer_list<uint8_t> result_data = {0, 0, 255, 0}; + std::initializer_list<int8_t> result_data = {-128, -128, 127, -128}; auto result_dims = {2, 1, 4}; TfLiteTensor results = tflite::testing::CreateQuantizedTensor( result_data, tflite::testing::IntArrayFromInitializer(result_dims), - "input_tensor", 0.0f, 128.0f); + "input_tensor", -128.0f, 127.0f); const char* found_command; uint8_t score; diff --git a/tensorflow/lite/micro/examples/person_detection/BUILD b/tensorflow/lite/micro/examples/person_detection/BUILD index cb9fdb80c33..84eddba73d4 100644 --- a/tensorflow/lite/micro/examples/person_detection/BUILD +++ b/tensorflow/lite/micro/examples/person_detection/BUILD @@ -23,7 +23,7 @@ cc_library( cc_library( name = "person_detect_model_data", srcs = [ - "person_detect_model_data.cc", + "@person_detect_data//:person_detect_model_data", ], hdrs = [ "person_detect_model_data.h", @@ -56,7 +56,7 @@ cc_library( deps = [ ":model_settings", "//tensorflow/lite/c:common", - "//tensorflow/lite/micro:micro_framework", + "//tensorflow/lite/micro:micro_error_reporter", ], ) @@ -69,7 +69,7 @@ tflite_micro_cc_test( ":image_provider", ":model_settings", "//tensorflow/lite/c:common", - "//tensorflow/lite/micro:micro_framework", + "//tensorflow/lite/micro:micro_error_reporter", "//tensorflow/lite/micro/testing:micro_test", ], ) @@ -84,7 +84,7 @@ cc_library( ], deps = [ "//tensorflow/lite/c:common", - "//tensorflow/lite/micro:micro_framework", + "//tensorflow/lite/micro:micro_error_reporter", ], ) @@ -112,8 +112,15 @@ cc_binary( ":model_settings", ":person_detect_model_data", "//tensorflow/lite:schema_fbs_version", + "//tensorflow/lite/micro:micro_error_reporter", "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro/kernels:micro_ops", "//tensorflow/lite/schema:schema_fbs", ], ) + +sh_test( + name = "person_detection_binary_test", + srcs = ["person_detection_binary_test.sh"], + data = [":person_detection"], +) diff --git a/tensorflow/lite/micro/examples/person_detection/README.md b/tensorflow/lite/micro/examples/person_detection/README.md index 5ee7bda9914..423941dcad8 100644 --- a/tensorflow/lite/micro/examples/person_detection/README.md +++ b/tensorflow/lite/micro/examples/person_detection/README.md @@ -5,7 +5,9 @@ network to recognize people in images captured by a camera. It is designed to run on systems with small amounts of memory such as microcontrollers and DSPs. ## Table of contents + - [Getting started](#getting-started) +- [Running on ARC EM SDP](#running-on-arc-em-sdp) - [Running on Arduino](#running-on-arduino) - [Running on ESP32](#running-on-esp32) - [Running on SparkFun Edge](#running-on-sparkfun-edge) @@ -13,6 +15,94 @@ run on systems with small amounts of memory such as microcontrollers and DSPs. - [Debugging image capture](#debugging-image-capture) - [Training your own model](#training-your-own-model) +## Running on ARC EM SDP + +The following instructions will help you to build and deploy this example to +[ARC EM SDP](https://www.synopsys.com/dw/ipdir.php?ds=arc-em-software-development-platform) +board. General information and instructions on using the board with TensorFlow +Lite Micro can be found in the common +[ARC targets description](/tensorflow/lite/micro/tools/make/targets/arc/README.md). + +This example is quantized with symmetric uint8 scheme. As noted in +[kernels/arc_mli/README.md](/tensorflow/lite/micro/kernels/arc_mli/README.md), +embARC MLI supports optimized kernels for int8 quantization only. Therefore, +this example will only use TFLM reference kernels. + +The ARC EM SDP board contains the reach set of extension interfaces. You can +choose any compatible camera and modify +[image_provider.cc](/tensorflow/lite/micro/examples/person_detection/image_provider.cc) +file accordingly to use input from your specific camera. By default, results of +running this example are printed to the console. If you would like to instead +implement some target-specific actions, you need to modify +[detection_responder.cc](/tensorflow/lite/micro/examples/person_detection/detection_responder.cc) +accordingly. + +The reference implementations of these files are used by default on the EM SDP. + +### Initial setup + +Follow the instructions on the +[ARC EM SDP Initial Setup](/tensorflow/lite/micro/tools/make/targets/arc/README.md#ARC-EM-Software-Development-Platform-ARC-EM-SDP) +to get and install all required tools for work with ARC EM SDP. + +### Generate Example Project + +The example project for ARC EM SDP platform can be generated with the following +command: + +``` +make -f tensorflow/lite/micro/tools/make/Makefile TARGET=arc_emsdp TAGS=no_arc_mli generate_person_detection_make_project +``` + +### Build and Run Example + +For more detailed information on building and running examples see the +appropriate sections of general descriptions of the +[ARC EM SDP usage with TFLM](/tensorflow/lite/micro/tools/make/targets/arc/README.md#ARC-EM-Software-Development-Platform-ARC-EM-SDP). +In the directory with generated project you can also find a +*README_ARC_EMSDP.md* file with instructions and options on building and +running. Here we only briefly mention main steps which are typically enough to +get it started. + +1. You need to + [connect the board](/tensorflow/lite/micro/tools/make/targets/arc/README.md#connect-the-board) + and open an serial connection. + +2. Go to the generated example project director + + ``` + cd tensorflow/lite/micro/tools/make/gen/arc_emsdp_arc/prj/person_detection/make + ``` + +3. Build the example using + + ``` + make app + ``` + +4. To generate artefacts for self-boot of example from the board use + + ``` + make flash + ``` + +5. To run application from the board using microSD card: + + * Copy the content of the created /bin folder into the root of microSD + card. Note that the card must be formatted as FAT32 with default cluster + size (but less than 32 Kbytes) + * Plug in the microSD card into the J11 connector. + * Push the RST button. If a red LED is lit beside RST button, push the CFG + button. + +6. If you have the MetaWare Debugger installed in your environment: + + * To run application from the console using it type `make run`. + * To stop the execution type `Ctrl+C` in the console several times. + +In both cases (step 5 and 6) you will see the application output in the serial +terminal. + ## Running on Arduino The following instructions will help you build and deploy this sample diff --git a/tensorflow/lite/micro/examples/person_detection/arc_emsdp/Makefile.inc b/tensorflow/lite/micro/examples/person_detection/arc_emsdp/Makefile.inc new file mode 100644 index 00000000000..29a09466e83 --- /dev/null +++ b/tensorflow/lite/micro/examples/person_detection/arc_emsdp/Makefile.inc @@ -0,0 +1,24 @@ +ifeq ($(TARGET), arc_emsdp) + +# Patch of arc make project to adjust it specifically +# for person detection example. In particular: +# - Use Linker command file with better usage of fast memory +# - In case project was generated with MLI usage, reduce scratch buffers. + + person_detection_HDRS += \ + person_detection_patch.txt + + person_detection_TEST_HDRS += \ + person_detection_patch.txt + + +%/person_detection_patch.txt: %/emsdp.lcf %/Makefile + @cp tensorflow/lite/micro/tools/make/targets/arc/emsdp/emsdp_v2.lcf $< + @echo emsdp.lcf > $@ + @sed -E -i 's#MLI_ONLY *\?= *false#MLI_ONLY \?= false\n\ + CXXFLAGS += -DSCRATCH_MEM_X_SIZE=0 -DSCRATCH_MEM_Y_SIZE=0 -DSCRATCH_MEM_Z_SIZE=0\ + CCFLAGS += -DSCRATCH_MEM_X_SIZE=0 -DSCRATCH_MEM_Y_SIZE=0 -DSCRATCH_MEM_Z_SIZE=0#'\ + $(word 2, $^) + @echo Makefile >> $@ + +endif diff --git a/tensorflow/lite/micro/examples/person_detection/esp/app_camera_esp.h b/tensorflow/lite/micro/examples/person_detection/esp/app_camera_esp.h index 403fb4defb1..e8cbe2177a9 100644 --- a/tensorflow/lite/micro/examples/person_detection/esp/app_camera_esp.h +++ b/tensorflow/lite/micro/examples/person_detection/esp/app_camera_esp.h @@ -30,7 +30,7 @@ limitations under the License. #define CAMERA_PIXEL_FORMAT PIXFORMAT_GRAYSCALE /* - * FRAMESIZE_96x96, // 96x96 + * FRAMESIZE_96X96, // 96x96 * FRAMESIZE_QQVGA, // 160x120 * FRAMESIZE_QQVGA2, // 128x160 * FRAMESIZE_QCIF, // 176x144 @@ -43,7 +43,7 @@ limitations under the License. * FRAMESIZE_SXGA, // 1280x1024 * FRAMESIZE_UXGA, // 1600x1200 */ -#define CAMERA_FRAME_SIZE FRAMESIZE_96x96 +#define CAMERA_FRAME_SIZE FRAMESIZE_96X96 #if CONFIG_CAMERA_MODEL_WROVER_KIT #define PWDN_GPIO_NUM -1 diff --git a/tensorflow/lite/micro/examples/person_detection/person_detection_binary_test.sh b/tensorflow/lite/micro/examples/person_detection/person_detection_binary_test.sh new file mode 100755 index 00000000000..00d985d19bf --- /dev/null +++ b/tensorflow/lite/micro/examples/person_detection/person_detection_binary_test.sh @@ -0,0 +1,33 @@ +#!/bin/bash +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Bash unit tests for the example binary. + +set -e + +OUTPUT_LOG_FILE=${TEST_TMPDIR}/output_log.txt + +# Needed for copybara compatibility. +SCRIPT_BASE_DIR=/org_"tensor"flow +${TEST_SRCDIR}${SCRIPT_BASE_DIR}/tensorflow/lite/micro/examples/person_detection/person_detection 2>&1 | head > ${OUTPUT_LOG_FILE} + +if ! grep -q 'person score' ${OUTPUT_LOG_FILE}; then + echo "ERROR: Expected logs not found in output '${OUTPUT_LOG_FILE}'" + exit 1 +fi + +echo +echo "SUCCESS: person_detection_binary_test PASSED" diff --git a/tensorflow/lite/micro/examples/person_detection_experimental/BUILD b/tensorflow/lite/micro/examples/person_detection_experimental/BUILD index cb9fdb80c33..49f10c814cb 100644 --- a/tensorflow/lite/micro/examples/person_detection_experimental/BUILD +++ b/tensorflow/lite/micro/examples/person_detection_experimental/BUILD @@ -56,7 +56,7 @@ cc_library( deps = [ ":model_settings", "//tensorflow/lite/c:common", - "//tensorflow/lite/micro:micro_framework", + "//tensorflow/lite/micro:micro_error_reporter", ], ) @@ -69,6 +69,7 @@ tflite_micro_cc_test( ":image_provider", ":model_settings", "//tensorflow/lite/c:common", + "//tensorflow/lite/micro:micro_error_reporter", "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro/testing:micro_test", ], @@ -84,7 +85,7 @@ cc_library( ], deps = [ "//tensorflow/lite/c:common", - "//tensorflow/lite/micro:micro_framework", + "//tensorflow/lite/micro:micro_error_reporter", ], ) @@ -112,6 +113,7 @@ cc_binary( ":model_settings", ":person_detect_model_data", "//tensorflow/lite:schema_fbs_version", + "//tensorflow/lite/micro:micro_error_reporter", "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro/kernels:micro_ops", "//tensorflow/lite/schema:schema_fbs", diff --git a/tensorflow/lite/micro/examples/person_detection_experimental/README.md b/tensorflow/lite/micro/examples/person_detection_experimental/README.md index d8aaa9ba383..bf99b40d776 100644 --- a/tensorflow/lite/micro/examples/person_detection_experimental/README.md +++ b/tensorflow/lite/micro/examples/person_detection_experimental/README.md @@ -6,13 +6,101 @@ run on systems with small amounts of memory such as microcontrollers and DSPs. This uses the experimental int8 quantized version of the person detection model. ## Table of contents + - [Getting started](#getting-started) +- [Running on ARC EM SDP](#running-on-arc-em-sdp) - [Running on Arduino](#running-on-arduino) - [Running on SparkFun Edge](#running-on-sparkfun-edge) - [Run the tests on a development machine](#run-the-tests-on-a-development-machine) - [Debugging image capture](#debugging-image-capture) - [Training your own model](#training-your-own-model) +## Running on ARC EM SDP + +The following instructions will help you to build and deploy this example to +[ARC EM SDP](https://www.synopsys.com/dw/ipdir.php?ds=arc-em-software-development-platform) +board. General information and instructions on using the board with TensorFlow +Lite Micro can be found in the common +[ARC targets description](/tensorflow/lite/micro/tools/make/targets/arc/README.md). + +This example uses asymmetric int8 quantization and can therefore leverage +optimized int8 kernels from the embARC MLI library + +The ARC EM SDP board contains a rich set of extension interfaces. You can choose +any compatible camera and modify +[image_provider.cc](/tensorflow/lite/micro/examples/person_detection_experimental/image_provider.cc) +file accordingly to use input from your specific camera. By default, results of +running this example are printed to the console. If you would like to instead +implement some target-specific actions, you need to modify +[detection_responder.cc](/tensorflow/lite/micro/examples/person_detection_experimental/detection_responder.cc) +accordingly. + +The reference implementations of these files are used by default on the EM SDP. + +### Initial setup + +Follow the instructions on the +[ARC EM SDP Initial Setup](/tensorflow/lite/micro/tools/make/targets/arc/README.md#ARC-EM-Software-Development-Platform-ARC-EM-SDP) +to get and install all required tools for work with ARC EM SDP. + +### Generate Example Project + +The example project for ARC EM SDP platform can be generated with the following +command: + +``` +make -f tensorflow/lite/micro/tools/make/Makefile TARGET=arc_emsdp generate_person_detection_int8_make_project +``` + +### Build and Run Example + +For more detailed information on building and running examples see the +appropriate sections of general descriptions of the +[ARC EM SDP usage with TFLM](/tensorflow/lite/micro/tools/make/targets/arc/README.md#ARC-EM-Software-Development-Platform-ARC-EM-SDP). +In the directory with generated project you can also find a +*README_ARC_EMSDP.md* file with instructions and options on building and +running. Here we only briefly mention main steps which are typically enough to +get it started. + +1. You need to + [connect the board](/tensorflow/lite/micro/tools/make/targets/arc/README.md#connect-the-board) + and open an serial connection. + +2. Go to the generated example project director + + ``` + cd tensorflow/lite/micro/tools/make/gen/arc_emsdp_arc/prj/person_detection_int8/make + ``` + +3. Build the example using + + ``` + make app + ``` + +4. To generate artefacts for self-boot of example from the board use + + ``` + make flash + ``` + +5. To run application from the board using microSD card: + + * Copy the content of the created /bin folder into the root of microSD + card. Note that the card must be formatted as FAT32 with default cluster + size (but less than 32 Kbytes) + * Plug in the microSD card into the J11 connector. + * Push the RST button. If a red LED is lit beside RST button, push the CFG + button. + +6. If you have the MetaWare Debugger installed in your environment: + + * To run application from the console using it type `make run`. + * To stop the execution type `Ctrl+C` in the console several times. + +In both cases (step 5 and 6) you will see the application output in the serial +terminal. + ## Running on Arduino The following instructions will help you build and deploy this sample diff --git a/tensorflow/lite/micro/examples/person_detection_experimental/arc_emsdp/Makefile.inc b/tensorflow/lite/micro/examples/person_detection_experimental/arc_emsdp/Makefile.inc new file mode 100644 index 00000000000..c00f9b89953 --- /dev/null +++ b/tensorflow/lite/micro/examples/person_detection_experimental/arc_emsdp/Makefile.inc @@ -0,0 +1,21 @@ +ifeq ($(TARGET), arc_emsdp) + +# Patch of arc make project to adjust it specifically +# for experimental person detection example. In particular: +# - Use Linker command file with better usage of fast memory +# - Stripout TFLM reference code by default. + + person_detection_HDRS += \ + person_detection_int8_patch.txt + + person_detection_TEST_HDRS += \ + person_detection_int8_patch.txt + + +%/person_detection_int8_patch.txt: %/emsdp.lcf %/Makefile + @cp tensorflow/lite/micro/examples/person_detection_experimental/arc_emsdp/emsdp.lcf $< + @echo emsdp.lcf > $@ + @sed -E -i 's#MLI_ONLY *\?= *false#MLI_ONLY \?= true#' $(word 2, $^) + @echo Makefile > $@ + +endif diff --git a/tensorflow/lite/micro/examples/person_detection_experimental/arc_emsdp/emsdp.lcf b/tensorflow/lite/micro/examples/person_detection_experimental/arc_emsdp/emsdp.lcf new file mode 100644 index 00000000000..c4150930d2b --- /dev/null +++ b/tensorflow/lite/micro/examples/person_detection_experimental/arc_emsdp/emsdp.lcf @@ -0,0 +1,74 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Difference with common EMSDP LCF file (to reduce data access time): +# - move data from external PSRAM to on-chip memory +# - move text from SRAM to ICCM +# +# CCMWRAP memory regions indicate unusable portions of the address space +# due to CCM memory wrapping into upper addresses beyond its size + +MEMORY { + PSRAM : ORIGIN = 0x10000400, LENGTH = (0x01000000 >> 1) - 0x400 + SRAM : ORIGIN = 0x20000000, LENGTH = 0x00040000 + IVT : ORIGIN = 0x60000000, LENGTH = 0x400 + ICCM0 : ORIGIN = 0x60000400, LENGTH = (0x00020000 - 0x400) +# CCMWRAP0: ORIGIN = 0x60020000, LENGTH = 0x0ffe0000 + DCCM : ORIGIN = 0x80000000, LENGTH = 0x00020000 +# CCMWRAP1: ORIGIN = 0x80020000, LENGTH = 0x0ffe0000 + XCCM : ORIGIN = 0x90000000, LENGTH = 0x00004000 +# CCMWRAP2: ORIGIN = 0x90004000, LENGTH = 0x0fffc000 + YCCM : ORIGIN = 0xa0000000, LENGTH = 0x00004000 +# CCMWRAP3: ORIGIN = 0xa0004000, LENGTH = 0x0fffc000 + } + +SECTIONS { + + GROUP BLOCK(4) : { + .vectors (TEXT) SIZE(DEFINED _IVTSIZE?_IVTSIZE:756): {} = FILL(0xa5a5a5a5,4) + } > IVT + + GROUP BLOCK(4): { + .text? : { *('.text$crt*') } + * (TEXT): {} + * (LIT): {} + } > ICCM0 + + GROUP BLOCK(4): { + .rodata_in_data? : {} + } > PSRAM + + GROUP BLOCK(4): { + /* _SDA_BASE_ computed implicitly */ + .sdata?: {} + .sbss?: {} + * (DATA): {} + * (BSS): {} + .debug_log? : {} + } > SRAM + + GROUP BLOCK(4): { + .Zdata? : {} + .heap? ALIGN(4) SIZE(DEFINED _HEAPSIZE?_HEAPSIZE:8K): {} + .stack ALIGN(4) SIZE(DEFINED _STACKSIZE?_STACKSIZE:8K): {} + } > DCCM + + GROUP BLOCK(4): { + .Xdata? : {} + } > XCCM + + GROUP BLOCK(4): { + .Ydata? : {} + } > YCCM +} + + diff --git a/tensorflow/lite/micro/examples/person_detection_experimental/training_a_model.md b/tensorflow/lite/micro/examples/person_detection_experimental/training_a_model.md index 24067fc188f..beb743a2923 100644 --- a/tensorflow/lite/micro/examples/person_detection_experimental/training_a_model.md +++ b/tensorflow/lite/micro/examples/person_detection_experimental/training_a_model.md @@ -372,6 +372,9 @@ tf.lite.TFLiteConverter.from_frozen_graph('vww_96_grayscale_frozen.pb', ['input'], ['MobilenetV1/Predictions/Reshape_1']) converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.representative_dataset = representative_dataset_gen +converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] +converter.inference_input_type = tf.int8 +converter.inference_output_type = tf.int8 tflite_quant_model = converter.convert() open("vww_96_grayscale_quantized.tflite", "wb").write(tflite_quant_model) diff --git a/tensorflow/lite/micro/kernels/BUILD b/tensorflow/lite/micro/kernels/BUILD index 50a0a4f9190..b6c6054d604 100644 --- a/tensorflow/lite/micro/kernels/BUILD +++ b/tensorflow/lite/micro/kernels/BUILD @@ -201,7 +201,7 @@ tflite_micro_cc_test( deps = [ ":all_ops_resolver", "//tensorflow/lite/c:common", - "//tensorflow/lite/micro:micro_framework", + "//tensorflow/lite/micro:debug_log", "//tensorflow/lite/micro/testing:micro_test", ], ) @@ -214,7 +214,6 @@ tflite_micro_cc_test( deps = [ ":all_ops_resolver", "//tensorflow/lite/c:common", - "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro/testing:micro_test", ], ) @@ -228,7 +227,6 @@ tflite_micro_cc_test( ":all_ops_resolver", "//tensorflow/lite/c:common", "//tensorflow/lite/kernels/internal:tensor", - "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro/testing:micro_test", ], ) @@ -242,7 +240,6 @@ tflite_micro_cc_test( ":portable_optimized_ops_resolver", "//tensorflow/lite/c:common", "//tensorflow/lite/kernels/internal:tensor", - "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro/testing:micro_test", ], ) @@ -269,7 +266,6 @@ tflite_micro_cc_test( deps = [ ":all_ops_resolver", "//tensorflow/lite/c:common", - "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro/testing:micro_test", ], ) @@ -282,7 +278,6 @@ tflite_micro_cc_test( deps = [ ":all_ops_resolver", "//tensorflow/lite/c:common", - "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro/testing:micro_test", ], ) @@ -295,7 +290,6 @@ tflite_micro_cc_test( deps = [ ":all_ops_resolver", "//tensorflow/lite/c:common", - "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro/testing:micro_test", ], ) @@ -308,7 +302,6 @@ tflite_micro_cc_test( deps = [ ":all_ops_resolver", "//tensorflow/lite/c:common", - "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro:micro_utils", "//tensorflow/lite/micro/testing:micro_test", ], @@ -322,7 +315,6 @@ tflite_micro_cc_test( deps = [ ":all_ops_resolver", "//tensorflow/lite/c:common", - "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro/testing:micro_test", ], ) @@ -335,7 +327,6 @@ tflite_micro_cc_test( deps = [ ":all_ops_resolver", "//tensorflow/lite/c:common", - "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro/testing:micro_test", ], ) @@ -348,7 +339,6 @@ tflite_micro_cc_test( deps = [ ":all_ops_resolver", "//tensorflow/lite/c:common", - "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro/testing:micro_test", ], ) @@ -361,7 +351,6 @@ tflite_micro_cc_test( deps = [ ":all_ops_resolver", "//tensorflow/lite/c:common", - "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro/testing:micro_test", ], ) @@ -374,7 +363,6 @@ tflite_micro_cc_test( deps = [ ":all_ops_resolver", "//tensorflow/lite/c:common", - "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro/testing:micro_test", ], ) @@ -387,7 +375,6 @@ tflite_micro_cc_test( deps = [ ":all_ops_resolver", "//tensorflow/lite/c:common", - "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro/testing:micro_test", ], ) @@ -400,7 +387,6 @@ tflite_micro_cc_test( deps = [ ":all_ops_resolver", "//tensorflow/lite/c:common", - "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro/testing:micro_test", ], ) @@ -412,9 +398,7 @@ tflite_micro_cc_test( ], deps = [ ":all_ops_resolver", - ":micro_utils", "//tensorflow/lite/c:common", - "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro/testing:micro_test", ], ) @@ -426,9 +410,7 @@ tflite_micro_cc_test( ], deps = [ ":all_ops_resolver", - ":micro_utils", "//tensorflow/lite/c:common", - "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro/testing:micro_test", ], ) @@ -441,7 +423,6 @@ tflite_micro_cc_test( deps = [ ":all_ops_resolver", "//tensorflow/lite/c:common", - "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro/testing:micro_test", ], ) @@ -454,7 +435,6 @@ tflite_micro_cc_test( deps = [ ":all_ops_resolver", "//tensorflow/lite/c:common", - "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro/testing:micro_test", ], ) @@ -467,7 +447,6 @@ tflite_micro_cc_test( deps = [ ":all_ops_resolver", "//tensorflow/lite/c:common", - "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro/testing:micro_test", ], ) @@ -480,7 +459,7 @@ tflite_micro_cc_test( deps = [ ":all_ops_resolver", "//tensorflow/lite/c:common", - "//tensorflow/lite/micro:micro_framework", + "//tensorflow/lite/micro:debug_log", "//tensorflow/lite/micro/testing:micro_test", ], ) @@ -493,7 +472,7 @@ tflite_micro_cc_test( deps = [ ":all_ops_resolver", "//tensorflow/lite/c:common", - "//tensorflow/lite/micro:micro_framework", + "//tensorflow/lite/micro:debug_log", "//tensorflow/lite/micro/testing:micro_test", ], ) @@ -506,7 +485,7 @@ tflite_micro_cc_test( deps = [ ":all_ops_resolver", "//tensorflow/lite/c:common", - "//tensorflow/lite/micro:micro_framework", + "//tensorflow/lite/micro:debug_log", "//tensorflow/lite/micro/testing:micro_test", ], ) @@ -552,7 +531,6 @@ tflite_micro_cc_test( ], deps = [ ":all_ops_resolver", - ":micro_utils", "//tensorflow/lite/c:common", "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro/testing:micro_test", @@ -566,7 +544,6 @@ tflite_micro_cc_test( ], deps = [ ":all_ops_resolver", - ":micro_utils", "//tensorflow/lite/c:common", "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro/testing:micro_test", @@ -602,7 +579,6 @@ tflite_micro_cc_test( deps = [ ":all_ops_resolver", "//tensorflow/lite/c:common", - "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro/testing:micro_test", ], ) @@ -614,9 +590,7 @@ tflite_micro_cc_test( ], deps = [ ":all_ops_resolver", - ":micro_utils", "//tensorflow/lite/c:common", - "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro/testing:micro_test", ], ) @@ -666,7 +640,6 @@ tflite_micro_cc_test( ], deps = [ ":all_ops_resolver", - ":micro_ops", "//tensorflow/lite/c:common", "//tensorflow/lite/micro/testing:micro_test", ], @@ -679,7 +652,6 @@ tflite_micro_cc_test( ], deps = [ ":all_ops_resolver", - ":micro_ops", "//tensorflow/lite/c:common", "//tensorflow/lite/micro/testing:micro_test", ], diff --git a/tensorflow/lite/micro/kernels/arc/conv.cc b/tensorflow/lite/micro/kernels/arc/conv.cc deleted file mode 100644 index 69542e12e90..00000000000 --- a/tensorflow/lite/micro/kernels/arc/conv.cc +++ /dev/null @@ -1,343 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/kernels/internal/reference/conv.h" - -#include "mli_api.h" // NOLINT -#include "tensorflow/lite/c/builtin_op_data.h" -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/kernels/internal/common.h" -#include "tensorflow/lite/kernels/internal/quantization_util.h" -#include "tensorflow/lite/kernels/internal/reference/integer_ops/conv.h" -#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" -#include "tensorflow/lite/kernels/kernel_util.h" -#include "tensorflow/lite/kernels/padding.h" -#include "tensorflow/lite/micro/kernels/arc/mli_tf_utils.h" - -namespace tflite { -namespace ops { -namespace micro { -namespace conv { - -constexpr int kInputTensor = 0; -constexpr int kFilterTensor = 1; -constexpr int kBiasTensor = 2; -constexpr int kOutputTensor = 0; -constexpr int kMaxChannels = 256; - -// This file has 2 implementation of Conv. - -const int kTensorNotAllocated = -1; - -struct OpData { - TfLitePaddingValues padding; - // The scaling factor from input to output (aka the 'real multiplier') can - // be represented as a fixed point multiplier plus a left shift. - int32_t output_multiplier; - int output_shift; - - // Per channel output multiplier and shift. - // TODO(b/141139247): Allocate these dynamically when possible. - int32_t per_channel_output_multiplier[kMaxChannels]; - int32_t per_channel_output_shift[kMaxChannels]; - - // The range of the fused activation layer. For example for kNone and - // uint8_t these would be 0 and 255. - int32_t output_activation_min; - int32_t output_activation_max; -}; - -inline PaddingType RuntimePaddingType(TfLitePadding padding) { - switch (padding) { - case TfLitePadding::kTfLitePaddingSame: - return PaddingType::kSame; - case TfLitePadding::kTfLitePaddingValid: - return PaddingType::kValid; - case TfLitePadding::kTfLitePaddingUnknown: - default: - return PaddingType::kNone; - } -} - -TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node, - TfLiteConvParams* params, int width, int height, - int filter_width, int filter_height, int out_width, - int out_height, const TfLiteType data_type, - OpData* data) { - bool has_bias = node->inputs->size == 3; - // Check number of inputs/outputs - TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2); - TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); - - // Matching GetWindowedOutputSize in TensorFlow. - auto padding = params->padding; - data->padding = ComputePaddingHeightWidth( - params->stride_height, params->stride_width, - params->dilation_height_factor, params->dilation_width_factor, height, - width, filter_height, filter_width, padding, &out_height, &out_width); - - // Note that quantized inference requires that all tensors have their - // parameters set. This is usually done during quantized training. - if (data_type != kTfLiteFloat32) { - const TfLiteTensor* input = GetInput(context, node, kInputTensor); - const TfLiteTensor* filter = GetInput(context, node, kFilterTensor); - const TfLiteTensor* bias = - GetOptionalInputTensor(context, node, kBiasTensor); - TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - - TF_LITE_ENSURE_STATUS(tflite::PopulateConvolutionQuantizationParams( - context, input, filter, bias, output, params->activation, - &data->output_multiplier, &data->output_shift, - &data->output_activation_min, &data->output_activation_max, - data->per_channel_output_multiplier, - reinterpret_cast<int*>(data->per_channel_output_shift))); - } - return kTfLiteOk; -} - -void EvalQuantized(TfLiteContext* context, TfLiteNode* node, - TfLiteConvParams* params, OpData* data, - const TfLiteTensor* input, const TfLiteTensor* filter, - const TfLiteTensor* bias, TfLiteTensor* im2col, - TfLiteTensor* hwcn_weights, TfLiteTensor* output) { - const int32_t input_offset = -input->params.zero_point; - const int32_t filter_offset = -filter->params.zero_point; - const int32_t output_offset = output->params.zero_point; - - ConvParams op_params; - op_params.padding_type = RuntimePaddingType(params->padding); - op_params.padding_values.width = data->padding.width; - op_params.padding_values.height = data->padding.height; - op_params.stride_width = params->stride_width; - op_params.stride_height = params->stride_height; - op_params.dilation_width_factor = params->dilation_width_factor; - op_params.dilation_height_factor = params->dilation_height_factor; - op_params.input_offset = input_offset; - op_params.weights_offset = filter_offset; - op_params.output_offset = output_offset; - op_params.output_multiplier = data->output_multiplier; - op_params.output_shift = -data->output_shift; - op_params.quantized_activation_min = data->output_activation_min; - op_params.quantized_activation_max = data->output_activation_max; - reference_ops::Conv(op_params, GetTensorShape(input), - GetTensorData<uint8_t>(input), GetTensorShape(filter), - GetTensorData<uint8_t>(filter), GetTensorShape(bias), - GetTensorData<int32_t>(bias), GetTensorShape(output), - GetTensorData<uint8_t>(output), GetTensorShape(im2col), - GetTensorData<uint8_t>(im2col), nullptr); -} - -void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node, - TfLiteConvParams* params, OpData* data, - const TfLiteTensor* input, - const TfLiteTensor* filter, - const TfLiteTensor* bias, TfLiteTensor* output, - TfLiteTensor* im2col) { - // Run Conv MLI kernel - // MLI optimized version only supports int8 dataype and dilation factor of 1 - if ((input->type == kTfLiteInt8) && (params->dilation_width_factor == 1) && - (params->dilation_height_factor == 1)) { - mli_tensor mli_in = {0}; - mli_tensor mli_weights = {0}; - mli_tensor mli_bias = {0}; - mli_tensor mli_out = {0}; - mli_conv2d_cfg cfg = {}; - - // reuse space allocated for OpData parameters - mli_weights.el_params.asym.scale.pi16 = - (int16_t*)data->per_channel_output_multiplier; - mli_bias.el_params.asym.scale.pi16 = - (int16_t*)data->per_channel_output_shift; - - int16_t filter_zero_point = 0; - int16_t bias_zero_point = 0; - mli_weights.el_params.asym.zero_point.pi16 = &filter_zero_point; - mli_bias.el_params.asym.zero_point.pi16 = &bias_zero_point; - - ConvertToMliTensor<int8_t>(input, &mli_in); - ConvertToMliTensorPerChannel<int8_t>(filter, &mli_weights); - ConvertToMliTensorPerChannel<int32_t>(bias, &mli_bias); - ConvertToMliTensor<int8_t>(output, &mli_out); - - if (params->activation == kTfLiteActRelu) { - cfg.relu.type = MLI_RELU_GEN; - } else if (params->activation == kTfLiteActRelu6) { - cfg.relu.type = MLI_RELU_6; - } else if (params->activation == kTfLiteActRelu1) { - cfg.relu.type = MLI_RELU_1; - } else { - cfg.relu.type = MLI_RELU_NONE; - } - - cfg.stride_width = params->stride_width; - cfg.stride_height = params->stride_height; - if (params->padding == kTfLitePaddingValid) { - cfg.padding_left = 0; - cfg.padding_right = 0; - cfg.padding_top = 0; - cfg.padding_bottom = 0; - } else { - cfg.padding_left = data->padding.width; - cfg.padding_right = data->padding.width + data->padding.width_offset; - cfg.padding_top = data->padding.height; - cfg.padding_bottom = data->padding.height + data->padding.height_offset; - } - - mli_point_to_subtsr_cfg substr_cfg_in = { - {0, 0}, 2, static_cast<uint8_t>(mli_in.shape[1])}; - mli_point_to_subtsr_cfg substr_cfg_out = { - {0, 0}, 2, static_cast<uint8_t>(mli_out.shape[1])}; - mli_tensor sub_mli_in = {0}; - mli_tensor sub_mli_out = {0}; - - const int batches = - MatchingDim(GetTensorShape(input), 0, GetTensorShape(output), 0); - - for (int i = 0; i < batches; i++) { - substr_cfg_in.start_coord[0] = i; - substr_cfg_out.start_coord[0] = i; - mli_hlp_point_to_subtensor(&mli_in, &substr_cfg_in, &sub_mli_in); - mli_hlp_point_to_subtensor(&mli_out, &substr_cfg_out, &sub_mli_out); - - mli_krn_conv2d_hwc_sa8_sa8_sa32(&sub_mli_in, &mli_weights, &mli_bias, - &cfg, &sub_mli_out); - } - } else { - ConvParams op_params; - op_params.input_offset = -input->params.zero_point; - op_params.output_offset = output->params.zero_point; - op_params.stride_height = params->stride_height; - op_params.stride_width = params->stride_width; - op_params.dilation_height_factor = params->dilation_height_factor; - op_params.dilation_width_factor = params->dilation_width_factor; - op_params.padding_values.height = data->padding.height; - op_params.padding_values.width = data->padding.width; - - reference_integer_ops::ConvPerChannel( - op_params, data->per_channel_output_multiplier, - data->per_channel_output_shift, GetTensorShape(input), - GetTensorData<int8>(input), GetTensorShape(filter), - GetTensorData<int8>(filter), GetTensorShape(bias), - GetTensorData<int32>(bias), GetTensorShape(output), - GetTensorData<int8>(output)); - } -} - -void EvalFloat(TfLiteContext* context, TfLiteNode* node, - TfLiteConvParams* params, OpData* data, - const TfLiteTensor* input, const TfLiteTensor* filter, - const TfLiteTensor* bias, TfLiteTensor* im2col, - TfLiteTensor* hwcn_weights, TfLiteTensor* output) { - float output_activation_min, output_activation_max; - CalculateActivationRange(params->activation, &output_activation_min, - &output_activation_max); - - ConvParams op_params; - op_params.padding_type = RuntimePaddingType(params->padding); - op_params.padding_values.width = data->padding.width; - op_params.padding_values.height = data->padding.height; - op_params.stride_width = params->stride_width; - op_params.stride_height = params->stride_height; - op_params.dilation_width_factor = params->dilation_width_factor; - op_params.dilation_height_factor = params->dilation_height_factor; - op_params.float_activation_min = output_activation_min; - op_params.float_activation_max = output_activation_max; - - reference_ops::Conv(op_params, GetTensorShape(input), - GetTensorData<float>(input), GetTensorShape(filter), - GetTensorData<float>(filter), GetTensorShape(bias), - GetTensorData<float>(bias), GetTensorShape(output), - GetTensorData<float>(output), GetTensorShape(im2col), - GetTensorData<float>(im2col)); -} - -TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - auto* params = reinterpret_cast<TfLiteConvParams*>(node->builtin_data); - - TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - const TfLiteTensor* input = GetInput(context, node, kInputTensor); - const TfLiteTensor* filter = GetInput(context, node, kFilterTensor); - const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); - - int input_width = input->dims->data[2]; - int input_height = input->dims->data[1]; - int filter_width = filter->dims->data[2]; - int filter_height = filter->dims->data[1]; - int output_width = output->dims->data[2]; - int output_height = output->dims->data[1]; - - OpData data; - - // All per-channel quantized tensors need valid zero point and scale arrays. - if (input->type == kTfLiteInt8) { - TF_LITE_ENSURE_EQ(context, filter->quantization.type, - kTfLiteAffineQuantization); - - const auto* affine_quantization = - reinterpret_cast<TfLiteAffineQuantization*>( - filter->quantization.params); - TF_LITE_ENSURE(context, affine_quantization); - TF_LITE_ENSURE(context, affine_quantization->scale); - TF_LITE_ENSURE(context, affine_quantization->zero_point); - // Conv is quantized along dimension 0: - // https://www.tensorflow.org/lite/performance/quantization_spec - TF_LITE_ENSURE_EQ(context, filter->dims->data[0], - affine_quantization->scale->size); - TF_LITE_ENSURE_EQ(context, filter->dims->data[0], - affine_quantization->zero_point->size); - } - - TF_LITE_ENSURE_STATUS(CalculateOpData( - context, node, params, input_width, input_height, filter_width, - filter_height, output_width, output_height, input->type, &data)); - - switch (input->type) { // Already know in/out types are same. - case kTfLiteFloat32: - EvalFloat(context, node, params, &data, input, filter, bias, nullptr, - nullptr, output); - break; - case kTfLiteInt8: - EvalQuantizedPerChannel(context, node, params, &data, input, filter, bias, - output, nullptr); - break; - case kTfLiteUInt8: - EvalQuantized(context, node, params, &data, input, filter, bias, nullptr, - nullptr, output); - break; - default: - TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.", - TfLiteTypeGetName(input->type), input->type); - return kTfLiteError; - } - return kTfLiteOk; -} - -} // namespace conv - -TfLiteRegistration* Register_CONV_2D() { - static TfLiteRegistration r = {/*init=*/nullptr, - /*free=*/nullptr, - /*prepare=*/nullptr, - /*invoke=*/conv::Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; - return &r; -} - -} // namespace micro -} // namespace ops -} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/arc/depthwise_conv.cc b/tensorflow/lite/micro/kernels/arc/depthwise_conv.cc deleted file mode 100644 index 6322414f5c6..00000000000 --- a/tensorflow/lite/micro/kernels/arc/depthwise_conv.cc +++ /dev/null @@ -1,344 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/kernels/internal/reference/integer_ops/depthwise_conv.h" - -#include "mli_api.h" // NOLINT -#include "tensorflow/lite/c/builtin_op_data.h" -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/kernels/internal/common.h" -#include "tensorflow/lite/kernels/internal/quantization_util.h" -#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h" -#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h" -#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" -#include "tensorflow/lite/kernels/kernel_util.h" -#include "tensorflow/lite/kernels/padding.h" -#include "tensorflow/lite/micro/kernels/arc/mli_tf_utils.h" - -namespace tflite { -namespace ops { -namespace micro { -namespace depthwise_conv { -namespace { - -constexpr int kInputTensor = 0; -constexpr int kFilterTensor = 1; -constexpr int kBiasTensor = 2; -constexpr int kOutputTensor = 0; -constexpr int kMaxChannels = 256; - -struct OpData { - TfLitePaddingValues padding; - // The scaling factor from input to output (aka the 'real multiplier') can - // be represented as a fixed point multiplier plus a left shift. - int32_t output_multiplier; - int output_shift; - - // Per channel output multiplier and shift. - // TODO(b/141139247): Allocate these dynamically when possible. - int32_t per_channel_output_multiplier[kMaxChannels]; - int32_t per_channel_output_shift[kMaxChannels]; - - // The range of the fused activation layer. For example for kNone and - // uint8_t these would be 0 and 255. - int32_t output_activation_min; - int32_t output_activation_max; -}; - -TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node, - TfLiteDepthwiseConvParams* params, int width, - int height, int filter_width, int filter_height, - const TfLiteType data_type, OpData* data) { - bool has_bias = node->inputs->size == 3; - // Check number of inputs/outputs - TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2); - TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); - - int unused_output_height, unused_output_width; - data->padding = ComputePaddingHeightWidth( - params->stride_height, params->stride_width, 1, 1, height, width, - filter_height, filter_width, params->padding, &unused_output_height, - &unused_output_width); - - // Note that quantized inference requires that all tensors have their - // parameters set. This is usually done during quantized training. - if (data_type != kTfLiteFloat32) { - const TfLiteTensor* input = GetInput(context, node, kInputTensor); - const TfLiteTensor* filter = GetInput(context, node, kFilterTensor); - const TfLiteTensor* bias = - GetOptionalInputTensor(context, node, kBiasTensor); - TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - - // Ensure filter and bias channel count does not exceed space reserved for - // quantization metadata. - const auto filter_quantization = - reinterpret_cast<TfLiteAffineQuantization*>( - filter->quantization.params); - const auto bias_quantization = - reinterpret_cast<TfLiteAffineQuantization*>(bias->quantization.params); - TF_LITE_ENSURE(context, filter_quantization->scale->size <= kMaxChannels); - TF_LITE_ENSURE(context, bias_quantization->scale->size <= kMaxChannels); - - TF_LITE_ENSURE_STATUS(tflite::PopulateConvolutionQuantizationParams( - context, input, filter, bias, output, params->activation, - &data->output_multiplier, &data->output_shift, - &data->output_activation_min, &data->output_activation_max, - data->per_channel_output_multiplier, - reinterpret_cast<int*>(data->per_channel_output_shift))); - } - return kTfLiteOk; -} - -} // namespace - -void EvalFloat(TfLiteContext* context, TfLiteNode* node, - TfLiteDepthwiseConvParams* params, OpData* data, - const TfLiteTensor* input, const TfLiteTensor* filter, - const TfLiteTensor* bias, TfLiteTensor* output) { - float output_activation_min, output_activation_max; - CalculateActivationRange(params->activation, &output_activation_min, - &output_activation_max); - - tflite::DepthwiseParams op_params; - // Padding type is ignored, but still set. - op_params.padding_type = PaddingType::kSame; - op_params.padding_values.width = data->padding.width; - op_params.padding_values.height = data->padding.height; - op_params.stride_width = params->stride_width; - op_params.stride_height = params->stride_height; - op_params.dilation_width_factor = params->dilation_width_factor; - op_params.dilation_height_factor = params->dilation_height_factor; - op_params.depth_multiplier = params->depth_multiplier; - op_params.float_activation_min = output_activation_min; - op_params.float_activation_max = output_activation_max; - - tflite::reference_ops::DepthwiseConv( - op_params, GetTensorShape(input), GetTensorData<float>(input), - GetTensorShape(filter), GetTensorData<float>(filter), - GetTensorShape(bias), GetTensorData<float>(bias), GetTensorShape(output), - GetTensorData<float>(output)); -} - -void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node, - TfLiteDepthwiseConvParams* params, OpData* data, - const TfLiteTensor* input, - const TfLiteTensor* filter, - const TfLiteTensor* bias, TfLiteTensor* output) { - // Run Depthwise Conv MLI kernel - // MLI optimized version only supports int8 dataype and dilation factor of 1 - if ((input->type == kTfLiteInt8) && (params->dilation_width_factor == 1) && - (params->dilation_height_factor == 1)) { - mli_tensor mli_in = {0}; - mli_tensor mli_weights = {0}; - mli_tensor mli_bias = {0}; - mli_tensor mli_out = {0}; - mli_conv2d_cfg cfg = {}; - - // reuse space allocated for OpData parameters - mli_weights.el_params.asym.scale.pi16 = - (int16_t*)data->per_channel_output_multiplier; - mli_bias.el_params.asym.scale.pi16 = - (int16_t*)data->per_channel_output_shift; - - int16_t filter_zero_point = 0; - int16_t bias_zero_point = 0; - mli_weights.el_params.asym.zero_point.pi16 = &filter_zero_point; - mli_bias.el_params.asym.zero_point.pi16 = &bias_zero_point; - - ConvertToMliTensor<int8_t>(input, &mli_in); - ConvertToMliTensorPerChannel<int8_t>(filter, &mli_weights); - ConvertToMliTensorPerChannel<int32_t>(bias, &mli_bias); - ConvertToMliTensor<int8_t>(output, &mli_out); - - if (params->activation == kTfLiteActRelu) { - cfg.relu.type = MLI_RELU_GEN; - } else if (params->activation == kTfLiteActRelu6) { - cfg.relu.type = MLI_RELU_6; - } else if (params->activation == kTfLiteActRelu1) { - cfg.relu.type = MLI_RELU_1; - } else { - cfg.relu.type = MLI_RELU_NONE; - } - - cfg.stride_width = params->stride_width; - cfg.stride_height = params->stride_height; - if (params->padding == kTfLitePaddingValid) { - cfg.padding_left = 0; - cfg.padding_right = 0; - cfg.padding_top = 0; - cfg.padding_bottom = 0; - } else { - cfg.padding_left = data->padding.width; - cfg.padding_right = data->padding.width + data->padding.width_offset; - cfg.padding_top = data->padding.height; - cfg.padding_bottom = data->padding.height + data->padding.height_offset; - } - - mli_point_to_subtsr_cfg substr_cfg_in = { - {0, 0}, 2, static_cast<uint8_t>(mli_in.shape[1])}; - mli_point_to_subtsr_cfg substr_cfg_out = { - {0, 0}, 2, static_cast<uint8_t>(mli_out.shape[1])}; - mli_tensor sub_mli_in = {0}; - mli_tensor sub_mli_out = {0}; - - const int batches = - MatchingDim(GetTensorShape(input), 0, GetTensorShape(output), 0); - - for (int i = 0; i < batches; i++) { - substr_cfg_in.start_coord[0] = i; - substr_cfg_out.start_coord[0] = i; - mli_hlp_point_to_subtensor(&mli_in, &substr_cfg_in, &sub_mli_in); - mli_hlp_point_to_subtensor(&mli_out, &substr_cfg_out, &sub_mli_out); - - mli_krn_depthwise_conv2d_hwc_sa8_sa8_sa32(&sub_mli_in, &mli_weights, - &mli_bias, &cfg, &sub_mli_out); - } - } else { - DepthwiseParams op_params; - op_params.padding_type = PaddingType::kSame; - op_params.padding_values.width = data->padding.width; - op_params.padding_values.height = data->padding.height; - op_params.stride_width = params->stride_width; - op_params.stride_height = params->stride_height; - op_params.dilation_width_factor = params->dilation_width_factor; - op_params.dilation_height_factor = params->dilation_height_factor; - op_params.depth_multiplier = params->depth_multiplier; - op_params.input_offset = -input->params.zero_point; - op_params.weights_offset = 0; - op_params.output_offset = output->params.zero_point; - // TODO(b/130439627): Use calculated value for clamping. - op_params.quantized_activation_min = std::numeric_limits<int8_t>::min(); - op_params.quantized_activation_max = std::numeric_limits<int8_t>::max(); - - reference_integer_ops::DepthwiseConvPerChannel( - op_params, data->per_channel_output_multiplier, - data->per_channel_output_shift, GetTensorShape(input), - GetTensorData<int8>(input), GetTensorShape(filter), - GetTensorData<int8>(filter), GetTensorShape(bias), - GetTensorData<int32>(bias), GetTensorShape(output), - GetTensorData<int8>(output)); - } -} - -void EvalQuantized(TfLiteContext* context, TfLiteNode* node, - TfLiteDepthwiseConvParams* params, OpData* data, - const TfLiteTensor* input, const TfLiteTensor* filter, - const TfLiteTensor* bias, TfLiteTensor* output) { - const int32_t input_offset = -input->params.zero_point; - const int32_t filter_offset = -filter->params.zero_point; - const int32_t output_offset = output->params.zero_point; - - tflite::DepthwiseParams op_params; - // Padding type is ignored, but still set. - op_params.padding_type = PaddingType::kSame; - op_params.padding_values.width = data->padding.width; - op_params.padding_values.height = data->padding.height; - op_params.stride_width = params->stride_width; - op_params.stride_height = params->stride_height; - op_params.dilation_width_factor = params->dilation_width_factor; - op_params.dilation_height_factor = params->dilation_height_factor; - op_params.depth_multiplier = params->depth_multiplier; - op_params.quantized_activation_min = data->output_activation_min; - op_params.quantized_activation_max = data->output_activation_max; - op_params.input_offset = input_offset; - op_params.weights_offset = filter_offset; - op_params.output_offset = output_offset; - op_params.output_multiplier = data->output_multiplier; - // Legacy ops used mixed left and right shifts. Now all are +ve-means-left. - op_params.output_shift = -data->output_shift; - - tflite::reference_ops::DepthwiseConv( - op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), - GetTensorShape(filter), GetTensorData<uint8_t>(filter), - GetTensorShape(bias), GetTensorData<int32_t>(bias), - GetTensorShape(output), GetTensorData<uint8_t>(output)); -} - -TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - auto* params = - reinterpret_cast<TfLiteDepthwiseConvParams*>(node->builtin_data); - - TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - const TfLiteTensor* input = GetInput(context, node, kInputTensor); - const TfLiteTensor* filter = GetInput(context, node, kFilterTensor); - const TfLiteTensor* bias = - (NumInputs(node) == 3) ? GetInput(context, node, kBiasTensor) : nullptr; - - const TfLiteType data_type = input->type; - int width = SizeOfDimension(input, 2); - int height = SizeOfDimension(input, 1); - int filter_width = SizeOfDimension(filter, 2); - int filter_height = SizeOfDimension(filter, 1); - - OpData data; - - // All per-channel quantized tensors need valid zero point and scale arrays. - if (input->type == kTfLiteInt8) { - TF_LITE_ENSURE_EQ(context, filter->quantization.type, - kTfLiteAffineQuantization); - - const auto* affine_quantization = - reinterpret_cast<TfLiteAffineQuantization*>( - filter->quantization.params); - TF_LITE_ENSURE(context, affine_quantization); - TF_LITE_ENSURE(context, affine_quantization->scale); - TF_LITE_ENSURE(context, affine_quantization->zero_point); - // Depthwise conv is quantized along dimension 3: - // https://www.tensorflow.org/lite/performance/quantization_spec - TF_LITE_ENSURE_EQ(context, filter->dims->data[3], - affine_quantization->scale->size); - TF_LITE_ENSURE_EQ(context, filter->dims->data[3], - affine_quantization->zero_point->size); - } - - TF_LITE_ENSURE_STATUS(CalculateOpData(context, node, params, width, height, - filter_width, filter_height, data_type, - &data)); - switch (input->type) { // Already know in/out types are same. - case kTfLiteFloat32: - EvalFloat(context, node, params, &data, input, filter, bias, output); - break; - case kTfLiteInt8: - EvalQuantizedPerChannel(context, node, params, &data, input, filter, bias, - output); - break; - case kTfLiteUInt8: - EvalQuantized(context, node, params, &data, input, filter, bias, output); - break; - default: - TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.", - TfLiteTypeGetName(input->type), input->type); - return kTfLiteError; - } - return kTfLiteOk; -} - -} // namespace depthwise_conv - -TfLiteRegistration* Register_DEPTHWISE_CONV_2D() { - static TfLiteRegistration r = {/*init=*/nullptr, - /*free=*/nullptr, - /*prepare=*/nullptr, - /*invoke=*/depthwise_conv::Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; - return &r; -} - -} // namespace micro -} // namespace ops -} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/arc/fully_connected.cc b/tensorflow/lite/micro/kernels/arc/fully_connected.cc deleted file mode 100644 index 57203f10487..00000000000 --- a/tensorflow/lite/micro/kernels/arc/fully_connected.cc +++ /dev/null @@ -1,248 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/kernels/internal/reference/fully_connected.h" - -#include "mli_api.h" // NOLINT -#include "tensorflow/lite/c/builtin_op_data.h" -#include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/kernels/internal/common.h" -#include "tensorflow/lite/kernels/internal/quantization_util.h" -#include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h" -#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" -#include "tensorflow/lite/kernels/kernel_util.h" -#include "tensorflow/lite/micro/kernels/arc/mli_tf_utils.h" - -namespace tflite { -namespace ops { -namespace micro { -namespace fully_connected { -namespace { - -struct OpData { - // The scaling factor from input to output (aka the 'real multiplier') can - // be represented as a fixed point multiplier plus a left shift. - int32_t output_multiplier; - int output_shift; - // The range of the fused activation layer. For example for kNone and - // uint8_t these would be 0 and 255. - int32_t output_activation_min; - int32_t output_activation_max; - // The index of the temporary tensor where the quantized inputs are cached. - int input_quantized_index; -}; - -constexpr int kInputTensor = 0; -constexpr int kWeightsTensor = 1; -constexpr int kBiasTensor = 2; -constexpr int kOutputTensor = 0; - -TfLiteStatus CalculateOpData(TfLiteContext* context, - TfLiteFullyConnectedParams* params, - TfLiteType data_type, const TfLiteTensor* input, - const TfLiteTensor* filter, - const TfLiteTensor* bias, TfLiteTensor* output, - OpData* data) { - TfLiteStatus status = kTfLiteOk; - if (data_type != kTfLiteFloat32) { - double real_multiplier = 0.0; - TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( - context, input, filter, bias, output, &real_multiplier)); - int exponent; - QuantizeMultiplier(real_multiplier, &data->output_multiplier, &exponent); - data->output_shift = -exponent; - TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized( - context, params->activation, output, &data->output_activation_min, - &data->output_activation_max)); - } - return status; -} - -} // namespace - -TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node, - TfLiteFullyConnectedParams* params, OpData* data, - const TfLiteTensor* input, - const TfLiteTensor* filter, - const TfLiteTensor* bias, TfLiteTensor* output) { - // Run Fully Connected MLI kernel - // MLI optimized version only supports int8 dataype and no fused Relu - // TODO: subject to add mli_saturate kernel - // work around for issue #35318, mli fully connect kernel only supports - // zeropoint == 0 for weights. this check can be removed once issue #35318 is - // resolved. - if ((filter->params.zero_point == 0) && - (input->type == kTfLiteInt8 && params->activation == kTfLiteActNone)) { - mli_tensor mli_in = {0}; - mli_tensor mli_weights = {0}; - mli_tensor mli_bias = {0}; - mli_tensor mli_out = {0}; - - ConvertToMliTensor<int8_t>(input, &mli_in); - ConvertToMliTensor<int8_t>(filter, &mli_weights); - ConvertToMliTensor<int32_t>(bias, &mli_bias); - ConvertToMliTensor<int8_t>(output, &mli_out); - - mli_point_to_subtsr_cfg substr_cfg_in = { - {0, 0}, 2, static_cast<uint8_t>(mli_in.shape[1])}; - mli_point_to_subtsr_cfg substr_cfg_out = { - {0, 0}, 2, static_cast<uint8_t>(mli_out.shape[1])}; - mli_tensor sub_mli_in = {0}; - mli_tensor sub_mli_out = {0}; - - const int batches = - MatchingDim(GetTensorShape(input), 0, GetTensorShape(output), 0); - - for (int i = 0; i < batches; i++) { - substr_cfg_in.start_coord[0] = i; - substr_cfg_out.start_coord[0] = i; - mli_hlp_point_to_subtensor(&mli_in, &substr_cfg_in, &sub_mli_in); - mli_hlp_point_to_subtensor(&mli_out, &substr_cfg_out, &sub_mli_out); - - mli_krn_fully_connected_sa8_sa8_sa32(&sub_mli_in, &mli_weights, &mli_bias, - &sub_mli_out); - } - } else { - FullyConnectedParams op_params; - op_params.input_offset = -input->params.zero_point; - op_params.weights_offset = -filter->params.zero_point; - op_params.output_offset = output->params.zero_point; - op_params.output_multiplier = data->output_multiplier; - // TODO(b/138810107): Figure out whether output shift should be inverted - op_params.output_shift = -data->output_shift; - op_params.quantized_activation_min = data->output_activation_min; - op_params.quantized_activation_max = data->output_activation_max; - - reference_integer_ops::FullyConnected( - op_params, GetTensorShape(input), GetTensorData<int8_t>(input), - GetTensorShape(filter), GetTensorData<int8_t>(filter), - GetTensorShape(bias), GetTensorData<int32_t>(bias), - GetTensorShape(output), GetTensorData<int8_t>(output)); - } - return kTfLiteOk; -} - -TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, - TfLiteFullyConnectedParams* params, OpData* data, - const TfLiteTensor* input, - const TfLiteTensor* filter, const TfLiteTensor* bias, - TfLiteTensor* output) { - const int32_t input_offset = -input->params.zero_point; - const int32_t filter_offset = -filter->params.zero_point; - const int32_t output_offset = output->params.zero_point; - - tflite::FullyConnectedParams op_params; - op_params.input_offset = input_offset; - op_params.weights_offset = filter_offset; - op_params.output_offset = output_offset; - op_params.output_multiplier = data->output_multiplier; - // Legacy ops used mixed left and right shifts. Now all are +ve-means-left. - op_params.output_shift = -data->output_shift; - op_params.quantized_activation_min = data->output_activation_min; - op_params.quantized_activation_max = data->output_activation_max; - -#define TF_LITE_FULLY_CONNECTED(output_data_type) \ - reference_ops::FullyConnected( \ - op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), \ - GetTensorShape(filter), GetTensorData<uint8_t>(filter), \ - GetTensorShape(bias), GetTensorData<int32_t>(bias), \ - GetTensorShape(output), GetTensorData<output_data_type>(output)) - switch (output->type) { - case kTfLiteUInt8: - TF_LITE_FULLY_CONNECTED(uint8_t); - break; - case kTfLiteInt16: - TF_LITE_FULLY_CONNECTED(int16_t); - break; - default: - TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.", - TfLiteTypeGetName(output->type), output->type); - return kTfLiteError; - } - - return kTfLiteOk; -} - -TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node, - TfLiteFullyConnectedParams* params, OpData* data, - const TfLiteTensor* input, const TfLiteTensor* filter, - const TfLiteTensor* bias, TfLiteTensor* output) { - float output_activation_min, output_activation_max; - CalculateActivationRange(params->activation, &output_activation_min, - &output_activation_max); - tflite::FullyConnectedParams op_params; - op_params.float_activation_min = output_activation_min; - op_params.float_activation_max = output_activation_max; - tflite::reference_ops::FullyConnected( - op_params, GetTensorShape(input), GetTensorData<float>(input), - GetTensorShape(filter), GetTensorData<float>(filter), - GetTensorShape(bias), GetTensorData<float>(bias), GetTensorShape(output), - GetTensorData<float>(output)); - return kTfLiteOk; -} - -TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - auto* params = - reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data); - - const TfLiteTensor* input = GetInput(context, node, kInputTensor); - const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor); - const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); - TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - - TfLiteType data_type = input->type; - OpData local_data_object; - OpData* data = &local_data_object; - TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, data_type, input, - filter, bias, output, data)); - - switch (filter->type) { // Already know in/out types are same. - case kTfLiteFloat32: - return EvalFloat(context, node, params, data, input, filter, bias, - output); - case kTfLiteInt8: - return EvalQuantizedInt8(context, node, params, data, input, filter, bias, - output); - - case kTfLiteUInt8: - return EvalQuantized(context, node, params, data, input, filter, bias, - output); - - default: - TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.", - TfLiteTypeGetName(filter->type), filter->type); - return kTfLiteError; - } - return kTfLiteOk; -} - -} // namespace fully_connected - -TfLiteRegistration* Register_FULLY_CONNECTED() { - static TfLiteRegistration r = {/*init=*/nullptr, - /*free=*/nullptr, - /*prepare=*/nullptr, - /*invoke=*/fully_connected::Eval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; - - return &r; -} - -} // namespace micro -} // namespace ops -} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/arc/pooling.cc b/tensorflow/lite/micro/kernels/arc/pooling.cc deleted file mode 100644 index 55452013028..00000000000 --- a/tensorflow/lite/micro/kernels/arc/pooling.cc +++ /dev/null @@ -1,292 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/lite/kernels/internal/reference/pooling.h" - -#include "mli_api.h" // NOLINT -#include "tensorflow/lite/c/builtin_op_data.h" -#include "tensorflow/lite/kernels/internal/reference/integer_ops/pooling.h" -#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" -#include "tensorflow/lite/kernels/kernel_util.h" -#include "tensorflow/lite/kernels/padding.h" -#include "tensorflow/lite/micro/kernels/arc/mli_tf_utils.h" - -namespace tflite { -namespace ops { -namespace micro { -namespace pooling { - -namespace { - -constexpr int kInputTensor = 0; -constexpr int kOutputTensor = 0; - -struct OpData { - TfLitePaddingValues padding; -}; - -TfLiteStatus CalculateOpData(const TfLiteContext* context, - const TfLitePoolParams* params, - const TfLiteTensor* input, - const TfLiteTensor* output, OpData* data) { - // input: batch, height, width, channel - int height = SizeOfDimension(input, 1); - int width = SizeOfDimension(input, 2); - - int out_height, out_width; - - data->padding = ComputePaddingHeightWidth( - params->stride_height, params->stride_width, - /*dilation_rate_height=*/1, - /*dilation_rate_width=*/1, height, width, params->filter_height, - params->filter_width, params->padding, &out_height, &out_width); - - return kTfLiteOk; -} - -void AverageEvalFloat(const TfLiteContext* context, const TfLiteNode* node, - const TfLitePoolParams* params, const OpData* data, - const TfLiteTensor* input, TfLiteTensor* output) { - float activation_min, activation_max; - CalculateActivationRange(params->activation, &activation_min, - &activation_max); - - PoolParams op_params; - op_params.stride_height = params->stride_height; - op_params.stride_width = params->stride_width; - op_params.filter_height = params->filter_height; - op_params.filter_width = params->filter_width; - op_params.padding_values.height = data->padding.height; - op_params.padding_values.width = data->padding.width; - op_params.float_activation_min = activation_min; - op_params.float_activation_max = activation_max; - reference_ops::AveragePool( - op_params, GetTensorShape(input), GetTensorData<float>(input), - GetTensorShape(output), GetTensorData<float>(output)); -} - -void AverageEvalUint8(TfLiteContext* context, const TfLiteNode* node, - const TfLitePoolParams* params, const OpData* data, - const TfLiteTensor* input, TfLiteTensor* output) { - int32_t activation_min, activation_max; - (void)CalculateActivationRangeQuantized(context, params->activation, output, - &activation_min, &activation_max); - - PoolParams op_params; - op_params.stride_height = params->stride_height; - op_params.stride_width = params->stride_width; - op_params.filter_height = params->filter_height; - op_params.filter_width = params->filter_width; - op_params.padding_values.height = data->padding.height; - op_params.padding_values.width = data->padding.width; - op_params.quantized_activation_min = activation_min; - op_params.quantized_activation_max = activation_max; - reference_ops::AveragePool( - op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), - GetTensorShape(output), GetTensorData<uint8_t>(output)); -} - -void AverageEvalInt8(TfLiteContext* context, const TfLiteNode* node, - const TfLitePoolParams* params, const OpData* data, - const TfLiteTensor* input, TfLiteTensor* output) { - // Run Average Pooling MLI kernel - // MLI optimized version only supports int8 dataype and no fused Relu - // TODO: subject to add mli_saturate kernel - if (input->type == kTfLiteInt8 && params->activation == kTfLiteActNone) { - mli_tensor mli_in = {0}; - mli_tensor mli_out = {0}; - mli_pool_cfg cfg = {0}; - - ConvertToMliTensor<int8_t>(input, &mli_in); - ConvertToMliTensor<int8_t>(output, &mli_out); - - cfg.kernel_width = params->filter_width; - cfg.kernel_height = params->filter_height; - cfg.stride_width = params->stride_width; - cfg.stride_height = params->stride_height; - - if (params->padding == kTfLitePaddingValid) { - cfg.padding_left = 0; - cfg.padding_right = 0; - cfg.padding_top = 0; - cfg.padding_bottom = 0; - } else { - cfg.padding_left = data->padding.width; - cfg.padding_right = data->padding.width + data->padding.width_offset; - cfg.padding_top = data->padding.height; - cfg.padding_bottom = data->padding.height + data->padding.height_offset; - } - - mli_point_to_subtsr_cfg substr_cfg_in = { - {0, 0}, 2, static_cast<uint8_t>(mli_in.shape[1])}; - mli_point_to_subtsr_cfg substr_cfg_out = { - {0, 0}, 2, static_cast<uint8_t>(mli_out.shape[1])}; - mli_tensor sub_mli_in = {0}; - mli_tensor sub_mli_out = {0}; - - const int batches = - MatchingDim(GetTensorShape(input), 0, GetTensorShape(output), 0); - - for (int i = 0; i < batches; i++) { - substr_cfg_in.start_coord[0] = i; - substr_cfg_out.start_coord[0] = i; - mli_hlp_point_to_subtensor(&mli_in, &substr_cfg_in, &sub_mli_in); - mli_hlp_point_to_subtensor(&mli_out, &substr_cfg_out, &sub_mli_out); - - mli_krn_avepool_hwc_sa8(&sub_mli_in, &cfg, &sub_mli_out); - } - } else { - int32_t activation_min, activation_max; - (void)CalculateActivationRangeQuantized(context, params->activation, output, - &activation_min, &activation_max); - PoolParams op_params; - op_params.stride_height = params->stride_height; - op_params.stride_width = params->stride_width; - op_params.filter_height = params->filter_height; - op_params.filter_width = params->filter_width; - op_params.padding_values.height = data->padding.height; - op_params.padding_values.width = data->padding.width; - op_params.quantized_activation_min = activation_min; - op_params.quantized_activation_max = activation_max; - reference_integer_ops::AveragePool( - op_params, GetTensorShape(input), GetTensorData<int8_t>(input), - GetTensorShape(output), GetTensorData<int8_t>(output)); - } -} - -void MaxEvalFloat(TfLiteContext* context, TfLiteNode* node, - TfLitePoolParams* params, OpData* data, - const TfLiteTensor* input, TfLiteTensor* output) { - float activation_min, activation_max; - CalculateActivationRange(params->activation, &activation_min, - &activation_max); - - tflite::PoolParams op_params; - op_params.stride_height = params->stride_height; - op_params.stride_width = params->stride_width; - op_params.filter_height = params->filter_height; - op_params.filter_width = params->filter_width; - op_params.padding_values.height = data->padding.height; - op_params.padding_values.width = data->padding.width; - op_params.float_activation_min = activation_min; - op_params.float_activation_max = activation_max; - reference_ops::MaxPool(op_params, GetTensorShape(input), - GetTensorData<float>(input), GetTensorShape(output), - GetTensorData<float>(output)); -} - -void MaxEvalQuantizedUInt8(TfLiteContext* context, TfLiteNode* node, - TfLitePoolParams* params, OpData* data, - const TfLiteTensor* input, TfLiteTensor* output) { - int32_t activation_min, activation_max; - (void)CalculateActivationRangeQuantized(context, params->activation, output, - &activation_min, &activation_max); - - tflite::PoolParams op_params; - op_params.stride_height = params->stride_height; - op_params.stride_width = params->stride_width; - op_params.filter_height = params->filter_height; - op_params.filter_width = params->filter_width; - op_params.padding_values.height = data->padding.height; - op_params.padding_values.width = data->padding.width; - op_params.quantized_activation_min = activation_min; - op_params.quantized_activation_max = activation_max; - reference_ops::MaxPool(op_params, GetTensorShape(input), - GetTensorData<uint8_t>(input), GetTensorShape(output), - GetTensorData<uint8_t>(output)); -} - -} // namespace - -TfLiteStatus AverageEval(TfLiteContext* context, TfLiteNode* node) { - auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data); - OpData data; - - const TfLiteTensor* input = GetInput(context, node, kInputTensor); - TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - - TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, input, output, &data)); - - // Inputs and outputs share the same type, guarenteed by the converter. - switch (input->type) { - case kTfLiteFloat32: - AverageEvalFloat(context, node, params, &data, input, output); - break; - case kTfLiteUInt8: - AverageEvalUint8(context, node, params, &data, input, output); - break; - case kTfLiteInt8: - AverageEvalInt8(context, node, params, &data, input, output); - break; - default: - TF_LITE_KERNEL_LOG(context, "Input type %s is not currently supported", - TfLiteTypeGetName(input->type)); - return kTfLiteError; - } - return kTfLiteOk; -} - -TfLiteStatus MaxEval(TfLiteContext* context, TfLiteNode* node) { - auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data); - OpData data; - - const TfLiteTensor* input = GetInput(context, node, kInputTensor); - TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - - TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, input, output, &data)); - - switch (input->type) { - case kTfLiteFloat32: - MaxEvalFloat(context, node, params, &data, input, output); - break; - case kTfLiteUInt8: - MaxEvalQuantizedUInt8(context, node, params, &data, input, output); - break; - default: - TF_LITE_KERNEL_LOG(context, "Type %s not currently supported.", - TfLiteTypeGetName(input->type)); - return kTfLiteError; - } - return kTfLiteOk; -} - -} // namespace pooling - -TfLiteRegistration* Register_AVERAGE_POOL_2D() { - static TfLiteRegistration r = {/*init=*/nullptr, - /*free=*/nullptr, - /*prepare=*/nullptr, - /*invoke=*/pooling::AverageEval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; - return &r; -} - -TfLiteRegistration* Register_MAX_POOL_2D() { - static TfLiteRegistration r = {/*init=*/nullptr, - /*free=*/nullptr, - /*prepare=*/nullptr, - /*invoke=*/pooling::MaxEval, - /*profiling_string=*/nullptr, - /*builtin_code=*/0, - /*custom_name=*/nullptr, - /*version=*/0}; - return &r; -} - -} // namespace micro -} // namespace ops -} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/arc_mli/README.md b/tensorflow/lite/micro/kernels/arc_mli/README.md new file mode 100644 index 00000000000..9bd0085f373 --- /dev/null +++ b/tensorflow/lite/micro/kernels/arc_mli/README.md @@ -0,0 +1,96 @@ +# EmbARC MLI Library Based Optimizations of TensorFlow Lite Micro Kernels for ARC Platforms. + +This folder contains kernel implementations which use optimized +[embARC MLI Library](https://github.com/foss-for-synopsys-dwc-arc-processors/embarc_mli). +It allows acceleration of inference operations which use int8 (asymmetric +quantization). + +## Usage + +embARC MLI Library is used by default to speed up execution of some kernels for +asymmetrically quantized layers. This means that usual project generation for +ARC specific target implies usage of embARC MLI. + +For example: + +``` +make -f tensorflow/lite/micro/tools/make/Makefile TARGET=arc_emsdp generate_person_detection_int8_make_project +``` + +In case MLI implementation can’t be used, kernels in this folder fallback to +TFLM reference implementations. For applications which may not benefit from MLI +library, projects can be generated without these implementations by adding +`TAGS=no_arc_mli` in the command line, which can reduce overall code size: + +``` +make -f tensorflow/lite/micro/tools/make/Makefile TARGET=arc_emsdp TAGS=no_arc_mli generate_person_detection_int8_make_project +``` + +For ARC EM SDP board, a pre-compiled MLI library is downloaded and used in the +application. For a custom target ARC-based platform, MLI sources are downloaded +and compiled during project generation phase. To build library from sources for +ARC EM SDP platform, add `BUILD_ARC_MLI=true` option to make command: + +``` +make -f tensorflow/lite/micro/tools/make/Makefile TARGET=arc_emsdp BUILD_ARC_MLI=true generate_person_detection_int8_make_project +``` + +If an application exclusively uses accelerated MLI kernel implementations, one +can strip out TFLM reference kernel implementations to reduce code size of +application. Build application with `MLI_ONLY=true` option in generated project +(after the project was built): + +``` +cd tensorflow/lite/micro/tools/make/gen/arc_emsdp_arc/prj/person_detection_int8/make + +make app MLI_ONLY=true +``` + +if you try this and application execution fails, then most probably MLI can’t be +used for some nodes and you need to revert to using TFLM reference kernels. + +## Limitations + +Currently, the MLI Library provides optimized implementation only for int8 +(asymmetric) versions of the following kernels: 1. Convolution 2D – Per axis +quantization only, `dilation_ratio==1` 2. Depthwise Convolution 2D – Per axis +quantization only, `dilation_ratio==1` 3. Average Pooling 4. Max Pooling 5. +Fully Connected + +Currently only +[/tensorflow/lite/micro/examples/person_detection_experimental](/tensorflow/lite/micro/examples/person_detection_experimental) +is quantized using this specification. Other examples can be executed on +ARC-based targets, but will only use reference kernels. + +## Scratch Buffers and Slicing + +The following information applies only for ARC EM SDP and other targets with XY +memory. embARC MLI uses specific optimizations which assumes node operands are +in XY memory and/or DCCM (Data Closely Coupled Memory). As operands might be +quite big and may not fit in available XY memory, special slicing logic is +applied which allows kernel calculations to be split into multiple parts. For +this reason, internal static buffers are allocated in these X, Y and DCCM memory +banks and used to execute sub-calculations. + +All this is performed automatically and invisible to the user. Half of the DCCM +memory bank and the full XY banks are occupied for MLI specific needs. If the +user needs space in XY memory for other tasks, these arrays can be reduced by +setting specific sizes. For this, add the following option to build command +replacing **<size[a|b|c]>** with required values: + +``` +EXT_CFLAGS=”-DSCRATCH_MEM_Z_SIZE=<size_a> -DSCRATCH_MEM_X_SIZE=<size_b> -DSCRATCH_MEM_Y_SIZE=<size_c>” +``` + +For example, to reduce sizes of arrays placed in DCCM and XCCM to 32k and 8k +respectively, use next command: + +``` +make app EXT_CFLAGS=”-DSCRATCH_MEM_Z_SIZE=32*1024 -DSCRATCH_MEM_X_SIZE=8*1024” +``` + +## License + +TensorFlow's code is covered by the Apache2 License included in the repository, +and third party dependencies are covered by their respective licenses, in the +third_party folder of this package. diff --git a/tensorflow/lite/micro/kernels/arc_mli/conv.cc b/tensorflow/lite/micro/kernels/arc_mli/conv.cc new file mode 100644 index 00000000000..66880b732cc --- /dev/null +++ b/tensorflow/lite/micro/kernels/arc_mli/conv.cc @@ -0,0 +1,490 @@ +/* Copyright 2019-2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/kernels/internal/reference/conv.h" + +#include "mli_api.h" // NOLINT +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/integer_ops/conv.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/padding.h" +#include "tensorflow/lite/micro/kernels/arc_mli/mli_slicers.h" +#include "tensorflow/lite/micro/kernels/arc_mli/mli_tf_utils.h" +#include "tensorflow/lite/micro/kernels/arc_mli/scratch_buf_mgr.h" +#include "tensorflow/lite/micro/kernels/arc_mli/scratch_buffers.h" + +namespace tflite { +namespace ops { +namespace micro { +namespace conv { + +constexpr int kInputTensor = 0; +constexpr int kFilterTensor = 1; +constexpr int kBiasTensor = 2; +constexpr int kOutputTensor = 0; +constexpr int kMaxChannels = 256; + +// Conv is quantized along dimension 0: +// https://www.tensorflow.org/lite/performance/quantization_spec +constexpr int kConvQuantizedDimension = 0; + +struct OpData { + TfLitePaddingValues padding; + // The scaling factor from input to output (aka the 'real multiplier') can + // be represented as a fixed point multiplier plus a left shift. + int32_t output_multiplier; + int output_shift; + + // Per channel output multiplier and shift. + int32_t per_channel_output_multiplier[kMaxChannels]; + int32_t per_channel_output_shift[kMaxChannels]; + + // The range of the fused activation layer. For example for kNone and + // uint8_t these would be 0 and 255. + int32_t output_activation_min; + int32_t output_activation_max; +}; + +inline PaddingType RuntimePaddingType(TfLitePadding padding) { + switch (padding) { + case TfLitePadding::kTfLitePaddingSame: + return PaddingType::kSame; + case TfLitePadding::kTfLitePaddingValid: + return PaddingType::kValid; + case TfLitePadding::kTfLitePaddingUnknown: + default: + return PaddingType::kNone; + } +} + +bool IsMliApplicable(TfLiteContext* context, const TfLiteTensor* input, + const TfLiteTensor* filter, const TfLiteTensor* bias, + const TfLiteConvParams* params) { + const auto* affine_quantization = + reinterpret_cast<TfLiteAffineQuantization*>(filter->quantization.params); + // MLI optimized version only supports int8 dataype, dilation factor of 1 and + // per-axis quantization of weights (no broadcasting/per-tensor) + bool ret_val = (filter->type == kTfLiteInt8) && + (input->type == kTfLiteInt8) && (bias->type == kTfLiteInt32) && + (params->dilation_width_factor == 1) && + (params->dilation_height_factor == 1) && + (affine_quantization->scale->size == + filter->dims->data[kConvQuantizedDimension]) && + affine_quantization->scale->size <= (kMaxChannels * 2); + return ret_val; +} + +TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node, + TfLiteConvParams* params, int width, int height, + int filter_width, int filter_height, int out_width, + int out_height, const TfLiteType data_type, + bool mli_is_applicable, OpData* data) { + bool has_bias = node->inputs->size == 3; + // Check number of inputs/outputs + TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2); + TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); + + // Matching GetWindowedOutputSize in TensorFlow. + auto padding = params->padding; + data->padding = ComputePaddingHeightWidth( + params->stride_height, params->stride_width, + params->dilation_height_factor, params->dilation_width_factor, height, + width, filter_height, filter_width, padding, &out_height, &out_width); + + // Note that quantized inference requires that all tensors have their + // parameters set. This is usually done during quantized training. +#if !defined(TF_LITE_STRIP_REFERENCE_IMPL) + if (data_type != kTfLiteFloat32 && !mli_is_applicable) { + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* filter = GetInput(context, node, kFilterTensor); + const TfLiteTensor* bias = + GetOptionalInputTensor(context, node, kBiasTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + int output_channels = filter->dims->data[kConvQuantizedDimension]; + + TF_LITE_ENSURE_STATUS(tflite::PopulateConvolutionQuantizationParams( + context, input, filter, bias, output, params->activation, + &data->output_multiplier, &data->output_shift, + &data->output_activation_min, &data->output_activation_max, + data->per_channel_output_multiplier, + reinterpret_cast<int*>(data->per_channel_output_shift), + output_channels)); + } +#endif + return kTfLiteOk; +} + +TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, + TfLiteConvParams* params, OpData* data, + const TfLiteTensor* input, + const TfLiteTensor* filter, const TfLiteTensor* bias, + TfLiteTensor* im2col, TfLiteTensor* hwcn_weights, + TfLiteTensor* output) { +#if !defined(TF_LITE_STRIP_REFERENCE_IMPL) + const int32_t input_offset = -input->params.zero_point; + const int32_t filter_offset = -filter->params.zero_point; + const int32_t output_offset = output->params.zero_point; + + ConvParams op_params; + op_params.padding_type = RuntimePaddingType(params->padding); + op_params.padding_values.width = data->padding.width; + op_params.padding_values.height = data->padding.height; + op_params.stride_width = params->stride_width; + op_params.stride_height = params->stride_height; + op_params.dilation_width_factor = params->dilation_width_factor; + op_params.dilation_height_factor = params->dilation_height_factor; + op_params.input_offset = input_offset; + op_params.weights_offset = filter_offset; + op_params.output_offset = output_offset; + op_params.output_multiplier = data->output_multiplier; + op_params.output_shift = -data->output_shift; + op_params.quantized_activation_min = data->output_activation_min; + op_params.quantized_activation_max = data->output_activation_max; + reference_ops::Conv(op_params, GetTensorShape(input), + GetTensorData<uint8_t>(input), GetTensorShape(filter), + GetTensorData<uint8_t>(filter), GetTensorShape(bias), + GetTensorData<int32_t>(bias), GetTensorShape(output), + GetTensorData<uint8_t>(output), GetTensorShape(im2col), + GetTensorData<uint8_t>(im2col), nullptr); + return kTfLiteOk; +#else + TF_LITE_KERNEL_LOG(context, + "Type %s (%d) is not supported by ARC MLI Library.", + TfLiteTypeGetName(input->type), input->type); + return kTfLiteError; +#endif +} + +TfLiteStatus EvalMliQuantizedPerChannel( + TfLiteContext* context, TfLiteNode* node, TfLiteConvParams* params, + OpData* data, const TfLiteTensor* input, const TfLiteTensor* filter, + const TfLiteTensor* bias, TfLiteTensor* output) { + // Run Conv MLI kernel + // MLI optimized version only supports int8 dataype and dilation factor of 1 + if ((input->type == kTfLiteInt8) && (params->dilation_width_factor == 1) && + (params->dilation_height_factor == 1)) { + mli_tensor mli_in = {0}; + mli_tensor mli_weights = {0}; + mli_tensor mli_bias = {0}; + mli_tensor mli_out = {0}; + mli_conv2d_cfg cfg = {}; + + // reuse space allocated for OpData parameters + mli_weights.el_params.asym.scale.pi16 = + (int16_t*)data->per_channel_output_multiplier; + mli_bias.el_params.asym.scale.pi16 = + (int16_t*)data->per_channel_output_shift; + + int16_t filter_zero_point = 0; + int16_t bias_zero_point = 0; + mli_weights.el_params.asym.zero_point.pi16 = &filter_zero_point; + mli_bias.el_params.asym.zero_point.pi16 = &bias_zero_point; + + ConvertToMliTensor<int8_t>(input, &mli_in); + ConvertToMliTensorPerChannel<int8_t>(filter, &mli_weights); + ConvertToMliTensorPerChannel<int32_t>(bias, &mli_bias); + ConvertToMliTensor<int8_t>(output, &mli_out); + + if (params->activation == kTfLiteActRelu) { + cfg.relu.type = MLI_RELU_GEN; + } else if (params->activation == kTfLiteActRelu6) { + cfg.relu.type = MLI_RELU_6; + } else if (params->activation == kTfLiteActRelu1) { + cfg.relu.type = MLI_RELU_1; + } else { + cfg.relu.type = MLI_RELU_NONE; + } + + cfg.stride_width = params->stride_width; + cfg.stride_height = params->stride_height; + if (params->padding == kTfLitePaddingValid) { + cfg.padding_left = 0; + cfg.padding_right = 0; + cfg.padding_top = 0; + cfg.padding_bottom = 0; + } else { + cfg.padding_left = data->padding.width; + cfg.padding_right = data->padding.width + data->padding.width_offset; + cfg.padding_top = data->padding.height; + cfg.padding_bottom = data->padding.height + data->padding.height_offset; + } + + // for height slicing + const int height_dimension = 1; + int in_slice_height = 0; + int out_slice_height = 0; + const int kernel_height = + static_cast<int>(mli_weights.shape[KRNL_H_DIM_HWC]); + const int overlap = kernel_height - cfg.stride_height; + + // for weight slicing (on output channels) + // NHWC layout for weigths, output channel dimension is the first dimension. + const int weight_out_ch_dimension = 0; + int slice_channels = + static_cast<int>(mli_weights.shape[weight_out_ch_dimension]); + // Batch-Height-Width-Channel layout means last dimension is output + // channels. + const int out_tensor_ch_dimension = 3; + + // Tensors for data in fast (local) memory and config to copy data from + // external to local memory + mli_tensor weights_local = mli_weights; + mli_tensor bias_local = mli_bias; + mli_tensor in_local = mli_in; + mli_tensor out_local = mli_out; + mli_mov_cfg_t copy_config; + mli_mov_cfg_for_copy(©_config); + TF_LITE_ENSURE_STATUS(get_arc_scratch_buffer_for_conv_tensors( + context, &in_local, &weights_local, &bias_local, &out_local)); + TF_LITE_ENSURE_STATUS(arc_scratch_buffer_calc_slice_size_io( + &in_local, &out_local, kernel_height, cfg.stride_height, + cfg.padding_top, cfg.padding_bottom, &in_slice_height, + &out_slice_height)); + TF_LITE_ENSURE_STATUS(arc_scratch_buffer_calc_slice_size_weights( + &weights_local, &bias_local, weight_out_ch_dimension, &slice_channels)); + + /* is_local indicates that the tensor is already in local memory, + so in that case the original tensor can be used, + and there is no need to copy it to the local tensor*/ + const bool in_is_local = in_local.data == mli_in.data; + const bool out_is_local = out_local.data == mli_out.data; + const bool w_is_local = weights_local.data == mli_weights.data; + const bool b_is_local = bias_local.data == mli_bias.data; + + TensorSlicer w_slice(&mli_weights, weight_out_ch_dimension, slice_channels); + TensorSlicer b_slice(&mli_bias, weight_out_ch_dimension, slice_channels); + TensorSlicer out_ch_slice(&mli_out, out_tensor_ch_dimension, slice_channels, + 0, 0, 0, true); + + mli_tensor* w_ptr = w_is_local ? w_slice.Sub() : &weights_local; + mli_tensor* b_ptr = b_is_local ? b_slice.Sub() : &bias_local; + + void* input_buffer_ptr = NULL; + int input_buffer_size = 0; + + while (!w_slice.Done()) { + mli_mov_tensor_sync(w_slice.Sub(), ©_config, w_ptr); + mli_mov_tensor_sync(b_slice.Sub(), ©_config, b_ptr); + + /* mli_in tensor contains batches of HWC tensors. so it is a 4 dimensional + tensor. because the mli kernel will process one HWC tensor at a time, the + 4 dimensional tensor needs to be sliced into nBatch 3 dimensional tensors. + on top of that there could be a need to also slice in the Height + dimension. for that the sliceHeight has been calculated. The tensor slicer + is configured that it will completely slice the nBatch dimension (0) and + slice the height dimension (1) in chunks of 'sliceHeight' */ + TensorSlicer in_slice(&mli_in, height_dimension, in_slice_height, + cfg.padding_top, cfg.padding_bottom, overlap); + + /* output tensor is alreade sliced in the output channel dimension. + out_ch_slice.Sub() is the tensor for the amount of output channels of this + itteration of the weight slice loop. This tensor needs to be further + sliced over the batch and height dimension. */ + TensorSlicer out_slice(out_ch_slice.Sub(), height_dimension, + out_slice_height); + + /* setup the pointers to the local or remote tensor to make the code + * inside the loop easier. */ + mli_tensor* in_ptr = in_is_local ? in_slice.Sub() : &in_local; + mli_tensor* out_ptr = out_is_local ? out_slice.Sub() : &out_local; + + while (!out_slice.Done()) { + TF_LITE_ENSURE(context, !in_slice.Done()); + cfg.padding_top = in_slice.GetPaddingPre(); + cfg.padding_bottom = in_slice.GetPaddingPost(); + + // if same input copy as previous iteration, skip the copy of input + if ((in_slice.Sub()->data != input_buffer_ptr) || + (mli_hlp_count_elem_num(in_slice.Sub(), 0) != input_buffer_size)) { + mli_mov_tensor_sync(in_slice.Sub(), ©_config, in_ptr); + input_buffer_ptr = in_slice.Sub()->data; + input_buffer_size = mli_hlp_count_elem_num(in_slice.Sub(), 0); + } + mli_krn_conv2d_nhwc_sa8_sa8_sa32(in_ptr, w_ptr, b_ptr, &cfg, out_ptr); + mli_mov_tensor_sync(out_ptr, ©_config, out_slice.Sub()); + + in_slice.Next(); + out_slice.Next(); + } + w_slice.Next(); + b_slice.Next(); + out_ch_slice.Next(); + TF_LITE_ENSURE(context, in_slice.Done()); + } + } + return kTfLiteOk; +} + +TfLiteStatus EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node, + TfLiteConvParams* params, OpData* data, + const TfLiteTensor* input, + const TfLiteTensor* filter, + const TfLiteTensor* bias, + TfLiteTensor* output) { +#if !defined(TF_LITE_STRIP_REFERENCE_IMPL) + ConvParams op_params; + op_params.input_offset = -input->params.zero_point; + op_params.output_offset = output->params.zero_point; + op_params.stride_height = params->stride_height; + op_params.stride_width = params->stride_width; + op_params.dilation_height_factor = params->dilation_height_factor; + op_params.dilation_width_factor = params->dilation_width_factor; + op_params.padding_values.height = data->padding.height; + op_params.padding_values.width = data->padding.width; + op_params.quantized_activation_min = data->output_activation_min; + op_params.quantized_activation_max = data->output_activation_max; + + reference_integer_ops::ConvPerChannel( + op_params, data->per_channel_output_multiplier, + data->per_channel_output_shift, GetTensorShape(input), + GetTensorData<int8>(input), GetTensorShape(filter), + GetTensorData<int8>(filter), GetTensorShape(bias), + GetTensorData<int32>(bias), GetTensorShape(output), + GetTensorData<int8>(output)); + return kTfLiteOk; +#else + TF_LITE_KERNEL_LOG(context, + "Node configuration is not supported by ARC MLI Library."); + return kTfLiteError; +#endif +} + +TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node, + TfLiteConvParams* params, OpData* data, + const TfLiteTensor* input, const TfLiteTensor* filter, + const TfLiteTensor* bias, TfLiteTensor* im2col, + TfLiteTensor* hwcn_weights, TfLiteTensor* output) { +#if !defined(TF_LITE_STRIP_REFERENCE_IMPL) + float output_activation_min, output_activation_max; + CalculateActivationRange(params->activation, &output_activation_min, + &output_activation_max); + + ConvParams op_params; + op_params.padding_type = RuntimePaddingType(params->padding); + op_params.padding_values.width = data->padding.width; + op_params.padding_values.height = data->padding.height; + op_params.stride_width = params->stride_width; + op_params.stride_height = params->stride_height; + op_params.dilation_width_factor = params->dilation_width_factor; + op_params.dilation_height_factor = params->dilation_height_factor; + op_params.float_activation_min = output_activation_min; + op_params.float_activation_max = output_activation_max; + + reference_ops::Conv(op_params, GetTensorShape(input), + GetTensorData<float>(input), GetTensorShape(filter), + GetTensorData<float>(filter), GetTensorShape(bias), + GetTensorData<float>(bias), GetTensorShape(output), + GetTensorData<float>(output), GetTensorShape(im2col), + GetTensorData<float>(im2col)); + return kTfLiteOk; +#else + TF_LITE_KERNEL_LOG(context, + "Type %s (%d) is not supported by ARC MLI Library.", + TfLiteTypeGetName(input->type), input->type); + return kTfLiteError; +#endif +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast<TfLiteConvParams*>(node->builtin_data); + + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* filter = GetInput(context, node, kFilterTensor); + const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); + + int input_width = input->dims->data[2]; + int input_height = input->dims->data[1]; + int filter_width = filter->dims->data[2]; + int filter_height = filter->dims->data[1]; + int output_width = output->dims->data[2]; + int output_height = output->dims->data[1]; + + OpData data; + + // All per-channel quantized tensors need valid zero point and scale arrays. + if (input->type == kTfLiteInt8) { + TF_LITE_ENSURE_EQ(context, filter->quantization.type, + kTfLiteAffineQuantization); + + const auto* affine_quantization = + reinterpret_cast<TfLiteAffineQuantization*>( + filter->quantization.params); + TF_LITE_ENSURE(context, affine_quantization); + TF_LITE_ENSURE(context, affine_quantization->scale); + TF_LITE_ENSURE(context, affine_quantization->zero_point); + + TF_LITE_ENSURE(context, + affine_quantization->scale->size == 1 || + affine_quantization->scale->size == + filter->dims->data[kConvQuantizedDimension]); + TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size, + affine_quantization->zero_point->size); + } + bool mli_is_applicable = + IsMliApplicable(context, input, filter, bias, params); + TF_LITE_ENSURE_STATUS( + CalculateOpData(context, node, params, input_width, input_height, + filter_width, filter_height, output_width, output_height, + input->type, mli_is_applicable, &data)); + + switch (input->type) { // Already know in/out types are same. + case kTfLiteFloat32: + return EvalFloat(context, node, params, &data, input, filter, bias, + nullptr, nullptr, output); + break; + case kTfLiteInt8: + if (mli_is_applicable) { + return EvalMliQuantizedPerChannel(context, node, params, &data, input, + filter, bias, output); + + } else { + return EvalQuantizedPerChannel(context, node, params, &data, input, + filter, bias, output); + } + break; + case kTfLiteUInt8: + return EvalQuantized(context, node, params, &data, input, filter, bias, + nullptr, nullptr, output); + break; + default: + TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.", + TfLiteTypeGetName(input->type), input->type); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace conv + +TfLiteRegistration* Register_CONV_2D() { + static TfLiteRegistration r = {/*init=*/nullptr, + /*free=*/nullptr, + /*prepare=*/nullptr, + /*invoke=*/conv::Eval, + /*profiling_string=*/nullptr, + /*builtin_code=*/0, + /*custom_name=*/nullptr, + /*version=*/0}; + return &r; +} + +} // namespace micro +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/arc_mli/conv_slicing_test.cc b/tensorflow/lite/micro/kernels/arc_mli/conv_slicing_test.cc new file mode 100644 index 00000000000..1accc919dd2 --- /dev/null +++ b/tensorflow/lite/micro/kernels/arc_mli/conv_slicing_test.cc @@ -0,0 +1,506 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This test checks that slicing logic doesn`t affect result of convolution +// kernel +// +// This test doesn`t replace default convolution test +// (tensorflow/lite/micro/kernels/conv_test.cc). It is added to the whole +// testset only in case MLI for ARC platform is used during generation (which is +// handled in arc_mli.inc). So such tests won`t be generated for other +// platforms. + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/micro/kernels/all_ops_resolver.h" +#include "tensorflow/lite/micro/micro_utils.h" +#include "tensorflow/lite/micro/testing/micro_test.h" +#include "tensorflow/lite/micro/testing/test_utils.h" + +namespace tflite { +namespace testing { +namespace { + +// Common inputs and outputs 1. +static const int kInput1Elements = 20; +static const int kInput1Shape[] = {4, 1, 5, 2, 2}; +static const float kInput1Data[] = {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}; +static const int kFilter1Elements = 36; +static const int kFilter1Shape[] = {4, 2, 3, 3, 2}; +static const float kFilter1Data[] = {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}; +static const int kBias1Elements = 2; +static const int kBias1Shape[] = {1, 2}; +static const float kBias1Data[] = {2, 2}; +static const int kOutput1Elements = 20; +static const int kOutput1Shape[] = {4, 1, 5, 2, 2}; +static const float kGolden1Data[] = {34, 34, 34, 34, 50, 50, 50, 50, 50, 50, + 50, 50, 50, 50, 50, 50, 34, 34, 34, 34}; + +// Common inputs and outputs 2. +static const int kInput2Elements = 80; +static const int kInput2Shape[] = {4, 1, 20, 2, 2}; +static const float kInput2Data[] = { + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}; +static const int kFilter2Elements = 36; +static const int kFilter2Shape[] = {4, 2, 3, 3, 2}; +static const float kFilter2Data[] = {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}; +static const int kBias2Elements = 2; +static const int kBias2Shape[] = {1, 2}; +static const float kBias2Data[] = {2, 2}; +static const int kOutput2Elements = 80; +static const int kOutput2Shape[] = {4, 1, 20, 2, 2}; +static const float kGolden2Data[] = { + 34, 34, 34, 34, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, + 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, + 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, + 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, + 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 34, 34, 34, 34}; + +// Common inputs and outputs 3. +static const int kInput3Elements = 40; +static const int kInput3Shape[] = {4, 1, 2, 2, 10}; +static const float kInput3Data[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; +static const int kFilter3Elements = 90; +static const int kFilter3Shape[] = {4, 1, 3, 3, 10}; // 1 3 3 10 +static const float kFilter3Data[] = { + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; +static const int kBias3Elements = 1; +static const int kBias3Shape[] = {1, 1}; +static const float kBias3Data[] = {1}; +static const int kOutput3Elements = 4; +static const int kOutput3Shape[] = {4, 1, 2, 2, 1}; // 2 2 1 +static const float kGolden3Data[] = {41, 41, 41, 41}; + +// Common inputs and outputs 4. +static const int kInput4Elements = 80; +static const int kInput4Shape[] = {4, 1, 4, 2, 10}; +static const float kInput4Data[] = { + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; +static const int kFilter4Elements = 90; +static const int kFilter4Shape[] = {4, 1, 3, 3, 10}; +static const float kFilter4Data[] = { + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; +static const int kBias4Elements = 1; +static const int kBias4Shape[] = {1, 1}; +static const float kBias4Data[] = {1}; +static const int kOutput4Elements = 8; +static const int kOutput4Shape[] = {4, 1, 4, 2, 1}; +static const float kGolden4Data[] = {41, 41, 61, 61, 61, 61, 41, 41}; + +static TfLiteConvParams common_conv_params = { + kTfLitePaddingSame, // padding + 1, // stride_width + 1, // stride_height + kTfLiteActNone, // activation + 1, // dilation_width_factor + 1, // dilation_height_factor +}; + +template <typename T> +TfLiteStatus ValidateConvGoldens(TfLiteTensor* tensors, int tensors_size, + const T* expected_output_data, T* output_data, + int output_length, + TfLiteConvParams* conv_params, + float tolerance = 1e-5) { + TfLiteContext context; + PopulateContext(tensors, tensors_size, micro_test::reporter, &context); + + ::tflite::ops::micro::AllOpsResolver resolver; + + const TfLiteRegistration* registration = + resolver.FindOp(tflite::BuiltinOperator_CONV_2D, 1); + + TF_LITE_MICRO_EXPECT_NE(nullptr, registration); + + const char* init_data = reinterpret_cast<const char*>(conv_params); + size_t init_data_size = 0; + void* user_data = nullptr; + + if (registration->init) { + user_data = registration->init(&context, init_data, init_data_size); + } + + int inputs_array_data[] = {3, 0, 1, 2}; + TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); + int outputs_array_data[] = {1, 3}; + TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); + int temporaries_array_data[] = {0}; + TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data); + + TfLiteNode node; + node.inputs = inputs_array; + node.outputs = outputs_array; + node.temporaries = temporaries_array; + node.user_data = user_data; + node.builtin_data = reinterpret_cast<void*>(conv_params); + node.custom_initial_data = nullptr; + node.custom_initial_data_size = 0; + node.delegate = nullptr; + + if (registration->prepare) { + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); + } + TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke); + TfLiteStatus return_val = registration->invoke(&context, &node); + if (return_val != kTfLiteOk) { + return return_val; + } + + if (registration->free) { + registration->free(&context, user_data); + } + + for (int i = 0; i < output_length; ++i) { + TF_LITE_MICRO_EXPECT_NEAR(expected_output_data[i], output_data[i], + tolerance); + } + return kTfLiteOk; +} + +void TestConvQuantizedPerChannel( + const int* input_dims_data, const float* input_data, + int8_t* input_quantized, float input_scale, int input_zero_point, + const int* filter_dims_data, const float* filter_data, + int8_t* filter_data_quantized, const int* bias_dims_data, + const float* bias_data, int32_t* bias_data_quantized, float* bias_scales, + int* bias_zero_points, const int* output_dims_data, + const float* expected_output_data, int8_t* expected_output_data_quantized, + int8_t* output_data, float output_scale, int output_zero_point, + TfLiteConvParams* conv_params) { + TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); + TfLiteIntArray* filter_dims = IntArrayFromInts(filter_dims_data); + TfLiteIntArray* bias_dims = IntArrayFromInts(bias_dims_data); + TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data); + const int output_dims_count = ElementCount(*output_dims); + + int filter_zero_points[5]; + float filter_scales[5]; + TfLiteAffineQuantization filter_quant; + TfLiteAffineQuantization bias_quant; + TfLiteTensor input_tensor = + CreateQuantizedTensor(input_data, input_quantized, input_dims, + input_scale, input_zero_point, "input_tensor"); + TfLiteTensor filter_tensor = CreateSymmetricPerChannelQuantizedTensor( + filter_data, filter_data_quantized, filter_dims, filter_scales, + filter_zero_points, &filter_quant, 0 /* quantized dimension */, + "filter_tensor"); + + // DN: to replace scales and quantized data to avoid second quantization + int channel_count = filter_dims->data[0]; + float true_filter_scales[5] = {1.0, 1.0, 1.0, 1.0, 1.0}; + true_filter_scales[0] = static_cast<float>(channel_count); + TfLiteAffineQuantization* to_change = + (TfLiteAffineQuantization*)filter_tensor.quantization.params; + to_change->scale = FloatArrayFromFloats(true_filter_scales); + + int filter_size = filter_tensor.bytes; + for (int i = 0; i < filter_size; ++i) { + filter_tensor.data.int8[i] = filter_data[i]; + } + + TfLiteTensor bias_tensor = CreatePerChannelQuantizedBiasTensor( + bias_data, bias_data_quantized, bias_dims, input_scale, &filter_scales[1], + bias_scales, bias_zero_points, &bias_quant, 0 /* quantized dimension */, + "bias_tensor"); + TfLiteTensor output_tensor = + CreateQuantizedTensor(output_data, output_dims, output_scale, + output_zero_point, "output_tensor"); + + float input_scales[] = {1, input_scale}; + int input_zero_points[] = {1, input_zero_point}; + TfLiteAffineQuantization input_quant = {FloatArrayFromFloats(input_scales), + IntArrayFromInts(input_zero_points)}; + input_tensor.quantization = {kTfLiteAffineQuantization, &input_quant}; + + float output_scales[] = {1, output_scale}; + int output_zero_points[] = {1, output_zero_point}; + TfLiteAffineQuantization output_quant = { + FloatArrayFromFloats(output_scales), + IntArrayFromInts(output_zero_points)}; + output_tensor.quantization = {kTfLiteAffineQuantization, &output_quant}; + + constexpr int inputs_size = 3; + constexpr int outputs_size = 1; + constexpr int tensors_size = inputs_size + outputs_size; + TfLiteTensor tensors[tensors_size] = { + input_tensor, + filter_tensor, + bias_tensor, + output_tensor, + }; + + tflite::AsymmetricQuantize(expected_output_data, + expected_output_data_quantized, output_dims_count, + output_scale, output_zero_point); + TF_LITE_MICRO_EXPECT_EQ( + kTfLiteOk, + ValidateConvGoldens(tensors, tensors_size, expected_output_data_quantized, + output_data, output_dims_count, conv_params, + 1.0 /* tolerance */)); +} + +} // namespace +} // namespace testing +} // namespace tflite + +TF_LITE_MICRO_TESTS_BEGIN + +// Test group 1 +TF_LITE_MICRO_TEST(SystemTestQuantizedPerChannel1) { + const int output_dims_count = 20; + const float input_scale = 1.0f; + const float output_scale = 1.0f; + const int input_zero_point = 0; + const int output_zero_point = 0; + + int8_t input_quantized[tflite::testing::kInput1Elements]; + int8_t filter_quantized[tflite::testing::kFilter1Elements]; + int32_t bias_quantized[tflite::testing::kBias1Elements]; + int8_t golden_quantized[tflite::testing::kOutput1Elements]; + int8_t output_data[output_dims_count]; + + int zero_points[tflite::testing::kBias1Elements + 1]; + float scales[tflite::testing::kBias1Elements + 1]; + + tflite::testing::TestConvQuantizedPerChannel( + tflite::testing::kInput1Shape, tflite::testing::kInput1Data, + input_quantized, input_scale, input_zero_point, + tflite::testing::kFilter1Shape, tflite::testing::kFilter1Data, + filter_quantized, tflite::testing::kBias1Shape, + tflite::testing::kBias1Data, bias_quantized, scales, zero_points, + tflite::testing::kOutput1Shape, tflite::testing::kGolden1Data, + golden_quantized, output_data, output_scale, output_zero_point, + &tflite::testing::common_conv_params); +} + +TF_LITE_MICRO_TEST(LocalTestQuantizedPerChannel1) { + const int output_dims_count = 20; + const float input_scale = 1.0f; + const float output_scale = 1.0f; + const int input_zero_point = 0; + const int output_zero_point = 0; + +#pragma Bss(".Xdata") + static int8_t input_quantized[tflite::testing::kInput1Elements]; + static int8_t filter_quantized[tflite::testing::kFilter1Elements]; + static int32_t bias_quantized[tflite::testing::kBias1Elements]; + static int8_t output_data[output_dims_count]; +#pragma Bss() + + int8_t golden_quantized[tflite::testing::kOutput1Elements]; + int zero_points[tflite::testing::kBias1Elements + 1]; + float scales[tflite::testing::kBias1Elements + 1]; + + tflite::testing::TestConvQuantizedPerChannel( + tflite::testing::kInput1Shape, tflite::testing::kInput1Data, + input_quantized, input_scale, input_zero_point, + tflite::testing::kFilter1Shape, tflite::testing::kFilter1Data, + filter_quantized, tflite::testing::kBias1Shape, + tflite::testing::kBias1Data, bias_quantized, scales, zero_points, + tflite::testing::kOutput1Shape, tflite::testing::kGolden1Data, + golden_quantized, output_data, output_scale, output_zero_point, + &tflite::testing::common_conv_params); +} + +// Test group 2 +TF_LITE_MICRO_TEST(SystemTestQuantizedPerChannel2) { + const int output_dims_count = 80; + const float input_scale = 1.0f; + const float output_scale = 1.0f; + const int input_zero_point = 0; + const int output_zero_point = 0; + + int8_t input_quantized[tflite::testing::kInput2Elements]; + int8_t filter_quantized[tflite::testing::kFilter2Elements]; + int32_t bias_quantized[tflite::testing::kBias2Elements]; + int8_t golden_quantized[tflite::testing::kOutput2Elements]; + int8_t output_data[output_dims_count]; + + int zero_points[tflite::testing::kBias2Elements + 1]; + float scales[tflite::testing::kBias2Elements + 1]; + + tflite::testing::TestConvQuantizedPerChannel( + tflite::testing::kInput2Shape, tflite::testing::kInput2Data, + input_quantized, input_scale, input_zero_point, + tflite::testing::kFilter2Shape, tflite::testing::kFilter2Data, + filter_quantized, tflite::testing::kBias2Shape, + tflite::testing::kBias2Data, bias_quantized, scales, zero_points, + tflite::testing::kOutput2Shape, tflite::testing::kGolden2Data, + golden_quantized, output_data, output_scale, output_zero_point, + &tflite::testing::common_conv_params); +} + +TF_LITE_MICRO_TEST(LocalTestQuantizedPerChannel2) { + const int output_dims_count = 80; + const float input_scale = 1.0f; + const float output_scale = 1.0f; + const int input_zero_point = 0; + const int output_zero_point = 0; + +#pragma Bss(".Xdata") + static int8_t input_quantized[tflite::testing::kInput2Elements]; + static int8_t filter_quantized[tflite::testing::kFilter2Elements]; + static int32_t bias_quantized[tflite::testing::kBias2Elements]; + static int8_t output_data[output_dims_count]; +#pragma Bss() + + int8_t golden_quantized[tflite::testing::kOutput2Elements]; + int zero_points[tflite::testing::kBias2Elements + 1]; + float scales[tflite::testing::kBias2Elements + 1]; + + tflite::testing::TestConvQuantizedPerChannel( + tflite::testing::kInput2Shape, tflite::testing::kInput2Data, + input_quantized, input_scale, input_zero_point, + tflite::testing::kFilter2Shape, tflite::testing::kFilter2Data, + filter_quantized, tflite::testing::kBias2Shape, + tflite::testing::kBias2Data, bias_quantized, scales, zero_points, + tflite::testing::kOutput2Shape, tflite::testing::kGolden2Data, + golden_quantized, output_data, output_scale, output_zero_point, + &tflite::testing::common_conv_params); +} + +// Test group 3 +TF_LITE_MICRO_TEST(SystemTestQuantizedPerChannel3) { + const int output_dims_count = 4; + const float input_scale = 1.0f; + const float output_scale = 1.0f; + const int input_zero_point = 0; + const int output_zero_point = 0; + + int8_t input_quantized[tflite::testing::kInput3Elements]; + int8_t filter_quantized[tflite::testing::kFilter3Elements]; + int32_t bias_quantized[tflite::testing::kBias3Elements]; + int8_t golden_quantized[tflite::testing::kOutput3Elements]; + int8_t output_data[output_dims_count]; + + int zero_points[tflite::testing::kBias3Elements + 1]; + float scales[tflite::testing::kBias3Elements + 1]; + + tflite::testing::TestConvQuantizedPerChannel( + tflite::testing::kInput3Shape, tflite::testing::kInput3Data, + input_quantized, input_scale, input_zero_point, + tflite::testing::kFilter3Shape, tflite::testing::kFilter3Data, + filter_quantized, tflite::testing::kBias3Shape, + tflite::testing::kBias3Data, bias_quantized, scales, zero_points, + tflite::testing::kOutput3Shape, tflite::testing::kGolden3Data, + golden_quantized, output_data, output_scale, output_zero_point, + &tflite::testing::common_conv_params); +} + +TF_LITE_MICRO_TEST(LocalTestQuantizedPerChannel3) { + const int output_dims_count = 4; + const float input_scale = 1.0f; + const float output_scale = 1.0f; + const int input_zero_point = 0; + const int output_zero_point = 0; + +#pragma Bss(".Xdata") + static int8_t input_quantized[tflite::testing::kInput3Elements]; + static int8_t filter_quantized[tflite::testing::kFilter3Elements]; + static int32_t bias_quantized[tflite::testing::kBias3Elements]; + static int8_t output_data[output_dims_count]; +#pragma Bss() + + int8_t golden_quantized[tflite::testing::kOutput3Elements]; + int zero_points[tflite::testing::kBias3Elements + 1]; + float scales[tflite::testing::kBias3Elements + 1]; + + tflite::testing::TestConvQuantizedPerChannel( + tflite::testing::kInput3Shape, tflite::testing::kInput3Data, + input_quantized, input_scale, input_zero_point, + tflite::testing::kFilter3Shape, tflite::testing::kFilter3Data, + filter_quantized, tflite::testing::kBias3Shape, + tflite::testing::kBias3Data, bias_quantized, scales, zero_points, + tflite::testing::kOutput3Shape, tflite::testing::kGolden3Data, + golden_quantized, output_data, output_scale, output_zero_point, + &tflite::testing::common_conv_params); +} + +// Test group 4 +TF_LITE_MICRO_TEST(SystemTestQuantizedPerChannel4) { + const int output_dims_count = 8; + const float input_scale = 1.0f; + const float output_scale = 1.0f; + const int input_zero_point = 0; + const int output_zero_point = 0; + + int8_t input_quantized[tflite::testing::kInput4Elements]; + int8_t filter_quantized[tflite::testing::kFilter4Elements]; + int32_t bias_quantized[tflite::testing::kBias4Elements]; + int8_t golden_quantized[tflite::testing::kOutput4Elements]; + int8_t output_data[output_dims_count]; + + int zero_points[tflite::testing::kBias4Elements + 1]; + float scales[tflite::testing::kBias4Elements + 1]; + + tflite::testing::TestConvQuantizedPerChannel( + tflite::testing::kInput4Shape, tflite::testing::kInput4Data, + input_quantized, input_scale, input_zero_point, + tflite::testing::kFilter4Shape, tflite::testing::kFilter4Data, + filter_quantized, tflite::testing::kBias4Shape, + tflite::testing::kBias4Data, bias_quantized, scales, zero_points, + tflite::testing::kOutput4Shape, tflite::testing::kGolden4Data, + golden_quantized, output_data, output_scale, output_zero_point, + &tflite::testing::common_conv_params); +} + +TF_LITE_MICRO_TEST(LocalTestQuantizedPerChannel4) { + const int output_dims_count = 8; + const float input_scale = 1.0f; + const float output_scale = 1.0f; + const int input_zero_point = 0; + const int output_zero_point = 0; + +#pragma Bss(".Xdata") + static int8_t input_quantized[tflite::testing::kInput4Elements]; + static int8_t filter_quantized[tflite::testing::kFilter4Elements]; + static int32_t bias_quantized[tflite::testing::kBias4Elements]; + static int8_t output_data[output_dims_count]; +#pragma Bss() + + int8_t golden_quantized[tflite::testing::kOutput4Elements]; + int zero_points[tflite::testing::kBias4Elements + 1]; + float scales[tflite::testing::kBias4Elements + 1]; + + tflite::testing::TestConvQuantizedPerChannel( + tflite::testing::kInput4Shape, tflite::testing::kInput4Data, + input_quantized, input_scale, input_zero_point, + tflite::testing::kFilter4Shape, tflite::testing::kFilter4Data, + filter_quantized, tflite::testing::kBias4Shape, + tflite::testing::kBias4Data, bias_quantized, scales, zero_points, + tflite::testing::kOutput4Shape, tflite::testing::kGolden4Data, + golden_quantized, output_data, output_scale, output_zero_point, + &tflite::testing::common_conv_params); +} +TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/micro/kernels/arc_mli/depthwise_conv.cc b/tensorflow/lite/micro/kernels/arc_mli/depthwise_conv.cc new file mode 100644 index 00000000000..b1a26a6a10e --- /dev/null +++ b/tensorflow/lite/micro/kernels/arc_mli/depthwise_conv.cc @@ -0,0 +1,515 @@ +/* Copyright 2017-2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/kernels/internal/reference/integer_ops/depthwise_conv.h" + +#include "mli_api.h" // NOLINT +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h" +#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/padding.h" +#include "tensorflow/lite/micro/kernels/arc_mli/mli_slicers.h" +#include "tensorflow/lite/micro/kernels/arc_mli/mli_tf_utils.h" +#include "tensorflow/lite/micro/kernels/arc_mli/scratch_buf_mgr.h" +#include "tensorflow/lite/micro/kernels/arc_mli/scratch_buffers.h" + +namespace tflite { +namespace ops { +namespace micro { +namespace depthwise_conv { +namespace { + +constexpr int kInputTensor = 0; +constexpr int kFilterTensor = 1; +constexpr int kBiasTensor = 2; +constexpr int kOutputTensor = 0; +constexpr int kMaxChannels = 256; + +// Depthwise conv is quantized along dimension 3: +// https://www.tensorflow.org/lite/performance/quantization_spec +constexpr int kDepthwiseConvQuantizedDimension = 3; + +struct OpData { + TfLitePaddingValues padding; + // The scaling factor from input to output (aka the 'real multiplier') can + // be represented as a fixed point multiplier plus a left shift. + int32_t output_multiplier; + int output_shift; + + // Per channel output multiplier and shift. + int32_t per_channel_output_multiplier[kMaxChannels]; + int32_t per_channel_output_shift[kMaxChannels]; + + // The range of the fused activation layer. For example for kNone and + // uint8_t these would be 0 and 255. + int32_t output_activation_min; + int32_t output_activation_max; +}; + +bool IsMliApplicable(TfLiteContext* context, const TfLiteTensor* input, + const TfLiteTensor* filter, const TfLiteTensor* bias, + const TfLiteDepthwiseConvParams* params) { + const auto* affine_quantization = + reinterpret_cast<TfLiteAffineQuantization*>(filter->quantization.params); + const int in_ch = SizeOfDimension(input, 3); + const int filters_num = SizeOfDimension(filter, 3); + + // MLI optimized version only supports int8 dataype, dilation factor of 1 and + // per-axis quantization of weights (no broadcasting/per-tensor) + // (in_ch == filters_num) || (in_ch == 1)) is a forbidding of + // channel multiplier logic for multichannel input. + bool ret_val = (filter->type == kTfLiteInt8) && + (input->type == kTfLiteInt8) && (bias->type == kTfLiteInt32) && + (params->dilation_width_factor == 1) && + (params->dilation_height_factor == 1) && + (affine_quantization->scale->size == + filter->dims->data[kDepthwiseConvQuantizedDimension]) && + ((in_ch == filters_num) || (in_ch == 1)) && + affine_quantization->scale->size <= (kMaxChannels * 2); + return ret_val; +} + +TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node, + TfLiteDepthwiseConvParams* params, int width, + int height, int filter_width, int filter_height, + const TfLiteType data_type, bool mli_is_applicable, + OpData* data) { + bool has_bias = node->inputs->size == 3; + // Check number of inputs/outputs + TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2); + TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); + + int unused_output_height, unused_output_width; + data->padding = ComputePaddingHeightWidth( + params->stride_height, params->stride_width, 1, 1, height, width, + filter_height, filter_width, params->padding, &unused_output_height, + &unused_output_width); + + // Note that quantized inference requires that all tensors have their + // parameters set. This is usually done during quantized training. +#if !defined(TF_LITE_STRIP_REFERENCE_IMPL) + if (data_type != kTfLiteFloat32 && !mli_is_applicable) { + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* filter = GetInput(context, node, kFilterTensor); + const TfLiteTensor* bias = + GetOptionalInputTensor(context, node, kBiasTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + int num_channels = filter->dims->data[kDepthwiseConvQuantizedDimension]; + + // Ensure filter and bias channel count does not exceed space reserved for + // quantization metadata. + const auto filter_quantization = + reinterpret_cast<TfLiteAffineQuantization*>( + filter->quantization.params); + const auto bias_quantization = + reinterpret_cast<TfLiteAffineQuantization*>(bias->quantization.params); + TF_LITE_ENSURE(context, filter_quantization->scale->size <= kMaxChannels); + TF_LITE_ENSURE(context, bias_quantization->scale->size <= kMaxChannels); + + TF_LITE_ENSURE_STATUS(tflite::PopulateConvolutionQuantizationParams( + context, input, filter, bias, output, params->activation, + &data->output_multiplier, &data->output_shift, + &data->output_activation_min, &data->output_activation_max, + data->per_channel_output_multiplier, + reinterpret_cast<int*>(data->per_channel_output_shift), num_channels)); + } +#endif + return kTfLiteOk; +} + +} // namespace + +TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node, + TfLiteDepthwiseConvParams* params, OpData* data, + const TfLiteTensor* input, const TfLiteTensor* filter, + const TfLiteTensor* bias, TfLiteTensor* output) { +#if !defined(TF_LITE_STRIP_REFERENCE_IMPL) + float output_activation_min, output_activation_max; + CalculateActivationRange(params->activation, &output_activation_min, + &output_activation_max); + + tflite::DepthwiseParams op_params; + // Padding type is ignored, but still set. + op_params.padding_type = PaddingType::kSame; + op_params.padding_values.width = data->padding.width; + op_params.padding_values.height = data->padding.height; + op_params.stride_width = params->stride_width; + op_params.stride_height = params->stride_height; + op_params.dilation_width_factor = params->dilation_width_factor; + op_params.dilation_height_factor = params->dilation_height_factor; + op_params.depth_multiplier = params->depth_multiplier; + op_params.float_activation_min = output_activation_min; + op_params.float_activation_max = output_activation_max; + + tflite::reference_ops::DepthwiseConv( + op_params, GetTensorShape(input), GetTensorData<float>(input), + GetTensorShape(filter), GetTensorData<float>(filter), + GetTensorShape(bias), GetTensorData<float>(bias), GetTensorShape(output), + GetTensorData<float>(output)); + return kTfLiteOk; +#else + TF_LITE_KERNEL_LOG(context, + "Type %s (%d) is not supported by ARC MLI Library.", + TfLiteTypeGetName(input->type), input->type); + return kTfLiteError; +#endif +} + +TfLiteStatus EvalMliQuantizedPerChannel( + TfLiteContext* context, TfLiteNode* node, TfLiteDepthwiseConvParams* params, + OpData* data, const TfLiteTensor* input, const TfLiteTensor* filter, + const TfLiteTensor* bias, TfLiteTensor* output) { + // Run Depthwise Conv MLI kernel + mli_tensor mli_in = {0}; + mli_tensor mli_weights = {0}; + mli_tensor mli_bias = {0}; + mli_tensor mli_out = {0}; + mli_conv2d_cfg cfg = {}; + + // reuse space allocated for OpData parameters + mli_weights.el_params.asym.scale.pi16 = + (int16_t*)data->per_channel_output_multiplier; + mli_bias.el_params.asym.scale.pi16 = (int16_t*)data->per_channel_output_shift; + + int16_t filter_zero_point = 0; + int16_t bias_zero_point = 0; + mli_weights.el_params.asym.zero_point.pi16 = &filter_zero_point; + mli_bias.el_params.asym.zero_point.pi16 = &bias_zero_point; + + ConvertToMliTensor<int8_t>(input, &mli_in); + ConvertToMliTensorPerChannel<int8_t>(filter, &mli_weights); + ConvertToMliTensorPerChannel<int32_t>(bias, &mli_bias); + ConvertToMliTensor<int8_t>(output, &mli_out); + + if (params->activation == kTfLiteActRelu) { + cfg.relu.type = MLI_RELU_GEN; + } else if (params->activation == kTfLiteActRelu6) { + cfg.relu.type = MLI_RELU_6; + } else if (params->activation == kTfLiteActRelu1) { + cfg.relu.type = MLI_RELU_1; + } else { + cfg.relu.type = MLI_RELU_NONE; + } + + cfg.stride_width = params->stride_width; + cfg.stride_height = params->stride_height; + if (params->padding == kTfLitePaddingValid) { + cfg.padding_left = 0; + cfg.padding_right = 0; + cfg.padding_top = 0; + cfg.padding_bottom = 0; + } else { + cfg.padding_left = data->padding.width; + cfg.padding_right = data->padding.width + data->padding.width_offset; + cfg.padding_top = data->padding.height; + cfg.padding_bottom = data->padding.height + data->padding.height_offset; + } + + // for height slicing + const int heightDimension = 1; + int inSliceHeight = 0; + int outSliceHeight = 0; + const int kernelHeight = + static_cast<int>(mli_weights.shape[KRNL_DW_H_DIM_HWC]); + const int overlap = kernelHeight - cfg.stride_height; + + // for weight slicing (on output channels) + // HWCN layout for weigths, output channel dimension is the first dimension. + const int weight_out_ch_dimension = 3; + // bias has only 1 dimension + const int bias_out_ch_dimension = 0; + // Batch-Height-Width-Channel layout means last dimension is output channels. + const int out_tensor_ch_dimension = 3; + const int32_t in_channels = mli_in.shape[out_tensor_ch_dimension]; + const int32_t out_channels = mli_out.shape[out_tensor_ch_dimension]; + int slice_channels = + static_cast<int>(mli_weights.shape[weight_out_ch_dimension]); + + // Tensors for data in fast (local) memory + // and config to copy data from external to local memory + mli_tensor weights_local = mli_weights; + mli_tensor bias_local = mli_bias; + mli_tensor in_local = mli_in; + mli_tensor out_local = mli_out; // this assumes that output shape + // is already filled in the tensor struct. + mli_mov_cfg_t copy_config; + mli_mov_cfg_for_copy(©_config); + + TF_LITE_ENSURE_STATUS(get_arc_scratch_buffer_for_conv_tensors( + context, &in_local, &weights_local, &bias_local, &out_local)); + /* is_local indicates that the tensor is already in local memory, + so in that case the original tensor can be used, + and there is no need to copy it to the local tensor*/ + const bool in_is_local = in_local.data == mli_in.data; + const bool out_is_local = out_local.data == mli_out.data; + const bool w_is_local = weights_local.data == mli_weights.data; + const bool b_is_local = bias_local.data == mli_bias.data; + + TF_LITE_ENSURE_STATUS(arc_scratch_buffer_calc_slice_size_io( + &in_local, &out_local, kernelHeight, cfg.stride_height, cfg.padding_top, + cfg.padding_bottom, &inSliceHeight, &outSliceHeight)); + TF_LITE_ENSURE_STATUS(arc_scratch_buffer_calc_slice_size_weights( + &weights_local, &bias_local, weight_out_ch_dimension, &slice_channels)); + + /* if input channels is not equal to output channels, a channel multiplier + is used. in this case the slice channels needs to be rounded down to a + multiple of the input channels */ + if (in_channels != out_channels) { + slice_channels = (slice_channels / in_channels) * in_channels; + } + + TensorSlicer b_slice(&mli_bias, bias_out_ch_dimension, slice_channels); + TensorSlicer w_slice(&mli_weights, weight_out_ch_dimension, slice_channels, 0, + 0, 0, true); + TensorSlicer out_ch_slice(&mli_out, out_tensor_ch_dimension, slice_channels, + 0, 0, 0, true); + TensorSlicer in_ch_slice(&mli_in, out_tensor_ch_dimension, slice_channels, 0, + 0, 0, true); + + mli_tensor* w_ptr = w_is_local ? w_slice.Sub() : &weights_local; + mli_tensor* b_ptr = b_is_local ? b_slice.Sub() : &bias_local; + + void* input_buffer_ptr = NULL; + int input_buffer_size = 0; + int padding_top = cfg.padding_top; + int padding_bottom = cfg.padding_bottom; + + while (!w_slice.Done()) { + mli_mov_tensor_sync(w_slice.Sub(), ©_config, w_ptr); + mli_mov_tensor_sync(b_slice.Sub(), ©_config, b_ptr); + + /* input tensor is alreade sliced in the channel dimension. + out_ch_slice.Sub() is the tensor for the amount of channels of this + itteration of the weight slice loop. This tensor needs to be further + sliced over the batch and height dimension. in_ch_slice.Sub() tensor + contains batches of HWC tensors. so it is a 4 dimensional tensor. because + the mli kernel will process one HWC tensor at a time, the 4 dimensional + tensor needs to be sliced into nBatch 3 dimensional tensors. on top of + that there could be a need to also slice in the Height dimension. for that + the sliceHeight has been calculated. The tensor slicer is configured that + it will completely slice the nBatch dimension (0) and slice the height + dimension (1) in chunks of 'sliceHeight' */ + TensorSlicer in_slice(in_ch_slice.Sub(), heightDimension, inSliceHeight, + padding_top, padding_bottom, overlap); + + /* output tensor is alreade sliced in the output channel dimension. + out_ch_slice.Sub() is the tensor for the amount of output channels of this + itteration of the weight slice loop. This tensor needs to be further + sliced over the batch and height dimension. */ + TensorSlicer out_slice(out_ch_slice.Sub(), heightDimension, outSliceHeight); + + /* setup the pointers to the local or remote tensor to make the code + * inside the loop easier. */ + mli_tensor* in_ptr = in_is_local ? in_slice.Sub() : &in_local; + mli_tensor* out_ptr = out_is_local ? out_slice.Sub() : &out_local; + + while (!out_slice.Done()) { + TF_LITE_ENSURE(context, !in_slice.Done()); + cfg.padding_top = in_slice.GetPaddingPre(); + cfg.padding_bottom = in_slice.GetPaddingPost(); + + // if same input copy as previous iteration, skip the copy of input + if ((in_slice.Sub()->data != input_buffer_ptr) || + (mli_hlp_count_elem_num(in_slice.Sub(), 0) != input_buffer_size)) { + mli_mov_tensor_sync(in_slice.Sub(), ©_config, in_ptr); + input_buffer_ptr = in_slice.Sub()->data; + input_buffer_size = mli_hlp_count_elem_num(in_slice.Sub(), 0); + } + mli_krn_depthwise_conv2d_hwcn_sa8_sa8_sa32(in_ptr, w_ptr, b_ptr, &cfg, + out_ptr); + mli_mov_tensor_sync(out_ptr, ©_config, out_slice.Sub()); + + in_slice.Next(); + out_slice.Next(); + } + w_slice.Next(); + b_slice.Next(); + out_ch_slice.Next(); + in_ch_slice.Next(); + TF_LITE_ENSURE(context, in_slice.Done()); + } + return kTfLiteOk; +} + +TfLiteStatus EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node, + TfLiteDepthwiseConvParams* params, + OpData* data, const TfLiteTensor* input, + const TfLiteTensor* filter, + const TfLiteTensor* bias, + TfLiteTensor* output) { +#if !defined(TF_LITE_STRIP_REFERENCE_IMPL) + DepthwiseParams op_params; + op_params.padding_type = PaddingType::kSame; + op_params.padding_values.width = data->padding.width; + op_params.padding_values.height = data->padding.height; + op_params.stride_width = params->stride_width; + op_params.stride_height = params->stride_height; + op_params.dilation_width_factor = params->dilation_width_factor; + op_params.dilation_height_factor = params->dilation_height_factor; + op_params.depth_multiplier = params->depth_multiplier; + op_params.input_offset = -input->params.zero_point; + op_params.weights_offset = 0; + op_params.output_offset = output->params.zero_point; + op_params.quantized_activation_min = data->output_activation_min; + op_params.quantized_activation_max = data->output_activation_max; + + reference_integer_ops::DepthwiseConvPerChannel( + op_params, data->per_channel_output_multiplier, + data->per_channel_output_shift, GetTensorShape(input), + GetTensorData<int8>(input), GetTensorShape(filter), + GetTensorData<int8>(filter), GetTensorShape(bias), + GetTensorData<int32>(bias), GetTensorShape(output), + GetTensorData<int8>(output)); + return kTfLiteOk; +#else + TF_LITE_KERNEL_LOG(context, + "Node configuration is not supported by ARC MLI Library."); + return kTfLiteError; +#endif +} + +TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, + TfLiteDepthwiseConvParams* params, OpData* data, + const TfLiteTensor* input, + const TfLiteTensor* filter, const TfLiteTensor* bias, + TfLiteTensor* output) { +#if !defined(TF_LITE_STRIP_REFERENCE_IMPL) + const int32_t input_offset = -input->params.zero_point; + const int32_t filter_offset = -filter->params.zero_point; + const int32_t output_offset = output->params.zero_point; + + tflite::DepthwiseParams op_params; + // Padding type is ignored, but still set. + op_params.padding_type = PaddingType::kSame; + op_params.padding_values.width = data->padding.width; + op_params.padding_values.height = data->padding.height; + op_params.stride_width = params->stride_width; + op_params.stride_height = params->stride_height; + op_params.dilation_width_factor = params->dilation_width_factor; + op_params.dilation_height_factor = params->dilation_height_factor; + op_params.depth_multiplier = params->depth_multiplier; + op_params.quantized_activation_min = data->output_activation_min; + op_params.quantized_activation_max = data->output_activation_max; + op_params.input_offset = input_offset; + op_params.weights_offset = filter_offset; + op_params.output_offset = output_offset; + op_params.output_multiplier = data->output_multiplier; + // Legacy ops used mixed left and right shifts. Now all are +ve-means-left. + op_params.output_shift = -data->output_shift; + + tflite::reference_ops::DepthwiseConv( + op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), + GetTensorShape(filter), GetTensorData<uint8_t>(filter), + GetTensorShape(bias), GetTensorData<int32_t>(bias), + GetTensorShape(output), GetTensorData<uint8_t>(output)); + return kTfLiteOk; +#else + TF_LITE_KERNEL_LOG(context, + "Type %s (%d) is not supported by ARC MLI Library.", + TfLiteTypeGetName(input->type), input->type); + return kTfLiteError; +#endif +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast<TfLiteDepthwiseConvParams*>(node->builtin_data); + + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* filter = GetInput(context, node, kFilterTensor); + const TfLiteTensor* bias = + (NumInputs(node) == 3) ? GetInput(context, node, kBiasTensor) : nullptr; + + const TfLiteType data_type = input->type; + int width = SizeOfDimension(input, 2); + int height = SizeOfDimension(input, 1); + int filter_width = SizeOfDimension(filter, 2); + int filter_height = SizeOfDimension(filter, 1); + + OpData data; + + // All per-channel quantized tensors need valid zero point and scale arrays. + if (input->type == kTfLiteInt8) { + TF_LITE_ENSURE_EQ(context, filter->quantization.type, + kTfLiteAffineQuantization); + + const auto* affine_quantization = + reinterpret_cast<TfLiteAffineQuantization*>( + filter->quantization.params); + TF_LITE_ENSURE(context, affine_quantization); + TF_LITE_ENSURE(context, affine_quantization->scale); + TF_LITE_ENSURE(context, affine_quantization->zero_point); + TF_LITE_ENSURE( + context, affine_quantization->scale->size == 1 || + affine_quantization->scale->size == + filter->dims->data[kDepthwiseConvQuantizedDimension]); + TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size, + affine_quantization->zero_point->size); + } + + bool mli_is_applicable = + IsMliApplicable(context, input, filter, bias, params); + TF_LITE_ENSURE_STATUS(CalculateOpData(context, node, params, width, height, + filter_width, filter_height, data_type, + mli_is_applicable, &data)); + switch (input->type) { // Already know in/out types are same. + case kTfLiteFloat32: + return EvalFloat(context, node, params, &data, input, filter, bias, + output); + break; + case kTfLiteInt8: + if (mli_is_applicable) { + return EvalMliQuantizedPerChannel(context, node, params, &data, input, + filter, bias, output); + } else { + return EvalQuantizedPerChannel(context, node, params, &data, input, + filter, bias, output); + } + break; + case kTfLiteUInt8: + return EvalQuantized(context, node, params, &data, input, filter, bias, + output); + break; + default: + TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.", + TfLiteTypeGetName(input->type), input->type); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace depthwise_conv + +TfLiteRegistration* Register_DEPTHWISE_CONV_2D() { + static TfLiteRegistration r = {/*init=*/nullptr, + /*free=*/nullptr, + /*prepare=*/nullptr, + /*invoke=*/depthwise_conv::Eval, + /*profiling_string=*/nullptr, + /*builtin_code=*/0, + /*custom_name=*/nullptr, + /*version=*/0}; + return &r; +} + +} // namespace micro +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/arc_mli/depthwise_conv_slicing_test.cc b/tensorflow/lite/micro/kernels/arc_mli/depthwise_conv_slicing_test.cc new file mode 100644 index 00000000000..2f528ea4e79 --- /dev/null +++ b/tensorflow/lite/micro/kernels/arc_mli/depthwise_conv_slicing_test.cc @@ -0,0 +1,550 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This test checks that slicing logic doesn`t affect result of depthwise +// convolution kernel +// +// This test doesn`t replace default depthwise convolution test +// (tensorflow/lite/micro/kernels/depthwise_conv_test.cc). It is added to the +// whole testset only in case MLI for ARC platform is used during generation +// (which is handled in arc_mli.inc). So such tests won`t be generated for other +// platforms. + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/micro/kernels/all_ops_resolver.h" +#include "tensorflow/lite/micro/testing/micro_test.h" +#include "tensorflow/lite/micro/testing/test_utils.h" + +namespace tflite { +namespace testing { +namespace { + +constexpr int kMaxFilterChannels = 64; +constexpr int kMaxBiasChannels = 64; + +// Index of the output tensor in context->tensors, specific to +// DepthwiseConv. +constexpr int kOutputTensorIndex = 3; + +// Creates a DepthwiseConv opeerator, calls it with the provided input tensors +// and some defaults parameters, and compares the output with +// expected_output_data. +// +// The tensors parameter contains both the input tensors as well as a +// preallocated output tensor into which the output is stored. +template <typename T> +TfLiteStatus ValidateDepthwiseConvGoldens(const T* expected_output_data, + int output_length, + TfLiteFusedActivation activation, + float tolerance, int tensors_size, + TfLiteTensor* tensors) { + TfLiteContext context; + PopulateContext(tensors, tensors_size, micro_test::reporter, &context); + + ::tflite::ops::micro::AllOpsResolver resolver; + const TfLiteRegistration* registration = + resolver.FindOp(tflite::BuiltinOperator_DEPTHWISE_CONV_2D, 1); + TF_LITE_MICRO_EXPECT_NE(nullptr, registration); + + int input_depth = tensors[0].dims->data[3]; + int output_depth = tensors[1].dims->data[3]; + int depth_mul = output_depth / input_depth; + TfLiteDepthwiseConvParams builtin_data; + builtin_data.padding = kTfLitePaddingValid; + builtin_data.activation = activation; + builtin_data.stride_height = 1; + builtin_data.stride_width = 1; + builtin_data.dilation_height_factor = 1; + builtin_data.dilation_width_factor = 1; + builtin_data.depth_multiplier = depth_mul; + + const char* init_data = reinterpret_cast<const char*>(&builtin_data); + size_t init_data_size = 0; + void* user_data = nullptr; + if (registration->init) { + user_data = registration->init(&context, init_data, init_data_size); + } + int inputs_array_data[] = {3, 0, 1, 2}; + TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); + int outputs_array_data[] = {1, 3}; + TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); + int temporaries_array_data[] = {0}; + TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data); + + TfLiteNode node; + node.inputs = inputs_array; + node.outputs = outputs_array; + node.temporaries = temporaries_array; + node.user_data = user_data; + node.builtin_data = reinterpret_cast<void*>(&builtin_data); + node.custom_initial_data = nullptr; + node.custom_initial_data_size = 0; + node.delegate = nullptr; + if (registration->prepare) { + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); + } + TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke); + TfLiteStatus invoke_status = registration->invoke(&context, &node); + if (invoke_status != kTfLiteOk) { + return invoke_status; + } + + if (registration->free) { + registration->free(&context, user_data); + } + + const T* output_data = tflite::GetTensorData<T>(&tensors[kOutputTensorIndex]); + for (int i = 0; i < output_length; ++i) { + TF_LITE_MICRO_EXPECT_NEAR(expected_output_data[i], output_data[i], + tolerance); + } + return kTfLiteOk; +} + +void TestDepthwiseConvQuantizedPerChannel( + const int* input_dims_data, const float* input_data, + int8_t* input_quantized, float input_scale, int input_zero_point, + const int* filter_dims_data, const float* filter_data, + int8_t* filter_data_quantized, const int* bias_dims_data, + const float* bias_data, int32_t* bias_data_quantized, + const int* output_dims_data, const float* expected_output_data, + int8_t* expected_output_data_quantized, int8_t* output_data, + float output_scale, int output_zero_point, + TfLiteFusedActivation activation) { + TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); + TfLiteIntArray* filter_dims = IntArrayFromInts(filter_dims_data); + TfLiteIntArray* bias_dims = IntArrayFromInts(bias_dims_data); + TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data); + const int output_dims_count = ElementCount(*output_dims); + + int filter_zero_points[kMaxFilterChannels]; + float filter_scales[kMaxFilterChannels]; + int bias_zero_points[kMaxBiasChannels]; + float bias_scales[kMaxBiasChannels]; + TfLiteAffineQuantization filter_quant; + TfLiteAffineQuantization bias_quant; + TfLiteTensor input_tensor = + CreateQuantizedTensor(input_data, input_quantized, input_dims, + input_scale, input_zero_point, "input_tensor"); + TfLiteTensor filter_tensor = CreateSymmetricPerChannelQuantizedTensor( + filter_data, filter_data_quantized, filter_dims, filter_scales, + filter_zero_points, &filter_quant, 3 /* quantized dimension */, + "filter_tensor"); + TfLiteTensor bias_tensor = CreatePerChannelQuantizedBiasTensor( + bias_data, bias_data_quantized, bias_dims, input_scale, &filter_scales[1], + bias_scales, bias_zero_points, &bias_quant, 3 /* quantized dimension */, + "bias_tensor"); + TfLiteTensor output_tensor = + CreateQuantizedTensor(output_data, output_dims, output_scale, + input_zero_point, "output_tensor"); + + float input_scales[] = {1, input_scale}; + int input_zero_points[] = {1, input_zero_point}; + TfLiteAffineQuantization input_quant = {FloatArrayFromFloats(input_scales), + IntArrayFromInts(input_zero_points)}; + input_tensor.quantization = {kTfLiteAffineQuantization, &input_quant}; + + float output_scales[] = {1, output_scale}; + int output_zero_points[] = {1, output_zero_point}; + TfLiteAffineQuantization output_quant = { + FloatArrayFromFloats(output_scales), + IntArrayFromInts(output_zero_points)}; + output_tensor.quantization = {kTfLiteAffineQuantization, &output_quant}; + + constexpr int inputs_size = 3; + constexpr int outputs_size = 1; + constexpr int tensors_size = inputs_size + outputs_size; + TfLiteTensor tensors[tensors_size] = { + input_tensor, + filter_tensor, + bias_tensor, + output_tensor, + }; + + AsymmetricQuantize(expected_output_data, expected_output_data_quantized, + output_dims_count, output_scale, output_zero_point); + + TF_LITE_MICRO_EXPECT_EQ( + kTfLiteOk, ValidateDepthwiseConvGoldens(expected_output_data_quantized, + output_dims_count, activation, + 1.0, tensors_size, tensors)); +} + +} // namespace +} // namespace testing +} // namespace tflite + +TF_LITE_MICRO_TESTS_BEGIN + +// Test group 1 +TF_LITE_MICRO_TEST(SystemTestQuantizedPerChannel1) { + const int input_elements = 20; + const int input_shape[] = {4, 1, 5, 2, 2}; + const float input_values[] = {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}; + const int filter_elements = 36; + const int filter_shape[] = {4, 2, 3, 3, 2}; + const float filter_values[] = {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}; + const int bias_elements = 2; + const int bias_shape[] = {4, 1, 1, 1, 2}; + const int output_elements = 20; + const float bias_values[] = {2, 2}; + const float golden[] = {34, 34, 34, 34, 50, 50, 50, 50, 50, 50, + 50, 50, 50, 50, 50, 50, 34, 34, 34, 34}; + const int output_shape[] = {4, 1, 5, 2, 2}; + const int output_dims_count = 20; + int8_t output_data[output_dims_count]; + + const float input_scale = 1.0; + const float output_scale = 1.0f; + const int input_zero_point = 0; + const int output_zero_point = 0; + + int8_t input_quantized[input_elements]; + int8_t filter_quantized[filter_elements]; + int32_t bias_quantized[bias_elements]; + int8_t golden_quantized[output_elements]; + int zero_points[bias_elements + 1]; + float scales[bias_elements + 1]; + + tflite::testing::TestDepthwiseConvQuantizedPerChannel( + input_shape, input_values, input_quantized, input_scale, input_zero_point, + filter_shape, filter_values, filter_quantized, bias_shape, bias_values, + bias_quantized, output_shape, golden, golden_quantized, output_data, + output_scale, output_zero_point, kTfLiteActNone); +} + +TF_LITE_MICRO_TEST(LocalTestQuantizedPerChannel1) { + const int input_elements = 20; + const int input_shape[] = {4, 1, 5, 2, 2}; + const int filter_elements = 36; + const int filter_shape[] = {4, 2, 3, 3, 2}; + const int bias_elements = 2; + const int bias_shape[] = {4, 1, 1, 1, 2}; + const int output_elements = 20; + const int output_shape[] = {4, 1, 5, 2, 2}; + const int output_dims_count = 20; + +#pragma Bss(".Zdata") + const float input_values[] = {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}; + const float filter_values[] = {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}; + const float bias_values[] = {2, 2}; + int8_t output_data[output_dims_count]; +#pragma Bss() + + const float golden[] = {34, 34, 34, 34, 50, 50, 50, 50, 50, 50, + 50, 50, 50, 50, 50, 50, 34, 34, 34, 34}; + + const float input_scale = 1.0; + const float output_scale = 1.0f; + const int input_zero_point = 0; + const int output_zero_point = 0; + + int8_t input_quantized[input_elements]; + int8_t filter_quantized[filter_elements]; + int32_t bias_quantized[bias_elements]; + int8_t golden_quantized[output_elements]; + int zero_points[bias_elements + 1]; + float scales[bias_elements + 1]; + + tflite::testing::TestDepthwiseConvQuantizedPerChannel( + input_shape, input_values, input_quantized, input_scale, input_zero_point, + filter_shape, filter_values, filter_quantized, bias_shape, bias_values, + bias_quantized, output_shape, golden, golden_quantized, output_data, + output_scale, output_zero_point, kTfLiteActNone); +} + +// Test group 2 +TF_LITE_MICRO_TEST(SystemTestQuantizedPerChannel2) { + const int input_elements = 80; + const int input_shape[] = {4, 1, 20, 2, 2}; + const float input_values[] = {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}; + const int filter_elements = 36; + const int filter_shape[] = {4, 2, 3, 3, 2}; + const float filter_values[] = {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}; + const int bias_elements = 2; + const int bias_shape[] = {4, 1, 1, 1, 2}; + const int output_elements = 80; + const float bias_values[] = {2, 2}; + const float golden[] = { + 34, 34, 34, 34, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, + 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, + 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, + 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, + 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 34, 34, 34, 34}; + const int output_shape[] = {4, 1, 20, 2, 2}; + const int output_dims_count = 80; + int8_t output_data[output_dims_count]; + + const float input_scale = 1.0; + const float output_scale = 1.0f; + const int input_zero_point = 0; + const int output_zero_point = 0; + + int8_t input_quantized[input_elements]; + int8_t filter_quantized[filter_elements]; + int32_t bias_quantized[bias_elements]; + int8_t golden_quantized[output_elements]; + int zero_points[bias_elements + 1]; + float scales[bias_elements + 1]; + + tflite::testing::TestDepthwiseConvQuantizedPerChannel( + input_shape, input_values, input_quantized, input_scale, input_zero_point, + filter_shape, filter_values, filter_quantized, bias_shape, bias_values, + bias_quantized, output_shape, golden, golden_quantized, output_data, + output_scale, output_zero_point, kTfLiteActNone); +} + +TF_LITE_MICRO_TEST(LocalTestQuantizedPerChannel2) { + const int input_elements = 80; + const int input_shape[] = {4, 1, 20, 2, 2}; + const int filter_elements = 36; + const int filter_shape[] = {4, 2, 3, 3, 2}; + const int bias_elements = 2; + const int bias_shape[] = {4, 1, 1, 1, 2}; + const int output_elements = 80; + const int output_shape[] = {4, 1, 20, 2, 2}; + const int output_dims_count = 80; + +#pragma Bss(".Zdata") + float input_values[] = {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}; + float filter_values[] = {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}; + float bias_values[] = {2, 2}; + int8_t output_data[output_dims_count]; +#pragma Bss() + + const float golden[] = { + 34, 34, 34, 34, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, + 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, + 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, + 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, + 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 34, 34, 34, 34}; + + const float input_scale = 1.0; + const float output_scale = 1.0f; + const int input_zero_point = 0; + const int output_zero_point = 0; + + int8_t input_quantized[input_elements]; + int8_t filter_quantized[filter_elements]; + int32_t bias_quantized[bias_elements]; + int8_t golden_quantized[output_elements]; + int zero_points[bias_elements + 1]; + float scales[bias_elements + 1]; + + tflite::testing::TestDepthwiseConvQuantizedPerChannel( + input_shape, input_values, input_quantized, input_scale, input_zero_point, + filter_shape, filter_values, filter_quantized, bias_shape, bias_values, + bias_quantized, output_shape, golden, golden_quantized, output_data, + output_scale, output_zero_point, kTfLiteActNone); +} + +// Test group 3 +TF_LITE_MICRO_TEST(SystemTestQuantizedPerChannel3) { + const int input_elements = 40; + const int input_shape[] = {4, 1, 2, 2, 10}; + const float input_values[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + const int filter_elements = 90; + const int filter_shape[] = {4, 1, 3, 3, 10}; + const float filter_values[] = { + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + const int bias_elements = 1; + const int bias_shape[] = {4, 1, 1, 1, 1}; + const int output_elements = 4; + const float bias_values[] = {1}; + const float golden[] = {41, 41, 41, 41}; + const int output_shape[] = {4, 1, 2, 2, 1}; + const int output_dims_count = 4; + int8_t output_data[output_dims_count]; + + const float input_scale = 1.0; + const float output_scale = 1.0f; + const int input_zero_point = 0; + const int output_zero_point = 0; + + int8_t input_quantized[input_elements]; + int8_t filter_quantized[filter_elements]; + int32_t bias_quantized[bias_elements]; + int8_t golden_quantized[output_elements]; + int zero_points[bias_elements + 1]; + float scales[bias_elements + 1]; + + tflite::testing::TestDepthwiseConvQuantizedPerChannel( + input_shape, input_values, input_quantized, input_scale, input_zero_point, + filter_shape, filter_values, filter_quantized, bias_shape, bias_values, + bias_quantized, output_shape, golden, golden_quantized, output_data, + output_scale, output_zero_point, kTfLiteActNone); +} + +TF_LITE_MICRO_TEST(LocalTestQuantizedPerChannel3) { + const int input_elements = 40; + const int input_shape[] = {4, 1, 2, 2, 10}; + const int filter_elements = 90; + const int filter_shape[] = {4, 1, 3, 3, 10}; + const int bias_elements = 1; + const int bias_shape[] = {4, 1, 1, 1, 1}; + const int output_elements = 4; + const int output_shape[] = {4, 1, 2, 2, 1}; + const int output_dims_count = 4; + +#pragma Bss(".Zdata") + float input_values[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + float filter_values[] = { + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + float bias_values[] = {1}; + int8_t output_data[output_dims_count]; +#pragma Bss() + + const float golden[] = {41, 41, 41, 41}; + + const float input_scale = 1.0; + const float output_scale = 1.0f; + const int input_zero_point = 0; + const int output_zero_point = 0; + + int8_t input_quantized[input_elements]; + int8_t filter_quantized[filter_elements]; + int32_t bias_quantized[bias_elements]; + int8_t golden_quantized[output_elements]; + int zero_points[bias_elements + 1]; + float scales[bias_elements + 1]; + + tflite::testing::TestDepthwiseConvQuantizedPerChannel( + input_shape, input_values, input_quantized, input_scale, input_zero_point, + filter_shape, filter_values, filter_quantized, bias_shape, bias_values, + bias_quantized, output_shape, golden, golden_quantized, output_data, + output_scale, output_zero_point, kTfLiteActNone); +} + +// Test group 4 +TF_LITE_MICRO_TEST(SystemTestQuantizedPerChannel4) { + const int input_elements = 80; + const int input_shape[] = {4, 1, 4, 2, 10}; + const float input_values[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + const int filter_elements = 90; + const int filter_shape[] = {4, 1, 3, 3, 10}; + const float filter_values[] = { + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + const int bias_elements = 1; + const int bias_shape[] = {4, 1, 1, 1, 1}; + const int output_elements = 8; + const float bias_values[] = {1}; + const float golden[] = {41, 41, 61, 61, 61, 61, 41, 41}; + const int output_shape[] = {4, 1, 4, 2, 1}; + const int output_dims_count = 8; + int8_t output_data[output_dims_count]; + + const float input_scale = 1.0; + const float output_scale = 1.0f; + const int input_zero_point = 0; + const int output_zero_point = 0; + + int8_t input_quantized[input_elements]; + int8_t filter_quantized[filter_elements]; + int32_t bias_quantized[bias_elements]; + int8_t golden_quantized[output_elements]; + int zero_points[bias_elements + 1]; + float scales[bias_elements + 1]; + + tflite::testing::TestDepthwiseConvQuantizedPerChannel( + input_shape, input_values, input_quantized, input_scale, input_zero_point, + filter_shape, filter_values, filter_quantized, bias_shape, bias_values, + bias_quantized, output_shape, golden, golden_quantized, output_data, + output_scale, output_zero_point, kTfLiteActNone); +} + +TF_LITE_MICRO_TEST(LocalTestQuantizedPerChannel4) { + const int input_elements = 80; + const int input_shape[] = {4, 1, 4, 2, 10}; + const int filter_elements = 90; + const int filter_shape[] = {4, 1, 3, 3, 10}; + const int bias_elements = 1; + const int bias_shape[] = {4, 1, 1, 1, 1}; + const int output_elements = 8; + const int output_shape[] = {4, 1, 4, 2, 1}; + const int output_dims_count = 8; + +#pragma Bss(".Zdata") + float input_values[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + float filter_values[] = { + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + float bias_values[] = {1}; + int8_t output_data[output_dims_count]; +#pragma Bss() + + const float golden[] = {41, 41, 61, 61, 61, 61, 41, 41}; + + const float input_scale = 1.0; + const float output_scale = 1.0f; + const int input_zero_point = 0; + const int output_zero_point = 0; + + int8_t input_quantized[input_elements]; + int8_t filter_quantized[filter_elements]; + int32_t bias_quantized[bias_elements]; + int8_t golden_quantized[output_elements]; + int zero_points[bias_elements + 1]; + float scales[bias_elements + 1]; + + tflite::testing::TestDepthwiseConvQuantizedPerChannel( + input_shape, input_values, input_quantized, input_scale, input_zero_point, + filter_shape, filter_values, filter_quantized, bias_shape, bias_values, + bias_quantized, output_shape, golden, golden_quantized, output_data, + output_scale, output_zero_point, kTfLiteActNone); +} +TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/micro/kernels/arc_mli/fully_connected.cc b/tensorflow/lite/micro/kernels/arc_mli/fully_connected.cc new file mode 100644 index 00000000000..fe077c99fac --- /dev/null +++ b/tensorflow/lite/micro/kernels/arc_mli/fully_connected.cc @@ -0,0 +1,385 @@ +/* Copyright 2017-2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/kernels/internal/reference/fully_connected.h" + +#include "mli_api.h" // NOLINT +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/micro/kernels/arc_mli/mli_slicers.h" +#include "tensorflow/lite/micro/kernels/arc_mli/mli_tf_utils.h" +#include "tensorflow/lite/micro/kernels/arc_mli/scratch_buf_mgr.h" +#include "tensorflow/lite/micro/kernels/arc_mli/scratch_buffers.h" + +namespace tflite { +namespace ops { +namespace micro { +namespace fully_connected { +namespace { + +struct OpData { + // The scaling factor from input to output (aka the 'real multiplier') can + // be represented as a fixed point multiplier plus a left shift. + int32_t output_multiplier; + int output_shift; + // The range of the fused activation layer. For example for kNone and + // uint8_t these would be 0 and 255. + int32_t output_activation_min; + int32_t output_activation_max; + // The index of the temporary tensor where the quantized inputs are cached. + int input_quantized_index; +}; + +constexpr int kInputTensor = 0; +constexpr int kWeightsTensor = 1; +constexpr int kBiasTensor = 2; +constexpr int kOutputTensor = 0; + +bool IsMliApplicable(TfLiteContext* context, const TfLiteTensor* input, + const TfLiteTensor* filter, const TfLiteTensor* bias, + const TfLiteFullyConnectedParams* params) { + // MLI optimized version only supports int8 dataype and no fused Relu and + // symmetric per-tensor quantization of weights (not per-axis) + bool ret_val = (filter->type == kTfLiteInt8) && + (input->type == kTfLiteInt8) && (bias->type == kTfLiteInt32) && + (params->activation == kTfLiteActNone) && + (filter->params.zero_point == 0); + return ret_val; +} + +TfLiteStatus CalculateOpData(TfLiteContext* context, + TfLiteFullyConnectedParams* params, + TfLiteType data_type, const TfLiteTensor* input, + const TfLiteTensor* filter, + const TfLiteTensor* bias, TfLiteTensor* output, + OpData* data) { + TfLiteStatus status = kTfLiteOk; +#if !defined(TF_LITE_STRIP_REFERENCE_IMPL) + if (data_type != kTfLiteFloat32 && + !IsMliApplicable(context, input, filter, bias, params)) { + double real_multiplier = 0.0; + TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( + context, input, filter, bias, output, &real_multiplier)); + int exponent; + QuantizeMultiplier(real_multiplier, &data->output_multiplier, &exponent); + data->output_shift = -exponent; + TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized( + context, params->activation, output, &data->output_activation_min, + &data->output_activation_max)); + } +#endif + return status; +} + +} // namespace + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + OpData* data = nullptr; + TfLiteStatus status = context->AllocatePersistentBuffer( + context, sizeof(OpData), reinterpret_cast<void**>(&data)); + if (status != kTfLiteOk || data == nullptr) { + return nullptr; + } + return data; +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + OpData* data = reinterpret_cast<OpData*>(node->user_data); + auto* params = + reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data); + + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor); + const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + TF_LITE_ENSURE(context, data != nullptr); + TF_LITE_ENSURE_EQ(context, input->type, output->type); + TF_LITE_ENSURE_MSG(context, input->type == filter->type, + "Hybrid models are not supported on TFLite Micro."); + + TfLiteType data_type = input->type; + TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, data_type, input, + filter, bias, output, data)); + + return kTfLiteOk; +} + +TfLiteStatus EvalMliQuantizedInt8(TfLiteContext* context, TfLiteNode* node, + TfLiteFullyConnectedParams* params, + OpData* data, const TfLiteTensor* input, + const TfLiteTensor* filter, + const TfLiteTensor* bias, + TfLiteTensor* output) { + mli_tensor mli_in = {0}; + mli_tensor mli_weights = {0}; + mli_tensor mli_bias = {0}; + mli_tensor mli_out = {0}; + + ConvertToMliTensor<int8_t>(input, &mli_in); + ConvertToMliTensor<int8_t>(filter, &mli_weights); + ConvertToMliTensor<int32_t>(bias, &mli_bias); + ConvertToMliTensor<int8_t>(output, &mli_out); + + /* The input tensor can have more than 2 dimensions. for the compute this + doesn't make any difference because all the inputs or a batch entry will + be used anyway. because the MLI kernel doesn't recognize the multiple + dimensions, the tensor shape is casted to a {batchnum, inputsize} shape. */ + mli_in.shape[0] = mli_out.shape[0]; + mli_in.shape[1] = mli_weights.shape[1]; + mli_in.shape[2] = 0; + mli_in.shape[3] = 0; + mli_in.rank = 2; + + // Tensors for data in fast (local) memory and config to copy data from + // external to local memory + mli_tensor weights_local = mli_weights; + mli_tensor bias_local = mli_bias; + mli_tensor in_local = mli_in; + mli_tensor out_local = mli_out; + mli_mov_cfg_t copy_config; + mli_mov_cfg_for_copy(©_config); + const int weight_out_dimension = 0; + const int out_tensor_dimension = 1; + const int input_size_dimension = 1; + int slice_size = mli_weights.shape[weight_out_dimension]; + + /* allocate the local buffers, and compute the slice size */ + TF_LITE_ENSURE_STATUS(get_arc_scratch_buffer_for_fully_connect_tensors( + context, &in_local, &weights_local, &bias_local, &out_local)); + TF_LITE_ENSURE_STATUS(arc_scratch_buffer_calc_slice_size_weights( + &weights_local, &bias_local, weight_out_dimension, &slice_size)); + int max_out_slice_size = + out_local.capacity / mli_hlp_tensor_element_size(&out_local); + if (slice_size > max_out_slice_size) slice_size = max_out_slice_size; + + /* is_local indicates that the tensor is already in local memory, + so in that case the original tensor can be used, + and there is no need to copy it to the local tensor*/ + const bool in_is_local = in_local.data == mli_in.data; + const bool out_is_local = out_local.data == mli_out.data; + const bool w_is_local = weights_local.data == mli_weights.data; + const bool b_is_local = bias_local.data == mli_bias.data; + + TensorSlicer w_slice(&mli_weights, weight_out_dimension, slice_size); + TensorSlicer b_slice(&mli_bias, weight_out_dimension, slice_size); + TensorSlicer out_ch_slice(&mli_out, out_tensor_dimension, slice_size, 0, 0, 0, + true); + + mli_tensor* w_ptr = w_is_local ? w_slice.Sub() : &weights_local; + mli_tensor* b_ptr = b_is_local ? b_slice.Sub() : &bias_local; + + void* input_buffer_ptr = NULL; + + while (!w_slice.Done()) { + mli_mov_tensor_sync(w_slice.Sub(), ©_config, w_ptr); + mli_mov_tensor_sync(b_slice.Sub(), ©_config, b_ptr); + + // Slice the input over the batches (one at a time with the size of a + // complete input) + TensorSlicer in_slice(&mli_in, input_size_dimension, + mli_in.shape[input_size_dimension]); + + /* output tensor is alreade sliced in the output size dimension. + out_ch_slice.Sub() is the tensor for the amount of output size of this + itteration of the weight slice loop. This tensor needs to be further + sliced over the batch */ + TensorSlicer out_slice(out_ch_slice.Sub(), out_tensor_dimension, + slice_size); + + /* setup the pointers to the local or remote tensor to make the code + * inside the loop easier. */ + mli_tensor* in_ptr = in_is_local ? in_slice.Sub() : &in_local; + mli_tensor* out_ptr = out_is_local ? out_slice.Sub() : &out_local; + + while (!out_slice.Done()) { + // if same input copy as previous iteration, skip the copy of input + if (in_slice.Sub()->data != input_buffer_ptr) { + mli_mov_tensor_sync(in_slice.Sub(), ©_config, in_ptr); + input_buffer_ptr = in_slice.Sub()->data; + } + mli_krn_fully_connected_sa8_sa8_sa32(in_ptr, w_ptr, b_ptr, out_ptr); + mli_mov_tensor_sync(out_ptr, ©_config, out_slice.Sub()); + + in_slice.Next(); + out_slice.Next(); + } + w_slice.Next(); + b_slice.Next(); + out_ch_slice.Next(); + } + return kTfLiteOk; +} + +TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node, + TfLiteFullyConnectedParams* params, OpData* data, + const TfLiteTensor* input, + const TfLiteTensor* filter, + const TfLiteTensor* bias, TfLiteTensor* output) { +#if !defined(TF_LITE_STRIP_REFERENCE_IMPL) + FullyConnectedParams op_params; + op_params.input_offset = -input->params.zero_point; + op_params.weights_offset = -filter->params.zero_point; + op_params.output_offset = output->params.zero_point; + op_params.output_multiplier = data->output_multiplier; + op_params.output_shift = -data->output_shift; + op_params.quantized_activation_min = data->output_activation_min; + op_params.quantized_activation_max = data->output_activation_max; + + reference_integer_ops::FullyConnected( + op_params, GetTensorShape(input), GetTensorData<int8_t>(input), + GetTensorShape(filter), GetTensorData<int8_t>(filter), + GetTensorShape(bias), GetTensorData<int32_t>(bias), + GetTensorShape(output), GetTensorData<int8_t>(output)); + return kTfLiteOk; +#else + TF_LITE_KERNEL_LOG(context, + "Node configuration is not supported by ARC MLI Library."); + return kTfLiteError; +#endif +} + +TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, + TfLiteFullyConnectedParams* params, OpData* data, + const TfLiteTensor* input, + const TfLiteTensor* filter, const TfLiteTensor* bias, + TfLiteTensor* output) { +#if !defined(TF_LITE_STRIP_REFERENCE_IMPL) + const int32_t input_offset = -input->params.zero_point; + const int32_t filter_offset = -filter->params.zero_point; + const int32_t output_offset = output->params.zero_point; + + tflite::FullyConnectedParams op_params; + op_params.input_offset = input_offset; + op_params.weights_offset = filter_offset; + op_params.output_offset = output_offset; + op_params.output_multiplier = data->output_multiplier; + // Legacy ops used mixed left and right shifts. Now all are +ve-means-left. + op_params.output_shift = -data->output_shift; + op_params.quantized_activation_min = data->output_activation_min; + op_params.quantized_activation_max = data->output_activation_max; + +#define TF_LITE_FULLY_CONNECTED(output_data_type) \ + reference_ops::FullyConnected( \ + op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), \ + GetTensorShape(filter), GetTensorData<uint8_t>(filter), \ + GetTensorShape(bias), GetTensorData<int32_t>(bias), \ + GetTensorShape(output), GetTensorData<output_data_type>(output)) + switch (output->type) { + case kTfLiteUInt8: + TF_LITE_FULLY_CONNECTED(uint8_t); + break; + case kTfLiteInt16: + TF_LITE_FULLY_CONNECTED(int16_t); + break; + default: + TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.", + TfLiteTypeGetName(output->type), output->type); + return kTfLiteError; + } + return kTfLiteOk; +#else + TF_LITE_KERNEL_LOG(context, + "Type %s (%d) is not supported by ARC MLI Library.", + TfLiteTypeGetName(input->type), input->type); + return kTfLiteError; +#endif +} + +TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node, + TfLiteFullyConnectedParams* params, OpData* data, + const TfLiteTensor* input, const TfLiteTensor* filter, + const TfLiteTensor* bias, TfLiteTensor* output) { +#if !defined(TF_LITE_STRIP_REFERENCE_IMPL) + float output_activation_min, output_activation_max; + CalculateActivationRange(params->activation, &output_activation_min, + &output_activation_max); + tflite::FullyConnectedParams op_params; + op_params.float_activation_min = output_activation_min; + op_params.float_activation_max = output_activation_max; + tflite::reference_ops::FullyConnected( + op_params, GetTensorShape(input), GetTensorData<float>(input), + GetTensorShape(filter), GetTensorData<float>(filter), + GetTensorShape(bias), GetTensorData<float>(bias), GetTensorShape(output), + GetTensorData<float>(output)); + return kTfLiteOk; +#else + TF_LITE_KERNEL_LOG(context, + "Type %s (%d) is not supported by ARC MLI Library.", + TfLiteTypeGetName(input->type), input->type); + return kTfLiteError; +#endif +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data); + + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor); + const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + OpData* data = reinterpret_cast<OpData*>(node->user_data); + TF_LITE_ENSURE(context, data != nullptr); + + // Checks in Prepare ensure input, output and filter types are all the same. + switch (input->type) { + case kTfLiteFloat32: + return EvalFloat(context, node, params, data, input, filter, bias, + output); + case kTfLiteInt8: + if (IsMliApplicable(context, input, filter, bias, params)) { + return EvalMliQuantizedInt8(context, node, params, data, input, filter, + bias, output); + } else { + return EvalQuantizedInt8(context, node, params, data, input, filter, + bias, output); + } + + case kTfLiteUInt8: + return EvalQuantized(context, node, params, data, input, filter, bias, + output); + + default: + TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.", + TfLiteTypeGetName(filter->type), filter->type); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace fully_connected + +TfLiteRegistration* Register_FULLY_CONNECTED() { + static TfLiteRegistration r = {/*init=*/fully_connected::Init, + /*free=*/nullptr, + /*prepare=*/fully_connected::Prepare, + /*invoke=*/fully_connected::Eval, + /*profiling_string=*/nullptr, + /*builtin_code=*/0, + /*custom_name=*/nullptr, + /*version=*/0}; + return &r; +} + +} // namespace micro +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/arc_mli/fully_connected_slicing_test.cc b/tensorflow/lite/micro/kernels/arc_mli/fully_connected_slicing_test.cc new file mode 100644 index 00000000000..a64e7bdff4a --- /dev/null +++ b/tensorflow/lite/micro/kernels/arc_mli/fully_connected_slicing_test.cc @@ -0,0 +1,425 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This test checks that slicing logic doesn`t affect result of fully +// connected kernel +// +// This test doesn`t replace default fully connected test +// (tensorflow/lite/micro/kernels/fully_connected_test.cc). It is added to the +// whole testset only in case MLI for ARC platform is used during generation +// (which is handled in arc_mli.inc). So such tests won`t be generated for other +// platforms. + +#include <cstdint> + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/micro/kernels/all_ops_resolver.h" +#include "tensorflow/lite/micro/testing/micro_test.h" +#include "tensorflow/lite/micro/testing/test_utils.h" + +namespace tflite { +namespace testing { +namespace { + +template <typename T> +void TestFullyConnectedQuantized( + const int* input_dims_data, const T* input_data, const float input_min, + const float input_max, const int* weights_dims_data, const T* weights_data, + const float weights_min, const float weights_max, const int* bias_dims_data, + const int32_t* bias_data, const float bias_scale, + const T* expected_output_data, const int* output_dims_data, + const float output_min, const float output_max, + TfLiteFusedActivation activation, T* output_data) { + TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); + TfLiteIntArray* weights_dims = IntArrayFromInts(weights_dims_data); + TfLiteIntArray* bias_dims = IntArrayFromInts(bias_dims_data); + TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data); + const int output_dims_count = ElementCount(*output_dims); + + constexpr int inputs_size = 3; + constexpr int outputs_size = 1; + constexpr int tensors_size = inputs_size + outputs_size; + TfLiteTensor tensors[tensors_size] = { + CreateQuantizedTensor(input_data, input_dims, "input_tensor", input_min, + input_max), + CreateQuantizedTensor(weights_data, weights_dims, "weights_tensor", + weights_min, weights_max), + CreateQuantized32Tensor(bias_data, bias_dims, "bias_tensor", bias_scale), + CreateQuantizedTensor(output_data, output_dims, "output_tensor", + output_min, output_max), + }; + + tensors[0].params.zero_point = 0; + tensors[1].params.zero_point = 0; + tensors[3].params.zero_point = 0; + + TfLiteContext context; + PopulateContext(tensors, tensors_size, micro_test::reporter, &context); + + ::tflite::ops::micro::AllOpsResolver resolver; + const TfLiteRegistration* registration = + resolver.FindOp(tflite::BuiltinOperator_FULLY_CONNECTED, 4); + TF_LITE_MICRO_EXPECT_NE(nullptr, registration); + + TfLiteFullyConnectedParams builtin_data = { + activation, + kTfLiteFullyConnectedWeightsFormatDefault, + }; + const char* init_data = reinterpret_cast<const char*>(&builtin_data); + size_t init_data_size = 0; + void* user_data = nullptr; + if (registration->init) { + user_data = registration->init(&context, init_data, init_data_size); + } + + int inputs_array_data[] = {3, 0, 1, 2}; + TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); + int outputs_array_data[] = {1, 3}; + TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); + int temporaries_array_data[] = {0}; + TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data); + + TfLiteNode node; + node.inputs = inputs_array; + node.outputs = outputs_array; + node.temporaries = temporaries_array; + node.user_data = user_data; + node.builtin_data = reinterpret_cast<void*>(&builtin_data); + node.custom_initial_data = nullptr; + node.custom_initial_data_size = 0; + node.delegate = nullptr; + + if (registration->prepare) { + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); + } + TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node)); + if (registration->free) { + registration->free(&context, user_data); + } + for (int i = 0; i < output_dims_count; ++i) { + TF_LITE_MICRO_EXPECT_EQ(expected_output_data[i], output_data[i]); + } +} + +} // namespace +} // namespace testing +} // namespace tflite + +TF_LITE_MICRO_TESTS_BEGIN + +// Test group 1 +TF_LITE_MICRO_TEST(SystemSimpleTestQuantized1) { + const float input_min = -128.0f; + const float input_max = 127.0f; + const float weights_min = -128.0f; + const float weights_max = 127.0f; + const float bias_scale = 1.0f; + const float output_min = -128.0f; + const float output_max = 127.0f; + + const int input_dims_data[] = {2, 2, 10}; + const int8_t input_data[] = {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}; + const int weights_dims_data[] = {2, 3, 10}; + const int8_t weights_data[] = {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}; + const int bias_dims_data[] = {1, 3}; + const int32_t bias_data[] = {1, 1, 1}; + const int8_t expected_output_data[] = {41, 41, 41, 41, 41, 41}; + const int output_dims_data[] = {2, 2, 3}; + + const int output_dims_count = 6; + int8_t output_data[output_dims_count]; + tflite::testing::TestFullyConnectedQuantized<int8_t>( + input_dims_data, input_data, input_min, input_max, weights_dims_data, + weights_data, weights_min, weights_max, bias_dims_data, bias_data, + bias_scale, expected_output_data, output_dims_data, output_min, + output_max, kTfLiteActNone, output_data); +} + +TF_LITE_MICRO_TEST(LocalSimpleTestQuantized1) { + const float input_min = -128.0f; + const float input_max = 127.0f; + const float weights_min = -128.0f; + const float weights_max = 127.0f; + const float bias_scale = 1.0f; + const float output_min = -128.0f; + const float output_max = 127.0f; + + const int input_dims_data_local[] = {2, 2, 10}; + const int weights_dims_data_local[] = {2, 3, 10}; + const int bias_dims_data_local[] = {1, 3}; + const int output_dims_data_local[] = {2, 2, 3}; + + const int output_dims_count = 6; + +#pragma Bss(".Zdata") + const int8_t input_data_local[] = {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}; + const int8_t weights_data_local[] = {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}; + const int32_t bias_data_local[] = {1, 1, 1}; + int8_t output_data_local[output_dims_count]; +#pragma Bss() + + const int8_t expected_output_data[] = {41, 41, 41, 41, 41, 41}; + + tflite::testing::TestFullyConnectedQuantized<int8_t>( + input_dims_data_local, input_data_local, input_min, input_max, + weights_dims_data_local, weights_data_local, weights_min, weights_max, + bias_dims_data_local, bias_data_local, bias_scale, expected_output_data, + output_dims_data_local, output_min, output_max, kTfLiteActNone, + output_data_local); +} + +// Test group 2 +TF_LITE_MICRO_TEST(SystemSimpleTestQuantized2) { + const float input_min = -128.0f; + const float input_max = 127.0f; + const float weights_min = -128.0f; + const float weights_max = 127.0f; + const float bias_scale = 1.0f; + const float output_min = -128.0f; + const float output_max = 127.0f; + + const int input_dims_data_2[] = {2, 10, 4}; + const int8_t input_data_2[] = {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}; + const int weights_dims_data_2[] = {2, 6, 4}; + const int8_t weights_data_2[] = {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}; + const int bias_dims_data_2[] = {1, 6}; + const int32_t bias_data_2[] = {1, 1, 1, 1, 1, 1}; + const int8_t expected_output_data_2[] = { + 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, + 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, + 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, + 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17}; + const int output_dims_data_2[] = {2, 10, 6}; + + const int output_dims_count_2 = 60; + int8_t output_data_2[output_dims_count_2]; + tflite::testing::TestFullyConnectedQuantized<int8_t>( + input_dims_data_2, input_data_2, input_min, input_max, + weights_dims_data_2, weights_data_2, weights_min, weights_max, + bias_dims_data_2, bias_data_2, bias_scale, expected_output_data_2, + output_dims_data_2, output_min, output_max, kTfLiteActNone, + output_data_2); +} + +TF_LITE_MICRO_TEST(LocalSimpleTestQuantized2) { + const float input_min = -128.0f; + const float input_max = 127.0f; + const float weights_min = -128.0f; + const float weights_max = 127.0f; + const float bias_scale = 1.0f; + const float output_min = -128.0f; + const float output_max = 127.0f; + + const int input_dims_data_local_2[] = {2, 10, 4}; + const int weights_dims_data_local_2[] = {2, 6, 4}; + const int bias_dims_data_local_2[] = {1, 6}; + const int output_dims_data_local_2[] = {2, 10, 6}; + + const int output_dims_count_local_2 = 60; + +#pragma Bss(".Zdata") + const int8_t input_data_local_2[] = {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}; + const int8_t weights_data_local_2[] = {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}; + const int32_t bias_data_local_2[] = {1, 1, 1, 1, 1, 1}; + int8_t output_data_local_2[output_dims_count_local_2]; +#pragma Bss() + + const int8_t expected_output_data_local_2[] = {41, 41, 41, 41, 41, 41}; + + tflite::testing::TestFullyConnectedQuantized<int8_t>( + input_dims_data_local_2, input_data_local_2, input_min, input_max, + weights_dims_data_local_2, weights_data_local_2, weights_min, weights_max, + bias_dims_data_local_2, bias_data_local_2, bias_scale, + expected_output_data_local_2, output_dims_data_local_2, output_min, + output_max, kTfLiteActNone, output_data_local_2); +} + +// Test group 3 +TF_LITE_MICRO_TEST(SystemSimpleTestQuantized3) { + const float input_min = -128.0f; + const float input_max = 127.0f; + const float weights_min = -128.0f; + const float weights_max = 127.0f; + const float bias_scale = 1.0f; + const float output_min = -128.0f; + const float output_max = 127.0f; + + const int input_dims_data_3[] = {2, 2, 5}; + const int8_t input_data_3[] = {2, 2, 2, 2, 2, 2, 2, 2, 2, 2}; + const int weights_dims_data_3[] = {2, 10, 5}; + const int8_t weights_data_3[] = {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}; + const int bias_dims_data_3[] = {1, 10}; + const int32_t bias_data_3[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + const int8_t expected_output_data_3[] = {21, 21, 21, 21, 21, 21, 21, + 21, 21, 21, 21, 21, 21, 21, + 21, 21, 21, 21, 21, 21}; + const int output_dims_data_3[] = {2, 2, 10}; + + const int output_dims_count_3 = 20; + int8_t output_data_3[output_dims_count_3]; + tflite::testing::TestFullyConnectedQuantized<int8_t>( + input_dims_data_3, input_data_3, input_min, input_max, + weights_dims_data_3, weights_data_3, weights_min, weights_max, + bias_dims_data_3, bias_data_3, bias_scale, expected_output_data_3, + output_dims_data_3, output_min, output_max, kTfLiteActNone, + output_data_3); +} + +TF_LITE_MICRO_TEST(LocalSimpleTestQuantized3) { + const float input_min = -128.0f; + const float input_max = 127.0f; + const float weights_min = -128.0f; + const float weights_max = 127.0f; + const float bias_scale = 1.0f; + const float output_min = -128.0f; + const float output_max = 127.0f; + + const int input_dims_data_local_3[] = {2, 2, 5}; + const int weights_dims_data_local_3[] = {2, 10, 5}; + const int bias_dims_data_local_3[] = {1, 10}; + const int output_dims_data_local_3[] = {2, 2, 10}; + + const int output_dims_count_local_3 = 20; + +#pragma Bss(".Zdata") + static int8_t input_data_local_3[10]; + static int8_t weights_data_local_3[50]; + static int32_t bias_data_local_3[10]; + static int8_t output_data_local_3[output_dims_count_local_3]; +#pragma Bss() + + for (int i = 0; i < 10; ++i) { + input_data_local_3[i] = 2; + } + + for (int i = 0; i < 50; ++i) { + weights_data_local_3[i] = 2; + } + + for (int i = 0; i < 10; ++i) { + bias_data_local_3[i] = 1; + } + + for (int i = 0; i < 20; ++i) { + output_data_local_3[i] = 0; + } + + const int8_t expected_output_data_local_3[] = {21, 21, 21, 21, 21, 21, 21, + 21, 21, 21, 21, 21, 21, 21, + 21, 21, 21, 21, 21, 21}; + + tflite::testing::TestFullyConnectedQuantized<int8_t>( + input_dims_data_local_3, input_data_local_3, input_min, input_max, + weights_dims_data_local_3, weights_data_local_3, weights_min, weights_max, + bias_dims_data_local_3, bias_data_local_3, bias_scale, + expected_output_data_local_3, output_dims_data_local_3, output_min, + output_max, kTfLiteActNone, output_data_local_3); +} + +// Test group 4 +TF_LITE_MICRO_TEST(SystemSimpleTestQuantized4) { + const float input_min = -128.0f; + const float input_max = 127.0f; + const float weights_min = -128.0f; + const float weights_max = 127.0f; + const float bias_scale = 1.0f; + const float output_min = -128.0f; + const float output_max = 127.0f; + + const int input_dims_data_4[] = {2, 5, 10}; + const int8_t input_data_4[] = {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}; + const int weights_dims_data_4[] = {2, 5, 10}; + const int8_t weights_data_4[] = {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}; + const int bias_dims_data_4[] = {1, 5}; + const int32_t bias_data_4[] = {1, 1, 1, 1, 1}; + const int8_t expected_output_data_4[] = {41, 41, 41, 41, 41, 41, 41, 41, 41, + 41, 41, 41, 41, 41, 41, 41, 41, 41, + 41, 41, 41, 41, 41, 41, 41}; + const int output_dims_data_4[] = {2, 5, 5}; + + const int output_dims_count_4 = 25; + int8_t output_data_4[output_dims_count_4]; + tflite::testing::TestFullyConnectedQuantized<int8_t>( + input_dims_data_4, input_data_4, input_min, input_max, + weights_dims_data_4, weights_data_4, weights_min, weights_max, + bias_dims_data_4, bias_data_4, bias_scale, expected_output_data_4, + output_dims_data_4, output_min, output_max, kTfLiteActNone, + output_data_4); +} + +TF_LITE_MICRO_TEST(LocalSimpleTestQuantized4) { + const float input_min = -128.0f; + const float input_max = 127.0f; + const float weights_min = -128.0f; + const float weights_max = 127.0f; + const float bias_scale = 1.0f; + const float output_min = -128.0f; + const float output_max = 127.0f; + + const int input_dims_data_local_4[] = {2, 5, 10}; + const int weights_dims_data_local_4[] = {2, 5, 10}; + const int bias_dims_data_local_4[] = {1, 5}; + const int output_dims_data_local_4[] = {2, 5, 5}; + + const int output_dims_count_local_4 = 25; + +#pragma Bss(".Zdata") + const int8_t input_data_local_4[] = {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}; + const int8_t weights_data_local_4[] = {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}; + const int32_t bias_data_local_4[] = {1, 1, 1, 1, 1}; + int8_t output_data_local_4[output_dims_count_local_4]; +#pragma Bss() + + const int8_t expected_output_data_local_4[] = { + 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, + 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41}; + + tflite::testing::TestFullyConnectedQuantized<int8_t>( + input_dims_data_local_4, input_data_local_4, input_min, input_max, + weights_dims_data_local_4, weights_data_local_4, weights_min, weights_max, + bias_dims_data_local_4, bias_data_local_4, bias_scale, + expected_output_data_local_4, output_dims_data_local_4, output_min, + output_max, kTfLiteActNone, output_data_local_4); +} + +TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/micro/kernels/arc_mli/mli_slicers.cc b/tensorflow/lite/micro/kernels/arc_mli/mli_slicers.cc new file mode 100644 index 00000000000..4637470f62e --- /dev/null +++ b/tensorflow/lite/micro/kernels/arc_mli/mli_slicers.cc @@ -0,0 +1,126 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mli_slicers.h" // NOLINT + +#include <algorithm> + +namespace tflite { +namespace ops { +namespace micro { + +TensorSlicer::TensorSlicer(const mli_tensor* full_tensor, int slice_dim, + int slice_size, int padding_pre, int padding_post, + int overlap, bool interleave_mode) + : full_tensor_(full_tensor), + sliceDim_(slice_dim), + pad_pre_(padding_pre), + pad_post_(padding_post), + overlap_(overlap), + sub_cfg_{0}, + sub_tensor_{0}, + done_(false) { + /* In the interleave mode, the slicing happens from the deepest dimension up + to the slice_dim for example in an HWC layout this can mode can be used to + slice in the C dimenstion. in this mode the data is not contiguous in memory + anymore */ + if (interleave_mode) { + for (int i = 0; i < full_tensor->rank; i++) { + if (i > slice_dim) { + sub_cfg_.size[i] = 1; + } else if (i == slice_dim) { + sub_cfg_.size[i] = slice_size; + } else { + sub_cfg_.size[i] = full_tensor->shape[i]; + } + } + sub_cfg_.sub_tensor_rank = full_tensor->rank; + + } else { + /* In the not interleaved mode, the slicing happens from the outer most + dimension up to the slice_dim for example in an HWC layout this mode can be + used to slice in the H dimension. in this mode the data of the slice is + still contiguous in memory (if that was the case in the input tensor */ + for (int i = 0; i < full_tensor->rank; i++) { + if (i < slice_dim) { + sub_cfg_.size[i] = 1; + } else if (i == slice_dim) { + sub_cfg_.size[i] = slice_size; + } else { + sub_cfg_.size[i] = full_tensor->shape[i]; + } + } + sub_cfg_.sub_tensor_rank = full_tensor->rank - slice_dim; + } + + ComputeSubTensor(); +} + +void TensorSlicer::ComputeSubTensor(void) { + // subtsr_cfg_ is used to keep track of the iteration. + // A copy is created to update it with the correct clipping and padding for + // the current slice + mli_sub_tensor_cfg cfg_new = sub_cfg_; + + // begin and end spans the complete input region including padding areas. + const int begin = (int)sub_cfg_.offset[sliceDim_] - pad_pre_; + // end is clipped to the end of the full input region. this is needed for + // cases where the last slice is smaller than the rest. + const int end = std::min(begin + sub_cfg_.size[sliceDim_] + overlap_, + full_tensor_->shape[sliceDim_] + pad_post_); + // The start coordinate of the subtensor is clipped to zero + cfg_new.offset[sliceDim_] = std::max(begin, 0); + // and the stop coordinate is clipped to the size of the full tensor + const int stop_coord = + std::min(end, static_cast<int>(full_tensor_->shape[sliceDim_])); + // compute the size of the subtensor + cfg_new.size[sliceDim_] = stop_coord - cfg_new.offset[sliceDim_]; + + // compute the padding configuration for the current slice. + actual_padding_pre = cfg_new.offset[sliceDim_] - begin; + actual_padding_post = end - stop_coord; + + mli_hlp_create_subtensor(full_tensor_, &cfg_new, &sub_tensor_); +} + +void TensorSlicer::Next(void) { + for (int i = full_tensor_->rank - 1; i >= 0; i--) { + sub_cfg_.offset[i] += sub_cfg_.size[i]; + if (sub_cfg_.offset[i] >= full_tensor_->shape[i]) { + // wrap + sub_cfg_.offset[i] = 0; + // and continue to the next dimension, if no next dimension we are done. + if (i == 0) done_ = true; + continue; + } else { + // carry is false, so break from the loop + break; + } + } + + if (!done_) ComputeSubTensor(); +} + +bool TensorSlicer::Done(void) { return done_; } + +int TensorSlicer::GetPaddingPre(void) { return actual_padding_pre; } + +int TensorSlicer::GetPaddingPost(void) { return actual_padding_post; } + +mli_tensor* TensorSlicer::Sub(void) { return &sub_tensor_; } + +} // namespace micro +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/arc_mli/mli_slicers.h b/tensorflow/lite/micro/kernels/arc_mli/mli_slicers.h new file mode 100644 index 00000000000..b21a5b68054 --- /dev/null +++ b/tensorflow/lite/micro/kernels/arc_mli/mli_slicers.h @@ -0,0 +1,56 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_MICRO_KERNELS_ARC_MLI_SLICERS_H_ +#define TENSORFLOW_LITE_MICRO_KERNELS_ARC_MLI_SLICERS_H_ + +#include "mli_api.h" // NOLINT +namespace tflite { +namespace ops { +namespace micro { + +class TensorSlicer { + public: + TensorSlicer(const mli_tensor* full_tensor, int slice_dim, int slice_size, + int padding_pre = 0, int padding_post = 0, int overlap = 0, + bool interleave_mode = false); + ~TensorSlicer() = default; + + void Next(); + bool Done(); + int GetPaddingPre(); + int GetPaddingPost(); + + mli_tensor* Sub(); + + // Default constructor is deleted + TensorSlicer() = delete; + + private: + const mli_tensor* full_tensor_; + mli_tensor sub_tensor_; + mli_sub_tensor_cfg sub_cfg_; + bool done_; + int sliceDim_; + int pad_pre_, pad_post_, overlap_; + int actual_padding_pre, actual_padding_post; + + void ComputeSubTensor(); +}; + +} // namespace micro +} // namespace ops +} // namespace tflite +#endif // TENSORFLOW_LITE_MICRO_KERNELS_ARC_MLI_SLICERS_H_ diff --git a/tensorflow/lite/micro/kernels/arc/mli_tf_utils.h b/tensorflow/lite/micro/kernels/arc_mli/mli_tf_utils.h similarity index 100% rename from tensorflow/lite/micro/kernels/arc/mli_tf_utils.h rename to tensorflow/lite/micro/kernels/arc_mli/mli_tf_utils.h diff --git a/tensorflow/lite/micro/kernels/arc_mli/pooling.cc b/tensorflow/lite/micro/kernels/arc_mli/pooling.cc new file mode 100644 index 00000000000..babde709bae --- /dev/null +++ b/tensorflow/lite/micro/kernels/arc_mli/pooling.cc @@ -0,0 +1,376 @@ +/* Copyright 2019-2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/kernels/internal/reference/pooling.h" + +#include "mli_api.h" // NOLINT +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/kernels/internal/reference/integer_ops/pooling.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/padding.h" +#include "tensorflow/lite/micro/kernels/arc_mli/mli_slicers.h" +#include "tensorflow/lite/micro/kernels/arc_mli/mli_tf_utils.h" +#include "tensorflow/lite/micro/kernels/arc_mli/scratch_buf_mgr.h" +#include "tensorflow/lite/micro/kernels/arc_mli/scratch_buffers.h" + +namespace tflite { +namespace ops { +namespace micro { +namespace pooling { + +namespace { + +constexpr int kInputTensor = 0; +constexpr int kOutputTensor = 0; + +struct OpData { + TfLitePaddingValues padding; +}; + +enum MliPoolingType { AveragePooling = 0, MaxPooling = 1 }; + +bool IsMliApplicable(TfLiteContext* context, const TfLiteTensor* input, + const TfLitePoolParams* params) { + // MLI optimized version only supports int8 dataype and no fused Relu + return (input->type == kTfLiteInt8 && params->activation == kTfLiteActNone); +} + +TfLiteStatus CalculateOpData(const TfLiteContext* context, + const TfLitePoolParams* params, + const TfLiteTensor* input, + const TfLiteTensor* output, OpData* data) { + // input: batch, height, width, channel + int height = SizeOfDimension(input, 1); + int width = SizeOfDimension(input, 2); + + int out_height, out_width; + + data->padding = ComputePaddingHeightWidth( + params->stride_height, params->stride_width, + /*dilation_rate_height=*/1, + /*dilation_rate_width=*/1, height, width, params->filter_height, + params->filter_width, params->padding, &out_height, &out_width); + + return kTfLiteOk; +} + +TfLiteStatus AverageEvalFloat(TfLiteContext* context, const TfLiteNode* node, + const TfLitePoolParams* params, + const OpData* data, const TfLiteTensor* input, + TfLiteTensor* output) { +#if !defined(TF_LITE_STRIP_REFERENCE_IMPL) + float activation_min, activation_max; + CalculateActivationRange(params->activation, &activation_min, + &activation_max); + + PoolParams op_params; + op_params.stride_height = params->stride_height; + op_params.stride_width = params->stride_width; + op_params.filter_height = params->filter_height; + op_params.filter_width = params->filter_width; + op_params.padding_values.height = data->padding.height; + op_params.padding_values.width = data->padding.width; + op_params.float_activation_min = activation_min; + op_params.float_activation_max = activation_max; + reference_ops::AveragePool( + op_params, GetTensorShape(input), GetTensorData<float>(input), + GetTensorShape(output), GetTensorData<float>(output)); + return kTfLiteOk; +#else + TF_LITE_KERNEL_LOG(context, + "Type %s (%d) is not supported by ARC MLI Library.", + TfLiteTypeGetName(input->type), input->type); + return kTfLiteError; +#endif +} + +// Prepare MLI tensors and run Average or Max Pooling +TfLiteStatus EvalMli(TfLiteContext* context, const TfLitePoolParams* params, + const OpData* data, const TfLiteTensor* input, + TfLiteTensor* output, const MliPoolingType pooling_type) { + mli_tensor mli_in = {0}; + mli_tensor mli_out = {0}; + mli_pool_cfg cfg = {0}; + + ConvertToMliTensor<int8_t>(input, &mli_in); + ConvertToMliTensor<int8_t>(output, &mli_out); + + cfg.kernel_width = params->filter_width; + cfg.kernel_height = params->filter_height; + cfg.stride_width = params->stride_width; + cfg.stride_height = params->stride_height; + + if (params->padding == kTfLitePaddingValid) { + cfg.padding_left = 0; + cfg.padding_right = 0; + cfg.padding_top = 0; + cfg.padding_bottom = 0; + } else { + cfg.padding_left = data->padding.width; + cfg.padding_right = data->padding.width + data->padding.width_offset; + cfg.padding_top = data->padding.height; + cfg.padding_bottom = data->padding.height + data->padding.height_offset; + } + + const int height_dimension = 1; + int in_slice_height = 0; + int out_slice_height = 0; + const int overlap = cfg.kernel_height - cfg.stride_height; + + // Tensors for data in fast (local) memory and config to copy data from + // external to local memory + mli_tensor in_local = mli_in; + mli_tensor out_local = mli_out; + mli_mov_cfg_t copy_config; + mli_mov_cfg_for_copy(©_config); + TF_LITE_ENSURE_STATUS(get_arc_scratch_buffer_for_pooling_tensors( + context, &in_local, &out_local)); + bool in_is_local = in_local.data == mli_in.data; + bool out_is_local = out_local.data == mli_out.data; + TF_LITE_ENSURE_STATUS(arc_scratch_buffer_calc_slice_size_io( + &in_local, &out_local, cfg.kernel_height, cfg.stride_height, + cfg.padding_top, cfg.padding_bottom, &in_slice_height, + &out_slice_height)); + + /* mli_in tensor contains batches of HWC tensors. so it is a 4 dimensional + tensor. because the mli kernel will process one HWC tensor at a time, the 4 + dimensional tensor needs to be sliced into nBatch 3 dimensional tensors. on + top of that there could be a need to also slice in the Height dimension. + for that the sliceHeight has been calculated. The tensor slicer is + configured that it will completely slice the nBatch dimension (0) and slice + the height dimension (1) in chunks of 'sliceHeight' */ + TensorSlicer in_slice(&mli_in, height_dimension, in_slice_height, + cfg.padding_top, cfg.padding_bottom, overlap); + TensorSlicer out_slice(&mli_out, height_dimension, out_slice_height); + + /* is_local indicates that the tensor is already in local memory, + so in that case the original tensor can be used, + and there is no need to copy it to the local tensor*/ + mli_tensor* in_ptr = in_is_local ? in_slice.Sub() : &in_local; + mli_tensor* out_ptr = out_is_local ? out_slice.Sub() : &out_local; + + while (!out_slice.Done()) { + cfg.padding_top = in_slice.GetPaddingPre(); + cfg.padding_bottom = in_slice.GetPaddingPost(); + + mli_mov_tensor_sync(in_slice.Sub(), ©_config, in_ptr); + if (pooling_type == AveragePooling) + mli_krn_avepool_hwc_sa8(in_ptr, &cfg, out_ptr); + else if (pooling_type == MaxPooling) + mli_krn_maxpool_hwc_sa8(in_ptr, &cfg, out_ptr); + mli_mov_tensor_sync(out_ptr, ©_config, out_slice.Sub()); + + in_slice.Next(); + out_slice.Next(); + } + return kTfLiteOk; +} + +TfLiteStatus AverageEvalQuantized(TfLiteContext* context, + const TfLiteNode* node, + const TfLitePoolParams* params, + const OpData* data, const TfLiteTensor* input, + TfLiteTensor* output) { +#if !defined(TF_LITE_STRIP_REFERENCE_IMPL) + TFLITE_DCHECK(input->type == kTfLiteUInt8 || input->type == kTfLiteInt8); + int32_t activation_min, activation_max; + (void)CalculateActivationRangeQuantized(context, params->activation, output, + &activation_min, &activation_max); + PoolParams op_params; + op_params.stride_height = params->stride_height; + op_params.stride_width = params->stride_width; + op_params.filter_height = params->filter_height; + op_params.filter_width = params->filter_width; + op_params.padding_values.height = data->padding.height; + op_params.padding_values.width = data->padding.width; + op_params.quantized_activation_min = activation_min; + op_params.quantized_activation_max = activation_max; + + if (input->type == kTfLiteUInt8) { + reference_ops::AveragePool( + op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), + GetTensorShape(output), GetTensorData<uint8_t>(output)); + } else { + reference_integer_ops::AveragePool( + op_params, GetTensorShape(input), GetTensorData<int8_t>(input), + GetTensorShape(output), GetTensorData<int8_t>(output)); + } + return kTfLiteOk; +#else + TF_LITE_KERNEL_LOG( + context, + "Node configuration or type %s (%d) is not supported by ARC MLI Library.", + TfLiteTypeGetName(input->type), input->type); + return kTfLiteError; +#endif +} + +TfLiteStatus MaxEvalFloat(TfLiteContext* context, TfLiteNode* node, + TfLitePoolParams* params, OpData* data, + const TfLiteTensor* input, TfLiteTensor* output) { +#if !defined(TF_LITE_STRIP_REFERENCE_IMPL) + float activation_min, activation_max; + CalculateActivationRange(params->activation, &activation_min, + &activation_max); + + tflite::PoolParams op_params; + op_params.stride_height = params->stride_height; + op_params.stride_width = params->stride_width; + op_params.filter_height = params->filter_height; + op_params.filter_width = params->filter_width; + op_params.padding_values.height = data->padding.height; + op_params.padding_values.width = data->padding.width; + op_params.float_activation_min = activation_min; + op_params.float_activation_max = activation_max; + reference_ops::MaxPool(op_params, GetTensorShape(input), + GetTensorData<float>(input), GetTensorShape(output), + GetTensorData<float>(output)); + return kTfLiteOk; +#else + TF_LITE_KERNEL_LOG(context, + "Type %s (%d) is not supported by ARC MLI Library.", + TfLiteTypeGetName(input->type), input->type); + return kTfLiteError; +#endif +} + +TfLiteStatus MaxEvalQuantized(TfLiteContext* context, TfLiteNode* node, + TfLitePoolParams* params, OpData* data, + const TfLiteTensor* input, TfLiteTensor* output) { +#if !defined(TF_LITE_STRIP_REFERENCE_IMPL) + TFLITE_DCHECK(input->type == kTfLiteUInt8 || input->type == kTfLiteInt8); + int32_t activation_min, activation_max; + (void)CalculateActivationRangeQuantized(context, params->activation, output, + &activation_min, &activation_max); + + tflite::PoolParams op_params; + op_params.stride_height = params->stride_height; + op_params.stride_width = params->stride_width; + op_params.filter_height = params->filter_height; + op_params.filter_width = params->filter_width; + op_params.padding_values.height = data->padding.height; + op_params.padding_values.width = data->padding.width; + op_params.quantized_activation_min = activation_min; + op_params.quantized_activation_max = activation_max; + + if (input->type == kTfLiteUInt8) { + reference_ops::MaxPool( + op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), + GetTensorShape(output), GetTensorData<uint8_t>(output)); + } else { + reference_integer_ops::MaxPool( + op_params, GetTensorShape(input), GetTensorData<int8_t>(input), + GetTensorShape(output), GetTensorData<int8_t>(output)); + } + return kTfLiteOk; +#else + TF_LITE_KERNEL_LOG( + context, + "Node configuration or type %s (%d) is not supported by ARC MLI Library.", + TfLiteTypeGetName(input->type), input->type); + return kTfLiteError; +#endif +} +} // namespace + +TfLiteStatus AverageEval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data); + OpData data; + + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, input, output, &data)); + + // Inputs and outputs share the same type, guaranteed by the converter. + switch (input->type) { + case kTfLiteFloat32: + return AverageEvalFloat(context, node, params, &data, input, output); + break; + case kTfLiteUInt8: + case kTfLiteInt8: + if (IsMliApplicable(context, input, params)) { + return EvalMli(context, params, &data, input, output, AveragePooling); + } else { + return AverageEvalQuantized(context, node, params, &data, input, + output); + } + break; + default: + TF_LITE_KERNEL_LOG(context, "Input type %s is not currently supported", + TfLiteTypeGetName(input->type)); + return kTfLiteError; + } + return kTfLiteOk; +} + +TfLiteStatus MaxEval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast<TfLitePoolParams*>(node->builtin_data); + OpData data; + + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, input, output, &data)); + + switch (input->type) { + case kTfLiteFloat32: + return MaxEvalFloat(context, node, params, &data, input, output); + break; + case kTfLiteUInt8: + case kTfLiteInt8: + if (IsMliApplicable(context, input, params)) { + return EvalMli(context, params, &data, input, output, MaxPooling); + } else { + return MaxEvalQuantized(context, node, params, &data, input, output); + } + break; + default: + TF_LITE_KERNEL_LOG(context, "Type %s not currently supported.", + TfLiteTypeGetName(input->type)); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace pooling + +TfLiteRegistration* Register_AVERAGE_POOL_2D() { + static TfLiteRegistration r = {/*init=*/nullptr, + /*free=*/nullptr, + /*prepare=*/nullptr, + /*invoke=*/pooling::AverageEval, + /*profiling_string=*/nullptr, + /*builtin_code=*/0, + /*custom_name=*/nullptr, + /*version=*/0}; + return &r; +} + +TfLiteRegistration* Register_MAX_POOL_2D() { + static TfLiteRegistration r = {/*init=*/nullptr, + /*free=*/nullptr, + /*prepare=*/nullptr, + /*invoke=*/pooling::MaxEval, + /*profiling_string=*/nullptr, + /*builtin_code=*/0, + /*custom_name=*/nullptr, + /*version=*/0}; + return &r; +} + +} // namespace micro +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/arc_mli/pooling_slicing_test.cc b/tensorflow/lite/micro/kernels/arc_mli/pooling_slicing_test.cc new file mode 100644 index 00000000000..7cf5c9b607e --- /dev/null +++ b/tensorflow/lite/micro/kernels/arc_mli/pooling_slicing_test.cc @@ -0,0 +1,422 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This test checks that slicing logic doesn`t affect result of pooling kernels +// +// This test doesn`t replace default pooling test +// (tensorflow/lite/micro/kernels/pooling.cc). It is added to the +// whole testset only in case MLI for ARC platform is used during generation +// (which is handled in arc_mli.inc). So such tests won`t be generated for other +// platforms. + +#include <cstdint> + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/micro/kernels/all_ops_resolver.h" +#include "tensorflow/lite/micro/testing/micro_test.h" +#include "tensorflow/lite/micro/testing/test_utils.h" + +namespace tflite { +namespace testing { +namespace { + +template <typename T> +void TestAveragePoolingQuantized( + const int* input_dims_data, const T* input_data, const float input_min, + const float input_max, const int filter_height, const int filter_width, + const int stride_height, const int stride_width, + const T* expected_output_data, const int* output_dims_data, + float output_min, float output_max, TfLitePadding padding, + TfLiteFusedActivation activation, T* output_data) { + static_assert(sizeof(T) == 1, "Only int8/uint8 data types allowed."); + + TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); + TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data); + const int output_dims_count = ElementCount(*output_dims); + + constexpr int inputs_size = 1; + constexpr int outputs_size = 1; + constexpr int tensors_size = inputs_size + outputs_size; + TfLiteTensor tensors[tensors_size] = { + CreateQuantizedTensor(input_data, input_dims, "input_tensor", input_min, + input_max), + CreateQuantizedTensor(output_data, output_dims, "output_tensor", + output_min, output_max), + }; + + TfLiteContext context; + PopulateContext(tensors, tensors_size, micro_test::reporter, &context); + + ::tflite::ops::micro::AllOpsResolver resolver; + const TfLiteRegistration* registration = + resolver.FindOp(tflite::BuiltinOperator_AVERAGE_POOL_2D, 1); + TF_LITE_MICRO_EXPECT_NE(nullptr, registration); + + TfLitePoolParams builtin_data = {padding, stride_width, stride_height, + filter_width, filter_height, activation}; + const char* init_data = reinterpret_cast<const char*>(&builtin_data); + size_t init_data_size = 0; + void* user_data = nullptr; + if (registration->init) { + user_data = registration->init(&context, init_data, init_data_size); + } + int inputs_array_data[] = {1, 0}; + TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); + int outputs_array_data[] = {1, 1}; + TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); + int temporaries_array_data[] = {0}; + TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data); + + TfLiteNode node; + node.inputs = inputs_array; + node.outputs = outputs_array; + node.temporaries = temporaries_array; + node.user_data = user_data; + node.builtin_data = reinterpret_cast<void*>(&builtin_data); + node.custom_initial_data = nullptr; + node.custom_initial_data_size = 0; + node.delegate = nullptr; + + if (registration->prepare) { + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); + } + TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node)); + if (registration->free) { + registration->free(&context, user_data); + } + + for (int i = 0; i < output_dims_count; ++i) { + TF_LITE_MICRO_EXPECT_NEAR(expected_output_data[i], output_data[i], 1e-5f); + } +} + +template <typename T> +void TestMaxPoolQuantized(const int* input_dims_data, const T* input_data, + float input_min, float input_max, int filter_width, + int filter_height, int stride_width, + int stride_height, const T* expected_output_data, + float output_min, float output_max, + const int* output_dims_data, TfLitePadding padding, + TfLiteFusedActivation activation, T* output_data) { + static_assert(sizeof(T) == 1, "Only int8/uint8 data types allowed."); + + TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); + TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data); + const int output_dims_count = ElementCount(*output_dims); + + constexpr int inputs_size = 1; + constexpr int outputs_size = 1; + constexpr int tensors_size = inputs_size + outputs_size; + TfLiteTensor tensors[tensors_size] = { + CreateQuantizedTensor(input_data, input_dims, "input_tensor", input_min, + input_max), + CreateQuantizedTensor(output_data, output_dims, "output_tensor", + output_min, output_max), + }; + + TfLiteContext context; + PopulateContext(tensors, tensors_size, micro_test::reporter, &context); + + ::tflite::ops::micro::AllOpsResolver resolver; + const TfLiteRegistration* registration = + resolver.FindOp(tflite::BuiltinOperator_MAX_POOL_2D, 1); + TF_LITE_MICRO_EXPECT_NE(nullptr, registration); + + TfLitePoolParams builtin_data = { + padding, stride_width, stride_height, + filter_width, filter_height, activation, + }; + + const char* init_data = reinterpret_cast<const char*>(&builtin_data); + size_t init_data_size = 0; + void* user_data = nullptr; + if (registration->init) { + user_data = registration->init(&context, init_data, init_data_size); + } + + int inputs_array_data[] = {1, 0}; + TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); + int outputs_array_data[] = {1, 1}; + TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); + int temporaries_array_data[] = {0}; + TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data); + + TfLiteNode node; + node.inputs = inputs_array; + node.outputs = outputs_array; + node.temporaries = temporaries_array; + node.user_data = user_data; + node.builtin_data = reinterpret_cast<void*>(&builtin_data); + node.custom_initial_data = nullptr; + node.custom_initial_data_size = 0; + node.delegate = nullptr; + if (registration->prepare) { + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); + } + TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node)); + if (registration->free) { + registration->free(&context, user_data); + } + for (int i = 0; i < output_dims_count; ++i) { + TF_LITE_MICRO_EXPECT_EQ(expected_output_data[i], output_data[i]); + } +} + +} // namespace + +} // namespace testing +} // namespace tflite + +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(SystemAveragePoolTestInt1) { + using tflite::testing::F2QS; + + const float input_min = -128; + const float input_max = 127; + const float output_min = -128; + const float output_max = 127; + int8_t output_data[3]; + + const int kInput1Shape[] = {4, 1, 2, 4, 1}; + const int8_t kInput1Data[] = {1, 1, 1, 1, 1, 1, 1, 1}; + const int kOutput1Shape[] = {4, 1, 1, 3, 1}; + const int8_t kGolden1Data[] = {1, 1, 1}; + + tflite::testing::TestAveragePoolingQuantized( + kInput1Shape, // Input shape + kInput1Data, input_min, input_max, // input quantization range + 2, 2, // filter height, filter width + 1, 1, // stride height, stride width + kGolden1Data, + kOutput1Shape, // Output shape + output_min, output_max, // output quantization range + kTfLitePaddingValid, kTfLiteActNone, output_data); +} + +TF_LITE_MICRO_TEST(LocalAveragePoolTestInt1) { + using tflite::testing::F2QS; + + const float input_min = -128; + const float input_max = 127; + const float output_min = -128; + const float output_max = 127; + int8_t output_data[3]; + +#pragma Bss(".Zdata") + const int kInput1Shape[] = {4, 1, 2, 4, 1}; + const int8_t kInput1Data[] = {1, 1, 1, 1, 1, 1, 1, 1}; + const int kOutput1Shape[] = {4, 1, 1, 3, 1}; + const int8_t kGolden1Data[] = {1, 1, 1}; +#pragma Bss() + + tflite::testing::TestAveragePoolingQuantized( + kInput1Shape, // Input shape + kInput1Data, input_min, input_max, // input quantization range + 2, 2, // filter height, filter width + 1, 1, // stride height, stride width + kGolden1Data, + kOutput1Shape, // Output shape + output_min, output_max, // output quantization range + kTfLitePaddingValid, kTfLiteActNone, output_data); +} + +// Test group AVG 2 +TF_LITE_MICRO_TEST(SystemAveragePoolTestInt2) { + using tflite::testing::F2QS; + + const float input_min = -128; + const float input_max = 127; + const float output_min = -128; + const float output_max = 127; + int8_t output_data[45]; + + const int kInput2Shape[] = {4, 1, 6, 10, 1}; + const int8_t kInput2Data[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + const int kOutput2Shape[] = {4, 1, 5, 9, 1}; + const int8_t kGolden2Data[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + + tflite::testing::TestAveragePoolingQuantized( + kInput2Shape, // Input shape + kInput2Data, input_min, input_max, // input quantization range + 2, 2, // filter height, filter width + 1, 1, // stride height, stride width + kGolden2Data, + kOutput2Shape, // Output shape + output_min, output_max, // output quantization range + kTfLitePaddingValid, kTfLiteActNone, output_data); +} + +TF_LITE_MICRO_TEST(LocalAveragePoolTestInt2) { + using tflite::testing::F2QS; + + const float input_min = -128; + const float input_max = 127; + const float output_min = -128; + const float output_max = 127; + int8_t output_data[45]; + +#pragma Bss(".Zdata") + const int kInput2Shape[] = {4, 1, 6, 10, 1}; + const int8_t kInput2Data[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + const int kOutput2Shape[] = {4, 1, 5, 9, 1}; + const int8_t kGolden2Data[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; +#pragma Bss() + + tflite::testing::TestAveragePoolingQuantized( + kInput2Shape, // Input shape + kInput2Data, input_min, input_max, // input quantization range + 2, 2, // filter height, filter width + 1, 1, // stride height, stride width + kGolden2Data, + kOutput2Shape, // Output shape + output_min, output_max, // output quantization range + kTfLitePaddingValid, kTfLiteActNone, output_data); +} + +// Test group MAX 1 +TF_LITE_MICRO_TEST(SystemMaxPoolTestInt1) { + using tflite::testing::F2QS; + + int8_t output_data[3]; + const float input_min = -128; + const float input_max = 127; + const float output_min = -128; + const float output_max = 127; + int filter_width = 2; + int filter_height = 2; + int stride_width = 1; + int stride_height = 1; + + const int kInput1Shape[] = {4, 1, 2, 4, 1}; + const int8_t kInput1Data[] = {1, 1, 1, 1, 1, 1, 1, 1}; + const int kOutput1Shape[] = {4, 1, 1, 3, 1}; + const int8_t kGolden1Data[] = {1, 1, 1}; + + tflite::testing::TestMaxPoolQuantized( + kInput1Shape, // Input shape + kInput1Data, input_min, input_max, filter_width, filter_height, + stride_width, stride_height, kGolden1Data, output_min, output_max, + kOutput1Shape, // Output shape + kTfLitePaddingValid, kTfLiteActNone, output_data); +} + +TF_LITE_MICRO_TEST(LocalMaxPoolTestInt1) { + using tflite::testing::F2QS; + + int8_t output_data[3]; + const float input_min = -128; + const float input_max = 127; + const float output_min = -128; + const float output_max = 127; + int filter_width = 2; + int filter_height = 2; + int stride_width = 1; + int stride_height = 1; + +#pragma Bss(".Zdata") + const int kInput1Shape[] = {4, 1, 2, 4, 1}; + const int8_t kInput1Data[] = {1, 1, 1, 1, 1, 1, 1, 1}; + const int kOutput1Shape[] = {4, 1, 1, 3, 1}; + const int8_t kGolden1Data[] = {1, 1, 1}; +#pragma Bss() + + tflite::testing::TestMaxPoolQuantized( + kInput1Shape, // Input shape + kInput1Data, input_min, input_max, filter_width, filter_height, + stride_width, stride_height, kGolden1Data, output_min, output_max, + kOutput1Shape, // Output shape + kTfLitePaddingValid, kTfLiteActNone, output_data); +} + +// Test group MAX 2 +TF_LITE_MICRO_TEST(SystemMaxPoolTestInt2) { + using tflite::testing::F2QS; + + int8_t output_data[45]; + const float input_min = -128; + const float input_max = 127; + const float output_min = -128; + const float output_max = 127; + int filter_width = 2; + int filter_height = 2; + int stride_width = 1; + int stride_height = 1; + + const int kInput2Shape[] = {4, 1, 6, 10, 1}; + const int8_t kInput2Data[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + const int kOutput2Shape[] = {4, 1, 5, 9, 1}; + const int8_t kGolden2Data[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + + tflite::testing::TestMaxPoolQuantized( + kInput2Shape, // Input shape + kInput2Data, input_min, input_max, filter_width, filter_height, + stride_width, stride_height, kGolden2Data, output_min, output_max, + kOutput2Shape, // Output shape + kTfLitePaddingValid, kTfLiteActNone, output_data); +} + +TF_LITE_MICRO_TEST(LocalMaxPoolTestInt2) { + using tflite::testing::F2QS; + + int8_t output_data[45]; + const float input_min = -128; + const float input_max = 127; + const float output_min = -128; + const float output_max = 127; + int filter_width = 2; + int filter_height = 2; + int stride_width = 1; + int stride_height = 1; + +#pragma Bss(".Zdata") + const int kInput2Shape[] = {4, 1, 6, 10, 1}; + const int8_t kInput2Data[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + const int kOutput2Shape[] = {4, 1, 5, 9, 1}; + const int8_t kGolden2Data[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; +#pragma Bss() + + tflite::testing::TestMaxPoolQuantized( + kInput2Shape, // Input shape + kInput2Data, input_min, input_max, filter_width, filter_height, + stride_width, stride_height, kGolden2Data, output_min, output_max, + kOutput2Shape, // Output shape + kTfLitePaddingValid, kTfLiteActNone, output_data); +} + +TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/micro/kernels/arc_mli/scratch_buf_mgr.cc b/tensorflow/lite/micro/kernels/arc_mli/scratch_buf_mgr.cc new file mode 100644 index 00000000000..534d5ef3230 --- /dev/null +++ b/tensorflow/lite/micro/kernels/arc_mli/scratch_buf_mgr.cc @@ -0,0 +1,338 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/micro/kernels/arc_mli/scratch_buf_mgr.h" + +#include <limits.h> + +#include <algorithm> + +#include "tensorflow/lite/micro/kernels/arc_mli/scratch_buffers.h" + +namespace tflite { +namespace ops { +namespace micro { + +static void get_arc_two_buffer_sizes(int request_size_1, int request_size_2, + int* grant_size_1, int* grant_size_2) { + int maxrequest = 0; + int secondrequest = 0; + int maxavailable = 0; + int secondavail = 0; + + // determine the largest requested buffer. + if (request_size_1 > request_size_2) { + maxrequest = request_size_1; + secondrequest = request_size_2; + } else { + maxrequest = request_size_2; + secondrequest = request_size_1; + } + + // find the two largest available buffers. + get_arc_scratch_buffer_two_max_sizes(&maxavailable, &secondavail); + + // in case two buffers are available, the largest buffer can go to the largest + // request. + if (secondavail > 0) { // this condition can be enhanced to prevent cases + // where the second buffer is so small that it is + // better to use one buffer and split it. + if (request_size_1 > request_size_2) { + *grant_size_1 = maxavailable; + *grant_size_2 = secondavail; + } else { + *grant_size_1 = secondavail; + *grant_size_2 = maxavailable; + } + } else { + // In case only one buffer is available, + // use only the max buffer, and split it. + *grant_size_1 = maxavailable / 2; + *grant_size_2 = maxavailable / 2; + } +} + +static TfLiteStatus get_arc_scratch_buffer_for_io_tensors( + TfLiteContext* context, mli_tensor* in, mli_tensor* out) { +#ifdef __Xxy + int request_size_in = 0; + int request_size_out = 0; + int grant_size_in = 0; + int grant_size_out = 0; + if (!inside_arc_ccm(in->data)) { + // In case the input tensor contains multiple batches, it has rank 4 + // because the mli kernel cannot operate on batches, we need to have the + // size of a single HWC tensor. that is why the start_rank is 1 in case of + // input rank 4 + int start_rank = in->rank - 3; + request_size_in = mli_hlp_count_elem_num(in, start_rank) * + mli_hlp_tensor_element_size(in); + } + if (!inside_arc_ccm(out->data)) { + // In case the input tensor contains multiple batches, it has rank 4 + // because the mli kernel cannot operate on batches, we need to have the + // size of a single batch. that is why the start_rank is 1 in case of input + // rank 4 + int start_rank = out->rank - 3; + request_size_out = mli_hlp_count_elem_num(out, start_rank) * + mli_hlp_tensor_element_size(out); + } + + get_arc_two_buffer_sizes(request_size_in, request_size_out, &grant_size_in, + &grant_size_out); + + if (!inside_arc_ccm(in->data)) { + in->data = get_arc_scratch_buffer(grant_size_in); + in->capacity = grant_size_in; + if (in->data == NULL) return kTfLiteError; + } + if (!inside_arc_ccm(out->data)) { + out->data = get_arc_scratch_buffer(grant_size_out); + out->capacity = grant_size_out; + if (out->data == NULL) return kTfLiteError; + } +#endif + return kTfLiteOk; +} + +TfLiteStatus get_arc_scratch_buffer_for_conv_tensors(TfLiteContext* context, + mli_tensor* in, + mli_tensor* weights, + mli_tensor* bias, + mli_tensor* out) { + TfLiteStatus ret_val = kTfLiteOk; +#ifdef __Xxy + init_arc_scratch_buffers(); + if (!inside_arc_ccm(weights->data)) { + int weights_size = mli_hlp_count_elem_num(weights, 0) * + mli_hlp_tensor_element_size(weights); + int max_weights_size = 0; + weights->data = get_arc_scratch_buffer(weights_size); + weights->capacity = weights_size; + if (weights->data == NULL) { + get_arc_scratch_buffer_max_size(&max_weights_size); + weights->data = get_arc_scratch_buffer(max_weights_size); + weights->capacity = max_weights_size; + if (max_weights_size == 0) ret_val = kTfLiteError; + } + if (weights->data == NULL) ret_val = kTfLiteError; + } + + if (!inside_arc_ccm(bias->data)) { + uint32_t bias_mem_requirements = + mli_hlp_count_elem_num(bias, 0) * mli_hlp_tensor_element_size(bias); + bias->data = get_arc_scratch_buffer(bias_mem_requirements); + bias->capacity = bias_mem_requirements; + } + + if (ret_val == kTfLiteOk) { + ret_val = get_arc_scratch_buffer_for_io_tensors(context, in, out); + } + + if (bias->data == NULL) { + int max_bias_size = 0; + get_arc_scratch_buffer_max_size(&max_bias_size); + bias->data = get_arc_scratch_buffer(max_bias_size); + bias->capacity = max_bias_size; + if (max_bias_size == 0) ret_val = kTfLiteError; + } + if (bias->data == NULL) ret_val = kTfLiteError; + +#endif + return ret_val; +} + +TfLiteStatus get_arc_scratch_buffer_for_fully_connect_tensors( + TfLiteContext* context, mli_tensor* in, mli_tensor* weights, + mli_tensor* bias, mli_tensor* out) { + TfLiteStatus ret_val = kTfLiteOk; +#ifdef __Xxy + init_arc_scratch_buffers(); + /* strategy for FC kernels: + first allocate input, because this cannot be sliced. (in case of batch + processing, only a single input needs to be allocated) then weigths & bias + because if fully loaded, they can be reused over batches. then output. + The number of output channels (for weights slicing) depends on size of + output and size of weights&bias */ + + if (!inside_arc_ccm(in->data)) { + /* In case the input tensor contains multiple batches, + only count the size if the inner most dimension */ + int size_in = mli_hlp_count_elem_num(in, in->rank - 1) * + mli_hlp_tensor_element_size(in); + in->data = get_arc_scratch_buffer(size_in); + in->capacity = size_in; + if (in->data == NULL) { + in->capacity = 0; + ret_val = kTfLiteError; + } + } + + if (!inside_arc_ccm(weights->data)) { + int weights_size = mli_hlp_count_elem_num(weights, 0) * + mli_hlp_tensor_element_size(weights); + int max_weights_size = 0; + weights->data = get_arc_scratch_buffer(weights_size); + weights->capacity = weights_size; + if (weights->data == NULL) { + get_arc_scratch_buffer_max_size(&max_weights_size); + weights->data = get_arc_scratch_buffer(max_weights_size); + weights->capacity = max_weights_size; + if (max_weights_size == 0) ret_val = kTfLiteError; + } + if (weights->data == NULL) ret_val = kTfLiteError; + } + + if (!inside_arc_ccm(bias->data)) { + int bias_mem_requirements = + mli_hlp_count_elem_num(bias, 0) * mli_hlp_tensor_element_size(bias); + bias->data = get_arc_scratch_buffer(bias_mem_requirements); + bias->capacity = bias_mem_requirements; + } + + if (!inside_arc_ccm(out->data)) { + /* In case the input tensor contains multiple batches, + only count the size if the inner most dimension */ + int out_size = mli_hlp_count_elem_num(out, out->rank - 1) * + mli_hlp_tensor_element_size(out); + int max_out_size = 0; + out->data = get_arc_scratch_buffer(out_size); + out->capacity = out_size; + if (out->data == NULL) { + get_arc_scratch_buffer_max_size(&max_out_size); + out->data = get_arc_scratch_buffer(max_out_size); + out->capacity = max_out_size; + if (max_out_size == 0) ret_val = kTfLiteError; + } + if (out->data == NULL) ret_val = kTfLiteError; + } + + if (bias->data == NULL) { + int max_bias_size = 0; + get_arc_scratch_buffer_max_size(&max_bias_size); + bias->data = get_arc_scratch_buffer(max_bias_size); + bias->capacity = max_bias_size; + if (max_bias_size == 0) ret_val = kTfLiteError; + } + if (bias->data == NULL) ret_val = kTfLiteError; + +#endif + return ret_val; +} + +TfLiteStatus arc_scratch_buffer_calc_slice_size_io( + const mli_tensor* in, const mli_tensor* out, const int kernel_height, + const int stride_height, const int padding_top, const int padding_bot, + int* in_slice_height, int* out_slice_height) { + const int height_dimension = 1; + const int in_height = in->shape[height_dimension]; + const int out_height = out->shape[height_dimension]; + const int line_size_in = mli_hlp_count_elem_num(in, height_dimension + 1) * + mli_hlp_tensor_element_size(in); + const int line_size_out = mli_hlp_count_elem_num(out, height_dimension + 1) * + mli_hlp_tensor_element_size(out); + int max_lines_in = 0; + int max_lines_out = 0; + int max_out_lines_for_input = 0; + bool fit = (in->capacity >= in_height * line_size_in) && + (out->capacity >= out_height * line_size_out); + if (fit) { + // in case both tensors completely fit in the capacity, there is no need for + // slicing + *in_slice_height = in_height; + *out_slice_height = out_height; + } else { + // First compute how many lines fit into the input tensor, and compute how + // many output lines can be computed with that. + max_lines_in = + std::min(in_height, static_cast<int>(in->capacity) / line_size_in); + if (max_lines_in >= in_height) { + max_out_lines_for_input = out_height; + } else if (2 * max_lines_in >= in_height) { + // in this case only two slices are needed, so both could benefit from + // padding. take the MIN to get the worst case. + max_out_lines_for_input = + (max_lines_in + std::min(padding_top, padding_bot) - kernel_height + + 1) / + stride_height; + } else { + max_out_lines_for_input = + (max_lines_in - kernel_height + 1) / stride_height; + } + // Ten compute how many ouput lines fit into the output tensor. + max_lines_out = + std::min(out_height, static_cast<int>(out->capacity) / line_size_out); + // the smallest of the two determines the slice height for the output, and + // the derived sliceheight for the input. + *out_slice_height = std::min(max_out_lines_for_input, max_lines_out); + *in_slice_height = *out_slice_height * stride_height; + } + + if ((*in_slice_height > 0) && (*out_slice_height > 0)) { + return kTfLiteOk; + } else { + return kTfLiteError; + } +} + +TfLiteStatus arc_scratch_buffer_calc_slice_size_weights( + const mli_tensor* weights, const mli_tensor* bias, + const int weight_out_ch_dimension, int* slice_channels) { + const int channels = weights->shape[weight_out_ch_dimension]; + const int ch_size_w = (mli_hlp_count_elem_num(weights, 0) / channels) * + mli_hlp_tensor_element_size(weights); + const int ch_size_b = (mli_hlp_count_elem_num(bias, 0) / channels) * + mli_hlp_tensor_element_size(bias); + int max_ch_weigths = 0; + int max_ch_bias = 0; + + bool fit = (weights->capacity >= channels * ch_size_w) && + (bias->capacity >= channels * ch_size_b); + if (fit) { + // in case both tensors completely fit in the capacity, there is no need for + // slicing + *slice_channels = channels; + } else { + // First compute how many channels fit into the weights tensor + max_ch_weigths = + std::min(channels, static_cast<int>(weights->capacity) / ch_size_w); + // Ten compute how many channels fit into the bias tensor. + max_ch_bias = + std::min(channels, static_cast<int>(bias->capacity) / ch_size_b); + // the smallest of the two determines the slice size + *slice_channels = std::min(max_ch_weigths, max_ch_bias); + } + + if (*slice_channels > 0) { + return kTfLiteOk; + } else { + return kTfLiteError; + } +} + +TfLiteStatus get_arc_scratch_buffer_for_pooling_tensors(TfLiteContext* context, + mli_tensor* in, + mli_tensor* out) { +#ifdef __Xxy + init_arc_scratch_buffers(); + return get_arc_scratch_buffer_for_io_tensors(context, in, out); +#else + return kTfLiteOk; +#endif +} + +} // namespace micro +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/arc_mli/scratch_buf_mgr.h b/tensorflow/lite/micro/kernels/arc_mli/scratch_buf_mgr.h new file mode 100644 index 00000000000..0db2db558ee --- /dev/null +++ b/tensorflow/lite/micro/kernels/arc_mli/scratch_buf_mgr.h @@ -0,0 +1,129 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_MICRO_ARC_SCRATCH_BUF_MGR_H_ +#define TENSORFLOW_LITE_MICRO_ARC_SCRATCH_BUF_MGR_H_ + +#include "mli_api.h" // NOLINT +#include "tensorflow/lite/c/common.h" + +namespace tflite { +namespace ops { +namespace micro { + +/** + * @brief Function to allocate scratch buffers for the convolution tensors + * + * @detail This function will update the data pointers in the 4 tensors with + * pointers to scratch buffers in fast local memory. + * + * @param context [I] pointer to TfLite context (needed for error handling) + * @param in [IO] pointer to the input tensor + * @param weights [IO] pointer to the weights tensor + * @param bias [IO] pointer to the bias tensor + * @param output [IO] pointer to the output tensor + * + * @return Tf Lite status code + */ +TfLiteStatus get_arc_scratch_buffer_for_conv_tensors(TfLiteContext* context, + mli_tensor* in, + mli_tensor* weights, + mli_tensor* bias, + mli_tensor* out); + +/** + * @brief Function to allocate scratch buffers for pooling kernels with only + * input and output buffers + * + * @detail This function will update the data pointers in the 2 tensors with + * pointers to scratch buffers in fast local memory. + * + * @param context [I] pointer to TfLite context (needed for error handling) + * @param in [IO] pointer to the input tensor + * @param output [IO] pointer to the output tensor + * + * @return Tf Lite status code + */ +TfLiteStatus get_arc_scratch_buffer_for_pooling_tensors(TfLiteContext* context, + mli_tensor* in, + mli_tensor* out); + +/** + * @brief Function to allocate scratch buffers for the fully connect tensors + * + * @detail This function will update the data pointers in the 4 tensors with + * pointers to scratch buffers in fast local memory. + * + * @param context [I] pointer to TfLite context (needed for error handling) + * @param in [IO] pointer to the input tensor + * @param weights [IO] pointer to the weights tensor + * @param bias [IO] pointer to the bias tensor + * @param output [IO] pointer to the output tensor + * + * @return Tf Lite status code + */ +TfLiteStatus get_arc_scratch_buffer_for_fully_connect_tensors( + TfLiteContext* context, mli_tensor* in, mli_tensor* weights, + mli_tensor* bias, mli_tensor* out); + +/** + * @brief Function to calculate slice size for io tensors + * + * @detail This function will calculate the slice size in the height dimension + * for input and output tensors. it takes into account the kernel size and the + * padding. the function will look at the capacity filed in the in and out + * tensor to determine the available buffersize. + * + * @param in [I] pointer to the input tensor + * @param out [I] pointer to the output tensor + * @param kernelHeight [I] size of the kernel in height dimension + * @param strideHeight [I] input stride in height dimension + * @param padding_top [I] number of lines with zeros at the top + * @param padding_bot [I] number of lines with zeros at the bottom + * @param inSliceHeight [O] slice size in height dimension for the input tensor + * @param outSliceHeight [O] slice size in height dimension for the output + * tensor + * + * @return Tf Lite status code + */ +TfLiteStatus arc_scratch_buffer_calc_slice_size_io( + const mli_tensor* in, const mli_tensor* out, const int kernelHeight, + const int strideHeight, const int padding_top, const int padding_bot, + int* in_slice_height, int* out_slice_height); + +/** + * @brief Function to calculate slice size for weight slicing + * + * @detail This function will calculate the slice size in the output channel + * dimension for weight and bias tensors. the function will look at the capacity + * filed in the weights and bias tensor to determine the available buffersize. + * + * @param weights [I] pointer to the input tensor + * @param bias [I] pointer to the output tensor + * @param weightOutChDimension [I] dimension of the output channels in the + * weights tensor + * @param sliceChannels [O] slice size in output channel dimension + * + * @return Tf Lite status code + */ +TfLiteStatus arc_scratch_buffer_calc_slice_size_weights( + const mli_tensor* weights, const mli_tensor* bias, + const int weight_out_ch_dimension, int* slice_channels); + +} // namespace micro +} // namespace ops +} // namespace tflite + +#endif // TENSORFLOW_LITE_MICRO_ARC_SCRATCH_BUF_MGR_H_ diff --git a/tensorflow/lite/micro/kernels/arc_mli/scratch_buffers.cc b/tensorflow/lite/micro/kernels/arc_mli/scratch_buffers.cc new file mode 100644 index 00000000000..2ee91da5eb7 --- /dev/null +++ b/tensorflow/lite/micro/kernels/arc_mli/scratch_buffers.cc @@ -0,0 +1,135 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/micro/kernels/arc_mli/scratch_buffers.h" + +#include <limits.h> + +namespace tflite { +namespace ops { +namespace micro { + +/* by default use all the XY memory, and half of the DCCM because DCCM is also + * used for the data section and the stack. the values can be overruled by + * adding a -D option to the makefile of the application + */ +#ifndef SCRATCH_MEM_X_SIZE +#ifdef core_config_xy_size +#define SCRATCH_MEM_X_SIZE (core_config_xy_size) +#else +#define SCRATCH_MEM_X_SIZE (0) +#endif +#endif + +#ifndef SCRATCH_MEM_Y_SIZE +#ifdef core_config_xy_size +#define SCRATCH_MEM_Y_SIZE (core_config_xy_size) +#else +#define SCRATCH_MEM_Y_SIZE (0) +#endif +#endif + +#ifndef SCRATCH_MEM_Z_SIZE +#ifdef core_config_dccm_size +#define SCRATCH_MEM_Z_SIZE ((core_config_dccm_size) / 2) +#else +#define SCRATCH_MEM_Z_SIZE (0) +#endif +#endif + +namespace { +#pragma Bss(".Xdata") +static int8_t scratch_mem_x[SCRATCH_MEM_X_SIZE]; +#pragma Bss() + +#pragma Bss(".Ydata") +static int8_t scratch_mem_y[SCRATCH_MEM_Y_SIZE]; +#pragma Bss() + +#pragma Bss(".Zdata") +static int8_t scratch_mem_z[SCRATCH_MEM_Z_SIZE]; +#pragma Bss() +} // namespace + +static int8_t *scratch_mem[] = {scratch_mem_x, scratch_mem_y, scratch_mem_z}; +static uint32_t scratch_sizes[] = {SCRATCH_MEM_X_SIZE, SCRATCH_MEM_Y_SIZE, + SCRATCH_MEM_Z_SIZE}; + +void *get_arc_scratch_buffer(int size) { + // Function to asign fast memory from one of 3 scratch buffers. + // Best Fit strategy - memory is allocated from that memory bank that leaves + // the least unused memory. + void *buf = NULL; + int best_mem_idx = -1; + int best_mem_delta = INT_MAX; + const int num_mem = sizeof(scratch_mem) / sizeof(scratch_mem[0]); + // find a local memory that fits the data size. + for (int mem_idx = 0; mem_idx < num_mem; ++mem_idx) { + // Best Fit + if ((size <= scratch_sizes[mem_idx]) && + (scratch_sizes[mem_idx] - size < best_mem_delta)) { + best_mem_idx = mem_idx; + best_mem_delta = scratch_sizes[mem_idx] - size; + } + } + if (best_mem_idx >= 0) { + buf = static_cast<void *>(scratch_mem[best_mem_idx]); + scratch_mem[best_mem_idx] += size; + scratch_sizes[best_mem_idx] -= size; + } + return buf; +} + +void get_arc_scratch_buffer_max_size(int *size) { + int maxavailable = 0; + const int num_mem = sizeof(scratch_mem) / sizeof(scratch_mem[0]); + // find the largest available buffer. + for (int i = 0; i < num_mem; i++) { + if (scratch_sizes[i] > maxavailable) { + maxavailable = scratch_sizes[i]; + } + } + *size = maxavailable; +} + +void get_arc_scratch_buffer_two_max_sizes(int *size1, int *size2) { + int maxavailable = 0; + int secondavail = 0; + const int num_mem = sizeof(scratch_mem) / sizeof(scratch_mem[0]); + // find the two largest available buffers. + for (int i = 0; i < num_mem; i++) { + if (scratch_sizes[i] > maxavailable) { + secondavail = maxavailable; + maxavailable = scratch_sizes[i]; + } else if (scratch_sizes[i] > secondavail) { + secondavail = scratch_sizes[i]; + } + } + *size1 = maxavailable; + *size2 = secondavail; +} + +void init_arc_scratch_buffers(void) { + scratch_mem[0] = scratch_mem_x; + scratch_mem[1] = scratch_mem_y; + scratch_mem[2] = scratch_mem_z; + scratch_sizes[0] = SCRATCH_MEM_X_SIZE; + scratch_sizes[1] = SCRATCH_MEM_Y_SIZE; + scratch_sizes[2] = SCRATCH_MEM_Z_SIZE; +} + +} // namespace micro +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/arc_mli/scratch_buffers.h b/tensorflow/lite/micro/kernels/arc_mli/scratch_buffers.h new file mode 100644 index 00000000000..f139659960e --- /dev/null +++ b/tensorflow/lite/micro/kernels/arc_mli/scratch_buffers.h @@ -0,0 +1,68 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_MICRO_ARC_SCRATCH_BUFFERS_H_ +#define TENSORFLOW_LITE_MICRO_ARC_SCRATCH_BUFFERS_H_ + +#include "mli_api.h" // NOLINT +#include "tensorflow/lite/c/common.h" + +namespace tflite { +namespace ops { +namespace micro { + +void init_arc_scratch_buffers(void); +void* get_arc_scratch_buffer( + int size); // Function to assign fast memory from one of 3 scratch buffers. + +void get_arc_scratch_buffer_max_size(int* size); +void get_arc_scratch_buffer_two_max_sizes(int* size1, int* size2); + +static inline bool inside_arc_dccm(void* p) { +#if core_config_dccm_present + return ((unsigned)p >= core_config_dccm_base) && + ((unsigned)p < core_config_dccm_base + core_config_dccm_size); +#else + return false; +#endif +} + +static inline bool inside_arc_xccm(void* p) { +#if core_config_xy + return ((unsigned)p >= core_config_xy_x_base) && + ((unsigned)p < core_config_xy_x_base + core_config_xy_size); +#else + return false; +#endif +} + +static inline bool inside_arc_yccm(void* p) { +#if core_config_xy + return ((unsigned)p >= core_config_xy_y_base) && + ((unsigned)p < core_config_xy_y_base + core_config_xy_size); +#else + return false; +#endif +} + +static inline bool inside_arc_ccm(void* p) { + return inside_arc_dccm(p) || inside_arc_xccm(p) || inside_arc_yccm(p); +} + +} // namespace micro +} // namespace ops +} // namespace tflite + +#endif // TENSORFLOW_LITE_MICRO_ARC_SCRATCH_BUFFERS_H_ diff --git a/tensorflow/lite/micro/kernels/cmsis-nn/conv.cc b/tensorflow/lite/micro/kernels/cmsis-nn/conv.cc index 34d4e837f65..6e8272b221a 100644 --- a/tensorflow/lite/micro/kernels/cmsis-nn/conv.cc +++ b/tensorflow/lite/micro/kernels/cmsis-nn/conv.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/reference/conv.h" +#include "arm_nn_types.h" #include "arm_nnfunctions.h" #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" @@ -116,7 +117,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) { TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { #if defined(__ARM_FEATURE_DSP) OpData data; - int32_t buf_size; + int32_t buf_size = 0; auto* params = reinterpret_cast<TfLiteConvParams*>(node->builtin_data); @@ -127,32 +128,49 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { RuntimeShape input_shape = GetTensorShape(input); RuntimeShape output_shape = GetTensorShape(output); - const int input_depth = input_shape.Dims(3); - const int input_width = input->dims->data[2]; - const int input_height = input->dims->data[1]; - const int filter_width = filter->dims->data[2]; - const int filter_height = filter->dims->data[1]; - const int output_width = output->dims->data[2]; - const int output_height = output->dims->data[1]; - const int batches = MatchingDim(input_shape, 0, output_shape, 0); + // Initialize cmsis-nn input dimensions + cmsis_nn_dims input_dims; + input_dims.n = MatchingDim(input_shape, 0, output_shape, 0); + input_dims.h = input->dims->data[1]; + input_dims.w = input->dims->data[2]; + input_dims.c = input_shape.Dims(3); + + // Initialize cmsis-nn filter dimensions + cmsis_nn_dims filter_dims; + filter_dims.n = output_shape.Dims(3); + filter_dims.h = filter->dims->data[1]; + filter_dims.w = filter->dims->data[2]; + filter_dims.c = input_dims.c; + + // Initialize cmsis-nn output dimensions + cmsis_nn_dims output_dims; + output_dims.n = input_dims.n; + output_dims.h = output->dims->data[1]; + output_dims.w = output->dims->data[2]; + output_dims.c = output_shape.Dims(3); int* buffer_idx = reinterpret_cast<int*>(node->user_data); TF_LITE_ENSURE_STATUS(CalculateOpData( - context, node, params, input_width, input_height, filter_width, - filter_height, output_width, output_height, input->type, &data)); + context, node, params, input_dims.w, input_dims.h, filter_dims.w, + filter_dims.h, output_dims.w, output_dims.h, input->type, &data)); - if (data.padding.width == 0 && data.padding.height == 0 && - (input_depth % 4 == 0) && params->stride_width == 1 && - params->stride_height == 1 && filter_width == 1 && filter_height == 1) { - buf_size = arm_convolve_1x1_s8_fast_get_buffer_size(input_depth); - } else if (output_height == 1 && input_height == 1 && filter_height == 1 && - (output_width % 4 == 0) && batches == 1) { - buf_size = arm_convolve_1_x_n_s8_get_buffer_size(input_depth, filter_width, - filter_height); - } else { - buf_size = arm_convolve_s8_get_buffer_size(input_depth, filter_width, - filter_height); + if (input->type == kTfLiteInt8) { + // Initialize cmsis-nn convolution parameters + cmsis_nn_conv_params conv_params; + conv_params.input_offset = -input->params.zero_point; + conv_params.output_offset = output->params.zero_point; + conv_params.stride.h = params->stride_height; + conv_params.stride.w = params->stride_width; + conv_params.dilation.h = params->dilation_height_factor; + conv_params.dilation.w = params->dilation_width_factor; + conv_params.padding.h = data.padding.height; + conv_params.padding.w = data.padding.width; + conv_params.activation.min = data.output_activation_min; + conv_params.activation.max = data.output_activation_max; + + buf_size = arm_convolve_wrapper_s8_get_buffer_size( + &conv_params, &input_dims, &filter_dims, &output_dims); } node->user_data = buffer_idx; @@ -204,6 +222,102 @@ TfLiteStatus EvalQuantizedPerChannel( TfLiteContext* context, TfLiteNode* node, TfLiteConvParams* params, OpData* data, const TfLiteTensor* input, const TfLiteTensor* filter, const TfLiteTensor* bias, TfLiteTensor* output, TfLiteTensor* im2col) { + // Initialize cmsis-nn convolution parameters + cmsis_nn_conv_params conv_params; + conv_params.input_offset = -input->params.zero_point; + conv_params.output_offset = output->params.zero_point; + conv_params.stride.h = params->stride_height; + conv_params.stride.w = params->stride_width; + conv_params.dilation.h = params->dilation_height_factor; + conv_params.dilation.w = params->dilation_width_factor; + conv_params.padding.h = data->padding.height; + conv_params.padding.w = data->padding.width; + conv_params.activation.min = data->output_activation_min; + conv_params.activation.max = data->output_activation_max; + + // Initialize cmsis-nn per channel quantization parameters + cmsis_nn_per_channel_quant_params quant_params; + quant_params.multiplier = data->per_channel_output_multiplier; + quant_params.shift = data->per_channel_output_shift; + +#if defined(__ARM_FEATURE_DSP) + RuntimeShape filter_shape = GetTensorShape(filter); + RuntimeShape input_shape = GetTensorShape(input); + RuntimeShape output_shape = GetTensorShape(output); + RuntimeShape bias_shape = GetTensorShape(bias); + + // Sanity check. + TFLITE_DCHECK_LE(conv_params.activation.min, conv_params.activation.max); + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + const int batch_size = MatchingDim(input_shape, 0, output_shape, 0); + const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3); + const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3); + if (GetTensorData<int8_t>(bias)) { + TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth); + } + + // Initialize cmsis-nn dimensions + // Input + cmsis_nn_dims input_dims; + input_dims.n = batch_size; + input_dims.h = input_shape.Dims(1); + input_dims.w = input_shape.Dims(2); + input_dims.c = input_depth; + + // Filter + cmsis_nn_dims filter_dims; + filter_dims.n = output_depth; + filter_dims.h = filter_shape.Dims(1); + filter_dims.w = filter_shape.Dims(2); + filter_dims.c = input_depth; + + // Bias + cmsis_nn_dims bias_dims; + bias_dims.n = 1; + bias_dims.h = 1; + bias_dims.w = 1; + bias_dims.c = output_depth; + + // Output + cmsis_nn_dims output_dims; + output_dims.n = batch_size; + output_dims.h = output_shape.Dims(1); + output_dims.w = output_shape.Dims(2); + output_dims.c = output_depth; + + // Initialize cmsis-nn context + cmsis_nn_context ctx; + ctx.buf = nullptr; + ctx.size = 0; + + auto* buffer_idx = reinterpret_cast<int*>(node->user_data); + if (*buffer_idx > -1) { + ctx.buf = context->GetScratchBuffer(context, *buffer_idx); + // Note: ctx.size is currently not used in cmsis-nn. + // The buffer should be allocated in the Prepare function through + // arm_convolve_wrapper_s8_get_buffer_size + } + + // arm_convolve_wrapper_s8 dispatches the optimized kernel accordingly with + // the parameters passed + arm_status status = arm_convolve_wrapper_s8( + &ctx, &conv_params, &quant_params, &input_dims, + GetTensorData<int8_t>(input), &filter_dims, GetTensorData<int8_t>(filter), + &bias_dims, GetTensorData<int32>(bias), &output_dims, + GetTensorData<int8_t>(output)); + + if (status == ARM_MATH_SUCCESS) { + return kTfLiteOk; + } else { + return kTfLiteError; + } + +#else +#pragma message( \ + "CMSIS-NN optimization for conv not available for this target. Using reference kernel.") + ConvParams op_params; op_params.input_offset = -input->params.zero_point; op_params.output_offset = output->params.zero_point; @@ -216,91 +330,6 @@ TfLiteStatus EvalQuantizedPerChannel( op_params.quantized_activation_min = data->output_activation_min; op_params.quantized_activation_max = data->output_activation_max; -#if defined(__ARM_FEATURE_DSP) - RuntimeShape filter_shape = GetTensorShape(filter); - RuntimeShape input_shape = GetTensorShape(input); - RuntimeShape output_shape = GetTensorShape(output); - RuntimeShape bias_shape = GetTensorShape(bias); - - // Set min and max value of the output. - const int32 output_activation_min = std::numeric_limits<int8_t>::min(); - const int32 output_activation_max = std::numeric_limits<int8_t>::max(); - - // Sanity check. - TFLITE_DCHECK_LE(output_activation_min, output_activation_max); - TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); - TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4); - TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); - const int batches = MatchingDim(input_shape, 0, output_shape, 0); - const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3); - const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3); - if (GetTensorData<int8_t>(bias)) { - TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth); - } - - const int input_height = input_shape.Dims(1); - const int input_width = input_shape.Dims(2); - const int filter_height = filter_shape.Dims(1); - const int filter_width = filter_shape.Dims(2); - const int output_height = output_shape.Dims(1); - const int output_width = output_shape.Dims(2); - int16_t* buf = nullptr; - - auto* buffer_idx = reinterpret_cast<int*>(node->user_data); - if (*buffer_idx > -1) { - void* raw = context->GetScratchBuffer(context, *buffer_idx); - buf = reinterpret_cast<int16_t*>(raw); - } - - if (op_params.padding_values.width == 0 && - op_params.padding_values.height == 0 && (input_depth % 4 == 0) && - op_params.stride_width == 1 && op_params.stride_height == 1 && - filter_width == 1 && filter_height == 1) { - if (arm_convolve_1x1_s8_fast( - GetTensorData<int8_t>(input), input_width, input_height, - input_depth, batches, GetTensorData<int8_t>(filter), output_depth, - op_params.padding_values.width, op_params.padding_values.height, - op_params.stride_width, op_params.stride_height, - GetTensorData<int32>(bias), GetTensorData<int8_t>(output), - data->per_channel_output_shift, data->per_channel_output_multiplier, - op_params.output_offset, op_params.input_offset, - output_activation_min, output_activation_max, output_width, - output_height, buf) != ARM_MATH_SUCCESS) { - return kTfLiteError; - } - - } else if (output_height == 1 && input_height == 1 && filter_height == 1 && - (output_width % 4 == 0) && batches == 1) { - if (arm_convolve_1_x_n_s8( - GetTensorData<int8_t>(input), input_width, input_depth, batches, - GetTensorData<int8_t>(filter), output_depth, filter_width, - op_params.padding_values.width, op_params.stride_width, - GetTensorData<int32_t>(bias), GetTensorData<int8_t>(output), - data->per_channel_output_shift, data->per_channel_output_multiplier, - op_params.output_offset, op_params.input_offset, - output_activation_min, output_activation_max, output_width, - buf) != ARM_MATH_SUCCESS) { - return kTfLiteError; - } - } else { - if (arm_convolve_s8( - GetTensorData<int8_t>(input), input_width, input_height, - input_depth, batches, GetTensorData<int8_t>(filter), output_depth, - filter_width, filter_height, op_params.padding_values.width, - op_params.padding_values.height, op_params.stride_width, - op_params.stride_height, GetTensorData<int32>(bias), - GetTensorData<int8_t>(output), data->per_channel_output_shift, - data->per_channel_output_multiplier, op_params.output_offset, - op_params.input_offset, output_activation_min, - output_activation_max, output_width, output_height, - buf) != ARM_MATH_SUCCESS) { - return kTfLiteError; - } - } -#else -#pragma message( \ - "CMSIS-NN optimization for conv not available for this target. Using reference kernel.") - reference_integer_ops::ConvPerChannel( op_params, data->per_channel_output_multiplier, data->per_channel_output_shift, GetTensorShape(input), diff --git a/tensorflow/lite/micro/kernels/conv_test.cc b/tensorflow/lite/micro/kernels/conv_test.cc index 4cc2a80c3ea..6d5a6f55814 100644 --- a/tensorflow/lite/micro/kernels/conv_test.cc +++ b/tensorflow/lite/micro/kernels/conv_test.cc @@ -409,8 +409,9 @@ TF_LITE_MICRO_TEST(Kernel1x1QuantizedPerChannel) { TF_LITE_MICRO_TEST(Kernel1x1QuantizedPerChannelRelu6) { // conv params: - // padding, stride_<width,height>, dilation_<width, height>, activation - TfLiteConvParams conv_params = {kTfLitePaddingValid, 1, 1, kTfLiteActRelu6}; + // padding, stride_<width,height>, activation, dilation_<width, height> + TfLiteConvParams conv_params = {kTfLitePaddingValid, 1, 1, + kTfLiteActRelu6, 1, 1}; const int kInputShape[] = {4, 1, 2, 2, 4}; // [len,N,H,W,C] const int kInputElements = kInputShape[1] * kInputShape[2] * kInputShape[3] * kInputShape[4]; diff --git a/tensorflow/lite/micro/kernels/pooling_test.cc b/tensorflow/lite/micro/kernels/pooling_test.cc index 8bfeb718a1b..96dff421d53 100644 --- a/tensorflow/lite/micro/kernels/pooling_test.cc +++ b/tensorflow/lite/micro/kernels/pooling_test.cc @@ -496,7 +496,7 @@ TF_LITE_MICRO_TEST(SimpleAveragePoolTestInt8PaddingSameStride1ActNone) { F2QS(8.5, output_min, output_max), F2QS(7., output_min, output_max)}, {4, 1, 2, 4, 1}, // Output shape output_min, output_max, // output quantization range - kTfLitePaddingValid, kTfLiteActNone, output_data); + kTfLitePaddingSame, kTfLiteActNone, output_data); } TF_LITE_MICRO_TEST(SimpleMaxPoolTestFloat) { diff --git a/tensorflow/lite/micro/kernels/prelu.cc b/tensorflow/lite/micro/kernels/prelu.cc index 2c575269cca..801181abba4 100644 --- a/tensorflow/lite/micro/kernels/prelu.cc +++ b/tensorflow/lite/micro/kernels/prelu.cc @@ -102,6 +102,21 @@ TfLiteStatus PreluEval(TfLiteContext* context, TfLiteNode* node) { GetTensorShape(output), GetTensorData<uint8_t>(output)); return kTfLiteOk; } break; + case kTfLiteInt8: { + PreluParams op_params; + op_params.input_offset = -input->params.zero_point; + op_params.alpha_offset = -alpha->params.zero_point; + op_params.output_offset = output->params.zero_point; + op_params.output_multiplier_1 = output_multiplier_1; + op_params.output_shift_1 = output_shift_1; + op_params.output_multiplier_2 = output_multiplier_2; + op_params.output_shift_2 = output_shift_2; + reference_ops::BroadcastPrelu4DSlow( + op_params, GetTensorShape(input), GetTensorData<int8_t>(input), + GetTensorShape(alpha), GetTensorData<int8_t>(alpha), + GetTensorShape(output), GetTensorData<int8_t>(output)); + return kTfLiteOk; + } break; default: TF_LITE_KERNEL_LOG( context, "Only float32 and uint8 are supported currently, got %d.", diff --git a/tensorflow/lite/micro/kernels/prelu_test.cc b/tensorflow/lite/micro/kernels/prelu_test.cc index d6c851a2726..66c0a609e8a 100644 --- a/tensorflow/lite/micro/kernels/prelu_test.cc +++ b/tensorflow/lite/micro/kernels/prelu_test.cc @@ -82,16 +82,18 @@ void TestPreluFloat(std::initializer_list<int> input_dims_data, } } +// Template argument T can be either uint8_t or int8_t depending on which type +// of quantization required to be tested. +template <typename T> void TestPreluQuantized(std::initializer_list<int> input_dims_data, - std::initializer_list<uint8_t> input_data, - float input_min, float input_max, + std::initializer_list<T> input_data, float input_min, + float input_max, std::initializer_list<int> alpha_dims_data, - std::initializer_list<uint8_t> alpha_data, - float alpha_min, float alpha_max, - std::initializer_list<uint8_t> expected_output_data, + std::initializer_list<T> alpha_data, float alpha_min, + float alpha_max, + std::initializer_list<T> expected_output_data, std::initializer_list<int> output_dims_data, - float output_min, float output_max, - uint8_t* output_data) { + float output_min, float output_max, T* output_data) { TfLiteIntArray* input_dims = IntArrayFromInitializer(input_dims_data); TfLiteIntArray* alpha_dims = IntArrayFromInitializer(alpha_dims_data); TfLiteIntArray* output_dims = IntArrayFromInitializer(output_dims_data); @@ -173,7 +175,7 @@ TF_LITE_MICRO_TEST(FloatPreluActivationsOpTest) { output_data); } -TF_LITE_MICRO_TEST(QuantizedPreluActivationsOpTest) { +TF_LITE_MICRO_TEST(QuantizedUint8PreluActivationsOpTest) { using tflite::testing::F2Q; const float kMin = -1; const float kMax = 127.f / 128.f; @@ -200,4 +202,30 @@ TF_LITE_MICRO_TEST(QuantizedPreluActivationsOpTest) { kMin, kMax, output_data); } +TF_LITE_MICRO_TEST(QuantizedInt8PreluActivationsOpTest) { + using tflite::testing::F2QS; + const float kMin = -1; + const float kMax = 127.f / 128.f; + const float kAlphaMin = -0.5f; + const float kAlphaMax = 0.5f; + const int output_dims_count = 12; + int8_t output_data[output_dims_count]; + tflite::testing::TestPreluQuantized( + {1, 2, 2, 3}, // input shape + {F2QS(0.0f, kMin, kMax), F2QS(0.0f, kMin, kMax), F2QS(0.0f, kMin, kMax), + F2QS(0.5f, kMin, kMax), F2QS(0.5f, kMin, kMax), F2QS(0.5f, kMin, kMax), + F2QS(-1.0f, kMin, kMax), F2QS(-1.0f, kMin, kMax), + F2QS(-1.0f, kMin, kMax), F2QS(-0.25f, kMin, kMax), + F2QS(-0.25f, kMin, kMax), F2QS(-0.25f, kMin, kMax)}, + kMin, kMax, {1, 1, 1, 3}, // alpha shape + {F2QS(0.0f, kMin, kMax), F2QS(0.5f, kMin, kMax), F2QS(-0.5f, kMin, kMax)}, + kMin, kMax, + {F2QS(0.0f, kMin, kMax), F2QS(0.0f, kMin, kMax), F2QS(0.0f, kMin, kMax), + F2QS(0.5f, kMin, kMax), F2QS(0.5f, kMin, kMax), F2QS(0.5f, kMin, kMax), + F2QS(0.0f, kMin, kMax), F2QS(-0.5f, kMin, kMax), F2QS(0.5f, kMin, kMax), + F2QS(0.0f, kMin, kMax), F2QS(-0.125f, kMin, kMax), + F2QS(0.125f, kMin, kMax)}, + {1, 2, 2, 3}, // output shape + kMin, kMax, output_data); +} TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini/fully_connected.cc b/tensorflow/lite/micro/kernels/xtensa_hifimini/fully_connected.cc index c8bba633de7..39f07862753 100644 --- a/tensorflow/lite/micro/kernels/xtensa_hifimini/fully_connected.cc +++ b/tensorflow/lite/micro/kernels/xtensa_hifimini/fully_connected.cc @@ -192,7 +192,11 @@ TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node, const OpData& data, const TfLiteTensor* input, const TfLiteTensor* filter, const TfLiteTensor* bias, TfLiteTensor* output) { - // TODO(b/154032858): Investigate removing extra copies. + // TODO(b/154032858): Investigate removing extra copies, and also passing by + // value. TODO(b/155656675): Consider passing OpData by value once it is also + // passed to the FullyConnected function. Until it is copied to a local + // op_param variable, we do not get any latency improvements from passing by + // value. FullyConnectedParams op_params; op_params.input_offset = -input->params.zero_point; op_params.weights_offset = -filter->params.zero_point; diff --git a/tensorflow/lite/micro/kernels/xtensa_hifimini/softmax.cc b/tensorflow/lite/micro/kernels/xtensa_hifimini/softmax.cc index a7c5604ef64..da75118b598 100644 --- a/tensorflow/lite/micro/kernels/xtensa_hifimini/softmax.cc +++ b/tensorflow/lite/micro/kernels/xtensa_hifimini/softmax.cc @@ -48,12 +48,12 @@ constexpr int kExpFractionalBits = 16; constexpr int kMaxExponentValue = (1 << kExpFractionalBits); // Quantized softmax with int8 input and int16 output. -// TODO(b/155656675): Investigate removing const ref params. -inline TfLiteStatus Softmax(const OpData& op_data, - const RuntimeShape& input_shape, - const int8_t* input_data, - const RuntimeShape& output_shape, - int16_t* output_data) { +// Passing OpData by value does not have much savings in this op, but following +// that as a best practice, at least for the xtensa kernels. See b/155656675 for +// more details. +TfLiteStatus Softmax(OpData op_data, const RuntimeShape& input_shape, + const int8_t* input_data, const RuntimeShape& output_shape, + int16_t* output_data) { // The last dimension is depth. Outer size is the the total input size // divided by depth. const int trailing_dim = input_shape.DimensionsCount() - 1; @@ -190,7 +190,6 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* output = GetOutput(context, node, 0); if (input->type == kTfLiteInt8 && output->type == kTfLiteInt16) { - // TODO(b/155656675): Const ref params can be slow on xtensa. return Softmax(*op_data, GetTensorShape(input), GetTensorData<int8_t>(input), GetTensorShape(output), GetTensorData<int16_t>(output)); diff --git a/tensorflow/lite/micro/memory_helpers.cc b/tensorflow/lite/micro/memory_helpers.cc index 302f160a235..c1b761bf088 100644 --- a/tensorflow/lite/micro/memory_helpers.cc +++ b/tensorflow/lite/micro/memory_helpers.cc @@ -15,8 +15,12 @@ limitations under the License. #include "tensorflow/lite/micro/memory_helpers.h" +#include <cstddef> #include <cstdint> +#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/core/api/flatbuffer_conversions.h" namespace tflite { diff --git a/tensorflow/lite/micro/memory_helpers.h b/tensorflow/lite/micro/memory_helpers.h index ef8205c8038..f52da062271 100644 --- a/tensorflow/lite/micro/memory_helpers.h +++ b/tensorflow/lite/micro/memory_helpers.h @@ -15,6 +15,9 @@ limitations under the License. #ifndef TENSORFLOW_LITE_MICRO_MEMORY_HELPERS_H_ #define TENSORFLOW_LITE_MICRO_MEMORY_HELPERS_H_ +#include <cstddef> +#include <cstdint> + #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/schema/schema_generated.h" diff --git a/tensorflow/lite/micro/micro_allocator.cc b/tensorflow/lite/micro/micro_allocator.cc index 54ce3383a08..b67e158980d 100644 --- a/tensorflow/lite/micro/micro_allocator.cc +++ b/tensorflow/lite/micro/micro_allocator.cc @@ -18,6 +18,7 @@ limitations under the License. #include <cstddef> #include <cstdint> +#include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/core/api/flatbuffer_conversions.h" @@ -26,6 +27,7 @@ limitations under the License. #include "tensorflow/lite/micro/compatibility.h" #include "tensorflow/lite/micro/memory_helpers.h" #include "tensorflow/lite/micro/memory_planner/greedy_memory_planner.h" +#include "tensorflow/lite/micro/memory_planner/memory_planner.h" #include "tensorflow/lite/micro/simple_memory_allocator.h" namespace tflite { @@ -256,7 +258,7 @@ TfLiteStatus CommitPlan(ErrorReporter* error_reporter, MemoryPlanner* planner, namespace internal { -TfLiteStatus InitializeRuntimeTensor( +TfLiteStatus InitializeTfLiteTensorFromFlatbuffer( SimpleMemoryAllocator* allocator, const tflite::Tensor& flatbuffer_tensor, const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers, ErrorReporter* error_reporter, TfLiteTensor* result) { @@ -378,58 +380,9 @@ TfLiteStatus InitializeRuntimeTensor( } return kTfLiteOk; } + } // namespace internal -TfLiteStatus MicroAllocator::Init() { - auto* subgraphs = model_->subgraphs(); - if (subgraphs->size() != 1) { - TF_LITE_REPORT_ERROR(error_reporter_, - "Only 1 subgraph is currently supported.\n"); - return kTfLiteError; - } - subgraph_ = (*subgraphs)[0]; - - context_->tensors_size = subgraph_->tensors()->size(); - context_->tensors = - reinterpret_cast<TfLiteTensor*>(memory_allocator_->AllocateFromTail( - sizeof(TfLiteTensor) * context_->tensors_size, - alignof(TfLiteTensor))); - if (context_->tensors == nullptr) { - TF_LITE_REPORT_ERROR( - error_reporter_, - "Failed to allocate memory for context->tensors, %d bytes required", - sizeof(TfLiteTensor) * context_->tensors_size); - return kTfLiteError; - } - - // Initialize runtime tensors in context_ using the flatbuffer. - for (size_t i = 0; i < subgraph_->tensors()->size(); ++i) { - TfLiteStatus status = internal::InitializeRuntimeTensor( - memory_allocator_, *subgraph_->tensors()->Get(i), model_->buffers(), - error_reporter_, &context_->tensors[i]); - if (status != kTfLiteOk) { - TF_LITE_REPORT_ERROR(error_reporter_, "Failed to initialize tensor %d", - i); - return kTfLiteError; - } - } - - return kTfLiteOk; -} - -size_t MicroAllocator::used_bytes() const { - if (active_) { - return 0; - } - TF_LITE_REPORT_ERROR(error_reporter_, "Total buffer usage: %d bytes", - memory_allocator_->GetUsedBytes()); - TF_LITE_REPORT_ERROR(error_reporter_, "Head usage: %d bytes", - memory_allocator_->GetHeadUsedBytes()); - TF_LITE_REPORT_ERROR(error_reporter_, "Tail usage: %d bytes", - memory_allocator_->GetTailUsedBytes()); - return memory_allocator_->GetUsedBytes(); -} - MicroAllocator::MicroAllocator(TfLiteContext* context, const Model* model, uint8_t* tensor_arena, size_t arena_size, ErrorReporter* error_reporter) @@ -448,7 +401,8 @@ MicroAllocator::MicroAllocator(TfLiteContext* context, const Model* model, // destructed as it's the root allocator. memory_allocator_ = CreateInPlaceSimpleMemoryAllocator( error_reporter, aligned_arena, aligned_arena_size); - TfLiteStatus status = Init(); + + TfLiteStatus status = InitGraphAndContextTensorData(); // TODO(b/147871299): Consider improving this code. A better way of handling // failures in the constructor is to have a static function that returns a // pointer to the class. If allocation failed, a nullptr will be returned. @@ -461,88 +415,15 @@ MicroAllocator::MicroAllocator(TfLiteContext* context, const Model* model, } } -TfLiteStatus MicroAllocator::AllocateNodeAndRegistrations( +TfLiteStatus MicroAllocator::InitializeFromFlatbuffer( const OpResolver& op_resolver, NodeAndRegistration** node_and_registrations) { if (!active_) { return kTfLiteError; } - - auto* output = reinterpret_cast<NodeAndRegistration*>( - memory_allocator_->AllocateFromTail( - sizeof(NodeAndRegistration) * subgraph_->operators()->size(), - alignof(NodeAndRegistration))); - if (output == nullptr) { - TF_LITE_REPORT_ERROR( - error_reporter_, - "Failed to allocate memory for node_and_registrations."); - return kTfLiteError; - } - TfLiteStatus status = kTfLiteOk; - auto* opcodes = model_->operator_codes(); - MicroBuiltinDataAllocator builtin_data_allocator(memory_allocator_); - for (size_t i = 0; i < subgraph_->operators()->size(); ++i) { - const auto* op = subgraph_->operators()->Get(i); - size_t index = op->opcode_index(); - if (index >= opcodes->size()) { - TF_LITE_REPORT_ERROR(error_reporter_, - "Missing registration for opcode_index %d\n", index); - return kTfLiteError; - } - auto* opcode = (*opcodes)[index]; - status = GetRegistrationFromOpCode(opcode, op_resolver, error_reporter_, - &(output[i].registration)); - if (status != kTfLiteOk) { - TF_LITE_REPORT_ERROR(error_reporter_, - "Failed to get registration from op code %s\n ", - EnumNameBuiltinOperator(opcode->builtin_code())); - return status; - } - const auto* registration = output[i].registration; - if (registration == nullptr) { - TF_LITE_REPORT_ERROR(error_reporter_, "Skipping op for opcode_index %d\n", - index); - return kTfLiteError; - } - BuiltinOperator op_type = - static_cast<BuiltinOperator>(registration->builtin_code); - - if (op_type != BuiltinOperator_CUSTOM && op->custom_options()) { - TF_LITE_REPORT_ERROR( - error_reporter_, - "Unsupported behavior: found builtin operator %s with custom " - "options.\n", - EnumNameBuiltinOperator(op_type)); - return kTfLiteError; - } - - const char* custom_data = nullptr; - size_t custom_data_size = 0; - unsigned char* builtin_data = nullptr; - if (op->custom_options()) { - custom_data = reinterpret_cast<const char*>(op->custom_options()->data()); - custom_data_size = op->custom_options()->size(); - } else { - TF_LITE_ENSURE_STATUS(ParseOpData(op, op_type, error_reporter_, - &builtin_data_allocator, - (void**)(&builtin_data))); - } - - // Disregard const qualifier to workaround with existing API. - TfLiteIntArray* inputs_array = const_cast<TfLiteIntArray*>( - reinterpret_cast<const TfLiteIntArray*>(op->inputs())); - TfLiteIntArray* outputs_array = const_cast<TfLiteIntArray*>( - reinterpret_cast<const TfLiteIntArray*>(op->outputs())); - - TfLiteNode* node = &(output[i].node); - *node = {}; - node->inputs = inputs_array; - node->outputs = outputs_array; - node->builtin_data = reinterpret_cast<void*>(builtin_data); - node->custom_initial_data = custom_data; - node->custom_initial_data_size = custom_data_size; - } - *node_and_registrations = output; + TF_LITE_ENSURE_STATUS(AllocateNodeAndRegistrations(node_and_registrations)); + TF_LITE_ENSURE_STATUS(PrepareNodeAndRegistrationDataFromFlatbuffer( + op_resolver, *node_and_registrations)); return kTfLiteOk; } @@ -677,4 +558,151 @@ void* MicroAllocator::GetScratchBuffer(int buffer_idx) const { return scratch_buffer_handles_[scratch_buffer_count_ - buffer_idx - 1].data; } +size_t MicroAllocator::used_bytes() const { + if (active_) { + return 0; + } + TF_LITE_REPORT_ERROR(error_reporter_, "Total buffer usage: %d bytes", + memory_allocator_->GetUsedBytes()); + TF_LITE_REPORT_ERROR(error_reporter_, "Head usage: %d bytes", + memory_allocator_->GetHeadUsedBytes()); + TF_LITE_REPORT_ERROR(error_reporter_, "Tail usage: %d bytes", + memory_allocator_->GetTailUsedBytes()); + return memory_allocator_->GetUsedBytes(); +} + +TfLiteStatus MicroAllocator::InitGraphAndContextTensorData() { + auto* subgraphs = model_->subgraphs(); + if (subgraphs->size() != 1) { + TF_LITE_REPORT_ERROR(error_reporter_, + "Only 1 subgraph is currently supported.\n"); + return kTfLiteError; + } + subgraph_ = (*subgraphs)[0]; + + TF_LITE_ENSURE_STATUS(AllocateTfLiteTensorArray()); + TF_LITE_ENSURE_STATUS(PopulateTfLiteTensorArrayFromFlatbuffer()); + + return kTfLiteOk; +} + +TfLiteStatus MicroAllocator::AllocateTfLiteTensorArray() { + context_->tensors_size = subgraph_->tensors()->size(); + context_->tensors = + reinterpret_cast<TfLiteTensor*>(memory_allocator_->AllocateFromTail( + sizeof(TfLiteTensor) * context_->tensors_size, + alignof(TfLiteTensor))); + if (context_->tensors == nullptr) { + TF_LITE_REPORT_ERROR( + error_reporter_, + "Failed to allocate memory for context->tensors, %d bytes required", + sizeof(TfLiteTensor) * context_->tensors_size); + return kTfLiteError; + } + return kTfLiteOk; +} + +TfLiteStatus MicroAllocator::PopulateTfLiteTensorArrayFromFlatbuffer() { + // Initialize tensors in context_ using the flatbuffer for quantization data. + for (size_t i = 0; i < subgraph_->tensors()->size(); ++i) { + TfLiteStatus status = internal::InitializeTfLiteTensorFromFlatbuffer( + memory_allocator_, *subgraph_->tensors()->Get(i), model_->buffers(), + error_reporter_, &context_->tensors[i]); + if (status != kTfLiteOk) { + TF_LITE_REPORT_ERROR(error_reporter_, "Failed to initialize tensor %d", + i); + return kTfLiteError; + } + } + return kTfLiteOk; +} + +TfLiteStatus MicroAllocator::AllocateNodeAndRegistrations( + NodeAndRegistration** node_and_registrations) { + NodeAndRegistration* output = reinterpret_cast<NodeAndRegistration*>( + memory_allocator_->AllocateFromTail( + sizeof(NodeAndRegistration) * subgraph_->operators()->size(), + alignof(NodeAndRegistration))); + if (output == nullptr) { + TF_LITE_REPORT_ERROR( + error_reporter_, + "Failed to allocate memory for node_and_registrations."); + return kTfLiteError; + } + *node_and_registrations = output; + return kTfLiteOk; +} + +TfLiteStatus MicroAllocator::PrepareNodeAndRegistrationDataFromFlatbuffer( + const OpResolver& op_resolver, + NodeAndRegistration* node_and_registrations) { + TfLiteStatus status = kTfLiteOk; + auto* opcodes = model_->operator_codes(); + MicroBuiltinDataAllocator builtin_data_allocator(memory_allocator_); + for (size_t i = 0; i < subgraph_->operators()->size(); ++i) { + const auto* op = subgraph_->operators()->Get(i); + const size_t index = op->opcode_index(); + if (index >= opcodes->size()) { + TF_LITE_REPORT_ERROR(error_reporter_, + "Missing registration for opcode_index %d\n", index); + return kTfLiteError; + } + auto* opcode = (*opcodes)[index]; + status = + GetRegistrationFromOpCode(opcode, op_resolver, error_reporter_, + &(node_and_registrations[i].registration)); + if (status != kTfLiteOk) { + TF_LITE_REPORT_ERROR(error_reporter_, + "Failed to get registration from op code %s\n ", + EnumNameBuiltinOperator(opcode->builtin_code())); + return status; + } + const auto* registration = node_and_registrations[i].registration; + if (registration == nullptr) { + TF_LITE_REPORT_ERROR(error_reporter_, "Skipping op for opcode_index %d\n", + index); + return kTfLiteError; + } + BuiltinOperator op_type = + static_cast<BuiltinOperator>(registration->builtin_code); + + if (op_type != BuiltinOperator_CUSTOM && op->custom_options()) { + TF_LITE_REPORT_ERROR( + error_reporter_, + "Unsupported behavior: found builtin operator %s with custom " + "options.\n", + EnumNameBuiltinOperator(op_type)); + return kTfLiteError; + } + + const char* custom_data = nullptr; + size_t custom_data_size = 0; + unsigned char* builtin_data = nullptr; + if (op->custom_options()) { + custom_data = reinterpret_cast<const char*>(op->custom_options()->data()); + custom_data_size = op->custom_options()->size(); + } else { + TF_LITE_ENSURE_STATUS(ParseOpData(op, op_type, error_reporter_, + &builtin_data_allocator, + (void**)(&builtin_data))); + } + + // Disregard const qualifier to workaround with existing API. + TfLiteIntArray* inputs_array = const_cast<TfLiteIntArray*>( + reinterpret_cast<const TfLiteIntArray*>(op->inputs())); + TfLiteIntArray* outputs_array = const_cast<TfLiteIntArray*>( + reinterpret_cast<const TfLiteIntArray*>(op->outputs())); + + TfLiteNode* node = &(node_and_registrations[i].node); + *node = {}; + node->inputs = inputs_array; + node->outputs = outputs_array; + node->builtin_data = reinterpret_cast<void*>(builtin_data); + node->custom_initial_data = custom_data; + node->custom_initial_data_size = custom_data_size; + } + + return kTfLiteOk; +} + } // namespace tflite diff --git a/tensorflow/lite/micro/micro_allocator.h b/tensorflow/lite/micro/micro_allocator.h index 6a6e1e03e53..1dd90c36a4d 100644 --- a/tensorflow/lite/micro/micro_allocator.h +++ b/tensorflow/lite/micro/micro_allocator.h @@ -15,9 +15,13 @@ limitations under the License. #ifndef TENSORFLOW_LITE_MICRO_MICRO_ALLOCATOR_H_ #define TENSORFLOW_LITE_MICRO_MICRO_ALLOCATOR_H_ +#include <cstddef> +#include <cstdint> + +#include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/core/api/error_reporter.h" -#include "tensorflow/lite/core/api/flatbuffer_conversions.h" +#include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/micro/simple_memory_allocator.h" #include "tensorflow/lite/schema/schema_generated.h" @@ -26,9 +30,9 @@ namespace tflite { // Namespace used for unittests. namespace internal { -// Sets up all of the data structure members for a runtime tensor -// based on the contents of a serialized tensor. -TfLiteStatus InitializeRuntimeTensor( +// Sets up all of the data structure members for a TfLiteTensor based on the +// contents of a serialized tensor in the flatbuffer. +TfLiteStatus InitializeTfLiteTensorFromFlatbuffer( SimpleMemoryAllocator* allocator, const tflite::Tensor& flatbuffer_tensor, const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers, ErrorReporter* error_reporter, TfLiteTensor* result); @@ -82,6 +86,15 @@ class MicroAllocator { uint8_t* tensor_arena, size_t arena_size, ErrorReporter* error_reporter); + // Run through the model flatbuffer data (loaded from the TfLiteModel + // instance) to allocate nodes and registrations. We need to keep them for the + // entire life time of the model to allow persistent tensors. This method + // needs to be called before FinishTensorAllocation method. This method also + // allocates any internal Op data that is required from the flatbuffer. + TfLiteStatus InitializeFromFlatbuffer( + const OpResolver& op_resolver, + NodeAndRegistration** node_and_registrations); + // Runs through the model and allocates all necessary input, output and // intermediate tensors. // WARNING: doing any allocation after calling this method has the risk of @@ -89,17 +102,6 @@ class MicroAllocator { // called in this class. TfLiteStatus FinishTensorAllocation(); - // Returns the arena usage in bytes, only available after - // `FinishTensorAllocation`. Otherwise, it will return 0. - size_t used_bytes() const; - - // Run through the model to allocate nodes and registrations. We need to keep - // them for the entire life time of the model to allow persistent tensors. - // This method needs to be called before FinishTensorAllocation method. - TfLiteStatus AllocateNodeAndRegistrations( - const OpResolver& op_resolver, - NodeAndRegistration** node_and_registrations); - // Allocates persistent buffer which has the same life time as the allocator. // The memory is immediately available and is allocated from the tail of the // arena. @@ -116,8 +118,38 @@ class MicroAllocator { // Returns the pointer to the planned scratch buffer. void* GetScratchBuffer(int buffer_idx) const; + // Returns the arena usage in bytes, only available after + // `FinishTensorAllocation`. Otherwise, it will return 0. + size_t used_bytes() const; + + protected: + // Allocates an array in the arena to hold pointers to the tensors required + // to initialize and prepare a model. These allocations are stored and + // populated on the context. + TfLiteStatus AllocateTfLiteTensorArray(); + + // Populates content on the list of tensor pointers required to initialize and + // prepare a model from data in the flatbuffer (loaded from the TfLiteModel + // instance). Persistent data (e.g. quantization params) is allocated from the + // arena. + TfLiteStatus PopulateTfLiteTensorArrayFromFlatbuffer(); + + // Allocates an array in the arena to hold pointers to the node and + // registration pointers required to represent the inference graph of the + // model. + TfLiteStatus AllocateNodeAndRegistrations( + NodeAndRegistration** node_and_registrations); + + // Populates node and registration pointers representing the inference graph + // of the model from values inside the flatbuffer (loaded from the TfLiteModel + // instance). Persistent data (e.g. operator data) is allocated from the + // arena. + TfLiteStatus PrepareNodeAndRegistrationDataFromFlatbuffer( + const OpResolver& op_resolver, + NodeAndRegistration* node_and_registrations); + private: - TfLiteStatus Init(); + TfLiteStatus InitGraphAndContextTensorData(); const Model* model_; // A simple memory allocator that always allocate from the arena tail. diff --git a/tensorflow/lite/micro/micro_allocator_test.cc b/tensorflow/lite/micro/micro_allocator_test.cc index 78419edbbf9..b34b2dc2866 100644 --- a/tensorflow/lite/micro/micro_allocator_test.cc +++ b/tensorflow/lite/micro/micro_allocator_test.cc @@ -77,7 +77,7 @@ TF_LITE_MICRO_TEST(TestInitializeRuntimeTensor) { TfLiteTensor allocated_tensor; TF_LITE_MICRO_EXPECT_EQ( - kTfLiteOk, tflite::internal::InitializeRuntimeTensor( + kTfLiteOk, tflite::internal::InitializeTfLiteTensorFromFlatbuffer( &simple_allocator, *tensor, buffers, micro_test::reporter, &allocated_tensor)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteInt32, allocated_tensor.type); @@ -103,7 +103,7 @@ TF_LITE_MICRO_TEST(TestInitializeQuantizedTensor) { TfLiteTensor allocated_tensor; TF_LITE_MICRO_EXPECT_EQ( - kTfLiteOk, tflite::internal::InitializeRuntimeTensor( + kTfLiteOk, tflite::internal::InitializeTfLiteTensorFromFlatbuffer( &simple_allocator, *tensor, buffers, micro_test::reporter, &allocated_tensor)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteInt32, allocated_tensor.type); @@ -129,7 +129,7 @@ TF_LITE_MICRO_TEST(TestMissingQuantization) { TfLiteTensor allocated_tensor; TF_LITE_MICRO_EXPECT_EQ( - kTfLiteOk, tflite::internal::InitializeRuntimeTensor( + kTfLiteOk, tflite::internal::InitializeTfLiteTensorFromFlatbuffer( &simple_allocator, *tensor, buffers, micro_test::reporter, &allocated_tensor)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteInt32, allocated_tensor.type); diff --git a/tensorflow/lite/micro/micro_error_reporter.cc b/tensorflow/lite/micro/micro_error_reporter.cc index bea3dc8db4c..6d8361cd25a 100644 --- a/tensorflow/lite/micro/micro_error_reporter.cc +++ b/tensorflow/lite/micro/micro_error_reporter.cc @@ -15,7 +15,10 @@ limitations under the License. #include "tensorflow/lite/micro/micro_error_reporter.h" +#include <cstdarg> + #ifndef TF_LITE_STRIP_ERROR_STRINGS +#include "tensorflow/lite/micro/debug_log.h" #include "tensorflow/lite/micro/micro_string.h" #endif diff --git a/tensorflow/lite/micro/micro_error_reporter.h b/tensorflow/lite/micro/micro_error_reporter.h index b18c47f4ecb..e2c073a465d 100644 --- a/tensorflow/lite/micro/micro_error_reporter.h +++ b/tensorflow/lite/micro/micro_error_reporter.h @@ -15,9 +15,10 @@ limitations under the License. #ifndef TENSORFLOW_LITE_MICRO_MICRO_ERROR_REPORTER_H_ #define TENSORFLOW_LITE_MICRO_MICRO_ERROR_REPORTER_H_ +#include <cstdarg> + #include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/micro/compatibility.h" -#include "tensorflow/lite/micro/debug_log.h" namespace tflite { diff --git a/tensorflow/lite/micro/micro_interpreter.cc b/tensorflow/lite/micro/micro_interpreter.cc index 2d774d0a139..6b78966020e 100644 --- a/tensorflow/lite/micro/micro_interpreter.cc +++ b/tensorflow/lite/micro/micro_interpreter.cc @@ -14,12 +14,16 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/micro/micro_interpreter.h" +#include <cstdarg> +#include <cstddef> +#include <cstdint> + +#include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/core/api/flatbuffer_conversions.h" +#include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/core/api/tensor_utils.h" -#include "tensorflow/lite/micro/compatibility.h" #include "tensorflow/lite/micro/micro_allocator.h" -#include "tensorflow/lite/micro/micro_optional_debug_tools.h" namespace tflite { namespace { @@ -161,7 +165,7 @@ void MicroInterpreter::CorrectTensorDataEndianness(T* data, int32_t size) { } TfLiteStatus MicroInterpreter::AllocateTensors() { - TF_LITE_ENSURE_OK(&context_, allocator_.AllocateNodeAndRegistrations( + TF_LITE_ENSURE_OK(&context_, allocator_.InitializeFromFlatbuffer( op_resolver_, &node_and_registrations_)); // Only allow AllocatePersistentBuffer in Init stage. diff --git a/tensorflow/lite/micro/micro_interpreter.h b/tensorflow/lite/micro/micro_interpreter.h index 15f53b681a6..180a557668e 100644 --- a/tensorflow/lite/micro/micro_interpreter.h +++ b/tensorflow/lite/micro/micro_interpreter.h @@ -15,6 +15,10 @@ limitations under the License. #ifndef TENSORFLOW_LITE_MICRO_MICRO_INTERPRETER_H_ #define TENSORFLOW_LITE_MICRO_MICRO_INTERPRETER_H_ +#include <cstddef> +#include <cstdint> + +#include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/core/api/op_resolver.h" diff --git a/tensorflow/lite/micro/micro_mutable_op_resolver.h b/tensorflow/lite/micro/micro_mutable_op_resolver.h index ead9be490a3..6c3e9a3331e 100644 --- a/tensorflow/lite/micro/micro_mutable_op_resolver.h +++ b/tensorflow/lite/micro/micro_mutable_op_resolver.h @@ -15,7 +15,10 @@ limitations under the License. #ifndef TENSORFLOW_LITE_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_ #define TENSORFLOW_LITE_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_ +#include <cstring> + #include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/micro/compatibility.h" #include "tensorflow/lite/schema/schema_generated.h" diff --git a/tensorflow/lite/micro/micro_optional_debug_tools.cc b/tensorflow/lite/micro/micro_optional_debug_tools.cc index 42c42aea9f8..daa5d007cdf 100644 --- a/tensorflow/lite/micro/micro_optional_debug_tools.cc +++ b/tensorflow/lite/micro/micro_optional_debug_tools.cc @@ -20,8 +20,17 @@ limitations under the License. #endif #include <cinttypes> +#include <cstddef> +#include <cstdint> +#include <cstdio> +#include <vector> +#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/micro/micro_allocator.h" +#include "tensorflow/lite/micro/micro_interpreter.h" #include "tensorflow/lite/schema/schema_generated.h" + namespace tflite { namespace { diff --git a/tensorflow/lite/micro/simple_memory_allocator.cc b/tensorflow/lite/micro/simple_memory_allocator.cc index be7c469529e..d55e7e87640 100644 --- a/tensorflow/lite/micro/simple_memory_allocator.cc +++ b/tensorflow/lite/micro/simple_memory_allocator.cc @@ -18,10 +18,25 @@ limitations under the License. #include <cstddef> #include <cstdint> +#include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/micro/memory_helpers.h" namespace tflite { +SimpleMemoryAllocator::SimpleMemoryAllocator(ErrorReporter* error_reporter, + uint8_t* buffer_head, + uint8_t* buffer_tail) + : error_reporter_(error_reporter), + buffer_head_(buffer_head), + buffer_tail_(buffer_tail), + head_(buffer_head), + tail_(buffer_tail) {} + +SimpleMemoryAllocator::SimpleMemoryAllocator(ErrorReporter* error_reporter, + uint8_t* buffer, + size_t buffer_size) + : SimpleMemoryAllocator(error_reporter, buffer, buffer + buffer_size) {} + SimpleMemoryAllocator* CreateInPlaceSimpleMemoryAllocator( ErrorReporter* error_reporter, uint8_t* buffer, size_t buffer_size) { SimpleMemoryAllocator tmp = @@ -63,4 +78,28 @@ uint8_t* SimpleMemoryAllocator::AllocateFromTail(size_t size, return aligned_result; } +uint8_t* SimpleMemoryAllocator::GetHead() const { return head_; } + +uint8_t* SimpleMemoryAllocator::GetTail() const { return tail_; } + +size_t SimpleMemoryAllocator::GetHeadUsedBytes() const { + return head_ - buffer_head_; +} + +size_t SimpleMemoryAllocator::GetTailUsedBytes() const { + return buffer_tail_ - tail_; +} + +size_t SimpleMemoryAllocator::GetAvailableMemory() const { + return tail_ - head_; +} + +size_t SimpleMemoryAllocator::GetUsedBytes() const { + return GetBufferSize() - GetAvailableMemory(); +} + +size_t SimpleMemoryAllocator::GetBufferSize() const { + return buffer_tail_ - buffer_head_; +} + } // namespace tflite diff --git a/tensorflow/lite/micro/simple_memory_allocator.h b/tensorflow/lite/micro/simple_memory_allocator.h index ed73104a2c6..5be260f9ed2 100644 --- a/tensorflow/lite/micro/simple_memory_allocator.h +++ b/tensorflow/lite/micro/simple_memory_allocator.h @@ -16,9 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_LITE_MICRO_SIMPLE_MEMORY_ALLOCATOR_H_ #define TENSORFLOW_LITE_MICRO_SIMPLE_MEMORY_ALLOCATOR_H_ +#include <cstddef> #include <cstdint> -#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/core/api/error_reporter.h" namespace tflite { @@ -29,15 +29,9 @@ namespace tflite { class SimpleMemoryAllocator { public: SimpleMemoryAllocator(ErrorReporter* error_reporter, uint8_t* buffer_head, - uint8_t* buffer_tail) - : error_reporter_(error_reporter), - buffer_head_(buffer_head), - buffer_tail_(buffer_tail), - head_(buffer_head), - tail_(buffer_tail) {} + uint8_t* buffer_tail); SimpleMemoryAllocator(ErrorReporter* error_reporter, uint8_t* buffer, - size_t buffer_size) - : SimpleMemoryAllocator(error_reporter, buffer, buffer + buffer_size) {} + size_t buffer_size); // Allocates memory starting at the head of the arena (lowest address and // moving upwards). @@ -46,16 +40,17 @@ class SimpleMemoryAllocator { // moving downwards). uint8_t* AllocateFromTail(size_t size, size_t alignment); - uint8_t* GetHead() const { return head_; } - uint8_t* GetTail() const { return tail_; } - size_t GetAvailableMemory() const { return tail_ - head_; } - size_t GetUsedBytes() const { return GetBufferSize() - GetAvailableMemory(); } + uint8_t* GetHead() const; + uint8_t* GetTail() const; - size_t GetHeadUsedBytes() const { return head_ - buffer_head_; } - size_t GetTailUsedBytes() const { return buffer_tail_ - tail_; } + size_t GetHeadUsedBytes() const; + size_t GetTailUsedBytes() const; + + size_t GetAvailableMemory() const; + size_t GetUsedBytes() const; private: - size_t GetBufferSize() const { return buffer_tail_ - buffer_head_; } + size_t GetBufferSize() const; ErrorReporter* error_reporter_; uint8_t* buffer_head_; diff --git a/tensorflow/lite/micro/test_helpers.cc b/tensorflow/lite/micro/test_helpers.cc index 77a1cc82f3b..c2607cd32c6 100644 --- a/tensorflow/lite/micro/test_helpers.cc +++ b/tensorflow/lite/micro/test_helpers.cc @@ -15,10 +15,15 @@ limitations under the License. #include "tensorflow/lite/micro/test_helpers.h" +#include <cstdarg> +#include <cstddef> +#include <cstdint> #include <initializer_list> +#include <new> +#include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/core/api/tensor_utils.h" +#include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/kernels/internal/compatibility.h" #include "tensorflow/lite/micro/micro_utils.h" #include "tensorflow/lite/schema/schema_generated.h" diff --git a/tensorflow/lite/micro/test_helpers.h b/tensorflow/lite/micro/test_helpers.h index 010e1f9e336..2d1d2895db0 100644 --- a/tensorflow/lite/micro/test_helpers.h +++ b/tensorflow/lite/micro/test_helpers.h @@ -18,8 +18,10 @@ limitations under the License. // Useful functions for writing tests. +#include <cstdint> + +#include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/kernels/internal/compatibility.h" #include "tensorflow/lite/micro/micro_utils.h" #include "tensorflow/lite/schema/schema_generated.h" diff --git a/tensorflow/lite/micro/testing/BUILD b/tensorflow/lite/micro/testing/BUILD index 245e919bb05..8db93c6eeac 100644 --- a/tensorflow/lite/micro/testing/BUILD +++ b/tensorflow/lite/micro/testing/BUILD @@ -22,6 +22,7 @@ cc_library( deps = [ "//tensorflow/lite/c:common", "//tensorflow/lite/core/api", + "//tensorflow/lite/micro:micro_error_reporter", "//tensorflow/lite/micro:micro_framework", "//tensorflow/lite/micro:micro_utils", ], @@ -43,8 +44,7 @@ cc_library( "micro_benchmark.h", ], deps = [ - "//tensorflow/lite/c:common", - "//tensorflow/lite/micro:micro_framework", + "//tensorflow/lite/micro:micro_error_reporter", "//tensorflow/lite/micro:micro_time", ], ) diff --git a/tensorflow/lite/micro/tools/make/Makefile b/tensorflow/lite/micro/tools/make/Makefile index 1331163a410..13761cca28b 100644 --- a/tensorflow/lite/micro/tools/make/Makefile +++ b/tensorflow/lite/micro/tools/make/Makefile @@ -94,8 +94,10 @@ endif # runtime that can be linked in to other programs. MICROLITE_LIB_NAME := libtensorflow-microlite.a +# These two must be defined before we include the target specific Makefile.inc +# because we filter out the examples that are not supported for those targets. +# See targets/xtensa_xpg_makefile.inc for an example. MICRO_LITE_EXAMPLE_TESTS := $(shell find tensorflow/lite/micro/examples/ -name Makefile.inc) - MICRO_LITE_BENCHMARKS := $(wildcard tensorflow/lite/micro/benchmarks/Makefile.inc) MICROLITE_TEST_SRCS := \ @@ -137,6 +139,7 @@ tensorflow/lite/c/common.h \ tensorflow/lite/core/api/error_reporter.h \ tensorflow/lite/core/api/flatbuffer_conversions.h \ tensorflow/lite/core/api/op_resolver.h \ +tensorflow/lite/core/api/profiler.h \ tensorflow/lite/core/api/tensor_utils.h \ tensorflow/lite/kernels/internal/common.h \ tensorflow/lite/kernels/internal/compatibility.h \ @@ -237,6 +240,7 @@ include $(MAKEFILE_DIR)/third_party_downloads.inc THIRD_PARTY_DOWNLOADS := $(eval $(call add_third_party_download,$(GEMMLOWP_URL),$(GEMMLOWP_MD5),gemmlowp,)) $(eval $(call add_third_party_download,$(FLATBUFFERS_URL),$(FLATBUFFERS_MD5),flatbuffers,)) +$(eval $(call add_third_party_download,$(RUY_URL),$(RUY_MD5),ruy,)) # These target-specific makefiles should modify or replace options like # CXXFLAGS or LIBS to work for a specific targeted architecture. All logic diff --git a/tensorflow/lite/micro/tools/make/download_and_extract.sh b/tensorflow/lite/micro/tools/make/download_and_extract.sh index da9a474b004..2f602ce9d4c 100755 --- a/tensorflow/lite/micro/tools/make/download_and_extract.sh +++ b/tensorflow/lite/micro/tools/make/download_and_extract.sh @@ -90,7 +90,7 @@ patch_cifar10_dataset() { } build_embarc_mli() { - gmake -j 4 -C ${1}/lib/make TCF_FILE=${2} + make -j 4 -C ${1}/lib/make TCF_FILE=${2} } # Main function handling the download, verify, extract, and patch process. @@ -173,7 +173,12 @@ download_and_extract() { elif [[ ${action} == "patch_cifar10_dataset" ]]; then patch_cifar10_dataset ${dir} elif [[ ${action} == "build_embarc_mli" ]]; then - build_embarc_mli ${dir} ${action_param1} + if [[ "${action_param1}" == *.tcf ]]; then + cp ${action_param1} ${dir}/hw/arc.tcf + build_embarc_mli ${dir} ../../hw/arc.tcf + else + build_embarc_mli ${dir} ${action_param1} + fi elif [[ ${action} ]]; then echo "Unknown action '${action}'" exit 1 diff --git a/tensorflow/lite/micro/tools/make/ext_libs/arc_mli.inc b/tensorflow/lite/micro/tools/make/ext_libs/arc_mli.inc new file mode 100644 index 00000000000..5dbb91dd368 --- /dev/null +++ b/tensorflow/lite/micro/tools/make/ext_libs/arc_mli.inc @@ -0,0 +1,104 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Settings for embARC MLI library for ARC platform. + +ifeq ($(TARGET_ARCH), arc) + +# MLI Library is used by default for ARC platform whenever it is possible. +# To use TFLM reference implementation MLI should be intentionally turned off +# by passing 'no_arc_mli' tag (make -f <tflm_main_makefile> TAGS=no_arc_mli ...) +ifeq ($(filter no_arc_mli,$(ALL_TAGS)),) + +ALL_TAGS += arc_mli + +ifeq ($(BUILD_ARC_MLI),true) + MLI_LIB_DIR ?= arc_mli_$(basename $(TCF_FILE_NAME)) + + $(eval $(call add_third_party_download,$(EMBARC_MLI_URL),$(EMBARC_MLI_MD5),$(MLI_LIB_DIR),build_embarc_mli,$(TCF_FILE))) + + MLI_INCLUDE_FOLDER = $(MLI_LIB_DIR)/include + MLI_LIB = third_party/$(MLI_LIB_DIR)/bin/libmli.a + MICROLITE_LIBS += $(MAKEFILE_DIR)/downloads/$(MLI_LIB_DIR)/bin/libmli.a + + THIRD_PARTY_CC_HDRS += \ + third_party/$(MLI_LIB_DIR)/LICENSE +else +ifneq ($(ARC_MLI_PRE_COMPILED_TARGET),) + MLI_LIB_DIR ?= arc_mli_package + $(eval $(call add_third_party_download,$(EMBARC_MLI_PRE_COMPILED_URL),$(EMBARC_MLI_PRE_COMPILED_MD5),$(MLI_LIB_DIR),)) + + MLI_INCLUDE_FOLDER = $(MLI_LIB_DIR)/include + MLI_LIB = third_party/$(MLI_LIB_DIR)/bin/$(ARC_MLI_PRE_COMPILED_TARGET)/release/libmli.a + MICROLITE_LIBS += $(MAKEFILE_DIR)/downloads/$(MLI_LIB_DIR)/bin/$(ARC_MLI_PRE_COMPILED_TARGET)/release/libmli.a + + THIRD_PARTY_CC_HDRS += \ + third_party/$(MLI_LIB_DIR)/LICENSE +else +$(error Target for pre compiled ARC MLI library is not defined) +endif +endif + + THIRD_PARTY_CC_HDRS += $(MLI_LIB) + GENERATED_PROJECT_LIBS += $(MLI_LIB) + + INCLUDES += \ + -I$(MAKEFILE_DIR)/downloads/$(MLI_INCLUDE_FOLDER) \ + -I$(MAKEFILE_DIR)/downloads/$(MLI_INCLUDE_FOLDER)/api + + GENERATED_PROJECT_INCLUDES += \ + -I. \ + -I./third_party/$(MLI_INCLUDE_FOLDER) \ + -I./third_party/$(MLI_INCLUDE_FOLDER)/api + + + THIRD_PARTY_CC_HDRS += \ + third_party/$(MLI_INCLUDE_FOLDER)/mli_api.h \ + third_party/$(MLI_INCLUDE_FOLDER)/mli_config.h \ + third_party/$(MLI_INCLUDE_FOLDER)/mli_types.h \ + third_party/$(MLI_INCLUDE_FOLDER)/api/mli_helpers_api.h \ + third_party/$(MLI_INCLUDE_FOLDER)/api/mli_kernels_api.h \ + third_party/$(MLI_INCLUDE_FOLDER)/api/mli_krn_avepool_spec_api.h \ + third_party/$(MLI_INCLUDE_FOLDER)/api/mli_krn_conv2d_spec_api.h \ + third_party/$(MLI_INCLUDE_FOLDER)/api/mli_krn_depthwise_conv2d_spec_api.h \ + third_party/$(MLI_INCLUDE_FOLDER)/api/mli_krn_maxpool_spec_api.h \ + third_party/$(MLI_INCLUDE_FOLDER)/api/mli_mov_api.h + + MICROLITE_CC_HDRS += tensorflow/lite/micro/kernels/arc_mli/scratch_buffers.h + MICROLITE_CC_SRCS += tensorflow/lite/micro/kernels/arc_mli/scratch_buffers.cc + MICROLITE_CC_HDRS += tensorflow/lite/micro/kernels/arc_mli/scratch_buf_mgr.h + MICROLITE_CC_SRCS += tensorflow/lite/micro/kernels/arc_mli/scratch_buf_mgr.cc + MICROLITE_CC_HDRS += tensorflow/lite/micro/kernels/arc_mli/mli_slicers.h + MICROLITE_CC_SRCS += tensorflow/lite/micro/kernels/arc_mli/mli_slicers.cc + MICROLITE_CC_HDRS += tensorflow/lite/micro/kernels/arc_mli/mli_tf_utils.h + + + MICROLITE_TEST_SRCS += $(wildcard tensorflow/lite/micro/kernels/arc_mli/*test.cc) + + ARC_MLI_TESTS := conv depthwise_conv pooling fully_connected + ARC_MLI_TESTS += $(foreach TEST,$(ARC_MLI_TESTS), $(TEST)_slicing) + +generate_arc_mli_test_projects: $(foreach TEST,$(ARC_MLI_TESTS), generate_kernel_$(TEST)_test_make_project) + + ARC_EXTRA_APP_SETTINGS += \ + \nMLI_ONLY ?= false\n\ + \nifeq \($(DLR)\(MLI_ONLY\), true\)\ + \nCCFLAGS += -DTF_LITE_STRIP_REFERENCE_IMPL\ + \nCXXFLAGS += -DTF_LITE_STRIP_REFERENCE_IMPL\ + \nendif\n + + + +endif # no_embarc_mli +endif # TARGET_ARCH diff --git a/tensorflow/lite/micro/tools/make/helper_functions.inc b/tensorflow/lite/micro/tools/make/helper_functions.inc index 09771419843..1cf9afa8794 100644 --- a/tensorflow/lite/micro/tools/make/helper_functions.inc +++ b/tensorflow/lite/micro/tools/make/helper_functions.inc @@ -130,24 +130,37 @@ endef define generate_arc_project ifeq ($(TARGET_ARCH), arc) -$(PRJDIR)$(3)/$(1)/Makefile: tensorflow/lite/micro/tools/make/templates/Makefile.tpl + +$(PRJDIR)$(3)/$(1)/Makefile: tensorflow/lite/micro/tools/make/templates/arc/arc_app_makefile.tpl @mkdir -p $$(dir $$@) @sed -E 's#\%\{SRCS\}\%#$(4)#g' $$< | \ - sed -E '1 i\CC = ccac\nCXX = ccac\nLD = ccac\n' | \ + sed -E 's#\%\{CC\}\%#$(CC_TOOL)#g' | \ + sed -E 's#\%\{CXX\}\%#$(CXX_TOOL)#g' | \ + sed -E 's#\%\{LD\}\%#$(LD_TOOL)#g' | \ sed -E 's#\%\{EXECUTABLE\}\%#$(3).elf#g' | \ sed -E 's#\%\{LINKER_FLAGS\}\%#$(6)#g' | \ sed -E 's#\%\{CXX_FLAGS\}\%#$(7)#g' | \ - sed -E 's#\%\{CC_FLAGS\}\%#$(8)#g' > $$@ + sed -E 's#\%\{CC_FLAGS\}\%#$(8)#g' | \ + sed -E 's#\%\{EXTRA_APP_SETTINGS\}\%#$(ARC_EXTRA_APP_SETTINGS)#g' | \ + sed -E 's#\%\{EXTRA_APP_RULES\}\%#$(ARC_EXTRA_APP_RULES)#g' | \ + sed -E 's#\%\{BIN_DEPEND\}\%#$(ARC_BIN_DEPEND)#g' | \ + sed -E 's#\%\{BIN_RULE\}\%#$(ARC_BIN_RULE)#g' | \ + sed -E 's#\%\{EXTRA_RM_TARGETS\}\%#$(ARC_EXTRA_RM_TARGETS)#g' | \ + sed -E 's#\%\{APP_RUN_CMD\}\%#$(ARC_APP_RUN_CMD)#g' | \ + sed -E 's#\%\{APP_DEBUG_CMD\}\%#$(ARC_APP_DEBUG_CMD)#g' | \ + sed -E 's#\%\{EXTRA_EXECUTE_RULES\}\%#$(ARC_EXTRA_EXECUTE_RULES)#g' > $$@ - -# Special rule to copy TCF in case the local filesystem file name has been defined -ifneq ($(TCF_FILE_NAME), ) -$(PRJDIR)$(3)/$(1)/$(TCF_FILE_NAME): $(TCF_FILE) +$(PRJDIR)$(3)/$(1)/%: tensorflow/lite/micro/tools/make/templates/arc/%.tpl @cp $$< $$@ -endif + +$(foreach var,$(ARC_TARGET_FILES_DIRS),$(eval $(call path_changing_copy_file,$(PRJDIR)$(3)/$(1),$(var)))) + endif endef + + + # Creates a set of rules to build a standalone Arduino project for an # executable, including all of the source and header files required in a # separate folder and a simple makefile. diff --git a/tensorflow/lite/micro/tools/make/targets/apollo3evb_makefile.inc b/tensorflow/lite/micro/tools/make/targets/apollo3evb_makefile.inc index 8671df5864f..5214b06a36f 100644 --- a/tensorflow/lite/micro/tools/make/targets/apollo3evb_makefile.inc +++ b/tensorflow/lite/micro/tools/make/targets/apollo3evb_makefile.inc @@ -40,7 +40,6 @@ $(MAKEFILE_DIR)/downloads/$(AM_SDK_DEST)/$(SF_BSPS_DEST): $(MAKEFILE_DIR)/downlo -fmessage-length=0 \ -fno-exceptions \ -fno-unwind-tables \ - -fno-builtin \ -ffunction-sections \ -fdata-sections \ -funsigned-char \ diff --git a/tensorflow/lite/micro/tools/make/targets/arc/README.md b/tensorflow/lite/micro/tools/make/targets/arc/README.md new file mode 100644 index 00000000000..366aede5db4 --- /dev/null +++ b/tensorflow/lite/micro/tools/make/targets/arc/README.md @@ -0,0 +1,315 @@ +# Building TensorFlow Lite for Microcontrollers for Synopsys DesignWare ARC EM/HS Processors + +This document contains the general information on building and running +TensorFlow Lite Micro for targets based on the Synopsys ARC EM/HS Processors. + +## Table of Contents + +- [Install the Synopsys DesignWare ARC MetaWare Development Toolkit](#install-the-synopsys-designware-arc-metaWare-development-toolkit) +- [ARC EM Software Development Platform (ARC EM SDP)](#ARC-EM-Software-Development-Platform-ARC-EM-SDP) +- [Custom ARC EM or HS Platform](#Custom-ARC-EMHS-Platform) + +## Install the Synopsys DesignWare ARC MetaWare Development Toolkit + +The Synopsys DesignWare ARC MetaWare Development Toolkit (MWDT) is required to +build and run Tensorflow Lite Micro applications for all ARC EM/HS targets. + +To license MWDT, please see further details +[here](https://www.synopsys.com/dw/ipdir.php?ds=sw_metaware) + +To request an evaluation version of MWDT, please use the +[Synopsys Eval Portal](https://eval.synopsys.com/) and follow the link for the +MetaWare Development Toolkit (Important: Do not confuse this with MetaWare EV +Development Toolkit or MetaWare Lite options also available on this page) + +Run the downloaded installer and follow the instructions to set up the toolchain +on your platform. + +TensorFlow Lite for Microcontrollers builds are divided into two phases: +Application Project Generation and Application Project Building/Running. The +former phase requires \*nix environment while the latter does not. + +For basic project generation targeting +[ARC EM Software Development Platform](#ARC-EM-Software-Development-Platform-ARC-EM-SDP), +MetaWare is NOT required for the Project Generation Phase. However, it is +required in case the following: - For project generation for custom (not EM SDP) +targets - To build microlib target library with all required TFLM objects for +external use + +Please consider the above when choosing whether to install Windows or Linux or +both versions of MWDT + +## ARC EM Software Development Platform (ARC EM SDP) + +This section describes how to deploy on an +[ARC EM SDP board](https://www.synopsys.com/dw/ipdir.php?ds=arc-em-software-development-platform) + +### Initial Setup + +To use the EM SDP, you need the following hardware and software: + +#### ARC EM SDP + +More information on the platform, including ordering information, can be found +[here](https://www.synopsys.com/dw/ipdir.php?ds=arc-em-software-development-platform). + +#### MetaWare Development Toolkit + +See +[Install the Synopsys DesignWare ARC MetaWare Development Toolkit](#install-the-synopsys-designware-arc-metaWare-development-toolkit) +section for instructions on toolchain installation. + +#### Digilent Adept 2 System Software Package + +If you wish to use the MetaWare Debugger to debug your code, you need to also +install the Digilent Adept 2 software, which includes the necessary drivers for +connecting to the targets. This is available from oficial +[Digilent site](https://reference.digilentinc.com/reference/software/adept/start?redirect=1#software_downloads). +You should install the “System” component, and Runtime. Utilities and SDK are +NOT required. + +Digilent installation is NOT required if you plan to deploy to EM SDP via the SD +card instead of using the debugger. + +#### Make Tool + +A `'make'` tool is required for both phases of deploying Tensorflow Lite Micro +applications on ARC EM SDP: 1. Application project generation 2. Working with +generated application (build and run) + +For the first phase you need an environment and make tool compatible with +Tensorflow Lite for Micro build system. At the moment of this writing, this +requires make >=3.82 and a *nix-like environment which supports shell and native +commands for file manipulations. MWDT toolkit is not required for this phase. + +For the second phase, requirements are less strict. The gmake version delivered +with MetaWare Development Toolkit is sufficient. There are no shell and *nix +command dependencies, so Windows can be used + +#### Serial Terminal Emulation Application + +The Debug UART port of the EM SDP is used to print application output. The USB +connection provides both the debug channel and RS232 transport. You can use any +terminal emulation program (like [PuTTY](https://www.putty.org/)) to view UART +output from the EM SDP. + +#### microSD Card + +If you want to self-boot your application (start it independently from a +debugger connection), you also need a microSD card with a minimum size of 512 MB +and a way to write to the card from your development host + +### Connect the Board + +1. Make sure Boot switches of the board (S3) are configured in the next way: + +Switch # | Switch position +:------: | :-------------: +1 | Low (0) +2 | Low (0) +3 | High (1) +4 | Low (0) + +1. Connect the power supply included in the product package to the ARC EM SDP. +2. Connect the USB cable to connector J10 on the ARC EM SDP (near the RST and + CFG buttons) and to an available USB port on your development host. +3. Determine the COM port assigned to the USB Serial Port (on Windows, using + Device Manager is an easy way to do this) +4. Execute the serial terminal application you installed in the previous step + and open the serial connection with the early defined COM port (speed 115200 + baud; 8 bits; 1 stop bit; no parity). +5. Push the CFG button on the board. After a few seconds you should see the + boot log in the terminal which begins as follows: + +``` +U-Boot <Versioning info> + +CPU: ARC EM11D v5.0 at 40 MHz +Subsys:ARC Data Fusion IP Subsystem +Model: snps,emsdp +Board: ARC EM Software Development Platform v1.0 +… +``` + +### Generate Application Project for ARC EM SDP + +Before building an example or test application, you need to generate a TFLM +project for this application from TensorFlow sources and external dependencies. +To generate it for ARC EM SDP board you need to set `TARGET=arc_emsdp` on the +make command line. For instance, to build the Person Detect test application, +use a shell to execute the following command from the root directory of the +TensorFlow repo: + +``` +make -f tensorflow/lite/micro/tools/make/Makefile generate_person_detection_test_int8_make_project TARGET=arc_emsdp +``` + +The application project will be generated into +*tensorflow/lite/micro/tools/make/gen/arc_emsdp_arc/prj/person_detection_test_int8/make* + +Info on generating and building example applications for EM SDP +(*tensorflow/lite/micro/examples*) can be found in the appropriate readme file +placed in the same directory with the examples. In general, it’s the same +process which described in this Readme. + +The +[embARC MLI Library](https://github.com/foss-for-synopsys-dwc-arc-processors/embarc_mli) +is used by default to speed up execution of some kernels for asymmetrically +quantized layers. Kernels which use MLI-based implementations are kept in the +*tensorflow/lite/micro/kernels/arc_mli* folder. For applications which may not +benefit from MLI library, the project can be generated without these +implementations by adding `TAGS=no_arc_mli` in the command line. This can reduce +code size when the optimized kernels are not required. + +For more options on embARC MLI usage see +[kernels/arc_mli/README.md](/tensorflow/lite/micro/kernels/arc_mli/README.md). + +### Build the Application + +You may need to adjust the following commands in order to use the appropriate +make tool available in your environment (ie: `make` or `gmake`) + +1. Open command shell and change the working directory to the location which + contains the generated project, as described in the previous section + +2. Clean previous build artifacts (optional) + + make clean + +3. Build application + + make app + +### Run the Application on the Board Using MetaWare Debugger + +In case you do not have access to the MetaWare Debugger or have chosen not to +install the Digilent drivers, you can skip to the next section. + +To run the application from the console, use the following command: + +``` + make run +``` + +If application runs in an infinite loop, type `Ctrl+C` several times to exit the +debugger. + +To run the application in the GUI debugger, use the following command: + +``` + make debug +``` + +In both cases you will see the application output in the serial terminal. + +### Run the Application on the Board from the microSD Card + +1. Use the following command in the same command shell you used for building + the application, as described in the previous step + + make flash + +2. Copy the content of the created *./bin* folder into the root of microSD + card. Note that the card must be formatted as FAT32 with default cluster + size (but less than 32 Kbytes) + +3. Plug in the microSD card into the J11 connector. + +4. Push the RST button. If a red LED is lit beside RST button, push the CFG + button. + +You will see the application output in the serial terminal. + +## Custom ARC EM/HS Platform + +This section describes how to deploy on a Custom ARC EM/HS platform defined only +by a TCF (Tool Configuration File, created at CPU configuration time) and +optional LCF (Linker Command File). In this case, the real hardware is unknown, +and applications can be run only in the nSIM simulator included with the +MetaWare toolkit + +### Initial Setup + +To with custom ARC EM/HS platform, you need the following : * Synopsys MetaWare +Development Toolkit version 2019.12 or higher * Make tool (make or gmake) + +See +[Install the Synopsys DesignWare ARC MetaWare Development Toolkit](#install-the-synopsys-designware-arc-metaWare-development-toolkit) +section for instructions on toolchain installation. See +[MetaWare Development Toolkit](#MetaWare-Development-Toolkit) and +[Make Tool](#Make-Tool) sections for instructions on toolchain installation and +comments about make versions. + +### Generate Application Project + +Before building the application itself, you need to generate the project for +this application from TensorFlow sources and external dependencies. To generate +it for a custom TCF you need to set the following variables in the make command +line: * TARGET_ARCH=arc * TCF_FILE=<path to TCF file> * (optional) +LCF_FILE=<path to LCF file> + +If you don’t supply an external LCF, the one embedded in the TCF will be used +instead + +For instance, to build **Person Detection** test application, use the following +command from the root directory of the TensorFlow repo: + +``` +make -f tensorflow/lite/micro/tools/make/Makefile generate_person_detection_test_int8_make_project TARGET_ARCH=arc TCF_FILE=<path_to_tcf_file> LCF_FILE=<path_to_lcf_file> +``` + +The application project will be generated into +*tensorflow/lite/micro/tools/make/gen/<tcf_file_basename>_arc/prj/person_detection_test_int8/make* + +The +[embARC MLI Library](https://github.com/foss-for-synopsys-dwc-arc-processors/embarc_mli) +is used by default to speed up execution of some kernels for asymmetrically +quantized layers. Kernels which use MLI-based implementations are kept in the +*tensorflow/lite/micro/kernels/arc_mli* folder. For applications which may not +benefit from MLI library, the project can be generated without these +implementations by adding `TAGS=no_arc_mli` in the command line. This can reduce +code size when the optimized kernels are not required. + +For more options on embARC MLI usage see +[kernels/arc_mli/README.md](/tensorflow/lite/micro/kernels/arc_mli/README.md). + +### Build the Application + +You may need to adjust the following commands in order to use the appropriate +make tool available in your environment (ie: `make` or `gmake`) + +1. Open command shell and change the working directory to the location which + contains the generated project, as described in the previous section + +2. Clean previous build artifacts (optional) + + make clean + +3. Build application + + make app + +### Run the Application with MetaWare Debugger on the nSim Simulator. + +To run application from the console, use the following command: + +``` + make run +``` + +If application runs in an infinite loop, type `Ctrl+C` several times to exit the +debugger. + +To run the application in the GUI debugger, use the following command: + +``` + make debug +``` + +You will see the application output in the same console where you ran it. + +## License + +TensorFlow's code is covered by the Apache2 License included in the repository, +and third-party dependencies are covered by their respective licenses, in the +third_party folder of this package. diff --git a/tensorflow/lite/micro/tools/make/targets/arc/arc_common.inc b/tensorflow/lite/micro/tools/make/targets/arc/arc_common.inc new file mode 100644 index 00000000000..596f219d3d1 --- /dev/null +++ b/tensorflow/lite/micro/tools/make/targets/arc/arc_common.inc @@ -0,0 +1,138 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Common Settings for ARC platform and its projects. +# Might be reused across different targets + +ifeq ($(TARGET_ARCH), arc) + + DLR := $$$$ + + # List of folders to search project files for copy with path changing + # For instance, TCF and LCF files are copied into the root of generated project + ARC_TARGET_FILES_DIRS ?= + + # For the following variables see arc_app_makefile.tpl for usage + + # Additional text into application settings section of arc makefile project + ARC_EXTRA_APP_SETTINGS ?= + + # Additional text into application general rules of arc makefile project + ARC_EXTRA_APP_RULES ?= + + # additional arguments for RM command of "clean" target rule ("make clean" command) + ARC_EXTRA_RM_TARGETS ?= + + # Dependencies of "flash" target rule ("make flash" command) + ARC_BIN_DEPEND ?= + + # Commands in "flash" target rule ("make flash" command) + ARC_BIN_RULE ?= \t$(DLR)\(error Flash rule isnt defined for this ARC target\) + + # Command to run app on "make run" command of generated project + ARC_APP_RUN_CMD ?= + + # Command to run app on "make debug" command of generated project + ARC_APP_DEBUG_CMD ?= + + # Additional text into application execution rules of arc makefile project + ARC_EXTRA_EXECUTE_RULES ?= + +# We overwrite project generator to exclude everything not relevant to ARC platform. +# ARC targets cannot work with non-ARC development tools. +# Basic make project is updated to be applicable for general ARC platform +define generate_microlite_projects +$(call generate_project,make,$(MAKE_PROJECT_FILES),$(1),$(MICROLITE_CC_SRCS) $(THIRD_PARTY_CC_SRCS) $(2),$(MICROLITE_CC_HDRS) $(THIRD_PARTY_CC_HDRS) $(MICROLITE_TEST_HDRS) $(3),$(LDFLAGS) $(MICROLITE_LIBS),$(CXXFLAGS) $(GENERATED_PROJECT_INCLUDES), $(CCFLAGS) $(GENERATED_PROJECT_INCLUDES),$(TARGET_TOOLCHAIN_ROOT),$(TARGET_TOOLCHAIN_PREFIX)) +$(call generate_arc_project,make,$(MAKE_PROJECT_FILES),$(1),$(MICROLITE_CC_SRCS) $(THIRD_PARTY_CC_SRCS) $(2),$(MICROLITE_CC_HDRS) $(THIRD_PARTY_CC_HDRS) $(MICROLITE_TEST_HDRS) $(3),$(LDFLAGS) $(GENERATED_PROJECT_LIBS),$(CXXFLAGS) $(GENERATED_PROJECT_INCLUDES), $(CCFLAGS) $(GENERATED_PROJECT_INCLUDES)) +endef + +# Copy rule generator to do file copies with changing paths in generated project +# Arguments are: +# 1 - Path files in generated project. +# 2 - Path files in the source repo +# Used in helper_functions.inc for arc projects to copy files +define path_changing_copy_file +$(1)/%: $(2)/% + @mkdir -p $$(dir $$@) + @cp $$< $$@ +endef + +# These are microcontroller-specific rules for converting the ELF output +# of the linker into a binary image that can be loaded directly. +# Not applicable for ARC, leaving it empty. +$(BINDIR)%.bin: + + +ifeq ($(ARC_TOOLCHAIN), mwdt) + CC_TOOL := ccac + AR_TOOL := arac + CXX_TOOL := ccac + LD_TOOL := ccac + + ARC_APP_RUN_CMD = mdb -run -jit -tcf=$(TCF_FILE_NAME) $(DLR)\(DBG_ARGS\) + ARC_APP_DEBUG_CMD = mdb -OK -jit -tcf=$(TCF_FILE_NAME) $(DLR)\(DBG_ARGS\) + + # The variable TCF_FILE stores path to Tool Configuration File (*.tcf). + # This file is used by MWDT toolchain to properly compile/run code + TCF_FILE ?= + + LCF_FILE ?= + + BUILD_ARC_MLI ?= true + +# The variable TCF_FILE_NAME stores the TCF file name (including .tcf extension), +# this variable is used later to add the option to the linker/compiler flags. +# This condition also handles the case when the user/makefile specifies +# the configuration bundled with MWDT (usually without .tcf extension) and that doesn't require copying. +ifneq (,$(findstring .tcf,$(TCF_FILE))) + TCF_FILE_NAME = $(notdir $(TCF_FILE)) + ARC_TARGET_FILES_DIRS = $(dir $(TCF_FILE)) + MAKE_PROJECT_FILES += $(TCF_FILE_NAME) +else + TCF_FILE_NAME = $(TCF_FILE) +endif + + PLATFORM_FLAGS = -tcf=$(TCF_FILE_NAME) -tcf_core_config + + PLATFORM_FLAGS += -Hnocopyr -Hpurge -Hdense_prologue -Hon=Long_enums -fslp-vectorize-aggressive -ffunction-sections -fdata-sections + + # Use compact CRT. It requires pre-defined heap size + PLATFORM_FLAGS += -Hcl -Hcrt_fast_memcpy -Hcrt_fast_memset + + PLATFORM_LDFLAGS = -tcf=$(TCF_FILE_NAME) + + PLATFORM_LDFLAGS += -Hnocopyr -m -Hldopt=-Coutput=memory.map -Hheap=2K + +ifneq ($(LCF_FILE), ) + PLATFORM_LDFLAGS += $(notdir $(LCF_FILE)) + MAKE_PROJECT_FILES += $(notdir $(LCF_FILE)) +ifeq ($(filter $(ARC_TARGET_FILES_DIRS), $(dir $(LCF_FILE))),) + ARC_TARGET_FILES_DIRS += $(dir $(LCF_FILE)) +endif +endif + + CXXFLAGS := $(filter-out -std=c++11,$(CXXFLAGS)) + CCFLAGS := $(filter-out -std=c11,$(CCFLAGS)) + MICROLITE_LIBS := $(filter-out -lm,$(MICROLITE_LIBS)) + + CXXFLAGS += $(PLATFORM_FLAGS) + CCFLAGS += $(PLATFORM_FLAGS) + LDFLAGS += $(PLATFORM_LDFLAGS) + + + + +endif # ARC_TOOLCHAIN +endif # TARGET_ARCH + diff --git a/tensorflow/lite/micro/tools/make/targets/arc/emsdp/emsdp.lcf b/tensorflow/lite/micro/tools/make/targets/arc/emsdp/emsdp.lcf new file mode 100644 index 00000000000..780fd7b9750 --- /dev/null +++ b/tensorflow/lite/micro/tools/make/targets/arc/emsdp/emsdp.lcf @@ -0,0 +1,85 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Common EMSDP LCF File for applications +# +# external SRAM memory is used for code, because some TFLM applications includes the whole +# set of supported kernels which doesn't fit to ICCM0. +# It could slow performance a bit. Smaller applications can use ICCM0 instead. +# +# External PSRAM is used for potentially big sections. In particular: +# - rodata_in data which typically includes protobuf with model. +# - other .data which typically includes tensor arena. +# +# stack and heap are kept in DCCM which is the closest memory to the core + +# CCMWRAP memory regions indicate unusable portions of the address space +# due to CCM memory wrapping into upper addresses beyond its size + +MEMORY { + PSRAM : ORIGIN = 0x10000400, LENGTH = (0x01000000 >> 1) - 0x400 + SRAM : ORIGIN = 0x20000000, LENGTH = 0x00040000 + IVT : ORIGIN = 0x60000000, LENGTH = 0x400 + ICCM0 : ORIGIN = 0x60000400, LENGTH = (0x00020000 - 0x400) +# CCMWRAP0: ORIGIN = 0x60020000, LENGTH = 0x0ffe0000 + DCCM : ORIGIN = 0x80000000, LENGTH = 0x00020000 +# CCMWRAP1: ORIGIN = 0x80020000, LENGTH = 0x0ffe0000 + XCCM : ORIGIN = 0x90000000, LENGTH = 0x00004000 +# CCMWRAP2: ORIGIN = 0x90004000, LENGTH = 0x0fffc000 + YCCM : ORIGIN = 0xa0000000, LENGTH = 0x00004000 +# CCMWRAP3: ORIGIN = 0xa0004000, LENGTH = 0x0fffc000 + } + +SECTIONS { + + GROUP BLOCK(4) : { + .vectors (TEXT) SIZE(DEFINED _IVTSIZE?_IVTSIZE:756): {} = FILL(0xa5a5a5a5,4) + } > IVT + + GROUP BLOCK(4): { + .text? : { *('.text$crt*') } + * (TEXT): {} + * (LIT): {} + } > SRAM + + GROUP BLOCK(4): { + .Zdata? : {} + .stack ALIGN(4) SIZE(DEFINED _STACKSIZE?_STACKSIZE:32K): {} + .heap? ALIGN(4) SIZE(DEFINED _HEAPSIZE?_HEAPSIZE:8K): {} + } > DCCM + + GROUP BLOCK(4): { + .Xdata? : {} + } > XCCM + + GROUP BLOCK(4): { + .Ydata? : {} + } > YCCM + + GROUP BLOCK(4): { + /* _SDA_BASE_ computed implicitly */ + .sdata?: {} + .sbss?: {} + * (DATA): {} + * (BSS): {} + } > PSRAM + + GROUP BLOCK(4): { + .rodata_in_data? : {} + } > PSRAM + + GROUP BLOCK(4): { + .debug_log? : {} + } > SRAM +} + + diff --git a/tensorflow/lite/micro/tools/make/targets/arc/emsdp/emsdp_v2.lcf b/tensorflow/lite/micro/tools/make/targets/arc/emsdp/emsdp_v2.lcf new file mode 100644 index 00000000000..63ef48667db --- /dev/null +++ b/tensorflow/lite/micro/tools/make/targets/arc/emsdp/emsdp_v2.lcf @@ -0,0 +1,74 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Difference with common EMSDP LCF file (to reduce data access time): +# - move data from external PSRAM to DCCM +# - move text from SRAM to ICCM +# +# CCMWRAP memory regions indicate unusable portions of the address space +# due to CCM memory wrapping into upper addresses beyond its size + +MEMORY { + PSRAM : ORIGIN = 0x10000400, LENGTH = (0x01000000 >> 1) - 0x400 + SRAM : ORIGIN = 0x20000000, LENGTH = 0x00040000 + IVT : ORIGIN = 0x60000000, LENGTH = 0x400 + ICCM0 : ORIGIN = 0x60000400, LENGTH = (0x00020000 - 0x400) +# CCMWRAP0: ORIGIN = 0x60020000, LENGTH = 0x0ffe0000 + DCCM : ORIGIN = 0x80000000, LENGTH = 0x00020000 +# CCMWRAP1: ORIGIN = 0x80020000, LENGTH = 0x0ffe0000 + XCCM : ORIGIN = 0x90000000, LENGTH = 0x00004000 +# CCMWRAP2: ORIGIN = 0x90004000, LENGTH = 0x0fffc000 + YCCM : ORIGIN = 0xa0000000, LENGTH = 0x00004000 +# CCMWRAP3: ORIGIN = 0xa0004000, LENGTH = 0x0fffc000 + } + +SECTIONS { + + GROUP BLOCK(4) : { + .vectors (TEXT) SIZE(DEFINED _IVTSIZE?_IVTSIZE:756): {} = FILL(0xa5a5a5a5,4) + } > IVT + + GROUP BLOCK(4): { + .text? : { *('.text$crt*') } + * (TEXT): {} + * (LIT): {} + } > ICCM0 + + GROUP BLOCK(4): { + .rodata_in_data? : {} + } > PSRAM + + GROUP BLOCK(4): { + .debug_log? : {} + } > SRAM + + GROUP BLOCK(4): { + /* _SDA_BASE_ computed implicitly */ + .sdata?: {} + .sbss?: {} + * (DATA): {} + * (BSS): {} + .Zdata? : {} + .stack ALIGN(4) SIZE(DEFINED _STACKSIZE?_STACKSIZE:8K): {} + .heap? ALIGN(4) SIZE(DEFINED _HEAPSIZE?_HEAPSIZE:8K): {} + } > DCCM + + GROUP BLOCK(4): { + .Xdata? : {} + } > XCCM + + GROUP BLOCK(4): { + .Ydata? : {} + } > YCCM +} + + diff --git a/tensorflow/lite/micro/tools/make/targets/arc_emsdp_makefile.inc b/tensorflow/lite/micro/tools/make/targets/arc_emsdp_makefile.inc new file mode 100644 index 00000000000..405b9698cca --- /dev/null +++ b/tensorflow/lite/micro/tools/make/targets/arc_emsdp_makefile.inc @@ -0,0 +1,73 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Settings for EMSDP target (ARC processor) +ifeq ($(TARGET), arc_emsdp) + + TARGET_ARCH := arc + ARC_TOOLCHAIN := mwdt + + + BUILD_ARC_MLI := false + ARC_MLI_PRE_COMPILED_TARGET := emsdp_em11d_em9d_dfss + +ifneq ($(filter no_arc_mli,$(ALL_TAGS)),) + MLI_LIB_DIR = arc_mli_package + $(eval $(call add_third_party_download,$(EMBARC_MLI_PRE_COMPILED_URL),$(EMBARC_MLI_PRE_COMPILED_MD5),$(MLI_LIB_DIR),)) +else ifeq ($(BUILD_ARC_MLI), true) + MLI_LIB_DIR = arc_mli_$(ARC_MLI_PRE_COMPILED_TARGET) +endif + + TCF_FILE = $(PWD)/$(MAKEFILE_DIR)/downloads/$(MLI_LIB_DIR)/hw/emsdp_em11d_em9d_dfss.tcf + LCF_FILE = $(PWD)/$(MAKEFILE_DIR)/targets/arc/emsdp/emsdp.lcf + UBOOT_FILE := $(PWD)/$(MAKEFILE_DIR)/targets/arc/emsdp/uboot.env + UBOOT_FILE_NAME := $(notdir $(UBOOT_FILE)) + + +include $(MAKEFILE_DIR)/targets/arc/arc_common.inc + + ARC_EXTRA_APP_SETTINGS = \ + BIN_DIR = .$(DLR)\(PS\)bin\n\ + BIN_FILE = $(DLR)\(BIN_DIR\)$(DLR)\(PS\)app.elf\n + + ARC_EXTRA_APP_RULES = \ + $(DLR)\(BIN_FILE\): $(DLR)\(BIN_DIR\) $(DLR)\(OUT_NAME\)\ + \n\t\@$(DLR)\(CP\) $(DLR)\(OUT_NAME\) $(DLR)\(BIN_FILE\)\ + \n\t\@$(DLR)\(CP\) $(UBOOT_FILE_NAME) $(DLR)\(BIN_DIR\)$(DLR)\(PS\)$(UBOOT_FILE_NAME)\ + \n \ + \n$(DLR)\(BIN_DIR\):\ + \n\t\@$(DLR)\(MKDIR\) $(DLR)\(BIN_DIR\)\ + + ARC_EXTRA_RM_TARGETS = $(DLR)\(BIN_DIR\) + + ARC_BIN_DEPEND = $(DLR)\(BIN_DIR\) $(DLR)\(BIN_FILE\) + ARC_BIN_RULE = \t@echo Copy content of $(DLR)\(BIN_DIR\) into the root of SD card and follow instructions + + ARC_APP_RUN_CMD = mdb -run -digilent -nooptions $(DLR)\(DBG_ARGS\) + ARC_APP_DEBUG_CMD = mdb -OK -digilent -nooptions $(DLR)\(DBG_ARGS\) + ARC_EXTRA_EXECUTE_RULES = + + MAKE_PROJECT_FILES += $(UBOOT_FILE_NAME) +ifeq ($(filter $(ARC_TARGET_FILES_DIRS), $(dir $(UBOOT_FILE))),) + ARC_TARGET_FILES_DIRS += $(dir $(UBOOT_FILE)) +endif + + MAKE_PROJECT_FILES := $(filter-out README_MAKE.md, $(MAKE_PROJECT_FILES)) README_ARC_EMSDP.md + + # for default EMSDP configuration we can use em9d_va rt libs + # for better performance runtime should be built for emsdp configuration + # No hostlink library for smaller codesize purpose + PLATFORM_LDFLAGS += -Hlib=em9d_voice_audio -Hhostlib= + +endif diff --git a/tensorflow/lite/micro/tools/make/targets/arc_makefile.inc b/tensorflow/lite/micro/tools/make/targets/arc_makefile.inc index 0f56e5f4641..9f5442b4c6c 100644 --- a/tensorflow/lite/micro/tools/make/targets/arc_makefile.inc +++ b/tensorflow/lite/micro/tools/make/targets/arc_makefile.inc @@ -1,86 +1,40 @@ -# Settings for arc processors +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Settings for not pre-defined ARC processors. +# User need to specify ARC target with Tool Configuration File (*.tcf). +# Path to this file must be passed through TCF_FILE variable. +# Otherwise, default em7d_voice_audio configuration is used ifeq ($(TARGET_ARCH), arc) - CC_TOOL = ccac - AR_TOOL = arac - CXX_TOOL = ccac +# Known target are specified with their own make configurations. +ifeq ($(filter $(TARGET), arc_emsdp),) + +ARC_TOOLCHAIN := mwdt ifneq ($(TCF_FILE), ) TARGET = $(basename $(notdir $(TCF_FILE))) else + $(warning TCF_FILE variable is not specified. Use default em7d_voice_audio configuration) TARGET = em7d_voice_audio TCF_FILE = em7d_voice_audio endif -# The variable TCF_FILE_NAME stores the TCF file name (including .tcf extension), this variable is used later to add the option to the linker/compiler flags. -# This condition also handles the case when the user/makefile specifies the configuration bundled with MWDT (usually without .tcf extension) and that doesn't require copying. -ifneq (,$(findstring .tcf,$(TCF_FILE))) - TCF_FILE_NAME = $(notdir $(TCF_FILE)) - THIRD_PARTY_CC_HDRS += $(TCF_FILE_NAME) -else - TCF_FILE_NAME = $(TCF_FILE) -endif +include $(MAKEFILE_DIR)/targets/arc/arc_common.inc - PLATFORM_FLAGS = -tcf=$(TCF_FILE_NAME) -Hnocopyr -O3 -Hpurge -Hcl -fslp-vectorize-aggressive -ffunction-sections -fdata-sections - PLATFORM_LDFLAGS = -tcf=$(TCF_FILE_NAME) -Hnocopyr -m -Hldopt=-Coutput=memory.map +MAKE_PROJECT_FILES := $(filter-out README_MAKE.md, $(MAKE_PROJECT_FILES)) README_ARC.md - CXXFLAGS += $(PLATFORM_FLAGS) - CXXFLAGS:=$(filter-out -std=c++11,$(CXXFLAGS)) - CCFLAGS += $(PLATFORM_FLAGS) - LDFLAGS += $(PLATFORM_LDFLAGS) +endif # $(TARGET) +endif # $(TARGET_ARCH)... - MICROLITE_LIBS := $(filter-out -lm,$(MICROLITE_LIBS)) - - USE_EMBARC_MLI ?= true - -ifeq ($(USE_EMBARC_MLI), true) - ALL_TAGS += arc - -ifeq ($(PRE_COMPILED_MLI),true) - $(eval $(call add_third_party_download,$(EMBARC_OSP_URL),$(EMBARC_OSP_MD5),embarc_osp,)) - - MLI_INCLUDE_FOLDER = embarc_osp/library/embarc_mli/include - MLI_LIB = third_party/embarc_osp/library/embarc_mli/lib/arcem9d/libmli_iotdk.a - - THIRD_PARTY_CC_HDRS += \ - third_party/embarc_osp/LICENSE -else - MLI_LIB_DIR = embarc_mli_$(basename $(TCF_FILE_NAME)) - - $(eval $(call add_third_party_download,$(EMBARC_MLI_URL),$(EMBARC_MLI_MD5),$(MLI_LIB_DIR),build_embarc_mli,$(TCF_FILE))) - - MLI_INCLUDE_FOLDER = $(MLI_LIB_DIR)/include - MLI_LIB = third_party/$(MLI_LIB_DIR)/bin/libmli.a - MICROLITE_LIBS += $(MAKEFILE_DIR)/downloads/$(MLI_LIB_DIR)/bin/libmli.a - - THIRD_PARTY_CC_HDRS += \ - third_party/$(MLI_LIB_DIR)/LICENSE -endif - - THIRD_PARTY_CC_HDRS += $(MLI_LIB) - GENERATED_PROJECT_LIBS += $(MLI_LIB) - - INCLUDES += \ - -I$(MAKEFILE_DIR)/downloads/$(MLI_INCLUDE_FOLDER) \ - -I$(MAKEFILE_DIR)/downloads/$(MLI_INCLUDE_FOLDER)/api - - GENERATED_PROJECT_INCLUDES += \ - -I. \ - -I./third_party/$(MLI_INCLUDE_FOLDER) \ - -I./third_party/$(MLI_INCLUDE_FOLDER)/api - - - THIRD_PARTY_CC_HDRS += \ - third_party/$(MLI_INCLUDE_FOLDER)/mli_api.h \ - third_party/$(MLI_INCLUDE_FOLDER)/mli_config.h \ - third_party/$(MLI_INCLUDE_FOLDER)/mli_types.h \ - third_party/$(MLI_INCLUDE_FOLDER)/api/mli_helpers_api.h \ - third_party/$(MLI_INCLUDE_FOLDER)/api/mli_kernels_api.h \ - third_party/$(MLI_INCLUDE_FOLDER)/api/mli_krn_avepool_spec_api.h \ - third_party/$(MLI_INCLUDE_FOLDER)/api/mli_krn_conv2d_spec_api.h \ - third_party/$(MLI_INCLUDE_FOLDER)/api/mli_krn_depthwise_conv2d_spec_api.h \ - third_party/$(MLI_INCLUDE_FOLDER)/api/mli_krn_maxpool_spec_api.h \ - -endif # USE_EMBARC_MLI - -endif diff --git a/tensorflow/lite/micro/tools/make/targets/bluepill_makefile.inc b/tensorflow/lite/micro/tools/make/targets/bluepill_makefile.inc index 2bd84fa6e29..5223f00a74f 100644 --- a/tensorflow/lite/micro/tools/make/targets/bluepill_makefile.inc +++ b/tensorflow/lite/micro/tools/make/targets/bluepill_makefile.inc @@ -19,7 +19,6 @@ ifeq ($(TARGET), bluepill) -fmessage-length=0 \ -fno-exceptions \ -fno-unwind-tables \ - -fno-builtin \ -ffunction-sections \ -fdata-sections \ -funsigned-char \ diff --git a/tensorflow/lite/micro/tools/make/targets/ecm3531_makefile.inc b/tensorflow/lite/micro/tools/make/targets/ecm3531_makefile.inc index 756915f946a..709f060c4ac 100644 --- a/tensorflow/lite/micro/tools/make/targets/ecm3531_makefile.inc +++ b/tensorflow/lite/micro/tools/make/targets/ecm3531_makefile.inc @@ -27,7 +27,6 @@ ifeq ($(TARGET), ecm3531) -fmessage-length=0 \ -fno-exceptions \ -fno-unwind-tables \ - -fno-builtin \ -ffunction-sections \ -fdata-sections \ -funsigned-char \ diff --git a/tensorflow/lite/micro/tools/make/targets/hexagon_makefile.inc b/tensorflow/lite/micro/tools/make/targets/hexagon_makefile.inc index 155fff99dcd..c3cbf206c8a 100644 --- a/tensorflow/lite/micro/tools/make/targets/hexagon_makefile.inc +++ b/tensorflow/lite/micro/tools/make/targets/hexagon_makefile.inc @@ -40,7 +40,6 @@ ifeq ($(TARGET), hexagon) -fdata-sections \ -ffunction-sections \ -fmessage-length=0 \ - -fno-builtin \ -fno-delete-null-pointer-checks \ -fno-exceptions \ -fno-register-global-dtors-with-atexit \ diff --git a/tensorflow/lite/micro/tools/make/targets/mcu_riscv_makefile.inc b/tensorflow/lite/micro/tools/make/targets/mcu_riscv_makefile.inc index 079c3c14f1e..b91b0a516f2 100644 --- a/tensorflow/lite/micro/tools/make/targets/mcu_riscv_makefile.inc +++ b/tensorflow/lite/micro/tools/make/targets/mcu_riscv_makefile.inc @@ -16,7 +16,6 @@ ifeq ($(TARGET), riscv32_mcu) -DTF_LITE_MCU_DEBUG_LOG \ -DTF_LITE_USE_GLOBAL_CMATH_FUNCTIONS \ -fno-unwind-tables \ - -fno-builtin \ -ffunction-sections \ -fdata-sections \ -funsigned-char \ diff --git a/tensorflow/lite/micro/tools/make/targets/stm32f4_makefile.inc b/tensorflow/lite/micro/tools/make/targets/stm32f4_makefile.inc index 7abd3cc7e38..5114aa4620c 100644 --- a/tensorflow/lite/micro/tools/make/targets/stm32f4_makefile.inc +++ b/tensorflow/lite/micro/tools/make/targets/stm32f4_makefile.inc @@ -16,7 +16,6 @@ ifeq ($(TARGET), stm32f4) -fmessage-length=0 \ -fno-exceptions \ -fno-unwind-tables \ - -fno-builtin \ -ffunction-sections \ -fdata-sections \ -funsigned-char \ diff --git a/tensorflow/lite/micro/tools/make/targets/xtensa_xpg_makefile.inc b/tensorflow/lite/micro/tools/make/targets/xtensa_xpg_makefile.inc index 5ed601f8dd1..dba98b45cd9 100644 --- a/tensorflow/lite/micro/tools/make/targets/xtensa_xpg_makefile.inc +++ b/tensorflow/lite/micro/tools/make/targets/xtensa_xpg_makefile.inc @@ -30,4 +30,16 @@ ifeq ($(TARGET), xtensa-xpg) LDFLAGS += -Wl,-gc-sections TEST_SCRIPT := tensorflow/lite/micro/testing/test_xtensa_xpg_binary.sh + + # TODO(b/156962140): This manually maintained list of excluded examples is + # quite error prone. + EXCLUDED_EXAMPLE_TESTS := \ + tensorflow/lite/micro/examples/image_recognition_experimental/Makefile.inc \ + tensorflow/lite/micro/examples/magic_wand/Makefile.inc \ + tensorflow/lite/micro/examples/micro_speech/Makefile.inc \ + tensorflow/lite/micro/examples/network_tester/Makefile.inc \ + tensorflow/lite/micro/examples/person_detection/Makefile.inc \ + tensorflow/lite/micro/examples/person_detection_experimental/Makefile.inc + MICRO_LITE_EXAMPLE_TESTS := $(filter-out $(EXCLUDED_EXAMPLE_TESTS), $(MICRO_LITE_EXAMPLE_TESTS)) + endif diff --git a/tensorflow/lite/micro/tools/make/templates/arc/README_ARC.md.tpl b/tensorflow/lite/micro/tools/make/templates/arc/README_ARC.md.tpl new file mode 100644 index 00000000000..0ddaf3e0a81 --- /dev/null +++ b/tensorflow/lite/micro/tools/make/templates/arc/README_ARC.md.tpl @@ -0,0 +1,45 @@ +# TensorFlow Lite Micro ARC Make Project + +This folder has been autogenerated by TensorFlow, and contains sources, headers, and project files needed to build a single TensorFlow Lite Micro application using make tool and a Synopsys DesignWare ARC processor compatible toolchain, specifically the ARC MetaWare Development Toolkit (MWDT). + +This project has been generated for a target defined by TCF file only (Tool Configuration File). The real target board is unspecified, and applications can be run only in the nSIM simulator included with MWDT. + +See +[tensorflow/lite/micro](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/micro) +for details on how projects like this can be generated from the main source tree. + +## Usage + +See [Custom ARC EM/HS Platform](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/micro/tools/make/targets/arc/README.md#Custom-ARC-EMHS-Platform) section for more detailed information on requirements and usage of this project. + +The Makefile contains all the information on building and running the project. One can modify it to satisfy specific needs. Next actions are available out of the box. You may need to adjust the following commands in order to use the appropriate make tool available in your environment, ie: `make` or `gmake` + +1. Build the application. + + make app + +2. Build the application passing additional flags to compiler. + + make app EXT_CFLAGS=[additional compiler flags] + +3. Build the application and stripout TFLM reference kernel fallback implementations in order to reduce code size. This only has an effect in case the project was generated with MLI support. See more info in [EmbARC MLI Library Based Optimizations](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/micro/kernels/arc_mli/README.md). `false` is the default value. + + make app MLI_ONLY=[true|false] + +4. Delete all artifacts created during build. + + make clean + +5. Run the application with the nSIM simulator in console mode. + + make run + +6. Run the application with the nSIM simulator, but using the MetaWare Debugger GUI for further execution/debugging capabilities. + + make debug + + + +## License + +TensorFlow's code is covered by the Apache2 License included in the repository, and third party dependencies are covered by their respective licenses, in the third_party folder of this package. diff --git a/tensorflow/lite/micro/tools/make/templates/arc/README_ARC_EMSDP.md.tpl b/tensorflow/lite/micro/tools/make/templates/arc/README_ARC_EMSDP.md.tpl new file mode 100644 index 00000000000..9d2801ed6b7 --- /dev/null +++ b/tensorflow/lite/micro/tools/make/templates/arc/README_ARC_EMSDP.md.tpl @@ -0,0 +1,48 @@ +# TensorFlow Lite Micro ARC Make Project for EM SDP Board. + +This folder has been autogenerated by TensorFlow, and contains source, header, and project files needed to build a single TensorFlow Lite Micro target using make tool and and a Synopsys DesignWare ARC processor compatible toolchain, specifically the ARC MetaWare Development Toolkit (MWDT). + +This project has been generated for the ARC EM Software Development Platform (EM SDP). The built application can be run only on this platform. + +See +[tensorflow/lite/micro](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/micro) +for details on how projects like this can be generated from the main source tree. + +## Usage + +See [ARC EM Software Development Platform](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/micro/tools/make/targets/arc/README.md#ARC-EM-Software-Development-Platform-ARC-EM-SDP) section for more detailed information on requirements and usage of this project. + +The Makefile contains all the information on building and running the project. One can modify it to satisfy specific needs. Next actions are available out of the box. You may need to adjust the following commands in order to use the appropriate make tool available in your environment, ie: `make` or `gmake`: + +1. Build the application. + + make app + +2. Build the application passing additional flags to compiler. + + make app EXT_CFLAGS=[additional compiler flags] + +3. Build the application and stripout TFLM reference kernel fallback implementations in order to reduce code size. This only has an effect in case the project was generated with MLI support. See more info in [EmbARC MLI Library Based Optimizations](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/micro/kernels/arc_mli/README.md). `false` is the default value. + + make app MLI_ONLY=[true|false] + +4. Delete all artifacts created during build. + + make clean + +5. Run the application with the nSIM simulator in console mode. + + make run + +6. Load the application and open MetaWare Debugger GUI for further execution/debugging. + + make debug + +7. Generate necessary artefacts for self-booting execution from flash. See [reference to Run the application on the board from the micro SD card](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/micro/tools/make/targets/arc/README.md#Run-the-Application-on-the-Board-from-the-microSD-Card). + + make flash + + +## License + +TensorFlow's code is covered by the Apache2 License included in the repository, and third party dependencies are covered by their respective licenses, in the third_party folder of this package. diff --git a/tensorflow/lite/micro/tools/make/templates/arc/arc_app_makefile.tpl b/tensorflow/lite/micro/tools/make/templates/arc/arc_app_makefile.tpl new file mode 100644 index 00000000000..a1a3ab71028 --- /dev/null +++ b/tensorflow/lite/micro/tools/make/templates/arc/arc_app_makefile.tpl @@ -0,0 +1,114 @@ +#============================================================= +# OS-specific definitions +#============================================================= +COMMA=, +OPEN_PAREN=( +CLOSE_PAREN=) +BACKSLASH=\$(nullstring) +ifneq ($(ComSpec)$(COMSPEC),) + O_SYS=Windows + RM=del /F /Q + MKDIR=mkdir + CP=copy /Y + TYPE=type + PS=$(BACKSLASH) + Q= + coQ=\$(nullstring) + fix_platform_path = $(subst /,$(PS), $(1)) + DEV_NULL = nul +else + O_SYS=Unix + RM=rm -rf + MKDIR=mkdir -p + CP=cp + TYPE=cat + PS=/ + Q=$(BACKSLASH) + coQ= + fix_platform_path=$(1) + DEV_NULL=/dev/null +endif + +#============================================================= +# Toolchain definitions +#============================================================= +CC = %{CC}% +CXX = %{CXX}% +LD = %{LD}% + + +#============================================================= +# Applications settings +#============================================================= +OUT_NAME = %{EXECUTABLE}% + +DBG_ARGS ?= + +RUN_ARGS ?= + +EXT_CFLAGS ?= + +CXXFLAGS += %{CXX_FLAGS}% + +CCFLAGS += %{CC_FLAGS}% + +LDFLAGS += %{LINKER_FLAGS}% + +%{EXTRA_APP_SETTINGS}% + + +#============================================================= +# Files and directories +#============================================================= +SRCS := \ +%{SRCS}% + +OBJS := \ +$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(SRCS))) + + +#============================================================= +# Common rules +#============================================================= +.PHONY: all app flash clean run debug + +%.o: %.cc + $(CXX) $(CXXFLAGS) $(EXT_CFLAGS) $(INCLUDES) -c $< -o $@ + +%.o: %.c + $(CC) $(CCFLAGS) $(EXT_CFLAGS) $(INCLUDES) -c $< -o $@ + +$(OUT_NAME): $(OBJS) + $(LD) $(CXXFLAGS) -o $@ -Ccrossref $(OBJS) $(LDFLAGS) + +%{EXTRA_APP_RULES}% + + +#================================================================= +# Global rules +#================================================================= +all: $(OUT_NAME) + +app: $(OUT_NAME) + +flash: %{BIN_DEPEND}% +%{BIN_RULE}% + +clean: + -@$(RM) $(call fix_platform_path,$(OBJS)) + -@$(RM) $(OUT_NAME) %{EXTRA_RM_TARGETS}% + +#================================================================= +# Execution rules +#================================================================= + +APP_RUN := %{APP_RUN_CMD}% +APP_DEBUG := %{APP_DEBUG_CMD}% + +run: $(OUT_NAME) + $(APP_RUN) $(OUT_NAME) $(RUN_ARGS) + +debug: $(OUT_NAME) + $(APP_DEBUG) $(OUT_NAME) $(RUN_ARGS) + +%{EXTRA_EXECUTE_RULES}% diff --git a/tensorflow/lite/micro/tools/make/third_party_downloads.inc b/tensorflow/lite/micro/tools/make/third_party_downloads.inc index 9251e4c161e..806501a004a 100644 --- a/tensorflow/lite/micro/tools/make/third_party_downloads.inc +++ b/tensorflow/lite/micro/tools/make/third_party_downloads.inc @@ -28,8 +28,8 @@ LEON_BCC2_MD5 := "cdf78082be4882da2a92c9baa82fe765" TSIM_URL := "https://www.gaisler.com/anonftp/tsim/tsim-eval-2.0.63.tar.gz" TSIM_MD5 := "afa0095d3ed989a949e1467f94e41d2f" -CMSIS_URL := "https://github.com/ARM-software/CMSIS_5/archive/8a4db53f69da06e97565fe2f2e8926d193a5759d.zip" -CMSIS_MD5 := "e9864fb71b65adc4f7d92a9dea6e1aab" +CMSIS_URL := "https://github.com/ARM-software/CMSIS_5/archive/1150e71e07c79b538efd842aba5b210a31827ae5.zip" +CMSIS_MD5 := "e05f4222ef58825193910b41a0871dcb" AM_SDK_URL := "http://s3.asia.ambiqmicro.com/downloads/AmbiqSuite-Rel2.2.0.zip" AM_SDK_MD5 := "7605fa2d4d97e6bb7a1190c92b66b597" @@ -71,11 +71,11 @@ PERSON_MODEL_MD5 := "fe2934bd0788f1dcc7af3f0a954542ab" PERSON_MODEL_INT8_URL := "https://storage.googleapis.com/download.tensorflow.org/data/tf_lite_micro_person_data_int8_grayscale_2020_01_13.zip" PERSON_MODEL_INT8_MD5 := "8a7d2c70325f53136faea6dde517b8cc" -EMBARC_OSP_URL := "https://github.com/foss-for-synopsys-dwc-arc-processors/embarc_osp/archive/embarc_mli.zip" -EMBARC_OSP_MD5 := "9eaf7b3a1ed05872a03da9796672a776" +EMBARC_MLI_URL := "https://github.com/foss-for-synopsys-dwc-arc-processors/embarc_mli/archive/58284867ca52d1f43b25045e8601999d7359d986.zip" +EMBARC_MLI_MD5 := "2bf4982a327fdaa9d475803ce014d1ef" -EMBARC_MLI_URL := "https://github.com/foss-for-synopsys-dwc-arc-processors/embarc_mli/archive/6316034d421cbbb59756239908d7c9a99075a3bb.zip" -EMBARC_MLI_MD5 := "db0910cf0e07e43f74ae7a31de485d56" +EMBARC_MLI_PRE_COMPILED_URL := "https://github.com/foss-for-synopsys-dwc-arc-processors/embarc_mli/releases/download/Release_1.1_RC2/embARC_MLI_package.zip" +EMBARC_MLI_PRE_COMPILED_MD5 := "a95ff9e0370434484f14e7e4114327f6" XTENSA_HIFI4_URL :="https://github.com/foss-xtensa/nnlib-hifi4/raw/master/archive/xa_nnlib_04_07.zip" XTENSA_HIFI4_MD5 :="f234764928f9a42901df33a27e118c8b" diff --git a/tensorflow/lite/nnapi/NeuralNetworksTypes.h b/tensorflow/lite/nnapi/NeuralNetworksTypes.h index 851c1718e0a..6739838e4d1 100644 --- a/tensorflow/lite/nnapi/NeuralNetworksTypes.h +++ b/tensorflow/lite/nnapi/NeuralNetworksTypes.h @@ -136,6 +136,13 @@ enum { ANEURALNETWORKS_UNIDIRECTIONAL_SEQUENCE_LSTM = 92, ANEURALNETWORKS_UNIDIRECTIONAL_SEQUENCE_RNN = 93, ANEURALNETWORKS_RESIZE_NEAREST_NEIGHBOR = 94, + ANEURALNETWORKS_QUANTIZED_LSTM = 95, + ANEURALNETWORKS_IF = 96, + ANEURALNETWORKS_WHILE = 97, + ANEURALNETWORKS_ELU = 98, + ANEURALNETWORKS_HARD_SWISH = 99, + ANEURALNETWORKS_FILL = 100, + ANEURALNETWORKS_RANK = 101, }; /** @@ -208,6 +215,18 @@ enum { ANEURALNETWORKS_DEVICE_ACCELERATOR = 4, }; +/** + * Relative execution priority. + * + * Available since API level 30. + */ +enum { + ANEURALNETWORKS_PRIORITY_LOW = 90, + ANEURALNETWORKS_PRIORITY_MEDIUM = 100, + ANEURALNETWORKS_PRIORITY_HIGH = 110, + ANEURALNETWORKS_PRIORITY_DEFAULT = ANEURALNETWORKS_PRIORITY_MEDIUM, +}; + /** * ANeuralNetworksMemory is an opaque type that represents memory. * @@ -521,9 +540,21 @@ typedef int (*ANeuralNetworksCompilation_setCaching_fn)( ANeuralNetworksCompilation* compilation, const char* cacheDir, const uint8_t* token); +typedef int (*ANeuralNetworksCompilation_setTimeout_fn)( + ANeuralNetworksCompilation* compilation, uint64_t duration); + +typedef int (*ANeuralNetworksCompilation_setPriority_fn)( + ANeuralNetworksCompilation* compilation, int priority); + typedef int (*ANeuralNetworksExecution_compute_fn)( ANeuralNetworksExecution* execution); +typedef int (*ANeuralNetworksExecution_setTimeout_fn)( + ANeuralNetworksExecution* execution, uint64_t duration); + +typedef int (*ANeuralNetworksExecution_setLoopTimeout_fn)( + ANeuralNetworksExecution* execution, uint64_t duration); + typedef int (*ANeuralNetworksExecution_getOutputOperandRank_fn)( ANeuralNetworksExecution* execution, int32_t index, uint32_t* rank); diff --git a/tensorflow/lite/nnapi/nnapi_implementation.cc b/tensorflow/lite/nnapi/nnapi_implementation.cc index 71a4de53e9a..ad5869fec04 100644 --- a/tensorflow/lite/nnapi/nnapi_implementation.cc +++ b/tensorflow/lite/nnapi/nnapi_implementation.cc @@ -45,19 +45,6 @@ int32_t GetAndroidSdkVersion() { } result = result * 10 + digit; } - // TODO(levp): remove once SDK gets updated to 29th level - // Upgrade SDK version for pre-release Q to be able to test functionality - // available from SDK level 29. - if (result == 28) { - char versionCodename[PROP_VALUE_MAX]; - const char* versionCodenameProp = "ro.build.version.codename"; - length = __system_property_get(versionCodenameProp, versionCodename); - if (length != 0) { - if (versionCodename[0] == 'Q') { - return 29; - } - } - } return result; } return 0; @@ -228,6 +215,17 @@ const NnApi LoadNnApi() { ANeuralNetworksModel_getExtensionOperationType); LOAD_FUNCTION_OPTIONAL(libneuralnetworks, ANeuralNetworksModel_setOperandExtensionData); + + // API 30 (NNAPI 1.3) methods. + LOAD_FUNCTION_OPTIONAL(libneuralnetworks, + ANeuralNetworksCompilation_setTimeout); + LOAD_FUNCTION_OPTIONAL(libneuralnetworks, + ANeuralNetworksCompilation_setPriority); + LOAD_FUNCTION_OPTIONAL(libneuralnetworks, + ANeuralNetworksExecution_setTimeout); + LOAD_FUNCTION_OPTIONAL(libneuralnetworks, + ANeuralNetworksExecution_setLoopTimeout); + return nnapi; } diff --git a/tensorflow/lite/nnapi/nnapi_implementation.h b/tensorflow/lite/nnapi/nnapi_implementation.h index a27f5ba661a..abee0fbdef3 100644 --- a/tensorflow/lite/nnapi/nnapi_implementation.h +++ b/tensorflow/lite/nnapi/nnapi_implementation.h @@ -789,6 +789,76 @@ struct NnApi { ANeuralNetworksCompilation* compilation, const char* cacheDir, const uint8_t* token); + /** + * Set the maximum expected duration for compiling the model. + * + * If the device is not able to complete the compilation within the specified + * duration, the compilation may be aborted. The timeout duration begins at + * the call to {@link ANeuralNetworksCompilation_finish}. + * + * This timeout duration acts as a hint to drivers, and can be used to both + * free up compute resources within the driver and return control back to the + * application quicker than is possible without the hint. It enables drivers + * that are able to estimate how long a compilation will take to abort the + * compilation before it has even started if the driver believes the + * compilation cannot be completed within the timeout duration. Similarly, it + * enables drivers to abort an ongoing compilation if it is taking too long. + * However, this call does not guarantee that the compilation will complete or + * abort within the timeout duration. + * + * By default (i.e., unless ANeuralNetworksCompilation_setTimeout is called), + * the timeout duration for compiling the model is considered infinite. + * + * The {@link ANeuralNetworksCompilation} must have been created with + * {@link ANeuralNetworksCompilation_createForDevices} with numDevices = 1, + * otherwise this function will fail with ANEURALNETWORKS_BAD_DATA. If the + * device has a feature level reported by + * {@link ANeuralNetworksDevice_getFeatureLevel} that is lower than 30, then + * the timeout duration hint will be ignored. + * + * See {@link ANeuralNetworksCompilation} for information on multithreaded + * usage. + * + * @param compilation The compilation to be modified. + * @param duration The maximum amount of time in nanoseconds that is expected + * to be spent finishing a compilation. If this duration is exceeded, the + * compilation may be aborted. If set to 0, the timeout duration is + * considered infinite. + * + * @return ANEURALNETWORKS_NO_ERROR if successful. + * + * Available since API level 30. + */ + int (*ANeuralNetworksCompilation_setTimeout)( + ANeuralNetworksCompilation* compilation, uint64_t duration); + + /** + * Set the execution priority. + * + * Execution priorities are relative to other executions created by the same + * application (specifically same uid) for the same device. Specifically, + * priorities of executions from one application will not affect executions + * from another application. Similarly, priorities of executions on one device + * will not affect executions on another device. + * + * Higher priority executions may use more compute resources than lower + * priority executions, and may preempt or starve lower priority executions. + * + * See {@link ANeuralNetworksCompilation} for information on multithreaded + * usage. + * + * Available since API level 30. + * + * @param compilation The compilation to be modified. + * @param priority The relative priority of the execution compared to other + * executions created by the application. Must be one of + * ANEURALNETWORKS_PRIORITY_*. + * + * @return ANEURALNETWORKS_NO_ERROR if successful. + */ + int (*ANeuralNetworksCompilation_setPriority)( + ANeuralNetworksCompilation* compilation, int priority); + /** * Schedule synchronous evaluation of the execution. * @@ -813,6 +883,84 @@ struct NnApi { */ int (*ANeuralNetworksExecution_compute)(ANeuralNetworksExecution* execution); + /** + * Set the maximum expected duration of the specified execution. + * + * If the device is not able to complete the execution within the specified + * duration, the execution may be aborted. The timeout duration begins at a + * call to one of: + * - {@link ANeuralNetworksExecution_burstCompute} + * - {@link ANeuralNetworksExecution_compute} + * - {@link ANeuralNetworksExecution_startCompute} + * - {@link ANeuralNetworksExecution_startComputeWithDependencies} + * + * This timeout duration acts as a hint to drivers, and can be used to both + * free up compute resources within the driver and return control back to the + * application quicker than is possible without the hint. It enables drivers + * that are able to estimate how long an execution will take to abort the + * execution before it has even started if the driver believes the execution + * cannot be completed within the timeout duration. Similarly, it enables + * drivers to abort an ongoing execution if it is taking too long. However, + * this call does not guarantee that the execution will complete or abort + * within the timeout duration. + * + * By default (i.e., unless ANeuralNetworksExecution_setTimeout is called), + * the timeout duration for execution is considered infinite. + * + * The {@link ANeuralNetworksExecution} must have been created from an + * {@link ANeuralNetworksCompilation} which in turn was created from + * {@link ANeuralNetworksCompilation_createForDevices} with numDevices = 1, + * otherwise this function will fail with ANEURALNETWORKS_BAD_DATA. If the + * device has a feature level reported by + * {@link ANeuralNetworksDevice_getFeatureLevel} that is lower than 30, then + * the timeout duration hint will be ignored. + * + * See {@link ANeuralNetworksExecution} for information on multithreaded + * usage. + * + * @param execution The execution to be modified. + * @param duration The maximum amount of time in nanoseconds that is expected + * to be spent executing a model. If this duration is exceeded, the execution + * may be aborted. If set to 0, the timeout duration is considered + * infinite. + * + * @return ANEURALNETWORKS_NO_ERROR if successful. + * + * Available since API level 30. + */ + int (*ANeuralNetworksExecution_setTimeout)( + ANeuralNetworksExecution* execution, uint64_t duration); + + /** + * Set the maximum duration of WHILE loops in the specified execution. + * + * This is a fuzzy per-loop timeout intended to prevent infinite loops. + * + * If a WHILE loop condition model does not output false within the specified + * duration, the execution will be aborted. + * + * See {@link ANeuralNetworks_getDefaultLoopTimeout} and + * {@link ANeuralNetworks_getMaximumLoopTimeout} for the default + * and maximum timeout values. + * + * See {@link ANeuralNetworksExecution} for information on multithreaded + * usage. + * + * @param execution The execution to be modified. + * @param duration The maximum amount of time in nanoseconds that can be spent + * executing a WHILE loop. If the specified duration value exceeds the + * value produced by {@link ANeuralNetworks_getMaximumLoopTimeout}, it will be + * overridden by that value. + * + * @return ANEURALNETWORKS_NO_ERROR if successful. + * ANEURALNETWORKS_BAD_STATE if execution has started. + * ANEURALNETWORKS_UNEXPECTED_NULL if execution is NULL. + * + * Available since API level 30. + */ + int (*ANeuralNetworksExecution_setLoopTimeout)( + ANeuralNetworksExecution* execution, uint64_t duration); + /** * Get the dimensional information of the specified output operand of the * model of the diff --git a/tensorflow/lite/python/convert.py b/tensorflow/lite/python/convert.py index 6b7a32f1bcc..a5fbb88132e 100644 --- a/tensorflow/lite/python/convert.py +++ b/tensorflow/lite/python/convert.py @@ -169,9 +169,10 @@ def toco_convert_protos(model_flags_str, RuntimeError: When conversion fails, an exception is raised with the error message embedded. """ - # TODO(aselle): When toco does not use fatal errors for failure, we can - # switch this on. - if not _toco_from_proto_bin: + # Historically, TOCO conversion failures would trigger a crash, so we would + # attempt to run the converter out-of-process. The MLIR conversion pipeline + # surfaces errors instead, and can be safely run in-process. + if enable_mlir_converter or not _toco_from_proto_bin: try: model_str = wrap_toco.wrapped_toco_convert(model_flags_str, toco_flags_str, input_data_str, diff --git a/tensorflow/lite/python/interpreter.py b/tensorflow/lite/python/interpreter.py index ccbba9014c8..04863b12853 100644 --- a/tensorflow/lite/python/interpreter.py +++ b/tensorflow/lite/python/interpreter.py @@ -27,20 +27,8 @@ import numpy as np # pylint: disable=g-import-not-at-top if not __file__.endswith('tflite_runtime/interpreter.py'): # This file is part of tensorflow package. - from tensorflow.python.util.lazy_loader import LazyLoader + from tensorflow.lite.python.interpreter_wrapper import _pywrap_tensorflow_interpreter_wrapper as _interpreter_wrapper from tensorflow.python.util.tf_export import tf_export as _tf_export - - # Lazy load since some of the performance benchmark skylark rules - # break dependencies. Must use double quotes to match code internal rewrite - # rule. - # pylint: disable=g-inconsistent-quotes - _interpreter_wrapper = LazyLoader( - "_interpreter_wrapper", globals(), - "tensorflow.lite.python.interpreter_wrapper." - '_pywrap_tensorflow_interpreter_wrapper') - # pylint: enable=g-inconsistent-quotes - - del LazyLoader else: # This file is part of tflite_runtime package. from tflite_runtime import _pywrap_tensorflow_interpreter_wrapper as _interpreter_wrapper diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py index 99be58f4376..ce59c56a1d0 100644 --- a/tensorflow/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -386,13 +386,8 @@ class TFLiteConverterBase(object): return True return False - def _parse_saved_model_args(self, always_enable_saved_model_import=False): - """Parses SavedModel arguments from the given Keras/RNN SavedModel. - - Args: - always_enable_saved_model_import: Bool. When the value is true, it enables - MLIR saved model import path regardless of checking the conditions. - """ + def _parse_saved_model_args(self): + """Parses SavedModel arguments from the given Keras/RNN SavedModel.""" if not self.experimental_new_converter: self.saved_model_dir = None return @@ -405,17 +400,16 @@ class TFLiteConverterBase(object): # frozen graph def path. self.saved_model_dir = None return - if (not always_enable_saved_model_import and - not self._contains_function_with_implements_attr(saved_model_proto)): + if not self._contains_function_with_implements_attr(saved_model_proto): self.saved_model_dir = None - return - - if not self._saved_model_exported_names: - self._saved_model_exported_names = [] - self._saved_model_version = saved_model_proto.saved_model_schema_version - if self._saved_model_version not in [1, 2]: - raise ValueError("SavedModel file format({0}) is not supported".format( - self._saved_model_version)) + else: + if not self._saved_model_exported_names: + self._saved_model_exported_names = [] + self._saved_model_version = saved_model_proto.saved_model_schema_version + if self._saved_model_version not in [1, 2]: + raise ValueError( + "SavedModel file format({0}) is not supported".format( + self._saved_model_version)) class TFLiteConverterBaseV2(TFLiteConverterBase): @@ -548,7 +542,7 @@ class TFLiteSavedModelConverterV2(TFLiteConverterBaseV2): self._saved_model_tags = saved_model_tags self._saved_model_exported_names = saved_model_exported_names self._trackable_obj = trackable_obj - self._parse_saved_model_args(always_enable_saved_model_import=True) + self._parse_saved_model_args() def convert(self): """Converts a TensorFlow GraphDef based on instance variables. diff --git a/tensorflow/lite/python/tflite_convert.py b/tensorflow/lite/python/tflite_convert.py index d0dd7313df3..c7504a3a638 100644 --- a/tensorflow/lite/python/tflite_convert.py +++ b/tensorflow/lite/python/tflite_convert.py @@ -65,6 +65,8 @@ def _parse_inference_type(value, flag): return lite_constants.FLOAT if value == "QUANTIZED_UINT8": return lite_constants.QUANTIZED_UINT8 + if value == "INT8": + return lite_constants.INT8 raise ValueError("Unsupported value for --{0}. Only FLOAT and " "QUANTIZED_UINT8 are supported.".format(flag)) @@ -352,12 +354,12 @@ def _get_tf1_flags(parser): parser.add_argument( "--inference_type", type=str.upper, - choices=["FLOAT", "QUANTIZED_UINT8"], + choices=["FLOAT", "QUANTIZED_UINT8", "INT8"], help="Target data type of real-number arrays in the output file.") parser.add_argument( "--inference_input_type", type=str.upper, - choices=["FLOAT", "QUANTIZED_UINT8"], + choices=["FLOAT", "QUANTIZED_UINT8", "INT8"], help=("Target data type of real-number input arrays. Allows for a " "different type for input arrays in the case of quantization.")) diff --git a/tensorflow/lite/python/tflite_convert_test.py b/tensorflow/lite/python/tflite_convert_test.py index 1e80907edbd..d6a35ba9248 100644 --- a/tensorflow/lite/python/tflite_convert_test.py +++ b/tensorflow/lite/python/tflite_convert_test.py @@ -98,8 +98,8 @@ class TfLiteConvertV1Test(TestModels): sess.close() flags_str = ('--graph_def_file={0} --input_arrays={1} ' - '--output_arrays={2}'.format(graph_def_file, - 'Placeholder', 'add')) + '--output_arrays={2}'.format(graph_def_file, 'Placeholder', + 'add')) self._run(flags_str, should_succeed=True) os.remove(graph_def_file) @@ -137,8 +137,31 @@ class TfLiteConvertV1Test(TestModels): sess.close() flags_str = ('--graph_def_file={0} --input_arrays={1} ' - '--output_arrays={2}'.format(graph_def_file, - 'random', 'add')) + '--output_arrays={2}'.format(graph_def_file, 'random', 'add')) + self._run(flags_str, should_succeed=True) + os.remove(graph_def_file) + + def testQATFrozenGraphDefInt8(self): + with ops.Graph().as_default(): + in_tensor_1 = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA') + in_tensor_2 = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB') + _ = array_ops.fake_quant_with_min_max_args( + in_tensor_1 + in_tensor_2, min=0., max=1., name='output', + num_bits=16) # INT8 inference type works for 16 bits fake quant. + sess = session.Session() + + # Write graph to file. + graph_def_file = self._getFilepath('model.pb') + write_graph(sess.graph_def, '', graph_def_file, False) + sess.close() + + flags_str = ('--inference_type=INT8 --std_dev_values=128,128 ' + '--mean_values=128,128 ' + '--graph_def_file={0} --input_arrays={1},{2} ' + '--output_arrays={3}'.format(graph_def_file, 'inputA', + 'inputB', 'output')) self._run(flags_str, should_succeed=True) os.remove(graph_def_file) @@ -166,8 +189,8 @@ class TfLiteConvertV1Test(TestModels): def testKerasFileMLIR(self): keras_file = self._getKerasModelFile() - flags_str = ('--keras_model_file={} --experimental_new_converter' - .format(keras_file)) + flags_str = ( + '--keras_model_file={} --experimental_new_converter'.format(keras_file)) self._run(flags_str, should_succeed=True) os.remove(keras_file) @@ -299,8 +322,8 @@ class TfLiteConvertV2Test(TestModels): def testKerasFileMLIR(self): keras_file = self._getKerasModelFile() - flags_str = ('--keras_model_file={} --experimental_new_converter' - .format(keras_file)) + flags_str = ( + '--keras_model_file={} --experimental_new_converter'.format(keras_file)) self._run(flags_str, should_succeed=True) os.remove(keras_file) diff --git a/tensorflow/lite/python/util_test.py b/tensorflow/lite/python/util_test.py index 51a0c57260a..f3c287dd7fc 100644 --- a/tensorflow/lite/python/util_test.py +++ b/tensorflow/lite/python/util_test.py @@ -174,7 +174,6 @@ class TensorFunctionsTest(test_util.TensorFlowTestCase): str(error.exception)) self.assertEqual([None, 3, 5], tensor.shape.as_list()) - @test_util.run_deprecated_v1 def testSetTensorShapeDimensionInvalid(self): # Tests set_tensor_shape where the shape passed in is incompatible. with ops.Graph().as_default(): diff --git a/tensorflow/lite/simple_memory_arena.cc b/tensorflow/lite/simple_memory_arena.cc index a4d6d19656b..4aa0a1eb2ef 100644 --- a/tensorflow/lite/simple_memory_arena.cc +++ b/tensorflow/lite/simple_memory_arena.cc @@ -136,6 +136,8 @@ TfLiteStatus SimpleMemoryArena::ResolveAlloc( char** output_ptr) { TF_LITE_ENSURE(context, committed_); TF_LITE_ENSURE(context, output_ptr != nullptr); + TF_LITE_ENSURE(context, + underlying_buffer_size_ >= (alloc.offset + alloc.size)); if (alloc.size == 0) { *output_ptr = nullptr; } else { diff --git a/tensorflow/lite/simple_memory_arena_test.cc b/tensorflow/lite/simple_memory_arena_test.cc index fe337562b0a..0196421cc9c 100644 --- a/tensorflow/lite/simple_memory_arena_test.cc +++ b/tensorflow/lite/simple_memory_arena_test.cc @@ -197,6 +197,9 @@ TEST_P(BufferAndPlanClearingTest, TestClearBufferAndClearPlan) { EXPECT_NE(resolved_ptr, nullptr); } +INSTANTIATE_TEST_SUITE_P(BufferAndPlanClearingTest, BufferAndPlanClearingTest, + ::testing::Values(true, false)); + } // namespace } // namespace tflite diff --git a/tensorflow/lite/testing/model_coverage/model_coverage_lib_test.py b/tensorflow/lite/testing/model_coverage/model_coverage_lib_test.py index 9236181f840..03a0004b2fc 100644 --- a/tensorflow/lite/testing/model_coverage/model_coverage_lib_test.py +++ b/tensorflow/lite/testing/model_coverage/model_coverage_lib_test.py @@ -31,6 +31,7 @@ from tensorflow.python.eager import def_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -178,18 +179,21 @@ class EvaluateKerasModel(test.TestCase): os.close(fd) return keras_file + @test_util.run_v1_only('Keras test fails under v2, see b/157266669') def testFloat(self): model = self._getSingleInputKerasModel() keras_file = self._saveKerasModel(model) model_coverage.test_keras_model(keras_file) + @test_util.run_v1_only('Keras test fails under v2, see b/157266669') def testPostTrainingQuantize(self): model = self._getSingleInputKerasModel() keras_file = self._saveKerasModel(model) model_coverage.test_keras_model(keras_file, post_training_quantize=True) + @test_util.run_v1_only('Keras test fails under v2, see b/157266669') def testTargetOps(self): model = self._getSingleInputKerasModel() keras_file = self._saveKerasModel(model) diff --git a/tensorflow/lite/testing/op_tests/space_to_batch_nd.py b/tensorflow/lite/testing/op_tests/space_to_batch_nd.py index 81753539e8a..86b061c6885 100644 --- a/tensorflow/lite/testing/op_tests/space_to_batch_nd.py +++ b/tensorflow/lite/testing/op_tests/space_to_batch_nd.py @@ -105,6 +105,13 @@ def make_space_to_batch_nd_tests(options): values.append(np.array(parameters["paddings"])) return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) + if options.use_experimental_converter: + # Remove unsupported dimension cases. Currently, kernel supports 3 and 4-D + # inputs. + test_parameters = [ + test_parameters[0], test_parameters[1], test_parameters[3] + ] + make_zip_of_tests( options, test_parameters, diff --git a/tensorflow/lite/testing/op_tests/transpose_conv.py b/tensorflow/lite/testing/op_tests/transpose_conv.py index 654856f0d88..09c1b5f4f14 100644 --- a/tensorflow/lite/testing/op_tests/transpose_conv.py +++ b/tensorflow/lite/testing/op_tests/transpose_conv.py @@ -38,6 +38,7 @@ def make_transpose_conv_tests(options): { "input_shape": [[1, 3, 4, 1], [1, 10, 10, 3], [3, 20, 20, 1]], "filter_size": [[1, 1], [1, 2], [3, 3]], + "has_bias": [False], "strides": [[1, 1, 1, 1], [1, 3, 3, 1]], "padding": ["SAME", "VALID"], "data_format": ["NHWC"], @@ -50,6 +51,7 @@ def make_transpose_conv_tests(options): { "input_shape": [[1, 3, 3, 1]], "filter_size": [[3, 3, 2, 1]], + "has_bias": [False], "strides": [[1, 1, 1, 1]], "padding": ["SAME"], "data_format": ["NHWC"], @@ -60,6 +62,7 @@ def make_transpose_conv_tests(options): { "input_shape": [[1, 3, 3, 1]], "filter_size": [[3, 3, 2, 1]], + "has_bias": [False], "strides": [[1, 2, 2, 1]], "padding": ["SAME"], "data_format": ["NHWC"], @@ -70,13 +73,25 @@ def make_transpose_conv_tests(options): { "input_shape": [[1, 4, 3, 1]], "filter_size": [[3, 3, 2, 1]], + "has_bias": [False], "strides": [[1, 2, 2, 1]], "padding": ["SAME"], "data_format": ["NHWC"], "channel_multiplier": [1], "output_shape": [[1, 8, 6, 2]], "fully_quantize": [True] - } + }, + { + "input_shape": [[1, 3, 3, 1]], + "filter_size": [[3, 3, 2, 1]], + "has_bias": [True], + "strides": [[1, 1, 1, 1]], + "padding": ["SAME"], + "data_format": ["NHWC"], + "channel_multiplier": [1], + "output_shape": [[1, 3, 3, 2]], + "fully_quantize": [True] + }, ] def get_tensor_shapes(parameters): @@ -124,6 +139,13 @@ def make_transpose_conv_tests(options): strides=parameters["strides"], padding=parameters["padding"], data_format=parameters["data_format"]) + if parameters["has_bias"]: + bias_input = create_tensor_data( + np.float32, (parameters["output_shape"][-1],), + min_value=-1, + max_value=1) + out = tf.nn.bias_add( + out, bias_input, data_format=parameters["data_format"]) return input_tensors, [out] diff --git a/tensorflow/lite/tflite_with_xnnpack_optional.cc b/tensorflow/lite/tflite_with_xnnpack_optional.cc new file mode 100644 index 00000000000..31d4ff50f28 --- /dev/null +++ b/tensorflow/lite/tflite_with_xnnpack_optional.cc @@ -0,0 +1,52 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/tflite_with_xnnpack_optional.h" + +#include "tensorflow/lite/core/macros.h" + +#ifdef TFLITE_BUILD_WITH_XNNPACK_DELEGATE +#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" +#endif + +namespace tflite { + +using TfLiteDelegatePtr = + std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>; + +#ifndef TFLITE_BUILD_WITH_XNNPACK_DELEGATE +// Using weak symbols to create a delegate allows automatic injection of the +// delegate simply by adding it as a dependency. See the strong override in +// lite/tflite_with_xnnpack.cc, +TFLITE_ATTRIBUTE_WEAK TfLiteDelegatePtr +AcquireXNNPACKDelegate(int num_threads) { + return TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {}); +} +#endif + +#ifdef TFLITE_BUILD_WITH_XNNPACK_DELEGATE +TfLiteDelegatePtr MaybeCreateXNNPACKDelegate(int num_threads) { + auto opts = TfLiteXNNPackDelegateOptionsDefault(); + // Note that we don't want to use the thread pool for num_threads == 1. + opts.num_threads = num_threads > 1 ? num_threads : 0; + return TfLiteDelegatePtr(TfLiteXNNPackDelegateCreate(&opts), + TfLiteXNNPackDelegateDelete); +} +#else +TfLiteDelegatePtr MaybeCreateXNNPACKDelegate(int num_threads) { + return AcquireXNNPACKDelegate(num_threads); +} +#endif + +} // namespace tflite diff --git a/tensorflow/lite/tflite_with_xnnpack_optional.h b/tensorflow/lite/tflite_with_xnnpack_optional.h new file mode 100644 index 00000000000..afbdbd17356 --- /dev/null +++ b/tensorflow/lite/tflite_with_xnnpack_optional.h @@ -0,0 +1,26 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_TFLITE_WITH_XNNPACK_OPTIONAL_H_ +#define TENSORFLOW_LITE_TFLITE_WITH_XNNPACK_OPTIONAL_H_ +#include <memory> + +#include "tensorflow/lite/c/common.h" + +namespace tflite { +std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)> +MaybeCreateXNNPACKDelegate(int num_threads); +} // namespace tflite + +#endif // TENSORFLOW_LITE_TFLITE_WITH_XNNPACK_OPTIONAL_H_ diff --git a/tensorflow/lite/toco/model_cmdline_flags.cc b/tensorflow/lite/toco/model_cmdline_flags.cc index 2434481272f..86a1cedd612 100644 --- a/tensorflow/lite/toco/model_cmdline_flags.cc +++ b/tensorflow/lite/toco/model_cmdline_flags.cc @@ -204,7 +204,7 @@ void ReadModelFlagsFromCommandLineFlags( } #ifdef PLATFORM_GOOGLE - CHECK(!((base::SpecifiedOnCommandLine("batch") && + CHECK(!((base::WasPresentOnCommandLine("batch") && parsed_model_flags.variable_batch.specified()))) << "The --batch and --variable_batch flags are mutually exclusive."; #endif diff --git a/tensorflow/lite/tools/BUILD b/tensorflow/lite/tools/BUILD index a96c1c3ede3..6ae5c1dda18 100644 --- a/tensorflow/lite/tools/BUILD +++ b/tensorflow/lite/tools/BUILD @@ -17,7 +17,10 @@ py_binary( srcs = ["visualize.py"], python_version = "PY3", srcs_version = "PY2AND3", - deps = ["//tensorflow/lite/python:schema_py"], + deps = [ + "//tensorflow/lite/python:schema_py", + "//third_party/py/numpy", + ], ) py_test( diff --git a/tensorflow/lite/tools/delegates/gpu_delegate_provider.cc b/tensorflow/lite/tools/delegates/gpu_delegate_provider.cc index db1f32b2282..62805b2644b 100644 --- a/tensorflow/lite/tools/delegates/gpu_delegate_provider.cc +++ b/tensorflow/lite/tools/delegates/gpu_delegate_provider.cc @@ -154,8 +154,8 @@ TfLiteDelegatePtr GpuDelegateProvider::CreateTfLiteDelegate( delegate = TfLiteDelegatePtr(TFLGpuDelegateCreate(&gpu_opts), &TFLGpuDelegateDelete); #else - TFLITE_LOG(WARN) << "The GPU delegate compile options are only supported on" - "Android or iOS platforms."; + TFLITE_LOG(WARN) << "The GPU delegate compile options are only supported " + "on Android or iOS platforms."; delegate = evaluation::CreateGPUDelegate(); #endif diff --git a/tensorflow/lite/tools/make/Makefile b/tensorflow/lite/tools/make/Makefile index 41f87fb033d..3635ac95167 100644 --- a/tensorflow/lite/tools/make/Makefile +++ b/tensorflow/lite/tools/make/Makefile @@ -339,11 +339,18 @@ $(BENCHMARK_LIB) : $(LIB_PATH) $(BENCHMARK_LIB_OBJS) benchmark_lib: $(BENCHMARK_LIB) +BENCHMARK_LINKOPTS := +ifeq ($(HOST_OS),osx) + BENCHMARK_LINKOPTS += $(LIBFLAGS) -Wl,-force_load $(BENCHMARK_LIB) $(LIBS) $(LDFLAGS) -framework CoreFoundation +else + BENCHMARK_LINKOPTS += $(LIBFLAGS) -Wl,--whole-archive $(BENCHMARK_LIB) -Wl,--no-whole-archive $(LDFLAGS) $(LIBS) +endif + $(BENCHMARK_BINARY) : $(BENCHMARK_MAIN_OBJ) $(BENCHMARK_LIB) @mkdir -p $(dir $@) $(CXX) $(CXXFLAGS) $(INCLUDES) \ -o $(BENCHMARK_BINARY) $(BENCHMARK_MAIN_OBJ) \ - $(LIBFLAGS) -Wl,--whole-archive $(BENCHMARK_LIB) -Wl,--no-whole-archive $(LDFLAGS) $(LIBS) + $(LIBFLAGS) $(BENCHMARK_LINKOPTS) $(BENCHMARK_PERF_OPTIONS_BINARY) : $(BENCHMARK_PERF_OPTIONS_OBJ) $(BENCHMARK_LIB) @mkdir -p $(dir $@) diff --git a/tensorflow/lite/tools/make/targets/rpi_makefile.inc b/tensorflow/lite/tools/make/targets/rpi_makefile.inc index 2225848ae64..71046d08131 100644 --- a/tensorflow/lite/tools/make/targets/rpi_makefile.inc +++ b/tensorflow/lite/tools/make/targets/rpi_makefile.inc @@ -32,7 +32,7 @@ ifeq ($(TARGET),rpi) # TODO(petewarden) In the future, we'll want to use OpenBLAS as a faster # alternative to Eigen on non-NEON ARM hardware like armv6. ifeq ($(TARGET_ARCH), armv6) - TARGET_TOOLCHAIN_PREFIX := arm-linux-gnueabi- + TARGET_TOOLCHAIN_PREFIX := arm-linux-gnueabihf- CXXFLAGS += \ -march=armv6 \ -mfpu=vfp \ diff --git a/tensorflow/lite/tools/pip_package/README.md b/tensorflow/lite/tools/pip_package/README.md index dac8ce02ca1..8a2be59b980 100644 --- a/tensorflow/lite/tools/pip_package/README.md +++ b/tensorflow/lite/tools/pip_package/README.md @@ -49,6 +49,52 @@ BUILD_DEB=y to the make command (only for python3): make BASE_IMAGE=debian:buster PYTHON=python3 TENSORFLOW_TARGET=rpi BUILD_DEB=y docker-build ``` +## Alternative build with Bazel (experimental) + +There is another build steps to build a binary wheel which uses Bazel instead of +Makefile. You don't need to install additional dependencies. +This approach can leverage TF's ci_build.sh for ARM cross builds. + +### Native build for your workstation + +```sh +tensorflow/lite/tools/pip_package/build_pip_package_with_bazel.sh +``` + +### Cross build for armhf Python 3.5 + +```sh +CI_DOCKER_EXTRA_PARAMS="-e CI_BUILD_PYTHON=python3 -e CROSSTOOL_PYTHON_INCLUDE_PATH=/usr/include/python3.5" \ + tensorflow/tools/ci_build/ci_build.sh PI-PYTHON3 \ + tensorflow/lite/tools/pip_package/build_pip_package_with_bazel.sh armhf +``` + +### Cross build for armhf Python 3.7 + +```sh +CI_DOCKER_EXTRA_PARAMS="-e CI_BUILD_PYTHON=python3 -e CROSSTOOL_PYTHON_INCLUDE_PATH=/usr/include/python3.7" \ + tensorflow/tools/ci_build/ci_build.sh PI-PYTHON37 \ + tensorflow/lite/tools/pip_package/build_pip_package_with_bazel.sh armhf +``` + +### Cross build for aarch64 Python 3.5 + +```sh + CI_DOCKER_EXTRA_PARAMS="-e CI_BUILD_PYTHON=python3 -e CROSSTOOL_PYTHON_INCLUDE_PATH=/usr/include/python3.5" \ + tensorflow/tools/ci_build/ci_build.sh PI-PYTHON3 \ + tensorflow/lite/tools/pip_package/build_pip_package_with_bazel.sh aarch64 +``` + +### Cross build for aarch64 Python 3.7 + +```sh +CI_DOCKER_EXTRA_PARAMS="-e CI_BUILD_PYTHON=python3 -e CROSSTOOL_PYTHON_INCLUDE_PATH=/usr/include/python3.7" \ + tensorflow/tools/ci_build/ci_build.sh PI-PYTHON37 \ + tensorflow/lite/tools/pip_package/build_pip_package_with_bazel.sh aarch64 +``` + +## Usage + Note, unlike tensorflow this will be installed to a tflite_runtime namespace. You can then use the Tensorflow Lite interpreter as. diff --git a/tensorflow/lite/tools/pip_package/build_pip_package_with_bazel.sh b/tensorflow/lite/tools/pip_package/build_pip_package_with_bazel.sh new file mode 100755 index 00000000000..69afb2f6b80 --- /dev/null +++ b/tensorflow/lite/tools/pip_package/build_pip_package_with_bazel.sh @@ -0,0 +1,126 @@ +#!/usr/bin/env bash +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +set -ex + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PYTHON="${PYTHON:-python3}" +VERSION_SUFFIX=${VERSION_SUFFIX:-} +export TENSORFLOW_DIR="${SCRIPT_DIR}/../../../.." +TENSORFLOW_LITE_DIR="${TENSORFLOW_DIR}/tensorflow/lite" +TENSORFLOW_VERSION=$(grep "_VERSION = " "${TENSORFLOW_DIR}/tensorflow/tools/pip_package/setup.py" | cut -d= -f2 | sed "s/[ '-]//g") +export PACKAGE_VERSION="${TENSORFLOW_VERSION}${VERSION_SUFFIX}" +BUILD_DIR="${SCRIPT_DIR}/gen/tflite_pip/${PYTHON}" +TENSORFLOW_TARGET=$1 + +# Build source tree. +rm -rf "${BUILD_DIR}" && mkdir -p "${BUILD_DIR}/tflite_runtime" +cp -r "${TENSORFLOW_LITE_DIR}/tools/pip_package/debian" \ + "${TENSORFLOW_LITE_DIR}/tools/pip_package/setup_with_bazel.py" \ + "${TENSORFLOW_LITE_DIR}/tools/pip_package/MANIFEST.in" \ + "${TENSORFLOW_LITE_DIR}/python/interpreter_wrapper" \ + "${BUILD_DIR}" +cp "${TENSORFLOW_LITE_DIR}/python/interpreter.py" \ + "${BUILD_DIR}/tflite_runtime" +echo "__version__ = '${PACKAGE_VERSION}'" >> "${BUILD_DIR}/tflite_runtime/__init__.py" +echo "__git_version__ = '$(git -C "${TENSORFLOW_DIR}" describe)'" >> "${BUILD_DIR}/tflite_runtime/__init__.py" + +# Build python interpreter_wrapper. +cd "${BUILD_DIR}" +case "${TENSORFLOW_TARGET}" in + rpi|armhf) + BAZEL_FLAGS="--config=elinux_armhf + --copt=-march=armv7-a --copt=-mfpu=neon-vfpv4 + --copt=-O3 --copt=-fno-tree-pre --copt=-fpermissive + --define=raspberry_pi_with_neon=true" + ;; + aarch64) + BAZEL_FLAGS="--config=elinux_aarch64 + --copt=-O3" + ;; + *) + ;; +esac + +# We need to pass down the environment variable with a possible alternate Python +# include path for Python 3.x builds to work. +export CROSSTOOL_PYTHON_INCLUDE_PATH + +bazel build -c opt -s --config=monolithic ${BAZEL_FLAGS} //tensorflow/lite/python/interpreter_wrapper:_pywrap_tensorflow_interpreter_wrapper +cp "${TENSORFLOW_DIR}/bazel-bin/tensorflow/lite/python/interpreter_wrapper/_pywrap_tensorflow_interpreter_wrapper.so" \ + "${BUILD_DIR}/tflite_runtime" + +# Build python wheel. +cd "${BUILD_DIR}" +case "${TENSORFLOW_TARGET}" in + rpi|armhf) + ${PYTHON} setup_with_bazel.py bdist --plat-name=linux-armv7l \ + bdist_wheel --plat-name=linux-armv7l + ;; + aarch64) + ${PYTHON} setup_with_bazel.py bdist --plat-name=linux-aarch64 \ + bdist_wheel --plat-name=linux-aarch64 + ;; + *) + if [[ -n "${TENSORFLOW_TARGET}" ]] && [[ -n "${TENSORFLOW_TARGET_ARCH}" ]]; then + ${PYTHON} setup_with_bazel.py bdist --plat-name=${TENSORFLOW_TARGET}-${TENSORFLOW_TARGET_ARCH} \ + bdist_wheel --plat-name=${TENSORFLOW_TARGET}-${TENSORFLOW_TARGET_ARCH} + else + ${PYTHON} setup_with_bazel.py bdist bdist_wheel + fi + ;; +esac + +echo "Output can be found here:" +find "${BUILD_DIR}" + +# Build debian package. +if [[ "${BUILD_DEB}" != "y" ]]; then + exit 0 +fi + +PYTHON_VERSION=$(${PYTHON} -c "import sys;print(sys.version_info.major)") +if [[ ${PYTHON_VERSION} != 3 ]]; then + echo "Debian package can only be generated for python3." >&2 + exit 1 +fi + +DEB_VERSION=$(dpkg-parsechangelog --show-field Version | cut -d- -f1) +if [[ "${DEB_VERSION}" != "${PACKAGE_VERSION}" ]]; then + cat << EOF > "${BUILD_DIR}/debian/changelog" +tflite-runtime (${PACKAGE_VERSION}-1) unstable; urgency=low + + * Bump version to ${PACKAGE_VERSION}. + + -- TensorFlow team <packages@tensorflow.org> $(date -R) + +$(<"${BUILD_DIR}/debian/changelog") +EOF +fi + +case "${TENSORFLOW_TARGET}" in + rpi|armhf) + dpkg-buildpackage -b -rfakeroot -us -uc -tc -d -a armhf + ;; + aarch64) + dpkg-buildpackage -b -rfakeroot -us -uc -tc -d -a arm64 + ;; + *) + dpkg-buildpackage -b -rfakeroot -us -uc -tc -d + ;; +esac + +cat "${BUILD_DIR}/debian/changelog" + diff --git a/tensorflow/lite/tools/pip_package/setup_with_bazel.py b/tensorflow/lite/tools/pip_package/setup_with_bazel.py new file mode 100644 index 00000000000..e3e9a35a62e --- /dev/null +++ b/tensorflow/lite/tools/pip_package/setup_with_bazel.py @@ -0,0 +1,70 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TensorFlow Lite is for mobile and embedded devices. + +TensorFlow Lite is the official solution for running machine learning models on +mobile and embedded devices. It enables on-device machine learning inference +with low latency and a small binary size on Android, iOS, and other operating +systems. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from setuptools import find_packages +from setuptools import setup +PACKAGE_NAME = 'tflite_runtime' +PACKAGE_VERSION = os.environ['PACKAGE_VERSION'] +DOCLINES = __doc__.split('\n') + +setup( + name=PACKAGE_NAME.replace('_', '-'), + version=PACKAGE_VERSION, + description=DOCLINES[0], + long_description='\n'.join(DOCLINES[2:]), + url='https://www.tensorflow.org/lite/', + author='Google, LLC', + author_email='packages@tensorflow.org', + license='Apache 2.0', + include_package_data=True, + keywords='tflite tensorflow tensor machine learning', + classifiers=[ + 'Development Status :: 5 - Production/Stable', + 'Intended Audience :: Developers', + 'Intended Audience :: Education', + 'Intended Audience :: Science/Research', + 'License :: OSI Approved :: Apache Software License', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.4', + 'Programming Language :: Python :: 3.5', + 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', + 'Topic :: Scientific/Engineering', + 'Topic :: Scientific/Engineering :: Mathematics', + 'Topic :: Scientific/Engineering :: Artificial Intelligence', + 'Topic :: Software Development', + 'Topic :: Software Development :: Libraries', + 'Topic :: Software Development :: Libraries :: Python Modules', + ], + packages=find_packages(exclude=[]), + package_dir={'': '.'}, + package_data={'': ['*.so']}, + install_requires=[ + 'numpy >= 1.16.0', + 'pybind11 >= 2.4.3', + ]) diff --git a/tensorflow/lite/tools/visualize.py b/tensorflow/lite/tools/visualize.py index 1f89f9c5448..3d22d1bb05b 100644 --- a/tensorflow/lite/tools/visualize.py +++ b/tensorflow/lite/tools/visualize.py @@ -28,6 +28,7 @@ import json import os import re import sys +import numpy as np from tensorflow.lite.python import schema_py_generated as schema_fb @@ -377,23 +378,34 @@ def CamelCaseToSnakeCase(camel_case_input): return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() -def FlatbufferToDict(fb): - """Converts a hierarchy of FB objects into a nested dict.""" - if hasattr(fb, "__dict__"): +def FlatbufferToDict(fb, preserve_as_numpy): + """Converts a hierarchy of FB objects into a nested dict. + + We avoid transforming big parts of the flat buffer into python arrays. This + speeds conversion from ten minutes to a few seconds on big graphs. + + Args: + fb: a flat buffer structure. (i.e. ModelT) + preserve_as_numpy: true if all downstream np.arrays should be preserved. + false if all downstream np.array should become python arrays + Returns: + A dictionary representing the flatbuffer rather than a flatbuffer object. + """ + if isinstance(fb, int) or isinstance(fb, float) or isinstance(fb, str): + return fb + elif hasattr(fb, "__dict__"): result = {} for attribute_name in dir(fb): attribute = fb.__getattribute__(attribute_name) if not callable(attribute) and attribute_name[0] != "_": snake_name = CamelCaseToSnakeCase(attribute_name) - result[snake_name] = FlatbufferToDict(attribute) + preserve = True if attribute_name == "buffers" else preserve_as_numpy + result[snake_name] = FlatbufferToDict(attribute, preserve) return result - elif isinstance(fb, str): - return fb + elif isinstance(fb, np.ndarray): + return fb if preserve_as_numpy else fb.tolist() elif hasattr(fb, "__len__"): - result = [] - for entry in fb: - result.append(FlatbufferToDict(entry)) - return result + return [FlatbufferToDict(entry, preserve_as_numpy) for entry in fb] else: return fb @@ -401,7 +413,7 @@ def FlatbufferToDict(fb): def CreateDictFromFlatbuffer(buffer_data): model_obj = schema_fb.Model.GetRootAsModel(buffer_data, 0) model = schema_fb.ModelT.InitFromObj(model_obj) - return FlatbufferToDict(model) + return FlatbufferToDict(model, preserve_as_numpy=False) def CreateHtmlFile(tflite_input, html_output): diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index a49e4b74def..13c58c74583 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -3,7 +3,7 @@ # ":platform" - Low-level and platform-specific Python code. load("//tensorflow:tensorflow.bzl", "py_strict_library") -load("//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_mlir", "if_not_windows", "if_xla_available", "py_test", "py_tests", "tf_cc_shared_object", "tf_cuda_library", "tf_gen_op_wrapper_py", "tf_py_test") +load("//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_mlir", "if_not_windows", "if_xla_available", "py_test", "py_tests", "tf_cc_shared_object", "tf_cuda_library", "tf_gen_op_wrapper_py") # buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension") @@ -26,6 +26,9 @@ load("//tensorflow:tensorflow.bzl", "tf_external_workspace_visible") # buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "tf_pybind_cc_library_wrapper") +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "tf_py_test") + # buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "tf_py_build_info_genrule") load("//tensorflow/core/platform:build_config.bzl", "pyx_library", "tf_additional_all_protos", "tf_additional_lib_deps", "tf_proto_library", "tf_proto_library_py", "tf_protos_grappler") # @unused @@ -53,6 +56,7 @@ visibility = [ "//third_party/py/tf_slim:__subpackages__", # TODO(aselle): to pass open source test. "//bazel_pip/tensorflow/lite/toco/python:__pkg__", + "//third_party/py/tensorflow_docs:__subpackages__", ] package( @@ -134,7 +138,7 @@ py_library( ":_pywrap_utils", ":array_ops", ":audio_ops_gen", - ":bincount", + ":bincount_ops", ":bitwise_ops", ":boosted_trees_ops", ":check_ops", @@ -2071,6 +2075,7 @@ tf_py_test( srcs = ["framework/constant_op_test.py"], main = "framework/constant_op_test.py", python_version = "PY3", + tfrt_enabled = True, deps = [ ":constant_op", ], @@ -3472,23 +3477,24 @@ py_library( ) py_library( - name = "bincount", - srcs = ["ops/bincount.py"], + name = "bincount_ops", + srcs = ["ops/bincount_ops.py"], srcs_version = "PY2AND3", deps = [ ":count_ops_gen", ":framework", ":framework_for_generated_wrappers", + "//tensorflow/python/compat", ], ) tf_py_test( - name = "bincount_test", + name = "bincount_ops_test", size = "small", - srcs = ["ops/bincount_test.py"], + srcs = ["ops/bincount_ops_test.py"], python_version = "PY3", deps = [ - ":bincount", + ":bincount_ops", ":platform_test", ], ) @@ -4175,6 +4181,7 @@ py_library( ":random_ops", ":tensor_shape", ":tensor_util", + ":variables", "//third_party/py/numpy", ], ) diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py index 8939c9b3143..781ef33f744 100644 --- a/tensorflow/python/__init__.py +++ b/tensorflow/python/__init__.py @@ -85,7 +85,7 @@ from tensorflow.python import keras from tensorflow.python.feature_column import feature_column_lib as feature_column from tensorflow.python.layers import layers from tensorflow.python.module import module -from tensorflow.python.ops import bincount +from tensorflow.python.ops import bincount_ops from tensorflow.python.ops import bitwise_ops as bitwise from tensorflow.python.ops import gradient_checker_v2 from tensorflow.python.ops import image_ops as image diff --git a/tensorflow/python/autograph/converters/BUILD b/tensorflow/python/autograph/converters/BUILD index ec780a7c0a1..9cf3bba8dd5 100644 --- a/tensorflow/python/autograph/converters/BUILD +++ b/tensorflow/python/autograph/converters/BUILD @@ -118,7 +118,13 @@ py_test( name = "control_flow_test", srcs = ["control_flow_test.py"], python_version = "PY3", - srcs_version = "PY2AND3", + srcs_version = "PY3", + tags = [ + "no_oss_py2", + "no_pip", + "no_windows", + "nopip", + ], deps = [ ":converters", "//tensorflow/python:client_testlib", diff --git a/tensorflow/python/autograph/converters/conditional_expressions.py b/tensorflow/python/autograph/converters/conditional_expressions.py index 44ab6dee926..65fb6765fcf 100644 --- a/tensorflow/python/autograph/converters/conditional_expressions.py +++ b/tensorflow/python/autograph/converters/conditional_expressions.py @@ -18,7 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import gast + from tensorflow.python.autograph.core import converter +from tensorflow.python.autograph.pyct import parser from tensorflow.python.autograph.pyct import templates @@ -26,19 +29,20 @@ class ConditionalExpressionTransformer(converter.Base): """Converts conditional expressions to functional form.""" def visit_IfExp(self, node): - return templates.replace_as_expression( - '''ag__.if_stmt( + template = ''' + ag__.if_exp( test, lambda: true_expr, lambda: false_expr, - lambda: (), - lambda _: None, - ('<internal expr>',), - ()) - ''', + expr_repr) + ''' + expr_repr = parser.unparse(node.test, include_encoding_marker=False).strip() + return templates.replace_as_expression( + template, test=node.test, true_expr=node.body, - false_expr=node.orelse) + false_expr=node.orelse, + expr_repr=gast.Constant(expr_repr, kind=None)) def transform(node, ctx): diff --git a/tensorflow/python/autograph/converters/control_flow.py b/tensorflow/python/autograph/converters/control_flow.py index a903c43bcfc..673781e47dd 100644 --- a/tensorflow/python/autograph/converters/control_flow.py +++ b/tensorflow/python/autograph/converters/control_flow.py @@ -23,7 +23,6 @@ import gast from tensorflow.python.autograph.core import converter from tensorflow.python.autograph.lang import directives from tensorflow.python.autograph.pyct import anno -from tensorflow.python.autograph.pyct import ast_util from tensorflow.python.autograph.pyct import cfg from tensorflow.python.autograph.pyct import parser from tensorflow.python.autograph.pyct import qual_names @@ -57,114 +56,16 @@ class ControlFlowTransformer(converter.Base): fn.scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) return self.generic_visit(node) - def _create_cond_branch(self, body_name, aliased_orig_names, - aliased_new_names, body, returns): - if len(returns) == 1: - template = """ - return retval - """ - return_stmt = templates.replace(template, retval=returns[0]) - else: - template = """ - return (retvals,) - """ - return_stmt = templates.replace(template, retvals=returns) - - if aliased_orig_names: - alias_declarations = [] - for new_name, old_name in zip(aliased_new_names, aliased_orig_names): - template = """ - try: - aliased_new_name = aliased_orig_name - except NameError: - aliased_new_name = ag__.Undefined(symbol_name) - """ - - alias_declarations.extend( - templates.replace( - template, - aliased_new_name=new_name, - aliased_orig_name=old_name, - symbol_name=gast.Constant(str(old_name), kind=None))) - - template = """ - def body_name(): - alias_declarations - body - return_stmt - """ - return templates.replace( - template, - alias_declarations=alias_declarations, - body_name=body_name, - body=body, - return_stmt=return_stmt) - else: - template = """ - def body_name(): - body - return_stmt - """ - return templates.replace( - template, body_name=body_name, body=body, return_stmt=return_stmt) - - def _create_cond_expr(self, results, test, body_name, orelse_name, - state_getter_name, state_setter_name, - basic_symbol_names, composite_symbol_names): - if results is not None: - template = """ - results = ag__.if_stmt(test, body_name, orelse_name, - state_getter_name, state_setter_name, - (basic_symbol_names,), - (composite_symbol_names,)) - """ - return templates.replace( - template, - test=test, - results=results, - body_name=body_name, - orelse_name=orelse_name, - state_getter_name=state_getter_name, - state_setter_name=state_setter_name, - basic_symbol_names=basic_symbol_names, - composite_symbol_names=composite_symbol_names) - else: - template = """ - ag__.if_stmt(test, body_name, orelse_name, getter_name, setter_name, - (basic_symbol_names,), (composite_symbol_names,)) - """ - return templates.replace( - template, - test=test, - body_name=body_name, - orelse_name=orelse_name, - getter_name=state_getter_name, - setter_name=state_setter_name, - basic_symbol_names=basic_symbol_names, - composite_symbol_names=composite_symbol_names) - - def _fmt_symbols(self, symbol_set): - if not symbol_set: - return 'no variables' - return ', '.join(map(str, symbol_set)) - - def _determine_aliased_symbols(self, scope, node_defined_in): - modified_live = scope.modified & node_defined_in - # Composite symbols are handled elsewhere, see _create_state_functions - return { - s for s in modified_live - if not s.is_composite() and s not in self.state[_Function].scope.globals - } - - def _create_nonlocal_declarations(self, loop_vars): + def _create_nonlocal_declarations(self, vars_): + vars_ = set(vars_) results = [] global_vars = self.state[_Function].scope.globals if global_vars: - results.append(gast.Global([str(v) for v in global_vars])) + results.append(gast.Global([str(v) for v in vars_])) nonlocal_vars = [ - v for v in loop_vars if not v.is_composite() and v not in global_vars] + v for v in vars_ if not v.is_composite() and v not in global_vars] if nonlocal_vars: results.append(gast.Nonlocal([str(v) for v in nonlocal_vars])) @@ -176,9 +77,9 @@ class ControlFlowTransformer(converter.Base): template = """ def getter_name(): return state_vars, - def setter_name(loop_vars): + def setter_name(vars_): nonlocal_declarations - state_vars, = loop_vars + state_vars, = vars_ """ return templates.replace( template, @@ -222,166 +123,34 @@ class ControlFlowTransformer(converter.Base): symbol_name=gast.Constant(s.ssf(), kind=None)) return assignments - def visit_If(self, node): - body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) - orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE) - defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN) - live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT) - - # Note: this information needs to be extracted before the body conversion - # that happens in the call to generic_visit below, because the conversion - # generates nodes that lack static analysis annotations. - need_alias_in_body = self._determine_aliased_symbols( - body_scope, defined_in) - need_alias_in_orelse = self._determine_aliased_symbols( - orelse_scope, defined_in) - - node = self.generic_visit(node) - - modified_in_cond = body_scope.modified | orelse_scope.modified - returned_from_cond = set() - composites = set() - for s in modified_in_cond: - if s in live_out and not s.is_composite(): - returned_from_cond.add(s) - if s.is_composite(): - # Special treatment for compound objects, always return them. - # This allows special handling within the if_stmt itself. - # For example, in TensorFlow we need to restore the state of composite - # symbols to ensure that only effects from the executed branch are seen. - composites.add(s) - - created_in_body = body_scope.modified & returned_from_cond - defined_in - created_in_orelse = orelse_scope.modified & returned_from_cond - defined_in - - basic_created_in_body = tuple( - s for s in created_in_body if not s.is_composite()) - basic_created_in_orelse = tuple( - s for s in created_in_orelse if not s.is_composite()) - - # These variables are defined only in a single branch. This is fine in - # Python so we pass them through. Another backend, e.g. Tensorflow, may need - # to handle these cases specially or throw an Error. - possibly_undefined = (set(basic_created_in_body) ^ - set(basic_created_in_orelse)) - - # Alias the closure variables inside the conditional functions, to allow - # the functions access to the respective variables. - # We will alias variables independently for body and orelse scope, - # because different branches might write different variables. - aliased_body_orig_names = tuple(need_alias_in_body) - aliased_orelse_orig_names = tuple(need_alias_in_orelse) - aliased_body_new_names = tuple( - self.ctx.namer.new_symbol(s.ssf(), body_scope.referenced) - for s in aliased_body_orig_names) - aliased_orelse_new_names = tuple( - self.ctx.namer.new_symbol(s.ssf(), orelse_scope.referenced) - for s in aliased_orelse_orig_names) - - alias_body_map = dict(zip(aliased_body_orig_names, aliased_body_new_names)) - alias_orelse_map = dict( - zip(aliased_orelse_orig_names, aliased_orelse_new_names)) - - node_body = ast_util.rename_symbols(node.body, alias_body_map) - node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map) - - cond_var_name = self.ctx.namer.new_symbol('cond', body_scope.referenced) - body_name = self.ctx.namer.new_symbol('if_true', body_scope.referenced) - orelse_name = self.ctx.namer.new_symbol('if_false', orelse_scope.referenced) - all_referenced = body_scope.referenced | orelse_scope.referenced - state_getter_name = self.ctx.namer.new_symbol('get_state', all_referenced) - state_setter_name = self.ctx.namer.new_symbol('set_state', all_referenced) - - returned_from_cond = tuple(returned_from_cond) - composites = tuple(composites) - - if returned_from_cond: - if len(returned_from_cond) == 1: - cond_results = returned_from_cond[0] - else: - cond_results = gast.Tuple([s.ast() for s in returned_from_cond], None) - - returned_from_body = tuple( - alias_body_map[s] if s in need_alias_in_body else s - for s in returned_from_cond) - returned_from_orelse = tuple( - alias_orelse_map[s] if s in need_alias_in_orelse else s - for s in returned_from_cond) - - else: - # When the cond would return no value, we leave the cond called without - # results. That in turn should trigger the side effect guards. The - # branch functions will return a dummy value that ensures cond - # actually has some return value as well. - cond_results = None - # TODO(mdan): Replace with None once side_effect_guards is retired. - returned_from_body = (templates.replace_as_expression( - 'ag__.match_staging_level(1, cond_var_name)', - cond_var_name=cond_var_name),) - returned_from_orelse = (templates.replace_as_expression( - 'ag__.match_staging_level(1, cond_var_name)', - cond_var_name=cond_var_name),) - - cond_assign = self.create_assignment(cond_var_name, node.test) - body_def = self._create_cond_branch( - body_name, - aliased_orig_names=aliased_body_orig_names, - aliased_new_names=aliased_body_new_names, - body=node_body, - returns=returned_from_body) - orelse_def = self._create_cond_branch( - orelse_name, - aliased_orig_names=aliased_orelse_orig_names, - aliased_new_names=aliased_orelse_new_names, - body=node_orelse, - returns=returned_from_orelse) - undefined_assigns = self._create_undefined_assigns(possibly_undefined) - composite_defs = self._create_state_functions( - composites, [], state_getter_name, state_setter_name) - - basic_symbol_names = tuple( - gast.Constant(str(symbol), kind=None) for symbol in returned_from_cond) - composite_symbol_names = tuple( - gast.Constant(str(symbol), kind=None) for symbol in composites) - - cond_expr = self._create_cond_expr(cond_results, cond_var_name, body_name, - orelse_name, state_getter_name, - state_setter_name, basic_symbol_names, - composite_symbol_names) - - if_ast = ( - undefined_assigns + composite_defs + body_def + orelse_def + - cond_assign + cond_expr) - return if_ast - - def _get_basic_loop_vars(self, modified, live_in, live_out): - # The loop variables corresponding to simple symbols (e.g. `x`). - basic_loop_vars = [] + def _get_block_basic_vars(self, modified, live_in, live_out): + nonlocals = self.state[_Function].scope.nonlocals + basic_scope_vars = [] for s in modified: if s.is_composite(): - # TODO(mdan): Raise an error when this happens for a TF loop. + # TODO(mdan): Raise an error when this happens for a TF scope. continue - # Variables not live into or out of the loop are considered local to the - # loop. - if s not in live_in and s not in live_out: - continue - basic_loop_vars.append(s) - return frozenset(basic_loop_vars) + # Variables not live into or out of the scope are considered local to the + # scope. + if s in live_in or s in live_out or s in nonlocals: + basic_scope_vars.append(s) + continue + return frozenset(basic_scope_vars) - def _get_composite_loop_vars(self, modified, live_in): - # The loop variables corresponding to composite symbols (e.g. `self.x`). - composite_loop_vars = [] + def _get_block_composite_vars(self, modified, live_in): + # The scope variables corresponding to composite symbols (e.g. `self.x`). + composite_scope_vars = [] for s in modified: if not s.is_composite(): continue - # Mutations made to objects created inside the loop will appear as writes + # Mutations made to objects created inside the scope will appear as writes # to composite symbols. Because these mutations appear as modifications # made to composite symbols, we check whether the composite's parent is - # actually live into the loop. + # actually live into the scope. # Example: # while cond: # x = Foo() - # x.foo = 2 * x.foo # x.foo is live into the loop, but x is not. + # x.foo = 2 * x.foo # x.foo is live into the scope, but x is not. # # Note that some parents might not be symbols - for example, in x['foo'], # 'foo' is a parent, but it's a literal, not a symbol. We don't check the @@ -390,40 +159,106 @@ class ControlFlowTransformer(converter.Base): sss for sss in s.support_set if sss.is_symbol()) if not all(sss in live_in for sss in support_set_symbols): continue - composite_loop_vars.append(s) - return frozenset(composite_loop_vars) + composite_scope_vars.append(s) + return frozenset(composite_scope_vars) - def _get_loop_vars(self, node, modified): - body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) + def _get_block_vars(self, node, modified): + """Determines the variables affected inside a control flow statement.""" defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN) live_in = anno.getanno(node, anno.Static.LIVE_VARS_IN) live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT) - reserved_symbols = body_scope.referenced - basic_loop_vars = self._get_basic_loop_vars(modified, live_in, live_out) - composite_loop_vars = self._get_composite_loop_vars(modified, live_in) - loop_vars = tuple(basic_loop_vars | composite_loop_vars) + basic_scope_vars = self._get_block_basic_vars( + modified, + live_in, + live_out) + composite_scope_vars = self._get_block_composite_vars(modified, live_in) + scope_vars = tuple(basic_scope_vars | composite_scope_vars) - # Variable that are used or defined inside the loop, but not defined - # before entering the loop. Only simple variables must be defined. The + # Variables that are modified inside the scope, but not defined + # before entering it. Only simple variables must be defined. The # composite ones will be implicitly checked at runtime. - undefined_lives = basic_loop_vars - defined_in + # This covers loop variables as well as variables that + undefined = tuple(v for v in modified - defined_in if not v.is_composite()) - return loop_vars, reserved_symbols, undefined_lives + # Variables that are modified inside the scope, and depend on values outside + # it. + input_only = basic_scope_vars & live_in - live_out + + # Place the outputs first. + scope_vars = sorted(scope_vars, key=lambda v: v in input_only) + nouts = len(scope_vars) - len(input_only) + + return scope_vars, undefined, nouts + + def visit_If(self, node): + node = self.generic_visit(node) + body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) + orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE) + + cond_vars, undefined, nouts = self._get_block_vars( + node, body_scope.modified | orelse_scope.modified) + + undefined_assigns = self._create_undefined_assigns(undefined) + + nonlocal_declarations = self._create_nonlocal_declarations(cond_vars) + + reserved = body_scope.referenced | orelse_scope.referenced + state_getter_name = self.ctx.namer.new_symbol('get_state', reserved) + state_setter_name = self.ctx.namer.new_symbol('set_state', reserved) + state_functions = self._create_state_functions( + cond_vars, nonlocal_declarations, state_getter_name, state_setter_name) + + orelse_body = node.orelse + if not orelse_body: + orelse_body = [gast.Pass()] + + template = """ + state_functions + def body_name(): + nonlocal_declarations + body + def orelse_name(): + nonlocal_declarations + orelse + undefined_assigns + ag__.if_stmt( + test, + body_name, + orelse_name, + state_getter_name, + state_setter_name, + (symbol_names,), + nouts) + """ + return templates.replace( + template, + body=node.body, + body_name=self.ctx.namer.new_symbol('if_body', reserved), + orelse=orelse_body, + orelse_name=self.ctx.namer.new_symbol('else_body', reserved), + nonlocal_declarations=nonlocal_declarations, + nouts=gast.Constant(nouts, kind=None), + state_functions=state_functions, + state_getter_name=state_getter_name, + state_setter_name=state_setter_name, + symbol_names=tuple(gast.Constant(str(s), kind=None) for s in cond_vars), + test=node.test, + undefined_assigns=undefined_assigns) def visit_While(self, node): node = self.generic_visit(node) body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) - loop_vars, reserved_symbols, possibly_undefs = self._get_loop_vars( - node, body_scope.modified) + loop_vars, undefined, _ = self._get_block_vars(node, body_scope.modified) - undefined_assigns = self._create_undefined_assigns(possibly_undefs) + undefined_assigns = self._create_undefined_assigns(undefined) nonlocal_declarations = self._create_nonlocal_declarations(loop_vars) - state_getter_name = self.ctx.namer.new_symbol('get_state', reserved_symbols) - state_setter_name = self.ctx.namer.new_symbol('set_state', reserved_symbols) + reserved = body_scope.referenced + state_getter_name = self.ctx.namer.new_symbol('get_state', reserved) + state_setter_name = self.ctx.namer.new_symbol('set_state', reserved) state_functions = self._create_state_functions( loop_vars, nonlocal_declarations, state_getter_name, state_setter_name) @@ -448,7 +283,7 @@ class ControlFlowTransformer(converter.Base): return templates.replace( template, body=node.body, - body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols), + body_name=self.ctx.namer.new_symbol('loop_body', reserved), nonlocal_declarations=nonlocal_declarations, opts=opts, state_functions=state_functions, @@ -456,7 +291,7 @@ class ControlFlowTransformer(converter.Base): state_setter_name=state_setter_name, symbol_names=tuple(gast.Constant(str(s), kind=None) for s in loop_vars), test=node.test, - test_name=self.ctx.namer.new_symbol('loop_test', reserved_symbols), + test_name=self.ctx.namer.new_symbol('loop_test', reserved), undefined_assigns=undefined_assigns) def visit_For(self, node): @@ -464,15 +299,16 @@ class ControlFlowTransformer(converter.Base): body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) iter_scope = anno.getanno(node, annos.NodeAnno.ITERATE_SCOPE) - loop_vars, reserved_symbols, possibly_undefs = self._get_loop_vars( + loop_vars, undefined, _ = self._get_block_vars( node, body_scope.modified | iter_scope.modified) - undefined_assigns = self._create_undefined_assigns(possibly_undefs) + undefined_assigns = self._create_undefined_assigns(undefined) nonlocal_declarations = self._create_nonlocal_declarations(loop_vars) - state_getter_name = self.ctx.namer.new_symbol('get_state', reserved_symbols) - state_setter_name = self.ctx.namer.new_symbol('set_state', reserved_symbols) + reserved = body_scope.referenced | iter_scope.referenced + state_getter_name = self.ctx.namer.new_symbol('get_state', reserved) + state_setter_name = self.ctx.namer.new_symbol('set_state', reserved) state_functions = self._create_state_functions( loop_vars, nonlocal_declarations, state_getter_name, state_setter_name) @@ -484,7 +320,7 @@ class ControlFlowTransformer(converter.Base): if anno.hasanno(node, anno.Basic.EXTRA_LOOP_TEST): extra_test = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST) extra_test_name = self.ctx.namer.new_symbol( - 'extra_test', reserved_symbols) + 'extra_test', reserved) template = """ def extra_test_name(): nonlocal_declarations @@ -502,7 +338,7 @@ class ControlFlowTransformer(converter.Base): # iterate_arg_name holds a single arg with the iterates, which may be a # tuple. - iterate_arg_name = self.ctx.namer.new_symbol('itr', reserved_symbols) + iterate_arg_name = self.ctx.namer.new_symbol('itr', reserved) template = """ iterates = iterate_arg_name """ @@ -529,7 +365,7 @@ class ControlFlowTransformer(converter.Base): return templates.replace( template, body=node.body, - body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols), + body_name=self.ctx.namer.new_symbol('loop_body', reserved), extra_test_function=extra_test_function, extra_test_name=extra_test_name, iterate_arg_name=iterate_arg_name, diff --git a/tensorflow/python/autograph/converters/control_flow_test.py b/tensorflow/python/autograph/converters/control_flow_test.py index 32e86400da6..935e2cec4b8 100644 --- a/tensorflow/python/autograph/converters/control_flow_test.py +++ b/tensorflow/python/autograph/converters/control_flow_test.py @@ -1,3 +1,4 @@ +# Lint as: python3 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -453,6 +454,17 @@ class IfStatementTest(ControlFlowTestBase): self.assertTransformedResult(test_fn, constant_op.constant(1), 5) self.assertTransformedResult(test_fn, constant_op.constant(-1), -1) + def test_local_remains_local(self): + + def test_fn(n): + if n > 0: + b = 4 + n = b + 1 + return n + + self.assertTransformedResult(test_fn, constant_op.constant(1), 5) + self.assertTransformedResult(test_fn, constant_op.constant(-1), -1) + def test_no_outputs(self): def test_fn(n): @@ -465,6 +477,85 @@ class IfStatementTest(ControlFlowTestBase): self.assertTransformedResult(test_fn, constant_op.constant(1), 1) self.assertTransformedResult(test_fn, constant_op.constant(-1), -1) + def test_created_outputs(self): + + def test_fn(i): + if i == 0: + result = i - 1 + else: + result = i + 1 + return result + + self.assertTransformedResult(test_fn, 0, -1) + self.assertTransformedResult(test_fn, 1, 2) + + def test_created_loop_local_outputs(self): + + def test_fn(n, x): + for i in n: + if i == 0: + result = i - 1 + else: + result = i + 1 + if result > 0: + x += 1 + return x + + self.assertTransformedResult(test_fn, (range(5), 10), 14) + + def test_created_loop_variable(self): + + def test_fn(n, x): + for i in n: + if i == 0: + result = i - 1 + if i > 0: # Using the result from previous iteration. + if result < 0: + x += 1 + return x + + self.assertTransformedResult(test_fn, (range(5), 10), 14) + + def test_unaffected_global(self): + + def test_fn(i): + global g # pylint:disable=global-variable-undefined + if i == 0: + g = i - 1 + return g + + self.assertTransformedResult(test_fn, 1, 3, symbols={'g': 3}) + self.assertTransformedResult(test_fn, 0, -1, symbols={'g': 3}) + + def test_unaffected_nonlocal(self): + + def test_fn(i): + def inner_fn(): + nonlocal n + if i == 0: + n = i - 1 + + n = 3 + inner_fn() + return n + + self.assertTransformedResult(test_fn, 1, 3) + self.assertTransformedResult(test_fn, 0, -1) + + def test_output_defined_in_prior_except(self): + + def test_fn(i): + try: + raise ValueError() + except ValueError: + x = 1 + if i == 0: + x = i - 1 + return x + + self.assertTransformedResult(test_fn, 1, 1) + self.assertTransformedResult(test_fn, 0, -1) + def test_unbalanced_multiple_composites(self): class Foo(object): diff --git a/tensorflow/python/autograph/converters/variables.py b/tensorflow/python/autograph/converters/variables.py index 3028a65a69b..9784f50ed56 100644 --- a/tensorflow/python/autograph/converters/variables.py +++ b/tensorflow/python/autograph/converters/variables.py @@ -60,6 +60,31 @@ class VariableAccessTransformer(converter.Base): node = templates.replace_as_expression('ag__.ld(var_)', var_=node) return node + def visit_Delete(self, node): + node = self.generic_visit(node) + + rewrite_targets = [] + for tgt in node.targets: + # Don't rewrite composites like `del a[0]`. + if isinstance(tgt, gast.Name): + rewrite_targets.append(tgt) + + if not rewrite_targets: + return node + + results = [] + for tgt in rewrite_targets: + template = """ + var_ = ag__.Undefined(var_name) + """ + results.extend(templates.replace( + template, var_=tgt, var_name=gast.Constant(tgt.id, kind=None))) + remaining_targets = [n for n in node.targets if n not in rewrite_targets] + if remaining_targets: + results.append(gast.Delete(targets=remaining_targets)) + + return results + def visit_AugAssign(self, node): if isinstance(node.target, gast.Name): template = """ diff --git a/tensorflow/python/autograph/converters/variables_test.py b/tensorflow/python/autograph/converters/variables_test.py index 556dafbaa8a..93a31e63de3 100644 --- a/tensorflow/python/autograph/converters/variables_test.py +++ b/tensorflow/python/autograph/converters/variables_test.py @@ -51,6 +51,90 @@ class VariablesTest(converter_testing.TestCase): with self.apply_add_one_conversion(test_fn) as result: self.assertEqual(result.test_fn(1), (1 + 1) * 10 + 1) # two reads + def test_del(self): + + def test_fn(l): + del l + return l + + with self.converted(test_fn, variables, {}) as result: + with self.assertRaisesRegex( + NameError, "'l' is used before assignment"): + result.test_fn(1) + + def test_del_getitem_ignored(self): + + def basic_slice(l): + del l[0] + return l + + with self.converted(basic_slice, variables, {}) as result: + self.assertListEqual([2], result.basic_slice([1, 2])) + + def range_slice(l): + del l[0:2] + return l + + with self.converted(range_slice, variables, {}) as result: + self.assertListEqual([], result.range_slice([1, 2])) + + def test_del_getattr_ignored(self): + + def test_fn(l): + del l.a + return l + + class TestClass(object): + + def __init__(self): + self.a = 1 + self.b = 2 + + with self.converted(test_fn, variables, {}) as result: + self.assertFalse(hasattr(result.test_fn(TestClass()), 'a')) + self.assertEqual(result.test_fn(TestClass()).b, 2) + + def test_del_packing_ignored(self): + # Note: test for UnboundLocalError, not NameError because in this case we + # don't rewrite the del. + + def list_(a, b): + del [a, b] + return a + + with self.converted(list_, variables, {}) as result: + with self.assertRaises(UnboundLocalError): + result.list_(1, 2) + + def nested(a, b, c): + del [a, (b, c)] + return c + + with self.converted(nested, variables, {}) as result: + with self.assertRaises(UnboundLocalError): + result.nested(1, 2, 3) + + def test_del_item_multiple_mixed(self): + + def test_fn_failing(a, b, c): + del a, b, c[0] + a = 1 + return a, b, c + + with self.converted(test_fn_failing, variables, {}) as result: + with self.assertRaisesRegex( + NameError, "'b' is used before assignment"): + result.test_fn_failing(1, 2, [1, 2]) + + def test_fn_passing(a, b, c): + del a, b, c[0] + a = 1 + b = 2 + return c + + with self.converted(test_fn_passing, variables, {}) as result: + self.assertListEqual([2], result.test_fn_passing(1, 2, [1, 2])) + def test_attribute(self): class TestClass(object): diff --git a/tensorflow/python/autograph/impl/api.py b/tensorflow/python/autograph/impl/api.py index 3ebb5824b7f..98e19fdde86 100644 --- a/tensorflow/python/autograph/impl/api.py +++ b/tensorflow/python/autograph/impl/api.py @@ -18,13 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections -import copy import functools import inspect import os -import pdb -import re import sys import textwrap import traceback @@ -344,6 +340,15 @@ def _call_unconverted(f, args, kwargs, options, update_cache=True): return f(*args) +def _is_of_known_loaded_module(f, module_name): + mod = sys.modules.get(module_name, None) + if mod is None: + return False + if any(v is not None for v in mod.__dict__.values() if f is v): + return True + return False + + def _is_known_loaded_type(f, module_name, entity_name): """Tests whether the function or method is an instance of a known type.""" if (module_name not in sys.modules or @@ -511,7 +516,8 @@ def converted_call(f, # Other built-in modules are permanently whitelisted. # TODO(mdan): Figure out how to do this consistently for all stdlib modules. if any( - f in m.__dict__.values() for m in (collections, pdb, copy, inspect, re)): + _is_of_known_loaded_module(f, m) + for m in ('collections', 'pdb', 'copy', 'inspect', 're')): logging.log(2, 'Permanently whitelisted: %s: part of builtin module', f) return _call_unconverted(f, args, kwargs, options) diff --git a/tensorflow/python/autograph/impl/api_py3_test.py b/tensorflow/python/autograph/impl/api_py3_test.py index df6544928bf..c460e478008 100644 --- a/tensorflow/python/autograph/impl/api_py3_test.py +++ b/tensorflow/python/autograph/impl/api_py3_test.py @@ -19,6 +19,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import os from tensorflow.python.autograph.core import converter @@ -60,6 +61,23 @@ class ApiTest(test.TestCase): self.assertEqual(5, tc.no_arg(2)) + def test_converted_call_avoids_triggering_operators(self): + + test_self = self + + class Pair(collections.namedtuple('Pair', ['a', 'b'])): + + def __call__(self): + return self.a + self.b + + def __eq__(self, other): + test_self.fail('Triggered operator') + + p = Pair(constant_op.constant(1), constant_op.constant(2)) + + x = api.converted_call(p, (), {}, options=DEFAULT_RECURSIVE) + self.assertIsNotNone(self.evaluate(x), 3) + if __name__ == '__main__': os.environ['AUTOGRAPH_STRICT_CONVERSION'] = '1' diff --git a/tensorflow/python/autograph/operators/BUILD b/tensorflow/python/autograph/operators/BUILD index 3851c7b44ba..5f644ea525d 100644 --- a/tensorflow/python/autograph/operators/BUILD +++ b/tensorflow/python/autograph/operators/BUILD @@ -22,6 +22,7 @@ py_library( name = "operators", srcs = [ "__init__.py", + "conditional_expressions.py", "control_flow.py", "control_flow_deprecated_py2.py", "data_structures.py", @@ -62,6 +63,20 @@ py_test( ], ) +py_test( + name = "conditional_expressions_test", + srcs = ["conditional_expressions_test.py"], + python_version = "PY3", + srcs_version = "PY3", + tags = [ + "no_oss_py2", + ], + deps = [ + ":operators", + "//tensorflow/python:client_testlib", + ], +) + py_test( name = "control_flow_test", srcs = ["control_flow_test.py"], diff --git a/tensorflow/python/autograph/operators/__init__.py b/tensorflow/python/autograph/operators/__init__.py index f7f9078107c..8ac4e1d8bb3 100644 --- a/tensorflow/python/autograph/operators/__init__.py +++ b/tensorflow/python/autograph/operators/__init__.py @@ -37,6 +37,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.autograph.operators.conditional_expressions import if_exp from tensorflow.python.autograph.operators.control_flow import for_stmt from tensorflow.python.autograph.operators.control_flow import if_stmt from tensorflow.python.autograph.operators.control_flow import while_stmt diff --git a/tensorflow/python/autograph/operators/conditional_expressions.py b/tensorflow/python/autograph/operators/conditional_expressions.py new file mode 100644 index 00000000000..7ea2b249935 --- /dev/null +++ b/tensorflow/python/autograph/operators/conditional_expressions.py @@ -0,0 +1,56 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Conditional expressions (e.g. the ternary if statement).""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +from tensorflow.python.autograph.operators import control_flow +from tensorflow.python.autograph.utils import tensors +from tensorflow.python.ops import control_flow_ops + + +def if_exp(cond, if_true, if_false, expr_repr): + if tensors.is_dense_tensor(cond): + return _tf_if_exp(cond, if_true, if_false, expr_repr) + else: + return _py_if_exp(cond, if_true, if_false) + + +def _tf_if_exp(cond, if_true, if_false, expr_repr): + """Overload of if_exp that stages a TF cond.""" + # TODO(mdan): Use nonlocal once we no longer need to support py2. + true_val = [] + false_val = [] + + def true_fn(): + true_val.append(if_true()) + if true_val and false_val: + control_flow.verify_single_cond_var(expr_repr, true_val[0], false_val[0]) + return true_val[0] + + def false_fn(): + false_val.append(if_false()) + if true_val and false_val: + control_flow.verify_single_cond_var(expr_repr, true_val[0], false_val[0]) + return false_val[0] + + return control_flow_ops.cond(cond, true_fn, false_fn) + + +def _py_if_exp(cond, if_true, if_false): + return if_true() if cond else if_false() diff --git a/tensorflow/python/autograph/operators/conditional_expressions_test.py b/tensorflow/python/autograph/operators/conditional_expressions_test.py new file mode 100644 index 00000000000..3f126116023 --- /dev/null +++ b/tensorflow/python/autograph/operators/conditional_expressions_test.py @@ -0,0 +1,66 @@ +# Lint as: python3 +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for conditional_expressions module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.autograph.operators import conditional_expressions +from tensorflow.python.eager import def_function +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test + + +def _basic_expr(cond): + return conditional_expressions.if_exp( + cond, + lambda: constant_op.constant(1), + lambda: constant_op.constant(2), + 'cond') + + +@test_util.run_all_in_graph_and_eager_modes +class IfExpTest(test.TestCase): + + def test_tensor(self): + self.assertEqual(self.evaluate(_basic_expr(constant_op.constant(True))), 1) + self.assertEqual(self.evaluate(_basic_expr(constant_op.constant(False))), 2) + + def test_tensor_mismatched_type(self): + # tf.function required because eager cond degenerates to Python if. + @def_function.function + def test_fn(): + conditional_expressions.if_exp( + constant_op.constant(True), lambda: 1.0, lambda: 2, 'expr_repr') + + with self.assertRaisesRegexp( + TypeError, + "'expr_repr' has dtype float32 in the main.*int32 in the else"): + test_fn() + + def test_python(self): + self.assertEqual(self.evaluate(_basic_expr(True)), 1) + self.assertEqual(self.evaluate(_basic_expr(False)), 2) + self.assertEqual( + conditional_expressions.if_exp(True, lambda: 1, lambda: 2, ''), 1) + self.assertEqual( + conditional_expressions.if_exp(False, lambda: 1, lambda: 2, ''), 2) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/autograph/operators/control_flow.py b/tensorflow/python/autograph/operators/control_flow.py index 592281b0ce2..77db7579ece 100644 --- a/tensorflow/python/autograph/operators/control_flow.py +++ b/tensorflow/python/autograph/operators/control_flow.py @@ -102,7 +102,7 @@ def _verify_loop_init_vars(values, symbol_names): """Ensures that all values in the state are defined when entering a loop.""" for name, value in zip(symbol_names, values): if value is None: - raise ValueError('"{}" may not be None before the loop.'.format(name)) + raise ValueError("'{}' may not be None before the loop.".format(name)) if isinstance(value, variables.UndefinedReturnValue): # Assumption: the loop will only capture the variable which tracks the # return value if the loop contained a return statement. @@ -110,7 +110,7 @@ def _verify_loop_init_vars(values, symbol_names): raise ValueError( 'return statements are not supported within a TensorFlow loop.') if isinstance(value, variables.Undefined): - raise ValueError('"{}" must be defined before the loop.'.format(name)) + raise ValueError("'{}' must be defined before the loop.".format(name)) def _is_subshape(left, right): @@ -133,9 +133,9 @@ def _is_subshape(left, right): def _verify_single_loop_var( name, check_shape, init, entry, exit_, shape_invariant): """Verifies whether the initial, entry and exit values are consistent.""" - assert entry is not None, 'no TF op should set "{}" to None?'.format(name) + assert entry is not None, "no TF op should set '{}' to None?".format(name) if exit_ is None: - raise ValueError('"{}" is None at the end of the iteration.'.format(name)) + raise ValueError("'{}' is None at the end of the iteration.".format(name)) if isinstance(init, (bool, int, float, str, np.ndarray)): init = ops.convert_to_tensor_v2(init) @@ -158,9 +158,8 @@ def _verify_single_loop_var( if entry.dtype != exit_.dtype: raise TypeError( - '"{}" has dtype {} before the loop, but dtype {} after one' - ' iteration. TensorFlow control flow requires it stays the' - ' same.'.format( + "'{}' has dtype {} before the loop, but dtype {} after one" + ' iteration'.format( name, entry.dtype.name, exit_.dtype.name, @@ -171,19 +170,19 @@ def _verify_single_loop_var( entry_shape = entry.shape if not _is_subshape(exit_shape, entry_shape): raise ValueError( - '"{}" has shape {} before the loop, but shape {} after one' + "'{}' has shape {} before the loop, but shape {} after one" ' iteration. Use tf.autograph.experimental.set_loop_options to set' ' shape invariants.'.format(name, entry_shape, exit_shape)) else: init_shape = init.shape if not _is_subshape(init_shape, shape_invariant): raise ValueError( - '"{}" has shape {} before the loop, which does not conform with' + "'{}' has shape {} before the loop, which does not conform with" ' the shape invariant {}.'.format(name, init_shape, shape_invariant)) if not _is_subshape(exit_shape, shape_invariant): raise ValueError( - '"{}" has shape {} after one iteration, which does not conform with' + "'{}' has shape {} after one iteration, which does not conform with" ' the shape invariant {}.'.format( name, exit_shape, shape_invariant)) @@ -216,13 +215,13 @@ def _verify_tf_loop_vars(init_vars, nest.assert_same_structure(init, entry, expand_composites=True) nest.assert_same_structure(entry, exit_, expand_composites=True) except (ValueError, TypeError) as e: - raise TypeError('"{}" does not have the same nested structure after one' + raise TypeError("'{}' does not have the same nested structure after one" ' iteration.\n\n{}'.format(name, e)) if invariant is not None: try: nest.assert_same_structure(init, invariant, expand_composites=False) except (ValueError, TypeError) as e: - raise TypeError('"{}" does not have the same nested structure as its' + raise TypeError("'{}' does not have the same nested structure as its" ' corresponding shape invariant.\n\n{}'.format(name, e)) nest.map_structure( @@ -230,13 +229,13 @@ def _verify_tf_loop_vars(init_vars, entry, exit_, invariant) -def _verify_single_cond_var(name, body_var, orelse_var): +def verify_single_cond_var(name, body_var, orelse_var): """Verifies whether body_var and orelse_var are consistent.""" if body_var is None: - raise ValueError('"{}" is None at the end of the TRUE branch.'.format(name)) + raise ValueError("'{}' is None at the end of the main branch.".format(name)) if orelse_var is None: raise ValueError( - '"{}" is None at the end of the FALSE branch.'.format(name)) + "'{}' is None at the end of the else branch.".format(name)) if isinstance(body_var, (bool, int, float, str, np.ndarray)): body_var = ops.convert_to_tensor_v2(body_var) @@ -255,41 +254,37 @@ def _verify_single_cond_var(name, body_var, orelse_var): if body_var.dtype != orelse_var.dtype: raise TypeError( - '"{}" has dtype {} in the TRUE branch, but dtype={} in the FALSE' - ' branch. TensorFlow control flow requires that they are the' - ' same.'.format(name, body_var.dtype.name, - orelse_var.dtype.name)) + "'{}' has dtype {} in the main branch, but dtype {} in the else" + ' branch'.format(name, body_var.dtype.name, + orelse_var.dtype.name)) + + +def _verify_tf_cond_branch_vars(vars_, symbol_names, branch_name): + """Verifies variables output by a conditional branch for consistency.""" + for name, var_ in zip(symbol_names, vars_): + if isinstance(var_, variables.Undefined): + raise ValueError( + "'{}' must also be initialized in the {} branch".format( + name, branch_name)) + if isinstance(var_, variables.UndefinedReturnValue): + raise ValueError( + 'the {} branch must also have a return statement.'.format( + branch_name)) def _verify_tf_cond_vars(body_vars, orelse_vars, symbol_names): """Verifies variables manipulated by a conditional for consistency.""" - basic_body_vars, composite_body_vars = body_vars - basic_orelse_vars, composite_orelse_vars = orelse_vars - assert isinstance(composite_body_vars, tuple) - assert isinstance(composite_orelse_vars, tuple) - - # TODO(kkb): Make this more consistent. - # The basic outputs should always be a tuple. - if not isinstance(basic_body_vars, tuple): - basic_body_vars = (basic_body_vars,) - if not isinstance(basic_orelse_vars, tuple): - basic_orelse_vars = (basic_orelse_vars,) - - body_vars = basic_body_vars + composite_body_vars - orelse_vars = basic_orelse_vars + composite_orelse_vars - named_vars = zip(symbol_names, body_vars, orelse_vars) + for name, body_var, orelse_var in named_vars: try: - nest.assert_same_structure( - body_var, orelse_var, expand_composites=True) + nest.assert_same_structure(body_var, orelse_var, expand_composites=True) except (ValueError, TypeError) as e: raise TypeError( - '"{}" does not have the same nested structure in the TRUE and FALSE' - ' branches.\n\n{}'.format(name, str(e))) - + "'{}' must have the same nested structure in the main and else" + ' branches:\n\n{}'.format(name, str(e))) nest.map_structure( - functools.partial(_verify_single_cond_var, name), body_var, orelse_var) + functools.partial(verify_single_cond_var, name), body_var, orelse_var) def for_stmt(iter_, extra_test, body, get_state, set_state, symbol_names, opts): @@ -314,12 +309,16 @@ def for_stmt(iter_, extra_test, body, get_state, set_state, symbol_names, opts): `extra_test`, `body`, `get_state` and `set_state` functions must bind to the original `geo_mean` and `arith_mean` symbols, using `nonlocal`. + The inputs and outputs of the callables representing the loop blocks are not + explicit - instead, these functions must use nonlocal/global for side effects. + The inputs and outputs are instead controlled by the set_state/get_state + functions. + Args: iter_: The entity being iterated over. - extra_test: Callable with the state as arguments, and boolean return type. + extra_test: Callable with boolean return type. An additional loop condition. - body: Callable with the iterate and the state as arguments, and state as - return type. The actual loop body. + body: Callable representing the actual loop body. get_state: Additional callable which can capture additional state (such as the values of composite symbols). This is only useful when staging the loop. @@ -717,11 +716,14 @@ def while_stmt(test, body, get_state, set_state, symbol_names, opts): a tuple of entities that represent an actual state, or a list of arguments of the corresponding types. + The inputs and outputs of the callables representing the loop blocks are not + explicit - instead, these functions must use nonlocal/global for side effects. + The inputs and outputs are instead controlled by the set_state/get_state + functions. + Args: - test: Callable with the state as arguments, and boolean return type. The - loop condition. - body: Callable with the state as arguments, and state as return type. The - actual loop body. + test: Callable with boolean return type. The loop condition. + body: Callable representing the actual loop body. get_state: Additional callable which can capture additional state (such as the values of composite symbols). This is only useful when staging the loop. @@ -894,21 +896,32 @@ def _tf_while_stmt(test, body, get_state, set_state, symbol_names, opts): set_state(final_loop_vars) -def if_stmt(cond, - body, - orelse, - get_state, - set_state, - basic_symbol_names, - composite_symbol_names): +def if_stmt(cond, body, orelse, get_state, set_state, symbol_names, nouts): """Functional form of an if statement. + The conditional operates on a state, which includes all symbols whose values + are a function of the branch taken. + + For example, given the code below that calculates the abs function: + + ``` + x = 1 + if x > 0: + x = -x + ``` + + The state is represented by the variable `x`. The `body, `orelse` and + `set_state` functions must bind to the original `x` symbol, using `nonlocal`. + + The inputs and outputs of the callables representing the loop blocks are not + explicit - instead, these functions must use nonlocal/global for side effects. + The inputs and outputs are instead controlled by the set_state/get_state + functions. + Args: cond: Boolean. - body: Callable with no arguments, and outputs of the positive (if) branch as - return type. - orelse: Callable with no arguments, and outputs of the negative (else) - branch as return type. + body: Callable representing the main block of the conditional. + orelse: Callable representing the else block of the conditional. get_state: Function that returns a tuple containing the values of all composite symbols modified within the conditional. This allows access to state that branches may mutate through side effects. This function is not @@ -920,123 +933,63 @@ def if_stmt(cond, restore checkpointed values. The single argument a tuple containing values for each composite symbol that may be modified in a branch of the conditional. The is usually the result of a call to get_state. - basic_symbol_names: Tuple containing basic loop var names. - composite_symbol_names: Tuple containing composite loop var names. - - Returns: - Tuple containing the statement outputs. + symbol_names: Tuple containing basic loop var names. + nouts: Number of variables output by the statement. Vars which are + not outputs will not be passed through staged control flow such as + tf.cond. This includes variables that are defined before the conditional, + but are not used after it. """ # Note: tf.cond doesn't support SparseTensor. if tensors.is_dense_tensor(cond): - return tf_if_stmt(cond, body, orelse, get_state, set_state, - basic_symbol_names, composite_symbol_names) + _tf_if_stmt(cond, body, orelse, get_state, set_state, symbol_names, nouts) else: - return _py_if_stmt(cond, body, orelse) + _py_if_stmt(cond, body, orelse) -def tf_if_stmt(cond, body, orelse, get_state, set_state, basic_symbol_names, - composite_symbol_names): +def _tf_if_stmt( + cond, body, orelse, get_state, set_state, symbol_names, nouts): """Overload of if_stmt that stages a TF cond.""" - body = _wrap_disallow_undefs_from_cond(body, branch_name='if') - orelse = _wrap_disallow_undefs_from_cond(orelse, branch_name='else') - body = _isolate_state(body, get_state, set_state) - orelse = _isolate_state(orelse, get_state, set_state) + if not nouts: + prev_get_state, prev_set_state = get_state, set_state + # Control flow V1 wants at least one output. + get_state = lambda: (0,) + prev_get_state() + set_state = lambda v: prev_set_state(v[1:]) + symbol_names += ('<unused dummy>',) + nouts = 1 - # `state` currently includes the values of any composite symbols (e.g. `a.b`) - # composites modified by the loop. `final_vars` includes the values of basic - # symbols (e.g. `a`) which cannot be passed by reference and must be returned. - # See _isolate_state. - # TODO(mdan): We should minimize calls to get/set_state. + init_vars = get_state() - body_branch = 0 - orelse_branch = 1 - result = [None, None] + # TODO(mdan): Use nonlocal once we no longer need to support py2. + new_body_vars_ = [None] + new_orelse_vars_ = [None] - def error_checking_body(): - result[body_branch] = body() - if result[orelse_branch] is not None: - _verify_tf_cond_vars(result[body_branch], result[orelse_branch], - basic_symbol_names + composite_symbol_names) - return result[body_branch] + def aug_body(): + set_state(init_vars) + body() + new_body_vars = get_state() + new_body_vars = new_body_vars[:nouts] + new_body_vars_[0] = new_body_vars + _verify_tf_cond_branch_vars(new_body_vars, symbol_names, 'main') + if new_orelse_vars_[0] is not None: + _verify_tf_cond_vars(new_body_vars, new_orelse_vars_[0], symbol_names) + return new_body_vars - def error_checking_orelse(): - result[orelse_branch] = orelse() - if result[body_branch] is not None: - _verify_tf_cond_vars(result[body_branch], result[orelse_branch], - basic_symbol_names + composite_symbol_names) - return result[orelse_branch] + def aug_orelse(): + set_state(init_vars) + orelse() + new_orelse_vars = get_state() + new_orelse_vars = new_orelse_vars[:nouts] + new_orelse_vars_[0] = new_orelse_vars + _verify_tf_cond_branch_vars(new_orelse_vars, symbol_names, 'else') + if new_body_vars_[0] is not None: + _verify_tf_cond_vars(new_body_vars_[0], new_orelse_vars, symbol_names) + return new_orelse_vars - final_vars, final_state = control_flow_ops.cond(cond, error_checking_body, - error_checking_orelse) + final_cond_vars = control_flow_ops.cond( + cond, aug_body, aug_orelse, strict=True) + final_cond_vars = final_cond_vars + init_vars[nouts:] - set_state(final_state) - - return final_vars - - -def _isolate_state(func, get_state, set_state): - """Wraps func to (best-effort) isolate state mutations that func may do. - - The simplest example of state mutation is mutation of variables (via e.g. - attributes), or modification of globals. - - This allows us to more safely execute this function without worrying about - side effects when the function wasn't normally expected to execute. For - example, staging requires that the function is executed ahead of time, and - we need to ensure its effects are not observed during normal execution. - - Args: - func: () -> Any - get_state: () -> Any, returns the current state - set_state: (Any) -> None, resets the state to the specified values. - Typically the result of an earlier call to `get_state`. - - Returns: - Tuple[Any, Any], where the first element is the return value of `func`, - and the second is the final state values. - """ - - def wrapper(): - init_state = get_state() - new_vars = func() - # TODO(mdan): These should be copies, lest set_state might affect them. - new_state = get_state() - set_state(init_state) - return new_vars, new_state - - return wrapper - - -def _wrap_disallow_undefs_from_cond(func, branch_name): - """Wraps conditional branch to disallow returning undefined symbols.""" - - def wrapper(): - """Calls function and raises an error if undefined symbols are returned.""" - results = func() - - if isinstance(results, tuple): - results_tuple = results - else: - results_tuple = results, - - for result in results_tuple: - if isinstance(result, variables.UndefinedReturnValue): - raise ValueError( - 'A value must also be returned from the {} branch. If a value is ' - 'returned from one branch of a conditional a value must be ' - 'returned from all branches.'.format(branch_name)) - - undefined = [v for v in results_tuple if isinstance(v, variables.Undefined)] - if undefined: - raise ValueError( - 'The following symbols must also be initialized in the {} branch: {}.' - ' Alternatively, you may initialize them before the if' - ' statement.'.format(branch_name, - tuple(s.symbol_name for s in undefined))) - - return results - - return wrapper + set_state(final_cond_vars) def _py_if_stmt(cond, body, orelse): diff --git a/tensorflow/python/autograph/operators/control_flow_test.py b/tensorflow/python/autograph/operators/control_flow_test.py index 1c4407904b2..57288be9a9f 100644 --- a/tensorflow/python/autograph/operators/control_flow_test.py +++ b/tensorflow/python/autograph/operators/control_flow_test.py @@ -543,21 +543,21 @@ class ForLoopTest(test.TestCase): return s def test_tensor_illegal_input(self): - with self.assertRaisesRegex(ValueError, '"s" may not be None'): + with self.assertRaisesRegex(ValueError, '\'s\' may not be None'): self._basic_loop(None, lambda i, s: s) - with self.assertRaisesRegex(ValueError, '"s" must be defined'): + with self.assertRaisesRegex(ValueError, '\'s\' must be defined'): self._basic_loop(variable_operators.Undefined(''), lambda i, s: s) def test_tensor_none_output(self): - with self.assertRaisesRegex(ValueError, '"s" is None at the end'): + with self.assertRaisesRegex(ValueError, '\'s\' is None at the end'): self._basic_loop(0, lambda i, s: None) def test_tensor_dtype_change(self): - with self.assertRaisesRegex(TypeError, '"s".* dtype float32 after'): + with self.assertRaisesRegex(TypeError, '\'s\'.* dtype float32 after'): self._basic_loop(0, lambda i, s: 1.0) def test_tensor_shape_change(self): - with self.assertRaisesRegex(ValueError, r'"s".* shape \(1,\) after'): + with self.assertRaisesRegex(ValueError, r'\'s\'.* shape \(1,\) after'): self._basic_loop(0, lambda i, s: np.array([1], dtype=np.int32)) @@ -782,21 +782,21 @@ class WhileLoopTest(test.TestCase): return s def test_tensor_illegal_input(self): - with self.assertRaisesRegex(ValueError, '"s" may not be None'): + with self.assertRaisesRegex(ValueError, "'s' may not be None"): self._basic_loop(None, lambda i, s: s) - with self.assertRaisesRegex(ValueError, '"s" must be defined'): + with self.assertRaisesRegex(ValueError, "'s' must be defined"): self._basic_loop(variable_operators.Undefined(''), lambda i, s: s) def test_tensor_none_output(self): - with self.assertRaisesRegex(ValueError, '"s" is None at the end'): + with self.assertRaisesRegex(ValueError, "'s' is None at the end"): self._basic_loop(0, lambda i, s: None) def test_tensor_dtype_change(self): - with self.assertRaisesRegex(TypeError, '"s".* dtype float32 after'): + with self.assertRaisesRegex(TypeError, "'s'.* dtype float32 after"): self._basic_loop(0, lambda i, s: 1.0) def test_tensor_shape_change(self): - with self.assertRaisesRegex(ValueError, r'"s".* shape \(1,\) after'): + with self.assertRaisesRegex(ValueError, r"'s'.* shape \(1,\) after"): self._basic_loop(0, lambda i, s: np.array([1], dtype=np.int32)) @@ -806,29 +806,88 @@ class IfStmtTest(test.TestCase): def test_tensor(self): def test_fn(cond): - return control_flow.if_stmt( + def body(): + nonlocal i + i = constant_op.constant(1) + + def orelse(): + nonlocal i + i = constant_op.constant(-1) + + def set_state(cond_vars): + nonlocal i + i, = cond_vars + + i = None + control_flow.if_stmt( cond=cond, - body=lambda: constant_op.constant(1), - orelse=lambda: constant_op.constant(-1), - get_state=lambda: (), - set_state=lambda _: None, - basic_symbol_names=('_',), - composite_symbol_names=()) + body=body, + orelse=orelse, + get_state=lambda: (i,), + set_state=set_state, + symbol_names=('i',), + nouts=1) + return i self.assertEqual(1, self.evaluate(test_fn(constant_op.constant(True)))) self.assertEqual(-1, self.evaluate(test_fn(constant_op.constant(False)))) + def test_tensor_no_outputs(self): + + def test_fn(cond): + def body(): + nonlocal i + i = constant_op.constant(1) + + def orelse(): + nonlocal i + i = constant_op.constant(-1.0) + + def set_state(cond_vars): + nonlocal i + i, = cond_vars + + i = None + control_flow.if_stmt( + cond=cond, + body=body, + orelse=orelse, + get_state=lambda: (i,), + set_state=set_state, + symbol_names=('i',), + nouts=0) + return i + + self.assertEqual(None, test_fn(constant_op.constant(True))) + self.assertEqual(None, test_fn(constant_op.constant(False))) + def test_tensor_multiple_returns(self): def test_fn(cond): - return control_flow.if_stmt( + def body(): + nonlocal i, j + i = constant_op.constant(1) + j = constant_op.constant(2) + + def orelse(): + nonlocal i, j + i = constant_op.constant(-1) + j = constant_op.constant(-2) + + def set_state(cond_vars): + nonlocal i, j + i, j = cond_vars + + i, j = None, None + control_flow.if_stmt( cond=cond, - body=lambda: (constant_op.constant(1), constant_op.constant(2)), - orelse=lambda: (constant_op.constant(-1), constant_op.constant(-2)), - get_state=lambda: (), - set_state=lambda _: None, - basic_symbol_names=('_',), - composite_symbol_names=()) + body=body, + orelse=orelse, + get_state=lambda: (i, j), + set_state=set_state, + symbol_names=('i', 'j'), + nouts=2) + return i, j self.assertEqual((1, 2), self.evaluate(test_fn(constant_op.constant(True)))) self.assertEqual((-1, -2), @@ -837,14 +896,24 @@ class IfStmtTest(test.TestCase): def test_python(self): def test_fn(cond): - return control_flow.if_stmt( + def body(): + nonlocal i + i = 1 + + def orelse(): + nonlocal i + i = -1 + + i = None + control_flow.if_stmt( cond=cond, - body=lambda: 1, - orelse=lambda: -1, - get_state=lambda: (), - set_state=lambda _: None, - basic_symbol_names=('_',), - composite_symbol_names=()) + body=body, + orelse=orelse, + get_state=None, + set_state=None, + symbol_names=('i',), + nouts=1) + return i self.assertEqual(1, test_fn(True)) self.assertEqual(-1, test_fn(False)) @@ -852,48 +921,75 @@ class IfStmtTest(test.TestCase): def test_python_multiple_returns(self): def test_fn(cond): - return control_flow.if_stmt( + def body(): + nonlocal i, j + i = 1 + j = 2 + + def orelse(): + nonlocal i, j + i = -1 + j = -2 + + i, j = None, None + control_flow.if_stmt( cond=cond, - body=lambda: (1, 2), - orelse=lambda: (-1, -2), - get_state=lambda: (), - set_state=lambda _: None, - basic_symbol_names=('_',), - composite_symbol_names=()) + body=body, + orelse=orelse, + get_state=None, + set_state=None, + symbol_names=('i', 'j'), + nouts=2) + return i, j self.assertEqual((1, 2), test_fn(True)) self.assertEqual((-1, -2), test_fn(False)) - def _basic_cond(self, true_value, false_value): + def _basic_cond(self, body_fn, else_fn): + def body(): + nonlocal x + x = body_fn() + + def orelse(): + nonlocal x + x = else_fn() + + def set_state(cond_vars): + nonlocal x + x, = cond_vars + + x = 0 # Eager cond had different semantics, we don't test those here. with func_graph.FuncGraph('tmp').as_default(): - return control_flow.if_stmt( + control_flow.if_stmt( cond=constant_op.constant(True), - body=true_value, - orelse=false_value, - get_state=lambda: (), - set_state=lambda _: None, - basic_symbol_names=('s',), - composite_symbol_names=()) + body=body, + orelse=orelse, + get_state=lambda: (x,), + set_state=set_state, + symbol_names=('x',), + nouts=1) + return x def test_tensor_none_output(self): with self.assertRaisesRegex( - ValueError, '"s" is None at the end of the TRUE branch'): + ValueError, "'x' is None at the end of the main branch"): self._basic_cond(lambda: None, lambda: 1) with self.assertRaisesRegex( - ValueError, '"s" is None at the end of the FALSE branch'): + ValueError, "'x' is None at the end of the else branch"): self._basic_cond(lambda: 1, lambda: None) def test_tensor_undefined_output(self): with self.assertRaisesRegex( - ValueError, "must also be initialized in the if.*'s'"): - self._basic_cond(lambda: variable_operators.Undefined('s'), lambda: 1) + ValueError, "'x' must also be initialized in the main branch"): + self._basic_cond(lambda: variable_operators.Undefined('x'), lambda: 1) with self.assertRaisesRegex( - ValueError, "must also be initialized in the else.*'s'"): + ValueError, "'x' must also be initialized in the else branch"): self._basic_cond(lambda: 1, lambda: variable_operators.Undefined('s')) def test_tensor_dtype_change(self): - with self.assertRaisesRegex(TypeError, '"s" has dtype int32.*but.*float32'): + with self.assertRaisesRegex( + TypeError, "'x' has dtype int32.*but.*float32"): self._basic_cond(lambda: 1, lambda: 1.0) diff --git a/tensorflow/python/autograph/pyct/qual_names.py b/tensorflow/python/autograph/pyct/qual_names.py index f97e595d1dc..d9491691567 100644 --- a/tensorflow/python/autograph/pyct/qual_names.py +++ b/tensorflow/python/autograph/pyct/qual_names.py @@ -41,21 +41,13 @@ class Symbol(collections.namedtuple('Symbol', ['name'])): """Represents a Python symbol.""" -class StringLiteral(collections.namedtuple('StringLiteral', ['value'])): - """Represents a Python string literal.""" - - def __str__(self): - return '\'%s\'' % self.value - - def __repr__(self): - return str(self) - - -class NumberLiteral(collections.namedtuple('NumberLiteral', ['value'])): +class Literal(collections.namedtuple('Literal', ['value'])): """Represents a Python numeric literal.""" def __str__(self): - return '%s' % self.value + if isinstance(self.value, str): + return "'{}'".format(self.value) + return str(self.value) def __repr__(self): return str(self) @@ -91,7 +83,7 @@ class QN(object): self._has_subscript = True else: - if not isinstance(base, (str, StringLiteral, NumberLiteral)): + if not isinstance(base, (str, Literal)): # TODO(mdan): Require Symbol instead of string. raise ValueError( 'for simple QNs, base must be a string or a Literal object;' @@ -169,12 +161,13 @@ class QN(object): self.has_attr() == other.has_attr()) def __str__(self): + root = self.qn[0] if self.has_subscript(): - return str(self.qn[0]) + '[' + str(self.qn[1]) + ']' + return '{}[{}]'.format(root, self.qn[1]) if self.has_attr(): return '.'.join(map(str, self.qn)) else: - return str(self.qn[0]) + return str(root) def __repr__(self): return str(self) @@ -207,13 +200,11 @@ class QN(object): if isinstance(base, str): return gast.Name( base, ctx=CallerMustSetThis, annotation=None, type_comment=None) - elif isinstance(base, StringLiteral): - return gast.Constant(base.value, kind=None) - elif isinstance(base, NumberLiteral): + elif isinstance(base, Literal): return gast.Constant(base.value, kind=None) else: assert False, ('the constructor should prevent types other than ' - 'str, StringLiteral and NumberLiteral') + 'str and Literal') class QnResolver(gast.NodeTransformer): @@ -243,7 +234,7 @@ class QnResolver(gast.NodeTransformer): # Continuing silently because some demos use these. return node if isinstance(s.value, gast.Constant): - subscript = QN(NumberLiteral(s.value.value)) + subscript = QN(Literal(s.value.value)) else: # The index may be an expression, case in which a name doesn't make sense. if anno.hasanno(node.slice.value, anno.Basic.QN): diff --git a/tensorflow/python/autograph/pyct/qual_names_test.py b/tensorflow/python/autograph/pyct/qual_names_test.py index ce17aecc024..6addb0a7179 100644 --- a/tensorflow/python/autograph/pyct/qual_names_test.py +++ b/tensorflow/python/autograph/pyct/qual_names_test.py @@ -75,9 +75,7 @@ class QNTest(test.TestCase): b_sub_c = QN(b, subscript=c) a_sub_b_sub_c = QN(a, subscript=b_sub_c) self.assertEqual(a_sub_b_sub_c.qn, (a, b_sub_c)) - self.assertTrue(a_sub_b.is_composite()) self.assertTrue(a_sub_b_sub_c.is_composite()) - self.assertTrue(a_sub_b.has_subscript()) self.assertTrue(a_sub_b_sub_c.has_subscript()) self.assertEqual(b_sub_c.qn, (b, c)) self.assertEqual(str(a_sub_b_sub_c), 'a[b[c]]') @@ -154,14 +152,17 @@ class QNTest(test.TestCase): def test_literals(self): a = QN('a') - a_sub_str_b = QN(a, subscript=QN(qual_names.StringLiteral('b'))) + a_sub_str_b = QN(a, subscript=QN(qual_names.Literal('b'))) a_sub_b = QN(a, subscript=QN('b')) self.assertNotEqual(a_sub_str_b, a_sub_b) self.assertNotEqual(hash(a_sub_str_b), hash(a_sub_b)) + self.assertEqual(a_sub_str_b.ast().slice.value.value, 'b') + self.assertEqual(str(a_sub_str_b), "a['b']") - a_sub_three = QN(a, subscript=QN(qual_names.NumberLiteral(3))) + a_sub_three = QN(a, subscript=QN(qual_names.Literal(3))) self.assertEqual(a_sub_three.ast().slice.value.value, 3) + self.assertEqual(str(a_sub_three), "a[3]") def test_support_set(self): a = QN('a') diff --git a/tensorflow/python/autograph/pyct/static_analysis/activity.py b/tensorflow/python/autograph/pyct/static_analysis/activity.py index ca68bc9911c..0e19da87451 100644 --- a/tensorflow/python/autograph/pyct/static_analysis/activity.py +++ b/tensorflow/python/autograph/pyct/static_analysis/activity.py @@ -70,6 +70,9 @@ class Scope(object): globals: Set[qual_names.QN], names that are explicitly marked as global in this scope. Note that this doesn't include free read-only vars bound to global symbols. + nonlocals: Set[qual_names.QN], names that are explicitly marked as nonlocal + in this scope. Note that this doesn't include free read-only vars bound to + global symbols. free_vars: Set[qual_names.QN], the free variables in this scope. See https://docs.python.org/3/reference/executionmodel.html for a precise definition. @@ -111,6 +114,7 @@ class Scope(object): self.bound = set() self.globals = set() + self.nonlocals = set() self.annotations = set() self.params = weakref.WeakValueDictionary() @@ -186,6 +190,7 @@ class Scope(object): self.parent.modified.update(self.modified - self.isolated_names) self.parent.bound.update(self.bound - self.isolated_names) self.parent.globals.update(self.globals) + self.parent.nonlocals.update(self.nonlocals) self.parent.annotations.update(self.annotations) else: # TODO(mdan): This is not accurate. @@ -363,6 +368,7 @@ class ActivityAnalyzer(transformer.Base): qn = qual_names.QN(name) self.scope.read.add(qn) self.scope.bound.add(qn) + self.scope.nonlocals.add(qn) self._exit_and_record_scope(node) return node diff --git a/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py b/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py index 64b00fcbeba..ac91b662a47 100644 --- a/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py +++ b/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py @@ -404,6 +404,46 @@ class ReachingDefinitionsAnalyzerTest(ReachingDefinitionsAnalyzerTestBase): self.assertHasDefinedIn(fn_body[1], ('a',)) + def test_definitions_in_except_block(self): + + def test_fn(): + try: + pass + except ValueError: + a = None + if a: # pylint:disable=using-constant-test + a = None + return a + + node = self._parse_and_analyze(test_fn) + fn_body = node.body + + self.assertHasDefs(fn_body[1].test, 1) + self.assertHasDefs(fn_body[1].body[0].targets[0], 1) + self.assertHasDefs(fn_body[2].value, 2) + + self.assertHasDefinedIn(fn_body[1], ('a',)) + + def test_definitions_in_except_block_of_raising_try(self): + + def test_fn(): + try: + raise ValueError() + except ValueError: + a = None + if a: # pylint:disable=using-constant-test + a = None + return a + + node = self._parse_and_analyze(test_fn) + fn_body = node.body + + self.assertHasDefs(fn_body[1].test, 1) + self.assertHasDefs(fn_body[1].body[0].targets[0], 1) + self.assertHasDefs(fn_body[2].value, 2) + + self.assertHasDefinedIn(fn_body[1], ('a',)) + def test_global(self): def test_fn(): diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py index 1c244c1b297..074b50bf69b 100644 --- a/tensorflow/python/client/session_test.py +++ b/tensorflow/python/client/session_test.py @@ -34,6 +34,7 @@ from tensorflow.core.lib.core import error_codes_pb2 from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session from tensorflow.python.eager import context +from tensorflow.python.eager import def_function from tensorflow.python.framework import config from tensorflow.python.framework import constant_op from tensorflow.python.framework import device as framework_device_lib @@ -1911,8 +1912,8 @@ class SessionTest(test_util.TensorFlowTestCase): def __str__(self): return self._output + context.set_log_device_placement(True) if context.executing_eagerly(): - context.set_log_device_placement(True) with CaptureStderr() as log: a = constant_op.constant(1) b = constant_op.constant(2) @@ -1939,6 +1940,22 @@ class SessionTest(test_util.TensorFlowTestCase): add_executions = [l for l in str(log).splitlines() if 'AddV2' in l] self.assertEqual(len(add_executions), 2) + @def_function.function + def fn(): + a = constant_op.constant(1) + b = constant_op.constant(2) + c = a + b + d = a + b + return c, d + + with CaptureStderr() as log: + c, d = self.evaluate(fn()) + self.assertEqual(c, 3) + self.assertEqual(d, 3) + # Ensure that we did log device placement. + add_executions = [l for l in str(log).splitlines() if 'AddV2' in l] + self.assertEqual(len(add_executions), 2) + @test_util.run_v1_only('b/120545219') def testLocalMasterSessionTimeout(self): # Test that the timeout passed in a config to the session works correctly. diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 2a21590bb9a..53545c58a2d 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -33,7 +33,7 @@ from tensorflow.python.util.tf_export import tf_export # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 5, 14) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 5, 27) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None diff --git a/tensorflow/python/data/experimental/kernel_tests/BUILD b/tensorflow/python/data/experimental/kernel_tests/BUILD index d5d6cb00733..1d5abb9871b 100644 --- a/tensorflow/python/data/experimental/kernel_tests/BUILD +++ b/tensorflow/python/data/experimental/kernel_tests/BUILD @@ -1,5 +1,5 @@ -load("//tensorflow:tensorflow.bzl", "tf_py_test") -load("//tensorflow:tensorflow.bzl", "cuda_py_test") +load("//tensorflow:tensorflow.bzl", "tf_py_test") # buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "cuda_py_test") # buildifier: disable=same-origin-load package( default_visibility = ["//tensorflow:internal"], @@ -87,6 +87,17 @@ tf_py_test( ], ) +tf_py_test( + name = "compression_ops_test", + srcs = ["compression_ops_test.py"], + deps = [ + "//tensorflow/python/data/experimental/ops:compression_ops", + "//tensorflow/python/data/kernel_tests:test_base", + "//tensorflow/python/data/ops:dataset_ops", + "@absl_py//absl/testing:parameterized", + ], +) + cuda_py_test( name = "copy_to_device_test", size = "small", diff --git a/tensorflow/python/data/experimental/kernel_tests/compression_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/compression_ops_test.py new file mode 100644 index 00000000000..a091bdca8b9 --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/compression_ops_test.py @@ -0,0 +1,81 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for compression ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +from tensorflow.python.data.experimental.ops import compression_ops +from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import structure +from tensorflow.python.framework import combinations +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.platform import test + + +def _test_objects(): + return [ + combinations.NamedObject("int", 1), + combinations.NamedObject("string", "dog"), + combinations.NamedObject("tuple", (1, 1)), + combinations.NamedObject("int_string_tuple", (1, "dog")), + combinations.NamedObject( + "sparse", + sparse_tensor.SparseTensorValue( + indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])), + combinations.NamedObject( + "sparse_structured", { + "a": + sparse_tensor.SparseTensorValue( + indices=[[0, 0], [1, 2]], + values=[1, 2], + dense_shape=[3, 4]), + "b": (1, 2, "dog") + }) + ] + + +class CompressionOpsTest(test_base.DatasetTestBase, parameterized.TestCase): + + @combinations.generate( + combinations.times(test_base.default_test_combinations(), + combinations.combine(element=_test_objects()))) + def testCompression(self, element): + element = element._obj + + compressed = compression_ops.compress(element) + uncompressed = compression_ops.uncompress( + compressed, structure.type_spec_from_value(element)) + self.assertValuesEqual(element, self.evaluate(uncompressed)) + + @combinations.generate( + combinations.times(test_base.default_test_combinations(), + combinations.combine(element=_test_objects()))) + def testDatasetCompression(self, element): + element = element._obj + + dataset = dataset_ops.Dataset.from_tensors(element) + element_spec = dataset.element_spec + + dataset = dataset.map(lambda *x: compression_ops.compress(x)) + dataset = dataset.map(lambda x: compression_ops.uncompress(x, element_spec)) + self.assertDatasetProduces(dataset, [element]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/ops/BUILD b/tensorflow/python/data/experimental/ops/BUILD index 50d095e46f6..2adf2a6362d 100644 --- a/tensorflow/python/data/experimental/ops/BUILD +++ b/tensorflow/python/data/experimental/ops/BUILD @@ -33,6 +33,15 @@ py_library( ], ) +py_library( + name = "compression_ops", + srcs = ["compression_ops.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:experimental_dataset_ops_gen", + ], +) + py_library( name = "counter", srcs = ["counter.py"], @@ -475,6 +484,7 @@ py_library( deps = [ ":batching", ":cardinality", + ":compression_ops", ":counter", ":data_service_ops", ":distribute", diff --git a/tensorflow/python/data/experimental/ops/compression_ops.py b/tensorflow/python/data/experimental/ops/compression_ops.py new file mode 100644 index 00000000000..1ef7c8b3f01 --- /dev/null +++ b/tensorflow/python/data/experimental/ops/compression_ops.py @@ -0,0 +1,55 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Ops for compressing and uncompressing dataset elements.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.util import structure +from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops + + +def compress(element): + """Compress a dataset element. + + Args: + element: A nested structure of types supported by Tensorflow. + + Returns: + A variant tensor representing the compressed element. This variant can be + passed to `uncompress` to get back the original element. + """ + element_spec = structure.type_spec_from_value(element) + tensor_list = structure.to_tensor_list(element_spec, element) + return ged_ops.compress_element(tensor_list) + + +def uncompress(element, output_spec): + """Uncompress a compressed dataset element. + + Args: + element: A scalar variant tensor to uncompress. The element should have been + created by calling `compress`. + output_spec: A nested structure of `tf.TypeSpec` representing the type(s) of + the uncompressed element. + + Returns: + The uncompressed element. + """ + flat_types = structure.get_flat_tensor_types(output_spec) + flat_shapes = structure.get_flat_tensor_shapes(output_spec) + tensor_list = ged_ops.uncompress_element( + element, output_types=flat_types, output_shapes=flat_shapes) + return structure.from_tensor_list(output_spec, tensor_list) diff --git a/tensorflow/python/data/experimental/ops/data_service_ops.py b/tensorflow/python/data/experimental/ops/data_service_ops.py index 67dfadb4841..782f438c701 100644 --- a/tensorflow/python/data/experimental/ops/data_service_ops.py +++ b/tensorflow/python/data/experimental/ops/data_service_ops.py @@ -22,6 +22,7 @@ import functools import six from tensorflow.python import tf2 +from tensorflow.python.data.experimental.ops import compression_ops from tensorflow.python.data.experimental.ops.distribute_options import ExternalStatePolicy from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes @@ -84,6 +85,7 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource): if task_refresh_interval_hint_ms is None: task_refresh_interval_hint_ms = dataset_ops.AUTOTUNE + self._input_dataset = input_dataset self._dataset_id = ops.convert_to_tensor( dataset_id, dtype=dtypes.int64, name="dataset_id") self._processing_mode = ops.convert_to_tensor( @@ -201,16 +203,31 @@ def _distribute(processing_mode, protocol = ops.convert_to_tensor( protocol, dtype=dtypes.string, name="protocol") - def _apply_fn(dataset): + def _apply_fn(dataset): # pylint: disable=missing-docstring external_state_policy = dataset.options().experimental_external_state_policy if external_state_policy is None: external_state_policy = ExternalStatePolicy.WARN + + uncompressed_spec = dataset.element_spec + # Compress the dataset elements to reduce the amount of data that needs to + # be sent over the network. + # TODO(b/157105111): Make this an autotuned parallel map when we have a way + # to limit memory usage. + dataset = dataset.map(lambda *x: compression_ops.compress(x)) + # Prefetch one compressed element to reduce latency when requesting data + # from tf.data workers. + # TODO(b/157105111): Set this to autotune when we have a way to limit + # memory usage + dataset = dataset.prefetch(1) + # Apply options so that the dataset executed in the tf.data service will + # be optimized and support autotuning. + dataset = dataset._apply_options() # pylint: disable=protected-access dataset_id = gen_experimental_dataset_ops.register_dataset( dataset._variant_tensor, # pylint: disable=protected-access address=address, protocol=protocol, external_state_policy=external_state_policy.value) - return _DataServiceDataset( + dataset = _DataServiceDataset( input_dataset=dataset, dataset_id=dataset_id, processing_mode=processing_mode, @@ -219,6 +236,11 @@ def _distribute(processing_mode, job_name=job_name, max_outstanding_requests=max_outstanding_requests, task_refresh_interval_hint_ms=task_refresh_interval_hint_ms) + # TODO(b/157105111): Make this an autotuned parallel map when we have a way + # to limit memory usage. + dataset = dataset.map( + lambda x: compression_ops.uncompress(x, output_spec=uncompressed_spec)) + return dataset return _apply_fn diff --git a/tensorflow/python/data/kernel_tests/data_service_ops_test.py b/tensorflow/python/data/kernel_tests/data_service_ops_test.py index 217c586caef..726f0dc1530 100644 --- a/tensorflow/python/data/kernel_tests/data_service_ops_test.py +++ b/tensorflow/python/data/kernel_tests/data_service_ops_test.py @@ -37,12 +37,12 @@ from tensorflow.python.platform import test PROTOCOL = "grpc" -def _make_distributed_dataset(dataset, service, job_name=None): +def _make_distributed_dataset(dataset, address, job_name=None): """Creates a distributed dataset with a short task refresh interval.""" return dataset.apply( data_service_ops._distribute( "parallel_epochs", - service, + "{0}://{1}".format(PROTOCOL, address), job_name=job_name, task_refresh_interval_hint_ms=20)) @@ -56,34 +56,32 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): num_workers: The number of workers in the cluster. Returns: - A target for connecting to the service, e.g. - "grpc+local://localhost:2000". + The address of the master. """ - self._master = server_lib.MasterServer(PROTOCOL) - master_address = self._master.target[len(PROTOCOL + "://"):] - + self._master = server_lib.MasterServer(port=0, protocol=PROTOCOL) self._servers = [] for _ in range(num_workers): self._servers.append( - server_lib.WorkerServer(PROTOCOL, master_address=master_address)) + server_lib.WorkerServer( + port=0, master_address=self._master._address, protocol=PROTOCOL)) - return self._master.target + return self._master._address @combinations.generate(test_base.eager_only_combinations()) def testDistributeBasic(self): num_elements = 10 - service = self.create_cluster(1) + master_address = self.create_cluster(1) ds = dataset_ops.Dataset.range(num_elements) - ds = _make_distributed_dataset(ds, service) + ds = _make_distributed_dataset(ds, master_address) results = [elem.numpy() for elem in ds] self.assertEqual(list(range(num_elements)), results) @combinations.generate(test_base.eager_only_combinations()) def testMultipleEpochs(self): num_elements = 3 - service = self.create_cluster(1) + master_address = self.create_cluster(1) ds = dataset_ops.Dataset.range(num_elements) - ds = _make_distributed_dataset(ds, service) + ds = _make_distributed_dataset(ds, master_address) for _ in range(10): self.assertEqual(list(range(num_elements)), [elem.numpy() for elem in ds]) @@ -91,9 +89,9 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): def testRepeatedDataset(self): num_elements = 10 num_repetitions = 5 - service = self.create_cluster(1) + master_address = self.create_cluster(1) ds = dataset_ops.Dataset.range(num_elements) - ds = _make_distributed_dataset(ds, service) + ds = _make_distributed_dataset(ds, master_address) ds = ds.repeat(num_repetitions) self.assertDatasetProduces( ds, expected_output=num_repetitions * list(range(num_elements))) @@ -102,12 +100,12 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): def testConcurrentEpoch(self): num_elements = 10 num_datasets = 3 - service = self.create_cluster(1) + master_address = self.create_cluster(1) iterators = [] results = [] for _ in range(num_datasets): ds = dataset_ops.Dataset.range(num_elements) - ds = _make_distributed_dataset(ds, service) + ds = _make_distributed_dataset(ds, master_address) iterators.append(iter(ds)) results.append([]) @@ -123,9 +121,9 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): self.skipTest("Not yet implemented") num_elements = 10 num_iterators = 3 - service = self.create_cluster(1) + master_address = self.create_cluster(1) ds = dataset_ops.Dataset.range(num_elements) - ds = _make_distributed_dataset(ds, service) + ds = _make_distributed_dataset(ds, master_address) result = [] iterators = [] for _ in range(num_iterators): @@ -147,21 +145,20 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): def testMultiWorker(self): num_workers = 3 num_elements = 10 - service = self.create_cluster(num_workers) + master_address = self.create_cluster(num_workers) ds = dataset_ops.Dataset.range(num_elements) - ds = _make_distributed_dataset(ds, service) + ds = _make_distributed_dataset(ds, master_address) results = [elem.numpy() for elem in ds] self.assertCountEqual(num_workers * list(range(num_elements)), results) @combinations.generate(test_base.eager_only_combinations()) def testAddWorkerMidJob(self): - self._master = server_lib.MasterServer(PROTOCOL) - master_address = self._master.target[len(PROTOCOL + "://"):] + self._master = server_lib.MasterServer(port=0, protocol=PROTOCOL) self._worker = server_lib.WorkerServer( - PROTOCOL, master_address=master_address) + port=0, master_address=self._master._address, protocol=PROTOCOL) num_elements = 100 ds = dataset_ops.Dataset.range(num_elements) - ds = _make_distributed_dataset(ds, self._master.target) + ds = _make_distributed_dataset(ds, self._master._address) iterator = iter(ds) results = [] # Read halfway through the dataset. @@ -169,10 +166,10 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): results.append(next(iterator).numpy()) self._new_worker = server_lib.WorkerServer( - PROTOCOL, master_address=master_address) + port=0, master_address=self._master._address, protocol=PROTOCOL) # Wait for the new worker to register with the master. - while self._master.num_tasks() < 2: + while self._master._num_workers() < 2: time.sleep(10 / 1000) # 10ms for elem in iterator: @@ -184,13 +181,12 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): combinations.times(test_base.eager_only_combinations(), combinations.combine(use_same_port=[True, False]))) def testRestartWorker(self, use_same_port): - self._master = server_lib.MasterServer(PROTOCOL) - master_address = self._master.target[len(PROTOCOL + "://"):] + self._master = server_lib.MasterServer(port=0, protocol=PROTOCOL) self._worker = server_lib.WorkerServer( - PROTOCOL, master_address=master_address) + port=0, master_address=self._master._address, protocol=PROTOCOL) num_elements = 100 ds = dataset_ops.Dataset.range(num_elements) - ds = _make_distributed_dataset(ds, self._master.target) + ds = _make_distributed_dataset(ds, self._master._address) iterator = iter(ds) # Read halfway through the dataset. midpoint = num_elements // 2 @@ -200,11 +196,10 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): # Stop the original worker and start a new one. port = 0 if use_same_port: - worker_address = self._worker.target[len(PROTOCOL + "://"):] - port = int(worker_address.split(":")[1]) - self._worker.stop() + port = int(self._worker._address.split(":")[1]) + self._worker._stop() self._new_worker = server_lib.WorkerServer( - PROTOCOL, master_address=master_address, port=port) + port=port, master_address=self._master._address, protocol=PROTOCOL) # The dataset starts over now that we read from the new worker. for i in range(num_elements): @@ -219,12 +214,12 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): def testMaxOutstandingRequests(self): num_elements = 10 num_workers = 3 - service = self.create_cluster(num_workers) + address = self.create_cluster(num_workers) ds = dataset_ops.Dataset.range(num_elements) ds = ds.apply( data_service_ops._distribute( "parallel_epochs", - service, + "{0}://{1}".format(PROTOCOL, address), max_outstanding_requests=1, task_refresh_interval_hint_ms=20)) self.assertCountEqual(num_workers * list(range(num_elements)), @@ -234,12 +229,12 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): def testInsideFunction(self): num_workers = 3 num_elements = 10 - service = self.create_cluster(num_workers) + master_address = self.create_cluster(num_workers) @def_function.function def f(): ds = dataset_ops.Dataset.range(num_elements) - ds = _make_distributed_dataset(ds, service) + ds = _make_distributed_dataset(ds, master_address) result = tensor_array_ops.TensorArray( dtypes.int64, size=num_workers * num_elements, dynamic_size=True) i = 0 @@ -254,10 +249,10 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate(test_base.eager_only_combinations()) def testSharedJobName(self): num_elements = 10 - service = self.create_cluster(1) + master_address = self.create_cluster(1) ds = dataset_ops.Dataset.range(num_elements) - ds1 = _make_distributed_dataset(ds, service, job_name="job_name") - ds2 = _make_distributed_dataset(ds, service, job_name="job_name") + ds1 = _make_distributed_dataset(ds, master_address, job_name="job_name") + ds2 = _make_distributed_dataset(ds, master_address, job_name="job_name") iter1 = iter(ds1) iter2 = iter(ds2) results = [] @@ -273,20 +268,20 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate(test_base.eager_only_combinations()) def testDifferentJobNames(self): num_elements = 10 - service = self.create_cluster(1) + master_address = self.create_cluster(1) ds = dataset_ops.Dataset.range(num_elements) - ds1 = _make_distributed_dataset(ds, service, job_name="job_name1") - ds2 = _make_distributed_dataset(ds, service, job_name="job_name2") + ds1 = _make_distributed_dataset(ds, master_address, job_name="job_name1") + ds2 = _make_distributed_dataset(ds, master_address, job_name="job_name2") self.assertDatasetProduces(ds1, list(range(num_elements))) self.assertDatasetProduces(ds2, list(range(num_elements))) @combinations.generate(test_base.eager_only_combinations()) def testSharedJobNameMultiIteration(self): num_elements = 10 - service = self.create_cluster(1) + master_address = self.create_cluster(1) ds = dataset_ops.Dataset.range(num_elements) - ds1 = _make_distributed_dataset(ds, service, job_name="job_name") - ds2 = _make_distributed_dataset(ds, service, job_name="job_name") + ds1 = _make_distributed_dataset(ds, master_address, job_name="job_name") + ds2 = _make_distributed_dataset(ds, master_address, job_name="job_name") # iteration 1 self.assertDatasetProduces(ds1, list(range(num_elements))) self.assertDatasetProduces(ds2, []) @@ -298,11 +293,11 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): def testSharedJobNameRepeat(self): num_elements = 10 num_repetitions = 3 - service = self.create_cluster(1) + master_address = self.create_cluster(1) ds = dataset_ops.Dataset.range(num_elements) - ds1 = _make_distributed_dataset(ds, service, job_name="job_name") + ds1 = _make_distributed_dataset(ds, master_address, job_name="job_name") ds1 = ds1.repeat(num_repetitions) - ds2 = _make_distributed_dataset(ds, service, job_name="job_name") + ds2 = _make_distributed_dataset(ds, master_address, job_name="job_name") ds2 = ds2.repeat(num_repetitions) results = [] iter1 = iter(ds1) @@ -326,8 +321,8 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): options.experimental_external_state_policy = external_state_policy ds = ds.with_options(options) - service = self.create_cluster(3) - ds = _make_distributed_dataset(ds, service) + master_address = self.create_cluster(3) + ds = _make_distributed_dataset(ds, master_address) next(iter(ds)) @combinations.generate( @@ -347,12 +342,12 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate(test_base.eager_only_combinations()) def testDistributeFromInterleave(self): - service = self.create_cluster(1) + master_address = self.create_cluster(1) ds = dataset_ops.Dataset.range(2) def interleave_fn(_): ds = dataset_ops.Dataset.range(2) - _make_distributed_dataset(ds, service) + _make_distributed_dataset(ds, master_address) return ds with self.assertRaisesRegex( diff --git a/tensorflow/python/data/kernel_tests/options_test.py b/tensorflow/python/data/kernel_tests/options_test.py index dea217367dc..27b5a336a6c 100644 --- a/tensorflow/python/data/kernel_tests/options_test.py +++ b/tensorflow/python/data/kernel_tests/options_test.py @@ -107,9 +107,6 @@ class OptionsTest(test_base.DatasetTestBase, parameterized.TestCase): for _ in range(999): result = result.concatenate(ds) - options = dataset_ops.Options() - options.experimental_optimization.autotune = False - result = result.with_options(options) self.assertDatasetProduces(result, [0]*1000) diff --git a/tensorflow/python/data/service/BUILD b/tensorflow/python/data/service/BUILD index 19bcaa3b952..18678230205 100644 --- a/tensorflow/python/data/service/BUILD +++ b/tensorflow/python/data/service/BUILD @@ -1,4 +1,6 @@ load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension") + +# buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "tf_py_test") package( diff --git a/tensorflow/python/data/service/server_lib.py b/tensorflow/python/data/service/server_lib.py index b8f6e673f2e..df65508e6b2 100644 --- a/tensorflow/python/data/service/server_lib.py +++ b/tensorflow/python/data/service/server_lib.py @@ -24,93 +24,208 @@ from tensorflow.python.data.service import _pywrap_server_lib class MasterServer(object): - """An in-process tf.data service master, for use in testing.""" + """An in-process tf.data service master server. - def __init__(self, protocol): - """Creates and starts a new tf.data master server. + A `tf.data.experimental.service.MasterServer` coordinates a cluster of + `tf.data.experimental.service.WorkerServer`s. When the workers start, they + register themselves with the master. - The server will choose an available port. Use `target()` to get the string - for connecting to the server. + ``` + master_server = tf.data.experimental.service.MasterServer(port=5050) + worker_server = tf.data.experimental.service.WorkerServer( + port=0, master_address="localhost:5050") + dataset = tf.data.Dataset.range(10) + dataset = dataset.apply(tf.data.experimental.service.distribute( + processing_mode="parallel_epochs", service="grpc://localhost:5050")) + ``` + + When starting a dedicated tf.data master process, use join() to block + indefinitely after starting up the server. + + ``` + master_server = tf.data.experimental.service.MasterServer(port=5050) + master_server.join() + ``` + """ + + def __init__(self, port, protocol=None, start=True): + """Creates a new master server. Args: - protocol: A string representing the type of protocol to use when creating - channels. For no security, use "grpc". For local credentials, use - "grpc+local", and make sure your binary links in - `data/service:local_credentials`. + port: Specifies the port to bind to. + protocol: (Optional.) Specifies the protocol to be used by the server. + Acceptable values include `"grpc", "grpc+local"`. Defaults to `"grpc"`. + start: (Optional.) Boolean, indicating whether to start the server after + creating it. Defaults to `True`. + + Raises: + tf.errors.OpError: Or one of its subclasses if an error occurs while + creating the TensorFlow server. """ + if protocol is None: + protocol = "grpc" self._protocol = protocol - self._server = _pywrap_server_lib.TF_DATA_NewMasterServer(0, protocol) - self._running = True + self._server = _pywrap_server_lib.TF_DATA_NewMasterServer(port, protocol) + if start: + self._server.start() - @property - def target(self): - """Returns the target for connecting to this server. + def start(self): + """Starts this server. - The returned string will be in the form protocol://address:port, e.g. - "grpc://localhost:1000". + Raises: + tf.errors.OpError: Or one of its subclasses if an error occurs while + starting the server. """ - port = _pywrap_server_lib.TF_DATA_MasterServerBoundPort(self._server) - return "{0}://localhost:{1}".format(self._protocol, port) + self._server.start() - def num_tasks(self): - """Returns the number of tasks on the master.""" - return _pywrap_server_lib.TF_DATA_MasterServerNumTasks(self._server) + def join(self): + """Blocks until the server has shut down. - def stop(self): - """Shuts down and deletes the server. + This is useful when starting a dedicated master process. - This method will block until all outstanding rpcs have completed and the - server has been shut down. + ``` + master_server = tf.data.experimental.service.MasterServer(port=5050) + master_server.join() + ``` + + Raises: + tf.errors.OpError: Or one of its subclasses if an error occurs while + joining the server. """ - if self._running: - self._running = False - _pywrap_server_lib.TF_DATA_DeleteMasterServer(self._server) + self._server.join() + + def _stop(self): + """Stops the server. + + Raises: + tf.errors.OpError: Or one of its subclasses if an error occurs while + stopping the server. + """ + self._server.stop() def __del__(self): - self.stop() + self._stop() + + @property + def _address(self): + """Returns the address of the server. + + The returned string will be in the form address:port, e.g. "localhost:1000". + """ + return "localhost:{0}".format(self._server.bound_port()) + + def _num_workers(self): + """Returns the number of workers registered with the master.""" + return self._server.num_workers() class WorkerServer(object): - """An in-process tf.data service worker, for use in testing.""" + """An in-process tf.data service worker server. - def __init__(self, protocol, master_address, port=0): - """Creates and starts a new tf.data worker server. + A `tf.data.experimental.service.WorkerServer` performs `tf.data.Dataset` + processing for user-defined datasets, and provides the resulting elements over + RPC. A worker is associated with a single + `tf.data.experimental.service.MasterServer`. - The server will choose an available port. Use `target()` to get the string - for connecting to the server. + ``` + master_server = tf.data.experimental.service.MasterServer(port=5050) + worker_server = tf.data.experimental.service.WorkerServer( + port=0, master_address="localhost:5050") + dataset = tf.data.Dataset.range(10) + dataset = dataset.apply(tf.data.experimental.service.distribute( + processing_mode="parallel_epochs", service="grpc://localhost:5050")) + ``` + + When starting a dedicated tf.data worker process, use join() to block + indefinitely after starting up the server. + + ``` + worker_server = tf.data.experimental.service.WorkerServer( + port=5050, master_address="grpc://localhost:5050") + worker_server.join() + ``` + """ + + def __init__(self, + port, + master_address, + worker_address=None, + protocol=None, + start=True): + """Creates a new worker server. Args: - protocol: A string representing the type of protocol to use when creating - channels. For no security, use "grpc". For local credentials, use - "grpc+local", and make sure your binary links in - `data/service:local_credentials`. - master_address: The address of the tf.data master server to register with. - port: The port to bind to. + port: Specifies the port to bind to. A value of 0 indicates that the + worker can bind to any available port. + master_address: Specifies the address of the master server. + worker_address: (Optional.) Specifies the address of the worker server. + This address is passed to the master server so that the master can tell + clients how to connect to this worker. Defaults to `"localhost:%port%"`, + where `%port%` will be replaced with the port used by the worker. + protocol: (Optional.) Specifies the protocol to be used by the server. + Acceptable values include `"grpc", "grpc+local"`. Defaults to `"grpc"`. + start: (Optional.) Boolean, indicating whether to start the server after + creating it. Defaults to `True`. + + Raises: + tf.errors.OpError: Or one of its subclasses if an error occurs while + creating the TensorFlow server. """ + if worker_address is None: + worker_address = "localhost:%port%" + if protocol is None: + protocol = "grpc" + self._protocol = protocol self._server = _pywrap_server_lib.TF_DATA_NewWorkerServer( - port, protocol, master_address, "localhost:%port%") - self._running = True + port, protocol, master_address, worker_address) + if start: + self._server.start() - @property - def target(self): - """Returns the target for connecting to this server. + def start(self): + """Starts this server. - The returned string will be in the form protocol://address:port, e.g. - "grpc://localhost:1000". + Raises: + tf.errors.OpError: Or one of its subclasses if an error occurs while + starting the server. """ - port = _pywrap_server_lib.TF_DATA_WorkerServerBoundPort(self._server) - return "{0}://localhost:{1}".format(self._protocol, port) + self._server.start() - def stop(self): - """Shuts down and deletes the server. + def join(self): + """Blocks until the server has shut down. - This method will block until all outstanding rpcs have completed and the - server has been shut down. + This is useful when starting a dedicated worker process. + + ``` + worker_server = tf.data.experimental.service.WorkerServer( + port=5050, master_address="grpc://localhost:5050") + worker_server.join() + ``` + + This method currently blocks forever. + + Raises: + tf.errors.OpError: Or one of its subclasses if an error occurs while + joining the server. """ - if self._running: - self._running = False - _pywrap_server_lib.TF_DATA_DeleteWorkerServer(self._server) + self._server.join() + + def _stop(self): + """Stops the server. + + Raises: + tf.errors.OpError: Or one of its subclasses if an error occurs while + stopping the server. + """ + self._server.stop() def __del__(self): - self.stop() + self._stop() + + @property + def _address(self): + """Returns the address of the server. + + The returned string will be in the form address:port, e.g. "localhost:1000". + """ + return "localhost:{0}".format(self._server.bound_port()) diff --git a/tensorflow/python/data/service/server_lib_test.py b/tensorflow/python/data/service/server_lib_test.py index b18262bf52b..59bb731d98e 100644 --- a/tensorflow/python/data/service/server_lib_test.py +++ b/tensorflow/python/data/service/server_lib_test.py @@ -22,20 +22,71 @@ from tensorflow.python.data.service import server_lib from tensorflow.python.platform import test -PROTOCOL = "grpc" - class ServerLibTest(test.TestCase): def testStartMaster(self): - master = server_lib.MasterServer(PROTOCOL) - self.assertRegex(master.target, PROTOCOL + "://.*:.*") + master = server_lib.MasterServer(0, start=False) + master.start() + + def testMultipleStartMaster(self): + master = server_lib.MasterServer(0, start=True) + master.start() def testStartWorker(self): - master = server_lib.MasterServer(PROTOCOL) - worker = server_lib.WorkerServer(PROTOCOL, - master.target[len(PROTOCOL + "://"):]) - self.assertRegex(worker.target, PROTOCOL + "://.*:.*") + master = server_lib.MasterServer(0) + worker = server_lib.WorkerServer(0, master._address, start=False) + worker.start() + + def testMultipleStartWorker(self): + master = server_lib.MasterServer(0) + worker = server_lib.WorkerServer(0, master._address, start=True) + worker.start() + + def testStopMaster(self): + master = server_lib.MasterServer(0) + master._stop() + master._stop() + + def testStopWorker(self): + master = server_lib.MasterServer(0) + worker = server_lib.WorkerServer(0, master._address) + worker._stop() + worker._stop() + + def testStopStartMaster(self): + master = server_lib.MasterServer(0) + master._stop() + with self.assertRaisesRegex( + RuntimeError, "Server cannot be started after it has been stopped"): + master.start() + + def testStopStartWorker(self): + master = server_lib.MasterServer(0) + worker = server_lib.WorkerServer(0, master._address) + worker._stop() + with self.assertRaisesRegex( + RuntimeError, "Server cannot be started after it has been stopped"): + worker.start() + + def testJoinMaster(self): + master = server_lib.MasterServer(0) + master._stop() + master.join() + + def testJoinWorker(self): + master = server_lib.MasterServer(0) + worker = server_lib.WorkerServer(0, master._address) + worker._stop() + worker.join() + + def testMasterNumWorkers(self): + master = server_lib.MasterServer(0) + self.assertEqual(0, master._num_workers()) + worker1 = server_lib.WorkerServer(0, master._address) # pylint: disable=unused-variable + self.assertEqual(1, master._num_workers()) + worker2 = server_lib.WorkerServer(0, master._address) # pylint: disable=unused-variable + self.assertEqual(2, master._num_workers()) if __name__ == "__main__": diff --git a/tensorflow/python/data/service/server_lib_wrapper.cc b/tensorflow/python/data/service/server_lib_wrapper.cc index 8325d74a768..03453a56c7f 100644 --- a/tensorflow/python/data/service/server_lib_wrapper.cc +++ b/tensorflow/python/data/service/server_lib_wrapper.cc @@ -28,8 +28,24 @@ limitations under the License. namespace py = pybind11; PYBIND11_MODULE(_pywrap_server_lib, m) { - py::class_<tensorflow::data::MasterGrpcDataServer>(m, "MasterGrpcDataServer"); - py::class_<tensorflow::data::WorkerGrpcDataServer>(m, "WorkerGrpcDataServer"); + py::class_<tensorflow::data::MasterGrpcDataServer>(m, "MasterGrpcDataServer") + .def("start", &tensorflow::data::MasterGrpcDataServer::Start) + .def("stop", &tensorflow::data::MasterGrpcDataServer::Stop) + .def("join", &tensorflow::data::MasterGrpcDataServer::Join) + .def("bound_port", &tensorflow::data::MasterGrpcDataServer::BoundPort) + .def("num_workers", + [](tensorflow::data::MasterGrpcDataServer* server) -> int { + int num_workers; + tensorflow::Status status = server->NumWorkers(&num_workers); + tensorflow::MaybeRaiseFromStatus(status); + return num_workers; + }); + + py::class_<tensorflow::data::WorkerGrpcDataServer>(m, "WorkerGrpcDataServer") + .def("start", &tensorflow::data::WorkerGrpcDataServer::Start) + .def("stop", &tensorflow::data::WorkerGrpcDataServer::Stop) + .def("join", &tensorflow::data::WorkerGrpcDataServer::Join) + .def("bound_port", &tensorflow::data::WorkerGrpcDataServer::BoundPort); m.def( "TF_DATA_NewMasterServer", @@ -39,27 +55,9 @@ PYBIND11_MODULE(_pywrap_server_lib, m) { tensorflow::Status status = tensorflow::data::NewMasterServer(port, protocol, &server); tensorflow::MaybeRaiseFromStatus(status); - server->Start(); return server; }, py::return_value_policy::reference); - m.def( - "TF_DATA_MasterServerBoundPort", - [](tensorflow::data::MasterGrpcDataServer* server) -> int { - return server->BoundPort(); - }, - py::return_value_policy::copy); - m.def("TF_DATA_DeleteMasterServer", - [](tensorflow::data::MasterGrpcDataServer* server) { server->Stop(); }); - m.def( - "TF_DATA_MasterServerNumTasks", - [](tensorflow::data::MasterGrpcDataServer* server) -> int { - int num_tasks; - tensorflow::Status status = server->NumTasks(&num_tasks); - tensorflow::MaybeRaiseFromStatus(status); - return num_tasks; - }, - py::return_value_policy::copy); m.def( "TF_DATA_NewWorkerServer", @@ -70,16 +68,7 @@ PYBIND11_MODULE(_pywrap_server_lib, m) { tensorflow::Status status = tensorflow::data::NewWorkerServer( port, protocol, master_address, worker_address, &server); tensorflow::MaybeRaiseFromStatus(status); - server->Start(); return server; }, py::return_value_policy::reference); - m.def( - "TF_DATA_WorkerServerBoundPort", - [](tensorflow::data::WorkerGrpcDataServer* server) -> int { - return server->BoundPort(); - }, - py::return_value_policy::copy); - m.def("TF_DATA_DeleteWorkerServer", - [](tensorflow::data::WorkerGrpcDataServer* server) { server->Stop(); }); }; diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD index 956e90999c7..1ef0504ecb8 100644 --- a/tensorflow/python/debug/BUILD +++ b/tensorflow/python/debug/BUILD @@ -840,7 +840,6 @@ py_test( python_version = "PY3", srcs_version = "PY2AND3", tags = [ - "no_oss_py38", #TODO(b/151449908) "no_windows", ], deps = [ diff --git a/tensorflow/python/debug/lib/check_numerics_callback.py b/tensorflow/python/debug/lib/check_numerics_callback.py index edcafad201e..796fabae301 100644 --- a/tensorflow/python/debug/lib/check_numerics_callback.py +++ b/tensorflow/python/debug/lib/check_numerics_callback.py @@ -275,7 +275,9 @@ class CheckNumericsCallback(object): output, inputs, graph=graph, - traceback=output.op.traceback)) + traceback=output.op.traceback, + stack_height_limit=self._stack_height_limit, + path_length_limit=self._path_length_limit)) _CHECK_NUMERICS_INPUT_LOOKUP[graph][checked_output.name] = output instrumented_outputs.append(self._get_output_tensor( op_type_bytes, output, checked_output, is_v1_graph_mode)) @@ -410,6 +412,21 @@ def enable_check_numerics(stack_height_limit=30, z = tf.matmul(y, y) ``` + NOTE: If your code is running on TPUs, be sure to call + `tf.config.set_soft_device_placement(True)` before calling + `tf.debugging.enable_check_numerics()` as this API uses automatic outside + compilation on TPUs. For example: + + ```py + tf.config.set_soft_device_placement(True) + tf.debugging.enable_check_numerics() + + resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') + strategy = tf.distribute.experimental.TPUStrategy(resolver) + with strategy.scope(): + # ... + ``` + Args: stack_height_limit: Limit to the height of the printed stack trace. Applicable only to ops in `tf.function`s (graphs). diff --git a/tensorflow/python/debug/lib/check_numerics_callback_test.py b/tensorflow/python/debug/lib/check_numerics_callback_test.py index 5f578da03c3..5c0cc6394ac 100644 --- a/tensorflow/python/debug/lib/check_numerics_callback_test.py +++ b/tensorflow/python/debug/lib/check_numerics_callback_test.py @@ -39,6 +39,7 @@ from tensorflow.python.ops import math_grad # pylint: disable=unused-import from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables from tensorflow.python.platform import googletest +from tensorflow.python.platform import test class LimitStringLengthTest(test_util.TensorFlowTestCase): @@ -105,6 +106,27 @@ class CheckNumericsCallbackTest(test_util.TensorFlowTestCase): self.assertAllClose(batches[0], np.log([1.25, 2])) self.assertAllClose(batches[1], np.log([3.25, 5])) + @test_util.run_in_graph_and_eager_modes + def testGraphModeUsesCorrectPathLengthAndStackHeightLimits(self): + check_numerics_callback.enable_check_numerics( + stack_height_limit=123, path_length_limit=1200) + + @def_function.function + def add_fn(x, y): + return x + y + + fake_get_check_numerics_error_message = test.mock.MagicMock( + return_value="dummy_message") + with test.mock.patch.object(check_numerics_callback, + "get_check_numerics_error_message", + fake_get_check_numerics_error_message): + x = constant_op.constant(2.0) + y = constant_op.constant(3.0) + self.assertAllClose(self.evaluate(add_fn(x, y)), 5.0) + (_, call_kwargs) = fake_get_check_numerics_error_message.call_args + self.assertEqual(call_kwargs["stack_height_limit"], 123) + self.assertEqual(call_kwargs["path_length_limit"], 1200) + class CheckNumericsCallbackUnhealthyTest(test_util.TensorFlowTestCase): """Test for cases in which enable_check_numerics() catches infs or nans.""" @@ -372,6 +394,22 @@ class CheckNumericsCallbackUnhealthyTest(test_util.TensorFlowTestCase): re.search(r"graph op.*\"Xdivy\"", message))) self.assertTrue(re.search(r"dtype.*float32", message)) + def testEagerModeUsesCorrectPathLengthAndStackHeightLimits(self): + check_numerics_callback.enable_check_numerics( + stack_height_limit=123, path_length_limit=1200) + fake_get_check_numerics_error_message = test.mock.MagicMock( + return_value="dummy_message") + with test.mock.patch.object(check_numerics_callback, + "get_check_numerics_error_message", + fake_get_check_numerics_error_message): + x = constant_op.constant(2.0) + y = constant_op.constant(0.0) + self._assertRaisesInvalidArgumentErrorAndGetMessage( + lambda: x / y) # Expected to generate an inf. + (_, call_kwargs) = fake_get_check_numerics_error_message.call_args + self.assertEqual(call_kwargs["stack_height_limit"], 123) + self.assertEqual(call_kwargs["path_length_limit"], 1200) + @test_util.run_in_graph_and_eager_modes def testExpectedNaNOpOutputs(self): """Test calling operations with benign NaN output.""" diff --git a/tensorflow/python/debug/lib/debug_v2_ops_test.py b/tensorflow/python/debug/lib/debug_v2_ops_test.py index c76cbeeac6c..07721920f63 100644 --- a/tensorflow/python/debug/lib/debug_v2_ops_test.py +++ b/tensorflow/python/debug/lib/debug_v2_ops_test.py @@ -33,6 +33,7 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_debug_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import googletest @@ -680,6 +681,39 @@ class DebugIdentityV2OpTest(dumping_callback_test_lib.DumpingCallbackTestBase): self.assertAllEqual(tensor_1, tensor_2) self.assertEqual(tensor_id_1, tensor_id_2) + def testCheckNumericsV2OpNegativeAndPositiveInf(self): + """Test that CheckNumericsV2 op distinguishes negative and positive infs.""" + with self.session(graph=ops.Graph()): + t1 = constant_op.constant([-1.0, 1.0]) + t2 = constant_op.constant([0.0, 0.0]) + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + r"pass through test.*had -Inf and \+Inf values"): + self.evaluate( + array_ops.check_numerics_v2(t1 / t2, message="pass through test")) + + def testCheckNumericsV2OpNegativeAndPositiveInfAndNaN(self): + """CheckNumericsV2 op distinguishes - & + infs when nan is present.""" + with self.session(graph=ops.Graph()): + t1 = constant_op.constant([-1.0, 1.0, 0.0]) + t2 = constant_op.constant([0.0, 0.0, 0.0]) + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + r"pass through test.*had -Inf, \+Inf, and NaN values"): + self.evaluate( + array_ops.check_numerics_v2(t1 / t2, message="pass through test")) + + def testCheckNumericsV2PositiveInfAndNaN(self): + """Test that CheckNumericsV2 op shows sign of inf when nan is present.""" + with self.session(graph=ops.Graph()): + t1 = constant_op.constant([0.0, 1.0]) + t2 = constant_op.constant([0.0, 0.0]) + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + r"pass through test.*had \+Inf and NaN values"): + self.evaluate( + array_ops.check_numerics_v2(t1 / t2, message="pass through test")) + if __name__ == "__main__": ops.enable_eager_execution() diff --git a/tensorflow/python/debug/lib/dumping_callback.py b/tensorflow/python/debug/lib/dumping_callback.py index 5f7fe5e7ea4..f012faf5f3c 100644 --- a/tensorflow/python/debug/lib/dumping_callback.py +++ b/tensorflow/python/debug/lib/dumping_callback.py @@ -721,6 +721,22 @@ def enable_dump_debug_info(dump_root, # Code to build, train and run your TensorFlow model... ``` + NOTE: If your code is running on TPUs, be sure to call + `tf.config.set_soft_device_placement(True)` before calling + `tf.debugging.experimental.enable_dump_debug_info()` as this API uses + automatic outside compilation on TPUs. For example: + + ```py + tf.config.set_soft_device_placement(True) + tf.debugging.experimental.enable_dump_debug_info( + logdir, tensor_debug_mode="FULL_HEALTH") + + resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') + strategy = tf.distribute.experimental.TPUStrategy(resolver) + with strategy.scope(): + # ... + ``` + Args: dump_root: The directory path where the dumping information will be written. tensor_debug_mode: Debug mode for tensor values, as a string. diff --git a/tensorflow/python/debug/lib/source_utils_test.py b/tensorflow/python/debug/lib/source_utils_test.py index faf2365fc9c..89964a21ba7 100644 --- a/tensorflow/python/debug/lib/source_utils_test.py +++ b/tensorflow/python/debug/lib/source_utils_test.py @@ -18,7 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import ast import os +import sys import tempfile import zipfile @@ -43,7 +45,41 @@ from tensorflow.python.util import tf_inspect def line_number_above(): - return tf_inspect.stack()[1][2] - 1 + """Get lineno of the AST node immediately above this function's call site. + + It is assumed that there is no empty line(s) between the call site and the + preceding AST node. + + Returns: + The lineno of the preceding AST node, at the same level of the AST. + If the preceding AST spans multiple lines: + - In Python 3.8+, the lineno of the first line is returned. + - In older Python versions, the lineno of the last line is returned. + """ + # https://bugs.python.org/issue12458: In Python 3.8, traceback started + # to return the lineno of the first line of a multi-line continuation block, + # instead of that of the last line. Therefore, in Python 3.8+, we use `ast` to + # get the lineno of the first line. + call_site_lineno = tf_inspect.stack()[1][2] + if sys.version_info < (3, 8): + return call_site_lineno - 1 + else: + with open(__file__, "rb") as f: + source_text = f.read().decode("utf-8") + source_tree = ast.parse(source_text) + prev_node = _find_preceding_ast_node(source_tree, call_site_lineno) + return prev_node.lineno + + +def _find_preceding_ast_node(node, lineno): + """Find the ast node immediately before and not including lineno.""" + for i, child_node in enumerate(node.body): + if child_node.lineno == lineno: + return node.body[i - 1] + if hasattr(child_node, "body"): + found_node = _find_preceding_ast_node(child_node, lineno) + if found_node: + return found_node class GuessIsTensorFlowLibraryTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index a7e62a2dc7c..01ae1b61f6a 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -1181,6 +1181,23 @@ distribute_py_test( ], ) +distribute_py_test( + name = "strategy_reduce_test", + srcs = ["strategy_reduce_test.py"], + main = "strategy_reduce_test.py", + tags = [ + "multi_and_single_gpu", + ], + deps = [ + ":combinations", + ":strategy_combinations", + "//tensorflow/python:errors", + "//tensorflow/python:variables", + "//tensorflow/python/eager:test", + "@absl_py//absl/testing:parameterized", + ], +) + distribute_py_test( name = "minimize_loss_test", srcs = ["minimize_loss_test.py"], @@ -1546,6 +1563,7 @@ cuda_py_test( srcs = ["parameter_server_strategy_test.py"], tags = [ "multi_and_single_gpu", + "no_windows", # TODO(b/156428279): reenable this test once the image is updated. ], # b/141096229: Non-atomic AssignAdd xla_enable_strict_auto_jit = False, diff --git a/tensorflow/python/distribute/checkpointing_test.py b/tensorflow/python/distribute/checkpointing_test.py index 040faf6f6ce..ad646905315 100644 --- a/tensorflow/python/distribute/checkpointing_test.py +++ b/tensorflow/python/distribute/checkpointing_test.py @@ -32,6 +32,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.keras.engine import training from tensorflow.python.keras.layers import core from tensorflow.python.keras.optimizer_v2 import adam +from tensorflow.python.ops import array_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variables as variables_lib from tensorflow.python.training import adam as adam_v1 @@ -96,6 +97,41 @@ class TrainingCheckpointTests(test.TestCase, parameterized.TestCase): self.assertEqual((training_continuation + 1) * num_training_steps, root.optimizer_step.numpy()) + @combinations.generate( + combinations.combine( + distribution=[ + strategy_combinations.mirrored_strategy_with_one_cpu, + strategy_combinations.mirrored_strategy_with_gpu_and_cpu, + strategy_combinations.tpu_strategy, + strategy_combinations.central_storage_strategy_with_two_gpus, + ], + mode=["eager"])) + def testInitializeFromCheckpoint(self, distribution): + variable_shape = [5] + save_checkpoint = trackable_utils.Checkpoint(v=variables_lib.Variable( + array_ops.ones(variable_shape))) + save_path = save_checkpoint.save( + os.path.join(self.get_temp_dir(), "checkpoint")) + with distribution.scope(): + restore_checkpoint = trackable_utils.Checkpoint() + restore_checkpoint.restore(save_path) + initial_value = restore_checkpoint._preload_simple_restoration( + "v", variable_shape) + v = variables_lib.Variable(initial_value) + # Check that the variable is now tagged as restored. `Checkpoint` then + # knows it doesn't have to restore `v`'s value when it's assigned to an + # object. + self.assertGreater(v._update_uid, 0) + self.assertAllClose(array_ops.ones(variable_shape), v) + v.assign(array_ops.zeros(variable_shape)) + # Assignment to an object should not trigger restoration, since we already + # restored the object through an initializer. This wouldn't be a + # correctness issue, but it would mean that models would use twice as much + # memory when loading (the buffer already assigned to the variable, and + # the new restoration). + restore_checkpoint.v = v + self.assertAllClose(array_ops.zeros(variable_shape), v) + @combinations.generate( combinations.combine( distribution=[ diff --git a/tensorflow/python/distribute/custom_training_loop_models_test.py b/tensorflow/python/distribute/custom_training_loop_models_test.py index 48f2af0349a..5a9384bb7e0 100644 --- a/tensorflow/python/distribute/custom_training_loop_models_test.py +++ b/tensorflow/python/distribute/custom_training_loop_models_test.py @@ -26,6 +26,7 @@ import numpy as np from tensorflow.python import keras from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import combinations +from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import strategy_combinations from tensorflow.python.eager import backprop from tensorflow.python.eager import def_function @@ -448,6 +449,35 @@ class KerasModelsTest(test.TestCase, parameterized.TestCase): train_step(input_iterator) + @combinations.generate( + combinations.combine( + distribution=strategy_combinations.all_strategies, mode=["eager"])) + def test_reduce_loss(self, distribution): + inputs = np.zeros((10, 4), dtype=np.float32) + targets = np.zeros((10, 1), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.batch(10, drop_remainder=False) + input_iterator = iter(distribution.experimental_distribute_dataset(dataset)) + + with distribution.scope(): + x = keras.layers.Input(shape=(4), name="input") + y = keras.layers.Dense(3, name="dense")(x) + model = keras.Model(x, y) + + @def_function.function + def train_step(iterator): + + def step_fn(inputs): + images, targets = inputs + outputs = model(images) + loss = keras.losses.sparse_categorical_crossentropy(targets, outputs) + return loss + + return distribution.run(step_fn, args=(next(iterator),)) + + loss = train_step(input_iterator) + loss = distribution.reduce(reduce_util.ReduceOp.MEAN, loss, axis=0) + @combinations.generate( combinations.combine( distribution=strategy_combinations.tpu_strategies, mode=["eager"])) diff --git a/tensorflow/python/distribute/distribute_lib.py b/tensorflow/python/distribute/distribute_lib.py index 6baa15f59c1..ecdc4fad159 100644 --- a/tensorflow/python/distribute/distribute_lib.py +++ b/tensorflow/python/distribute/distribute_lib.py @@ -114,6 +114,7 @@ from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.distribute import numpy_dataset from tensorflow.python.distribute import reduce_util from tensorflow.python.eager import context as eager_context +from tensorflow.python.eager import def_function from tensorflow.python.eager import monitoring from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -628,6 +629,10 @@ class StrategyBase(object): # a sensible value. extended._retrace_functions_for_each_device = True + # Below are the dicts of axis(int) -> `tf.function`. + self._mean_reduce_helper_fns = {} + self._reduce_sum_fns = {} + @property def extended(self): """`tf.distribute.StrategyExtended` with additional methods.""" @@ -1014,8 +1019,25 @@ class StrategyBase(object): if axis is None: return self._extended._reduce(reduce_op, value) # pylint: disable=protected-access if reduce_op == reduce_util.ReduceOp.SUM: - value = self.run( - lambda v: math_ops.reduce_sum(v, axis=axis), args=(value,)) + + def reduce_sum(v): + return math_ops.reduce_sum(v, axis=axis) + + if eager_context.executing_eagerly(): + # As some strategies (e.g. TPUStrategy) doesn't support pure eager + # execution, wrap the `reduce_sum_fn` with a `tf.function` so it can be + # run from eager mode. Cache the tf.function by `axis` to avoid the + # same function to be traced again. + if axis not in self._reduce_sum_fns: + + def reduce_sum_fn(v): + return self.run(reduce_sum, args=(v,)) + + self._reduce_sum_fns[axis] = def_function.function(reduce_sum_fn) + value = self._reduce_sum_fns[axis](value) + else: + value = self.run(reduce_sum, args=(value,)) + return self._extended._reduce(reduce_op, value) # pylint: disable=protected-access if reduce_op != reduce_util.ReduceOp.MEAN: raise TypeError("Expected `reduce_op` to be a `tf.distribute.ReduceOp`, " @@ -1062,7 +1084,22 @@ class StrategyBase(object): # reduce is complete? return numer, denom - numer, denom = self.run(mean_reduce_helper, args=(value,)) + if eager_context.executing_eagerly(): + # As some strategies (e.g. TPUStrategy) doesn't support pure eager + # execution, wrap the `mean_reduce_helper` with a `tf.function` so it can + # be run from eager mode. Cache the tf.function by `axis` to avoid the + # same function to be traced again. + if axis not in self._mean_reduce_helper_fns: + + def mean_reduce_fn(v): + return self.run(mean_reduce_helper, args=(v,)) + + self._mean_reduce_helper_fns[axis] = def_function.function( + mean_reduce_fn) + numer, denom = self._mean_reduce_helper_fns[axis](value) + else: + numer, denom = self.run(mean_reduce_helper, args=(value,)) + # TODO(josh11b): Should batch reduce here instead of doing two. numer = self._extended._reduce(reduce_util.ReduceOp.SUM, numer) # pylint: disable=protected-access denom = self._extended._reduce(reduce_util.ReduceOp.SUM, denom) # pylint: disable=protected-access @@ -1772,13 +1809,25 @@ class StrategyExtendedV2(object): kwargs["distribute_strategy"] = strategy # Unwrap `initial_value` if it is a `CheckpointInitialValue` to avoid - # dereferencing a `Tensor` that is without a `name`. - # TODO(b/138130844): Revisit the following check once - # `CheckpointInitialValue` class is removed. + # dereferencing a `Tensor` that is without a `name`. We still need to + # propagate the metadata it's holding. if isinstance(kwargs["initial_value"], trackable.CheckpointInitialValue): + checkpoint_restore_uid = kwargs[ + "initial_value"].checkpoint_position.restore_uid kwargs["initial_value"] = kwargs["initial_value"].wrapped_value + else: + checkpoint_restore_uid = None - return self._create_variable(next_creator, **kwargs) + created = self._create_variable(next_creator, **kwargs) + + if checkpoint_restore_uid is not None: + # pylint: disable=protected-access + # Let the checkpointing infrastructure know that the variable was + # already restored so it doesn't waste memory loading the value again. + created._maybe_initialize_trackable() + created._update_uid = checkpoint_restore_uid + # pylint: enable=protected-access + return created def distributed_getter(getter, *args, **kwargs): if not self._allow_variable_partition(): diff --git a/tensorflow/python/distribute/sharded_variable.py b/tensorflow/python/distribute/sharded_variable.py index 9886e42a8b3..7accc066d8a 100644 --- a/tensorflow/python/distribute/sharded_variable.py +++ b/tensorflow/python/distribute/sharded_variable.py @@ -96,6 +96,10 @@ class ShardedVariable(trackable.Trackable): 'to the order of the `Variable`s in the list passed to ' 'the constructor. Found {}'.format(save_slice_info)) + def __iter__(self): + """Return an iterable for accessing the underlying sharded variables.""" + return iter(self._variables) + @property def variables(self): """The list of `Variable`s that make up the shards of this object.""" diff --git a/tensorflow/python/distribute/strategy_reduce_test.py b/tensorflow/python/distribute/strategy_reduce_test.py new file mode 100644 index 00000000000..a87cce2f0b8 --- /dev/null +++ b/tensorflow/python/distribute/strategy_reduce_test.py @@ -0,0 +1,52 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for `strategy.reduce`.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +from tensorflow.python.distribute import combinations +from tensorflow.python.distribute import reduce_util +from tensorflow.python.distribute import strategy_combinations +from tensorflow.python.eager import def_function +from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op + + +class StrategyReduceTest(test.TestCase, parameterized.TestCase): + + @combinations.generate( + combinations.combine( + distribution=strategy_combinations.all_strategies, + mode=["eager"] + )) + def test_reduce_with_axis(self, distribution): + + @def_function.function + def fn(): + return constant_op.constant([1., 2.]) + x = distribution.run(fn) + + x_m = distribution.reduce(reduce_util.ReduceOp.MEAN, x, axis=0) + self.assertEqual(1.5, self.evaluate(x_m)) + x_s = distribution.reduce(reduce_util.ReduceOp.SUM, x, axis=0) + self.assertEqual(3 * distribution.num_replicas_in_sync, self.evaluate(x_s)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/distribute/tpu_strategy.py b/tensorflow/python/distribute/tpu_strategy.py index b574c523ccd..a8ffa618064 100644 --- a/tensorflow/python/distribute/tpu_strategy.py +++ b/tensorflow/python/distribute/tpu_strategy.py @@ -897,7 +897,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): if tensor_util.is_tensor(input_tensor): rank = input_tensor.get_shape().rank else: - rank = np.rank(input_tensor) + rank = np.ndim(input_tensor) maximum_shape = tensor_shape.TensorShape([None] * rank) maximum_shapes.append(maximum_shape) maximum_shapes = nest.pack_sequence_as(replicate_inputs[0], diff --git a/tensorflow/python/distribute/tpu_strategy_test.py b/tensorflow/python/distribute/tpu_strategy_test.py index de4c975d5ef..6c93e29c028 100644 --- a/tensorflow/python/distribute/tpu_strategy_test.py +++ b/tensorflow/python/distribute/tpu_strategy_test.py @@ -28,6 +28,7 @@ from tensorflow.python.eager import def_function from tensorflow.python.eager import function from tensorflow.python.eager import remote from tensorflow.python.eager import test +from tensorflow.python.framework import config from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops @@ -140,6 +141,9 @@ class TPUStrategyTest(test.TestCase): # for non-local TPU. if FLAGS.tpu: self.skipTest("Recovery fails for non-local TPU, see b/148150981") + + # Disable automatic outside compilation. + config.set_soft_device_placement(False) strategy = get_tpu_strategy() @def_function.function @@ -164,6 +168,28 @@ class TPUStrategyTest(test.TestCase): good_run() + def test_dynamic_shape_with_outside_compilation_failure(self): + # Enable automatic outside compilation. + config.set_soft_device_placement(True) + strategy = get_tpu_strategy() + dataset = dataset_ops.Dataset.from_tensors(("string", 1.0)).repeat().batch( + 2, drop_remainder=False) + dataset = strategy.experimental_distribute_dataset(dataset) + iterator = iter(dataset) + + @def_function.function + def train_fn(iterator): + + def step_fn(inputs): + _, inputs = inputs + return math_ops.reduce_sum(inputs) + + return strategy.experimental_local_results( + strategy.run(step_fn, args=(next(iterator),))) + + with self.assertRaisesRegex(errors.InternalError, "Compilation failure"): + logging.info(train_fn(iterator)) + def test_computation_on_subset_cores(self): resolver = get_tpu_cluster_resolver() remote.connect_to_cluster(resolver) diff --git a/tensorflow/python/distribute/values.py b/tensorflow/python/distribute/values.py index 444915aa123..d03628f4714 100644 --- a/tensorflow/python/distribute/values.py +++ b/tensorflow/python/distribute/values.py @@ -43,6 +43,7 @@ from tensorflow.python.util import nest from tensorflow.python.util.tf_export import tf_export +# Utility functions used by the different classes below. def _get_current_replica_id_as_int(): """Returns the current replica ID as an integer, or `None`.""" replica_context = ds_context.get_replica_context() @@ -55,6 +56,59 @@ def _get_current_replica_id_as_int(): return replica_id +def _assign_on_device(device, variable, tensor): + with ops.device(device): + return variable.assign(tensor) + + +def _assign_add_on_device(device, variable, tensor): + with ops.device(device): + return variable.assign_add(tensor) + + +def _assign_sub_on_device(device, variable, tensor): + with ops.device(device): + return variable.assign_sub(tensor) + + +def _assert_replica_context(strategy): + replica_context = ds_context.get_replica_context() + if not replica_context: + raise RuntimeError( + "Replica-local variables may only be assigned in a replica context.") + if replica_context.strategy is not strategy: + raise RuntimeError( + "Replica-local variables may only be assigned in a replica context.") + + +def _apply_aggregation(strategy, value, aggregation, destinations): + if aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: + return strategy.extended.broadcast_to( + strategy.experimental_local_results(value)[0], + destinations=destinations) + reduce_op = reduce_util.ReduceOp.from_variable_aggregation(aggregation) + return strategy.extended.reduce_to(reduce_op, value, destinations) + + +_aggregation_error_msg = ( + "You must specify an aggregation method to update a " + "{variable_type} in Replica Context. You can do so by passing " + "an explicit value for argument `aggregation` to tf.Variable(..)." + "e.g. `tf.Variable(..., aggregation=tf.VariableAggregation.SUM)`" + "`tf.VariableAggregation` lists the possible aggregation methods." + "This is required because {variable_type} should always be " + "kept in sync. When updating them or assigning to them in a " + "replica context, we automatically try to aggregate the values " + "before updating the variable. For this aggregation, we need to " + "know the aggregation method. " + "Another alternative is to not try to update such " + "{variable_type} in replica context, but in cross replica " + "context. You can enter cross replica context by calling " + "`tf.distribute.get_replica_context().merge_call(merge_fn, ..)`." + "Inside `merge_fn`, you can then update the {variable_type} " + "using `tf.distribute.StrategyExtended.update()`.") + + @tf_export("distribute.DistributedValues", v1=[]) class DistributedValues(object): """Base class for representing distributed values. @@ -139,7 +193,7 @@ class DistributedValues(object): "This method should be overridden by sub-classes which support cross-" "replica accesses.") - def _get_closest(self): + def _get_on_device_or_primary(self): """Returns value in same replica or device if possible, else the _primary.""" replica_id = _get_current_replica_id_as_int() if replica_id is None: @@ -379,7 +433,7 @@ class Mirrored(DistributedDelegate): """Holds a map from replica to values which are kept in sync.""" def _get_cross_replica(self): - return self._get_closest() + return self._get_on_device_or_primary() def _as_graph_element(self): obj = self._get() @@ -389,21 +443,6 @@ class Mirrored(DistributedDelegate): return obj -def _assign_on_device(device, variable, tensor): - with ops.device(device): - return variable.assign(tensor) - - -def _assign_add_on_device(device, variable, tensor): - with ops.device(device): - return variable.assign_add(tensor) - - -def _assign_sub_on_device(device, variable, tensor): - with ops.device(device): - return variable.assign_sub(tensor) - - class DistributedVarOp(object): """A class that looks like `tf.Operation`.""" @@ -480,11 +519,11 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable, return init_op def initialized_value(self): - return self._get_closest().initialized_value() + return self._get_on_device_or_primary().initialized_value() @property def initial_value(self): - return self._get_closest().initial_value + return self._get_on_device_or_primary().initial_value @property def constraint(self): @@ -537,7 +576,7 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable, return self._values[replica_id].handle def eval(self, session=None): - return self._get_closest().eval(session) + return self._get_on_device_or_primary().eval(session) @property def _save_slice_info(self): @@ -552,7 +591,7 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable, @property def device(self): - return self._get_closest().device + return self._get_on_device_or_primary().device @property def trainable(self): @@ -587,7 +626,7 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable, return array_ops.identity(self._get()) def value(self): - return self._get_closest().value() + return self._get_on_device_or_primary().value() def numpy(self): if context.executing_eagerly(): @@ -743,59 +782,6 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable, pass -def _validate_colocate_extended(v, extended): - variable_strategy = v._distribute_strategy # pylint: disable=protected-access - if variable_strategy.extended is not extended: - raise ValueError( - "`colocate_vars_with` must only be passed a variable created in this " - "tf.distribute.Strategy.scope(), not %s created in scope: %s" % - (v, variable_strategy)) - - -def validate_colocate_distributed_variable(v, extended): - if not isinstance(v, DistributedVariable): - raise ValueError( - "`colocate_vars_with` must only be passed a variable created in this " - "tf.distribute.Strategy.scope(), not: %r" % (v,)) - _validate_colocate_extended(v, extended) - - -def validate_colocate(v, extended): - if not hasattr(v, "_distribute_strategy"): - raise ValueError( - "`colocate_vars_with` must only be passed a variable created in this " - "tf.distribute.Strategy.scope(), not: %r" % (v,)) - _validate_colocate_extended(v, extended) - - -def _apply_aggregation(strategy, value, aggregation, destinations): - if aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: - return strategy.extended.broadcast_to( - strategy.experimental_local_results(value)[0], - destinations=destinations) - reduce_op = reduce_util.ReduceOp.from_variable_aggregation(aggregation) - return strategy.extended.reduce_to(reduce_op, value, destinations) - - -_aggregation_error_msg = ( - "You must specify an aggregation method to update a " - "{variable_type} in Replica Context. You can do so by passing " - "an explicit value for argument `aggregation` to tf.Variable(..)." - "e.g. `tf.Variable(..., aggregation=tf.VariableAggregation.SUM)`" - "`tf.VariableAggregation` lists the possible aggregation methods." - "This is required because {variable_type} should always be " - "kept in sync. When updating them or assigning to them in a " - "replica context, we automatically try to aggregate the values " - "before updating the variable. For this aggregation, we need to " - "know the aggregation method. " - "Another alternative is to not try to update such " - "{variable_type} in replica context, but in cross replica " - "context. You can enter cross replica context by calling " - "`tf.distribute.get_replica_context().merge_call(merge_fn, ..)`." - "Inside `merge_fn`, you can then update the {variable_type} " - "using `tf.distribute.StrategyExtended.update()`.") - - class _MirroredSaveable(saveable_object_util.ResourceVariableSaveable): """Class for defining how to restore a MirroredVariable.""" @@ -812,6 +798,276 @@ class _MirroredSaveable(saveable_object_util.ResourceVariableSaveable): for v in self._mirrored_variable.values)) +class MirroredVariable(DistributedVariable, Mirrored): + """Holds a map from replica to variables whose values are kept in sync.""" + + def _update_replica(self, update_fn, value, **kwargs): + if self.aggregation == vs.VariableAggregation.NONE: + raise ValueError( + _aggregation_error_msg.format(variable_type="MirroredVariable")) + + def merge_fn(strategy, value, **kwargs): + """Aggregate values and update all variables in cross replica context.""" + # Don't allow MEAN with non float dtype, since it may cause unexpected + # precision loss. Python3 and NumPy automatically upcast integers to + # float in division, but we should always preserve the type. + # + # Note that to be backward compatible we allow the case when the value + # is *always* the same on each replica. I.E. value is not a + # PerReplica. Refer to regroup() to see how values are grouped. + if self._aggregation == vs.VariableAggregation.MEAN and ( + not self.dtype.is_floating) and isinstance(value, PerReplica): + raise ValueError( + "Cannot update non-float variables with " + "tf.VariableAggregation.MEAN aggregation in replica context. " + "Either change the variable dtype to float or update it in " + "cross-replica context.") + + assert strategy == self.distribute_strategy + v = _apply_aggregation(strategy, value, self.aggregation, self) + return self._update_cross_replica(update_fn, v, **kwargs) + + return ds_context.get_replica_context().merge_call( + merge_fn, args=(value,), kwargs=kwargs) + + def scatter_min(self, *args, **kwargs): + if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and + self._aggregation != vs.VariableAggregation.NONE): + raise NotImplementedError("scatter_min is only supported for mirrored " + "variable (variable created within certain " + "`tf.distribute.Strategy` scope) with NONE or " + "`ONLY_FIRST_REPLICA` aggregation, got: %s" % + self._aggregation) + return super(MirroredVariable, self).scatter_min(*args, **kwargs) + + def scatter_max(self, *args, **kwargs): + if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and + self._aggregation != vs.VariableAggregation.NONE): + raise NotImplementedError("scatter_max is only supported for mirrored " + "variable (variable created within certain " + "`tf.distribute.Strategy` scope) with NONE or " + "`ONLY_FIRST_REPLICA` aggregation, got: %s" % + self._aggregation) + return super(MirroredVariable, self).scatter_max(*args, **kwargs) + + def scatter_update(self, *args, **kwargs): + if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and + self._aggregation != vs.VariableAggregation.NONE): + raise NotImplementedError("scatter_update is only supported for mirrored " + "variable (variable created within certain " + "`tf.distribute.Strategy` scope) with NONE or " + "`ONLY_FIRST_REPLICA` aggregation, got: %s" % + self._aggregation) + return super(MirroredVariable, self).scatter_update(*args, **kwargs) + + def _get_cross_replica(self): + # Return identity, to avoid directly exposing the variable to the user and + # allowing it to be modified by mistake. + return array_ops.identity(Mirrored._get_cross_replica(self)) + + def _as_graph_element(self): + return self._get_on_device_or_primary()._as_graph_element() # pylint: disable=protected-access + + def _gather_saveables_for_checkpoint(self): + """Overrides Trackable method. + + This allows both name-based and object-based save and restore of + MirroredVariables. + + Returns: + A dictionary mapping attribute names to `SaveableObject` factories. + """ + + def _saveable_factory(name=self._common_name): + return _MirroredSaveable(self, self._primary, name) + + return {trackable.VARIABLE_VALUE_KEY: _saveable_factory} + + def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): + """Converts a variable to a tensor.""" + # Try to avoid assignments to and other mutations of MirroredVariable + # state except through a DistributionStrategy.extended.update() call. + if as_ref: + # A TF 1.x case where the variable is a boolean variable and used like: + # tf.cond(v, true_fn, false_fn). + raise ValueError( + "You may be using variable created under distribute strategy in TF " + "1.x control flows. Try explicitly converting the variable to Tensor " + "using variable.read_value(), or switch to TF 2.x.") + return ops.convert_to_tensor( + self._get(), dtype=dtype, name=name, as_ref=as_ref) + + +class _SyncOnReadSaveable(saveable_object.SaveableObject): + """Class for defining how to restore a SyncOnReadVariable.""" + + def __init__(self, sync_on_read_variable, name): + self._sync_on_read_variable = sync_on_read_variable + + # We use a callable so that we don't have to evaluate this expression + # in the case where we are trying to restore instead of save. + def tensor(): + strategy = sync_on_read_variable._distribute_strategy # pylint: disable=protected-access + return strategy.extended.read_var(sync_on_read_variable) + + spec = saveable_object.SaveSpec( + tensor=tensor, + slice_spec="", + name=name, + dtype=sync_on_read_variable.dtype, + device=sync_on_read_variable._primary.device) # pylint: disable=protected-access + + super(_SyncOnReadSaveable, self).__init__(tensor, [spec], name) + + def restore(self, restored_tensors, restored_shapes): + """Restore the same value into all variables.""" + # To preserve the sum across save and restore, we have to divide the + # total across all devices when restoring a variable that was summed + # when saving. + tensor, = restored_tensors + if self._sync_on_read_variable.aggregation == vs.VariableAggregation.SUM: + tensor = math_ops.cast(tensor / len(self._sync_on_read_variable._devices), # pylint: disable=protected-access + self._sync_on_read_variable.dtype) + return control_flow_ops.group( + tuple( + _assign_on_device(v.device, v, tensor) + for v in self._sync_on_read_variable.values)) + + +class SyncOnReadVariable(DistributedVariable): + """Holds a map from replica to variables whose values are reduced on save.""" + + def _update_replica(self, update_fn, value, **kwargs): + return update_fn(self._get_on_device_or_primary(), value, **kwargs) + + def _assign_on_each_device(self, assign_func, value, read_value): + update = control_flow_ops.group( + tuple( + assign_func(v.device, v, value) + for v in self._values)) + if not read_value: + return update + with ops.control_dependencies([update] if update else []): + return self.read_value() + + # TODO(b/154017756): Make assign behaivor in cross replica context consistent + # with MirroredVariable. + def assign_sub(self, value, use_locking=False, name=None, read_value=True): + with ds_context.enter_or_assert_strategy(self._distribute_strategy): + if ds_context.in_cross_replica_context(): + if self._aggregation == vs.VariableAggregation.SUM: + raise ValueError( + "SyncOnReadVariable does not support `assign_sub` in " + "cross-replica context when aggregation is set to " + "`tf.VariableAggregation.SUM`.") + return self._assign_on_each_device(_assign_sub_on_device, value, + read_value) + else: + return super(SyncOnReadVariable, + self).assign_sub(value, use_locking, name, read_value) + + def assign_add(self, value, use_locking=False, name=None, read_value=True): + with ds_context.enter_or_assert_strategy(self._distribute_strategy): + if ds_context.in_cross_replica_context(): + if self._aggregation == vs.VariableAggregation.SUM: + raise ValueError( + "SyncOnReadVariable does not support `assign_add` in " + "cross-replica context when aggregation is set to " + "`tf.VariableAggregation.SUM`.") + return self._assign_on_each_device(_assign_add_on_device, value, + read_value) + else: + return super(SyncOnReadVariable, + self).assign_add(value, use_locking, name, read_value) + + def assign(self, value, use_locking=False, name=None, read_value=True): + with ds_context.enter_or_assert_strategy(self._distribute_strategy): + if ds_context.in_cross_replica_context(): + # To preserve the sum across save and restore, we have to divide the + # total across all devices when restoring a variable that was summed + # when saving. + if self._aggregation == vs.VariableAggregation.SUM: + value = math_ops.cast(value / len(self._values), self.dtype) + return self._assign_on_each_device(_assign_on_device, value, + read_value) + else: + return super(SyncOnReadVariable, + self).assign(value, use_locking, name, read_value) + + def _scatter_not_implemented(self, method): + raise NotImplementedError( + "Variables with `synchronization=ON_READ` doesn't support `%s`" % + method) + + def scatter_sub(self, *args, **kwargs): + self._scatter_not_implemented("scatter_sub") + + def scatter_add(self, *args, **kwargs): + self._scatter_not_implemented("scatter_add") + + def scatter_mul(self, *args, **kwargs): + self._scatter_not_implemented("scatter_mul") + + def scatter_div(self, *args, **kwargs): + self._scatter_not_implemented("scatter_div") + + def scatter_min(self, *args, **kwargs): + self._scatter_not_implemented("scatter_min") + + def scatter_max(self, *args, **kwargs): + self._scatter_not_implemented("scatter_max") + + def scatter_update(self, *args, **kwargs): + self._scatter_not_implemented("scatter_update") + + def value(self): + with ds_context.enter_or_assert_strategy(self._distribute_strategy): + if ds_context.in_cross_replica_context(): + return self._get_cross_replica() + else: + # _get_on_device_or_primary() returns a Variable. + return self._get_on_device_or_primary().value() + + def _get_cross_replica(self): + if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: + return self._primary + + with ds_context.enter_or_assert_strategy(self._distribute_strategy): + return self._distribute_strategy.reduce( + reduce_util.ReduceOp.from_variable_aggregation(self.aggregation), + self, + axis=None) + + def _as_graph_element(self): + # pylint: disable=protected-access + with ds_context.enter_or_assert_strategy(self._distribute_strategy): + if ds_context.in_cross_replica_context(): + return ops.convert_to_tensor(self._get_cross_replica()) + return self._get()._as_graph_element() + + def _gather_saveables_for_checkpoint(self): + """Overrides Trackable method. + + This allows both name-based and object-based save and restore of + `SyncOnReadVariable`s. + + Returns: + A dictionary mapping attribute names to `SaveableObject` factories. + """ + + def _saveable_factory(name=self._common_name): + return _SyncOnReadSaveable(self, name) + + return {trackable.VARIABLE_VALUE_KEY: _saveable_factory} + + def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): + """Converts a variable to a tensor.""" + with ds_context.enter_or_assert_strategy(self._distribute_strategy): + return ops.convert_to_tensor( + self._get(), dtype=dtype, name=name, as_ref=as_ref) + + +# Variable creation function for sync strategies. def create_mirrored_variable( # pylint: disable=missing-docstring strategy, real_mirrored_creator, mirrored_cls, sync_on_read_cls, **kwargs): # Figure out what collections this variable should be added to. @@ -893,108 +1149,9 @@ def create_mirrored_variable( # pylint: disable=missing-docstring return result -class MirroredVariable(DistributedVariable, Mirrored): - """Holds a map from replica to variables whose values are kept in sync.""" - - def _update_replica(self, update_fn, value, **kwargs): - if self.aggregation == vs.VariableAggregation.NONE: - raise ValueError( - _aggregation_error_msg.format(variable_type="MirroredVariable")) - - def merge_fn(strategy, value, **kwargs): - """Aggregate values and update all variables in cross replica context.""" - # Don't allow MEAN with non float dtype, since it may cause unexpected - # precision loss. Python3 and NumPy automatically upcast integers to - # float in division, but we should always preserve the type. - # - # Note that to be backward compatible we allow the case when the value - # is *always* the same on each replica. I.E. value is not a - # PerReplica. Refer to regroup() to see how values are grouped. - if self._aggregation == vs.VariableAggregation.MEAN and ( - not self.dtype.is_floating) and isinstance(value, PerReplica): - raise ValueError( - "Cannot update non-float variables with " - "tf.VariableAggregation.MEAN aggregation in replica context. " - "Either change the variable dtype to float or update it in " - "cross-replica context.") - - assert strategy == self.distribute_strategy - v = _apply_aggregation(strategy, value, self.aggregation, self) - return self._update_cross_replica(update_fn, v, **kwargs) - - return ds_context.get_replica_context().merge_call( - merge_fn, args=(value,), kwargs=kwargs) - - def scatter_min(self, *args, **kwargs): - if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and - self._aggregation != vs.VariableAggregation.NONE): - raise NotImplementedError("scatter_min is only supported for mirrored " - "variable (variable created within certain " - "`tf.distribute.Strategy` scope) with NONE or " - "`ONLY_FIRST_REPLICA` aggregation, got: %s" % - self._aggregation) - return super(MirroredVariable, self).scatter_min(*args, **kwargs) - - def scatter_max(self, *args, **kwargs): - if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and - self._aggregation != vs.VariableAggregation.NONE): - raise NotImplementedError("scatter_max is only supported for mirrored " - "variable (variable created within certain " - "`tf.distribute.Strategy` scope) with NONE or " - "`ONLY_FIRST_REPLICA` aggregation, got: %s" % - self._aggregation) - return super(MirroredVariable, self).scatter_max(*args, **kwargs) - - def scatter_update(self, *args, **kwargs): - if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and - self._aggregation != vs.VariableAggregation.NONE): - raise NotImplementedError("scatter_update is only supported for mirrored " - "variable (variable created within certain " - "`tf.distribute.Strategy` scope) with NONE or " - "`ONLY_FIRST_REPLICA` aggregation, got: %s" % - self._aggregation) - return super(MirroredVariable, self).scatter_update(*args, **kwargs) - - def _get_cross_replica(self): - # Return identity, to avoid directly exposing the variable to the user and - # allowing it to be modified by mistake. - return array_ops.identity(Mirrored._get_cross_replica(self)) - - def _as_graph_element(self): - return self._get_closest()._as_graph_element() # pylint: disable=protected-access - - def _gather_saveables_for_checkpoint(self): - """Overrides Trackable method. - - This allows both name-based and object-based save and restore of - MirroredVariables. - - Returns: - A dictionary mapping attribute names to `SaveableObject` factories. - """ - - def _saveable_factory(name=self._common_name): - return _MirroredSaveable(self, self._primary, name) - - return {trackable.VARIABLE_VALUE_KEY: _saveable_factory} - - def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): - """Converts a variable to a tensor.""" - # Try to avoid assignments to and other mutations of MirroredVariable - # state except through a DistributionStrategy.extended.update() call. - if as_ref: - # A TF 1.x case where the variable is a boolean variable and used like: - # tf.cond(v, true_fn, false_fn). - raise ValueError( - "You may be using variable created under distribute strategy in TF " - "1.x control flows. Try explicitly converting the variable to Tensor " - "using variable.read_value(), or switch to TF 2.x.") - return ops.convert_to_tensor( - self._get(), dtype=dtype, name=name, as_ref=as_ref) - - -# Register a conversion function which reads the value of the variable, +# Register a conversion functions which reads the value of the variable, # allowing instances of the class to be used as tensors. +# MirroredVariables def _tensor_conversion_mirrored(var, dtype=None, name=None, as_ref=False): return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access @@ -1003,6 +1160,7 @@ ops.register_tensor_conversion_function(MirroredVariable, _tensor_conversion_mirrored) +# Mirrored Values def _tensor_conversion_mirrored_val(value, dtype=None, name=None, as_ref=False): return ops.convert_to_tensor( value._get(), dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access @@ -1012,184 +1170,7 @@ ops.register_tensor_conversion_function(Mirrored, _tensor_conversion_mirrored_val) -def is_distributed_variable(v): - """Determine if a variable is ds variable or TPU mirrored variable.""" - return isinstance(v, DistributedVariable) - - -class _SyncOnReadSaveable(saveable_object.SaveableObject): - """Class for defining how to restore a SyncOnReadVariable.""" - - def __init__(self, sync_on_read_variable, name): - self._sync_on_read_variable = sync_on_read_variable - - # We use a callable so that we don't have to evaluate this expression - # in the case where we are trying to restore instead of save. - def tensor(): - strategy = sync_on_read_variable._distribute_strategy # pylint: disable=protected-access - return strategy.extended.read_var(sync_on_read_variable) - - spec = saveable_object.SaveSpec( - tensor=tensor, - slice_spec="", - name=name, - dtype=sync_on_read_variable.dtype, - device=sync_on_read_variable._primary.device) # pylint: disable=protected-access - - super(_SyncOnReadSaveable, self).__init__(tensor, [spec], name) - - def restore(self, restored_tensors, restored_shapes): - """Restore the same value into all variables.""" - # To preserve the sum across save and restore, we have to divide the - # total across all devices when restoring a variable that was summed - # when saving. - tensor, = restored_tensors - if self._sync_on_read_variable.aggregation == vs.VariableAggregation.SUM: - tensor = math_ops.cast(tensor / len(self._sync_on_read_variable._devices), # pylint: disable=protected-access - self._sync_on_read_variable.dtype) - return control_flow_ops.group( - tuple( - _assign_on_device(v.device, v, tensor) - for v in self._sync_on_read_variable.values)) - - -def _assert_replica_context(strategy): - replica_context = ds_context.get_replica_context() - if not replica_context: - raise RuntimeError( - "Replica-local variables may only be assigned in a replica context.") - if replica_context.strategy is not strategy: - raise RuntimeError( - "Replica-local variables may only be assigned in a replica context.") - - -class SyncOnReadVariable(DistributedVariable): - """Holds a map from replica to variables whose values are reduced on save.""" - - def _update_replica(self, update_fn, value, **kwargs): - return update_fn(self._get_closest(), value, **kwargs) - - # TODO(b/154017756): Make assign behaivor in cross replica context consistent - # with MirroredVariable. - def assign_sub(self, *args, **kwargs): - with ds_context.enter_or_assert_strategy(self._distribute_strategy): - if ds_context.in_cross_replica_context(): - if self._aggregation == vs.VariableAggregation.SUM: - raise ValueError( - "SyncOnReadVariable does not support `assign_sub` in " - "cross-replica context when aggregation is set to " - "`tf.VariableAggregation.SUM`.") - return control_flow_ops.group( - tuple( - _assign_sub_on_device(v.device, v, args[0]) - for v in self._values)) - else: - return super(SyncOnReadVariable, self).assign_sub(*args, **kwargs) - - def assign_add(self, *args, **kwargs): - with ds_context.enter_or_assert_strategy(self._distribute_strategy): - if ds_context.in_cross_replica_context(): - if self._aggregation == vs.VariableAggregation.SUM: - raise ValueError( - "SyncOnReadVariable does not support `assign_add` in " - "cross-replica context when aggregation is set to " - "`tf.VariableAggregation.SUM`.") - return control_flow_ops.group( - tuple( - _assign_add_on_device(v.device, v, args[0]) - for v in self._values)) - else: - return super(SyncOnReadVariable, self).assign_add(*args, **kwargs) - - def assign(self, *args, **kwargs): - with ds_context.enter_or_assert_strategy(self._distribute_strategy): - if ds_context.in_cross_replica_context(): - # To preserve the sum across save and restore, we have to divide the - # total across all devices when restoring a variable that was summed - # when saving. - tensor = args[0] - if self._aggregation == vs.VariableAggregation.SUM: - tensor = math_ops.cast(tensor / len(self._values), self.dtype) - return control_flow_ops.group( - tuple(_assign_on_device(v.device, v, tensor) for v in self._values)) - else: - return super(SyncOnReadVariable, self).assign(*args, **kwargs) - - def _scatter_not_implemented(self, method): - raise NotImplementedError( - "Variables with `synchronization=ON_READ` doesn't support `%s`" % - method) - - def scatter_sub(self, *args, **kwargs): - self._scatter_not_implemented("scatter_sub") - - def scatter_add(self, *args, **kwargs): - self._scatter_not_implemented("scatter_add") - - def scatter_mul(self, *args, **kwargs): - self._scatter_not_implemented("scatter_mul") - - def scatter_div(self, *args, **kwargs): - self._scatter_not_implemented("scatter_div") - - def scatter_min(self, *args, **kwargs): - self._scatter_not_implemented("scatter_min") - - def scatter_max(self, *args, **kwargs): - self._scatter_not_implemented("scatter_max") - - def scatter_update(self, *args, **kwargs): - self._scatter_not_implemented("scatter_update") - - def value(self): - with ds_context.enter_or_assert_strategy(self._distribute_strategy): - if ds_context.in_cross_replica_context(): - return self._get_cross_replica() - else: - # _get_closest() returns a Variable. - return self._get_closest().value() - - def _get_cross_replica(self): - if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: - return self._primary - - with ds_context.enter_or_assert_strategy(self._distribute_strategy): - return self._distribute_strategy.reduce( - reduce_util.ReduceOp.from_variable_aggregation(self.aggregation), - self, - axis=None) - - def _as_graph_element(self): - # pylint: disable=protected-access - with ds_context.enter_or_assert_strategy(self._distribute_strategy): - if ds_context.in_cross_replica_context(): - return ops.convert_to_tensor(self._get_cross_replica()) - return self._get()._as_graph_element() - - def _gather_saveables_for_checkpoint(self): - """Overrides Trackable method. - - This allows both name-based and object-based save and restore of - `SyncOnReadVariable`s. - - Returns: - A dictionary mapping attribute names to `SaveableObject` factories. - """ - - def _saveable_factory(name=self._common_name): - return _SyncOnReadSaveable(self, name) - - return {trackable.VARIABLE_VALUE_KEY: _saveable_factory} - - def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): - """Converts a variable to a tensor.""" - with ds_context.enter_or_assert_strategy(self._distribute_strategy): - return ops.convert_to_tensor( - self._get(), dtype=dtype, name=name, as_ref=as_ref) - - -# Register a conversion function for SyncOnReadVariable which allows as_ref to -# be true. +# SyncOnReadVariables def _tensor_conversion_sync_on_read(var, dtype=None, name=None, as_ref=False): return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access @@ -1379,6 +1360,37 @@ def value_container(val): return val +def is_distributed_variable(v): + """Determine if a variable is ds variable or TPU mirrored variable.""" + return isinstance(v, DistributedVariable) + + +def _validate_colocate_extended(v, extended): + variable_strategy = v._distribute_strategy # pylint: disable=protected-access + if variable_strategy.extended is not extended: + raise ValueError( + "`colocate_vars_with` must only be passed a variable created in this " + "tf.distribute.Strategy.scope(), not %s created in scope: %s" % + (v, variable_strategy)) + + +def validate_colocate_distributed_variable(v, extended): + if not isinstance(v, DistributedVariable): + raise ValueError( + "`colocate_vars_with` must only be passed a variable created in this " + "tf.distribute.Strategy.scope(), not: %r" % (v,)) + _validate_colocate_extended(v, extended) + + +def validate_colocate(v, extended): + if not hasattr(v, "_distribute_strategy"): + raise ValueError( + "`colocate_vars_with` must only be passed a variable created in this " + "tf.distribute.Strategy.scope(), not: %r" % (v,)) + _validate_colocate_extended(v, extended) + + +# Variable used in PSStrategy TF 1 and CentralStorageStrategy. class AggregatingVariable(variables_lib.Variable, core.Tensor): """A wrapper around a variable that aggregates updates across replicas.""" diff --git a/tensorflow/python/distribute/values_test.py b/tensorflow/python/distribute/values_test.py index 67ed86b4047..bbff6c631cf 100644 --- a/tensorflow/python/distribute/values_test.py +++ b/tensorflow/python/distribute/values_test.py @@ -651,7 +651,10 @@ class DistributedVariableTest(test.TestCase, parameterized.TestCase): self.assertIsInstance(v.assign_add(delta), core.Tensor) # In cross replica context we return a PerReplica which is not Tensor like - # yet. + # all the time yet. + if (synchronization == variables_lib.VariableSynchronization.ON_READ and + aggregation != variables_lib.VariableAggregation.SUM): + assert_is_tensor_like(v) # In replica context. distribution.run(assert_is_tensor_like, args=(v,)) @@ -1610,10 +1613,16 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase): variables_lib.VariableAggregation.MEAN, variables_lib.VariableAggregation.ONLY_FIRST_REPLICA, ] - options = ( # VariableAggregation.SUM in cross-replica mode is tested below - [x for x in itertools.product(updates, aggregations, [True, False]) - if not(x[1] == variables_lib.VariableAggregation.SUM and x[2])]) + options = list( + x for x in itertools.product(updates, aggregations, [True, False])) for update, aggregation, cross_replica in options: + # VariableAggregation.SUM in cross-replica mode is tested below, + # VariableAggregation.NONE in cross-replica mode is not supported. + if cross_replica and aggregation in [ + variables_lib.VariableAggregation.SUM, + variables_lib.VariableAggregation.NONE, + ]: + continue with distribution.scope(): v = variable_scope.variable( 0., @@ -1647,10 +1656,16 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase): variables_lib.VariableAggregation.MEAN, variables_lib.VariableAggregation.ONLY_FIRST_REPLICA, ] - options = ( # VariableAggregation.SUM in cross-replica mode is tested below - [x for x in itertools.product(updates, aggregations, [True, False]) - if not(x[1] == variables_lib.VariableAggregation.SUM and x[2])]) + options = list( + x for x in itertools.product(updates, aggregations, [True, False])) for update, aggregation, cross_replica in options: + # VariableAggregation.SUM in cross-replica mode is tested below, + # VariableAggregation.NONE in cross-replica mode is not supported. + if cross_replica and aggregation in [ + variables_lib.VariableAggregation.SUM, + variables_lib.VariableAggregation.NONE, + ]: + continue with distribution.scope(): v = variable_scope.variable( 0., @@ -1722,8 +1737,8 @@ class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase): experimental_run_tf_function): aggregations = [ variables_lib.VariableAggregation.SUM, - variables_lib.VariableAggregation.MEAN, - variables_lib.VariableAggregation.ONLY_FIRST_REPLICA, + # variables_lib.VariableAggregation.MEAN, + # variables_lib.VariableAggregation.ONLY_FIRST_REPLICA, ] for aggregation in aggregations: if isinstance(distribution, _TPU_STRATEGIES): diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index c08cb8cc1c3..adc30eab5e1 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -1,7 +1,7 @@ -load("//tensorflow:tensorflow.bzl", "tf_py_test") load("//tensorflow:tensorflow.bzl", "cuda_py_test") # buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "tf_py_test") load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension") load("//tensorflow/python/tpu:tpu.bzl", "tpu_py_test") load( @@ -432,6 +432,7 @@ cuda_py_test( srcs = ["function_test.py"], python_version = "PY3", shard_count = 15, + tags = ["nomac"], # b/157056289 deps = [ ":backprop", ":cancellation", diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index fb7c4055136..dc7bb7c4b11 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -241,6 +241,11 @@ def implicit_val_and_grad(f): "function was being computed.") sources = [v.handle for v in variables] + for s in sources: + if getattr(s, "is_packed", False): + raise ValueError( + "GradientTape.gradient is not supported on packed EagerTensors yet." + ) grad = imperative_grad.imperative_grad(this_tape, nest.flatten(end_node), sources) return end_node, list(zip(grad, variables)) @@ -548,6 +553,10 @@ def make_vjp(f, params=None, persistent=True): ] args = _ensure_unique_tensor_objects(parameter_positions, args) for i in parameter_positions: + if getattr(args[i], "is_packed", False): + raise ValueError( + "GradientTape.gradient is not supported on packed EagerTensors" + "yet.") sources.append(args[i]) tape.watch(this_tape, args[i]) result = f(*args) @@ -873,7 +882,7 @@ class GradientTape(object): Raises: ValueError: if it encounters something that is not a tensor. """ - for t in nest.flatten(tensor): + for t in nest.flatten(tensor, expand_composites=True): if not (_pywrap_utils.IsTensor(t) or _pywrap_utils.IsVariable(t)): raise ValueError("Passed in object of type {}, not tf.Tensor".format( type(t))) @@ -1032,6 +1041,10 @@ class GradientTape(object): logging.WARN, "The dtype of the source tensor must be " "floating (e.g. tf.float32) when calling GradientTape.gradient, " "got %r", t.dtype) + if getattr(t, "is_packed", False): + raise ValueError( + "GradientTape.gradient is not supported on packed EagerTensors yet." + ) if output_gradients is not None: output_gradients = [None if x is None else ops.convert_to_tensor(x) diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index b28aaa3a626..a0f98fc0a44 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -32,6 +32,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.framework import test_util @@ -48,6 +49,7 @@ from tensorflow.python.ops import nn_grad # pylint: disable=unused-import from tensorflow.python.ops import nn_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import variables from tensorflow.python.training import training @@ -1484,6 +1486,19 @@ class BackpropTest(test.TestCase, parameterized.TestCase): with self.assertRaisesRegexp(ValueError, 'ndarray'): g.watch(np.array(1.)) + def testWatchComposite(self): + """Test that tape.watch expands composites and watches component Tensors.""" + with backprop.GradientTape() as t: + values = constant_op.constant([1.0, 2.0], dtypes.float32) + s = sparse_tensor.SparseTensor( + indices=[[0, 0], [1, 2]], + values=values, + dense_shape=[3, 4]) + t.watch(s) + z = sparse_ops.sparse_reduce_sum_v2(s) + result = t.gradient(z, values) + self.assertAllEqual(result, [1.0, 1.0]) + def testWatchedVariablesAfterNonPersistentGradientCall(self): with backprop.GradientTape(persistent=False) as tape: x = resource_variable_ops.ResourceVariable(1.0) diff --git a/tensorflow/python/eager/benchmarks/resnet50/resnet50_test.py b/tensorflow/python/eager/benchmarks/resnet50/resnet50_test.py index 9d049a6d59d..30e2585e842 100644 --- a/tensorflow/python/eager/benchmarks/resnet50/resnet50_test.py +++ b/tensorflow/python/eager/benchmarks/resnet50/resnet50_test.py @@ -104,24 +104,24 @@ class ResNet50Test(tf.test.TestCase): context.async_wait() self.assertEqual((2, 1000), output.shape) - @test_util.disable_tfrt('b/155260334') + @test_util.disable_tfrt('Flaky test. b/157103729') def test_apply(self): self._apply(defun=False) @test_util.disable_tfrt( - 'TFE_ContextGetExecutorForThread not implemented for tfrt') + 'TFE_ContextGetExecutorForThread not implemented b/156188669') def test_apply_async(self): self._apply(defun=False, execution_mode=context.ASYNC) - @test_util.disable_tfrt('Graph is not supported yet.') + @test_util.disable_tfrt('Graph is not supported yet. b/156187905') def test_apply_with_defun(self): self._apply(defun=True) - @test_util.disable_tfrt('Graph is not supported yet.') + @test_util.disable_tfrt('Graph is not supported yet. b/156187905') def test_apply_with_defun_async(self): self._apply(defun=True, execution_mode=context.ASYNC) - @test_util.disable_tfrt('b/155260334') + @test_util.disable_tfrt('Flaky test. b/157103729') def test_apply_no_top(self): device, data_format = resnet50_test_util.device_and_data_format() model = resnet50.ResNet50(data_format, include_top=False) @@ -132,7 +132,7 @@ class ResNet50Test(tf.test.TestCase): if data_format == 'channels_first' else (2, 1, 1, 2048)) self.assertEqual(output_shape, output.shape) - @test_util.disable_tfrt('b/155260334') + @test_util.disable_tfrt('Flaky test. b/157103729') def test_apply_with_pooling(self): device, data_format = resnet50_test_util.device_and_data_format() model = resnet50.ResNet50(data_format, include_top=False, pooling='avg') @@ -141,7 +141,7 @@ class ResNet50Test(tf.test.TestCase): output = model(images, training=False) self.assertEqual((2, 2048), output.shape) - @test_util.disable_tfrt('b/155260334') + @test_util.disable_tfrt('Flaky test. b/157103729') def test_apply_no_average_pooling(self): device, data_format = resnet50_test_util.device_and_data_format() model = resnet50.ResNet50( @@ -153,7 +153,7 @@ class ResNet50Test(tf.test.TestCase): (2, 7, 7, 2048)) self.assertEqual(output_shape, output.shape) - @test_util.disable_tfrt('b/155260334') + @test_util.disable_tfrt('Flaky test. b/157103729') def test_apply_block3_strides(self): device, data_format = resnet50_test_util.device_and_data_format() model = resnet50.ResNet50( @@ -165,7 +165,7 @@ class ResNet50Test(tf.test.TestCase): (2, 1, 1, 2048)) self.assertEqual(output_shape, output.shape) - @test_util.disable_tfrt('b/155260334') + @test_util.disable_tfrt('Flaky test. b/157103729') def test_apply_retrieve_intermediates(self): device, data_format = resnet50_test_util.device_and_data_format() model = resnet50.ResNet50( @@ -220,15 +220,15 @@ class ResNet50Test(tf.test.TestCase): self.assertEqual(len(events), 2) self.assertEqual(events[1].summary.value[0].tag, 'loss') - @test_util.disable_tfrt('b/155260334') + @test_util.disable_tfrt('Flaky test. b/157103729') def test_train(self): self._test_train() - @test_util.disable_tfrt('b/155260334') + @test_util.disable_tfrt('TFE_ContextGetExecutorForThread missing b/156188669') def test_train_async(self): self._test_train(execution_mode=context.ASYNC) - @test_util.disable_tfrt('b/155260334') + @test_util.disable_tfrt('Flaky test. b/157103729') def test_no_garbage(self): device, data_format = resnet50_test_util.device_and_data_format() model = resnet50.ResNet50(data_format) @@ -337,7 +337,7 @@ class ResNet50Benchmarks(tf.test.Benchmark): defun=False, execution_mode=context.ASYNC) - @test_util.disable_tfrt('Graph is not supported yet.') + @test_util.disable_tfrt('Graph is not supported yet. b/156187905') def benchmark_eager_apply_with_defun(self): self._benchmark_eager_apply( 'eager_apply_with_defun', @@ -397,7 +397,7 @@ class ResNet50Benchmarks(tf.test.Benchmark): defun=False, execution_mode=context.ASYNC) - @test_util.disable_tfrt('Graph is not supported yet.') + @test_util.disable_tfrt('Graph is not supported yet. b/156187905') def benchmark_eager_train_with_defun(self): self._benchmark_eager_train( 'eager_train_with_defun', MockIterator, @@ -416,7 +416,7 @@ class ResNet50Benchmarks(tf.test.Benchmark): resnet50_test_util.device_and_data_format(), defun=False) - @test_util.disable_tfrt('Graph is not supported yet.') + @test_util.disable_tfrt('Graph is not supported yet. b/156187905') def benchmark_eager_train_datasets_with_defun(self): def make_iterator(tensors): diff --git a/tensorflow/python/eager/benchmarks_test.py b/tensorflow/python/eager/benchmarks_test.py index 227fca5ea6f..223b62ededa 100644 --- a/tensorflow/python/eager/benchmarks_test.py +++ b/tensorflow/python/eager/benchmarks_test.py @@ -120,6 +120,10 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): self._num_iters_2_by_2 = 30000 self._num_iters_100_by_784 = 30000 + # used for conv2d benchmarks + self._m_8_28_28_3 = random_ops.random_uniform((8, 28, 28, 3)) + self._m_1_3_3_1 = random_ops.random_uniform((1, 3, 3, 1)) + def _get_benchmark_name(self): """Mostly copied from benchmark.py _get_name().""" stack = tf_inspect.stack() @@ -305,6 +309,10 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): func = lambda: m * m self._run(func, num_iters) + def _benchmark_tf_conv2d(self, m1, m2, num_iters): + func = lambda: nn_ops.conv2d(m1, m2, strides=[1, 1, 1, 1], padding="VALID") + self._run(func, num_iters) + def _benchmark_tf_multiply_op(self, m, num_iters): func = lambda: math_ops.multiply(m, m) self._run(func, num_iters) @@ -339,6 +347,21 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): m = self._m_2.gpu() self._benchmark_tf_multiply_op(m, 30000) + def benchmark_tf_conv2d_CPU(self): + with context.device(CPU): + m1 = self._m_8_28_28_3.cpu() + m2 = self._m_1_3_3_1.cpu() + self._benchmark_tf_conv2d(m1, m2, 30000) + + @test_util.disable_tfrt("copy to GPU not supported") + def benchmark_tf_conv2d_GPU(self): + if not context.num_gpus(): + return + with context.device(GPU): + m1 = self._m_8_28_28_3.gpu() + m2 = self._m_1_3_3_1.gpu() + self._benchmark_tf_conv2d(m1, m2, 30000) + def benchmark_tf_identity(self): m = self._m_2 self._run(lambda: gen_array_ops.identity(m), 30000) @@ -595,7 +618,7 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): self._benchmark_tfe_py_execute_matmul( m, transpose_b=False, num_iters=self._num_iters_2_by_2) - @test_util.disable_tfrt("defun not supported") + @test_util.disable_tfrt("Graph is not supported yet. b/156187905") def benchmark_defun_matmul_2_by_2_GPU(self): if not context.num_gpus(): return @@ -616,7 +639,7 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): num_iters=self._num_iters_2_by_2, execution_mode=context.ASYNC) - @test_util.disable_tfrt("function not supported") + @test_util.disable_tfrt("Graph is not supported yet. b/156187905") def benchmark_nested_defun_matmul_2_by_2(self): m = self._m_2_by_2.cpu() self._benchmark_nested_defun_matmul( @@ -664,7 +687,7 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): self._benchmark_tfe_py_execute_matmul( m, transpose_b=True, num_iters=self._num_iters_100_by_784) - @test_util.disable_tfrt("function not supported") + @test_util.disable_tfrt("Graph is not supported yet. b/156187905") def benchmark_defun_matmul_100_by_784_CPU(self): with context.device(CPU): m = self._m_100_by_784.cpu() @@ -792,35 +815,35 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): func() self._run(func, 3000) - @test_util.disable_tfrt("defun not supported") + @test_util.disable_tfrt("Graph is not supported yet. b/156187905") def benchmark_forwardprop_matmul_256_by_2096_CPU(self): self._benchmark_forwardprop_matmul_CPU(shape=(256, 2096)) - @test_util.disable_tfrt("defun not supported") + @test_util.disable_tfrt("Graph is not supported yet. b/156187905") def benchmark_forwardprop_in_defun_matmul_256_by_2096_CPU(self): self._benchmark_forwardprop_in_defun_matmul_CPU(shape=(256, 2096)) - @test_util.disable_tfrt("defun not supported") + @test_util.disable_tfrt("Graph is not supported yet. b/156187905") def benchmark_forwardprop_in_defun_of_defun_matmul_256_by_2096_CPU(self): self._benchmark_forwardprop_in_defun_of_defun_matmul_CPU(shape=(256, 2096)) - @test_util.disable_tfrt("defun not supported") + @test_util.disable_tfrt("Graph is not supported yet. b/156187905") def benchmark_forwardprop_of_defun_matmul_256_by_2096_CPU(self): self._benchmark_forwardprop_of_defun_matmul_CPU(shape=(256, 2096)) - @test_util.disable_tfrt("defun not supported") + @test_util.disable_tfrt("Graph is not supported yet. b/156187905") def benchmark_forwardprop_matmul_100_by_784_CPU(self): self._benchmark_forwardprop_matmul_CPU(shape=(100, 784)) - @test_util.disable_tfrt("defun not supported") + @test_util.disable_tfrt("Graph is not supported yet. b/156187905") def benchmark_forwardprop_in_defun_matmul_100_by_784_CPU(self): self._benchmark_forwardprop_in_defun_matmul_CPU(shape=(100, 784)) - @test_util.disable_tfrt("defun not supported") + @test_util.disable_tfrt("Graph is not supported yet. b/156187905") def benchmark_forwardprop_in_defun_of_defun_matmul_100_by_784_CPU(self): self._benchmark_forwardprop_in_defun_of_defun_matmul_CPU(shape=(100, 784)) - @test_util.disable_tfrt("defun not supported") + @test_util.disable_tfrt("Graph is not supported yet. b/156187905") def benchmark_forwardprop_of_defun_matmul_100_by_784_CPU(self): self._benchmark_forwardprop_of_defun_matmul_CPU(shape=(100, 784)) @@ -1074,7 +1097,7 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): m = resource_variable_ops.ResourceVariable(self._m_2_by_2) self._benchmark_transpose(m, num_iters=self._num_iters_2_by_2) - @test_util.disable_tfrt("defun not supported") + @test_util.disable_tfrt("Graph is not supported yet. b/156187905") def benchmark_defun_without_signature(self): def func(t1, t2, t3, t4, t5, t6, t7, t8): @@ -1086,7 +1109,7 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): cache_computation = lambda: defined(t, t, t, t, t, t, t, t) self._run(cache_computation, 30000) - @test_util.disable_tfrt("defun not supported") + @test_util.disable_tfrt("Graph is not supported yet. b/156187905") def benchmark_defun_without_signature_and_with_kwargs(self): def func(t1, t2, t3, t4, t5, t6, t7, t8): @@ -1099,7 +1122,7 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): return defined(t1=t, t2=t, t3=t, t4=t, t5=t, t6=t, t7=t, t8=t) self._run(cache_computation, 30000) - @test_util.disable_tfrt("defun not supported") + @test_util.disable_tfrt("Graph is not supported yet. b/156187905") def benchmark_defun_with_signature(self): def func(t1, t2, t3, t4, t5, t6, t7, t8): @@ -1112,7 +1135,7 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): signature_computation = lambda: defined(t, t, t, t, t, t, t, t) self._run(signature_computation, 30000) - @test_util.disable_tfrt("defun not supported") + @test_util.disable_tfrt("Graph is not supported yet. b/156187905") def benchmark_defun_with_signature_and_kwargs(self): def func(t1, t2, t3, t4, t5, t6, t7, t8): @@ -1194,6 +1217,46 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): self._run(fn, 10000) + def _benchmark_convert_constant(self, value, cached): + global GLOBAL_TEST_VALUE + GLOBAL_TEST_VALUE = value + + def cached_func(): + ops.convert_to_tensor(value) + + def uncached_func(): + global GLOBAL_TEST_VALUE + GLOBAL_TEST_VALUE += 1 + ops.convert_to_tensor(GLOBAL_TEST_VALUE) + + func = cached_func if cached else uncached_func + + self._run(func, 10000) + + def benchmark_convert_python_int(self): + self._benchmark_convert_constant(42, cached=True) + + def benchmark_convert_python_int_uncached(self): + self._benchmark_convert_constant(42, cached=False) + + def benchmark_convert_python_float(self): + self._benchmark_convert_constant(42.0, cached=True) + + def benchmark_convert_python_float_uncached(self): + self._benchmark_convert_constant(42.0, cached=False) + + def benchmark_convert_numpy_int(self): + self._benchmark_convert_constant(np.array(42), cached=True) + + def benchmark_convert_numpy_int_uncached(self): + self._benchmark_convert_constant(np.array(42), cached=False) + + def benchmark_convert_numpy_float(self): + self._benchmark_convert_constant(np.array(42.0), cached=True) + + def benchmark_convert_numpy_float_uncached(self): + self._benchmark_convert_constant(np.array(42.0), cached=False) + @test_util.disable_tfrt("convert to tensor not supported") def benchmark_convert_3x_list_to_tensor(self): xs = [1, 2, 3] @@ -1242,11 +1305,11 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): resources.append(resource_variable_ops.ResourceVariable(self._m_2)) self._run(lambda: add_all(resources), num_iters) - @test_util.disable_tfrt("funtion not supported") + @test_util.disable_tfrt("Graph is not supported yet. b/156187905") def benchmarkFunctionWithFiveResourceInputs(self): self._benchmarkFunctionWithResourceInputs(5, 1000) - @test_util.disable_tfrt("funtion not supported") + @test_util.disable_tfrt("Graph is not supported yet. b/156187905") def benchmarkFunctionWithFiveHundredResourceInputs(self): self._benchmarkFunctionWithResourceInputs(500, 100) @@ -1281,15 +1344,15 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): with context.device(CPU): self._run(benchmark_fn, 10) - @test_util.disable_tfrt("funtion not supported") + @test_util.disable_tfrt("Graph is not supported yet. b/156187905") def benchmarkTenThousandResourceReadsInCondInInnerFunc(self): self._benchmarkResourceReadsInCondInInnerFunc(10000) - @test_util.disable_tfrt("funtion not supported") + @test_util.disable_tfrt("Graph is not supported yet. b/156187905") def benchmarkHundredResourceReadsInCondInInnerFunc(self): self._benchmarkResourceReadsInCondInInnerFunc(100) - @test_util.disable_tfrt("funtion not supported") + @test_util.disable_tfrt("Graph is not supported yet. b/156187905") def benchmarkTenResourceReadsInCondInInnerFunc(self): self._benchmarkResourceReadsInCondInInnerFunc(10) diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index 86b3d5cf95f..604a960afd5 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -1123,6 +1123,22 @@ class Context(object): pywrap_tfe.TFE_Py_RegisterCustomDevice(self._handle, device_capsule, device_name, device_info_capsule) + def pack_eager_tensors(self, tensors): + """Pack multiple `EagerTensor`s of the same dtype and shape. + + Args: + tensors: a list of EagerTensors to pack. + + Returns: + A packed EagerTensor. + """ + self.ensure_initialized() + if self._lazy_remote_inputs_copy is not None and ( + not self._lazy_remote_inputs_copy): + raise ValueError("Packing eager tensors is not supported when " + "lazy_remote_inputs_copy is disabled.") + return pywrap_tfe.TFE_Py_PackEagerTensors(self._handle, tensors) + def remove_function(self, name): """Remove a function from the context. diff --git a/tensorflow/python/eager/def_function_xla_jit_test.py b/tensorflow/python/eager/def_function_xla_jit_test.py index 13b46491d9f..b63a3b434d4 100644 --- a/tensorflow/python/eager/def_function_xla_jit_test.py +++ b/tensorflow/python/eager/def_function_xla_jit_test.py @@ -290,6 +290,101 @@ class DefFunctionTest(test.TestCase): y = f(x) tape.gradient(y, x) + def testTensorListConcatV2(self): + + def f(x): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, size=2, element_shape=[3]) + ta = ta.write(0, 2 * x) + ta = ta.write(1, 3 * x) + return ta.concat() + + compiled_f = def_function.function(experimental_compile=True)(f) + + inputs = constant_op.constant([3.14, 2.68, 7.69]) + + self.assertAllClose([6.28, 5.36, 15.38, 9.42, 8.04, 23.07], f(inputs)) + + self.assertAllClose(compiled_f(inputs), f(inputs)) + + def testTensorListConcatV2Multidim(self): + + def f(x): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, size=2, element_shape=[3, 2]) + ta = ta.write(0, 2 * x) + ta = ta.write(1, 3 * x) + return ta.concat() + + compiled_f = def_function.function(experimental_compile=True)(f) + + inputs = constant_op.constant([[3.14, 21.1], [2.68, 22.2], [7.69, 23.3]]) + self.assertAllClose(f(inputs), compiled_f(inputs)) + + def testTensorListConcatV2Scalars(self): + + def f(x): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, size=2, element_shape=[1]) + ta = ta.write(0, 2 * x) + ta = ta.write(1, 3 * x) + return ta.concat() + + compiled_f = def_function.function(experimental_compile=True)(f) + inputs = constant_op.constant([3.14]) + self.assertAllClose(f(inputs), compiled_f(inputs)) + + def testTensorListConcatGrad(self): + + def f(x): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, size=2, element_shape=[3]) + ta = ta.write(0, 2 * x) + ta = ta.write(1, 3 * x) + return ta.concat() + + def g(): + x = constant_op.constant([3.14, 2.68, 7.69]) + with backprop.GradientTape() as tape: + tape.watch(x) + y = f(x) + return tape.gradient(y, x) + + compiled_g = def_function.function(experimental_compile=True)(g) + + self.assertAllClose([5.0, 5.0, 5.0], g()) + self.assertAllClose(compiled_g(), g()) + + def testTensorListConcatGradNestedCompile(self): + + @def_function.function(experimental_compile=True) + def f(x): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, size=2, element_shape=[3]) + ta = ta.write(0, 2 * x) + ta = ta.write(1, 3 * x) + return ta.concat() + + @def_function.function(experimental_compile=True) + def g(): + x = constant_op.constant([3.14, 2.68, 7.69]) + with backprop.GradientTape() as tape: + tape.watch(x) + y = f(x) + out = tape.gradient(y, x) + return out + + self.assertAllClose([5.0, 5.0, 5.0], g()) + + def testCumsum(self): + + @def_function.function(experimental_compile=True) + def f(x): + return math_ops.cumsum(x) + + f64_input = constant_op.constant([1.1, 2.2, 3.3], dtype=dtypes.float64) + self.assertAllClose([1.1, 3.3, 6.6], f(f64_input)) + if __name__ == '__main__': ops.enable_eager_execution() diff --git a/tensorflow/python/eager/forwardprop_test.py b/tensorflow/python/eager/forwardprop_test.py index 4ddba6b9be3..dd0bad30cb8 100644 --- a/tensorflow/python/eager/forwardprop_test.py +++ b/tensorflow/python/eager/forwardprop_test.py @@ -199,7 +199,6 @@ def _test_gradients(testcase, # And the symbolic computations should be much closer. testcase.assertAllClose(sym_jac_back, sym_jac_fwd) - class ForwardpropTest(test.TestCase, parameterized.TestCase): def testJVPFunction(self): @@ -361,14 +360,17 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase): _test_gradients(self, f, [constant_op.constant([1., 2.])], order=3) - @test_util.assert_no_new_pyobjects_executing_eagerly - def testCustomGradientRecomputeGrad(self): + # TODO(allenl): investigate why assert_no_new_pyobjects_executing_eagerly fails around this test? + def testExceptionCustomGradientRecomputeGradForward(self): @custom_gradient.recompute_grad def f(x): return math_ops.reduce_prod(math_ops.tanh(x)**2) - _test_gradients(self, f, [constant_op.constant([1.])], order=3) + with self.assertRaisesRegexp(NotImplementedError, + "recompute_grad tried to transpose"): + primals = [constant_op.constant([1.])] + sym_jac_fwd = _jacfwd(f, primals) def testExceptionInCustomGradientNotSwallowed(self): diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 97708f056c2..ce495d772d0 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -1831,9 +1831,9 @@ class ConcreteFunction(object): `args` and `kwargs`. """ return self._call_flat( - (t for t in nest.flatten((args, kwargs), expand_composites=True) + [t for t in nest.flatten((args, kwargs), expand_composites=True) if isinstance(t, (ops.Tensor, - resource_variable_ops.BaseResourceVariable))), + resource_variable_ops.BaseResourceVariable))], captured_inputs=self.captured_inputs, cancellation_manager=cancellation_manager) @@ -1854,7 +1854,6 @@ class ConcreteFunction(object): Raises: ValueError: If `args` contains anything other than Tensors or Variables. """ - args = list(args) ctx = context.context() executing_eagerly = ctx.executing_eagerly() diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index 4e68f1460d9..078ca8b8878 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -186,6 +186,43 @@ class FunctionTest(test.TestCase, parameterized.TestCase): with self.assertRaisesRegexp(AttributeError, 'no attribute'): add(c) + def testPackedVariable(self): + with ops.device('/cpu:0'): + v0_0 = resource_variable_ops.ResourceVariable(1.0) + with ops.device('/cpu:1'): + v0_1 = resource_variable_ops.ResourceVariable(2.0) + v1_0 = resource_variable_ops.ResourceVariable(3.0) + with ops.device('/cpu:2'): + v1_1 = resource_variable_ops.ResourceVariable(4.0) + + packed_var_0 = ops.pack_eager_tensors([v0_0.handle, v0_1.handle]) + packed_var_1 = ops.pack_eager_tensors([v1_0.handle, v1_1.handle]) + + # TODO(b/145922293): use ResourceVariable.assign_add and + # ResourceVariable.read_value directly once we support packing multiple + # ResourceVariable into one ResourceVariable. + @def_function.function + def read_var(): + resource_variable_ops.assign_add_variable_op( + packed_var_0, constant_op.constant(5.0)) + resource_variable_ops.assign_add_variable_op( + packed_var_1, constant_op.constant(6.0)) + with ops.device('/cpu:0'): + read0 = resource_variable_ops.read_variable_op( + packed_var_0, dtype=dtypes.float32) + with ops.device('/cpu:1'): + read1 = resource_variable_ops.read_variable_op( + packed_var_0, dtype=dtypes.float32) + read2 = resource_variable_ops.read_variable_op( + packed_var_1, dtype=dtypes.float32) + with ops.device('/cpu:2'): + read3 = resource_variable_ops.read_variable_op( + packed_var_1, dtype=dtypes.float32) + + return read0, read1, read2, read3 + + self.assertAllEqual(read_var(), (1 + 5, 2 + 5, 3 + 6, 4 + 6)) + def testImplementsAttributeBasic(self): v = def_function.function( experimental_implements='func')(lambda x, y: x + y) diff --git a/tensorflow/python/eager/gradient_input_output_exclusions.py b/tensorflow/python/eager/gradient_input_output_exclusions.py index 94962bf6135..442151f667e 100644 --- a/tensorflow/python/eager/gradient_input_output_exclusions.py +++ b/tensorflow/python/eager/gradient_input_output_exclusions.py @@ -253,7 +253,8 @@ def _live_tensors(f, attr_name="inputs"): # Not a number, assuming it can be anything. return _ALL subscript_val, = subscript.qn - if not isinstance(subscript_val, qual_names.NumberLiteral): + if (not isinstance(subscript_val, qual_names.Literal) and + not isinstance(subscript_val.value, int)): # Not a number, assuming it can be anything. return _ALL input_output_indices.add(subscript_val.value) diff --git a/tensorflow/python/eager/monitoring.py b/tensorflow/python/eager/monitoring.py index 26d4d8a55b3..74d98558192 100644 --- a/tensorflow/python/eager/monitoring.py +++ b/tensorflow/python/eager/monitoring.py @@ -19,6 +19,8 @@ from __future__ import division from __future__ import print_function import collections +import functools +import time from tensorflow.core.framework import summary_pb2 from tensorflow.python import pywrap_tfe @@ -428,3 +430,46 @@ class Sampler(Metric): def get_cell(self, *labels): """Retrieves the cell.""" return SamplerCell(super(Sampler, self).get_cell(*labels)) + + +class MonitoredTimer(object): + """A context manager to measure the walltime and increment a Counter cell.""" + + def __init__(self, cell): + """Creates a new MonitoredTimer. + + Args: + cell: the cell associated with the time metric that will be inremented. + """ + self.cell = cell + + def __enter__(self): + self.t = time.time() + return self + + def __exit__(self, exception_type, exception_value, traceback): + del exception_type, exception_value, traceback + micro_seconds = (time.time() - self.t) * 1000000 + self.cell.increase_by(int(micro_seconds)) + + +def monitored_timer(cell): + """A function decorator for adding MonitoredTimer support. + + Arguments: + cell: the cell associated with the time metric that will be inremented. + Returns: + A decorator that measure the function runtime and increment the specified + counter cell. + """ + + def actual_decorator(func): + + @functools.wraps(func) + def wrapper(*args, **kwargs): + with MonitoredTimer(cell): + return func(*args, **kwargs) + + return wrapper + + return actual_decorator diff --git a/tensorflow/python/eager/monitoring_test.py b/tensorflow/python/eager/monitoring_test.py index 3f601735ef2..7cb8c0c2cd1 100644 --- a/tensorflow/python/eager/monitoring_test.py +++ b/tensorflow/python/eager/monitoring_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import time + from tensorflow.python.eager import monitoring from tensorflow.python.eager import test from tensorflow.python.framework import errors @@ -100,6 +102,26 @@ class MonitoringTest(test_util.TensorFlowTestCase): self.assertEqual(histogram_proto1.num, 2.0) self.assertEqual(histogram_proto1.sum, 6.0) + def test_context_manager(self): + counter = monitoring.Counter('test/ctxmgr', 'test context manager', 'slot') + with monitoring.MonitoredTimer(counter.get_cell('short')): + time.sleep(0.001) + with monitoring.MonitoredTimer(counter.get_cell('long')): + time.sleep(0.02) + self.assertGreater( + counter.get_cell('long').value(), + counter.get_cell('short').value()) + + def test_function_decorator(self): + counter = monitoring.Counter('test/funcdecorator', 'test func decorator') + + @monitoring.monitored_timer(counter.get_cell()) + def timed_function(seconds): + time.sleep(seconds) + + timed_function(0.001) + self.assertGreater(counter.get_cell().value(), 1000) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc index a72f74b38b8..b209ddb6162 100644 --- a/tensorflow/python/eager/pywrap_tensor.cc +++ b/tensorflow/python/eager/pywrap_tensor.cc @@ -345,6 +345,8 @@ typedef struct EagerTensor { char unused[kMaxEagerTensorParentSize]; TFE_TensorHandle* handle; int64_t id; + // Indicates whether it's a packed tensor or not. + bool is_packed; // This mirrors tensorflow.core.framework.ops.Tensor._handle_data Which will // be None for tensors of type other than DT_RESOURCE. For DT_RESOURCE // tensors, this will contain a serialized HandleData proto with shape @@ -418,6 +420,7 @@ bool MaybeInvokeCreatedOnEagerTensorProfiler(EagerTensor* created_tensor) { int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) { self->id = get_uid(); self->handle = nullptr; + self->is_packed = false; Py_INCREF(Py_None); self->handle_data = Py_None; Py_INCREF(Py_None); @@ -647,6 +650,11 @@ static PyObject* EagerTensor_backing_device(EagerTensor* self) { #endif } +// Getter `is_packed`. +static PyObject* EagerTensor_is_packed(EagerTensor* self) { + return PyBool_FromLong(self->is_packed); +} + static PyGetSetDef EagerTensor_getsetters[] = { {const_cast<char*>("_id"), (getter)EagerTensor_getid, nullptr, const_cast<char*>("Tensor ID."), nullptr}, @@ -655,6 +663,9 @@ static PyGetSetDef EagerTensor_getsetters[] = { {const_cast<char*>("backing_device"), (getter)EagerTensor_backing_device, nullptr, const_cast<char*>("Device on which tensor's memory is resident."), nullptr}, + {const_cast<char*>("is_packed"), (getter)EagerTensor_is_packed, nullptr, + const_cast<char*>("Whether the EagerTensor is a packed tensor or not."), + nullptr}, {const_cast<char*>("_handle_data"), (getter)EagerTensor_handle_data, (setter)EagerTensor_sethandle_data, const_cast<char*>("Shape/DType data if the EagerTensor is a DT_RESOURCE"), @@ -813,7 +824,8 @@ TFE_TensorHandle* EagerTensor_Handle(const PyObject* o) { return reinterpret_cast<const EagerTensor*>(o)->handle; } -PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle) { +PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle, + const bool is_packed) { if (handle == nullptr) { return nullptr; } @@ -821,6 +833,7 @@ PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle) { EagerTensorType->tp_new(EagerTensorType, EmptyTuple(), EmptyDict())); if (t != nullptr) { t->id = get_uid(); + t->is_packed = is_packed; Py_INCREF(Py_None); t->handle_data = Py_None; Py_INCREF(Py_None); diff --git a/tensorflow/python/eager/pywrap_tfe.h b/tensorflow/python/eager/pywrap_tfe.h index 92a0a200e3d..a5c9c181539 100755 --- a/tensorflow/python/eager/pywrap_tfe.h +++ b/tensorflow/python/eager/pywrap_tfe.h @@ -129,7 +129,8 @@ void TFE_DeleteContextCapsule(PyObject* context); bool EagerTensor_CheckExact(const PyObject* o); // Helper function to construct a new EagerTensor from a TFE_TensorHandle. -PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle); +PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle, + const bool is_packed = false); // Extracts the handle inside EagerTensor object `o`. Returns nullptr on error. TFE_TensorHandle* EagerTensor_Handle(const PyObject* o); diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index 2d96ed57246..639f623bd1a 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -852,6 +852,8 @@ void TFE_Py_ExecuteCancelable(TFE_Context* ctx, const char* device_name, TFE_CancellationManager* cancellation_manager, TFE_OutputTensorHandles* outputs, TF_Status* out_status) { + tensorflow::profiler::TraceMe activity( + "TFE_Py_ExecuteCancelable", tensorflow::profiler::TraceMeLevel::kInfo); TFE_Op* op = GetOp(ctx, op_name, device_name, out_status); auto cleaner = tensorflow::gtl::MakeCleanup([ctx, op] { ReturnOp(ctx, op); }); if (!out_status->status.ok()) return; diff --git a/tensorflow/python/eager/remote_test.py b/tensorflow/python/eager/remote_test.py index 32fe6372f77..710e7bf5f9d 100644 --- a/tensorflow/python/eager/remote_test.py +++ b/tensorflow/python/eager/remote_test.py @@ -40,6 +40,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import functional_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables from tensorflow.python.training import server_lib from tensorflow.python.training.server_lib import ClusterSpec @@ -324,6 +325,36 @@ class MultiWorkersTest(test.TestCase, parameterized.TestCase): self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0]) + def testMultiDeviceFunctionWithPackedVariable(self): + with ops.device('/job:worker/replica:0/task:0/device:CPU:0'): + var0 = resource_variable_ops.ResourceVariable(1.0) + with ops.device('/job:worker/replica:0/task:1/device:CPU:0'): + var1 = resource_variable_ops.ResourceVariable(2.0) + + packed_var = ops.pack_eager_tensors([var0.handle, var1.handle]) + self.assertEqual(packed_var.device, + '/job:localhost/replica:0/task:0/device:COMPOSITE:0') + self.assertEqual(packed_var.backing_device, + '/job:localhost/replica:0/task:0/device:COMPOSITE:0') + + @def_function.function + def add_variables(): + with ops.device('/job:worker/replica:0/task:0/device:CPU:0'): + read0 = resource_variable_ops.read_variable_op( + packed_var, dtype=dtypes.float32) + with ops.device('/job:worker/replica:0/task:1/device:CPU:0'): + read1 = resource_variable_ops.read_variable_op( + packed_var, dtype=dtypes.float32) + + return read0 + read1 + + # Run the function on a remote device + with ops.device('/job:worker/replica:0/task:0'): + self.assertAllEqual(add_variables().numpy(), 3.0) + + # Run the function on a local worker + self.assertAllEqual(add_variables().numpy(), 3.0) + @test_util.eager_lazy_remote_copy_on_and_off def testMultiDeviceFunctionOnRemoteDeviceWithWait(self): with ops.device('/job:worker/replica:0/task:1'): diff --git a/tensorflow/python/feature_column/BUILD b/tensorflow/python/feature_column/BUILD index d67cdf9cc06..786c26c009a 100644 --- a/tensorflow/python/feature_column/BUILD +++ b/tensorflow/python/feature_column/BUILD @@ -55,8 +55,6 @@ py_library( py_library( name = "feature_column_v2", srcs = [ - "dense_features.py", - "dense_features_v2.py", "feature_column_v2.py", "sequence_feature_column.py", "serialization.py", @@ -126,15 +124,6 @@ tf_py_test( ], ) -tf_py_test( - name = "dense_features_test", - srcs = ["dense_features_test.py"], - tags = ["no_pip"], - deps = [ - ":feature_column_test_main_lib", - ], -) - py_library( name = "feature_column_test_main_lib", srcs = ["feature_column_test.py"], @@ -177,15 +166,6 @@ tf_py_test( deps = [":feature_column_v2_test_main_lib"], ) -tf_py_test( - name = "dense_features_v2_test", - srcs = ["dense_features_v2_test.py"], - tags = ["no_pip"], - deps = [ - ":feature_column_v2_test_main_lib", - ], -) - py_library( name = "feature_column_v2_test_main_lib", srcs = ["feature_column_v2_test.py"], diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py index 87420d0e850..07df4e914c9 100644 --- a/tensorflow/python/feature_column/feature_column.py +++ b/tensorflow/python/feature_column/feature_column.py @@ -2546,7 +2546,7 @@ class _EmbeddingColumn( embedding_lookup_sparse = embedding_ops.safe_embedding_lookup_sparse if (not self.use_safe_embedding_lookup and sparse_id_rank is not None and sparse_id_rank <= 2): - embedding_lookup_sparse = embedding_ops.embedding_lookup_sparse + embedding_lookup_sparse = embedding_ops.embedding_lookup_sparse_v2 # Return embedding lookup result. return embedding_lookup_sparse( embedding_weights, @@ -2696,7 +2696,7 @@ class _SharedEmbeddingColumn( embedding_lookup_sparse = embedding_ops.safe_embedding_lookup_sparse if (not self.use_safe_embedding_lookup and sparse_id_rank is not None and sparse_id_rank <= 2): - embedding_lookup_sparse = embedding_ops.embedding_lookup_sparse + embedding_lookup_sparse = embedding_ops.embedding_lookup_sparse_v2 # Return embedding lookup result. return embedding_lookup_sparse( embedding_weights, diff --git a/tensorflow/python/feature_column/feature_column_lib.py b/tensorflow/python/feature_column/feature_column_lib.py index afe14f55bfc..bda20ff3f2c 100644 --- a/tensorflow/python/feature_column/feature_column_lib.py +++ b/tensorflow/python/feature_column/feature_column_lib.py @@ -19,13 +19,13 @@ from __future__ import division from __future__ import print_function # pylint: disable=unused-import,line-too-long,wildcard-import,g-bad-import-order -# We import dense_features_v2 first so that the V1 DenseFeatures is the default -# if users directly import feature_column_lib. -from tensorflow.python.feature_column.dense_features_v2 import * -from tensorflow.python.feature_column.dense_features import * from tensorflow.python.feature_column.feature_column import * from tensorflow.python.feature_column.feature_column_v2 import * from tensorflow.python.feature_column.sequence_feature_column import * from tensorflow.python.feature_column.serialization import * +# We import dense_features_v2 first so that the V1 DenseFeatures is the default +# if users directly import feature_column_lib. +from tensorflow.python.keras.feature_column.dense_features_v2 import * +from tensorflow.python.keras.feature_column.dense_features import * from tensorflow.python.keras.feature_column.sequence_feature_column import * # pylint: enable=unused-import,line-too-long diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py index 21def9cfa2c..38800fc2162 100644 --- a/tensorflow/python/feature_column/feature_column_test.py +++ b/tensorflow/python/feature_column/feature_column_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import collections import copy +from absl.testing import parameterized import numpy as np from tensorflow.core.example import example_pb2 @@ -852,9 +853,9 @@ class HashedCategoricalColumnTest(test.TestCase): 'aaa': inputs }), weight_collections=('my_weights',)) - self.assertItemsEqual( - [], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)) - self.assertItemsEqual([], ops.get_collection('my_weights')) + self.assertCountEqual([], + ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)) + self.assertCountEqual([], ops.get_collection('my_weights')) @test_util.run_deprecated_v1 def test_get_sparse_tensors_dense_input(self): @@ -1714,10 +1715,10 @@ class LinearModelTest(test.TestCase): # We check the mapping by checking that we have the right keys, # and that the values (output_tensors) were indeed the ones used to # form the input layer. - self.assertItemsEqual(all_cols, cols_to_output_tensors.keys()) + self.assertCountEqual(all_cols, cols_to_output_tensors.keys()) input_layer_inputs = [tensor for tensor in input_layer.op.inputs[:-1]] output_tensors = [tensor for tensor in cols_to_output_tensors.values()] - self.assertItemsEqual(input_layer_inputs, output_tensors) + self.assertCountEqual(input_layer_inputs, output_tensors) def test_dense_collection(self): price = fc._numeric_column('price') @@ -2841,7 +2842,7 @@ class FunctionalInputLayerTest(test.TestCase): cols_to_vars = {} all_cols = [price1, dense_feature_bucketized, some_embedding_column] fc.input_layer(features, all_cols, cols_to_vars=cols_to_vars) - self.assertItemsEqual(list(cols_to_vars.keys()), all_cols) + self.assertCountEqual(list(cols_to_vars.keys()), all_cols) self.assertEqual(0, len(cols_to_vars[price1])) self.assertEqual(0, len(cols_to_vars[dense_feature_bucketized])) self.assertEqual(1, len(cols_to_vars[some_embedding_column])) @@ -2891,7 +2892,7 @@ class FunctionalInputLayerTest(test.TestCase): shared_embedding_a, shared_embedding_b ] fc.input_layer(features, all_cols, cols_to_vars=cols_to_vars) - self.assertItemsEqual(list(cols_to_vars.keys()), all_cols) + self.assertCountEqual(list(cols_to_vars.keys()), all_cols) self.assertEqual(0, len(cols_to_vars[price1])) self.assertEqual(0, len(cols_to_vars[dense_feature_bucketized])) self.assertEqual(1, len(cols_to_vars[some_embedding_column])) @@ -2927,7 +2928,7 @@ class FunctionalInputLayerTest(test.TestCase): 'input_from_feature_columns', partitioner=partitioned_variables.fixed_size_partitioner(3, axis=0)): fc.input_layer(features, all_cols, cols_to_vars=cols_to_vars) - self.assertItemsEqual(list(cols_to_vars.keys()), all_cols) + self.assertCountEqual(list(cols_to_vars.keys()), all_cols) self.assertEqual(0, len(cols_to_vars[price1])) self.assertEqual(0, len(cols_to_vars[dense_feature_bucketized])) self.assertEqual(3, len(cols_to_vars[some_embedding_column])) @@ -3043,7 +3044,7 @@ class FunctionalInputLayerTest(test.TestCase): 'input_layer/sparse_feature_embedding/embedding_weights:0', 'input_layer_1/sparse_feature_embedding/embedding_weights:0' ] - self.assertItemsEqual( + self.assertCountEqual( expected_var_names, [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)]) @@ -3077,7 +3078,7 @@ class FunctionalInputLayerTest(test.TestCase): # Make sure that only 1 variable gets created in this case. self.assertEqual(1, len( ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))) - self.assertItemsEqual( + self.assertCountEqual( ['input_layer/aaa_bbb_shared_embedding/embedding_weights:0'], [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)]) @@ -3129,7 +3130,7 @@ class FunctionalInputLayerTest(test.TestCase): # Make sure that only 1 variable gets created in this case. self.assertEqual(1, len( ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))) - self.assertItemsEqual( + self.assertCountEqual( ['input_layer/aaa_bbb_shared_embedding/embedding_weights:0'], [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)]) @@ -3618,9 +3619,9 @@ class VocabularyFileCategoricalColumnTest(test.TestCase): 'aaa': inputs }), weight_collections=('my_weights',)) - self.assertItemsEqual( - [], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)) - self.assertItemsEqual([], ops.get_collection('my_weights')) + self.assertCountEqual([], + ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)) + self.assertCountEqual([], ops.get_collection('my_weights')) @test_util.run_deprecated_v1 def test_get_sparse_tensors_dense_input(self): @@ -4058,9 +4059,9 @@ class VocabularyListCategoricalColumnTest(test.TestCase): 'aaa': inputs }), weight_collections=('my_weights',)) - self.assertItemsEqual( - [], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)) - self.assertItemsEqual([], ops.get_collection('my_weights')) + self.assertCountEqual([], + ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)) + self.assertCountEqual([], ops.get_collection('my_weights')) @test_util.run_deprecated_v1 def test_get_sparse_tensors_dense_input(self): @@ -4363,9 +4364,9 @@ class IdentityCategoricalColumnTest(test.TestCase): 'aaa': inputs }), weight_collections=('my_weights',)) - self.assertItemsEqual( - [], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)) - self.assertItemsEqual([], ops.get_collection('my_weights')) + self.assertCountEqual([], + ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)) + self.assertCountEqual([], ops.get_collection('my_weights')) @test_util.run_deprecated_v1 def test_get_sparse_tensors_dense_input(self): @@ -4820,7 +4821,7 @@ class IndicatorColumnTest(test.TestCase): self.assertAllClose([[0., 1., 1., 0.]], self.evaluate(net)) -class EmbeddingColumnTest(test.TestCase): +class EmbeddingColumnTest(test.TestCase, parameterized.TestCase): @test_util.run_deprecated_v1 def test_defaults(self): @@ -4956,10 +4957,29 @@ class EmbeddingColumnTest(test.TestCase): _assert_sparse_tensor_value(self, self.evaluate(output_a), self.evaluate(output_embedded)) + @parameterized.named_parameters( + { + 'testcase_name': 'use_safe_embedding_lookup', + 'use_safe_embedding_lookup': True, + 'partition_variables': False, + }, { + 'testcase_name': 'dont_use_safe_embedding_lookup', + 'use_safe_embedding_lookup': False, + 'partition_variables': False, + }, { + 'testcase_name': 'use_safe_embedding_lookup_partitioned', + 'use_safe_embedding_lookup': True, + 'partition_variables': True, + }, { + 'testcase_name': 'dont_use_safe_embedding_lookup_partitioned', + 'use_safe_embedding_lookup': False, + 'partition_variables': True, + }) @test_util.run_deprecated_v1 - def test_get_dense_tensor(self): + def test_get_dense_tensor(self, use_safe_embedding_lookup, + partition_variables): # Inputs. - vocabulary_size = 3 + vocabulary_size = 4 sparse_input = sparse_tensor.SparseTensorValue( # example 0, ids [2] # example 1, ids [0, 1] @@ -4974,12 +4994,20 @@ class EmbeddingColumnTest(test.TestCase): embedding_values = ( (1., 2.), # id 0 (3., 5.), # id 1 - (7., 11.) # id 2 + (7., 11.), # id 2 + (9., 13.) # id 3 ) - def _initializer(shape, dtype, partition_info): - self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + + def _initializer(shape, dtype, partition_info=None): + if partition_variables: + self.assertEqual([vocabulary_size, embedding_dimension], + partition_info.full_shape) + self.assertAllEqual((2, embedding_dimension), shape) + else: + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertIsNone(partition_info) + self.assertEqual(dtypes.float32, dtype) - self.assertIsNone(partition_info) return embedding_values # Expected lookup result, using combiner='mean'. @@ -4997,25 +5025,43 @@ class EmbeddingColumnTest(test.TestCase): # Build columns. categorical_column = fc._categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) - embedding_column = fc._embedding_column( - categorical_column, - dimension=embedding_dimension, - initializer=_initializer) + partitioner = None + if partition_variables: + partitioner = partitioned_variables.fixed_size_partitioner(2, axis=0) + with variable_scope.variable_scope('vars', partitioner=partitioner): + embedding_column = fc._embedding_column( + categorical_column, + dimension=embedding_dimension, + initializer=_initializer, + use_safe_embedding_lookup=use_safe_embedding_lookup) - # Provide sparse input and get dense result. - embedding_lookup = embedding_column._get_dense_tensor( - _LazyBuilder({ - 'aaa': sparse_input - })) + # Provide sparse input and get dense result. + embedding_lookup = embedding_column._get_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) # Assert expected embedding variable and lookups. global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - self.assertItemsEqual(('embedding_weights:0',), - tuple([v.name for v in global_vars])) + if partition_variables: + self.assertCountEqual(('vars/embedding_weights/part_0:0', + 'vars/embedding_weights/part_1:0'), + tuple([v.name for v in global_vars])) + else: + self.assertCountEqual(('vars/embedding_weights:0',), + tuple([v.name for v in global_vars])) + for v in global_vars: + self.assertIsInstance(v, variables_lib.Variable) with _initialized_session(): self.assertAllEqual(embedding_values, global_vars[0].eval()) self.assertAllEqual(expected_lookups, self.evaluate(embedding_lookup)) + if use_safe_embedding_lookup: + self.assertIn('SparseFillEmptyRows', + [x.type for x in ops.get_default_graph().get_operations()]) + else: + self.assertNotIn( + 'SparseFillEmptyRows', + [x.type for x in ops.get_default_graph().get_operations()]) + @test_util.run_deprecated_v1 def test_get_dense_tensor_3d(self): # Inputs. @@ -5072,7 +5118,7 @@ class EmbeddingColumnTest(test.TestCase): # Assert expected embedding variable and lookups. global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - self.assertItemsEqual(('embedding_weights:0',), + self.assertCountEqual(('embedding_weights:0',), tuple([v.name for v in global_vars])) with _initialized_session(): self.assertAllEqual(embedding_values, global_vars[0].eval()) @@ -5102,11 +5148,11 @@ class EmbeddingColumnTest(test.TestCase): # Assert expected embedding variable and lookups. global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - self.assertItemsEqual(('embedding_weights:0',), + self.assertCountEqual(('embedding_weights:0',), tuple([v.name for v in global_vars])) my_vars = ops.get_collection('my_vars') - self.assertItemsEqual( - ('embedding_weights:0',), tuple([v.name for v in my_vars])) + self.assertCountEqual(('embedding_weights:0',), + tuple([v.name for v in my_vars])) @test_util.run_deprecated_v1 def test_get_dense_tensor_placeholder_inputs(self): @@ -5169,8 +5215,8 @@ class EmbeddingColumnTest(test.TestCase): # Assert expected embedding variable and lookups. global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - self.assertItemsEqual( - ('embedding_weights:0',), tuple([v.name for v in global_vars])) + self.assertCountEqual(('embedding_weights:0',), + tuple([v.name for v in global_vars])) with _initialized_session(): self.assertAllEqual(embedding_values, global_vars[0].eval()) self.assertAllEqual(expected_lookups, embedding_lookup.eval( @@ -5233,8 +5279,8 @@ class EmbeddingColumnTest(test.TestCase): # Assert expected embedding variable and lookups. global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - self.assertItemsEqual( - ('embedding_weights:0',), tuple([v.name for v in global_vars])) + self.assertCountEqual(('embedding_weights:0',), + tuple([v.name for v in global_vars])) with _initialized_session(): self.assertAllEqual(embedding_values, global_vars[0].eval()) self.assertAllEqual(expected_lookups, self.evaluate(embedding_lookup)) @@ -5280,14 +5326,14 @@ class EmbeddingColumnTest(test.TestCase): 'linear_model/aaa_embedding/weights:0', 'linear_model/aaa_embedding/embedding_weights:0', ) - self.assertItemsEqual( + self.assertCountEqual( expected_var_names, [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)]) trainable_vars = { v.name: v for v in ops.get_collection( ops.GraphKeys.TRAINABLE_VARIABLES) } - self.assertItemsEqual(expected_var_names, trainable_vars.keys()) + self.assertCountEqual(expected_var_names, trainable_vars.keys()) bias = trainable_vars['linear_model/bias_weights:0'] embedding_weights = trainable_vars[ 'linear_model/aaa_embedding/embedding_weights:0'] @@ -5361,14 +5407,14 @@ class EmbeddingColumnTest(test.TestCase): 'linear_model/aaa_embedding/weights:0', 'linear_model/aaa_embedding/embedding_weights:0', ) - self.assertItemsEqual( + self.assertCountEqual( expected_var_names, [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)]) trainable_vars = { v.name: v for v in ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) } - self.assertItemsEqual(expected_var_names, trainable_vars.keys()) + self.assertCountEqual(expected_var_names, trainable_vars.keys()) bias = trainable_vars['linear_model/bias_weights:0'] embedding_weights = trainable_vars[ 'linear_model/aaa_embedding/embedding_weights:0'] @@ -5450,13 +5496,11 @@ class EmbeddingColumnTest(test.TestCase): # Assert expected embedding variable and lookups. global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - self.assertItemsEqual( - ('input_layer/aaa_embedding/embedding_weights:0',), - tuple([v.name for v in global_vars])) + self.assertCountEqual(('input_layer/aaa_embedding/embedding_weights:0',), + tuple([v.name for v in global_vars])) trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) - self.assertItemsEqual( - ('input_layer/aaa_embedding/embedding_weights:0',), - tuple([v.name for v in trainable_vars])) + self.assertCountEqual(('input_layer/aaa_embedding/embedding_weights:0',), + tuple([v.name for v in trainable_vars])) with _initialized_session(): self.assertAllEqual(embedding_values, trainable_vars[0].eval()) self.assertAllEqual(expected_lookups, self.evaluate(input_layer)) @@ -5513,17 +5557,16 @@ class EmbeddingColumnTest(test.TestCase): # Assert expected embedding variable and lookups. global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - self.assertItemsEqual( - ('input_layer/aaa_embedding/embedding_weights:0',), - tuple([v.name for v in global_vars])) - self.assertItemsEqual( - [], ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)) + self.assertCountEqual(('input_layer/aaa_embedding/embedding_weights:0',), + tuple([v.name for v in global_vars])) + self.assertCountEqual([], + ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)) with _initialized_session(): self.assertAllEqual(embedding_values, global_vars[0].eval()) self.assertAllEqual(expected_lookups, self.evaluate(input_layer)) -class SharedEmbeddingColumnTest(test.TestCase): +class SharedEmbeddingColumnTest(test.TestCase, parameterized.TestCase): @test_util.run_deprecated_v1 def test_defaults(self): @@ -5772,33 +5815,59 @@ class SharedEmbeddingColumnTest(test.TestCase): _assert_sparse_tensor_value(self, self.evaluate(output_b), self.evaluate(output_b_embedded)) + @parameterized.named_parameters( + { + 'testcase_name': 'use_safe_embedding_lookup', + 'use_safe_embedding_lookup': True, + 'partition_variables': False, + }, { + 'testcase_name': 'dont_use_safe_embedding_lookup', + 'use_safe_embedding_lookup': False, + 'partition_variables': False, + }, { + 'testcase_name': 'use_safe_embedding_lookup_partitioned', + 'use_safe_embedding_lookup': True, + 'partition_variables': True, + }, { + 'testcase_name': 'dont_use_safe_embedding_lookup_partitioned', + 'use_safe_embedding_lookup': False, + 'partition_variables': True, + }) @test_util.run_deprecated_v1 - def test_get_dense_tensor(self): + def test_get_dense_tensor(self, use_safe_embedding_lookup, + partition_variables): # Inputs. - vocabulary_size = 3 + vocabulary_size = 4 # -1 values are ignored. - input_a = np.array( - [[2, -1, -1], # example 0, ids [2] - [0, 1, -1]]) # example 1, ids [0, 1] - input_b = np.array( - [[0, -1, -1], # example 0, ids [0] - [-1, -1, -1]]) # example 1, ids [] - input_features = { - 'aaa': input_a, - 'bbb': input_b - } + input_a = np.array([ + [2, -1, -1], # example 0, ids [2] + [0, 1, -1] + ]) # example 1, ids [0, 1] + input_b = np.array([ + [0, -1, -1], # example 0, ids [0] + [-1, -1, -1] + ]) # example 1, ids [] + input_features = {'aaa': input_a, 'bbb': input_b} # Embedding variable. embedding_dimension = 2 embedding_values = ( (1., 2.), # id 0 (3., 5.), # id 1 - (7., 11.) # id 2 + (7., 11.), # id 2 + (9., 13.) # id 3 ) - def _initializer(shape, dtype, partition_info): - self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + + def _initializer(shape, dtype, partition_info=None): + if partition_variables: + self.assertEqual([vocabulary_size, embedding_dimension], + partition_info.full_shape) + self.assertAllEqual((2, embedding_dimension), shape) + else: + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertIsNone(partition_info) + self.assertEqual(dtypes.float32, dtype) - self.assertIsNone(partition_info) return embedding_values # Expected lookup result, using combiner='mean'. @@ -5808,38 +5877,65 @@ class SharedEmbeddingColumnTest(test.TestCase): # example 1: (2., 3.5), # ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5] ) - expected_lookups_b = ( - # example 0: - (1., 2.), # ids [0], embedding = [1, 2] - # example 1: - (0., 0.), # ids [], embedding = [0, 0] - ) + if use_safe_embedding_lookup: + expected_lookups_b = ( + # example 0: + (1., 2.), # ids [0], embedding = [1, 2] + # example 1: + (0., 0.), # ids [], embedding = [0, 0] + ) + else: + expected_lookups_b = ( + # example 0: + (1., 2.), # ids [0], embedding = [1, 2] + ) # Build columns. categorical_column_a = fc._categorical_column_with_identity( key='aaa', num_buckets=vocabulary_size) categorical_column_b = fc._categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) - embedding_column_a, embedding_column_b = fc_new.shared_embedding_columns( - [categorical_column_a, categorical_column_b], - dimension=embedding_dimension, - initializer=_initializer) - # Provide sparse input and get dense result. - embedding_lookup_a = embedding_column_a._get_dense_tensor( - _LazyBuilder(input_features)) - embedding_lookup_b = embedding_column_b._get_dense_tensor( - _LazyBuilder(input_features)) + partitioner = None + if partition_variables: + partitioner = partitioned_variables.fixed_size_partitioner(2, axis=0) + with variable_scope.variable_scope('vars', partitioner=partitioner): + embedding_column_a, embedding_column_b = fc_new.shared_embedding_columns( + [categorical_column_a, categorical_column_b], + dimension=embedding_dimension, + initializer=_initializer, + use_safe_embedding_lookup=use_safe_embedding_lookup) + # Provide sparse input and get dense result. + embedding_lookup_a = embedding_column_a._get_dense_tensor( + _LazyBuilder(input_features)) + embedding_lookup_b = embedding_column_b._get_dense_tensor( + _LazyBuilder(input_features)) # Assert expected embedding variable and lookups. global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - self.assertItemsEqual(('embedding_weights:0',), - tuple([v.name for v in global_vars])) + if partition_variables: + self.assertCountEqual(('vars/embedding_weights/part_0:0', + 'vars/embedding_weights/part_1:0'), + tuple([v.name for v in global_vars])) + else: + self.assertCountEqual(('vars/embedding_weights:0',), + tuple([v.name for v in global_vars])) embedding_var = global_vars[0] - with _initialized_session(): - self.assertAllEqual(embedding_values, self.evaluate(embedding_var)) - self.assertAllEqual(expected_lookups_a, self.evaluate(embedding_lookup_a)) - self.assertAllEqual(expected_lookups_b, self.evaluate(embedding_lookup_b)) + + self.evaluate(variables_lib.global_variables_initializer()) + self.evaluate(lookup_ops.tables_initializer()) + + self.assertAllEqual(embedding_values, self.evaluate(embedding_var)) + self.assertAllEqual(expected_lookups_a, self.evaluate(embedding_lookup_a)) + self.assertAllEqual(expected_lookups_b, self.evaluate(embedding_lookup_b)) + + if use_safe_embedding_lookup: + self.assertIn('SparseFillEmptyRows', + [x.type for x in ops.get_default_graph().get_operations()]) + else: + self.assertNotIn( + 'SparseFillEmptyRows', + [x.type for x in ops.get_default_graph().get_operations()]) @test_util.run_deprecated_v1 def test_get_dense_tensor_weight_collections(self): @@ -5886,11 +5982,11 @@ class SharedEmbeddingColumnTest(test.TestCase): # Assert expected embedding variable and lookups. global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - self.assertItemsEqual( + self.assertCountEqual( ('input_layer/aaa_bbb_shared_embedding/embedding_weights:0',), tuple(v.name for v in global_vars)) my_vars = ops.get_collection('my_vars') - self.assertItemsEqual( + self.assertCountEqual( ('input_layer/aaa_bbb_shared_embedding/embedding_weights:0',), tuple(v.name for v in my_vars)) @@ -5997,14 +6093,14 @@ class SharedEmbeddingColumnTest(test.TestCase): 'linear_model/aaa_bbb_shared_embedding/embedding_weights:0', 'linear_model/aaa_bbb_shared_embedding_1/weights:0', ) - self.assertItemsEqual( + self.assertCountEqual( expected_var_names, [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)]) trainable_vars = { v.name: v for v in ops.get_collection( ops.GraphKeys.TRAINABLE_VARIABLES) } - self.assertItemsEqual(expected_var_names, trainable_vars.keys()) + self.assertCountEqual(expected_var_names, trainable_vars.keys()) bias = trainable_vars['linear_model/bias_weights:0'] embedding_weights = trainable_vars[ 'linear_model/aaa_bbb_shared_embedding/embedding_weights:0'] @@ -6091,14 +6187,14 @@ class SharedEmbeddingColumnTest(test.TestCase): 'linear_model/aaa_bbb_shared_embedding/embedding_weights:0', 'linear_model/aaa_bbb_shared_embedding_1/weights:0', ) - self.assertItemsEqual( + self.assertCountEqual( expected_var_names, [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)]) trainable_vars = { v.name: v for v in ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) } - self.assertItemsEqual(expected_var_names, trainable_vars.keys()) + self.assertCountEqual(expected_var_names, trainable_vars.keys()) bias = trainable_vars['linear_model/bias_weights:0'] embedding_weights = trainable_vars[ 'linear_model/aaa_bbb_shared_embedding/embedding_weights:0'] @@ -6195,16 +6291,16 @@ class SharedEmbeddingColumnTest(test.TestCase): # Assert expected embedding variable and lookups. global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - self.assertItemsEqual( + self.assertCountEqual( ['input_layer/aaa_bbb_shared_embedding/embedding_weights:0'], tuple([v.name for v in global_vars])) trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) if trainable: - self.assertItemsEqual( + self.assertCountEqual( ['input_layer/aaa_bbb_shared_embedding/embedding_weights:0'], tuple([v.name for v in trainable_vars])) else: - self.assertItemsEqual([], tuple([v.name for v in trainable_vars])) + self.assertCountEqual([], tuple([v.name for v in trainable_vars])) shared_embedding_vars = global_vars with _initialized_session(): self.assertAllEqual(embedding_values, shared_embedding_vars[0].eval()) diff --git a/tensorflow/python/feature_column/feature_column_v2.py b/tensorflow/python/feature_column/feature_column_v2.py index 23a9861eb1b..a03e4da0fae 100644 --- a/tensorflow/python/feature_column/feature_column_v2.py +++ b/tensorflow/python/feature_column/feature_column_v2.py @@ -145,8 +145,6 @@ from tensorflow.python.framework import tensor_shape # TODO(b/118385027): Dependency on keras can be problematic if Keras moves out # of the main repo. from tensorflow.python.keras import initializers -from tensorflow.python.keras.engine import training as keras_training -from tensorflow.python.keras.engine.base_layer import Layer from tensorflow.python.keras.utils import generic_utils from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops @@ -154,7 +152,6 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import string_ops @@ -383,376 +380,6 @@ class _StateManagerImplV2(_StateManagerImpl): return var -class _BaseFeaturesLayer(Layer): - """Base class for DenseFeatures and SequenceFeatures. - - Defines common methods and helpers. - - Args: - feature_columns: An iterable containing the FeatureColumns to use as - inputs to your model. - expected_column_type: Expected class for provided feature columns. - trainable: Boolean, whether the layer's variables will be updated via - gradient descent during training. - name: Name to give to the DenseFeatures. - **kwargs: Keyword arguments to construct a layer. - - Raises: - ValueError: if an item in `feature_columns` doesn't match - `expected_column_type`. - """ - - def __init__(self, - feature_columns, - expected_column_type, - trainable, - name, - partitioner=None, - **kwargs): - super(_BaseFeaturesLayer, self).__init__( - name=name, trainable=trainable, **kwargs) - self._feature_columns = _normalize_feature_columns(feature_columns) - self._state_manager = _StateManagerImpl(self, self.trainable) - self._partitioner = partitioner - for column in self._feature_columns: - if not isinstance(column, expected_column_type): - raise ValueError( - 'Items of feature_columns must be a {}. ' - 'You can wrap a categorical column with an ' - 'embedding_column or indicator_column. Given: {}'.format( - expected_column_type, column)) - - def build(self, _): - for column in self._feature_columns: - with variable_scope._pure_variable_scope( # pylint: disable=protected-access - self.name, - partitioner=self._partitioner): - with variable_scope._pure_variable_scope( # pylint: disable=protected-access - _sanitize_column_name_for_variable_scope(column.name)): - column.create_state(self._state_manager) - super(_BaseFeaturesLayer, self).build(None) - - def _output_shape(self, input_shape, num_elements): - """Computes expected output shape of the layer or a column's dense tensor. - - Args: - input_shape: Tensor or array with batch shape. - num_elements: Size of the last dimension of the output. - - Returns: - Tuple with output shape. - """ - raise NotImplementedError('Calling an abstract method.') - - def compute_output_shape(self, input_shape): - total_elements = 0 - for column in self._feature_columns: - total_elements += column.variable_shape.num_elements() - return self._target_shape(input_shape, total_elements) - - def _process_dense_tensor(self, column, tensor): - """Reshapes the dense tensor output of a column based on expected shape. - - Args: - column: A DenseColumn or SequenceDenseColumn object. - tensor: A dense tensor obtained from the same column. - - Returns: - Reshaped dense tensor.""" - num_elements = column.variable_shape.num_elements() - target_shape = self._target_shape(array_ops.shape(tensor), num_elements) - return array_ops.reshape(tensor, shape=target_shape) - - def _verify_and_concat_tensors(self, output_tensors): - """Verifies and concatenates the dense output of several columns.""" - _verify_static_batch_size_equality(output_tensors, self._feature_columns) - return array_ops.concat(output_tensors, -1) - - def get_config(self): - # Import here to avoid circular imports. - from tensorflow.python.feature_column import serialization # pylint: disable=g-import-not-at-top - column_configs = serialization.serialize_feature_columns( - self._feature_columns) - config = {'feature_columns': column_configs} - config['partitioner'] = generic_utils.serialize_keras_object( - self._partitioner) - - base_config = super( # pylint: disable=bad-super-call - _BaseFeaturesLayer, self).get_config() - return dict(list(base_config.items()) + list(config.items())) - - @classmethod - def from_config(cls, config, custom_objects=None): - # Import here to avoid circular imports. - from tensorflow.python.feature_column import serialization # pylint: disable=g-import-not-at-top - config_cp = config.copy() - config_cp['feature_columns'] = serialization.deserialize_feature_columns( - config['feature_columns'], custom_objects=custom_objects) - config_cp['partitioner'] = generic_utils.deserialize_keras_object( - config['partitioner'], custom_objects) - - return cls(**config_cp) - - -class _LinearModelLayer(Layer): - """Layer that contains logic for `LinearModel`.""" - - def __init__(self, - feature_columns, - units=1, - sparse_combiner='sum', - trainable=True, - name=None, - **kwargs): - super(_LinearModelLayer, self).__init__( - name=name, trainable=trainable, **kwargs) - - self._feature_columns = _normalize_feature_columns(feature_columns) - for column in self._feature_columns: - if not isinstance(column, (DenseColumn, CategoricalColumn)): - raise ValueError( - 'Items of feature_columns must be either a ' - 'DenseColumn or CategoricalColumn. Given: {}'.format(column)) - - self._units = units - self._sparse_combiner = sparse_combiner - - self._state_manager = _StateManagerImpl(self, self.trainable) - self.bias = None - - def build(self, _): - # We need variable scopes for now because we want the variable partitioning - # information to percolate down. We also use _pure_variable_scope's here - # since we want to open up a name_scope in the `call` method while creating - # the ops. - with variable_scope._pure_variable_scope(self.name): # pylint: disable=protected-access - for column in self._feature_columns: - with variable_scope._pure_variable_scope( # pylint: disable=protected-access - _sanitize_column_name_for_variable_scope(column.name)): - # Create the state for each feature column - column.create_state(self._state_manager) - - # Create a weight variable for each column. - if isinstance(column, CategoricalColumn): - first_dim = column.num_buckets - else: - first_dim = column.variable_shape.num_elements() - self._state_manager.create_variable( - column, - name='weights', - dtype=dtypes.float32, - shape=(first_dim, self._units), - initializer=initializers.zeros(), - trainable=self.trainable) - - # Create a bias variable. - self.bias = self.add_variable( - name='bias_weights', - dtype=dtypes.float32, - shape=[self._units], - initializer=initializers.zeros(), - trainable=self.trainable, - use_resource=True, - # TODO(rohanj): Get rid of this hack once we have a mechanism for - # specifying a default partitioner for an entire layer. In that case, - # the default getter for Layers should work. - getter=variable_scope.get_variable) - - super(_LinearModelLayer, self).build(None) - - def call(self, features): - if not isinstance(features, dict): - raise ValueError('We expected a dictionary here. Instead we got: {}' - .format(features)) - with ops.name_scope(self.name): - transformation_cache = FeatureTransformationCache(features) - weighted_sums = [] - for column in self._feature_columns: - with ops.name_scope( - _sanitize_column_name_for_variable_scope(column.name)): - # All the weights used in the linear model are owned by the state - # manager associated with this Linear Model. - weight_var = self._state_manager.get_variable(column, 'weights') - - weighted_sum = _create_weighted_sum( - column=column, - transformation_cache=transformation_cache, - state_manager=self._state_manager, - sparse_combiner=self._sparse_combiner, - weight_var=weight_var) - weighted_sums.append(weighted_sum) - - _verify_static_batch_size_equality(weighted_sums, self._feature_columns) - predictions_no_bias = math_ops.add_n( - weighted_sums, name='weighted_sum_no_bias') - predictions = nn_ops.bias_add( - predictions_no_bias, self.bias, name='weighted_sum') - return predictions - - def get_config(self): - # Import here to avoid circular imports. - from tensorflow.python.feature_column import serialization # pylint: disable=g-import-not-at-top - column_configs = serialization.serialize_feature_columns( - self._feature_columns) - config = { - 'feature_columns': column_configs, - 'units': self._units, - 'sparse_combiner': self._sparse_combiner - } - - base_config = super( # pylint: disable=bad-super-call - _LinearModelLayer, self).get_config() - return dict(list(base_config.items()) + list(config.items())) - - @classmethod - def from_config(cls, config, custom_objects=None): - # Import here to avoid circular imports. - from tensorflow.python.feature_column import serialization # pylint: disable=g-import-not-at-top - config_cp = config.copy() - columns = serialization.deserialize_feature_columns( - config_cp['feature_columns'], custom_objects=custom_objects) - - del config_cp['feature_columns'] - return cls(feature_columns=columns, **config_cp) - - -# TODO(tanzheny): Cleanup it with respect to Premade model b/132690565. -class LinearModel(keras_training.Model): - """Produces a linear prediction `Tensor` based on given `feature_columns`. - - This layer generates a weighted sum based on output dimension `units`. - Weighted sum refers to logits in classification problems. It refers to the - prediction itself for linear regression problems. - - Note on supported columns: `LinearLayer` treats categorical columns as - `indicator_column`s. To be specific, assume the input as `SparseTensor` looks - like: - - ```python - shape = [2, 2] - { - [0, 0]: "a" - [1, 0]: "b" - [1, 1]: "c" - } - ``` - `linear_model` assigns weights for the presence of "a", "b", "c' implicitly, - just like `indicator_column`, while `input_layer` explicitly requires wrapping - each of categorical columns with an `embedding_column` or an - `indicator_column`. - - Example of usage: - - ```python - price = numeric_column('price') - price_buckets = bucketized_column(price, boundaries=[0., 10., 100., 1000.]) - keywords = categorical_column_with_hash_bucket("keywords", 10K) - keywords_price = crossed_column('keywords', price_buckets, ...) - columns = [price_buckets, keywords, keywords_price ...] - linear_model = LinearLayer(columns) - - features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) - prediction = linear_model(features) - ``` - """ - - def __init__(self, - feature_columns, - units=1, - sparse_combiner='sum', - trainable=True, - name=None, - **kwargs): - """Constructs a LinearLayer. - - Args: - feature_columns: An iterable containing the FeatureColumns to use as - inputs to your model. All items should be instances of classes derived - from `_FeatureColumn`s. - units: An integer, dimensionality of the output space. Default value is 1. - sparse_combiner: A string specifying how to reduce if a categorical column - is multivalent. Except `numeric_column`, almost all columns passed to - `linear_model` are considered as categorical columns. It combines each - categorical column independently. Currently "mean", "sqrtn" and "sum" - are supported, with "sum" the default for linear model. "sqrtn" often - achieves good accuracy, in particular with bag-of-words columns. - * "sum": do not normalize features in the column - * "mean": do l1 normalization on features in the column - * "sqrtn": do l2 normalization on features in the column - For example, for two features represented as the categorical columns: - - ```python - # Feature 1 - - shape = [2, 2] - { - [0, 0]: "a" - [0, 1]: "b" - [1, 0]: "c" - } - - # Feature 2 - - shape = [2, 3] - { - [0, 0]: "d" - [1, 0]: "e" - [1, 1]: "f" - [1, 2]: "g" - } - ``` - - with `sparse_combiner` as "mean", the linear model outputs conceptually - are - ``` - y_0 = 1.0 / 2.0 * ( w_a + w_ b) + w_c + b_0 - y_1 = w_d + 1.0 / 3.0 * ( w_e + w_ f + w_g) + b_1 - ``` - where `y_i` is the output, `b_i` is the bias, and `w_x` is the weight - assigned to the presence of `x` in the input features. - trainable: If `True` also add the variable to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - name: Name to give to the Linear Model. All variables and ops created will - be scoped by this name. - **kwargs: Keyword arguments to construct a layer. - - Raises: - ValueError: if an item in `feature_columns` is neither a `DenseColumn` - nor `CategoricalColumn`. - """ - - super(LinearModel, self).__init__(name=name, **kwargs) - self.layer = _LinearModelLayer( - feature_columns, - units, - sparse_combiner, - trainable, - name=self.name, - **kwargs) - - def call(self, features): - """Returns a `Tensor` the represents the predictions of a linear model. - - Args: - features: A mapping from key to tensors. `_FeatureColumn`s look up via - these keys. For example `numeric_column('price')` will look at 'price' - key in this dict. Values are `Tensor` or `SparseTensor` depending on - corresponding `_FeatureColumn`. - - Returns: - A `Tensor` which represents predictions/logits of a linear model. Its - shape is (batch_size, units) and its dtype is `float32`. - - Raises: - ValueError: If features are not a dictionary. - """ - return self.layer(features) - - @property - def bias(self): - return self.layer.bias - - def _transform_features_v2(features, feature_columns, state_manager): """Returns transformed features based on features columns passed in. @@ -3263,7 +2890,7 @@ class EmbeddingColumn( embedding_lookup_sparse = embedding_ops.safe_embedding_lookup_sparse if (not self.use_safe_embedding_lookup and sparse_id_rank is not None and sparse_id_rank <= 2): - embedding_lookup_sparse = embedding_ops.embedding_lookup_sparse + embedding_lookup_sparse = embedding_ops.embedding_lookup_sparse_v2 # Return embedding lookup result. return embedding_lookup_sparse( embedding_weights, @@ -3558,7 +3185,7 @@ class SharedEmbeddingColumn( embedding_lookup_sparse = embedding_ops.safe_embedding_lookup_sparse if (not self.use_safe_embedding_lookup and sparse_id_rank is not None and sparse_id_rank <= 2): - embedding_lookup_sparse = (embedding_ops.embedding_lookup_sparse) + embedding_lookup_sparse = embedding_ops.embedding_lookup_sparse_v2 # Return embedding lookup result. return embedding_lookup_sparse( embedding_weights, diff --git a/tensorflow/python/feature_column/feature_column_v2_test.py b/tensorflow/python/feature_column/feature_column_v2_test.py index fe769850fb0..91fb7eadb89 100644 --- a/tensorflow/python/feature_column/feature_column_v2_test.py +++ b/tensorflow/python/feature_column/feature_column_v2_test.py @@ -31,7 +31,6 @@ from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.client import session from tensorflow.python.eager import backprop from tensorflow.python.eager import context -from tensorflow.python.feature_column import dense_features as df from tensorflow.python.feature_column import feature_column as fc_old from tensorflow.python.feature_column import feature_column_v2 as fc from tensorflow.python.feature_column import serialization @@ -49,7 +48,6 @@ from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as variables_lib from tensorflow.python.platform import test -from tensorflow.python.training import rmsprop def _initialized_session(config=None): @@ -440,36 +438,6 @@ class NumericColumnTest(test.TestCase): 'aaa', shape=[1, 2], default_value=np.array([[3., 2.]])) self.assertEqual(a.default_value, ((3., 2.),)) - @test_util.run_deprecated_v1 - def test_linear_model(self): - price = fc.numeric_column('price') - with ops.Graph().as_default(): - features = {'price': [[1.], [5.]]} - model = fc.LinearModel([price]) - predictions = model(features) - price_var, bias = model.variables - with _initialized_session() as sess: - self.assertAllClose([0.], self.evaluate(bias)) - self.assertAllClose([[0.]], self.evaluate(price_var)) - self.assertAllClose([[0.], [0.]], self.evaluate(predictions)) - sess.run(price_var.assign([[10.]])) - self.assertAllClose([[10.], [50.]], self.evaluate(predictions)) - - @test_util.run_deprecated_v1 - def test_linear_model_sanitizes_scope_names(self): - price = fc.numeric_column('price > 100') - with ops.Graph().as_default(): - features = {'price > 100': [[1.], [5.]]} - model = fc.LinearModel([price]) - predictions = model(features) - price_var, bias = model.variables - with _initialized_session() as sess: - self.assertAllClose([0.], self.evaluate(bias)) - self.assertAllClose([[0.]], self.evaluate(price_var)) - self.assertAllClose([[0.], [0.]], self.evaluate(predictions)) - sess.run(price_var.assign([[10.]])) - self.assertAllClose([[10.], [50.]], self.evaluate(predictions)) - def test_old_linear_model(self): price = fc.numeric_column('price') with ops.Graph().as_default(): @@ -706,63 +674,6 @@ class BucketizedColumnTest(test.TestCase): self.assertAllEqual(a_bucketized_copy.variable_shape, (2, 3)) self.assertEqual(a_bucketized_copy.boundaries, (0, 1)) - def test_linear_model_one_input_value(self): - """Tests linear_model() for input with shape=[1].""" - price = fc.numeric_column('price', shape=[1]) - bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6]) - with ops.Graph().as_default(): - features = {'price': [[-1.], [1.], [5.], [6.]]} - model = fc.LinearModel([bucketized_price]) - predictions = model(features) - bucketized_price_var, bias = model.variables - with _initialized_session() as sess: - self.assertAllClose([0.], self.evaluate(bias)) - # One weight variable per bucket, all initialized to zero. - self.assertAllClose([[0.], [0.], [0.], [0.], [0.]], - self.evaluate(bucketized_price_var)) - self.assertAllClose([[0.], [0.], [0.], [0.]], - self.evaluate(predictions)) - sess.run( - bucketized_price_var.assign([[10.], [20.], [30.], [40.], [50.]])) - # price -1. is in the 0th bucket, whose weight is 10. - # price 1. is in the 1st bucket, whose weight is 20. - # price 5. is in the 3rd bucket, whose weight is 40. - # price 6. is in the 4th bucket, whose weight is 50. - self.assertAllClose([[10.], [20.], [40.], [50.]], - self.evaluate(predictions)) - sess.run(bias.assign([1.])) - self.assertAllClose([[11.], [21.], [41.], [51.]], - self.evaluate(predictions)) - - def test_linear_model_two_input_values(self): - """Tests linear_model() for input with shape=[2].""" - price = fc.numeric_column('price', shape=[2]) - bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6]) - with ops.Graph().as_default(): - features = {'price': [[-1., 1.], [5., 6.]]} - model = fc.LinearModel([bucketized_price]) - predictions = model(features) - bucketized_price_var, bias = model.variables - with _initialized_session() as sess: - self.assertAllClose([0.], self.evaluate(bias)) - # One weight per bucket per input column, all initialized to zero. - self.assertAllClose( - [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]], - self.evaluate(bucketized_price_var)) - self.assertAllClose([[0.], [0.]], self.evaluate(predictions)) - sess.run( - bucketized_price_var.assign([[10.], [20.], [30.], [40.], [50.], - [60.], [70.], [80.], [90.], [100.]])) - # 1st example: - # price -1. is in the 0th bucket, whose weight is 10. - # price 1. is in the 6th bucket, whose weight is 70. - # 2nd example: - # price 5. is in the 3rd bucket, whose weight is 40. - # price 6. is in the 9th bucket, whose weight is 100. - self.assertAllClose([[80.], [140.]], self.evaluate(predictions)) - sess.run(bias.assign([1.])) - self.assertAllClose([[81.], [141.]], self.evaluate(predictions)) - def test_old_linear_model_one_input_value(self): """Tests linear_model() for input with shape=[1].""" price = fc.numeric_column('price', shape=[1]) @@ -1071,32 +982,6 @@ class HashedCategoricalColumnTest(test.TestCase): self.assertEqual( transformation_cache.get(hashed_sparse, None), id_weight_pair.id_tensor) - @test_util.run_deprecated_v1 - def test_linear_model(self): - wire_column = fc.categorical_column_with_hash_bucket('wire', 4) - self.assertEqual(4, wire_column.num_buckets) - with ops.Graph().as_default(): - model = fc.LinearModel((wire_column,)) - predictions = model({ - wire_column.name: - sparse_tensor.SparseTensorValue( - indices=((0, 0), (1, 0), (1, 1)), - values=('marlo', 'skywalker', 'omar'), - dense_shape=(2, 2)) - }) - wire_var, bias = model.variables - - self.evaluate(variables_lib.global_variables_initializer()) - self.evaluate(lookup_ops.tables_initializer()) - - self.assertAllClose((0.,), self.evaluate(bias)) - self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), self.evaluate(wire_var)) - self.assertAllClose(((0.,), (0.,)), self.evaluate(predictions)) - self.evaluate(wire_var.assign(((1.,), (2.,), (3.,), (4.,)))) - # 'marlo' -> 3: wire_var[3] = 4 - # 'skywalker' -> 2, 'omar' -> 2: wire_var[2] + wire_var[2] = 3+3 = 6 - self.assertAllClose(((4.,), (6.,)), self.evaluate(predictions)) - def test_old_linear_model(self): wire_column = fc.categorical_column_with_hash_bucket('wire', 4) self.assertEqual(4, wire_column.num_buckets) @@ -1365,101 +1250,6 @@ class CrossedColumnTest(test.TestCase): self.assertAllEqual(expected_values, id_tensor_eval.values) self.assertAllEqual((2, 4), id_tensor_eval.dense_shape) - @test_util.run_deprecated_v1 - def test_linear_model(self): - """Tests linear_model. - - Uses data from test_get_sparse_tensors_simple. - """ - a = fc.numeric_column('a', dtype=dtypes.int32, shape=(2,)) - b = fc.bucketized_column(a, boundaries=(0, 1)) - crossed = fc.crossed_column([b, 'c'], hash_bucket_size=5, hash_key=5) - with ops.Graph().as_default(): - model = fc.LinearModel((crossed,)) - predictions = model({ - 'a': - constant_op.constant(((-1., .5), (.5, 1.))), - 'c': - sparse_tensor.SparseTensor( - indices=((0, 0), (1, 0), (1, 1)), - values=['cA', 'cB', 'cC'], - dense_shape=(2, 2)), - }) - crossed_var, bias = model.variables - with _initialized_session() as sess: - self.assertAllClose((0.,), self.evaluate(bias)) - self.assertAllClose(((0.,), (0.,), (0.,), (0.,), (0.,)), - self.evaluate(crossed_var)) - self.assertAllClose(((0.,), (0.,)), self.evaluate(predictions)) - sess.run(crossed_var.assign(((1.,), (2.,), (3.,), (4.,), (5.,)))) - # Expected ids after cross = (1, 0, 1, 3, 4, 2) - self.assertAllClose(((3.,), (14.,)), self.evaluate(predictions)) - sess.run(bias.assign((.1,))) - self.assertAllClose(((3.1,), (14.1,)), self.evaluate(predictions)) - - def test_linear_model_with_weights(self): - - class _TestColumnWithWeights(BaseFeatureColumnForTests, - fc.CategoricalColumn): - """Produces sparse IDs and sparse weights.""" - - @property - def _is_v2_column(self): - return True - - @property - def name(self): - return 'test_column' - - @property - def parse_example_spec(self): - return { - self.name: - parsing_ops.VarLenFeature(dtypes.int32), - '{}_weights'.format(self.name): - parsing_ops.VarLenFeature(dtypes.float32), - } - - @property - def num_buckets(self): - return 5 - - def transform_feature(self, transformation_cache, state_manager): - return (transformation_cache.get(self.name, state_manager), - transformation_cache.get('{}_weights'.format(self.name), - state_manager)) - - def get_sparse_tensors(self, transformation_cache, state_manager): - """Populates both id_tensor and weight_tensor.""" - ids_and_weights = transformation_cache.get(self, state_manager) - return fc.CategoricalColumn.IdWeightPair( - id_tensor=ids_and_weights[0], weight_tensor=ids_and_weights[1]) - - t = _TestColumnWithWeights() - crossed = fc.crossed_column([t, 'c'], hash_bucket_size=5, hash_key=5) - with ops.Graph().as_default(): - with self.assertRaisesRegexp( - ValueError, - 'crossed_column does not support weight_tensor.*{}'.format(t.name)): - model = fc.LinearModel((crossed,)) - model({ - t.name: - sparse_tensor.SparseTensor( - indices=((0, 0), (1, 0), (1, 1)), - values=[0, 1, 2], - dense_shape=(2, 2)), - '{}_weights'.format(t.name): - sparse_tensor.SparseTensor( - indices=((0, 0), (1, 0), (1, 1)), - values=[1., 10., 2.], - dense_shape=(2, 2)), - 'c': - sparse_tensor.SparseTensor( - indices=((0, 0), (1, 0), (1, 1)), - values=['cA', 'cB', 'cC'], - dense_shape=(2, 2)), - }) - def test_old_linear_model(self): """Tests linear_model. @@ -1644,668 +1434,6 @@ class CrossedColumnTest(test.TestCase): self.assertIs(b, new_crossed.keys[0]) -class LinearModelTest(test.TestCase): - - def test_raises_if_empty_feature_columns(self): - with self.assertRaisesRegexp(ValueError, - 'feature_columns must not be empty'): - fc.LinearModel(feature_columns=[]) - - def test_should_be_feature_column(self): - with self.assertRaisesRegexp(ValueError, 'must be a FeatureColumn'): - fc.LinearModel(feature_columns='NotSupported') - - def test_should_be_dense_or_categorical_column(self): - - class NotSupportedColumn(BaseFeatureColumnForTests): - - @property - def _is_v2_column(self): - return True - - @property - def name(self): - return 'NotSupportedColumn' - - def transform_feature(self, transformation_cache, state_manager): - pass - - @property - def parse_example_spec(self): - pass - - with self.assertRaisesRegexp( - ValueError, 'must be either a DenseColumn or CategoricalColumn'): - fc.LinearModel(feature_columns=[NotSupportedColumn()]) - - def test_does_not_support_dict_columns(self): - with self.assertRaisesRegexp( - ValueError, 'Expected feature_columns to be iterable, found dict.'): - fc.LinearModel(feature_columns={'a': fc.numeric_column('a')}) - - def test_raises_if_duplicate_name(self): - with self.assertRaisesRegexp( - ValueError, 'Duplicate feature column name found for columns'): - fc.LinearModel( - feature_columns=[fc.numeric_column('a'), - fc.numeric_column('a')]) - - def test_not_dict_input_features(self): - price = fc.numeric_column('price') - with ops.Graph().as_default(): - features = [[1.], [5.]] - model = fc.LinearModel([price]) - with self.assertRaisesRegexp(ValueError, 'We expected a dictionary here'): - model(features) - - def test_dense_bias(self): - price = fc.numeric_column('price') - with ops.Graph().as_default(): - features = {'price': [[1.], [5.]]} - model = fc.LinearModel([price]) - predictions = model(features) - price_var, bias = model.variables - with _initialized_session() as sess: - self.assertAllClose([0.], self.evaluate(bias)) - sess.run(price_var.assign([[10.]])) - sess.run(bias.assign([5.])) - self.assertAllClose([[15.], [55.]], self.evaluate(predictions)) - - def test_sparse_bias(self): - wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4) - with ops.Graph().as_default(): - wire_tensor = sparse_tensor.SparseTensor( - values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3] - indices=[[0, 0], [1, 0], [1, 1]], - dense_shape=[2, 2]) - features = {'wire_cast': wire_tensor} - model = fc.LinearModel([wire_cast]) - predictions = model(features) - wire_cast_var, bias = model.variables - with _initialized_session() as sess: - self.assertAllClose([0.], self.evaluate(bias)) - self.assertAllClose([[0.], [0.], [0.], [0.]], - self.evaluate(wire_cast_var)) - sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]])) - sess.run(bias.assign([5.])) - self.assertAllClose([[1005.], [10015.]], self.evaluate(predictions)) - - def test_dense_and_sparse_bias(self): - wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4) - price = fc.numeric_column('price') - with ops.Graph().as_default(): - wire_tensor = sparse_tensor.SparseTensor( - values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3] - indices=[[0, 0], [1, 0], [1, 1]], - dense_shape=[2, 2]) - features = {'wire_cast': wire_tensor, 'price': [[1.], [5.]]} - model = fc.LinearModel([wire_cast, price]) - predictions = model(features) - price_var, wire_cast_var, bias = model.variables - with _initialized_session() as sess: - sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]])) - sess.run(bias.assign([5.])) - sess.run(price_var.assign([[10.]])) - self.assertAllClose([[1015.], [10065.]], self.evaluate(predictions)) - - def test_dense_and_sparse_column(self): - """When the column is both dense and sparse, uses sparse tensors.""" - - class _DenseAndSparseColumn(BaseFeatureColumnForTests, fc.DenseColumn, - fc.CategoricalColumn): - - @property - def _is_v2_column(self): - return True - - @property - def name(self): - return 'dense_and_sparse_column' - - @property - def parse_example_spec(self): - return {self.name: parsing_ops.VarLenFeature(self.dtype)} - - def transform_feature(self, transformation_cache, state_manager): - return transformation_cache.get(self.name, state_manager) - - @property - def variable_shape(self): - raise ValueError('Should not use this method.') - - def get_dense_tensor(self, transformation_cache, state_manager): - raise ValueError('Should not use this method.') - - @property - def num_buckets(self): - return 4 - - def get_sparse_tensors(self, transformation_cache, state_manager): - sp_tensor = sparse_tensor.SparseTensor( - indices=[[0, 0], [1, 0], [1, 1]], - values=[2, 0, 3], - dense_shape=[2, 2]) - return fc.CategoricalColumn.IdWeightPair(sp_tensor, None) - - dense_and_sparse_column = _DenseAndSparseColumn() - with ops.Graph().as_default(): - sp_tensor = sparse_tensor.SparseTensor( - values=['omar', 'stringer', 'marlo'], - indices=[[0, 0], [1, 0], [1, 1]], - dense_shape=[2, 2]) - features = {dense_and_sparse_column.name: sp_tensor} - model = fc.LinearModel([dense_and_sparse_column]) - predictions = model(features) - dense_and_sparse_column_var, bias = model.variables - with _initialized_session() as sess: - sess.run( - dense_and_sparse_column_var.assign([[10.], [100.], [1000.], - [10000.]])) - sess.run(bias.assign([5.])) - self.assertAllClose([[1005.], [10015.]], self.evaluate(predictions)) - - def test_dense_multi_output(self): - price = fc.numeric_column('price') - with ops.Graph().as_default(): - features = {'price': [[1.], [5.]]} - model = fc.LinearModel([price], units=3) - predictions = model(features) - price_var, bias = model.variables - with _initialized_session() as sess: - self.assertAllClose(np.zeros((3,)), self.evaluate(bias)) - self.assertAllClose(np.zeros((1, 3)), self.evaluate(price_var)) - sess.run(price_var.assign([[10., 100., 1000.]])) - sess.run(bias.assign([5., 6., 7.])) - self.assertAllClose([[15., 106., 1007.], [55., 506., 5007.]], - self.evaluate(predictions)) - - def test_sparse_multi_output(self): - wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4) - with ops.Graph().as_default(): - wire_tensor = sparse_tensor.SparseTensor( - values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3] - indices=[[0, 0], [1, 0], [1, 1]], - dense_shape=[2, 2]) - features = {'wire_cast': wire_tensor} - model = fc.LinearModel([wire_cast], units=3) - predictions = model(features) - wire_cast_var, bias = model.variables - with _initialized_session() as sess: - self.assertAllClose(np.zeros((3,)), self.evaluate(bias)) - self.assertAllClose(np.zeros((4, 3)), self.evaluate(wire_cast_var)) - sess.run( - wire_cast_var.assign([[10., 11., 12.], [100., 110., 120.], - [1000., 1100., 1200.], - [10000., 11000., 12000.]])) - sess.run(bias.assign([5., 6., 7.])) - self.assertAllClose([[1005., 1106., 1207.], [10015., 11017., 12019.]], - self.evaluate(predictions)) - - def test_dense_multi_dimension(self): - price = fc.numeric_column('price', shape=2) - with ops.Graph().as_default(): - features = {'price': [[1., 2.], [5., 6.]]} - model = fc.LinearModel([price]) - predictions = model(features) - price_var, _ = model.variables - with _initialized_session() as sess: - self.assertAllClose([[0.], [0.]], self.evaluate(price_var)) - sess.run(price_var.assign([[10.], [100.]])) - self.assertAllClose([[210.], [650.]], self.evaluate(predictions)) - - def test_sparse_multi_rank(self): - wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4) - with ops.Graph().as_default(): - wire_tensor = array_ops.sparse_placeholder(dtypes.string) - wire_value = sparse_tensor.SparseTensorValue( - values=['omar', 'stringer', 'marlo', 'omar'], # hashed = [2, 0, 3, 2] - indices=[[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 0, 1]], - dense_shape=[2, 2, 2]) - features = {'wire_cast': wire_tensor} - model = fc.LinearModel([wire_cast]) - predictions = model(features) - wire_cast_var, _ = model.variables - with _initialized_session() as sess: - self.assertAllClose(np.zeros((4, 1)), self.evaluate(wire_cast_var)) - self.assertAllClose( - np.zeros((2, 1)), - predictions.eval(feed_dict={wire_tensor: wire_value})) - sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]])) - self.assertAllClose( - [[1010.], [11000.]], - predictions.eval(feed_dict={wire_tensor: wire_value})) - - def test_sparse_combiner(self): - wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4) - with ops.Graph().as_default(): - wire_tensor = sparse_tensor.SparseTensor( - values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3] - indices=[[0, 0], [1, 0], [1, 1]], - dense_shape=[2, 2]) - features = {'wire_cast': wire_tensor} - model = fc.LinearModel([wire_cast], sparse_combiner='mean') - predictions = model(features) - wire_cast_var, bias = model.variables - with _initialized_session() as sess: - sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]])) - sess.run(bias.assign([5.])) - self.assertAllClose([[1005.], [5010.]], self.evaluate(predictions)) - - def test_sparse_combiner_sqrtn(self): - wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4) - with ops.Graph().as_default(): - wire_tensor = sparse_tensor.SparseTensor( - values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3] - indices=[[0, 0], [1, 0], [1, 1]], - dense_shape=[2, 2]) - features = {'wire_cast': wire_tensor} - model = fc.LinearModel([wire_cast], sparse_combiner='sqrtn') - predictions = model(features) - wire_cast_var, bias = model.variables - with _initialized_session() as sess: - self.evaluate(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]])) - self.evaluate(bias.assign([5.])) - self.assertAllClose([[1005.], [7083.139]], self.evaluate(predictions)) - - def test_sparse_combiner_with_negative_weights(self): - wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4) - wire_cast_weights = fc.weighted_categorical_column(wire_cast, 'weights') - - with ops.Graph().as_default(): - wire_tensor = sparse_tensor.SparseTensor( - values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3] - indices=[[0, 0], [1, 0], [1, 1]], - dense_shape=[2, 2]) - features = { - 'wire_cast': wire_tensor, - 'weights': constant_op.constant([[1., 1., -1.0]]) - } - model = fc.LinearModel([wire_cast_weights], sparse_combiner='sum') - predictions = model(features) - wire_cast_var, bias = model.variables - with _initialized_session() as sess: - sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]])) - sess.run(bias.assign([5.])) - self.assertAllClose([[1005.], [-9985.]], self.evaluate(predictions)) - - def test_dense_multi_dimension_multi_output(self): - price = fc.numeric_column('price', shape=2) - with ops.Graph().as_default(): - features = {'price': [[1., 2.], [5., 6.]]} - model = fc.LinearModel([price], units=3) - predictions = model(features) - price_var, bias = model.variables - with _initialized_session() as sess: - self.assertAllClose(np.zeros((3,)), self.evaluate(bias)) - self.assertAllClose(np.zeros((2, 3)), self.evaluate(price_var)) - sess.run(price_var.assign([[1., 2., 3.], [10., 100., 1000.]])) - sess.run(bias.assign([2., 3., 4.])) - self.assertAllClose([[23., 205., 2007.], [67., 613., 6019.]], - self.evaluate(predictions)) - - def test_raises_if_shape_mismatch(self): - price = fc.numeric_column('price', shape=2) - with ops.Graph().as_default(): - features = {'price': [[1.], [5.]]} - with self.assertRaisesRegexp( - Exception, - r'Cannot reshape a tensor with 2 elements to shape \[2,2\]'): - model = fc.LinearModel([price]) - model(features) - - def test_dense_reshaping(self): - price = fc.numeric_column('price', shape=[1, 2]) - with ops.Graph().as_default(): - features = {'price': [[[1., 2.]], [[5., 6.]]]} - model = fc.LinearModel([price]) - predictions = model(features) - price_var, bias = model.variables - with _initialized_session() as sess: - self.assertAllClose([0.], self.evaluate(bias)) - self.assertAllClose([[0.], [0.]], self.evaluate(price_var)) - self.assertAllClose([[0.], [0.]], self.evaluate(predictions)) - sess.run(price_var.assign([[10.], [100.]])) - self.assertAllClose([[210.], [650.]], self.evaluate(predictions)) - - def test_dense_multi_column(self): - price1 = fc.numeric_column('price1', shape=2) - price2 = fc.numeric_column('price2') - with ops.Graph().as_default(): - features = {'price1': [[1., 2.], [5., 6.]], 'price2': [[3.], [4.]]} - model = fc.LinearModel([price1, price2]) - predictions = model(features) - price1_var, price2_var, bias = model.variables - with _initialized_session() as sess: - self.assertAllClose([0.], self.evaluate(bias)) - self.assertAllClose([[0.], [0.]], self.evaluate(price1_var)) - self.assertAllClose([[0.]], self.evaluate(price2_var)) - self.assertAllClose([[0.], [0.]], self.evaluate(predictions)) - sess.run(price1_var.assign([[10.], [100.]])) - sess.run(price2_var.assign([[1000.]])) - sess.run(bias.assign([7.])) - self.assertAllClose([[3217.], [4657.]], self.evaluate(predictions)) - - def test_dense_trainable_default(self): - price = fc.numeric_column('price') - with ops.Graph().as_default() as g: - features = {'price': [[1.], [5.]]} - model = fc.LinearModel([price]) - model(features) - price_var, bias = model.variables - trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) - self.assertIn(bias, trainable_vars) - self.assertIn(price_var, trainable_vars) - - def test_sparse_trainable_default(self): - wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4) - with ops.Graph().as_default() as g: - wire_tensor = sparse_tensor.SparseTensor( - values=['omar'], indices=[[0, 0]], dense_shape=[1, 1]) - features = {'wire_cast': wire_tensor} - model = fc.LinearModel([wire_cast]) - model(features) - trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) - wire_cast_var, bias = model.variables - self.assertIn(bias, trainable_vars) - self.assertIn(wire_cast_var, trainable_vars) - - def test_dense_trainable_false(self): - price = fc.numeric_column('price') - with ops.Graph().as_default() as g: - features = {'price': [[1.], [5.]]} - model = fc.LinearModel([price], trainable=False) - model(features) - trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) - self.assertEqual([], trainable_vars) - - def test_sparse_trainable_false(self): - wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4) - with ops.Graph().as_default() as g: - wire_tensor = sparse_tensor.SparseTensor( - values=['omar'], indices=[[0, 0]], dense_shape=[1, 1]) - features = {'wire_cast': wire_tensor} - model = fc.LinearModel([wire_cast], trainable=False) - model(features) - trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) - self.assertEqual([], trainable_vars) - - def test_column_order(self): - price_a = fc.numeric_column('price_a') - price_b = fc.numeric_column('price_b') - wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4) - with ops.Graph().as_default(): - features = { - 'price_a': [[1.]], - 'price_b': [[3.]], - 'wire_cast': - sparse_tensor.SparseTensor( - values=['omar'], indices=[[0, 0]], dense_shape=[1, 1]) - } - model = fc.LinearModel([price_a, wire_cast, price_b]) - model(features) - - my_vars = model.variables - self.assertIn('price_a', my_vars[0].name) - self.assertIn('price_b', my_vars[1].name) - self.assertIn('wire_cast', my_vars[2].name) - - with ops.Graph().as_default(): - features = { - 'price_a': [[1.]], - 'price_b': [[3.]], - 'wire_cast': - sparse_tensor.SparseTensor( - values=['omar'], indices=[[0, 0]], dense_shape=[1, 1]) - } - model = fc.LinearModel([wire_cast, price_b, price_a]) - model(features) - - my_vars = model.variables - self.assertIn('price_a', my_vars[0].name) - self.assertIn('price_b', my_vars[1].name) - self.assertIn('wire_cast', my_vars[2].name) - - def test_variable_names(self): - price1 = fc.numeric_column('price1') - dense_feature = fc.numeric_column('dense_feature') - dense_feature_bucketized = fc.bucketized_column( - dense_feature, boundaries=[0.]) - some_sparse_column = fc.categorical_column_with_hash_bucket( - 'sparse_feature', hash_bucket_size=5) - some_embedding_column = fc.embedding_column( - some_sparse_column, dimension=10) - all_cols = [price1, dense_feature_bucketized, some_embedding_column] - - with ops.Graph().as_default(): - model = fc.LinearModel(all_cols) - features = { - 'price1': [[3.], [4.]], - 'dense_feature': [[-1.], [4.]], - 'sparse_feature': [['a'], ['x']], - } - model(features) - for var in model.variables: - self.assertIsInstance(var, variables_lib.VariableV1) - variable_names = [var.name for var in model.variables] - self.assertItemsEqual([ - 'linear_model/dense_feature_bucketized/weights:0', - 'linear_model/price1/weights:0', - 'linear_model/sparse_feature_embedding/embedding_weights:0', - 'linear_model/sparse_feature_embedding/weights:0', - 'linear_model/bias_weights:0', - ], variable_names) - - def test_fit_and_predict(self): - columns = [fc.numeric_column('a')] - - model = fc.LinearModel(columns) - model.compile( - optimizer=rmsprop.RMSPropOptimizer(1e-3), - loss='binary_crossentropy', - metrics=['accuracy']) - - x = {'a': np.random.random((10, 1))} - y = np.random.randint(0, 2, size=(10, 1)) - model.fit(x, y, epochs=1, batch_size=5) - model.fit(x, y, epochs=1, batch_size=5) - model.evaluate(x, y, batch_size=5) - model.predict(x, batch_size=5) - - def test_static_batch_size_mismatch(self): - price1 = fc.numeric_column('price1') - price2 = fc.numeric_column('price2') - with ops.Graph().as_default(): - features = { - 'price1': [[1.], [5.], [7.]], # batchsize = 3 - 'price2': [[3.], [4.]] # batchsize = 2 - } - with self.assertRaisesRegexp( - ValueError, - r'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string - model = fc.LinearModel([price1, price2]) - model(features) - - def test_subset_of_static_batch_size_mismatch(self): - price1 = fc.numeric_column('price1') - price2 = fc.numeric_column('price2') - price3 = fc.numeric_column('price3') - with ops.Graph().as_default(): - features = { - 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3 - 'price2': [[3.], [4.]], # batchsize = 2 - 'price3': [[3.], [4.], [5.]] # batchsize = 3 - } - with self.assertRaisesRegexp( - ValueError, - r'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string - model = fc.LinearModel([price1, price2, price3]) - model(features) - - def test_runtime_batch_size_mismatch(self): - price1 = fc.numeric_column('price1') - price2 = fc.numeric_column('price2') - with ops.Graph().as_default(): - features = { - 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3 - 'price2': [[3.], [4.]] # batchsize = 2 - } - model = fc.LinearModel([price1, price2]) - predictions = model(features) - with _initialized_session() as sess: - with self.assertRaisesRegexp(errors.OpError, - 'must have the same size and shape'): - sess.run( - predictions, feed_dict={features['price1']: [[1.], [5.], [7.]]}) - - def test_runtime_batch_size_matches(self): - price1 = fc.numeric_column('price1') - price2 = fc.numeric_column('price2') - with ops.Graph().as_default(): - features = { - 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2 - 'price2': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2 - } - model = fc.LinearModel([price1, price2]) - predictions = model(features) - with _initialized_session() as sess: - sess.run( - predictions, - feed_dict={ - features['price1']: [[1.], [5.]], - features['price2']: [[1.], [5.]], - }) - - @test_util.run_deprecated_v1 - def test_with_1d_sparse_tensor(self): - price = fc.numeric_column('price') - price_buckets = fc.bucketized_column( - price, boundaries=[ - 0., - 10., - 100., - ]) - body_style = fc.categorical_column_with_vocabulary_list( - 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan']) - - # Provides 1-dim tensor and dense tensor. - features = { - 'price': - constant_op.constant([ - -1., - 12., - ]), - 'body-style': - sparse_tensor.SparseTensor( - indices=((0,), (1,)), - values=('sedan', 'hardtop'), - dense_shape=(2,)), - } - self.assertEqual(1, features['price'].shape.ndims) - self.assertEqual(1, features['body-style'].dense_shape.get_shape()[0]) - - model = fc.LinearModel([price_buckets, body_style]) - net = model(features) - with _initialized_session() as sess: - body_style_var, price_buckets_var, bias = model.variables - - sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]])) - sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]])) - sess.run(bias.assign([5.])) - - self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]], - self.evaluate(net)) - - @test_util.run_deprecated_v1 - def test_with_1d_unknown_shape_sparse_tensor(self): - price = fc.numeric_column('price') - price_buckets = fc.bucketized_column( - price, boundaries=[ - 0., - 10., - 100., - ]) - body_style = fc.categorical_column_with_vocabulary_list( - 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan']) - country = fc.categorical_column_with_vocabulary_list( - 'country', vocabulary_list=['US', 'JP', 'CA']) - - # Provides 1-dim tensor and dense tensor. - features = { - 'price': array_ops.placeholder(dtypes.float32), - 'body-style': array_ops.sparse_placeholder(dtypes.string), - 'country': array_ops.placeholder(dtypes.string), - } - self.assertIsNone(features['price'].shape.ndims) - self.assertIsNone(features['body-style'].get_shape().ndims) - - price_data = np.array([-1., 12.]) - body_style_data = sparse_tensor.SparseTensorValue( - indices=((0,), (1,)), values=('sedan', 'hardtop'), dense_shape=(2,)) - country_data = np.array(['US', 'CA']) - - model = fc.LinearModel([price_buckets, body_style, country]) - net = model(features) - body_style_var, _, price_buckets_var, bias = model.variables - with _initialized_session() as sess: - sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]])) - sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]])) - sess.run(bias.assign([5.])) - - self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]], - sess.run( - net, - feed_dict={ - features['price']: price_data, - features['body-style']: body_style_data, - features['country']: country_data - })) - - @test_util.run_deprecated_v1 - def test_with_rank_0_feature(self): - price = fc.numeric_column('price') - features = { - 'price': constant_op.constant(0), - } - self.assertEqual(0, features['price'].shape.ndims) - - # Static rank 0 should fail - with self.assertRaisesRegexp(ValueError, 'Feature .* cannot have rank 0'): - model = fc.LinearModel([price]) - model(features) - - # Dynamic rank 0 should fail - features = { - 'price': array_ops.placeholder(dtypes.float32), - } - model = fc.LinearModel([price]) - net = model(features) - self.assertEqual(1, net.shape[1]) - with _initialized_session() as sess: - with self.assertRaisesOpError('Feature .* cannot have rank 0'): - sess.run(net, feed_dict={features['price']: np.array(1)}) - - def test_multiple_linear_models(self): - price = fc.numeric_column('price') - with ops.Graph().as_default(): - features1 = {'price': [[1.], [5.]]} - features2 = {'price': [[2.], [10.]]} - model1 = fc.LinearModel([price]) - model2 = fc.LinearModel([price]) - predictions1 = model1(features1) - predictions2 = model2(features2) - price_var1, bias1 = model1.variables - price_var2, bias2 = model2.variables - with _initialized_session() as sess: - self.assertAllClose([0.], self.evaluate(bias1)) - sess.run(price_var1.assign([[10.]])) - sess.run(bias1.assign([5.])) - self.assertAllClose([[15.], [55.]], self.evaluate(predictions1)) - self.assertAllClose([0.], self.evaluate(bias2)) - sess.run(price_var2.assign([[10.]])) - sess.run(bias2.assign([5.])) - self.assertAllClose([[25.], [105.]], self.evaluate(predictions2)) - - class OldLinearModelTest(test.TestCase): def test_raises_if_empty_feature_columns(self): @@ -2731,10 +1859,10 @@ class OldLinearModelTest(test.TestCase): # We check the mapping by checking that we have the right keys, # and that the values (output_tensors) were indeed the ones used to # form the input layer. - self.assertItemsEqual(all_cols, cols_to_output_tensors.keys()) + self.assertCountEqual(all_cols, cols_to_output_tensors.keys()) input_layer_inputs = [tensor for tensor in input_layer.op.inputs[:-1]] output_tensors = [tensor for tensor in cols_to_output_tensors.values()] - self.assertItemsEqual(input_layer_inputs, output_tensors) + self.assertCountEqual(input_layer_inputs, output_tensors) def test_dense_collection(self): price = fc.numeric_column('price') @@ -3411,7 +2539,7 @@ class FunctionalInputLayerTest(test.TestCase): cols_to_vars = {} all_cols = [price1, dense_feature_bucketized, some_embedding_column] fc_old.input_layer(features, all_cols, cols_to_vars=cols_to_vars) - self.assertItemsEqual(list(cols_to_vars.keys()), all_cols) + self.assertCountEqual(list(cols_to_vars.keys()), all_cols) self.assertEqual(0, len(cols_to_vars[price1])) self.assertEqual(0, len(cols_to_vars[dense_feature_bucketized])) self.assertEqual(1, len(cols_to_vars[some_embedding_column])) @@ -3461,7 +2589,7 @@ class FunctionalInputLayerTest(test.TestCase): shared_embedding_a, shared_embedding_b ] fc_old.input_layer(features, all_cols, cols_to_vars=cols_to_vars) - self.assertItemsEqual(list(cols_to_vars.keys()), all_cols) + self.assertCountEqual(list(cols_to_vars.keys()), all_cols) self.assertEqual(0, len(cols_to_vars[price1])) self.assertEqual(0, len(cols_to_vars[dense_feature_bucketized])) self.assertEqual(1, len(cols_to_vars[some_embedding_column])) @@ -3497,7 +2625,7 @@ class FunctionalInputLayerTest(test.TestCase): 'input_from_feature_columns', partitioner=partitioned_variables.fixed_size_partitioner(3, axis=0)): fc_old.input_layer(features, all_cols, cols_to_vars=cols_to_vars) - self.assertItemsEqual(list(cols_to_vars.keys()), all_cols) + self.assertCountEqual(list(cols_to_vars.keys()), all_cols) self.assertEqual(0, len(cols_to_vars[price1])) self.assertEqual(0, len(cols_to_vars[dense_feature_bucketized])) self.assertEqual(3, len(cols_to_vars[some_embedding_column])) @@ -3616,7 +2744,7 @@ class FunctionalInputLayerTest(test.TestCase): 'input_layer/sparse_feature_embedding/embedding_weights:0', 'input_layer_1/sparse_feature_embedding/embedding_weights:0' ] - self.assertItemsEqual( + self.assertCountEqual( expected_var_names, [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)]) @@ -4362,36 +3490,6 @@ class VocabularyFileCategoricalColumnTest(test.TestCase): dense_shape=inputs.dense_shape), self.evaluate(id_weight_pair.id_tensor)) - @test_util.run_deprecated_v1 - def test_linear_model(self): - wire_column = fc.categorical_column_with_vocabulary_file( - key='wire', - vocabulary_file=self._wire_vocabulary_file_name, - vocabulary_size=self._wire_vocabulary_size, - num_oov_buckets=1) - self.assertEqual(4, wire_column.num_buckets) - with ops.Graph().as_default(): - model = fc.LinearModel((wire_column,)) - predictions = model({ - wire_column.name: - sparse_tensor.SparseTensorValue( - indices=((0, 0), (1, 0), (1, 1)), - values=('marlo', 'skywalker', 'omar'), - dense_shape=(2, 2)) - }) - wire_var, bias = model.variables - - self.evaluate(variables_lib.global_variables_initializer()) - self.evaluate(lookup_ops.tables_initializer()) - - self.assertAllClose((0.,), self.evaluate(bias)) - self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), self.evaluate(wire_var)) - self.assertAllClose(((0.,), (0.,)), self.evaluate(predictions)) - self.evaluate(wire_var.assign(((1.,), (2.,), (3.,), (4.,)))) - # 'marlo' -> 2: wire_var[2] = 3 - # 'skywalker' -> 3, 'omar' -> 0: wire_var[3] + wire_var[0] = 4+1 = 5 - self.assertAllClose(((3.,), (5.,)), self.evaluate(predictions)) - def test_old_linear_model(self): wire_column = fc.categorical_column_with_vocabulary_file( key='wire', @@ -4828,35 +3926,6 @@ class VocabularyListCategoricalColumnTest(test.TestCase): dense_shape=inputs.dense_shape), self.evaluate(id_weight_pair.id_tensor)) - @test_util.run_deprecated_v1 - def test_linear_model(self): - wire_column = fc.categorical_column_with_vocabulary_list( - key='aaa', - vocabulary_list=('omar', 'stringer', 'marlo'), - num_oov_buckets=1) - self.assertEqual(4, wire_column.num_buckets) - with ops.Graph().as_default(): - model = fc.LinearModel((wire_column,)) - predictions = model({ - wire_column.name: - sparse_tensor.SparseTensorValue( - indices=((0, 0), (1, 0), (1, 1)), - values=('marlo', 'skywalker', 'omar'), - dense_shape=(2, 2)) - }) - wire_var, bias = model.variables - - self.evaluate(variables_lib.global_variables_initializer()) - self.evaluate(lookup_ops.tables_initializer()) - - self.assertAllClose((0.,), self.evaluate(bias)) - self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), self.evaluate(wire_var)) - self.assertAllClose(((0.,), (0.,)), self.evaluate(predictions)) - self.evaluate(wire_var.assign(((1.,), (2.,), (3.,), (4.,)))) - # 'marlo' -> 2: wire_var[2] = 3 - # 'skywalker' -> 3, 'omar' -> 0: wire_var[3] + wire_var[0] = 4+1 = 5 - self.assertAllClose(((3.,), (5.,)), self.evaluate(predictions)) - def test_old_linear_model(self): wire_column = fc.categorical_column_with_vocabulary_list( key='aaa', @@ -5196,32 +4265,6 @@ class IdentityCategoricalColumnTest(test.TestCase): input_shape: (2, 2), })) - @test_util.run_deprecated_v1 - def test_linear_model(self): - column = fc.categorical_column_with_identity(key='aaa', num_buckets=3) - self.assertEqual(3, column.num_buckets) - with ops.Graph().as_default(): - model = fc.LinearModel((column,)) - predictions = model({ - column.name: - sparse_tensor.SparseTensorValue( - indices=((0, 0), (1, 0), (1, 1)), - values=(0, 2, 1), - dense_shape=(2, 2)) - }) - weight_var, bias = model.variables - - self.evaluate(variables_lib.global_variables_initializer()) - self.evaluate(lookup_ops.tables_initializer()) - - self.assertAllClose((0.,), self.evaluate(bias)) - self.assertAllClose(((0.,), (0.,), (0.,)), self.evaluate(weight_var)) - self.assertAllClose(((0.,), (0.,)), self.evaluate(predictions)) - self.evaluate(weight_var.assign(((1.,), (2.,), (3.,)))) - # weight_var[0] = 1 - # weight_var[2] + weight_var[1] = 3+2 = 5 - self.assertAllClose(((1.,), (5.,)), self.evaluate(predictions)) - def test_old_linear_model(self): column = fc.categorical_column_with_identity(key='aaa', num_buckets=3) self.assertEqual(3, column.num_buckets) @@ -5514,30 +4557,6 @@ class IndicatorColumnTest(test.TestCase): self.assertAllEqual([[0., 1., 1.]], self.evaluate(indicator_tensor)) - @test_util.run_deprecated_v1 - def test_linear_model(self): - animal = fc.indicator_column( - fc.categorical_column_with_identity('animal', num_buckets=4)) - with ops.Graph().as_default(): - features = { - 'animal': - sparse_tensor.SparseTensor( - indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2]) - } - - model = fc.LinearModel([animal]) - predictions = model(features) - weight_var, _ = model.variables - - self.evaluate(variables_lib.global_variables_initializer()) - self.evaluate(lookup_ops.tables_initializer()) - - # All should be zero-initialized. - self.assertAllClose([[0.], [0.], [0.], [0.]], self.evaluate(weight_var)) - self.assertAllClose([[0.]], self.evaluate(predictions)) - self.evaluate(weight_var.assign([[1.], [2.], [3.], [4.]])) - self.assertAllClose([[2. + 3.]], self.evaluate(predictions)) - def test_old_linear_model(self): animal = fc.indicator_column( fc.categorical_column_with_identity('animal', num_buckets=4)) @@ -5582,23 +4601,6 @@ class IndicatorColumnTest(test.TestCase): self.evaluate(weight_var.assign([[1.], [2.], [3.], [4.]])) self.assertAllClose([[2. + 3.]], self.evaluate(predictions)) - @test_util.run_deprecated_v1 - def test_dense_features(self): - animal = fc.indicator_column( - fc.categorical_column_with_identity('animal', num_buckets=4)) - with ops.Graph().as_default(): - features = { - 'animal': - sparse_tensor.SparseTensor( - indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2]) - } - net = df.DenseFeatures([animal])(features) - - self.evaluate(variables_lib.global_variables_initializer()) - self.evaluate(lookup_ops.tables_initializer()) - - self.assertAllClose([[0., 1., 1., 0.]], self.evaluate(net)) - @test_util.run_deprecated_v1 def test_input_layer(self): animal = fc.indicator_column( @@ -5904,7 +4906,7 @@ class EmbeddingColumnTest(test.TestCase, parameterized.TestCase): # Assert expected embedding variable and lookups. global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - self.assertItemsEqual(('embedding_weights:0',), + self.assertCountEqual(('embedding_weights:0',), tuple([v.name for v in global_vars])) self.evaluate(variables_lib.global_variables_initializer()) @@ -5968,7 +4970,7 @@ class EmbeddingColumnTest(test.TestCase, parameterized.TestCase): # Assert expected embedding variable and lookups. global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - self.assertItemsEqual(('embedding_weights:0',), + self.assertCountEqual(('embedding_weights:0',), tuple([v.name for v in global_vars])) self.evaluate(variables_lib.global_variables_initializer()) @@ -6036,7 +5038,7 @@ class EmbeddingColumnTest(test.TestCase, parameterized.TestCase): # Assert expected embedding variable and lookups. global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - self.assertItemsEqual(('embedding_weights:0',), + self.assertCountEqual(('embedding_weights:0',), tuple([v.name for v in global_vars])) self.evaluate(variables_lib.global_variables_initializer()) @@ -6109,7 +5111,7 @@ class EmbeddingColumnTest(test.TestCase, parameterized.TestCase): # Assert expected embedding variable and lookups. global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - self.assertItemsEqual(('embedding_weights:0',), + self.assertCountEqual(('embedding_weights:0',), tuple([v.name for v in global_vars])) self.evaluate(variables_lib.global_variables_initializer()) @@ -6180,7 +5182,7 @@ class EmbeddingColumnTest(test.TestCase, parameterized.TestCase): # Assert expected embedding variable and lookups. global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - self.assertItemsEqual(('embedding_weights:0',), + self.assertCountEqual(('embedding_weights:0',), tuple([v.name for v in global_vars])) self.evaluate(variables_lib.global_variables_initializer()) @@ -6189,238 +5191,6 @@ class EmbeddingColumnTest(test.TestCase, parameterized.TestCase): self.assertAllEqual(embedding_values, self.evaluate(global_vars[0])) self.assertAllEqual(expected_lookups, self.evaluate(embedding_lookup)) - @test_util.run_deprecated_v1 - def test_linear_model(self): - # Inputs. - batch_size = 4 - vocabulary_size = 3 - sparse_input = sparse_tensor.SparseTensorValue( - # example 0, ids [2] - # example 1, ids [0, 1] - # example 2, ids [] - # example 3, ids [1] - indices=((0, 0), (1, 0), (1, 4), (3, 0)), - values=(2, 0, 1, 1), - dense_shape=(batch_size, 5)) - - # Embedding variable. - embedding_dimension = 2 - embedding_shape = (vocabulary_size, embedding_dimension) - zeros_embedding_values = np.zeros(embedding_shape) - - def _initializer(shape, dtype, partition_info=None): - self.assertAllEqual(embedding_shape, shape) - self.assertEqual(dtypes.float32, dtype) - self.assertIsNone(partition_info) - return zeros_embedding_values - - # Build columns. - categorical_column = fc.categorical_column_with_identity( - key='aaa', num_buckets=vocabulary_size) - embedding_column = fc.embedding_column( - categorical_column, - dimension=embedding_dimension, - initializer=_initializer) - - with ops.Graph().as_default(): - model = fc.LinearModel((embedding_column,)) - predictions = model({categorical_column.name: sparse_input}) - expected_var_names = ( - 'linear_model/bias_weights:0', - 'linear_model/aaa_embedding/weights:0', - 'linear_model/aaa_embedding/embedding_weights:0', - ) - self.assertItemsEqual( - expected_var_names, - [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)]) - trainable_vars = { - v.name: v - for v in ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) - } - self.assertItemsEqual(expected_var_names, trainable_vars.keys()) - bias = trainable_vars['linear_model/bias_weights:0'] - embedding_weights = trainable_vars[ - 'linear_model/aaa_embedding/embedding_weights:0'] - linear_weights = trainable_vars['linear_model/aaa_embedding/weights:0'] - - self.evaluate(variables_lib.global_variables_initializer()) - self.evaluate(lookup_ops.tables_initializer()) - - # Predictions with all zero weights. - self.assertAllClose(np.zeros((1,)), self.evaluate(bias)) - self.assertAllClose(zeros_embedding_values, - self.evaluate(embedding_weights)) - self.assertAllClose( - np.zeros((embedding_dimension, 1)), self.evaluate(linear_weights)) - self.assertAllClose(np.zeros((batch_size, 1)), self.evaluate(predictions)) - - # Predictions with all non-zero weights. - self.evaluate( - embedding_weights.assign(( - (1., 2.), # id 0 - (3., 5.), # id 1 - (7., 11.) # id 2 - ))) - self.evaluate(linear_weights.assign(((4.,), (6.,)))) - # example 0, ids [2], embedding[0] = [7, 11] - # example 1, ids [0, 1], embedding[1] = mean([1, 2] + [3, 5]) = [2, 3.5] - # example 2, ids [], embedding[2] = [0, 0] - # example 3, ids [1], embedding[3] = [3, 5] - # sum(embeddings * linear_weights) - # = [4*7 + 6*11, 4*2 + 6*3.5, 4*0 + 6*0, 4*3 + 6*5] = [94, 29, 0, 42] - self.assertAllClose(((94.,), (29.,), (0.,), (42.,)), - self.evaluate(predictions)) - - @parameterized.named_parameters( - { - 'testcase_name': 'use_safe_embedding_lookup', - 'use_safe_embedding_lookup': True - }, { - 'testcase_name': 'dont_use_safe_embedding_lookup', - 'use_safe_embedding_lookup': False - }) - @test_util.run_deprecated_v1 - def test_dense_features(self, use_safe_embedding_lookup): - # Inputs. - vocabulary_size = 3 - sparse_input = sparse_tensor.SparseTensorValue( - # example 0, ids [2] - # example 1, ids [0, 1] - # example 2, ids [] - # example 3, ids [1] - indices=((0, 0), (1, 0), (1, 4), (3, 0)), - values=(2, 0, 1, 1), - dense_shape=(4, 5)) - - # Embedding variable. - embedding_dimension = 2 - embedding_values = ( - (1., 2.), # id 0 - (3., 5.), # id 1 - (7., 11.) # id 2 - ) - - def _initializer(shape, dtype, partition_info=None): - self.assertAllEqual((vocabulary_size, embedding_dimension), shape) - self.assertEqual(dtypes.float32, dtype) - self.assertIsNone(partition_info) - return embedding_values - - # Expected lookup result, using combiner='mean'. - expected_lookups = ( - # example 0, ids [2], embedding = [7, 11] - (7., 11.), - # example 1, ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5] - (2., 3.5), - # example 2, ids [], embedding = [0, 0] - (0., 0.), - # example 3, ids [1], embedding = [3, 5] - (3., 5.), - ) - - # Build columns. - categorical_column = fc.categorical_column_with_identity( - key='aaa', num_buckets=vocabulary_size) - embedding_column = fc.embedding_column( - categorical_column, - dimension=embedding_dimension, - initializer=_initializer, - use_safe_embedding_lookup=use_safe_embedding_lookup) - - # Provide sparse input and get dense result. - l = df.DenseFeatures((embedding_column,)) - dense_features = l({'aaa': sparse_input}) - - # Assert expected embedding variable and lookups. - global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - self.assertItemsEqual(('dense_features/aaa_embedding/embedding_weights:0',), - tuple([v.name for v in global_vars])) - for v in global_vars: - self.assertIsInstance(v, variables_lib.Variable) - trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) - self.assertItemsEqual(('dense_features/aaa_embedding/embedding_weights:0',), - tuple([v.name for v in trainable_vars])) - - self.evaluate(variables_lib.global_variables_initializer()) - self.evaluate(lookup_ops.tables_initializer()) - - self.assertAllEqual(embedding_values, self.evaluate(trainable_vars[0])) - self.assertAllEqual(expected_lookups, self.evaluate(dense_features)) - - if use_safe_embedding_lookup: - self.assertIn('SparseFillEmptyRows', - [x.type for x in ops.get_default_graph().get_operations()]) - else: - self.assertNotIn( - 'SparseFillEmptyRows', - [x.type for x in ops.get_default_graph().get_operations()]) - - @test_util.run_deprecated_v1 - def test_dense_features_not_trainable(self): - # Inputs. - vocabulary_size = 3 - sparse_input = sparse_tensor.SparseTensorValue( - # example 0, ids [2] - # example 1, ids [0, 1] - # example 2, ids [] - # example 3, ids [1] - indices=((0, 0), (1, 0), (1, 4), (3, 0)), - values=(2, 0, 1, 1), - dense_shape=(4, 5)) - - # Embedding variable. - embedding_dimension = 2 - embedding_values = ( - (1., 2.), # id 0 - (3., 5.), # id 1 - (7., 11.) # id 2 - ) - - def _initializer(shape, dtype, partition_info=None): - self.assertAllEqual((vocabulary_size, embedding_dimension), shape) - self.assertEqual(dtypes.float32, dtype) - self.assertIsNone(partition_info) - return embedding_values - - # Expected lookup result, using combiner='mean'. - expected_lookups = ( - # example 0, ids [2], embedding = [7, 11] - (7., 11.), - # example 1, ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5] - (2., 3.5), - # example 2, ids [], embedding = [0, 0] - (0., 0.), - # example 3, ids [1], embedding = [3, 5] - (3., 5.), - ) - - # Build columns. - categorical_column = fc.categorical_column_with_identity( - key='aaa', num_buckets=vocabulary_size) - embedding_column = fc.embedding_column( - categorical_column, - dimension=embedding_dimension, - initializer=_initializer, - trainable=False) - - # Provide sparse input and get dense result. - dense_features = df.DenseFeatures((embedding_column,))({ - 'aaa': sparse_input - }) - - # Assert expected embedding variable and lookups. - global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - self.assertItemsEqual(('dense_features/aaa_embedding/embedding_weights:0',), - tuple([v.name for v in global_vars])) - self.assertItemsEqual([], - ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)) - - self.evaluate(variables_lib.global_variables_initializer()) - self.evaluate(lookup_ops.tables_initializer()) - - self.assertAllEqual(embedding_values, self.evaluate(global_vars[0])) - self.assertAllEqual(expected_lookups, self.evaluate(dense_features)) - @test_util.run_deprecated_v1 def test_input_layer(self): # Inputs. @@ -6475,10 +5245,10 @@ class EmbeddingColumnTest(test.TestCase, parameterized.TestCase): # Assert expected embedding variable and lookups. global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - self.assertItemsEqual(('input_layer/aaa_embedding/embedding_weights:0',), + self.assertCountEqual(('input_layer/aaa_embedding/embedding_weights:0',), tuple([v.name for v in global_vars])) trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) - self.assertItemsEqual(('input_layer/aaa_embedding/embedding_weights:0',), + self.assertCountEqual(('input_layer/aaa_embedding/embedding_weights:0',), tuple([v.name for v in trainable_vars])) self.evaluate(variables_lib.global_variables_initializer()) @@ -6528,14 +5298,14 @@ class EmbeddingColumnTest(test.TestCase, parameterized.TestCase): 'linear_model/aaa_embedding/weights:0', 'linear_model/aaa_embedding/embedding_weights:0', ) - self.assertItemsEqual( + self.assertCountEqual( expected_var_names, [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)]) trainable_vars = { v.name: v for v in ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) } - self.assertItemsEqual(expected_var_names, trainable_vars.keys()) + self.assertCountEqual(expected_var_names, trainable_vars.keys()) bias = trainable_vars['linear_model/bias_weights:0'] embedding_weights = trainable_vars[ 'linear_model/aaa_embedding/embedding_weights:0'] @@ -6610,14 +5380,14 @@ class EmbeddingColumnTest(test.TestCase, parameterized.TestCase): 'linear_model/aaa_embedding/weights:0', 'linear_model/aaa_embedding/embedding_weights:0', ) - self.assertItemsEqual( + self.assertCountEqual( expected_var_names, [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)]) trainable_vars = { v.name: v for v in ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) } - self.assertItemsEqual(expected_var_names, trainable_vars.keys()) + self.assertCountEqual(expected_var_names, trainable_vars.keys()) bias = trainable_vars['linear_model/bias_weights:0'] embedding_weights = trainable_vars[ 'linear_model/aaa_embedding/embedding_weights:0'] @@ -6972,15 +5742,26 @@ class SharedEmbeddingColumnTest(test.TestCase, parameterized.TestCase): @parameterized.named_parameters( { 'testcase_name': 'use_safe_embedding_lookup', - 'use_safe_embedding_lookup': True + 'use_safe_embedding_lookup': True, + 'partition_variables': False, }, { 'testcase_name': 'dont_use_safe_embedding_lookup', - 'use_safe_embedding_lookup': False + 'use_safe_embedding_lookup': False, + 'partition_variables': False, + }, { + 'testcase_name': 'use_safe_embedding_lookup_partitioned', + 'use_safe_embedding_lookup': True, + 'partition_variables': True, + }, { + 'testcase_name': 'dont_use_safe_embedding_lookup_partitioned', + 'use_safe_embedding_lookup': False, + 'partition_variables': True, }) @test_util.run_deprecated_v1 - def test_get_dense_tensor(self, use_safe_embedding_lookup): + def test_get_dense_tensor(self, use_safe_embedding_lookup, + partition_variables): # Inputs. - vocabulary_size = 3 + vocabulary_size = 4 # -1 values are ignored. input_a = np.array([ [2, -1, -1], # example 0, ids [2] @@ -6997,13 +5778,20 @@ class SharedEmbeddingColumnTest(test.TestCase, parameterized.TestCase): embedding_values = ( (1., 2.), # id 0 (3., 5.), # id 1 - (7., 11.) # id 2 + (7., 11.), # id 2 + (9., 13.) # id 3 ) def _initializer(shape, dtype, partition_info=None): - self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + if partition_variables: + self.assertEqual([vocabulary_size, embedding_dimension], + partition_info.full_shape) + self.assertAllEqual((2, embedding_dimension), shape) + else: + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertIsNone(partition_info) + self.assertEqual(dtypes.float32, dtype) - self.assertIsNone(partition_info) return embedding_values # Expected lookup result, using combiner='mean'. @@ -7031,22 +5819,32 @@ class SharedEmbeddingColumnTest(test.TestCase, parameterized.TestCase): key='aaa', num_buckets=vocabulary_size) categorical_column_b = fc.categorical_column_with_identity( key='bbb', num_buckets=vocabulary_size) - embedding_column_a, embedding_column_b = fc.shared_embedding_columns_v2( - [categorical_column_a, categorical_column_b], - dimension=embedding_dimension, - initializer=_initializer, - use_safe_embedding_lookup=use_safe_embedding_lookup) - # Provide sparse input and get dense result. - embedding_lookup_a = embedding_column_a.get_dense_tensor( - fc.FeatureTransformationCache(input_features), None) - embedding_lookup_b = embedding_column_b.get_dense_tensor( - fc.FeatureTransformationCache(input_features), None) + partitioner = None + if partition_variables: + partitioner = partitioned_variables.fixed_size_partitioner(2, axis=0) + + with variable_scope.variable_scope('vars', partitioner=partitioner): + embedding_column_a, embedding_column_b = fc.shared_embedding_columns_v2( + [categorical_column_a, categorical_column_b], + dimension=embedding_dimension, + initializer=_initializer, + use_safe_embedding_lookup=use_safe_embedding_lookup) + # Provide sparse input and get dense result. + embedding_lookup_a = embedding_column_a.get_dense_tensor( + fc.FeatureTransformationCache(input_features), None) + embedding_lookup_b = embedding_column_b.get_dense_tensor( + fc.FeatureTransformationCache(input_features), None) # Assert expected embedding variable and lookups. global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - self.assertItemsEqual(('aaa_bbb_shared_embedding:0',), - tuple([v.name for v in global_vars])) + if partition_variables: + self.assertCountEqual(('vars/aaa_bbb_shared_embedding/part_0:0', + 'vars/aaa_bbb_shared_embedding/part_1:0'), + tuple([v.name for v in global_vars])) + else: + self.assertCountEqual(('vars/aaa_bbb_shared_embedding:0',), + tuple([v.name for v in global_vars])) embedding_var = global_vars[0] self.evaluate(variables_lib.global_variables_initializer()) @@ -7228,227 +6026,6 @@ class SharedEmbeddingColumnTest(test.TestCase, parameterized.TestCase): with _initialized_session() as sess: sess.run([embedding_lookup_a, embedding_lookup_b], feed_dict=feed_dict) - @test_util.run_deprecated_v1 - def test_linear_model(self): - # Inputs. - batch_size = 2 - vocabulary_size = 3 - # -1 values are ignored. - input_a = np.array([ - [2, -1, -1], # example 0, ids [2] - [0, 1, -1] - ]) # example 1, ids [0, 1] - input_b = np.array([ - [0, -1, -1], # example 0, ids [0] - [-1, -1, -1] - ]) # example 1, ids [] - - # Embedding variable. - embedding_dimension = 2 - embedding_shape = (vocabulary_size, embedding_dimension) - zeros_embedding_values = np.zeros(embedding_shape) - - def _initializer(shape, dtype, partition_info=None): - self.assertAllEqual(embedding_shape, shape) - self.assertEqual(dtypes.float32, dtype) - self.assertIsNone(partition_info) - return zeros_embedding_values - - # Build columns. - categorical_column_a = fc.categorical_column_with_identity( - key='aaa', num_buckets=vocabulary_size) - categorical_column_b = fc.categorical_column_with_identity( - key='bbb', num_buckets=vocabulary_size) - embedding_column_a, embedding_column_b = fc.shared_embedding_columns_v2( - [categorical_column_a, categorical_column_b], - dimension=embedding_dimension, - initializer=_initializer) - - with ops.Graph().as_default(): - model = fc.LinearModel((embedding_column_a, embedding_column_b)) - predictions = model({ - categorical_column_a.name: input_a, - categorical_column_b.name: input_b - }) - - # Linear weights do not follow the column name. But this is a rare use - # case, and fixing it would add too much complexity to the code. - expected_var_names = ( - 'linear_model/bias_weights:0', - 'linear_model/aaa_shared_embedding/weights:0', - 'aaa_bbb_shared_embedding:0', - 'linear_model/bbb_shared_embedding/weights:0', - ) - self.assertItemsEqual( - expected_var_names, - [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)]) - trainable_vars = { - v.name: v - for v in ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) - } - self.assertItemsEqual(expected_var_names, trainable_vars.keys()) - bias = trainable_vars['linear_model/bias_weights:0'] - embedding_weights = trainable_vars['aaa_bbb_shared_embedding:0'] - linear_weights_a = trainable_vars[ - 'linear_model/aaa_shared_embedding/weights:0'] - linear_weights_b = trainable_vars[ - 'linear_model/bbb_shared_embedding/weights:0'] - - self.evaluate(variables_lib.global_variables_initializer()) - self.evaluate(lookup_ops.tables_initializer()) - - # Predictions with all zero weights. - self.assertAllClose(np.zeros((1,)), self.evaluate(bias)) - self.assertAllClose(zeros_embedding_values, - self.evaluate(embedding_weights)) - self.assertAllClose( - np.zeros((embedding_dimension, 1)), self.evaluate(linear_weights_a)) - self.assertAllClose( - np.zeros((embedding_dimension, 1)), self.evaluate(linear_weights_b)) - self.assertAllClose(np.zeros((batch_size, 1)), self.evaluate(predictions)) - - # Predictions with all non-zero weights. - self.evaluate( - embedding_weights.assign(( - (1., 2.), # id 0 - (3., 5.), # id 1 - (7., 11.) # id 2 - ))) - self.evaluate(linear_weights_a.assign(((4.,), (6.,)))) - # example 0, ids [2], embedding[0] = [7, 11] - # example 1, ids [0, 1], embedding[1] = mean([1, 2] + [3, 5]) = [2, 3.5] - # sum(embeddings * linear_weights) - # = [4*7 + 6*11, 4*2 + 6*3.5] = [94, 29] - self.evaluate(linear_weights_b.assign(((3.,), (5.,)))) - # example 0, ids [0], embedding[0] = [1, 2] - # example 1, ids [], embedding[1] = 0, 0] - # sum(embeddings * linear_weights) - # = [3*1 + 5*2, 3*0 +5*0] = [13, 0] - self.assertAllClose([[94. + 13.], [29.]], self.evaluate(predictions)) - - def _test_dense_features(self, trainable=True): - # Inputs. - vocabulary_size = 3 - sparse_input_a = sparse_tensor.SparseTensorValue( - # example 0, ids [2] - # example 1, ids [0, 1] - indices=((0, 0), (1, 0), (1, 4)), - values=(2, 0, 1), - dense_shape=(2, 5)) - sparse_input_b = sparse_tensor.SparseTensorValue( - # example 0, ids [0] - # example 1, ids [] - indices=((0, 0),), - values=(0,), - dense_shape=(2, 5)) - sparse_input_c = sparse_tensor.SparseTensorValue( - # example 0, ids [2] - # example 1, ids [0, 1] - indices=((0, 1), (1, 1), (1, 3)), - values=(2, 0, 1), - dense_shape=(2, 5)) - sparse_input_d = sparse_tensor.SparseTensorValue( - # example 0, ids [2] - # example 1, ids [] - indices=((0, 1),), - values=(2,), - dense_shape=(2, 5)) - - # Embedding variable. - embedding_dimension = 2 - embedding_values = ( - (1., 2.), # id 0 - (3., 5.), # id 1 - (7., 11.) # id 2 - ) - - def _initializer(shape, dtype, partition_info=None): - self.assertAllEqual((vocabulary_size, embedding_dimension), shape) - self.assertEqual(dtypes.float32, dtype) - self.assertIsNone(partition_info) - return embedding_values - - # Expected lookup result, using combiner='mean'. - expected_lookups = ( - # example 0: - # A ids [2], embedding = [7, 11] - # B ids [0], embedding = [1, 2] - # C ids [2], embedding = [7, 11] - # D ids [2], embedding = [7, 11] - (7., 11., 1., 2., 7., 11., 7., 11.), - # example 1: - # A ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5] - # B ids [], embedding = [0, 0] - # C ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5] - # D ids [], embedding = [0, 0] - (2., 3.5, 0., 0., 2., 3.5, 0., 0.), - ) - - # Build columns. - categorical_column_a = fc.categorical_column_with_identity( - key='aaa', num_buckets=vocabulary_size) - categorical_column_b = fc.categorical_column_with_identity( - key='bbb', num_buckets=vocabulary_size) - categorical_column_c = fc.categorical_column_with_identity( - key='ccc', num_buckets=vocabulary_size) - categorical_column_d = fc.categorical_column_with_identity( - key='ddd', num_buckets=vocabulary_size) - - embedding_column_a, embedding_column_b = fc.shared_embedding_columns_v2( - [categorical_column_a, categorical_column_b], - dimension=embedding_dimension, - initializer=_initializer, - trainable=trainable) - embedding_column_c, embedding_column_d = fc.shared_embedding_columns_v2( - [categorical_column_c, categorical_column_d], - dimension=embedding_dimension, - initializer=_initializer, - trainable=trainable) - - features = { - 'aaa': sparse_input_a, - 'bbb': sparse_input_b, - 'ccc': sparse_input_c, - 'ddd': sparse_input_d - } - - # Provide sparse input and get dense result. - dense_features = df.DenseFeatures( - feature_columns=(embedding_column_b, embedding_column_a, - embedding_column_c, embedding_column_d))( - features) - - # Assert expected embedding variable and lookups. - global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - self.assertItemsEqual( - ['aaa_bbb_shared_embedding:0', 'ccc_ddd_shared_embedding:0'], - tuple([v.name for v in global_vars])) - for v in global_vars: - self.assertIsInstance(v, variables_lib.Variable) - trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) - if trainable: - self.assertItemsEqual( - ['aaa_bbb_shared_embedding:0', 'ccc_ddd_shared_embedding:0'], - tuple([v.name for v in trainable_vars])) - else: - self.assertItemsEqual([], tuple([v.name for v in trainable_vars])) - shared_embedding_vars = global_vars - - self.evaluate(variables_lib.global_variables_initializer()) - self.evaluate(lookup_ops.tables_initializer()) - - self.assertAllEqual(embedding_values, - self.evaluate(shared_embedding_vars[0])) - self.assertAllEqual(expected_lookups, self.evaluate(dense_features)) - - @test_util.run_deprecated_v1 - def test_dense_features(self): - self._test_dense_features() - - @test_util.run_deprecated_v1 - def test_dense_features_no_trainable(self): - self._test_dense_features(trainable=False) - @test_util.run_deprecated_v1 def test_serialization(self): @@ -7687,115 +6264,6 @@ class WeightedCategoricalColumnTest(test.TestCase): values=np.array((.5, 1., .1), dtype=np.float32), dense_shape=(2, 2)), self.evaluate(weight_tensor)) - @test_util.run_deprecated_v1 - def test_linear_model(self): - column = fc.weighted_categorical_column( - categorical_column=fc.categorical_column_with_identity( - key='ids', num_buckets=3), - weight_feature_key='values') - with ops.Graph().as_default(): - model = fc.LinearModel((column,)) - predictions = model({ - 'ids': - sparse_tensor.SparseTensorValue( - indices=((0, 0), (1, 0), (1, 1)), - values=(0, 2, 1), - dense_shape=(2, 2)), - 'values': - sparse_tensor.SparseTensorValue( - indices=((0, 0), (1, 0), (1, 1)), - values=(.5, 1., .1), - dense_shape=(2, 2)) - }) - weight_var, bias = model.variables - - self.evaluate(variables_lib.global_variables_initializer()) - self.evaluate(lookup_ops.tables_initializer()) - - self.assertAllClose((0.,), self.evaluate(bias)) - self.assertAllClose(((0.,), (0.,), (0.,)), self.evaluate(weight_var)) - self.assertAllClose(((0.,), (0.,)), self.evaluate(predictions)) - self.evaluate(weight_var.assign(((1.,), (2.,), (3.,)))) - # weight_var[0] * weights[0, 0] = 1 * .5 = .5 - # weight_var[2] * weights[1, 0] + weight_var[1] * weights[1, 1] - # = 3*1 + 2*.1 = 3+.2 = 3.2 - self.assertAllClose(((.5,), (3.2,)), self.evaluate(predictions)) - - def test_linear_model_mismatched_shape(self): - column = fc.weighted_categorical_column( - categorical_column=fc.categorical_column_with_identity( - key='ids', num_buckets=3), - weight_feature_key='values') - with ops.Graph().as_default(): - with self.assertRaisesRegexp(ValueError, - r'Dimensions.*are not compatible'): - model = fc.LinearModel((column,)) - model({ - 'ids': - sparse_tensor.SparseTensorValue( - indices=((0, 0), (1, 0), (1, 1)), - values=(0, 2, 1), - dense_shape=(2, 2)), - 'values': - sparse_tensor.SparseTensorValue( - indices=((0, 0), (0, 1), (1, 0), (1, 1)), - values=(.5, 11., 1., .1), - dense_shape=(2, 2)) - }) - - def test_linear_model_mismatched_dense_values(self): - column = fc.weighted_categorical_column( - categorical_column=fc.categorical_column_with_identity( - key='ids', num_buckets=3), - weight_feature_key='values') - with ops.Graph().as_default(): - model = fc.LinearModel((column,), sparse_combiner='mean') - predictions = model({ - 'ids': - sparse_tensor.SparseTensorValue( - indices=((0, 0), (1, 0), (1, 1)), - values=(0, 2, 1), - dense_shape=(2, 2)), - 'values': ((.5,), (1.,)) - }) - # Disabling the constant folding optimizer here since it changes the - # error message differently on CPU and GPU. - config = config_pb2.ConfigProto() - config.graph_options.rewrite_options.constant_folding = ( - rewriter_config_pb2.RewriterConfig.OFF) - with _initialized_session(config): - with self.assertRaisesRegexp(errors.OpError, 'Incompatible shapes'): - self.evaluate(predictions) - - def test_linear_model_mismatched_dense_shape(self): - column = fc.weighted_categorical_column( - categorical_column=fc.categorical_column_with_identity( - key='ids', num_buckets=3), - weight_feature_key='values') - with ops.Graph().as_default(): - model = fc.LinearModel((column,)) - predictions = model({ - 'ids': - sparse_tensor.SparseTensorValue( - indices=((0, 0), (1, 0), (1, 1)), - values=(0, 2, 1), - dense_shape=(2, 2)), - 'values': ((.5,), (1.,), (.1,)) - }) - weight_var, bias = model.variables - - self.evaluate(variables_lib.global_variables_initializer()) - self.evaluate(lookup_ops.tables_initializer()) - - self.assertAllClose((0.,), self.evaluate(bias)) - self.assertAllClose(((0.,), (0.,), (0.,)), self.evaluate(weight_var)) - self.assertAllClose(((0.,), (0.,)), self.evaluate(predictions)) - self.evaluate(weight_var.assign(((1.,), (2.,), (3.,)))) - # weight_var[0] * weights[0, 0] = 1 * .5 = .5 - # weight_var[2] * weights[1, 0] + weight_var[1] * weights[1, 1] - # = 3*1 + 2*.1 = 3+.2 = 3.2 - self.assertAllClose(((.5,), (3.2,)), self.evaluate(predictions)) - def test_old_linear_model(self): column = fc.weighted_categorical_column( categorical_column=fc.categorical_column_with_identity( diff --git a/tensorflow/python/feature_column/keras_integration_test.py b/tensorflow/python/feature_column/keras_integration_test.py index e0677e84e50..456c0204350 100644 --- a/tensorflow/python/feature_column/keras_integration_test.py +++ b/tensorflow/python/feature_column/keras_integration_test.py @@ -23,12 +23,12 @@ import numpy as np from tensorflow.python import keras from tensorflow.python import tf2 from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.feature_column import dense_features_v2 from tensorflow.python.feature_column import feature_column_lib as fc from tensorflow.python.feature_column import feature_column_v2 from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import metrics as metrics_module from tensorflow.python.keras import testing_utils +from tensorflow.python.keras.feature_column import dense_features_v2 from tensorflow.python.keras.optimizer_v2 import gradient_descent from tensorflow.python.keras.premade import linear from tensorflow.python.keras.premade import wide_deep diff --git a/tensorflow/python/feature_column/sequence_feature_column_test.py b/tensorflow/python/feature_column/sequence_feature_column_test.py index 3d5d24ec03a..d0cf5ee7670 100644 --- a/tensorflow/python/feature_column/sequence_feature_column_test.py +++ b/tensorflow/python/feature_column/sequence_feature_column_test.py @@ -24,7 +24,6 @@ from absl.testing import parameterized import numpy as np from tensorflow.python.client import session -from tensorflow.python.feature_column import dense_features from tensorflow.python.feature_column import feature_column_v2 as fc from tensorflow.python.feature_column import sequence_feature_column as sfc from tensorflow.python.feature_column import serialization @@ -111,54 +110,6 @@ class ConcatenateContextInputTest(test.TestCase, parameterized.TestCase): sfc.concatenate_context_input(context_input, seq_input) -@test_util.run_all_in_graph_and_eager_modes -class DenseFeaturesTest(test.TestCase): - """Tests DenseFeatures with sequence feature columns.""" - - def test_embedding_column(self): - """Tests that error is raised for sequence embedding column.""" - vocabulary_size = 3 - sparse_input = sparse_tensor.SparseTensorValue( - # example 0, ids [2] - # example 1, ids [0, 1] - indices=((0, 0), (1, 0), (1, 1)), - values=(2, 0, 1), - dense_shape=(2, 2)) - - categorical_column_a = sfc.sequence_categorical_column_with_identity( - key='aaa', num_buckets=vocabulary_size) - embedding_column_a = fc.embedding_column( - categorical_column_a, dimension=2) - - input_layer = dense_features.DenseFeatures([embedding_column_a]) - with self.assertRaisesRegexp( - ValueError, - r'In embedding_column: aaa_embedding\. categorical_column must not be ' - r'of type SequenceCategoricalColumn\.'): - _ = input_layer({'aaa': sparse_input}) - - def test_indicator_column(self): - """Tests that error is raised for sequence indicator column.""" - vocabulary_size = 3 - sparse_input = sparse_tensor.SparseTensorValue( - # example 0, ids [2] - # example 1, ids [0, 1] - indices=((0, 0), (1, 0), (1, 1)), - values=(2, 0, 1), - dense_shape=(2, 2)) - - categorical_column_a = sfc.sequence_categorical_column_with_identity( - key='aaa', num_buckets=vocabulary_size) - indicator_column_a = fc.indicator_column(categorical_column_a) - - input_layer = dense_features.DenseFeatures([indicator_column_a]) - with self.assertRaisesRegexp( - ValueError, - r'In indicator_column: aaa_indicator\. categorical_column must not be ' - r'of type SequenceCategoricalColumn\.'): - _ = input_layer({'aaa': sparse_input}) - - def _assert_sparse_tensor_value(test_case, expected, actual): _assert_sparse_tensor_indices_shape(test_case, expected, actual) diff --git a/tensorflow/python/feature_column/serialization_test.py b/tensorflow/python/feature_column/serialization_test.py index 78b72746ac9..69b954022af 100644 --- a/tensorflow/python/feature_column/serialization_test.py +++ b/tensorflow/python/feature_column/serialization_test.py @@ -18,12 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from absl.testing import parameterized -from tensorflow.python.feature_column import dense_features from tensorflow.python.feature_column import feature_column_v2 as fc from tensorflow.python.feature_column import serialization -from tensorflow.python.framework import test_util from tensorflow.python.platform import test @@ -114,123 +111,5 @@ class FeatureColumnSerializationTest(test.TestCase): self.assertIs(new_price.normalizer_fn, _custom_fn) -@test_util.run_all_in_graph_and_eager_modes -class DenseFeaturesSerializationTest(test.TestCase, parameterized.TestCase): - - @parameterized.named_parameters( - ('default', None, None), - ('trainable', True, 'trainable'), - ('not_trainable', False, 'frozen')) - def test_get_config(self, trainable, name): - cols = [fc.numeric_column('a'), - fc.embedding_column(fc.categorical_column_with_identity( - key='b', num_buckets=3), dimension=2)] - orig_layer = dense_features.DenseFeatures( - cols, trainable=trainable, name=name) - config = orig_layer.get_config() - - self.assertEqual(config['name'], orig_layer.name) - self.assertEqual(config['trainable'], trainable) - self.assertLen(config['feature_columns'], 2) - self.assertEqual( - config['feature_columns'][0]['class_name'], 'NumericColumn') - self.assertEqual(config['feature_columns'][0]['config']['shape'], (1,)) - self.assertEqual( - config['feature_columns'][1]['class_name'], 'EmbeddingColumn') - - @parameterized.named_parameters( - ('default', None, None), - ('trainable', True, 'trainable'), - ('not_trainable', False, 'frozen')) - def test_from_config(self, trainable, name): - cols = [fc.numeric_column('a'), - fc.embedding_column(fc.categorical_column_with_vocabulary_list( - 'b', vocabulary_list=['1', '2', '3']), dimension=2), - fc.indicator_column(fc.categorical_column_with_hash_bucket( - key='c', hash_bucket_size=3))] - orig_layer = dense_features.DenseFeatures( - cols, trainable=trainable, name=name) - config = orig_layer.get_config() - - new_layer = dense_features.DenseFeatures.from_config(config) - - self.assertEqual(new_layer.name, orig_layer.name) - self.assertEqual(new_layer.trainable, trainable) - self.assertLen(new_layer._feature_columns, 3) - self.assertEqual(new_layer._feature_columns[0].name, 'a') - self.assertEqual(new_layer._feature_columns[1].initializer.mean, 0.0) - self.assertEqual(new_layer._feature_columns[1].categorical_column.name, 'b') - self.assertIsInstance(new_layer._feature_columns[2], fc.IndicatorColumn) - - def test_crossed_column(self): - a = fc.categorical_column_with_vocabulary_list( - 'a', vocabulary_list=['1', '2', '3']) - b = fc.categorical_column_with_vocabulary_list( - 'b', vocabulary_list=['1', '2', '3']) - ab = fc.crossed_column([a, b], hash_bucket_size=2) - cols = [fc.indicator_column(ab)] - - orig_layer = dense_features.DenseFeatures(cols) - config = orig_layer.get_config() - - new_layer = dense_features.DenseFeatures.from_config(config) - - self.assertLen(new_layer._feature_columns, 1) - self.assertEqual(new_layer._feature_columns[0].name, 'a_X_b_indicator') - - -@test_util.run_all_in_graph_and_eager_modes -class LinearModelLayerSerializationTest(test.TestCase, parameterized.TestCase): - - @parameterized.named_parameters( - ('default', 1, 'sum', None, None), - ('trainable', 6, 'mean', True, 'trainable'), - ('not_trainable', 10, 'sum', False, 'frozen')) - def test_get_config(self, units, sparse_combiner, trainable, name): - cols = [fc.numeric_column('a'), - fc.categorical_column_with_identity(key='b', num_buckets=3)] - layer = fc._LinearModelLayer( - cols, units=units, sparse_combiner=sparse_combiner, - trainable=trainable, name=name) - config = layer.get_config() - - self.assertEqual(config['name'], layer.name) - self.assertEqual(config['trainable'], trainable) - self.assertEqual(config['units'], units) - self.assertEqual(config['sparse_combiner'], sparse_combiner) - self.assertLen(config['feature_columns'], 2) - self.assertEqual( - config['feature_columns'][0]['class_name'], 'NumericColumn') - self.assertEqual( - config['feature_columns'][1]['class_name'], 'IdentityCategoricalColumn') - - @parameterized.named_parameters( - ('default', 1, 'sum', None, None), - ('trainable', 6, 'mean', True, 'trainable'), - ('not_trainable', 10, 'sum', False, 'frozen')) - def test_from_config(self, units, sparse_combiner, trainable, name): - cols = [fc.numeric_column('a'), - fc.categorical_column_with_vocabulary_list( - 'b', vocabulary_list=('1', '2', '3')), - fc.categorical_column_with_hash_bucket( - key='c', hash_bucket_size=3)] - orig_layer = fc._LinearModelLayer( - cols, units=units, sparse_combiner=sparse_combiner, - trainable=trainable, name=name) - config = orig_layer.get_config() - - new_layer = fc._LinearModelLayer.from_config(config) - - self.assertEqual(new_layer.name, orig_layer.name) - self.assertEqual(new_layer._units, units) - self.assertEqual(new_layer._sparse_combiner, sparse_combiner) - self.assertEqual(new_layer.trainable, trainable) - self.assertLen(new_layer._feature_columns, 3) - self.assertEqual(new_layer._feature_columns[0].name, 'a') - self.assertEqual( - new_layer._feature_columns[1].vocabulary_list, ('1', '2', '3')) - self.assertEqual(new_layer._feature_columns[2].num_buckets, 3) - - if __name__ == '__main__': test.main() diff --git a/tensorflow/python/framework/dtypes.py b/tensorflow/python/framework/dtypes.py index 73fb034f061..994a7eea494 100644 --- a/tensorflow/python/framework/dtypes.py +++ b/tensorflow/python/framework/dtypes.py @@ -640,5 +640,8 @@ def as_dtype(type_value): except (KeyError, TypeError): pass + if isinstance(type_value, _dtypes.DType): + return _INTERN_TABLE[type_value.as_datatype_enum] + raise TypeError("Cannot convert value %r to a TensorFlow DType." % (type_value,)) diff --git a/tensorflow/python/framework/dtypes_test.py b/tensorflow/python/framework/dtypes_test.py index dd2ea446b78..041cc5280cd 100644 --- a/tensorflow/python/framework/dtypes_test.py +++ b/tensorflow/python/framework/dtypes_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import numpy as np from tensorflow.core.framework import types_pb2 +from tensorflow.python import _dtypes from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util from tensorflow.python.platform import googletest @@ -64,6 +65,13 @@ class TypesTest(test_util.TensorFlowTestCase): dtypes.as_dtype(datatype_enum).base_dtype, dtypes.as_dtype(numpy_dtype)) + def testAllPybind11DTypeConvertibleToDType(self): + for datatype_enum in types_pb2.DataType.values(): + if datatype_enum == types_pb2.DT_INVALID: + continue + dtype = _dtypes.DType(datatype_enum) + self.assertEqual(dtypes.as_dtype(datatype_enum), dtype) + def testInvalid(self): with self.assertRaises(TypeError): dtypes.DType(types_pb2.DT_INVALID) diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 43652d51eae..5b6dac5be34 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -1394,6 +1394,65 @@ def _error_prefix(name): return "" if name is None else "%s: " % name +def pack_eager_tensors(tensors, ctx=None): + """Pack multiple `EagerTensor`s of the same dtype and shape. + + Args: + tensors: a list of EagerTensors to pack. + ctx: context.context(). + + Returns: + A packed EagerTensor. + """ + if not isinstance(tensors, list): + raise TypeError("tensors must be a list or a tuple: %s" % tensors) + + if not tensors: + raise ValueError("Empty tensors is unexpected for packing.") + + dtype = tensors[0].dtype + shape = tensors[0].shape + handle_data = tensors[0]._handle_data # pylint: disable=protected-access + is_resource = dtype == dtypes.resource + for i in range(len(tensors)): + t = tensors[i] + if not isinstance(t, EagerTensor): + raise TypeError("tensors must be a list of EagerTensors: %s" % t) + + if t.dtype != dtype: + raise ValueError( + "All tensors being packed should have the same dtype %s, " + "but the %d-th tensor is of dtype %s" % (dtype, i, t.dtype)) + if t.shape != shape: + raise ValueError( + "All tensors being packed should have the same shape %s, " + "but the %d-th tensor is of shape %s" % (shape, i, t.shape)) + # pylint: disable=protected-access + if is_resource and t._handle_data != handle_data: + raise ValueError( + "All tensors being packed should have the same handle data %s, " + "but the %d-th tensor is of handle data %s" % + (handle_data, i, t._handle_data)) + # pylint: enable=protected-access + + if ctx is None: + ctx = context.context() + + # Propogate handle data for resource variables + packed_tensor = ctx.pack_eager_tensors(tensors) + if handle_data is not None: + packed_tensor._handle_data = handle_data # pylint: disable=protected-access + + def grad_fun(_): + raise ValueError( + "Gradients through pack_eager_tensors are not supported yet.") + + tape.record_operation("pack_eager_tensors", [packed_tensor], tensors, + grad_fun) + + return packed_tensor + + def convert_to_tensor(value, dtype=None, name=None, diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index 322df8ffac8..7626bd780bb 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -34,6 +34,7 @@ from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.eager import function as eager_function from tensorflow.python.eager import wrap_function +from tensorflow.python.framework import config from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import constant_op from tensorflow.python.framework import device as pydev @@ -90,7 +91,7 @@ class ResourceTest(test_util.TensorFlowTestCase): resources.shared_resources()).eval()), 0) -@test_util.disable_tfrt("Graph is not supported yet.") +@test_util.disable_tfrt("Graph is not supported yet. b/156187905") class TensorAndShapeTest(test_util.TensorFlowTestCase): def testShape(self): @@ -310,7 +311,8 @@ class TensorAndShapeTest(test_util.TensorFlowTestCase): del x self.assertIsNotNone(x_ref.deref()) -@test_util.disable_tfrt("Graph mode is not supported yet.") + +@test_util.disable_tfrt("Graph is not supported yet. b/156187905") @test_util.run_all_in_graph_and_eager_modes class IndexedSlicesTest(test_util.TensorFlowTestCase): @@ -355,7 +357,7 @@ class IndexedSlicesTest(test_util.TensorFlowTestCase): self.assertAllEqual(x.indices, [0, 2]) -@test_util.disable_tfrt("Graph mode is not supported yet.") +@test_util.disable_tfrt("Graph is not supported yet. b/156187905") @test_util.run_all_in_graph_and_eager_modes class IndexedSlicesSpecTest(test_util.TensorFlowTestCase, parameterized.TestCase): @@ -501,7 +503,7 @@ def _apply_op(g, *args, **kwargs): return op.outputs -@test_util.disable_tfrt("Graph is not supported yet.") +@test_util.disable_tfrt("Graph is not supported yet. b/156187905") class OperationTest(test_util.TensorFlowTestCase): @test_util.run_deprecated_v1 @@ -1444,7 +1446,7 @@ class NameTest(test_util.TensorFlowTestCase): g.create_op("FloatOutput", [], [dtypes.float32]).name) -@test_util.disable_tfrt("Device API are not supported yet.") +@test_util.disable_tfrt("Device API are not supported yet. b/156188344") class DeviceTest(test_util.TensorFlowTestCase): def testNoDevice(self): @@ -2025,7 +2027,7 @@ class CollectionTest(test_util.TensorFlowTestCase): # Collections are ordered. self.assertEqual([90, 100], ops.get_collection("key")) - @test_util.disable_tfrt("Functions are not supported yet.") + @test_util.disable_tfrt("Graph is not supported yet. b/156187905") def test_defun(self): with context.eager_mode(): @@ -2132,7 +2134,7 @@ class ControlDependenciesTest(test_util.TensorFlowTestCase): # e should be dominated by c. self.assertEqual(e.op.control_inputs, []) - @test_util.disable_tfrt("Graph is not supported yet.") + @test_util.disable_tfrt("Graph is not supported yet. b/156187905") @test_util.run_in_graph_and_eager_modes def testEager(self): def future(): @@ -2453,7 +2455,7 @@ class OpScopeTest(test_util.TensorFlowTestCase): self._testGraphElements([a, variable, b]) -@test_util.disable_tfrt("Graphs are not supported yet.") +@test_util.disable_tfrt("Graph is not supported yet. b/156187905") class InitScopeTest(test_util.TensorFlowTestCase): def testClearsControlDependencies(self): @@ -2756,7 +2758,7 @@ class InitScopeTest(test_util.TensorFlowTestCase): self.assertFalse(self.evaluate(f())) -@test_util.disable_tfrt("Graphs are not supported yet.") +@test_util.disable_tfrt("Graph is not supported yet. b/156187905") class GraphTest(test_util.TensorFlowTestCase): def setUp(self): @@ -3234,7 +3236,7 @@ class ColocationGroupTest(test_util.TensorFlowTestCase): b = variables.Variable([3.0], name="b") self.assertEqual([b"loc:@a"], b.op.colocation_groups()) - @test_util.disable_tfrt("Functions are not supported yet.") + @test_util.disable_tfrt("Graph is not supported yet. b/156187905") def testColocateWithVariableInFunction(self): v = variables.Variable(1.) @@ -3408,5 +3410,51 @@ class CustomConvertToCompositeTensorTest(test_util.TensorFlowTestCase): self.assertAllEqual(x_, tensor_util.constant_value(y_)) +@test_util.disable_tfrt("Packing EagerTensors is not supported yet.") +class PackEagerTensorTest(test_util.TensorFlowTestCase): + + def setUp(self): + super(PackEagerTensorTest, self).setUp() + context._reset_context() + cpus = config.list_physical_devices("CPU") + # Set 2 virtual CPUs + config.set_logical_device_configuration(cpus[0], [ + context.LogicalDeviceConfiguration(), + context.LogicalDeviceConfiguration(), + ]) + + def testPack(self): + with context.eager_mode(): + with ops.device("CPU:0"): + var0 = resource_variable_ops.ResourceVariable(1.0) + c0 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + with ops.device("CPU:1"): + var1 = resource_variable_ops.ResourceVariable(2.0) + var2 = resource_variable_ops.ResourceVariable([3.0]) + c1 = constant_op.constant([9.0]) + + packed_var0 = ops.pack_eager_tensors([var0.handle, var1.handle]) + self.assertTrue(packed_var0.is_packed) + self.assertEqual(packed_var0.dtype, var0.handle.dtype) + self.assertEqual(packed_var0.shape, var0.handle.shape) + self.assertEqual(packed_var0._handle_data, var0.handle._handle_data) + self.assertIn("COMPOSITE:0", packed_var0.device) + self.assertIn("COMPOSITE:0", packed_var0.backing_device) + with self.assertRaises(errors.InvalidArgumentError): + packed_var0.numpy() + + # Different dtypes + with self.assertRaises(ValueError): + ops.pack_eager_tensors([var0.handle, c1]) + + # Different shapes + with self.assertRaises(ValueError): + ops.pack_eager_tensors([c0, c1]) + + # Different handle data + with self.assertRaises(ValueError): + ops.pack_eager_tensors([var0.handle, var2.handle]) + + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/python/framework/python_op_gen.cc b/tensorflow/python/framework/python_op_gen.cc index 857cc7b6638..ca0c5d9ef1a 100644 --- a/tensorflow/python/framework/python_op_gen.cc +++ b/tensorflow/python/framework/python_op_gen.cc @@ -959,7 +959,10 @@ void GenEagerPythonOp::AddDispatch(const string& prefix) { strings::StrAppend(&result_, prefix, "except (TypeError, ValueError):\n"); strings::StrAppend(&result_, prefix, " result = _dispatch.dispatch(\n"); - AddBodyNoReturn(strings::StrCat(prefix, " ", function_name_, ", ")); + AddBodyNoReturn(strings::StrCat(prefix, " ", function_name_, + ", " + "(), dict(")); + strings::StrAppend(&result_, prefix, " )\n"); strings::StrAppend(&result_, prefix, " if result is not " "_dispatch.OpDispatcher.NOT_SUPPORTED:\n"); diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index d5bbd889166..4981e1b68fd 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -2686,7 +2686,7 @@ class TensorFlowTestCase(googletest.TestCase): if (b.ndim <= 3 or b.size < 500): self.assertEqual( a.shape, b.shape, "Shape mismatch: expected %s, got %s." - " Contents: %s. \n%s." % (a.shape, b.shape, b, msg)) + " Contents: %r. \n%s." % (a.shape, b.shape, b, msg)) else: self.assertEqual( a.shape, b.shape, "Shape mismatch: expected %s, got %s." @@ -2709,8 +2709,8 @@ class TensorFlowTestCase(googletest.TestCase): else: # np.where is broken for scalars x, y = a, b - msgs.append("not equal lhs = {}".format(x)) - msgs.append("not equal rhs = {}".format(y)) + msgs.append("not equal lhs = %r" % x) + msgs.append("not equal rhs = %r" % y) # With Python 3, we need to make sure the dtype matches between a and b. b = b.astype(a.dtype) np.testing.assert_array_equal(a, b, err_msg="\n".join(msgs)) diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD index 4cd0af07c74..78e360c8354 100755 --- a/tensorflow/python/keras/BUILD +++ b/tensorflow/python/keras/BUILD @@ -584,6 +584,7 @@ tf_py_test( deps = [ ":backend", ":combinations", + ":engine", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", diff --git a/tensorflow/python/keras/activations.py b/tensorflow/python/keras/activations.py index 34d04d68c6c..0ee4a91f417 100644 --- a/tensorflow/python/keras/activations.py +++ b/tensorflow/python/keras/activations.py @@ -24,6 +24,7 @@ from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object from tensorflow.python.keras.utils.generic_utils import serialize_keras_object from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn +from tensorflow.python.util import dispatch from tensorflow.python.util.tf_export import keras_export # b/123041942 @@ -41,6 +42,7 @@ _TF_ACTIVATIONS_V2 = { @keras_export('keras.activations.softmax') +@dispatch.add_dispatch_support def softmax(x, axis=-1): """Softmax converts a real vector to a vector of categorical probabilities. @@ -82,6 +84,7 @@ def softmax(x, axis=-1): @keras_export('keras.activations.elu') +@dispatch.add_dispatch_support def elu(x, alpha=1.0): """Exponential linear unit. @@ -100,6 +103,7 @@ def elu(x, alpha=1.0): @keras_export('keras.activations.selu') +@dispatch.add_dispatch_support def selu(x): """Scaled Exponential Linear Unit (SELU). @@ -153,6 +157,7 @@ def selu(x): @keras_export('keras.activations.softplus') +@dispatch.add_dispatch_support def softplus(x): """Softplus activation function, `softplus(x) = log(exp(x) + 1)`. @@ -174,6 +179,7 @@ def softplus(x): @keras_export('keras.activations.softsign') +@dispatch.add_dispatch_support def softsign(x): """Softsign activation function, `softsign(x) = x / (abs(x) + 1)`. @@ -194,6 +200,7 @@ def softsign(x): @keras_export('keras.activations.swish') +@dispatch.add_dispatch_support def swish(x): """Swish activation function, `swish(x) = x * sigmoid(x)`. @@ -224,6 +231,7 @@ def swish(x): @keras_export('keras.activations.relu') +@dispatch.add_dispatch_support def relu(x, alpha=0., max_value=None, threshold=0): """Applies the rectified linear unit activation function. @@ -264,6 +272,7 @@ def relu(x, alpha=0., max_value=None, threshold=0): @keras_export('keras.activations.tanh') +@dispatch.add_dispatch_support def tanh(x): """Hyperbolic tangent activation function. @@ -285,6 +294,7 @@ def tanh(x): @keras_export('keras.activations.sigmoid') +@dispatch.add_dispatch_support def sigmoid(x): """Sigmoid activation function, `sigmoid(x) = 1 / (1 + exp(-x))`. @@ -314,6 +324,7 @@ def sigmoid(x): @keras_export('keras.activations.exponential') +@dispatch.add_dispatch_support def exponential(x): """Exponential activation function. @@ -334,6 +345,7 @@ def exponential(x): @keras_export('keras.activations.hard_sigmoid') +@dispatch.add_dispatch_support def hard_sigmoid(x): """Hard sigmoid activation function. @@ -360,6 +372,7 @@ def hard_sigmoid(x): @keras_export('keras.activations.linear') +@dispatch.add_dispatch_support def linear(x): """Linear activation function (pass-through). @@ -380,6 +393,7 @@ def linear(x): @keras_export('keras.activations.serialize') +@dispatch.add_dispatch_support def serialize(activation): """Returns the string identifier of an activation function. @@ -410,6 +424,7 @@ def serialize(activation): @keras_export('keras.activations.deserialize') +@dispatch.add_dispatch_support def deserialize(name, custom_objects=None): """Returns activation function given a string identifier. @@ -447,6 +462,7 @@ def deserialize(name, custom_objects=None): @keras_export('keras.activations.get') +@dispatch.add_dispatch_support def get(identifier): """Returns function. diff --git a/tensorflow/python/keras/applications/BUILD b/tensorflow/python/keras/applications/BUILD index 1eaed45c714..0c566c6e6d5 100644 --- a/tensorflow/python/keras/applications/BUILD +++ b/tensorflow/python/keras/applications/BUILD @@ -35,10 +35,16 @@ py_library( srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ - "//tensorflow/python:util", + "//tensorflow/python:lib", + "//tensorflow/python:platform", + "//tensorflow/python:tf_export", + "//tensorflow/python/keras:activations", "//tensorflow/python/keras:backend", - "//tensorflow/python/keras:engine", + "//tensorflow/python/keras/engine", "//tensorflow/python/keras/layers", + "//tensorflow/python/keras/utils:data_utils", + "//tensorflow/python/keras/utils:layer_utils", + "//third_party/py/numpy", ], ) diff --git a/tensorflow/python/keras/applications/densenet.py b/tensorflow/python/keras/applications/densenet.py index 39004be622f..620a0b21607 100644 --- a/tensorflow/python/keras/applications/densenet.py +++ b/tensorflow/python/keras/applications/densenet.py @@ -23,14 +23,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os - from tensorflow.python.keras import backend from tensorflow.python.keras.applications import imagenet_utils from tensorflow.python.keras.engine import training from tensorflow.python.keras.layers import VersionAwareLayers from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import layer_utils +from tensorflow.python.lib.io import file_io from tensorflow.python.util.tf_export import keras_export @@ -193,7 +192,7 @@ def DenseNet( ValueError: if `classifier_activation` is not `softmax` or `None` when using a pretrained top layer. """ - if not (weights in {'imagenet', None} or os.path.exists(weights)): + if not (weights in {'imagenet', None} or file_io.file_exists(weights)): raise ValueError('The `weights` argument should be either ' '`None` (random initialization), `imagenet` ' '(pre-training on ImageNet), ' diff --git a/tensorflow/python/keras/applications/efficientnet.py b/tensorflow/python/keras/applications/efficientnet.py index ece9f7f7e5b..e1413b08533 100644 --- a/tensorflow/python/keras/applications/efficientnet.py +++ b/tensorflow/python/keras/applications/efficientnet.py @@ -26,7 +26,6 @@ from __future__ import print_function import copy import math -import os from tensorflow.python.keras import backend from tensorflow.python.keras.applications import imagenet_utils @@ -34,6 +33,7 @@ from tensorflow.python.keras.engine import training from tensorflow.python.keras.layers import VersionAwareLayers from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import layer_utils +from tensorflow.python.lib.io import file_io from tensorflow.python.util.tf_export import keras_export @@ -269,7 +269,7 @@ def EfficientNet( if blocks_args == 'default': blocks_args = DEFAULT_BLOCKS_ARGS - if not (weights in {'imagenet', None} or os.path.exists(weights)): + if not (weights in {'imagenet', None} or file_io.file_exists(weights)): raise ValueError('The `weights` argument should be either ' '`None` (random initialization), `imagenet` ' '(pre-training on ImageNet), ' diff --git a/tensorflow/python/keras/applications/inception_resnet_v2.py b/tensorflow/python/keras/applications/inception_resnet_v2.py index 15cbfa5033c..31f342b4d5a 100644 --- a/tensorflow/python/keras/applications/inception_resnet_v2.py +++ b/tensorflow/python/keras/applications/inception_resnet_v2.py @@ -25,14 +25,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os - from tensorflow.python.keras import backend from tensorflow.python.keras.applications import imagenet_utils from tensorflow.python.keras.engine import training from tensorflow.python.keras.layers import VersionAwareLayers from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import layer_utils +from tensorflow.python.lib.io import file_io from tensorflow.python.util.tf_export import keras_export @@ -113,7 +112,7 @@ def InceptionResNetV2(include_top=True, layers = VersionAwareLayers() if kwargs: raise ValueError('Unknown argument(s): %s' % (kwargs,)) - if not (weights in {'imagenet', None} or os.path.exists(weights)): + if not (weights in {'imagenet', None} or file_io.file_exists(weights)): raise ValueError('The `weights` argument should be either ' '`None` (random initialization), `imagenet` ' '(pre-training on ImageNet), ' diff --git a/tensorflow/python/keras/applications/inception_v3.py b/tensorflow/python/keras/applications/inception_v3.py index 3f528fc131a..9fb1dad6b03 100644 --- a/tensorflow/python/keras/applications/inception_v3.py +++ b/tensorflow/python/keras/applications/inception_v3.py @@ -23,14 +23,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os - from tensorflow.python.keras import backend from tensorflow.python.keras.applications import imagenet_utils from tensorflow.python.keras.engine import training from tensorflow.python.keras.layers import VersionAwareLayers from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import layer_utils +from tensorflow.python.lib.io import file_io from tensorflow.python.util.tf_export import keras_export @@ -109,7 +108,7 @@ def InceptionV3( ValueError: if `classifier_activation` is not `softmax` or `None` when using a pretrained top layer. """ - if not (weights in {'imagenet', None} or os.path.exists(weights)): + if not (weights in {'imagenet', None} or file_io.file_exists(weights)): raise ValueError('The `weights` argument should be either ' '`None` (random initialization), `imagenet` ' '(pre-training on ImageNet), ' diff --git a/tensorflow/python/keras/applications/mobilenet.py b/tensorflow/python/keras/applications/mobilenet.py index f531d8d124c..3f29f01da2d 100644 --- a/tensorflow/python/keras/applications/mobilenet.py +++ b/tensorflow/python/keras/applications/mobilenet.py @@ -64,14 +64,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os - from tensorflow.python.keras import backend from tensorflow.python.keras.applications import imagenet_utils from tensorflow.python.keras.engine import training from tensorflow.python.keras.layers import VersionAwareLayers from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import layer_utils +from tensorflow.python.lib.io import file_io from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util.tf_export import keras_export @@ -164,7 +163,7 @@ def MobileNet(input_shape=None, layers = VersionAwareLayers() if kwargs: raise ValueError('Unknown argument(s): %s' % (kwargs,)) - if not (weights in {'imagenet', None} or os.path.exists(weights)): + if not (weights in {'imagenet', None} or file_io.file_exists(weights)): raise ValueError('The `weights` argument should be either ' '`None` (random initialization), `imagenet` ' '(pre-training on ImageNet), ' diff --git a/tensorflow/python/keras/applications/mobilenet_v2.py b/tensorflow/python/keras/applications/mobilenet_v2.py index b1138b7ae26..86fd864ab02 100644 --- a/tensorflow/python/keras/applications/mobilenet_v2.py +++ b/tensorflow/python/keras/applications/mobilenet_v2.py @@ -77,14 +77,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os - from tensorflow.python.keras import backend from tensorflow.python.keras.applications import imagenet_utils from tensorflow.python.keras.engine import training from tensorflow.python.keras.layers import VersionAwareLayers from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import layer_utils +from tensorflow.python.lib.io import file_io from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util.tf_export import keras_export @@ -181,7 +180,7 @@ def MobileNetV2(input_shape=None, layers = VersionAwareLayers() if kwargs: raise ValueError('Unknown argument(s): %s' % (kwargs,)) - if not (weights in {'imagenet', None} or os.path.exists(weights)): + if not (weights in {'imagenet', None} or file_io.file_exists(weights)): raise ValueError('The `weights` argument should be either ' '`None` (random initialization), `imagenet` ' '(pre-training on ImageNet), ' diff --git a/tensorflow/python/keras/applications/nasnet.py b/tensorflow/python/keras/applications/nasnet.py index f4e5f74e77d..20f1df91048 100644 --- a/tensorflow/python/keras/applications/nasnet.py +++ b/tensorflow/python/keras/applications/nasnet.py @@ -41,14 +41,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os - from tensorflow.python.keras import backend from tensorflow.python.keras.applications import imagenet_utils from tensorflow.python.keras.engine import training from tensorflow.python.keras.layers import VersionAwareLayers from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import layer_utils +from tensorflow.python.lib.io import file_io from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util.tf_export import keras_export @@ -151,7 +150,7 @@ def NASNet( ValueError: if `classifier_activation` is not `softmax` or `None` when using a pretrained top layer. """ - if not (weights in {'imagenet', None} or os.path.exists(weights)): + if not (weights in {'imagenet', None} or file_io.file_exists(weights)): raise ValueError('The `weights` argument should be either ' '`None` (random initialization), `imagenet` ' '(pre-training on ImageNet), ' diff --git a/tensorflow/python/keras/applications/resnet.py b/tensorflow/python/keras/applications/resnet.py index e72f06ce3d1..5bc47f89460 100644 --- a/tensorflow/python/keras/applications/resnet.py +++ b/tensorflow/python/keras/applications/resnet.py @@ -23,14 +23,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os - from tensorflow.python.keras import backend from tensorflow.python.keras.applications import imagenet_utils from tensorflow.python.keras.engine import training from tensorflow.python.keras.layers import VersionAwareLayers from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import layer_utils +from tensorflow.python.lib.io import file_io from tensorflow.python.util.tf_export import keras_export @@ -138,7 +137,7 @@ def ResNet(stack_fn, layers = VersionAwareLayers() if kwargs: raise ValueError('Unknown argument(s): %s' % (kwargs,)) - if not (weights in {'imagenet', None} or os.path.exists(weights)): + if not (weights in {'imagenet', None} or file_io.file_exists(weights)): raise ValueError('The `weights` argument should be either ' '`None` (random initialization), `imagenet` ' '(pre-training on ImageNet), ' diff --git a/tensorflow/python/keras/applications/vgg16.py b/tensorflow/python/keras/applications/vgg16.py index 3a523dc5dc3..b160c920347 100644 --- a/tensorflow/python/keras/applications/vgg16.py +++ b/tensorflow/python/keras/applications/vgg16.py @@ -23,14 +23,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os - from tensorflow.python.keras import backend from tensorflow.python.keras.applications import imagenet_utils from tensorflow.python.keras.engine import training from tensorflow.python.keras.layers import VersionAwareLayers from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import layer_utils +from tensorflow.python.lib.io import file_io from tensorflow.python.util.tf_export import keras_export @@ -114,7 +113,7 @@ def VGG16( ValueError: if `classifier_activation` is not `softmax` or `None` when using a pretrained top layer. """ - if not (weights in {'imagenet', None} or os.path.exists(weights)): + if not (weights in {'imagenet', None} or file_io.file_exists(weights)): raise ValueError('The `weights` argument should be either ' '`None` (random initialization), `imagenet` ' '(pre-training on ImageNet), ' diff --git a/tensorflow/python/keras/applications/vgg19.py b/tensorflow/python/keras/applications/vgg19.py index e4385cc8f6a..11f1a252c64 100644 --- a/tensorflow/python/keras/applications/vgg19.py +++ b/tensorflow/python/keras/applications/vgg19.py @@ -23,14 +23,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os - from tensorflow.python.keras import backend from tensorflow.python.keras.applications import imagenet_utils from tensorflow.python.keras.engine import training from tensorflow.python.keras.layers import VersionAwareLayers from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import layer_utils +from tensorflow.python.lib.io import file_io from tensorflow.python.util.tf_export import keras_export @@ -114,7 +113,7 @@ def VGG19( ValueError: if `classifier_activation` is not `softmax` or `None` when using a pretrained top layer. """ - if not (weights in {'imagenet', None} or os.path.exists(weights)): + if not (weights in {'imagenet', None} or file_io.file_exists(weights)): raise ValueError('The `weights` argument should be either ' '`None` (random initialization), `imagenet` ' '(pre-training on ImageNet), ' diff --git a/tensorflow/python/keras/applications/xception.py b/tensorflow/python/keras/applications/xception.py index 7139764b15b..f414ded6e18 100644 --- a/tensorflow/python/keras/applications/xception.py +++ b/tensorflow/python/keras/applications/xception.py @@ -27,14 +27,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os - from tensorflow.python.keras import backend from tensorflow.python.keras.applications import imagenet_utils from tensorflow.python.keras.engine import training from tensorflow.python.keras.layers import VersionAwareLayers from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import layer_utils +from tensorflow.python.lib.io import file_io from tensorflow.python.util.tf_export import keras_export @@ -114,7 +113,7 @@ def Xception( ValueError: if `classifier_activation` is not `softmax` or `None` when using a pretrained top layer. """ - if not (weights in {'imagenet', None} or os.path.exists(weights)): + if not (weights in {'imagenet', None} or file_io.file_exists(weights)): raise ValueError('The `weights` argument should be either ' '`None` (random initialization), `imagenet` ' '(pre-training on ImageNet), ' diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index 11e53e032ae..d0c3eb03342 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -76,6 +76,7 @@ from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import moving_averages from tensorflow.python.training.tracking import util as tracking_util +from tensorflow.python.util import dispatch from tensorflow.python.util import nest from tensorflow.python.util import object_identity from tensorflow.python.util import tf_contextlib @@ -173,6 +174,7 @@ def backend(): @keras_export('keras.backend.cast_to_floatx') +@dispatch.add_dispatch_support def cast_to_floatx(x): """Cast a Numpy array to the default Keras float type. @@ -292,7 +294,6 @@ def clear_session(): global _GRAPH_VARIABLES # pylint: disable=global-variable-not-assigned global _GRAPH_TF_OPTIMIZERS # pylint: disable=global-variable-not-assigned global _GRAPH - global _FREEZABLE_VARS _GRAPH.graph = None ops.reset_default_graph() reset_uids() @@ -305,7 +306,6 @@ def clear_session(): _GRAPH_LEARNING_PHASES.setdefault(graph) _GRAPH_VARIABLES.pop(graph, None) _GRAPH_TF_OPTIMIZERS.pop(graph, None) - _FREEZABLE_VARS.pop(graph, None) @keras_export('keras.backend.manual_variable_initialization') @@ -799,6 +799,7 @@ def is_sparse(tensor): @keras_export('keras.backend.to_dense') +@dispatch.add_dispatch_support def to_dense(tensor): """Converts a sparse tensor into a dense tensor and returns it. @@ -1007,6 +1008,7 @@ def _initialize_variables(session): @keras_export('keras.backend.constant') +@dispatch.add_dispatch_support def constant(value, dtype=None, shape=None, name=None): """Creates a constant tensor. @@ -1055,9 +1057,9 @@ def is_keras_tensor(x): >>> tf.keras.backend.is_keras_tensor(keras_var) False >>> keras_placeholder = tf.keras.backend.placeholder(shape=(2, 4, 5)) - >>> # A placeholder is not a Keras tensor. + >>> # A placeholder is a Keras tensor. >>> tf.keras.backend.is_keras_tensor(keras_placeholder) - False + True >>> keras_input = tf.keras.layers.Input([10]) >>> # An Input is a Keras tensor. >>> tf.keras.backend.is_keras_tensor(keras_input) @@ -1140,6 +1142,14 @@ def placeholder(shape=None, expand_composites=True) else: x = array_ops.placeholder(dtype, shape=shape, name=name) + + if context.executing_eagerly(): + # Add keras_history connectivity information to the placeholder + # when the placeholder is built in a top-level eager context + # (intended to be used with keras.backend.function) + from tensorflow.python.keras.engine import input_layer # pylint: disable=g-import-not-at-top + return input_layer.Input(tensor=x) + return x @@ -1163,6 +1173,7 @@ def is_placeholder(x): @keras_export('keras.backend.shape') +@dispatch.add_dispatch_support def shape(x): """Returns the symbolic shape of a tensor or variable. @@ -1245,6 +1256,7 @@ def ndim(x): @keras_export('keras.backend.dtype') +@dispatch.add_dispatch_support def dtype(x): """Returns the dtype of a Keras tensor or variable, as a string. @@ -1343,6 +1355,7 @@ def zeros(shape, dtype=None, name=None): @keras_export('keras.backend.ones') +@dispatch.add_dispatch_support def ones(shape, dtype=None, name=None): """Instantiates an all-ones variable and returns it. @@ -1377,6 +1390,7 @@ def ones(shape, dtype=None, name=None): @keras_export('keras.backend.eye') +@dispatch.add_dispatch_support def eye(size, dtype=None, name=None): """Instantiate an identity matrix and returns it. @@ -1433,6 +1447,7 @@ def zeros_like(x, dtype=None, name=None): @keras_export('keras.backend.ones_like') +@dispatch.add_dispatch_support def ones_like(x, dtype=None, name=None): """Instantiates an all-ones variable of the same shape as another tensor. @@ -1563,6 +1578,7 @@ def count_params(x): @keras_export('keras.backend.cast') +@dispatch.add_dispatch_support def cast(x, dtype): """Casts a tensor to a different dtype and returns it. @@ -1647,6 +1663,7 @@ def moving_average_update(x, value, momentum): @keras_export('keras.backend.dot') +@dispatch.add_dispatch_support def dot(x, y): """Multiplies 2 tensors (and/or variables) and returns a tensor. @@ -1707,6 +1724,7 @@ def dot(x, y): @keras_export('keras.backend.batch_dot') +@dispatch.add_dispatch_support def batch_dot(x, y, axes=None): """Batchwise dot product. @@ -1895,6 +1913,7 @@ def batch_dot(x, y, axes=None): @keras_export('keras.backend.transpose') +@dispatch.add_dispatch_support def transpose(x): """Transposes a tensor and returns it. @@ -1926,6 +1945,7 @@ def transpose(x): @keras_export('keras.backend.gather') +@dispatch.add_dispatch_support def gather(reference, indices): """Retrieves the elements of indices `indices` in the tensor `reference`. @@ -1961,6 +1981,7 @@ def gather(reference, indices): @keras_export('keras.backend.max') +@dispatch.add_dispatch_support def max(x, axis=None, keepdims=False): """Maximum value in a tensor. @@ -1979,6 +2000,7 @@ def max(x, axis=None, keepdims=False): @keras_export('keras.backend.min') +@dispatch.add_dispatch_support def min(x, axis=None, keepdims=False): """Minimum value in a tensor. @@ -1997,6 +2019,7 @@ def min(x, axis=None, keepdims=False): @keras_export('keras.backend.sum') +@dispatch.add_dispatch_support def sum(x, axis=None, keepdims=False): """Sum of the values in a tensor, alongside the specified axis. @@ -2015,6 +2038,7 @@ def sum(x, axis=None, keepdims=False): @keras_export('keras.backend.prod') +@dispatch.add_dispatch_support def prod(x, axis=None, keepdims=False): """Multiplies the values in a tensor, alongside the specified axis. @@ -2033,6 +2057,7 @@ def prod(x, axis=None, keepdims=False): @keras_export('keras.backend.cumsum') +@dispatch.add_dispatch_support def cumsum(x, axis=0): """Cumulative sum of the values in a tensor, alongside the specified axis. @@ -2047,6 +2072,7 @@ def cumsum(x, axis=0): @keras_export('keras.backend.cumprod') +@dispatch.add_dispatch_support def cumprod(x, axis=0): """Cumulative product of the values in a tensor, alongside the specified axis. @@ -2081,6 +2107,7 @@ def var(x, axis=None, keepdims=False): @keras_export('keras.backend.std') +@dispatch.add_dispatch_support def std(x, axis=None, keepdims=False): """Standard deviation of a tensor, alongside the specified axis. @@ -2107,6 +2134,7 @@ def std(x, axis=None, keepdims=False): @keras_export('keras.backend.mean') +@dispatch.add_dispatch_support def mean(x, axis=None, keepdims=False): """Mean of a tensor, alongside the specified axis. @@ -2127,6 +2155,7 @@ def mean(x, axis=None, keepdims=False): @keras_export('keras.backend.any') +@dispatch.add_dispatch_support def any(x, axis=None, keepdims=False): """Bitwise reduction (logical OR). @@ -2143,6 +2172,7 @@ def any(x, axis=None, keepdims=False): @keras_export('keras.backend.all') +@dispatch.add_dispatch_support def all(x, axis=None, keepdims=False): """Bitwise reduction (logical AND). @@ -2159,6 +2189,7 @@ def all(x, axis=None, keepdims=False): @keras_export('keras.backend.argmax') +@dispatch.add_dispatch_support def argmax(x, axis=-1): """Returns the index of the maximum value along an axis. @@ -2173,6 +2204,7 @@ def argmax(x, axis=-1): @keras_export('keras.backend.argmin') +@dispatch.add_dispatch_support def argmin(x, axis=-1): """Returns the index of the minimum value along an axis. @@ -2187,6 +2219,7 @@ def argmin(x, axis=-1): @keras_export('keras.backend.square') +@dispatch.add_dispatch_support def square(x): """Element-wise square. @@ -2200,6 +2233,7 @@ def square(x): @keras_export('keras.backend.abs') +@dispatch.add_dispatch_support def abs(x): """Element-wise absolute value. @@ -2213,6 +2247,7 @@ def abs(x): @keras_export('keras.backend.sqrt') +@dispatch.add_dispatch_support def sqrt(x): """Element-wise square root. @@ -2229,6 +2264,7 @@ def sqrt(x): @keras_export('keras.backend.exp') +@dispatch.add_dispatch_support def exp(x): """Element-wise exponential. @@ -2242,6 +2278,7 @@ def exp(x): @keras_export('keras.backend.log') +@dispatch.add_dispatch_support def log(x): """Element-wise log. @@ -2276,6 +2313,7 @@ def logsumexp(x, axis=None, keepdims=False): @keras_export('keras.backend.round') +@dispatch.add_dispatch_support def round(x): """Element-wise rounding to the closest integer. @@ -2291,6 +2329,7 @@ def round(x): @keras_export('keras.backend.sign') +@dispatch.add_dispatch_support def sign(x): """Element-wise sign. @@ -2304,6 +2343,7 @@ def sign(x): @keras_export('keras.backend.pow') +@dispatch.add_dispatch_support def pow(x, a): """Element-wise exponentiation. @@ -2318,6 +2358,7 @@ def pow(x, a): @keras_export('keras.backend.clip') +@dispatch.add_dispatch_support def clip(x, min_value, max_value): """Element-wise value clipping. @@ -2341,6 +2382,7 @@ def clip(x, min_value, max_value): @keras_export('keras.backend.equal') +@dispatch.add_dispatch_support def equal(x, y): """Element-wise equality between two tensors. @@ -2355,6 +2397,7 @@ def equal(x, y): @keras_export('keras.backend.not_equal') +@dispatch.add_dispatch_support def not_equal(x, y): """Element-wise inequality between two tensors. @@ -2369,6 +2412,7 @@ def not_equal(x, y): @keras_export('keras.backend.greater') +@dispatch.add_dispatch_support def greater(x, y): """Element-wise truth value of (x > y). @@ -2383,6 +2427,7 @@ def greater(x, y): @keras_export('keras.backend.greater_equal') +@dispatch.add_dispatch_support def greater_equal(x, y): """Element-wise truth value of (x >= y). @@ -2397,6 +2442,7 @@ def greater_equal(x, y): @keras_export('keras.backend.less') +@dispatch.add_dispatch_support def less(x, y): """Element-wise truth value of (x < y). @@ -2411,6 +2457,7 @@ def less(x, y): @keras_export('keras.backend.less_equal') +@dispatch.add_dispatch_support def less_equal(x, y): """Element-wise truth value of (x <= y). @@ -2425,6 +2472,7 @@ def less_equal(x, y): @keras_export('keras.backend.maximum') +@dispatch.add_dispatch_support def maximum(x, y): """Element-wise maximum of two tensors. @@ -2449,6 +2497,7 @@ def maximum(x, y): @keras_export('keras.backend.minimum') +@dispatch.add_dispatch_support def minimum(x, y): """Element-wise minimum of two tensors. @@ -2463,6 +2512,7 @@ def minimum(x, y): @keras_export('keras.backend.sin') +@dispatch.add_dispatch_support def sin(x): """Computes sin of x element-wise. @@ -2476,6 +2526,7 @@ def sin(x): @keras_export('keras.backend.cos') +@dispatch.add_dispatch_support def cos(x): """Computes cos of x element-wise. @@ -2621,6 +2672,7 @@ def normalize_batch_in_training(x, gamma, beta, reduction_axes, epsilon=1e-3): @keras_export('keras.backend.batch_normalization') +@dispatch.add_dispatch_support def batch_normalization(x, mean, var, beta, gamma, axis=-1, epsilon=1e-3): """Applies batch normalization on x given mean, var, beta and gamma. @@ -2683,6 +2735,7 @@ def batch_normalization(x, mean, var, beta, gamma, axis=-1, epsilon=1e-3): @keras_export('keras.backend.concatenate') +@dispatch.add_dispatch_support def concatenate(tensors, axis=-1): """Concatenates a list of tensors alongside the specified axis. @@ -2720,6 +2773,7 @@ def concatenate(tensors, axis=-1): @keras_export('keras.backend.reshape') +@dispatch.add_dispatch_support def reshape(x, shape): """Reshapes a tensor to the specified shape. @@ -2749,6 +2803,7 @@ def reshape(x, shape): @keras_export('keras.backend.permute_dimensions') +@dispatch.add_dispatch_support def permute_dimensions(x, pattern): """Permutes axes in a tensor. @@ -2780,6 +2835,7 @@ def permute_dimensions(x, pattern): @keras_export('keras.backend.resize_images') +@dispatch.add_dispatch_support def resize_images(x, height_factor, width_factor, data_format, interpolation='nearest'): """Resizes the images contained in a 4D tensor. @@ -2843,6 +2899,7 @@ def resize_images(x, height_factor, width_factor, data_format, @keras_export('keras.backend.resize_volumes') +@dispatch.add_dispatch_support def resize_volumes(x, depth_factor, height_factor, width_factor, data_format): """Resizes the volume contained in a 5D tensor. @@ -2875,6 +2932,7 @@ def resize_volumes(x, depth_factor, height_factor, width_factor, data_format): @keras_export('keras.backend.repeat_elements') +@dispatch.add_dispatch_support def repeat_elements(x, rep, axis): """Repeats the elements of a tensor along an axis, like `np.repeat`. @@ -2936,6 +2994,7 @@ def repeat_elements(x, rep, axis): @keras_export('keras.backend.repeat') +@dispatch.add_dispatch_support def repeat(x, n): """Repeats a 2D tensor. @@ -2971,6 +3030,7 @@ def repeat(x, n): @keras_export('keras.backend.arange') +@dispatch.add_dispatch_support def arange(start, stop=None, step=1, dtype='int32'): """Creates a 1D tensor containing a sequence of integers. @@ -3009,6 +3069,7 @@ def arange(start, stop=None, step=1, dtype='int32'): @keras_export('keras.backend.tile') +@dispatch.add_dispatch_support def tile(x, n): """Creates a tensor by tiling `x` by `n`. @@ -3026,6 +3087,7 @@ def tile(x, n): @keras_export('keras.backend.flatten') +@dispatch.add_dispatch_support def flatten(x): """Flatten a tensor. @@ -3051,6 +3113,7 @@ def flatten(x): @keras_export('keras.backend.batch_flatten') +@dispatch.add_dispatch_support def batch_flatten(x): """Turn a nD tensor into a 2D tensor with same 0th dimension. @@ -3076,6 +3139,7 @@ def batch_flatten(x): @keras_export('keras.backend.expand_dims') +@dispatch.add_dispatch_support def expand_dims(x, axis=-1): """Adds a 1-sized dimension at index "axis". @@ -3090,6 +3154,7 @@ def expand_dims(x, axis=-1): @keras_export('keras.backend.squeeze') +@dispatch.add_dispatch_support def squeeze(x, axis): """Removes a 1-dimension from the tensor at index "axis". @@ -3104,6 +3169,7 @@ def squeeze(x, axis): @keras_export('keras.backend.temporal_padding') +@dispatch.add_dispatch_support def temporal_padding(x, padding=(1, 1)): """Pads the middle dimension of a 3D tensor. @@ -3121,6 +3187,7 @@ def temporal_padding(x, padding=(1, 1)): @keras_export('keras.backend.spatial_2d_padding') +@dispatch.add_dispatch_support def spatial_2d_padding(x, padding=((1, 1), (1, 1)), data_format=None): """Pads the 2nd and 3rd dimensions of a 4D tensor. @@ -3152,6 +3219,7 @@ def spatial_2d_padding(x, padding=((1, 1), (1, 1)), data_format=None): @keras_export('keras.backend.spatial_3d_padding') +@dispatch.add_dispatch_support def spatial_3d_padding(x, padding=((1, 1), (1, 1), (1, 1)), data_format=None): """Pads 5D tensor with zeros along the depth, height, width dimensions. @@ -3196,6 +3264,7 @@ def spatial_3d_padding(x, padding=((1, 1), (1, 1), (1, 1)), data_format=None): @keras_export('keras.backend.stack') +@dispatch.add_dispatch_support def stack(x, axis=0): """Stacks a list of rank `R` tensors into a rank `R+1` tensor. @@ -3222,6 +3291,7 @@ def stack(x, axis=0): @keras_export('keras.backend.one_hot') +@dispatch.add_dispatch_support def one_hot(indices, num_classes): """Computes the one-hot representation of an integer tensor. @@ -3241,6 +3311,7 @@ def one_hot(indices, num_classes): @keras_export('keras.backend.reverse') +@dispatch.add_dispatch_support def reverse(x, axes): """Reverse a tensor along the specified axes. @@ -3314,13 +3385,14 @@ def get_value(x): if ops.executing_eagerly_outside_functions(): # This method of evaluating works inside the Keras FuncGraph. - return function([], x)(x) + return eval_in_eager_or_function(x) with x.graph.as_default(): return x.eval(session=get_session((x,))) @keras_export('keras.backend.batch_get_value') +@dispatch.add_dispatch_support def batch_get_value(tensors): """Returns the value of more than one tensor variable. @@ -3382,6 +3454,7 @@ def set_value(x, value): @keras_export('keras.backend.batch_set_value') +@dispatch.add_dispatch_support def batch_set_value(tuples): """Sets the values of many tensor variables at once. @@ -3424,6 +3497,7 @@ set_value.__doc__ = set_value.__doc__.format(snippet=_VALUE_SET_CODE_STRING) @keras_export('keras.backend.print_tensor') +@dispatch.add_dispatch_support def print_tensor(x, message=''): """Prints `message` and the tensor value when evaluated. @@ -3654,161 +3728,74 @@ class GraphExecutionFunction(object): return nest.map_structure(self._eval_if_composite, output_structure) -class EagerExecutionFunction(object): - """Helper class for constructing a TF graph function from the Keras graph. +def eval_in_eager_or_function(outputs): + """Method to evaluate a tensor in eager or in a tf.function. + + In the case of a tf.function, it will lift the tensor out of the function + and try to evaluate that piece of the graph. + + Warning: Do not add new usages of this function. + TODO(b/150169018): delete this function once _keras_history_helper is no + longer needed, after Keras switches to KerasTensors and op layers + work via dispatch. Arguments: - inputs: Feed placeholders to the computation graph. - outputs: Output tensors to fetch. - updates: Additional update ops to be run at function call. - name: A name to help users identify what this function does. - session_kwargs: Unsupported. + outputs: tensors to fetch. + Returns: + The value of the tensors (as numpy arrays). """ + outputs_structure = outputs + outputs = nest.flatten(outputs, expand_composites=True) - def __init__(self, inputs, outputs, updates=None, name=None): - self.name = name - self._inputs_structure = inputs - inputs = nest.flatten(inputs, expand_composites=True) - self._outputs_structure = outputs - outputs = nest.flatten(outputs, expand_composites=True) + graphs = { + i.graph + for i in nest.flatten([outputs]) + if hasattr(i, 'graph') + } + if len(graphs) > 1: + raise ValueError('Cannot create an execution function which is comprised ' + 'of elements from multiple graphs.') - updates = updates or [] - if not isinstance(updates, (list, tuple)): - raise TypeError('`updates` in a Keras backend function ' - 'should be a list or tuple.') + source_graph = graphs.pop() - if updates and not outputs: - # Edge case; never happens in practice - raise ValueError('Cannot create a Keras backend function with updates' - ' but no outputs during eager execution.') - graphs = { - i.graph - for i in nest.flatten([inputs, outputs, updates]) - if hasattr(i, 'graph') - } - if len(graphs) > 1: - raise ValueError('Cannot create an execution function which is comprised ' - 'of elements from multiple graphs.') - - source_graph = graphs.pop() + with _scratch_graph() as exec_graph: global_graph = get_graph() + if source_graph not in (exec_graph, global_graph): + raise ValueError('Unknown graph. Aborting.') - updates_ops = [] - legacy_update_ops = [] - for update in updates: - # For legacy reasons it is allowed to pass an update as a tuple - # `(variable, new_value)` (this maps to an assign op). Otherwise it - # is assumed to already be an op -- we cannot control its execution - # order. - if isinstance(update, tuple): - legacy_update_ops.append(update) - else: - if hasattr(update, 'op'): - update = update.op - if update is not None: - # `update.op` may have been None in certain cases. - updates_ops.append(update) + if source_graph is global_graph and exec_graph is not global_graph: + init_tensors = outputs + lifted_map = lift_to_graph.lift_to_graph( + tensors=init_tensors, + graph=exec_graph, + sources=[], + add_sources=True, + handle_captures=True, + base_graph=source_graph) - self._freezable_vars_to_feed = [] - self._freezable_vars_values = [] - freezable_vars_from_keras_graph = object_identity.ObjectIdentitySet( - _FREEZABLE_VARS.get(global_graph, {})) - with _scratch_graph() as exec_graph: - global_graph = get_graph() - if source_graph not in (exec_graph, global_graph): - raise ValueError('Unknown graph. Aborting.') + outputs = [lifted_map[i] for i in outputs] - if source_graph is global_graph and exec_graph is not global_graph: - init_tensors = ( - outputs + updates_ops + [p for [p, _] in legacy_update_ops] + - [p_new for [_, p_new] in legacy_update_ops - if isinstance(p_new, ops.Tensor)]) - lifted_map = lift_to_graph.lift_to_graph( - tensors=init_tensors, - graph=exec_graph, - sources=inputs, - add_sources=True, - handle_captures=True, - base_graph=source_graph) + # Consolidate updates + with exec_graph.as_default(): + outputs = cast_variables_to_tensor(outputs) - inputs = [lifted_map[i] for i in inputs] - outputs = [lifted_map[i] for i in outputs] - updates_ops = [lifted_map[i] for i in updates_ops] - legacy_update_ops = [(lifted_map[p], lifted_map.get(p_new, p_new)) - for p, p_new in legacy_update_ops] + exec_graph.inputs = exec_graph.internal_captures + exec_graph.outputs = outputs + graph_fn = eager_function.ConcreteFunction(exec_graph) - # Keep track of the value to feed to any "freezable variables" - # created in this graph. - for old_op, new_op in lifted_map.items(): - if old_op in freezable_vars_from_keras_graph: - frozen_var = old_op - if frozen_var._initial_value != frozen_var._current_value: - # We only feed a frozen_variable if its value has changed; - # otherwise it can rely on the default value of the - # underlying placeholder_with_default. - self._freezable_vars_to_feed.append(new_op) - self._freezable_vars_values.append(frozen_var._current_value) + graph_fn._num_positional_args = 0 + graph_fn._arg_keywords = [] - # Consolidate updates - with exec_graph.as_default(): - outputs = cast_variables_to_tensor(outputs) - with ops.control_dependencies(outputs): - for p, p_new in legacy_update_ops: - updates_ops.append(state_ops.assign(p, p_new)) + outputs = graph_fn() - self.inputs, self.outputs = inputs, outputs - self._input_references = self.inputs + self._freezable_vars_to_feed - with ops.control_dependencies(updates_ops): - self.outputs[0] = array_ops.identity(self.outputs[0]) - - exec_graph.inputs = self._input_references + exec_graph.internal_captures - exec_graph.outputs = self.outputs - graph_fn = eager_function.ConcreteFunction(exec_graph) - - graph_fn._num_positional_args = len(self._input_references) - graph_fn._arg_keywords = [] - self._graph_fn = graph_fn - - # Handle placeholders with default - # (treated as required placeholder by graph functions) - self._placeholder_default_values = {} - with exec_graph.as_default(): - for x in self.inputs: - if x.op.type == 'PlaceholderWithDefault': - self._placeholder_default_values[ops.tensor_id( - x)] = tensor_util.constant_value(x.op.inputs[0]) - - def __call__(self, inputs): - input_values = nest.flatten(inputs, expand_composites=True) - - if self._freezable_vars_values: - input_values = input_values + self._freezable_vars_values - converted_inputs = [] - for tensor, value in zip(self._input_references, input_values): - if value is None: - # Assume `value` is a placeholder with default - value = self._placeholder_default_values.get( - ops.tensor_id(tensor), None) - if value is None: - raise ValueError( - 'You must feed a value for placeholder %s' % (tensor,)) - if not isinstance(value, ops.Tensor): - value = ops.convert_to_tensor_v2(value, dtype=tensor.dtype) - if value.dtype != tensor.dtype: - # Temporary workaround due to `convert_to_tensor` not casting floats. - # See b/119637405 - value = math_ops.cast(value, tensor.dtype) - converted_inputs.append(value) - outputs = self._graph_fn(*converted_inputs) - - # EagerTensor.numpy() will often make a copy to ensure memory safety. - # However in this case `outputs` is not directly returned, so it is always - # safe to reuse the underlying buffer without checking. In such a case the - # private numpy conversion method is preferred to guarantee performance. - return nest.pack_sequence_as( - self._outputs_structure, - [x._numpy() for x in outputs], # pylint: disable=protected-access - expand_composites=True) + # EagerTensor.numpy() will often make a copy to ensure memory safety. + # However in this case `outputs` is not directly returned, so it is always + # safe to reuse the underlying buffer without checking. In such a case the + # private numpy conversion method is preferred to guarantee performance. + return nest.pack_sequence_as( + outputs_structure, + [x._numpy() for x in outputs], # pylint: disable=protected-access + expand_composites=True) @keras_export('keras.backend.function') @@ -3832,7 +3819,20 @@ def function(inputs, outputs, updates=None, name=None, **kwargs): if kwargs: raise ValueError('Session keyword arguments are not support during ' 'eager execution. You passed: %s' % (kwargs,)) - return EagerExecutionFunction(inputs, outputs, updates=updates, name=name) + if updates: + raise ValueError('`updates` argument is not support during ' + 'eager execution. You passed: %s' % (updates,)) + from tensorflow.python.keras import models # pylint: disable=g-import-not-at-top + from tensorflow.python.keras.utils import tf_utils # pylint: disable=g-import-not-at-top + model = models.Model(inputs=inputs, outputs=outputs) + + wrap_outputs = isinstance(outputs, list) and len(outputs) == 1 + def func(model_inputs): + outs = model(model_inputs) + if wrap_outputs: + outs = [outs] + return tf_utils.to_numpy_or_python_type(outs) + return func if kwargs: for key in kwargs: @@ -3861,6 +3861,7 @@ def gradients(loss, variables): @keras_export('keras.backend.stop_gradient') +@dispatch.add_dispatch_support def stop_gradient(variables): """Returns `variables` but with zero gradient w.r.t. every other variable. @@ -3882,6 +3883,7 @@ def stop_gradient(variables): @keras_export('keras.backend.rnn') +@dispatch.add_dispatch_support def rnn(step_function, inputs, initial_states, @@ -4276,6 +4278,7 @@ def rnn(step_function, @keras_export('keras.backend.switch') +@dispatch.add_dispatch_support def switch(condition, then_expression, else_expression): """Switches between two operations depending on a scalar value. @@ -4409,6 +4412,7 @@ def in_test_phase(x, alt, training=None): @keras_export('keras.backend.relu') +@dispatch.add_dispatch_support def relu(x, alpha=0., max_value=None, threshold=0): """Rectified linear unit. @@ -4462,6 +4466,7 @@ def relu(x, alpha=0., max_value=None, threshold=0): @keras_export('keras.backend.elu') +@dispatch.add_dispatch_support def elu(x, alpha=1.): """Exponential linear unit. @@ -4480,6 +4485,7 @@ def elu(x, alpha=1.): @keras_export('keras.backend.softmax') +@dispatch.add_dispatch_support def softmax(x, axis=-1): """Softmax of a tensor. @@ -4495,6 +4501,7 @@ def softmax(x, axis=-1): @keras_export('keras.backend.softplus') +@dispatch.add_dispatch_support def softplus(x): """Softplus of a tensor. @@ -4508,6 +4515,7 @@ def softplus(x): @keras_export('keras.backend.softsign') +@dispatch.add_dispatch_support def softsign(x): """Softsign of a tensor. @@ -4527,6 +4535,7 @@ def _backtrack_identity(tensor): @keras_export('keras.backend.categorical_crossentropy') +@dispatch.add_dispatch_support def categorical_crossentropy(target, output, from_logits=False, axis=-1): """Categorical crossentropy between an output tensor and a target tensor. @@ -4595,6 +4604,7 @@ def categorical_crossentropy(target, output, from_logits=False, axis=-1): @keras_export('keras.backend.sparse_categorical_crossentropy') +@dispatch.add_dispatch_support def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1): """Categorical crossentropy with integer targets. @@ -4676,6 +4686,7 @@ def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1): @keras_export('keras.backend.binary_crossentropy') +@dispatch.add_dispatch_support def binary_crossentropy(target, output, from_logits=False): """Binary crossentropy between an output tensor and a target tensor. @@ -4712,6 +4723,7 @@ def binary_crossentropy(target, output, from_logits=False): @keras_export('keras.backend.sigmoid') +@dispatch.add_dispatch_support def sigmoid(x): """Element-wise sigmoid. @@ -4725,6 +4737,7 @@ def sigmoid(x): @keras_export('keras.backend.hard_sigmoid') +@dispatch.add_dispatch_support def hard_sigmoid(x): """Segment-wise linear approximation of sigmoid. @@ -4747,6 +4760,7 @@ def hard_sigmoid(x): @keras_export('keras.backend.tanh') +@dispatch.add_dispatch_support def tanh(x): """Element-wise tanh. @@ -4760,6 +4774,7 @@ def tanh(x): @keras_export('keras.backend.dropout') +@dispatch.add_dispatch_support def dropout(x, level, noise_shape=None, seed=None): """Sets entries in `x` to zero at random, while scaling the entire tensor. @@ -4780,6 +4795,7 @@ def dropout(x, level, noise_shape=None, seed=None): @keras_export('keras.backend.l2_normalize') +@dispatch.add_dispatch_support def l2_normalize(x, axis=None): """Normalizes a tensor wrt the L2 norm alongside the specified axis. @@ -4794,6 +4810,7 @@ def l2_normalize(x, axis=None): @keras_export('keras.backend.in_top_k') +@dispatch.add_dispatch_support def in_top_k(predictions, targets, k): """Returns whether the `targets` are in the top `k` `predictions`. @@ -4896,6 +4913,7 @@ def _preprocess_padding(padding): @keras_export('keras.backend.conv1d') +@dispatch.add_dispatch_support def conv1d(x, kernel, strides=1, @@ -4946,6 +4964,7 @@ def conv1d(x, @keras_export('keras.backend.conv2d') +@dispatch.add_dispatch_support def conv2d(x, kernel, strides=(1, 1), @@ -4989,6 +5008,7 @@ def conv2d(x, @keras_export('keras.backend.conv2d_transpose') +@dispatch.add_dispatch_support def conv2d_transpose(x, kernel, output_shape, @@ -5129,6 +5149,7 @@ def separable_conv1d(x, @keras_export('keras.backend.separable_conv2d') +@dispatch.add_dispatch_support def separable_conv2d(x, depthwise_kernel, pointwise_kernel, @@ -5186,6 +5207,7 @@ def separable_conv2d(x, @keras_export('keras.backend.depthwise_conv2d') +@dispatch.add_dispatch_support def depthwise_conv2d(x, depthwise_kernel, strides=(1, 1), @@ -5235,6 +5257,7 @@ def depthwise_conv2d(x, @keras_export('keras.backend.conv3d') +@dispatch.add_dispatch_support def conv3d(x, kernel, strides=(1, 1, 1), @@ -5337,6 +5360,7 @@ def conv3d_transpose(x, @keras_export('keras.backend.pool2d') +@dispatch.add_dispatch_support def pool2d(x, pool_size, strides=(1, 1), @@ -5396,6 +5420,7 @@ def pool2d(x, @keras_export('keras.backend.pool3d') +@dispatch.add_dispatch_support def pool3d(x, pool_size, strides=(1, 1, 1), @@ -5526,6 +5551,7 @@ def local_conv(inputs, @keras_export('keras.backend.local_conv1d') +@dispatch.add_dispatch_support def local_conv1d(inputs, kernel, kernel_size, strides, data_format=None): """Apply 1D conv with un-shared weights. @@ -5561,6 +5587,7 @@ def local_conv1d(inputs, kernel, kernel_size, strides, data_format=None): @keras_export('keras.backend.local_conv2d') +@dispatch.add_dispatch_support def local_conv2d(inputs, kernel, kernel_size, @@ -5602,6 +5629,7 @@ def local_conv2d(inputs, @keras_export('keras.backend.bias_add') +@dispatch.add_dispatch_support def bias_add(x, bias, data_format=None): """Adds a bias vector to a tensor. @@ -5646,6 +5674,7 @@ def bias_add(x, bias, data_format=None): @keras_export('keras.backend.random_normal') +@dispatch.add_dispatch_support def random_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): """Returns a tensor with normal distribution of values. @@ -5682,6 +5711,7 @@ def random_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): @keras_export('keras.backend.random_uniform') +@dispatch.add_dispatch_support def random_uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None): """Returns a tensor with uniform distribution of values. @@ -5715,6 +5745,7 @@ def random_uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None): @deprecated(None, 'Use `tf.keras.backend.random_bernoulli` instead.') @keras_export('keras.backend.random_binomial') +@dispatch.add_dispatch_support def random_binomial(shape, p=0.0, dtype=None, seed=None): """Returns a tensor with random binomial distribution of values. @@ -5751,6 +5782,7 @@ def random_binomial(shape, p=0.0, dtype=None, seed=None): @keras_export('keras.backend.random_bernoulli') +@dispatch.add_dispatch_support def random_bernoulli(shape, p=0.0, dtype=None, seed=None): """Returns a tensor with random bernoulli distribution of values. @@ -5767,6 +5799,7 @@ def random_bernoulli(shape, p=0.0, dtype=None, seed=None): @keras_export('keras.backend.truncated_normal') +@dispatch.add_dispatch_support def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): """Returns a tensor with truncated random normal distribution of values. @@ -5801,6 +5834,7 @@ def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): @keras_export('keras.backend.ctc_label_dense_to_sparse') +@dispatch.add_dispatch_support def ctc_label_dense_to_sparse(labels, label_lengths): """Converts CTC labels from dense to sparse. @@ -5847,6 +5881,7 @@ def ctc_label_dense_to_sparse(labels, label_lengths): @keras_export('keras.backend.ctc_batch_cost') +@dispatch.add_dispatch_support def ctc_batch_cost(y_true, y_pred, input_length, label_length): """Runs CTC loss algorithm on each batch element. @@ -5879,6 +5914,7 @@ def ctc_batch_cost(y_true, y_pred, input_length, label_length): @keras_export('keras.backend.ctc_decode') +@dispatch.add_dispatch_support def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1): """Decodes the output of a softmax. @@ -6240,10 +6276,6 @@ class ContextValueCache(weakref.WeakKeyDictionary): # either train mode (learning_phase == 1) or test mode (learning_phase == 0). _GRAPH_LEARNING_PHASES = ContextValueCache(_default_learning_phase) -# This dictionary holds a mapping {graph: set_of_freezable_variables}. -# Each set tracks objects created via `freezable_variable` in the graph. -_FREEZABLE_VARS = ContextValueCache(object_identity.ObjectIdentityWeakSet) - # This dictionary holds a mapping between a graph and variables to initialize # in the graph. _GRAPH_VARIABLES = ContextValueCache(object_identity.ObjectIdentityWeakSet) diff --git a/tensorflow/python/keras/backend_config.py b/tensorflow/python/keras/backend_config.py index c1bf163c444..cd1f1e4b423 100644 --- a/tensorflow/python/keras/backend_config.py +++ b/tensorflow/python/keras/backend_config.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.util import dispatch from tensorflow.python.util.tf_export import keras_export # The type of float to use throughout a session. @@ -30,6 +31,7 @@ _IMAGE_DATA_FORMAT = 'channels_last' @keras_export('keras.backend.epsilon') +@dispatch.add_dispatch_support def epsilon(): """Returns the value of the fuzz factor used in numeric expressions. @@ -110,6 +112,7 @@ def set_floatx(value): @keras_export('keras.backend.image_data_format') +@dispatch.add_dispatch_support def image_data_format(): """Returns the default image data format convention. diff --git a/tensorflow/python/keras/backend_test.py b/tensorflow/python/keras/backend_test.py index 1adc20652b2..20547c570c7 100644 --- a/tensorflow/python/keras/backend_test.py +++ b/tensorflow/python/keras/backend_test.py @@ -1677,8 +1677,10 @@ class BackendCrossEntropyLossesTest(test.TestCase, parameterized.TestCase): t, p, from_logits=True, axis=0), self.assertArrayNear(self.evaluate(result)[0], [.002, 0, .17], 1e-3) - @combinations.generate(combinations.combine(mode=['graph', 'eager'])) + @combinations.generate(combinations.combine(mode=['graph'])) def test_sparse_categorical_crossentropy_loss_with_unknown_rank_tensor(self): + # This test only runs in graph because the TF op layer is not supported yet + # for sparse ops. t = backend.placeholder() p = backend.placeholder() o = backend.sparse_categorical_crossentropy(t, p) @@ -1870,6 +1872,8 @@ class TestRandomOps(test.TestCase): class FunctionTest(test.TestCase): def test_function_basics(self): + if context.executing_eagerly(): + self.skipTest('eager backend.function does not support updates') x1 = backend.placeholder(shape=(), dtype='float32') x2 = backend.placeholder(shape=(), dtype='int32') v = backend.variable(10.) @@ -1916,6 +1920,9 @@ class FunctionTest(test.TestCase): self.assertEqual(result, 4.) def test_tuple_updates(self): + if context.executing_eagerly(): + self.skipTest('eager backend.function does not support updates') + x_ph = backend.placeholder(ndim=2) v = backend.variable(np.ones((4, 2))) output = x_ph ** 2 + v @@ -1929,7 +1936,7 @@ class FunctionTest(test.TestCase): class BackendGraphTests(test.TestCase, parameterized.TestCase): - @combinations.generate(combinations.combine(mode=['graph', 'eager'])) + @combinations.generate(combinations.combine(mode=['graph'])) def test_function_placeholder_with_default(self): with backend.get_graph().as_default(): x1 = array_ops.placeholder_with_default( diff --git a/tensorflow/python/keras/distribute/BUILD b/tensorflow/python/keras/distribute/BUILD index 87625446e2f..6a39ebc5007 100644 --- a/tensorflow/python/keras/distribute/BUILD +++ b/tensorflow/python/keras/distribute/BUILD @@ -128,7 +128,6 @@ distribute_py_test( "multi_and_single_gpu", "no_rocm", # times out on ROCm "no_windows_gpu", - "notpu", # TODO(b/155867206) flaky segfault "notsan", ], tpu_tags = [ @@ -431,10 +430,10 @@ py_test( python_version = "PY3", shard_count = 5, tags = [ - "noasan", - "nomsan", - "notsan", - ], # TODO(b/156029134) + "noasan", # TODO(b/156029134) + "nomsan", # TODO(b/156029134) + "notsan", # TODO(b/156029134) + ], deps = [ "//tensorflow/python:platform", "//tensorflow/python/data/ops:dataset_ops", diff --git a/tensorflow/python/keras/distribute/distribute_strategy_test.py b/tensorflow/python/keras/distribute/distribute_strategy_test.py index f6a83c499fe..eac1e2feb8b 100644 --- a/tensorflow/python/keras/distribute/distribute_strategy_test.py +++ b/tensorflow/python/keras/distribute/distribute_strategy_test.py @@ -575,8 +575,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, @combinations.generate( combinations.combine( - distribution=[strategy_combinations.one_device_strategy] + - tpu_strategies, + distribution=[strategy_combinations.one_device_strategy], mode=['graph', 'eager'])) def test_optimizer_in_cross_replica_context_raises_error(self, distribution): @@ -1070,6 +1069,11 @@ class TestDistributionStrategyWithDatasets(test.TestCase, @combinations.generate(all_strategy_combinations()) def test_on_dataset_with_unknown_cardinality_without_steps( self, distribution, mode): + # TODO(b/155867206): Investigate why this test occasionally segfaults on TPU + # in eager mode. + if mode == 'eager' and isinstance( + distribution, (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)): + self.skipTest('caused segfault with TPU in eager mode.') if mode == 'graph' and isinstance( distribution, (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)): diff --git a/tensorflow/python/keras/distribute/keras_utils_test.py b/tensorflow/python/keras/distribute/keras_utils_test.py index 702d89d95f8..0f65bbbf917 100644 --- a/tensorflow/python/keras/distribute/keras_utils_test.py +++ b/tensorflow/python/keras/distribute/keras_utils_test.py @@ -26,7 +26,6 @@ import numpy as np from tensorflow.python import keras from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import combinations -from tensorflow.python.distribute import parameter_server_strategy from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import tpu_strategy from tensorflow.python.distribute import values @@ -398,9 +397,6 @@ class TestDistributionStrategyWithNormalizationLayer(test.TestCase, optimizer=strategy_combinations .gradient_descent_optimizer_keras_v2_fn))) def test_batchnorm_correctness(self, distribution, fused, optimizer): - if isinstance(distribution.extended, - parameter_server_strategy.ParameterServerStrategyExtended): - self.skipTest('b/152353796') with self.cached_session(): with distribution.scope(): model = keras.models.Sequential() diff --git a/tensorflow/python/keras/distribute/multi_worker_tutorial_test.py b/tensorflow/python/keras/distribute/multi_worker_tutorial_test.py index 1a46bcd7499..3f9ab18f89c 100644 --- a/tensorflow/python/keras/distribute/multi_worker_tutorial_test.py +++ b/tensorflow/python/keras/distribute/multi_worker_tutorial_test.py @@ -120,8 +120,8 @@ class MultiWorkerTutorialTest(parameterized.TestCase, test.TestCase): multi_worker_model.fit( multi_worker_dataset, - epochs=3, - steps_per_epoch=70, + epochs=2, + steps_per_epoch=20, callbacks=callbacks) with test_util.skip_if_error(self, errors_impl.UnavailableError): diff --git a/tensorflow/python/keras/engine/BUILD b/tensorflow/python/keras/engine/BUILD index 1ff15d7e2e1..231ab7661f0 100644 --- a/tensorflow/python/keras/engine/BUILD +++ b/tensorflow/python/keras/engine/BUILD @@ -118,6 +118,7 @@ py_library( "//tensorflow/python/distribute:distribute_lib", "//tensorflow/python/distribute:input_lib", "//tensorflow/python/distribute:reduce_util", + "//tensorflow/python/distribute:sharded_variable", "//tensorflow/python/eager:monitoring", "//tensorflow/python/keras:backend", "//tensorflow/python/keras:constraints", diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index 94b696d842b..b986f9a405e 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -34,6 +34,7 @@ from tensorflow.python import tf2 from tensorflow.python.autograph.core import ag_ctx from tensorflow.python.autograph.impl import api as autograph from tensorflow.python.distribute import distribution_strategy_context as ds_context +from tensorflow.python.distribute import sharded_variable from tensorflow.python.eager import context from tensorflow.python.eager import execute from tensorflow.python.eager import function @@ -385,6 +386,11 @@ class Layer(module.Module, version_utils.LayerVersionSelector): # might want to turn it off, like Sequential model. self._auto_track_sub_layers = True + # Will compute masking if `compute_mask` is overridden or `supports_masking` + # is set. + self._compute_mask_overridden = (not getattr(self.compute_mask, + '_is_default', False)) + @trackable.no_automatic_dependency_tracking @generic_utils.default def build(self, input_shape): @@ -590,7 +596,9 @@ class Layer(module.Module, version_utils.LayerVersionSelector): self._handle_weight_regularization(name_in_scope, variable, regularizer) - if isinstance(variable, tf_variables.PartitionedVariable): + if isinstance( + variable, + (tf_variables.PartitionedVariable, sharded_variable.ShardedVariable)): for v in variable: backend.track_variable(v) if trainable: @@ -814,6 +822,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector): inputs, args, kwargs = self._split_out_first_arg(args, kwargs) call_context = base_layer_utils.call_context() + in_call = call_context.in_call input_list = nest.flatten(inputs) # We will attempt to build a TF graph if & only if all inputs are symbolic. @@ -841,7 +850,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector): # setting the `_keras_mask` attribute on the inputs to a Layer. Masks passed # explicitly take priority. mask_arg_passed_by_framework = False - input_masks = self._collect_input_masks(inputs, args, kwargs) + input_masks = self._collect_input_masks(inputs, input_list, args, kwargs) if (self._expects_mask_arg and input_masks is not None and not self._call_arg_was_passed('mask', args, kwargs)): mask_arg_passed_by_framework = True @@ -888,16 +897,15 @@ class Layer(module.Module, version_utils.LayerVersionSelector): if build_graph and base_layer_utils.needs_keras_history(inputs): base_layer_utils.create_keras_history(inputs) - # Clear eager losses on top level model call. - # We are clearing the losses only on the top level model call and not on - # every layer/model call because layer/model may be reused. - if (base_layer_utils.is_in_eager_or_tf_function() and - not call_context.in_call): - self._clear_losses() - with call_context.enter(self, inputs, build_graph, training_value): # Check input assumptions set after layer building, e.g. input shape. if build_graph: + # Losses are cleared for all Layers when the outermost layer is called. + # Losses are not cleared each time an inner layer is called, bc inner + # Layers can be reused in a Model. + if not in_call and base_layer_utils.is_in_tf_function(): + self._clear_losses() + # Symbolic execution on symbolic tensors. We will attempt to build # the corresponding TF subgraph inside `backend.get_graph()` # TODO(reedwm): We should assert input compatibility after the inputs @@ -905,11 +913,12 @@ class Layer(module.Module, version_utils.LayerVersionSelector): input_spec.assert_input_compatibility(self.input_spec, inputs, self.name) graph = backend.get_graph() + # Use `self._name_scope()` to avoid auto-incrementing the name. with graph.as_default(), backend.name_scope(self._name_scope()): # Build layer if applicable (if the `build` method has been # overridden). self._maybe_build(inputs) - cast_inputs = self._maybe_cast_inputs(inputs) + cast_inputs = self._maybe_cast_inputs(inputs, input_list) if not self.dynamic: # Wrapping `call` function in autograph to allow for dynamic control @@ -970,21 +979,29 @@ class Layer(module.Module, version_utils.LayerVersionSelector): outputs = self._set_connectivity_metadata((inputs,) + args, kwargs, outputs) self._handle_activity_regularization(inputs, outputs) - self._set_mask_metadata(inputs, outputs, input_masks) + self._set_mask_metadata(inputs, outputs, input_masks, build_graph) if hasattr(self, '_set_inputs') and not self.inputs: # Subclassed network: explicitly set metadata normally set by # a call to self._set_inputs(). self._set_inputs(cast_inputs, outputs) else: # Eager execution on data tensors. - with backend.name_scope(self._name_scope()): + + # Losses are cleared for all Layers when the outermost layer is called. + # Losses are not cleared each time an inner layer is called, bc inner + # Layers can be reused in a Model. + if not in_call: + self._clear_losses() + + # In Eager mode, `ops.name_scope_v2` does not autoincrement the name. + with ops.name_scope_v2(self.name): self._maybe_build(inputs) - cast_inputs = self._maybe_cast_inputs(inputs) + cast_inputs = self._maybe_cast_inputs(inputs, input_list) with base_layer_utils.autocast_context_manager( self._compute_dtype): outputs = self.call(cast_inputs, *args, **kwargs) self._handle_activity_regularization(inputs, outputs) - self._set_mask_metadata(inputs, outputs, input_masks) + self._set_mask_metadata(inputs, outputs, input_masks, build_graph) if hasattr(self, '_set_save_spec'): self._set_save_spec(cast_inputs) @@ -1336,13 +1353,14 @@ class Layer(module.Module, version_utils.LayerVersionSelector): # Possible a loss was added in a Layer's `build`. self._losses.append(symbolic_loss) - @trackable.no_automatic_dependency_tracking def _clear_losses(self): """Used every step in eager to reset losses.""" - self._eager_losses = [] - if hasattr(self, '_layers'): - for layer in trackable_layer_utils.filter_empty_layer_containers( - self._layers): + # Set to thread local directly to avoid Layer.__setattr__ overhead. + self._thread_local._eager_losses = [] + sublayers = getattr(self, '_layers', []) + if sublayers: + sublayers = trackable_layer_utils.filter_empty_layer_containers(sublayers) + for layer in sublayers: layer._clear_losses() @property @@ -2114,7 +2132,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector): """ return self._dtype_policy.compute_dtype - def _maybe_cast_inputs(self, inputs): + def _maybe_cast_inputs(self, inputs, input_list): """Maybe casts the inputs to the compute dtype. If self._compute_dtype is floating-point, and self_autocast is True, @@ -2122,32 +2140,38 @@ class Layer(module.Module, version_utils.LayerVersionSelector): Args: inputs: Input tensor, or structure of input tensors. + input_list: Flat list of input tensors. Returns: `inputs`, but tensors may have been casted to self._compute_dtype """ compute_dtype = self._compute_dtype - if (self._autocast and compute_dtype and - dtypes.as_dtype(compute_dtype).is_floating): - def f(x): - """Cast a single Tensor or TensorSpec to the compute dtype.""" - cast_types = (ops.Tensor, sparse_tensor.SparseTensor, - ragged_tensor.RaggedTensor) - if (isinstance(x, cast_types) and x.dtype.is_floating and - x.dtype.base_dtype.name != compute_dtype): - if self._dtype_defaulted_to_floatx: - self._warn_about_input_casting(x.dtype.base_dtype) - return math_ops.cast(x, compute_dtype) - elif isinstance(x, tensor_spec.TensorSpec) and x.dtype.is_floating: - # Inputs may be TensorSpecs when this function is called from - # model._set_inputs. - return tensor_spec.TensorSpec(x.shape, compute_dtype, x.name) - else: - return x - return nest.map_structure(f, inputs) + should_autocast = ( + self._autocast and compute_dtype and + dtypes.as_dtype(compute_dtype).is_floating) + + if (should_autocast and + any(self._should_cast_single_input(x) for x in input_list)): + # Only perform expensive `nest` operation when needed. + return nest.map_structure(self._cast_single_input, inputs) else: return inputs + def _should_cast_single_input(self, x): + cast_types = (ops.Tensor, sparse_tensor.SparseTensor, + ragged_tensor.RaggedTensor) + return (isinstance(x, cast_types) and x.dtype.is_floating and + x.dtype.base_dtype.name != self._compute_dtype) + + def _cast_single_input(self, x): + """Cast a single Tensor or TensorSpec to the compute dtype.""" + if self._should_cast_single_input(x): + if self._dtype_defaulted_to_floatx: + self._warn_about_input_casting(x.dtype.base_dtype) + return math_ops.cast(x, self._compute_dtype) + else: + return x + def _warn_about_input_casting(self, input_dtype): # self._already_warned_about_input_casting is only retrieved or set in this # function. @@ -2250,47 +2274,45 @@ class Layer(module.Module, version_utils.LayerVersionSelector): mean_activity_loss = activity_loss / batch_size self.add_loss(mean_activity_loss) - def _set_mask_metadata(self, inputs, outputs, previous_mask): + def _set_mask_metadata(self, inputs, outputs, previous_mask, build_graph): + # Many `Layer`s don't need to call `compute_mask`. + # This method is optimized to do as little work as needed for the common + # case. + if not self.supports_masking and not self._compute_mask_overridden: + return + flat_outputs = nest.flatten(outputs) mask_already_computed = ( getattr(self, '_compute_output_and_mask_jointly', False) or all(getattr(x, '_keras_mask', None) is not None for x in flat_outputs)) - - # Only compute the mask if the Layer explicitly supports masking or has - # overridden `compute_mask`. - should_compute_mask = ( - hasattr(self, 'compute_mask') and - (self.supports_masking or - not getattr(self.compute_mask, '_is_default', False))) - if mask_already_computed: - flat_masks = [getattr(x, '_keras_mask', None) for x in flat_outputs] - elif not should_compute_mask: - flat_masks = [None for _ in flat_outputs] - else: - output_masks = self.compute_mask(inputs, previous_mask) - # `compute_mask` can return a single `None` even when a Layer - # has multiple outputs. - if output_masks is None: - flat_masks = [None for _ in flat_outputs] - else: - flat_masks = nest.flatten(output_masks) + if build_graph: + self._set_mask_keras_history_checked(flat_outputs) + return - for output, mask in zip(flat_outputs, flat_masks): + output_masks = self.compute_mask(inputs, previous_mask) + if output_masks is None: + return + + flat_masks = nest.flatten(output_masks) + for tensor, mask in zip(flat_outputs, flat_masks): try: - output._keras_mask = mask + tensor._keras_mask = mask except AttributeError: # C Type such as np.ndarray. pass - if tf_utils.are_all_symbolic_tensors(flat_outputs): - for output in flat_outputs: - if getattr(output, '_keras_mask', None) is not None: - # Do not track masks for `TensorFlowOpLayer` construction. - output._keras_mask._keras_history_checked = True + if build_graph: + self._set_mask_keras_history_checked(flat_outputs) - def _collect_input_masks(self, inputs, args, kwargs): + def _set_mask_keras_history_checked(self, flat_outputs): + for output in flat_outputs: + if getattr(output, '_keras_mask', None) is not None: + # Do not track masks for `TensorFlowOpLayer` construction. + output._keras_mask._keras_history_checked = True + + def _collect_input_masks(self, inputs, input_list, args, kwargs): """Checks if `mask` argument was passed, else gathers mask from inputs.""" if self._call_arg_was_passed('mask', args, kwargs): return self._get_call_arg_value('mask', args, kwargs) @@ -2298,22 +2320,25 @@ class Layer(module.Module, version_utils.LayerVersionSelector): if not self._should_compute_mask: return None - input_masks = nest.map_structure(lambda t: getattr(t, '_keras_mask', None), - inputs) - if generic_utils.is_all_none(input_masks): + input_masks = [getattr(t, '_keras_mask', None) for t in input_list] + if all(mask is None for mask in input_masks): return None - return input_masks + + # Only do expensive `nest` operation when masking is actually being used. + return nest.pack_sequence_as(inputs, input_masks) def _call_arg_was_passed(self, arg_name, args, kwargs, inputs_in_args=False): + # Performance optimization: do no work in most common case. + if not args and not kwargs: + return False + if arg_name in kwargs: return True call_fn_args = self._call_fn_args if not inputs_in_args: # Ignore `inputs` arg. call_fn_args = call_fn_args[1:] - if arg_name in dict(zip(call_fn_args, args)): - return True - return False + return arg_name in dict(zip(call_fn_args, args)) def _get_call_arg_value(self, arg_name, args, kwargs, inputs_in_args=False): if arg_name in kwargs: @@ -2588,7 +2613,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector): # Keep track of metric instance created in subclassed layer. from tensorflow.python.keras import metrics as metrics_module # pylint: disable=g-import-not-at-top for val in nest.flatten(value): - if isinstance(val, metrics_module.Metric): + if isinstance(val, metrics_module.Metric) and hasattr(self, '_metrics'): self._metrics.append(val) # TODO(scottzhu): Need to track Module object as well for weight tracking. diff --git a/tensorflow/python/keras/engine/base_layer_utils.py b/tensorflow/python/keras/engine/base_layer_utils.py index c5e00d8e38e..7e4e0e5da4a 100644 --- a/tensorflow/python/keras/engine/base_layer_utils.py +++ b/tensorflow/python/keras/engine/base_layer_utils.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools import threading from tensorflow.python import tf2 @@ -121,7 +122,7 @@ def make_variable(name, initializer, (type(init_ops.Initializer), type(init_ops_v2.Initializer))): initializer = initializer() - init_val = lambda: initializer(shape, dtype=dtype) + init_val = functools.partial(initializer, shape, dtype=dtype) variable_dtype = dtype.base_dtype if use_resource is None: use_resource = True @@ -247,7 +248,10 @@ def _create_keras_history_helper(tensors, processed_ops, created_layers): constants[i] = op_input else: with ops.init_scope(): - constants[i] = backend.function([], op_input)([]) + if ops.executing_eagerly_outside_functions(): + constants[i] = backend.eval_in_eager_or_function(op_input) + else: + constants[i] = backend.function([], op_input)([]) layer_inputs = unnest_if_single_tensor(layer_inputs) processed_ops, created_layers = _create_keras_history_helper( layer_inputs, processed_ops, created_layers) diff --git a/tensorflow/python/keras/engine/base_layer_v1.py b/tensorflow/python/keras/engine/base_layer_v1.py index 4a277ec3a3e..80e0b4be2f1 100644 --- a/tensorflow/python/keras/engine/base_layer_v1.py +++ b/tensorflow/python/keras/engine/base_layer_v1.py @@ -2226,7 +2226,7 @@ class Layer(base_layer.Layer): # Keep track of metric instance created in subclassed layer. from tensorflow.python.keras import metrics as metrics_module # pylint: disable=g-import-not-at-top for val in nest.flatten(value): - if isinstance(val, metrics_module.Metric): + if isinstance(val, metrics_module.Metric) and hasattr(self, '_metrics'): self._metrics.append(val) # TODO(scottzhu): Need to track Module object as well for weight tracking. diff --git a/tensorflow/python/keras/engine/base_preprocessing_layer.py b/tensorflow/python/keras/engine/base_preprocessing_layer.py index 84138dd0a00..efd8a0e621f 100644 --- a/tensorflow/python/keras/engine/base_preprocessing_layer.py +++ b/tensorflow/python/keras/engine/base_preprocessing_layer.py @@ -143,9 +143,12 @@ class CombinerPreprocessingLayer(PreprocessingLayer): accumulator = self._combiner.restore(self._restore_updates()) if not isinstance(data, - (dataset_ops.DatasetV2, np.ndarray, ops.EagerTensor)): + (dataset_ops.DatasetV2, + np.ndarray, + ops.Tensor, + ragged_tensor.RaggedTensor)): raise ValueError( - '`adapt()` requires a batched Dataset, an EagerTensor, ' + '`adapt()` requires a batched Dataset, a Tensor, ' 'or a Numpy array as input, ' 'got {}'.format(type(data))) @@ -158,9 +161,14 @@ class CombinerPreprocessingLayer(PreprocessingLayer): 'elements. Please use `dataset.take(...)` to make the number ' 'of elements finite.') next_data = self._get_dataset_iterator(data) + # TODO(fchollet): consider checking if the dataset is already batched + # and otherwise batching it. + elif isinstance(data, (ops.Tensor, ragged_tensor.RaggedTensor)): + next_data = self._get_dataset_iterator( + dataset_ops.Dataset.from_tensor_slices(data).batch(512)) else: generator, _ = training_generator.convert_to_generator_like( - data, batch_size=len(data)) + data, batch_size=512) # If the data is not a dataset, we can iterate over it using next(foo); # here, we wrap that into a callable. next_data = lambda: next(generator) diff --git a/tensorflow/python/keras/engine/base_preprocessing_layer_v1.py b/tensorflow/python/keras/engine/base_preprocessing_layer_v1.py index fb77b696f68..f603fac25c3 100644 --- a/tensorflow/python/keras/engine/base_preprocessing_layer_v1.py +++ b/tensorflow/python/keras/engine/base_preprocessing_layer_v1.py @@ -55,8 +55,9 @@ class CombinerPreprocessingLayer( def _get_dataset_iterator(self, dataset): """Gets an iterator from a tf.data.Dataset.""" - iterator = dataset_ops.make_one_shot_iterator(dataset) + iterator = dataset_ops.make_initializable_iterator(dataset) session = K.get_session() + session.run(iterator.initializer) next_element = iterator.get_next() return lambda: session.run(next_element) diff --git a/tensorflow/python/keras/engine/functional.py b/tensorflow/python/keras/engine/functional.py index c79e2849c4f..4958990ad66 100644 --- a/tensorflow/python/keras/engine/functional.py +++ b/tensorflow/python/keras/engine/functional.py @@ -25,6 +25,7 @@ import itertools from six.moves import zip # pylint: disable=redefined-builtin +from tensorflow.python.eager import context from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import ops from tensorflow.python.keras import backend @@ -358,7 +359,8 @@ class Functional(training_lib.Model): # by itself because it will duplicate any updates and losses in graph # mode by `call`ing the Layers again. output_tensors = self._run_internal_graph(inputs, mask=mask) - return nest.map_structure(lambda t: t._keras_mask, output_tensors) + return nest.map_structure(lambda t: getattr(t, '_keras_mask', None), + output_tensors) def call(self, inputs, training=None, mask=None): """Calls the model on new inputs. @@ -469,11 +471,11 @@ class Functional(training_lib.Model): mask: (Optional) Tensor or nested structure of Tensors. Returns: - Two lists: output_tensors, output_masks + output_tensors """ inputs = self._flatten_to_reference_inputs(inputs) if mask is None: - masks = [None for _ in range(len(inputs))] + masks = [None] * len(inputs) else: masks = self._flatten_to_reference_inputs(mask) for input_t, mask in zip(inputs, masks): @@ -481,55 +483,39 @@ class Functional(training_lib.Model): # Dictionary mapping reference tensors to computed tensors. tensor_dict = {} + tensor_usage_count = self._tensor_usage_count for x, y in zip(self.inputs, inputs): y = self._conform_to_reference_input(y, ref_input=x) x_id = str(id(x)) - tensor_dict[x_id] = [y] * self._tensor_usage_count[x_id] + tensor_dict[x_id] = [y] * tensor_usage_count[x_id] - depth_keys = list(self._nodes_by_depth.keys()) + nodes_by_depth = self._nodes_by_depth + depth_keys = list(nodes_by_depth.keys()) depth_keys.sort(reverse=True) for depth in depth_keys: - nodes = self._nodes_by_depth[depth] + nodes = nodes_by_depth[depth] for node in nodes: if node.is_input: continue # Input tensors already exist. - if not all( - str(id(tensor)) in tensor_dict - for tensor in nest.flatten(node.keras_inputs)): + if any(t_id not in tensor_dict for t_id in node.flat_input_ids): continue # Node is not computable, try skipping. - layer = node.layer args, kwargs = node.map_arguments(tensor_dict) - outputs = layer(*args, **kwargs) + outputs = node.layer(*args, **kwargs) # Update tensor_dict. - for x, y in zip(nest.flatten(node.outputs), nest.flatten(outputs)): - x_id = str(id(x)) - tensor_dict[x_id] = [y] * self._tensor_usage_count[x_id] + for x_id, y in zip(node.flat_output_ids, nest.flatten(outputs)): + tensor_dict[x_id] = [y] * tensor_usage_count[x_id] output_tensors = [] - output_shapes = [] for x in self.outputs: - assert str(id(x)) in tensor_dict, 'Could not compute output ' + str(x) - tensor = tensor_dict[str(id(x))].pop() - output_shapes.append(x.shape) - output_tensors.append(tensor) + x_id = str(id(x)) + assert x_id in tensor_dict, 'Could not compute output ' + str(x) + output_tensors.append(tensor_dict[x_id].pop()) - if output_shapes is not None: - input_shapes = [x.shape for x in inputs] - try: - cache_key = tuple(tf_utils.convert_shapes(input_shapes, to_tuples=True)) - self._output_shape_cache[cache_key] = nest.pack_sequence_as( - self._nested_outputs, output_shapes) - except ValueError: - # In case there are unknown TensorShape, eg for sparse tensor input, - # We skip the caching since the shape is unknown. - pass - - output_tensors = nest.pack_sequence_as(self._nested_outputs, output_tensors) - return output_tensors + return nest.pack_sequence_as(self._nested_outputs, output_tensors) def _flatten_to_reference_inputs(self, tensors): """Maps `tensors` to their respective `keras.Input`.""" @@ -550,34 +536,38 @@ class Functional(training_lib.Model): def _conform_to_reference_input(self, tensor, ref_input): """Set shape and dtype based on `keras.Input`s.""" - # Shape handling (only for non-CompositeTensors). - if isinstance(tensor, ops.Tensor) and isinstance(ref_input, ops.Tensor): + if isinstance(tensor, ops.Tensor): # Allow (None,) and (None, 1) Tensors to be passed interchangably. Use the # shape specified by the `keras.Input`. - if tensor.shape.rank is not None and ref_input.shape.rank is not None: - should_squeeze_last_dim = ( - tensor.shape.rank == ref_input.shape.rank + 1 and - tensor.shape[-1] == 1) - should_expand_last_dim = ( - tensor.shape.rank == ref_input.shape.rank - 1 and - ref_input.shape[-1] == 1) - if should_squeeze_last_dim: + t_shape = tensor.shape + t_rank = t_shape.rank + ref_shape = ref_input.shape + ref_rank = ref_shape.rank + if t_rank is not None and ref_rank is not None: + # Should squeeze last dimension. + # True if tensor is (BATCH, ..., 1) and reference is (BATCH, ...). + if (t_rank == ref_rank + 1 and t_shape[-1] == 1): tensor = array_ops.squeeze_v2(tensor, axis=-1) - elif should_expand_last_dim: + # Should expand last_dimension. + # True if tensor is (BATCH, ...) and reference is (BATCH, ..., 1). + elif (t_rank == ref_rank - 1 and ref_shape[-1] == 1): tensor = array_ops.expand_dims_v2(tensor, axis=-1) - # Add shape hints to Tensors that might have None shape dims but have - # shapes defined by the `keras.Input`. - try: - tensor.set_shape(tensor.shape.merge_with(ref_input.shape)) - except ValueError: - logging.warning( - 'Model was constructed with shape {} for input {}, but it was ' - 'called on an input with incompatible shape {}.'.format( - ref_input.shape, ref_input, tensor.shape)) + # Add shape hints to Tensors that may have None shape dims but have shapes + # defined by the `keras.Input` (not applicable in eager mode). + if not context.executing_eagerly(): + try: + tensor.set_shape(tensor.shape.merge_with(ref_input.shape)) + except ValueError: + logging.warning( + 'Model was constructed with shape {} for input {}, but it was ' + 'called on an input with incompatible shape {}.'.format( + ref_input.shape, ref_input, tensor.shape)) - # Dtype handling. - if isinstance(ref_input, (ops.Tensor, composite_tensor.CompositeTensor)): + # Dtype casting. + tensor = math_ops.cast(tensor, dtype=ref_input.dtype) + elif isinstance(tensor, composite_tensor.CompositeTensor): + # Dtype casting. tensor = math_ops.cast(tensor, dtype=ref_input.dtype) return tensor diff --git a/tensorflow/python/keras/engine/input_layer.py b/tensorflow/python/keras/engine/input_layer.py index ed715f61897..02e43110697 100644 --- a/tensorflow/python/keras/engine/input_layer.py +++ b/tensorflow/python/keras/engine/input_layer.py @@ -161,8 +161,11 @@ class InputLayer(base_layer.Layer): 'InputLayer, you should instantiate your model and ' 'directly call it on your input.') self.is_placeholder = False - self._batch_input_shape = tuple(input_tensor.shape.as_list()) - + try: + self._batch_input_shape = tuple(input_tensor.shape.as_list()) + except ValueError: + # If the shape cannot be represented as a tuple (e.g. unknown rank) + self._batch_input_shape = None # Create an input node. input_tensor._keras_mask = None node_module.Node(layer=self, outputs=input_tensor) @@ -215,7 +218,9 @@ def Input( # pylint: disable=invalid-name dtype: The data type expected by the input, as a string (`float32`, `float64`, `int32`...) sparse: A boolean specifying whether the placeholder to be created is - sparse. Only one of 'ragged' and 'sparse' can be True. + sparse. Only one of 'ragged' and 'sparse' can be True. Note that, + if `sparse` is False, sparse tensors can still be passed into the + input - they will be densified with a default value of 0. tensor: Optional existing tensor to wrap into the `Input` layer. If set, the layer will not create a placeholder tensor. ragged: A boolean specifying whether the placeholder to be created is diff --git a/tensorflow/python/keras/engine/node.py b/tensorflow/python/keras/engine/node.py index 945cf1c64bd..708904853b2 100644 --- a/tensorflow/python/keras/engine/node.py +++ b/tensorflow/python/keras/engine/node.py @@ -24,6 +24,7 @@ import json import numpy as np from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util from tensorflow.python.keras import backend from tensorflow.python.keras.engine import base_layer_utils from tensorflow.python.keras.utils import tf_utils @@ -73,6 +74,9 @@ class Node(object): # Cached for performance. self._flat_arguments = nest.flatten((self.call_args, self.call_kwargs)) + # Used to avoid expensive `nest` operations in the most common case. + self._single_positional_tensor_passed = (not self.call_kwargs and len( + self.call_args) == 1 and tensor_util.is_tensor(self.call_args[0])) # Create TensorFlowOpLayers if needed. for obj in self._flat_arguments: @@ -102,6 +106,10 @@ class Node(object): tensor._keras_history = KerasHistory( layer=layer, node_index=node_index, tensor_index=i) + # Cached for performance. + self.flat_input_ids = [str(id(t)) for t in self._keras_inputs] + self.flat_output_ids = [str(id(t)) for t in nest.flatten(self.outputs)] + @property def keras_inputs(self): """Tensors input to this node that can be traced back to a `keras.Input`.""" @@ -133,13 +141,18 @@ class Node(object): def map_arguments(self, tensor_dict): """Maps Keras Tensors to computed Tensors using `tensor_dict`.""" - flat_arguments = copy.copy(self._flat_arguments) - for kt_id, kt_index in self._keras_inputs_ids_and_indices: - flat_arguments[kt_index] = tensor_dict[kt_id].pop() + if self._single_positional_tensor_passed: + # Performance optimization for most common case. + kt_id, _ = self._keras_inputs_ids_and_indices[0] + return (tensor_dict[kt_id].pop(),), {} + else: + flat_arguments = copy.copy(self._flat_arguments) + for kt_id, kt_index in self._keras_inputs_ids_and_indices: + flat_arguments[kt_index] = tensor_dict[kt_id].pop() - args, kwargs = nest.pack_sequence_as( - (self.call_args, self.call_kwargs), flat_arguments) - return args, kwargs + args, kwargs = nest.pack_sequence_as((self.call_args, self.call_kwargs), + flat_arguments) + return args, kwargs def serialize(self, make_node_key, node_conversion_map): """Serializes `Node` for Functional API's `get_config`.""" diff --git a/tensorflow/python/keras/engine/sequential.py b/tensorflow/python/keras/engine/sequential.py index d07ed477ba9..d8325b98504 100644 --- a/tensorflow/python/keras/engine/sequential.py +++ b/tensorflow/python/keras/engine/sequential.py @@ -397,7 +397,7 @@ class Sequential(functional.Functional): raise ValueError(SINGLE_LAYER_OUTPUT_ERROR_MSG) # `outputs` will be the inputs to the next layer. inputs = outputs - mask = outputs._keras_mask + mask = getattr(outputs, '_keras_mask', None) return outputs def compute_output_shape(self, input_shape): @@ -411,7 +411,7 @@ class Sequential(functional.Functional): # by itself because it will duplicate any updates and losses in graph # mode by `call`ing the Layers again. outputs = self.call(inputs, mask=mask) - return outputs._keras_mask + return getattr(outputs, '_keras_mask', None) @deprecated('2021-01-01', 'Please use `model.predict()` instead.') def predict_proba(self, x, batch_size=32, verbose=0): diff --git a/tensorflow/python/keras/engine/training_utils.py b/tensorflow/python/keras/engine/training_utils.py index 680f33f75a5..0d7637cb98c 100644 --- a/tensorflow/python/keras/engine/training_utils.py +++ b/tensorflow/python/keras/engine/training_utils.py @@ -1935,7 +1935,7 @@ def get_input_shape_and_dtype(layer): raise ValueError('An empty Model cannot be used as a Layer.') layer = layer.layers[0] - if hasattr(layer, '_batch_input_shape'): + if getattr(layer, '_batch_input_shape', None): return layer._batch_input_shape, layer.dtype return None, None diff --git a/tensorflow/python/keras/feature_column/BUILD b/tensorflow/python/keras/feature_column/BUILD index 650efcceb52..6af53646d2f 100644 --- a/tensorflow/python/keras/feature_column/BUILD +++ b/tensorflow/python/keras/feature_column/BUILD @@ -12,15 +12,108 @@ exports_files(["LICENSE"]) py_library( name = "feature_column", + srcs = ["__init__.py"], deps = [ + ":base_feature_layer", + ":dense_features", + ":dense_features_v2", ":sequence_feature_column", ], ) +py_library( + name = "base_feature_layer", + srcs = ["base_feature_layer.py"], + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:variable_scope", + "//tensorflow/python/feature_column:feature_column_v2", + "//tensorflow/python/keras/engine:base_layer", + "//tensorflow/python/keras/utils:generic_utils", + ], +) + +py_library( + name = "dense_features", + srcs = [ + "dense_features.py", + ], + deps = [ + ":base_feature_layer", + "//tensorflow/python:framework_ops", + "//tensorflow/python:tf_export", + "//tensorflow/python:util", + "//tensorflow/python/feature_column:feature_column_v2", + "//tensorflow/python/keras:backend", + ], +) + +py_library( + name = "dense_features_v2", + srcs = [ + "dense_features_v2.py", + ], + deps = [ + ":base_feature_layer", + ":dense_features", + "//tensorflow/python:framework_ops", + "//tensorflow/python:tf_export", + "//tensorflow/python/feature_column:feature_column_v2", + ], +) + +tf_py_test( + name = "dense_features_test", + srcs = ["dense_features_test.py"], + tags = ["no_pip"], + deps = [ + ":dense_features", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:lookup_ops", + "//tensorflow/python:partitioned_variables", + "//tensorflow/python:session", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:variables", + "//tensorflow/python/eager:backprop", + "//tensorflow/python/eager:context", + "//tensorflow/python/feature_column:feature_column_v2", + ], +) + +tf_py_test( + name = "dense_features_v2_test", + srcs = ["dense_features_v2_test.py"], + tags = ["no_pip"], + deps = [ + ":dense_features_v2", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:lookup_ops", + "//tensorflow/python:session", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:variables", + "//tensorflow/python/eager:backprop", + "//tensorflow/python/eager:context", + "//tensorflow/python/feature_column:feature_column_v2", + ], +) + py_library( name = "sequence_feature_column", srcs = ["sequence_feature_column.py"], deps = [ + ":base_feature_layer", "//tensorflow/python:array_ops", "//tensorflow/python:check_ops", "//tensorflow/python:framework_ops", @@ -59,6 +152,7 @@ py_test( srcs_version = "PY2AND3", tags = ["no_pip"], deps = [ + ":dense_features", ":sequence_feature_column", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_test_lib", diff --git a/tensorflow/python/keras/feature_column/__init__.py b/tensorflow/python/keras/feature_column/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tensorflow/python/keras/feature_column/base_feature_layer.py b/tensorflow/python/keras/feature_column/base_feature_layer.py new file mode 100644 index 00000000000..12f507efe83 --- /dev/null +++ b/tensorflow/python/keras/feature_column/base_feature_layer.py @@ -0,0 +1,145 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""This API defines FeatureColumn abstraction.""" + +# This file was originally under tf/python/feature_column, and was moved to +# Keras package in order to remove the reverse dependency from TF to Keras. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +from tensorflow.python.feature_column import feature_column_v2 +from tensorflow.python.keras.engine.base_layer import Layer +from tensorflow.python.keras.utils import generic_utils +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variable_scope + + +class _BaseFeaturesLayer(Layer): + """Base class for DenseFeatures and SequenceFeatures. + + Defines common methods and helpers. + + Args: + feature_columns: An iterable containing the FeatureColumns to use as + inputs to your model. + expected_column_type: Expected class for provided feature columns. + trainable: Boolean, whether the layer's variables will be updated via + gradient descent during training. + name: Name to give to the DenseFeatures. + **kwargs: Keyword arguments to construct a layer. + + Raises: + ValueError: if an item in `feature_columns` doesn't match + `expected_column_type`. + """ + + def __init__(self, + feature_columns, + expected_column_type, + trainable, + name, + partitioner=None, + **kwargs): + super(_BaseFeaturesLayer, self).__init__( + name=name, trainable=trainable, **kwargs) + self._feature_columns = feature_column_v2._normalize_feature_columns( # pylint: disable=protected-access + feature_columns) + self._state_manager = feature_column_v2._StateManagerImpl( # pylint: disable=protected-access + self, self.trainable) + self._partitioner = partitioner + for column in self._feature_columns: + if not isinstance(column, expected_column_type): + raise ValueError( + 'Items of feature_columns must be a {}. ' + 'You can wrap a categorical column with an ' + 'embedding_column or indicator_column. Given: {}'.format( + expected_column_type, column)) + + def build(self, _): + for column in self._feature_columns: + with variable_scope._pure_variable_scope( # pylint: disable=protected-access + self.name, + partitioner=self._partitioner): + with variable_scope._pure_variable_scope( # pylint: disable=protected-access + feature_column_v2._sanitize_column_name_for_variable_scope( # pylint: disable=protected-access + column.name)): + column.create_state(self._state_manager) + super(_BaseFeaturesLayer, self).build(None) + + def _output_shape(self, input_shape, num_elements): + """Computes expected output shape of the layer or a column's dense tensor. + + Args: + input_shape: Tensor or array with batch shape. + num_elements: Size of the last dimension of the output. + + Returns: + Tuple with output shape. + """ + raise NotImplementedError('Calling an abstract method.') + + def compute_output_shape(self, input_shape): + total_elements = 0 + for column in self._feature_columns: + total_elements += column.variable_shape.num_elements() + return self._target_shape(input_shape, total_elements) + + def _process_dense_tensor(self, column, tensor): + """Reshapes the dense tensor output of a column based on expected shape. + + Args: + column: A DenseColumn or SequenceDenseColumn object. + tensor: A dense tensor obtained from the same column. + + Returns: + Reshaped dense tensor. + """ + num_elements = column.variable_shape.num_elements() + target_shape = self._target_shape(array_ops.shape(tensor), num_elements) + return array_ops.reshape(tensor, shape=target_shape) + + def _verify_and_concat_tensors(self, output_tensors): + """Verifies and concatenates the dense output of several columns.""" + feature_column_v2._verify_static_batch_size_equality( # pylint: disable=protected-access + output_tensors, self._feature_columns) + return array_ops.concat(output_tensors, -1) + + def get_config(self): + # Import here to avoid circular imports. + from tensorflow.python.feature_column import serialization # pylint: disable=g-import-not-at-top + column_configs = serialization.serialize_feature_columns( + self._feature_columns) + config = {'feature_columns': column_configs} + config['partitioner'] = generic_utils.serialize_keras_object( + self._partitioner) + + base_config = super( # pylint: disable=bad-super-call + _BaseFeaturesLayer, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + @classmethod + def from_config(cls, config, custom_objects=None): + # Import here to avoid circular imports. + from tensorflow.python.feature_column import serialization # pylint: disable=g-import-not-at-top + config_cp = config.copy() + config_cp['feature_columns'] = serialization.deserialize_feature_columns( + config['feature_columns'], custom_objects=custom_objects) + config_cp['partitioner'] = generic_utils.deserialize_keras_object( + config['partitioner'], custom_objects) + + return cls(**config_cp) diff --git a/tensorflow/python/feature_column/dense_features.py b/tensorflow/python/keras/feature_column/dense_features.py similarity index 96% rename from tensorflow/python/feature_column/dense_features.py rename to tensorflow/python/keras/feature_column/dense_features.py index 6feef185815..ef533b71fe7 100644 --- a/tensorflow/python/feature_column/dense_features.py +++ b/tensorflow/python/keras/feature_column/dense_features.py @@ -23,13 +23,13 @@ import json from tensorflow.python.feature_column import feature_column_v2 as fc from tensorflow.python.framework import ops from tensorflow.python.keras import backend -from tensorflow.python.keras.layers import serialization as layer_serialization +from tensorflow.python.keras.feature_column import base_feature_layer as kfc from tensorflow.python.util import serialization from tensorflow.python.util.tf_export import keras_export @keras_export(v1=['keras.layers.DenseFeatures']) -class DenseFeatures(fc._BaseFeaturesLayer): # pylint: disable=protected-access +class DenseFeatures(kfc._BaseFeaturesLayer): # pylint: disable=protected-access """A layer that produces a dense `Tensor` based on given `feature_columns`. Generally a single example in training data is described with FeatureColumns. @@ -173,7 +173,3 @@ class DenseFeatures(fc._BaseFeaturesLayer): # pylint: disable=protected-access cols_to_output_tensors[column] = processed_tensors output_tensors.append(processed_tensors) return self._verify_and_concat_tensors(output_tensors) - - -layer_serialization.inject_feature_column_v1_objects( - 'DenseFeatures', DenseFeatures) diff --git a/tensorflow/python/feature_column/dense_features_test.py b/tensorflow/python/keras/feature_column/dense_features_test.py similarity index 60% rename from tensorflow/python/feature_column/dense_features_test.py rename to tensorflow/python/keras/feature_column/dense_features_test.py index 7cd523dcc14..76b91dd605f 100644 --- a/tensorflow/python/feature_column/dense_features_test.py +++ b/tensorflow/python/keras/feature_column/dense_features_test.py @@ -18,22 +18,25 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized import numpy as np from tensorflow.python.client import session from tensorflow.python.eager import backprop from tensorflow.python.eager import context -from tensorflow.python.feature_column import dense_features as df from tensorflow.python.feature_column import feature_column_v2 as fc +from tensorflow.python.feature_column import sequence_feature_column as sfc from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util +from tensorflow.python.keras.feature_column import dense_features as df from tensorflow.python.ops import array_ops from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import partitioned_variables +from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as variables_lib from tensorflow.python.platform import test @@ -676,5 +679,452 @@ class DenseFeaturesTest(test.TestCase): sess.run(net, feed_dict={features['price']: np.array(1)}) +class IndicatorColumnTest(test.TestCase): + + @test_util.run_deprecated_v1 + def test_dense_features(self): + animal = fc.indicator_column( + fc.categorical_column_with_identity('animal', num_buckets=4)) + with ops.Graph().as_default(): + features = { + 'animal': + sparse_tensor.SparseTensor( + indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2]) + } + net = df.DenseFeatures([animal])(features) + + self.evaluate(variables_lib.global_variables_initializer()) + self.evaluate(lookup_ops.tables_initializer()) + + self.assertAllClose([[0., 1., 1., 0.]], self.evaluate(net)) + + +class EmbeddingColumnTest(test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + { + 'testcase_name': 'use_safe_embedding_lookup', + 'use_safe_embedding_lookup': True, + 'partition_variables': False, + }, { + 'testcase_name': 'dont_use_safe_embedding_lookup', + 'use_safe_embedding_lookup': False, + 'partition_variables': False, + }, { + 'testcase_name': 'use_safe_embedding_lookup_partitioned', + 'use_safe_embedding_lookup': True, + 'partition_variables': True, + }, { + 'testcase_name': 'dont_use_safe_embedding_lookup_partitioned', + 'use_safe_embedding_lookup': False, + 'partition_variables': True, + }) + @test_util.run_deprecated_v1 + def test_dense_features(self, use_safe_embedding_lookup, partition_variables): + # Inputs. + vocabulary_size = 4 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + # example 2, ids [] + # example 3, ids [1] + indices=((0, 0), (1, 0), (1, 4), (3, 0)), + values=(2, 0, 1, 1), + dense_shape=(4, 5)) + + # Embedding variable. + embedding_dimension = 2 + embedding_values = ( + (1., 2.), # id 0 + (3., 5.), # id 1 + (7., 11.), # id 2 + (9., 13.) # id 3 + ) + + def _initializer(shape, dtype, partition_info=None): + if partition_variables: + self.assertEqual([vocabulary_size, embedding_dimension], + partition_info.full_shape) + self.assertAllEqual((2, embedding_dimension), shape) + else: + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertIsNone(partition_info) + + self.assertEqual(dtypes.float32, dtype) + return embedding_values + + # Expected lookup result, using combiner='mean'. + expected_lookups = ( + # example 0, ids [2], embedding = [7, 11] + (7., 11.), + # example 1, ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5] + (2., 3.5), + # example 2, ids [], embedding = [0, 0] + (0., 0.), + # example 3, ids [1], embedding = [3, 5] + (3., 5.), + ) + + # Build columns. + categorical_column = fc.categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + partitioner = None + if partition_variables: + partitioner = partitioned_variables.fixed_size_partitioner(2, axis=0) + with variable_scope.variable_scope('vars', partitioner=partitioner): + embedding_column = fc.embedding_column( + categorical_column, + dimension=embedding_dimension, + initializer=_initializer, + use_safe_embedding_lookup=use_safe_embedding_lookup) + + # Provide sparse input and get dense result. + l = df.DenseFeatures((embedding_column,)) + dense_features = l({'aaa': sparse_input}) + + # Assert expected embedding variable and lookups. + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + if partition_variables: + self.assertCountEqual( + ('vars/dense_features/aaa_embedding/embedding_weights/part_0:0', + 'vars/dense_features/aaa_embedding/embedding_weights/part_1:0'), + tuple([v.name for v in global_vars])) + else: + self.assertCountEqual( + ('vars/dense_features/aaa_embedding/embedding_weights:0',), + tuple([v.name for v in global_vars])) + for v in global_vars: + self.assertIsInstance(v, variables_lib.Variable) + trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) + if partition_variables: + self.assertCountEqual( + ('vars/dense_features/aaa_embedding/embedding_weights/part_0:0', + 'vars/dense_features/aaa_embedding/embedding_weights/part_1:0'), + tuple([v.name for v in trainable_vars])) + else: + self.assertCountEqual( + ('vars/dense_features/aaa_embedding/embedding_weights:0',), + tuple([v.name for v in trainable_vars])) + + self.evaluate(variables_lib.global_variables_initializer()) + self.evaluate(lookup_ops.tables_initializer()) + + self.assertAllEqual(embedding_values, self.evaluate(trainable_vars[0])) + self.assertAllEqual(expected_lookups, self.evaluate(dense_features)) + + if use_safe_embedding_lookup: + self.assertIn('SparseFillEmptyRows', + [x.type for x in ops.get_default_graph().get_operations()]) + else: + self.assertNotIn( + 'SparseFillEmptyRows', + [x.type for x in ops.get_default_graph().get_operations()]) + + @test_util.run_deprecated_v1 + def test_dense_features_not_trainable(self): + # Inputs. + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + # example 2, ids [] + # example 3, ids [1] + indices=((0, 0), (1, 0), (1, 4), (3, 0)), + values=(2, 0, 1, 1), + dense_shape=(4, 5)) + + # Embedding variable. + embedding_dimension = 2 + embedding_values = ( + (1., 2.), # id 0 + (3., 5.), # id 1 + (7., 11.) # id 2 + ) + + def _initializer(shape, dtype, partition_info=None): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + + # Expected lookup result, using combiner='mean'. + expected_lookups = ( + # example 0, ids [2], embedding = [7, 11] + (7., 11.), + # example 1, ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5] + (2., 3.5), + # example 2, ids [], embedding = [0, 0] + (0., 0.), + # example 3, ids [1], embedding = [3, 5] + (3., 5.), + ) + + # Build columns. + categorical_column = fc.categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column = fc.embedding_column( + categorical_column, + dimension=embedding_dimension, + initializer=_initializer, + trainable=False) + + # Provide sparse input and get dense result. + dense_features = df.DenseFeatures((embedding_column,))({ + 'aaa': sparse_input + }) + + # Assert expected embedding variable and lookups. + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertCountEqual(('dense_features/aaa_embedding/embedding_weights:0',), + tuple([v.name for v in global_vars])) + self.assertCountEqual([], + ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)) + + self.evaluate(variables_lib.global_variables_initializer()) + self.evaluate(lookup_ops.tables_initializer()) + + self.assertAllEqual(embedding_values, self.evaluate(global_vars[0])) + self.assertAllEqual(expected_lookups, self.evaluate(dense_features)) + + +class SharedEmbeddingColumnTest(test.TestCase, parameterized.TestCase): + + def _test_dense_features(self, trainable=True): + # Inputs. + vocabulary_size = 3 + sparse_input_a = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 4)), + values=(2, 0, 1), + dense_shape=(2, 5)) + sparse_input_b = sparse_tensor.SparseTensorValue( + # example 0, ids [0] + # example 1, ids [] + indices=((0, 0),), + values=(0,), + dense_shape=(2, 5)) + sparse_input_c = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 1), (1, 1), (1, 3)), + values=(2, 0, 1), + dense_shape=(2, 5)) + sparse_input_d = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [] + indices=((0, 1),), + values=(2,), + dense_shape=(2, 5)) + + # Embedding variable. + embedding_dimension = 2 + embedding_values = ( + (1., 2.), # id 0 + (3., 5.), # id 1 + (7., 11.) # id 2 + ) + + def _initializer(shape, dtype, partition_info=None): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + + # Expected lookup result, using combiner='mean'. + expected_lookups = ( + # example 0: + # A ids [2], embedding = [7, 11] + # B ids [0], embedding = [1, 2] + # C ids [2], embedding = [7, 11] + # D ids [2], embedding = [7, 11] + (7., 11., 1., 2., 7., 11., 7., 11.), + # example 1: + # A ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5] + # B ids [], embedding = [0, 0] + # C ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5] + # D ids [], embedding = [0, 0] + (2., 3.5, 0., 0., 2., 3.5, 0., 0.), + ) + + # Build columns. + categorical_column_a = fc.categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + categorical_column_b = fc.categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + categorical_column_c = fc.categorical_column_with_identity( + key='ccc', num_buckets=vocabulary_size) + categorical_column_d = fc.categorical_column_with_identity( + key='ddd', num_buckets=vocabulary_size) + + embedding_column_a, embedding_column_b = fc.shared_embedding_columns_v2( + [categorical_column_a, categorical_column_b], + dimension=embedding_dimension, + initializer=_initializer, + trainable=trainable) + embedding_column_c, embedding_column_d = fc.shared_embedding_columns_v2( + [categorical_column_c, categorical_column_d], + dimension=embedding_dimension, + initializer=_initializer, + trainable=trainable) + + features = { + 'aaa': sparse_input_a, + 'bbb': sparse_input_b, + 'ccc': sparse_input_c, + 'ddd': sparse_input_d + } + + # Provide sparse input and get dense result. + dense_features = df.DenseFeatures( + feature_columns=(embedding_column_b, embedding_column_a, + embedding_column_c, embedding_column_d))( + features) + + # Assert expected embedding variable and lookups. + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertCountEqual( + ['aaa_bbb_shared_embedding:0', 'ccc_ddd_shared_embedding:0'], + tuple([v.name for v in global_vars])) + for v in global_vars: + self.assertIsInstance(v, variables_lib.Variable) + trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) + if trainable: + self.assertCountEqual( + ['aaa_bbb_shared_embedding:0', 'ccc_ddd_shared_embedding:0'], + tuple([v.name for v in trainable_vars])) + else: + self.assertCountEqual([], tuple([v.name for v in trainable_vars])) + shared_embedding_vars = global_vars + + self.evaluate(variables_lib.global_variables_initializer()) + self.evaluate(lookup_ops.tables_initializer()) + + self.assertAllEqual(embedding_values, + self.evaluate(shared_embedding_vars[0])) + self.assertAllEqual(expected_lookups, self.evaluate(dense_features)) + + @test_util.run_deprecated_v1 + def test_dense_features(self): + self._test_dense_features() + + @test_util.run_deprecated_v1 + def test_dense_features_no_trainable(self): + self._test_dense_features(trainable=False) + + +@test_util.run_all_in_graph_and_eager_modes +class DenseFeaturesSerializationTest(test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + ('default', None, None), + ('trainable', True, 'trainable'), + ('not_trainable', False, 'frozen')) + def test_get_config(self, trainable, name): + cols = [fc.numeric_column('a'), + fc.embedding_column(fc.categorical_column_with_identity( + key='b', num_buckets=3), dimension=2)] + orig_layer = df.DenseFeatures( + cols, trainable=trainable, name=name) + config = orig_layer.get_config() + + self.assertEqual(config['name'], orig_layer.name) + self.assertEqual(config['trainable'], trainable) + self.assertLen(config['feature_columns'], 2) + self.assertEqual( + config['feature_columns'][0]['class_name'], 'NumericColumn') + self.assertEqual(config['feature_columns'][0]['config']['shape'], (1,)) + self.assertEqual( + config['feature_columns'][1]['class_name'], 'EmbeddingColumn') + + @parameterized.named_parameters( + ('default', None, None), + ('trainable', True, 'trainable'), + ('not_trainable', False, 'frozen')) + def test_from_config(self, trainable, name): + cols = [fc.numeric_column('a'), + fc.embedding_column(fc.categorical_column_with_vocabulary_list( + 'b', vocabulary_list=['1', '2', '3']), dimension=2), + fc.indicator_column(fc.categorical_column_with_hash_bucket( + key='c', hash_bucket_size=3))] + orig_layer = df.DenseFeatures( + cols, trainable=trainable, name=name) + config = orig_layer.get_config() + + new_layer = df.DenseFeatures.from_config(config) + + self.assertEqual(new_layer.name, orig_layer.name) + self.assertEqual(new_layer.trainable, trainable) + self.assertLen(new_layer._feature_columns, 3) + self.assertEqual(new_layer._feature_columns[0].name, 'a') + self.assertEqual(new_layer._feature_columns[1].initializer.mean, 0.0) + self.assertEqual(new_layer._feature_columns[1].categorical_column.name, 'b') + self.assertIsInstance(new_layer._feature_columns[2], fc.IndicatorColumn) + + def test_crossed_column(self): + a = fc.categorical_column_with_vocabulary_list( + 'a', vocabulary_list=['1', '2', '3']) + b = fc.categorical_column_with_vocabulary_list( + 'b', vocabulary_list=['1', '2', '3']) + ab = fc.crossed_column([a, b], hash_bucket_size=2) + cols = [fc.indicator_column(ab)] + + orig_layer = df.DenseFeatures(cols) + config = orig_layer.get_config() + + new_layer = df.DenseFeatures.from_config(config) + + self.assertLen(new_layer._feature_columns, 1) + self.assertEqual(new_layer._feature_columns[0].name, 'a_X_b_indicator') + + +@test_util.run_all_in_graph_and_eager_modes +class SequenceFeatureColumnsTest(test.TestCase): + """Tests DenseFeatures with sequence feature columns.""" + + def test_embedding_column(self): + """Tests that error is raised for sequence embedding column.""" + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + + categorical_column_a = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column_a = fc.embedding_column( + categorical_column_a, dimension=2) + + input_layer = df.DenseFeatures([embedding_column_a]) + with self.assertRaisesRegexp( + ValueError, + r'In embedding_column: aaa_embedding\. categorical_column must not be ' + r'of type SequenceCategoricalColumn\.'): + _ = input_layer({'aaa': sparse_input}) + + def test_indicator_column(self): + """Tests that error is raised for sequence indicator column.""" + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + + categorical_column_a = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + indicator_column_a = fc.indicator_column(categorical_column_a) + + input_layer = df.DenseFeatures([indicator_column_a]) + with self.assertRaisesRegexp( + ValueError, + r'In indicator_column: aaa_indicator\. categorical_column must not be ' + r'of type SequenceCategoricalColumn\.'): + _ = input_layer({'aaa': sparse_input}) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/feature_column/dense_features_v2.py b/tensorflow/python/keras/feature_column/dense_features_v2.py similarity index 92% rename from tensorflow/python/feature_column/dense_features_v2.py rename to tensorflow/python/keras/feature_column/dense_features_v2.py index 405c5d63249..40c71ce7bd6 100644 --- a/tensorflow/python/feature_column/dense_features_v2.py +++ b/tensorflow/python/keras/feature_column/dense_features_v2.py @@ -18,10 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.feature_column import dense_features from tensorflow.python.feature_column import feature_column_v2 as fc from tensorflow.python.framework import ops -from tensorflow.python.keras.layers import serialization as layer_serialization +from tensorflow.python.keras.feature_column import base_feature_layer as kfc +from tensorflow.python.keras.feature_column import dense_features from tensorflow.python.util.tf_export import keras_export @@ -93,8 +93,4 @@ class DenseFeatures(dense_features.DenseFeatures): column.create_state(self._state_manager) # We would like to call Layer.build and not _DenseFeaturesHelper.build. # pylint: disable=protected-access - super(fc._BaseFeaturesLayer, self).build(None) # pylint: disable=bad-super-call - - -layer_serialization.inject_feature_column_v2_objects( - 'DenseFeatures', DenseFeatures) + super(kfc._BaseFeaturesLayer, self).build(None) # pylint: disable=bad-super-call diff --git a/tensorflow/python/feature_column/dense_features_v2_test.py b/tensorflow/python/keras/feature_column/dense_features_v2_test.py similarity index 99% rename from tensorflow/python/feature_column/dense_features_v2_test.py rename to tensorflow/python/keras/feature_column/dense_features_v2_test.py index 71cb163a7d9..95fc8b7ac1e 100644 --- a/tensorflow/python/feature_column/dense_features_v2_test.py +++ b/tensorflow/python/keras/feature_column/dense_features_v2_test.py @@ -23,7 +23,6 @@ import numpy as np from tensorflow.python.client import session from tensorflow.python.eager import backprop from tensorflow.python.eager import context -from tensorflow.python.feature_column import dense_features_v2 as df from tensorflow.python.feature_column import feature_column_v2 as fc from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -31,6 +30,7 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util +from tensorflow.python.keras.feature_column import dense_features_v2 as df from tensorflow.python.ops import array_ops from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import variables as variables_lib diff --git a/tensorflow/python/keras/feature_column/sequence_feature_column.py b/tensorflow/python/keras/feature_column/sequence_feature_column.py index 856e385c8fa..5f64ca9642e 100644 --- a/tensorflow/python/keras/feature_column/sequence_feature_column.py +++ b/tensorflow/python/keras/feature_column/sequence_feature_column.py @@ -24,6 +24,7 @@ from __future__ import print_function from tensorflow.python.feature_column import feature_column_v2 as fc from tensorflow.python.framework import ops from tensorflow.python.keras import backend +from tensorflow.python.keras.feature_column import base_feature_layer as kfc from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.util.tf_export import keras_export @@ -32,7 +33,7 @@ from tensorflow.python.util.tf_export import keras_export @keras_export('keras.experimental.SequenceFeatures') -class SequenceFeatures(fc._BaseFeaturesLayer): +class SequenceFeatures(kfc._BaseFeaturesLayer): """A layer for sequence input. All `feature_columns` must be sequence dense columns with the same diff --git a/tensorflow/python/keras/feature_column/sequence_feature_column_integration_test.py b/tensorflow/python/keras/feature_column/sequence_feature_column_integration_test.py index 8784182e23b..b1100bf7b07 100644 --- a/tensorflow/python/keras/feature_column/sequence_feature_column_integration_test.py +++ b/tensorflow/python/keras/feature_column/sequence_feature_column_integration_test.py @@ -24,11 +24,11 @@ from google.protobuf import text_format from tensorflow.core.example import example_pb2 from tensorflow.core.example import feature_pb2 from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.feature_column import dense_features from tensorflow.python.feature_column import feature_column_v2 as fc from tensorflow.python.feature_column import sequence_feature_column as sfc from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util +from tensorflow.python.keras.feature_column import dense_features from tensorflow.python.keras.feature_column import sequence_feature_column as ksfc from tensorflow.python.keras.layers import recurrent from tensorflow.python.ops import init_ops_v2 diff --git a/tensorflow/python/keras/integration_test/BUILD b/tensorflow/python/keras/integration_test/BUILD index 01c405a86ae..80d8fb86345 100644 --- a/tensorflow/python/keras/integration_test/BUILD +++ b/tensorflow/python/keras/integration_test/BUILD @@ -1,7 +1,7 @@ # Description: # Contains Keras integration tests that verify with other TF high level APIs. -load("//tensorflow:tensorflow.bzl", "tf_py_test") +load("//tensorflow:tensorflow.bzl", "cuda_py_test", "tf_py_test") package( default_visibility = [ @@ -70,3 +70,13 @@ tf_py_test( "//tensorflow/python:extra_py_tests_deps", ], ) + +cuda_py_test( + name = "gradient_checkpoint_test", + srcs = ["gradient_checkpoint_test.py"], + python_version = "PY3", + deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/python:extra_py_tests_deps", + ], +) diff --git a/tensorflow/python/keras/integration_test/gradient_checkpoint_test.py b/tensorflow/python/keras/integration_test/gradient_checkpoint_test.py new file mode 100644 index 00000000000..9d9e0a062b3 --- /dev/null +++ b/tensorflow/python/keras/integration_test/gradient_checkpoint_test.py @@ -0,0 +1,158 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +layers = tf.keras.layers +optimizers = tf.keras.optimizers + + +def _get_big_cnn_model(img_dim, n_channels, num_partitions, + blocks_per_partition): + """Creates a test model whose activations are significantly larger than model size.""" + model = tf.keras.Sequential() + model.add(layers.Input(shape=(img_dim, img_dim, n_channels))) + for _ in range(num_partitions): + for _ in range(blocks_per_partition): + model.add(layers.Conv2D(10, 5, padding='same', activation=tf.nn.relu)) + model.add(layers.MaxPooling2D((1, 1), padding='same')) + model.add(layers.Conv2D(40, 5, padding='same', activation=tf.nn.relu)) + model.add(layers.MaxPooling2D((1, 1), padding='same')) + model.add(layers.Conv2D(20, 5, padding='same', activation=tf.nn.relu)) + model.add(layers.MaxPooling2D((1, 1), padding='same')) + model.add(layers.Flatten()) + model.add(layers.Dense(32, activation=tf.nn.relu)) + model.add(layers.Dense(10)) + return model + + +def _get_split_cnn_model(img_dim, n_channels, num_partitions, + blocks_per_partition): + """Creates a test model that is split into `num_partitions` smaller models""" + models = [tf.keras.Sequential() for _ in range(num_partitions)] + models[0].add(layers.Input(shape=(img_dim, img_dim, n_channels))) + for i in range(num_partitions): + model = models[i] + if i > 0: + last_shape = models[i - 1].layers[-1].output_shape + model.add(layers.Input(shape=last_shape[1:])) + for _ in range(blocks_per_partition): + model.add(layers.Conv2D(10, 5, padding='same', activation=tf.nn.relu)) + model.add(layers.MaxPooling2D((1, 1), padding='same')) + model.add(layers.Conv2D(40, 5, padding='same', activation=tf.nn.relu)) + model.add(layers.MaxPooling2D((1, 1), padding='same')) + model.add(layers.Conv2D(20, 5, padding='same', activation=tf.nn.relu)) + model.add(layers.MaxPooling2D((1, 1), padding='same')) + models[-1].add(layers.Flatten()) + models[-1].add(layers.Dense(32, activation=tf.nn.relu)) + models[-1].add(layers.Dense(10)) + return models + + +def _compute_loss(logits, labels): + return tf.reduce_mean( + tf.nn.sparse_softmax_cross_entropy_with_logits( + logits=logits, labels=labels)) + + +def _limit_gpu_memory(): + """Helper function to limit GPU memory for testing """ + gpus = tf.config.experimental.list_physical_devices('GPU') + if gpus: + tf.config.experimental.set_virtual_device_configuration( + gpus[0], + [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=1024)]) + return True + return False + + +def _get_dummy_data(img_dim, n_channels, batch_size): + inputs = tf.ones([batch_size, img_dim, img_dim, n_channels]) + labels = tf.ones([batch_size], dtype=tf.int64) + return inputs, labels + + +def _train_no_recompute(n_steps): + """Trains a single large model without gradient checkpointing.""" + img_dim, n_channels, batch_size = 256, 1, 4 + x, y = _get_dummy_data(img_dim, n_channels, batch_size) + model = _get_big_cnn_model( + img_dim, n_channels, num_partitions=3, blocks_per_partition=2) + optimizer = optimizers.SGD() + losses = [] + tr_vars = model.trainable_variables + for _ in range(n_steps): + with tf.GradientTape() as tape: + logits = model(x) + loss = _compute_loss(logits, y) + losses.append(loss) + grads = tape.gradient(loss, tr_vars) # tr_vars + optimizer.apply_gradients(zip(grads, tr_vars)) + del grads + return losses + + +def _train_with_recompute(n_steps): + """Trains a single large model with gradient checkpointing using tf.recompute_grad.""" + img_dim, n_channels, batch_size = 256, 1, 4 + x, y = _get_dummy_data(img_dim, n_channels, batch_size) + # This model is the same model as _get_big_cnn_model but split into 3 parts. + models = _get_split_cnn_model( + img_dim, n_channels, num_partitions=3, blocks_per_partition=2) + model1, model2, model3 = models + # Apply gradient checkpointing to the submodels using tf.recompute_grad. + model1_re = tf.recompute_grad(model1) + model2_re = tf.recompute_grad(model2) + model3_re = tf.recompute_grad(model3) + optimizer = optimizers.SGD() + tr_vars = ( + model1.trainable_variables + model2.trainable_variables + + model3.trainable_variables) + losses = [] + for _ in range(n_steps): + with tf.GradientTape() as tape: + logits1 = model1_re(x) + logits2 = model2_re(logits1) + logits3 = model3_re(logits2) + loss = _compute_loss(logits3, y) + losses.append(loss) + grads = tape.gradient(loss, tr_vars) # tr_vars + optimizer.apply_gradients(zip(grads, tr_vars)) + del grads + return losses + + +class GradientCheckpointTest(tf.test.TestCase): + + def test_raises_oom_exception(self): + if not _limit_gpu_memory(): + self.skipTest('No virtual GPUs found') + with self.assertRaises(Exception) as context: + _train_no_recompute(1) + self.assertTrue( + context.exception.__class__.__name__ == 'ResourceExhaustedError') + + def test_does_not_raise_oom_exception(self): + if not _limit_gpu_memory(): + self.skipTest('No virtual GPUs found') + n_step = 2 + losses = _train_with_recompute(n_step) + self.assertTrue(len(losses) == n_step) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/python/keras/layers/BUILD b/tensorflow/python/keras/layers/BUILD index 46ac88754a8..10a9fe088ab 100644 --- a/tensorflow/python/keras/layers/BUILD +++ b/tensorflow/python/keras/layers/BUILD @@ -213,12 +213,13 @@ py_library( "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", "//tensorflow/python:util", + "//tensorflow/python/distribute:sharded_variable", "//tensorflow/python/eager:context", "//tensorflow/python/keras:backend", - "//tensorflow/python/keras:base_layer", "//tensorflow/python/keras:constraints", "//tensorflow/python/keras:initializers", "//tensorflow/python/keras:regularizers", + "//tensorflow/python/keras/engine:base_layer", "//tensorflow/python/keras/utils:tf_utils", ], ) @@ -593,9 +594,15 @@ cuda_py_test( python_version = "PY3", deps = [ "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:training_lib", + "//tensorflow/python:variables", + "//tensorflow/python/eager:backprop", "//tensorflow/python/keras", "//tensorflow/python/keras:combinations", - "@absl_py//absl/testing:parameterized", + "//tensorflow/python/keras:testing_utils", + "//tensorflow/python/ops/ragged:ragged_factory_ops", ], ) diff --git a/tensorflow/python/keras/layers/__init__.py b/tensorflow/python/keras/layers/__init__.py index ede199a9169..e0f087b2453 100644 --- a/tensorflow/python/keras/layers/__init__.py +++ b/tensorflow/python/keras/layers/__init__.py @@ -44,6 +44,9 @@ from tensorflow.python.keras.layers.preprocessing.image_preprocessing import Res # Preprocessing layers. if tf2.enabled(): + from tensorflow.python.keras.layers.preprocessing.category_encoding import CategoryEncoding + from tensorflow.python.keras.layers.preprocessing.category_encoding_v1 import CategoryEncoding as CategoryEncodingV1 + CategoryEncodingV2 = CategoryEncoding from tensorflow.python.keras.layers.preprocessing.normalization import Normalization from tensorflow.python.keras.layers.preprocessing.normalization_v1 import Normalization as NormalizationV1 NormalizationV2 = Normalization @@ -51,13 +54,17 @@ if tf2.enabled(): from tensorflow.python.keras.layers.preprocessing.text_vectorization_v1 import TextVectorization as TextVectorizationV1 TextVectorizationV2 = TextVectorization else: + from tensorflow.python.keras.layers.preprocessing.category_encoding_v1 import CategoryEncoding + from tensorflow.python.keras.layers.preprocessing.category_encoding import CategoryEncoding as CategoryEncodingV2 + CategoryEncodingV1 = CategoryEncoding from tensorflow.python.keras.layers.preprocessing.normalization_v1 import Normalization from tensorflow.python.keras.layers.preprocessing.normalization import Normalization as NormalizationV2 NormalizationV1 = Normalization from tensorflow.python.keras.layers.preprocessing.text_vectorization_v1 import TextVectorization from tensorflow.python.keras.layers.preprocessing.text_vectorization import TextVectorization as TextVectorizationV2 TextVectorizationV1 = TextVectorization -from tensorflow.python.keras.layers.preprocessing.categorical_crossing import CategoryCrossing +from tensorflow.python.keras.layers.preprocessing.category_crossing import CategoryCrossing +from tensorflow.python.keras.layers.preprocessing.hashing import Hashing # Advanced activations. from tensorflow.python.keras.layers.advanced_activations import LeakyReLU diff --git a/tensorflow/python/keras/layers/core.py b/tensorflow/python/keras/layers/core.py index db9c47eca17..60834fad30b 100644 --- a/tensorflow/python/keras/layers/core.py +++ b/tensorflow/python/keras/layers/core.py @@ -460,7 +460,7 @@ class Reshape(Layer): >>> # also supports shape inference using `-1` as dimension >>> model.add(tf.keras.layers.Reshape((-1, 2, 2))) >>> model.output_shape - (None, None, 2, 2) + (None, 3, 2, 2) """ def __init__(self, target_shape, **kwargs): @@ -495,7 +495,9 @@ class Reshape(Layer): is specified. """ output_shape = list(output_shape) - msg = 'total size of new array must be unchanged' + msg = ('total size of new array must be unchanged, ' + 'input_shape = {}, output_shape = {}' + .format(input_shape, output_shape)) known, unknown = 1, None for index, dim in enumerate(output_shape): @@ -529,8 +531,13 @@ class Reshape(Layer): return tensor_shape.TensorShape(output_shape) def call(self, inputs): - return array_ops.reshape(inputs, - (array_ops.shape(inputs)[0],) + self.target_shape) + result = array_ops.reshape( + inputs, (array_ops.shape(inputs)[0],) + self.target_shape) + if not context.executing_eagerly(): + # Set the static shape for the result since it might lost during array_ops + # reshape, eg, some `None` dim in the result could be inferred. + result.set_shape(self.compute_output_shape(inputs.shape)) + return result def get_config(self): config = {'target_shape': self.target_shape} diff --git a/tensorflow/python/keras/layers/core_test.py b/tensorflow/python/keras/layers/core_test.py index 3daa187f1ce..70ad63c17eb 100644 --- a/tensorflow/python/keras/layers/core_test.py +++ b/tensorflow/python/keras/layers/core_test.py @@ -430,6 +430,12 @@ class CoreLayersTest(keras_parameterized.TestCase): kwargs={'target_shape': (-1, 1)}, input_shape=(None, None, 2)) + def test_reshape_set_static_shape(self): + input_layer = keras.Input(batch_shape=(1, None)) + reshaped = keras.layers.Reshape((1, 100))(input_layer) + # Make sure the batch dim is not lost after array_ops.reshape. + self.assertEqual(reshaped.shape, [1, 1, 100]) + def test_permute(self): testing_utils.layer_test( keras.layers.Permute, kwargs={'dims': (2, 1)}, input_shape=(3, 2, 4)) diff --git a/tensorflow/python/keras/layers/embeddings.py b/tensorflow/python/keras/layers/embeddings.py index 3f57fd6cb63..3444b3a7665 100644 --- a/tensorflow/python/keras/layers/embeddings.py +++ b/tensorflow/python/keras/layers/embeddings.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.distribute import sharded_variable from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.keras import backend as K @@ -129,8 +130,10 @@ class Embedding(Layer): # since it knows all kernels using the variable only exist on CPU. # When eager execution is enabled, the placement decision has to be made # right now. Checking for the presence of GPUs to avoid complicating the - # TPU codepaths which can handle sparse optimizers. - if context.executing_eagerly() and context.context().num_gpus(): + # TPU codepaths which can handle sparse optimizers. But if we are within + # a tf.function, we go back the graph mode logic and rely on the placer. + if (context.executing_eagerly() and context.context().num_gpus() and + not ops.inside_function()): with ops.device('cpu:0'): self.embeddings = self.add_weight( shape=(self.input_dim, self.output_dim), @@ -181,7 +184,10 @@ class Embedding(Layer): dtype = K.dtype(inputs) if dtype != 'int32' and dtype != 'int64': inputs = math_ops.cast(inputs, 'int32') - out = embedding_ops.embedding_lookup(self.embeddings, inputs) + if isinstance(self.embeddings, sharded_variable.ShardedVariable): + out = embedding_ops.embedding_lookup_v2(self.embeddings.variables, inputs) + else: + out = embedding_ops.embedding_lookup_v2(self.embeddings, inputs) return out def get_config(self): diff --git a/tensorflow/python/keras/layers/embeddings_test.py b/tensorflow/python/keras/layers/embeddings_test.py index 661b29cd7bf..6aa873b2bd7 100644 --- a/tensorflow/python/keras/layers/embeddings_test.py +++ b/tensorflow/python/keras/layers/embeddings_test.py @@ -21,12 +21,14 @@ from __future__ import print_function import numpy as np from tensorflow.python import keras +from tensorflow.python.distribute import sharded_variable from tensorflow.python.eager import backprop from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.keras import combinations from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import testing_utils +from tensorflow.python.ops import variables from tensorflow.python.ops.ragged import ragged_factory_ops from tensorflow.python.platform import test from tensorflow.python.training import adagrad @@ -130,6 +132,20 @@ class EmbeddingTest(keras_parameterized.TestCase): [[[1., 1.], [2., 2.], [2., 2.]], [[0., 0.]], [[1., 1.], [2., 2.]]], ragged_rank=1)) + @keras_parameterized.run_all_keras_modes(always_skip_v1=True) + def test_embedding_with_sharded_variable(self): + layer = keras.layers.Embedding(input_dim=5, output_dim=2) + v = [ + variables.Variable([[1., 2.], [3., 4.]]), + variables.Variable([[5., 6.], [7., 8.]]), + variables.Variable([[9., 10.]]) + ] + model = keras.models.Sequential([layer]) + layer.embeddings = sharded_variable.ShardedVariable(v) + model.run_eagerly = testing_utils.should_run_eagerly() + outputs = model.predict(np.array([[0, 2, 4]], dtype='int32')) + self.assertAllClose(outputs, [[[1., 2.], [5., 6.], [9., 10.]]]) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/layers/kernelized.py b/tensorflow/python/keras/layers/kernelized.py index ce53334ebc7..5f401899bec 100644 --- a/tensorflow/python/keras/layers/kernelized.py +++ b/tensorflow/python/keras/layers/kernelized.py @@ -191,15 +191,15 @@ class RandomFourierFeatures(base_layer.Layer): kernel_initializer = _get_random_features_initializer( self.kernel_initializer, shape=(input_dim, self.output_dim)) - unscaled_kernel = self.add_weight( - name='unscaled_random_features', + self.unscaled_kernel = self.add_weight( + name='unscaled_kernel', shape=(input_dim, self.output_dim), dtype=dtypes.float32, initializer=kernel_initializer, trainable=False) self.bias = self.add_weight( - name='random_features_bias', + name='bias', shape=(self.output_dim,), dtype=dtypes.float32, initializer=init_ops.random_uniform_initializer( @@ -208,20 +208,20 @@ class RandomFourierFeatures(base_layer.Layer): if self.scale is None: self.scale = _get_default_scale(self.kernel_initializer, input_dim) - scale = self.add_weight( - name='random_features_scale', + self.kernel_scale = self.add_weight( + name='kernel_scale', shape=(1,), dtype=dtypes.float32, initializer=init_ops.constant_initializer(self.scale), trainable=True, constraint='NonNeg') - self.kernel = (1.0 / scale) * unscaled_kernel super(RandomFourierFeatures, self).build(input_shape) def call(self, inputs): inputs = ops.convert_to_tensor_v2(inputs, dtype=self.dtype) inputs = gen_math_ops.cast(inputs, dtypes.float32) - outputs = gen_math_ops.mat_mul(inputs, self.kernel) + kernel = (1.0 / self.kernel_scale) * self.unscaled_kernel + outputs = gen_math_ops.mat_mul(inputs, kernel) outputs = nn.bias_add(outputs, self.bias) return gen_math_ops.cos(outputs) diff --git a/tensorflow/python/keras/layers/kernelized_test.py b/tensorflow/python/keras/layers/kernelized_test.py index edb58f77868..a6a9d88423f 100644 --- a/tensorflow/python/keras/layers/kernelized_test.py +++ b/tensorflow/python/keras/layers/kernelized_test.py @@ -20,6 +20,8 @@ from __future__ import print_function import functools import math +import os +import shutil from absl.testing import parameterized import numpy as np @@ -35,7 +37,10 @@ from tensorflow.python.keras import backend as keras_backend from tensorflow.python.keras import combinations from tensorflow.python.keras import initializers from tensorflow.python.keras.engine import base_layer_utils +from tensorflow.python.keras.engine import input_layer +from tensorflow.python.keras.engine import training from tensorflow.python.keras.layers import kernelized as kernel_layers +from tensorflow.python.keras.saving import save from tensorflow.python.keras.utils import kernelized_utils from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops @@ -65,6 +70,22 @@ class RandomFourierFeaturesTest(test.TestCase, parameterized.TestCase): else: self.assertAllClose(expected, actual, atol=atol) + @test_util.run_v2_only + def test_state_saving_and_loading(self): + input_data = np.random.random((1, 2)) + rff_layer = kernel_layers.RandomFourierFeatures(output_dim=10, scale=3.0) + inputs = input_layer.Input((2,)) + outputs = rff_layer(inputs) + model = training.Model(inputs, outputs) + output_data = model.predict(input_data) + temp_dir = self.get_temp_dir() + self.addCleanup(shutil.rmtree, temp_dir) + saved_model_dir = os.path.join(temp_dir, 'rff_model') + model.save(saved_model_dir) + new_model = save.load_model(saved_model_dir) + new_output_data = new_model.predict(input_data) + self.assertAllClose(output_data, new_output_data, atol=1e-4) + def test_invalid_output_dim(self): with self.assertRaisesRegexp( ValueError, r'`output_dim` should be a positive integer. Given: -3.'): @@ -246,7 +267,7 @@ class RandomFourierFeaturesTest(test.TestCase, parameterized.TestCase): num_trainable_vars = 1 if trainable else 0 self.assertLen(rff_layer.trainable_variables, num_trainable_vars) if trainable: - self.assertEqual('random_fourier_features/random_features_scale:0', + self.assertEqual('random_fourier_features/kernel_scale:0', rff_layer.trainable_variables[0].name) self.assertLen(rff_layer.non_trainable_variables, 3 - num_trainable_vars) diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py index a6d3c3c3e1c..213aadeb606 100644 --- a/tensorflow/python/keras/layers/normalization.py +++ b/tensorflow/python/keras/layers/normalization.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Normalization layers. -""" +"""Normalization layers.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -43,7 +42,7 @@ from tensorflow.python.util.tf_export import keras_export class BatchNormalizationBase(Layer): - r"""Normalize and scale inputs or activations. (Ioffe and Szegedy, 2014). + r"""Normalize and scale inputs or activations. Normalize the activations of the previous layer at each batch, i.e. applies a transformation that maintains the mean activation @@ -65,20 +64,16 @@ class BatchNormalizationBase(Layer): `training=False` when calling the model, or using `model.predict`. Arguments: - axis: Integer, the axis that should be normalized - (typically the features axis). - For instance, after a `Conv2D` layer with - `data_format="channels_first"`, - set `axis=1` in `BatchNormalization`. + axis: Integer, the axis that should be normalized (typically the features + axis). For instance, after a `Conv2D` layer with + `data_format="channels_first"`, set `axis=1` in `BatchNormalization`. momentum: Momentum for the moving average. epsilon: Small float added to variance to avoid dividing by zero. - center: If True, add offset of `beta` to normalized tensor. - If False, `beta` is ignored. - scale: If True, multiply by `gamma`. - If False, `gamma` is not used. - When the next layer is linear (also e.g. `nn.relu`), - this can be disabled since the scaling - will be done by the next layer. + center: If True, add offset of `beta` to normalized tensor. If False, `beta` + is ignored. + scale: If True, multiply by `gamma`. If False, `gamma` is not used. When the + next layer is linear (also e.g. `nn.relu`), this can be disabled since the + scaling will be done by the next layer. beta_initializer: Initializer for the beta weight. gamma_initializer: Initializer for the gamma weight. moving_mean_initializer: Initializer for the moving mean. @@ -89,17 +84,17 @@ class BatchNormalizationBase(Layer): gamma_constraint: Optional constraint for the gamma weight. renorm: Whether to use [Batch Renormalization]( https://arxiv.org/abs/1702.03275). This adds extra variables during - training. The inference is the same for either value of this parameter. + training. The inference is the same for either value of this parameter. renorm_clipping: A dictionary that may map keys 'rmax', 'rmin', 'dmax' to - scalar `Tensors` used to clip the renorm correction. The correction - `(r, d)` is used as `corrected_value = normalized_value * r + d`, with - `r` clipped to [rmin, rmax], and `d` to [-dmax, dmax]. Missing rmax, rmin, + scalar `Tensors` used to clip the renorm correction. The correction `(r, + d)` is used as `corrected_value = normalized_value * r + d`, with `r` + clipped to [rmin, rmax], and `d` to [-dmax, dmax]. Missing rmax, rmin, dmax are set to inf, 0, inf, respectively. renorm_momentum: Momentum used to update the moving means and standard - deviations with renorm. Unlike `momentum`, this affects training - and should be neither too small (which would add noise) nor too large - (which would give stale estimates). Note that `momentum` is still applied - to get the means and variances for inference. + deviations with renorm. Unlike `momentum`, this affects training and + should be neither too small (which would add noise) nor too large (which + would give stale estimates). Note that `momentum` is still applied to get + the means and variances for inference. fused: if `True`, use a faster, fused implementation, or raise a ValueError if the fused implementation cannot be used. If `None`, use the faster implementation if possible. If False, do not used the fused @@ -117,54 +112,36 @@ class BatchNormalizationBase(Layer): example, if axis==-1, `adjustment = lambda shape: ( tf.random.uniform(shape[-1:], 0.93, 1.07), - tf.random.uniform(shape[-1:], -0.1, 0.1))` - will scale the normalized value by up to 7% up or down, then shift the - result by up to 0.1 (with independent scaling and bias for each feature - but shared across all examples), and finally apply gamma and/or beta. If - `None`, no adjustment is applied. Cannot be specified if - virtual_batch_size is specified. - + tf.random.uniform(shape[-1:], -0.1, 0.1))` will scale the normalized + value by up to 7% up or down, then shift the result by up to 0.1 + (with independent scaling and bias for each feature but shared + across all examples), and finally apply gamma and/or beta. If + `None`, no adjustment is applied. Cannot be specified if + virtual_batch_size is specified. Call arguments: inputs: Input tensor (of any rank). training: Python boolean indicating whether the layer should behave in training mode or in inference mode. - - `training=True`: The layer will normalize its inputs using the - mean and variance of the current batch of inputs. - - `training=False`: The layer will normalize its inputs using the - mean and variance of its moving statistics, learned during training. - - Input shape: - Arbitrary. Use the keyword argument `input_shape` - (tuple of integers, does not include the samples axis) - when using this layer as the first layer in a model. - - Output shape: - Same shape as input. - - {{TRAINABLE_ATTRIBUTE_NOTE}} - - Normalization equations: - Consider the intermediate activations \(x\) of a mini-batch of size - \\(m\\): - - We can compute the mean and variance of the batch - - \\({\mu_B} = \frac{1}{m} \sum_{i=1}^{m} {x_i}\\) - - \\({\sigma_B^2} = \frac{1}{m} \sum_{i=1}^{m} ({x_i} - {\mu_B})^2\\) - - and then compute a normalized \\(x\\), including a small factor - \\({\epsilon}\\) for numerical stability. - - \\(\hat{x_i} = \frac{x_i - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}}\\) - - And finally \\(\hat{x}\) is linearly transformed by \({\gamma}\\) - and \\({\beta}\\), which are learned parameters: - - \\({y_i} = {\gamma * \hat{x_i} + \beta}\\) - + - `training=True`: The layer will normalize its inputs using the mean and + variance of the current batch of inputs. + - `training=False`: The layer will normalize its inputs using the mean and + variance of its moving statistics, learned during training. + Input shape: Arbitrary. Use the keyword argument `input_shape` (tuple of + integers, does not include the samples axis) when using this layer as the + first layer in a model. + Output shape: Same shape as input. {{TRAINABLE_ATTRIBUTE_NOTE}} + Normalization equations: Consider the intermediate activations \(x\) of a + mini-batch of size + \\(m\\): We can compute the mean and variance of the batch \\({\mu_B} = + \frac{1}{m} \sum_{i=1}^{m} {x_i}\\) \\({\sigma_B^2} = \frac{1}{m} + \sum_{i=1}^{m} ({x_i} - {\mu_B})^2\\) and then compute a normalized + \\(x\\), including a small factor \\({\epsilon}\\) for numerical + stability. \\(\hat{x_i} = \frac{x_i - \mu_B}{\sqrt{\sigma_B^2 + + \epsilon}}\\) And finally \\(\hat{x}\) is linearly transformed by + \({\gamma}\\) + and \\({\beta}\\), which are learned parameters: \\({y_i} = {\gamma * + \hat{x_i} + \beta}\\) Reference: - - [Ioffe and Szegedy, 2015](https://arxiv.org/abs/1502.03167). """ @@ -195,8 +172,7 @@ class BatchNormalizationBase(Layer): adjustment=None, name=None, **kwargs): - super(BatchNormalizationBase, self).__init__( - name=name, **kwargs) + super(BatchNormalizationBase, self).__init__(name=name, **kwargs) if isinstance(axis, (list, tuple)): self.axis = axis[:] elif isinstance(axis, int): @@ -275,8 +251,8 @@ class BatchNormalizationBase(Layer): # TODO(reedwm): Support fp64 in FusedBatchNorm then remove this check. if self._compute_dtype not in ('float16', 'bfloat16', 'float32', None): raise ValueError('Passing fused=True is only supported when the compute ' - 'dtype is float16, bfloat16, or float32. Got dtype: %s' - % (self._compute_dtype,)) + 'dtype is float16, bfloat16, or float32. Got dtype: %s' % + (self._compute_dtype,)) def _fused_can_be_used(self): try: @@ -380,13 +356,14 @@ class BatchNormalizationBase(Layer): param_shape = (list(axis_to_dim.values())[0],) else: # Parameter shape is the original shape but with 1 in all non-axis dims - param_shape = [axis_to_dim[i] if i in axis_to_dim - else 1 for i in range(ndims)] + param_shape = [ + axis_to_dim[i] if i in axis_to_dim else 1 for i in range(ndims) + ] if self.virtual_batch_size is not None: # When using virtual batches, add an extra dim at index 1 param_shape.insert(1, 1) for idx, x in enumerate(self.axis): - self.axis[idx] = x + 1 # Account for added dimension + self.axis[idx] = x + 1 # Account for added dimension if self.scale: self.gamma = self.add_weight( @@ -507,8 +484,7 @@ class BatchNormalizationBase(Layer): decay = ops.convert_to_tensor_v2(1.0 - momentum, name='decay') if decay.dtype != variable.dtype.base_dtype: decay = math_ops.cast(decay, variable.dtype.base_dtype) - update_delta = ( - variable - math_ops.cast(value, variable.dtype)) * decay + update_delta = (variable - math_ops.cast(value, variable.dtype)) * decay if inputs_size is not None: update_delta = array_ops.where(inputs_size > 0, update_delta, K.zeros_like(update_delta)) @@ -650,8 +626,9 @@ class BatchNormalizationBase(Layer): with ops.control_dependencies([r, d]): mean = array_ops.identity(mean) stddev = array_ops.identity(stddev) - rmin, rmax, dmax = [self.renorm_clipping.get(key) - for key in ['rmin', 'rmax', 'dmax']] + rmin, rmax, dmax = [ + self.renorm_clipping.get(key) for key in ['rmin', 'rmax', 'dmax'] + ] if rmin is not None: r = math_ops.maximum(r, rmin) if rmax is not None: @@ -661,13 +638,13 @@ class BatchNormalizationBase(Layer): d = math_ops.minimum(d, dmax) # When not training, use r=1, d=0. r = tf_utils.smart_cond(training, lambda: r, lambda: array_ops.ones_like(r)) - d = tf_utils.smart_cond(training, - lambda: d, + d = tf_utils.smart_cond(training, lambda: d, lambda: array_ops.zeros_like(d)) def _update_renorm_variable(var, value, inputs_size): """Updates a moving average and weight, returns the unbiased value.""" value = array_ops.identity(value) + def _do_update(): """Updates the var, returns the updated value.""" new_var = self._assign_moving_average(var, value, self.renorm_momentum, @@ -676,6 +653,7 @@ class BatchNormalizationBase(Layer): def _fake_update(): return array_ops.identity(var) + return tf_utils.smart_cond(training, _do_update, _fake_update) # TODO(yuefengz): colocate the operations @@ -712,9 +690,10 @@ class BatchNormalizationBase(Layer): if self._USE_V2_BEHAVIOR: if isinstance(training, int): training = bool(training) - # When the layer is not trainable, it overrides the value passed from - # model. - training = math_ops.logical_and(training, self.trainable) + if not self.trainable: + # When the layer is not trainable, it overrides the value passed from + # model. + training = False return training def call(self, inputs, training=None): @@ -752,12 +731,13 @@ class BatchNormalizationBase(Layer): ndims = len(input_shape) reduction_axes = [i for i in range(ndims) if i not in self.axis] if self.virtual_batch_size is not None: - del reduction_axes[1] # Do not reduce along virtual batch dim + del reduction_axes[1] # Do not reduce along virtual batch dim # Broadcasting only necessary for single-axis batch norm where the axis is # not the last dimension broadcast_shape = [1] * ndims broadcast_shape[self.axis[0]] = input_shape.dims[self.axis[0]].value + def _broadcast(v): if (v is not None and len(v.shape) != ndims and reduction_axes != list(range(ndims - 1))): @@ -782,11 +762,9 @@ class BatchNormalizationBase(Layer): if self.adjustment: adj_scale, adj_bias = self.adjustment(array_ops.shape(inputs)) # Adjust only during training. - adj_scale = tf_utils.smart_cond(training, - lambda: adj_scale, + adj_scale = tf_utils.smart_cond(training, lambda: adj_scale, lambda: array_ops.ones_like(adj_scale)) - adj_bias = tf_utils.smart_cond(training, - lambda: adj_bias, + adj_bias = tf_utils.smart_cond(training, lambda: adj_bias, lambda: array_ops.zeros_like(adj_bias)) scale, offset = _compose_transforms(adj_scale, adj_bias, scale, offset) @@ -878,11 +856,8 @@ class BatchNormalizationBase(Layer): scale = math_ops.cast(scale, inputs.dtype) # TODO(reedwm): Maybe do math in float32 if given float16 inputs, if doing # math in float16 hurts validation accuracy of popular models like resnet. - outputs = nn.batch_normalization(inputs, - _broadcast(mean), - _broadcast(variance), - offset, - scale, + outputs = nn.batch_normalization(inputs, _broadcast(mean), + _broadcast(variance), offset, scale, self.epsilon) # If some components of the shape got lost due to adjustments, fix that. outputs.set_shape(input_shape) @@ -896,21 +871,32 @@ class BatchNormalizationBase(Layer): def get_config(self): config = { - 'axis': self.axis, - 'momentum': self.momentum, - 'epsilon': self.epsilon, - 'center': self.center, - 'scale': self.scale, - 'beta_initializer': initializers.serialize(self.beta_initializer), - 'gamma_initializer': initializers.serialize(self.gamma_initializer), + 'axis': + self.axis, + 'momentum': + self.momentum, + 'epsilon': + self.epsilon, + 'center': + self.center, + 'scale': + self.scale, + 'beta_initializer': + initializers.serialize(self.beta_initializer), + 'gamma_initializer': + initializers.serialize(self.gamma_initializer), 'moving_mean_initializer': initializers.serialize(self.moving_mean_initializer), 'moving_variance_initializer': initializers.serialize(self.moving_variance_initializer), - 'beta_regularizer': regularizers.serialize(self.beta_regularizer), - 'gamma_regularizer': regularizers.serialize(self.gamma_regularizer), - 'beta_constraint': constraints.serialize(self.beta_constraint), - 'gamma_constraint': constraints.serialize(self.gamma_constraint) + 'beta_regularizer': + regularizers.serialize(self.beta_regularizer), + 'gamma_regularizer': + regularizers.serialize(self.gamma_regularizer), + 'beta_constraint': + constraints.serialize(self.beta_constraint), + 'gamma_constraint': + constraints.serialize(self.gamma_constraint) } # Only add TensorFlow-specific parameters if they are set, so as to preserve # model compatibility with external Keras. @@ -941,16 +927,14 @@ def replace_in_base_docstring(replacements): @keras_export(v1=['keras.layers.BatchNormalization']) # pylint: disable=missing-docstring class BatchNormalization(BatchNormalizationBase): - __doc__ = replace_in_base_docstring( - [(''' + __doc__ = replace_in_base_docstring([(""" fused: if `True`, use a faster, fused implementation, or raise a ValueError if the fused implementation cannot be used. If `None`, use the faster implementation if possible. If False, do not used the fused - implementation.''', - ''' + implementation.""", """ fused: if `None` or `True`, use a faster, fused implementation if possible. - If `False`, use the system recommended implementation.'''), - ('{{TRAINABLE_ATTRIBUTE_NOTE}}', '')]) + If `False`, use the system recommended implementation."""), + ('{{TRAINABLE_ATTRIBUTE_NOTE}}', '')]) _USE_V2_BEHAVIOR = False @@ -1047,37 +1031,30 @@ class LayerNormalization(Layer): Arguments: - axis: Integer or List/Tuple. The axis or axes - to normalize across. Typically this is the features axis/axes. The - left-out axes are typically the batch axis/axes. - This argument defaults to `-1`, the last dimension in the input. - epsilon: Small float added to variance to avoid dividing by zero. - Defaults to 1e-3 - center: If True, add offset of `beta` to normalized tensor. - If False, `beta` is ignored. Defaults to True. - scale: If True, multiply by `gamma`. - If False, `gamma` is not used. Defaults to True. - When the next layer is linear (also e.g. `nn.relu`), - this can be disabled since the scaling - will be done by the next layer. + axis: Integer or List/Tuple. The axis or axes to normalize across. Typically + this is the features axis/axes. The left-out axes are typically the batch + axis/axes. This argument defaults to `-1`, the last dimension in the + input. + epsilon: Small float added to variance to avoid dividing by zero. Defaults + to 1e-3 + center: If True, add offset of `beta` to normalized tensor. If False, `beta` + is ignored. Defaults to True. + scale: If True, multiply by `gamma`. If False, `gamma` is not used. Defaults + to True. When the next layer is linear (also e.g. `nn.relu`), this can be + disabled since the scaling will be done by the next layer. beta_initializer: Initializer for the beta weight. Defaults to zeros. gamma_initializer: Initializer for the gamma weight. Defaults to ones. beta_regularizer: Optional regularizer for the beta weight. None by default. - gamma_regularizer: Optional regularizer for the gamma weight. - None by default. + gamma_regularizer: Optional regularizer for the gamma weight. None by + default. beta_constraint: Optional constraint for the beta weight. None by default. gamma_constraint: Optional constraint for the gamma weight. None by default. trainable: Boolean, if `True` the variables will be marked as trainable. Defaults to True. - - Input shape: - Arbitrary. Use the keyword argument `input_shape` - (tuple of integers, does not include the samples axis) - when using this layer as the first layer in a model. - - Output shape: - Same shape as input. - + Input shape: Arbitrary. Use the keyword argument `input_shape` (tuple of + integers, does not include the samples axis) when using this layer as the + first layer in a model. + Output shape: Same shape as input. Reference: - [Lei Ba et al., 2016](https://arxiv.org/abs/1607.06450). """ @@ -1203,9 +1180,9 @@ class LayerNormalization(Layer): broadcast_shape = [1] * ndims for dim in self.axis: broadcast_shape[dim] = input_shape.dims[dim].value + def _broadcast(v): - if (v is not None and len(v.shape) != ndims and - self.axis != [ndims - 1]): + if (v is not None and len(v.shape) != ndims and self.axis != [ndims - 1]): return array_ops.reshape(v, broadcast_shape) return v diff --git a/tensorflow/python/keras/layers/preprocessing/BUILD b/tensorflow/python/keras/layers/preprocessing/BUILD index c1e1d5573e5..af7f6392219 100644 --- a/tensorflow/python/keras/layers/preprocessing/BUILD +++ b/tensorflow/python/keras/layers/preprocessing/BUILD @@ -2,6 +2,8 @@ # Contains the Keras preprocess layers (internal TensorFlow version). load("//tensorflow:tensorflow.bzl", "tf_py_test") + +# buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "cuda_py_test") load("//tensorflow/python/tpu:tpu.bzl", "tpu_py_test") load("//tensorflow/core/platform/default:distribute.bzl", "distribute_py_test") @@ -23,7 +25,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - ":categorical_crossing", + ":category_crossing", ":discretization", ":hashing", ":image_preprocessing", @@ -50,9 +52,9 @@ py_library( ) py_library( - name = "categorical_crossing", + name = "category_crossing", srcs = [ - "categorical_crossing.py", + "category_crossing.py", ], srcs_version = "PY2AND3", deps = [ @@ -194,7 +196,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - ":categorical_encoding", + ":category_encoding", ":string_lookup", "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", @@ -214,10 +216,10 @@ py_library( ) py_library( - name = "categorical_encoding", + name = "category_encoding", srcs = [ - "categorical_encoding.py", - "categorical_encoding_v1.py", + "category_encoding.py", + "category_encoding_v1.py", ], srcs_version = "PY2AND3", deps = [ @@ -289,16 +291,16 @@ py_library( ) cuda_py_test( - name = "categorical_crossing_test", + name = "category_crossing_test", size = "medium", - srcs = ["categorical_crossing_test.py"], + srcs = ["category_crossing_test.py"], python_version = "PY3", shard_count = 4, tags = [ "no_windows", # b/149031156 ], deps = [ - ":categorical_crossing", + ":category_crossing", "//tensorflow/python:client_testlib", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", @@ -306,12 +308,12 @@ cuda_py_test( ) tf_py_test( - name = "categorical_encoding_test", + name = "category_encoding_test", size = "medium", - srcs = ["categorical_encoding_test.py"], + srcs = ["category_encoding_test.py"], python_version = "PY3", deps = [ - ":categorical_encoding", + ":category_encoding", ":preprocessing_test_utils", "//tensorflow/python:client_testlib", "//tensorflow/python/keras", @@ -322,9 +324,9 @@ tf_py_test( ) distribute_py_test( - name = "categorical_encoding_distribution_test", - srcs = ["categorical_encoding_distribution_test.py"], - main = "categorical_encoding_distribution_test.py", + name = "category_encoding_distribution_test", + srcs = ["category_encoding_distribution_test.py"], + main = "category_encoding_distribution_test.py", python_version = "PY3", tags = [ "multi_and_single_gpu", @@ -333,7 +335,7 @@ distribute_py_test( "no_oss", # b/155502591 ], deps = [ - ":categorical_encoding", + ":category_encoding", "//tensorflow/python/distribute:combinations", "//tensorflow/python/distribute:strategy_combinations", "//tensorflow/python/keras", @@ -341,9 +343,9 @@ distribute_py_test( ) distribute_py_test( - name = "categorical_crossing_distribution_test", - srcs = ["categorical_crossing_distribution_test.py"], - main = "categorical_crossing_distribution_test.py", + name = "category_crossing_distribution_test", + srcs = ["category_crossing_distribution_test.py"], + main = "category_crossing_distribution_test.py", python_version = "PY3", tags = [ "multi_and_single_gpu", @@ -352,7 +354,7 @@ distribute_py_test( "no_oss", # b/155502591 ], deps = [ - ":categorical_crossing", + ":category_crossing", "//tensorflow/python/distribute:combinations", "//tensorflow/python/distribute:strategy_combinations", "//tensorflow/python/keras", @@ -521,6 +523,7 @@ tf_py_test( size = "medium", srcs = ["text_vectorization_test.py"], python_version = "PY3", + shard_count = 4, deps = [ ":preprocessing_test_utils", ":text_vectorization", diff --git a/tensorflow/python/keras/layers/preprocessing/benchmarks/BUILD b/tensorflow/python/keras/layers/preprocessing/benchmarks/BUILD index 0c7e6ba856d..7c976880059 100644 --- a/tensorflow/python/keras/layers/preprocessing/benchmarks/BUILD +++ b/tensorflow/python/keras/layers/preprocessing/benchmarks/BUILD @@ -1,4 +1,7 @@ # Benchmarks for Keras preprocessing layers. +load("//tensorflow:tensorflow.bzl", "cuda_py_test") + +# buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "tf_py_test") package( @@ -8,22 +11,32 @@ package( exports_files(["LICENSE"]) tf_py_test( - name = "categorical_encoding_benchmark", - srcs = ["categorical_encoding_benchmark.py"], + name = "category_encoding_benchmark", + srcs = ["category_encoding_benchmark.py"], python_version = "PY3", deps = [ "//tensorflow:tensorflow_py", - "//tensorflow/python/keras/layers/preprocessing:categorical_encoding", + "//tensorflow/python/keras/layers/preprocessing:category_encoding", ], ) tf_py_test( - name = "categorical_crossing_benchmark", - srcs = ["categorical_crossing_benchmark.py"], + name = "category_crossing_benchmark", + srcs = ["category_crossing_benchmark.py"], python_version = "PY3", deps = [ "//tensorflow:tensorflow_py", - "//tensorflow/python/keras/layers/preprocessing:categorical_crossing", + "//tensorflow/python/keras/layers/preprocessing:category_crossing", + ], +) + +tf_py_test( + name = "hashing_benchmark", + srcs = ["hashing_benchmark.py"], + python_version = "PY3", + deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/python/keras/layers/preprocessing:hashing", ], ) @@ -46,3 +59,13 @@ tf_py_test( "//tensorflow/python/keras/layers/preprocessing:normalization", ], ) + +cuda_py_test( + name = "image_preproc_benchmark", + srcs = ["image_preproc_benchmark.py"], + python_version = "PY3", + deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/python/keras/layers/preprocessing:image_preprocessing", + ], +) diff --git a/tensorflow/python/keras/layers/preprocessing/benchmarks/categorical_crossing_benchmark.py b/tensorflow/python/keras/layers/preprocessing/benchmarks/category_crossing_benchmark.py similarity index 97% rename from tensorflow/python/keras/layers/preprocessing/benchmarks/categorical_crossing_benchmark.py rename to tensorflow/python/keras/layers/preprocessing/benchmarks/category_crossing_benchmark.py index 80a7903f0b9..efc0ca3766f 100644 --- a/tensorflow/python/keras/layers/preprocessing/benchmarks/categorical_crossing_benchmark.py +++ b/tensorflow/python/keras/layers/preprocessing/benchmarks/category_crossing_benchmark.py @@ -28,7 +28,7 @@ from tensorflow.python.compat import v2_compat from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape -from tensorflow.python.keras.layers.preprocessing import categorical_crossing +from tensorflow.python.keras.layers.preprocessing import category_crossing from tensorflow.python.ops import sparse_ops from tensorflow.python.platform import benchmark from tensorflow.python.platform import test @@ -74,7 +74,7 @@ class BenchmarkLayer(benchmark.Benchmark): def bm_layer_implementation(self, batch_size): input_1 = keras.Input(shape=(1,), dtype=dtypes.int64, name="word") input_2 = keras.Input(shape=(1,), dtype=dtypes.int64, name="int") - layer = categorical_crossing.CategoryCrossing() + layer = category_crossing.CategoryCrossing() _ = layer([input_1, input_2]) num_repeats = 5 @@ -97,7 +97,7 @@ class BenchmarkLayer(benchmark.Benchmark): ends.append(time.time()) avg_time = np.mean(np.array(ends) - np.array(starts)) / num_batches - name = "categorical_crossing|batch_%s" % batch_size + name = "category_crossing|batch_%s" % batch_size baseline = self.run_dataset_implementation(batch_size) extras = { "dataset implementation baseline": baseline, diff --git a/tensorflow/python/keras/layers/preprocessing/benchmarks/categorical_encoding_benchmark.py b/tensorflow/python/keras/layers/preprocessing/benchmarks/category_encoding_benchmark.py similarity index 93% rename from tensorflow/python/keras/layers/preprocessing/benchmarks/categorical_encoding_benchmark.py rename to tensorflow/python/keras/layers/preprocessing/benchmarks/category_encoding_benchmark.py index e68b77ebef9..71b4c7b6b61 100644 --- a/tensorflow/python/keras/layers/preprocessing/benchmarks/categorical_encoding_benchmark.py +++ b/tensorflow/python/keras/layers/preprocessing/benchmarks/category_encoding_benchmark.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Benchmark for Keras categorical_encoding preprocessing layer.""" +"""Benchmark for Keras category_encoding preprocessing layer.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -26,7 +26,7 @@ from tensorflow.python import keras from tensorflow.python.compat import v2_compat from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes -from tensorflow.python.keras.layers.preprocessing import categorical_encoding +from tensorflow.python.keras.layers.preprocessing import category_encoding from tensorflow.python.ops import random_ops from tensorflow.python.platform import benchmark from tensorflow.python.platform import test @@ -42,7 +42,7 @@ class BenchmarkLayer(benchmark.Benchmark): def run_dataset_implementation(self, output_mode, batch_size, sequence_length, max_tokens): input_t = keras.Input(shape=(sequence_length,), dtype=dtypes.int32) - layer = categorical_encoding.CategoricalEncoding( + layer = category_encoding.CategoryEncoding( max_tokens=max_tokens, output_mode=output_mode) _ = layer(input_t) @@ -68,7 +68,7 @@ class BenchmarkLayer(benchmark.Benchmark): ends.append(time.time()) avg_time = np.mean(np.array(ends) - np.array(starts)) / num_batches - name = "categorical_encoding|batch_%s|seq_length_%s|%s_max_tokens" % ( + name = "category_encoding|batch_%s|seq_length_%s|%s_max_tokens" % ( batch_size, sequence_length, max_tokens) self.report_benchmark(iters=num_repeats, wall_time=avg_time, name=name) diff --git a/tensorflow/python/keras/layers/preprocessing/benchmarks/hashing_benchmark.py b/tensorflow/python/keras/layers/preprocessing/benchmarks/hashing_benchmark.py new file mode 100644 index 00000000000..68ab28c7f6c --- /dev/null +++ b/tensorflow/python/keras/layers/preprocessing/benchmarks/hashing_benchmark.py @@ -0,0 +1,115 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Benchmark for Keras hashing preprocessing layer.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools +import random +import string +import time + +from absl import flags +import numpy as np + +from tensorflow.python import keras +from tensorflow.python.compat import v2_compat +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_shape +from tensorflow.python.keras.layers.preprocessing import hashing +from tensorflow.python.ops import string_ops +from tensorflow.python.platform import benchmark +from tensorflow.python.platform import test + +FLAGS = flags.FLAGS + +v2_compat.enable_v2_behavior() + + +# word_gen creates random sequences of ASCII letters (both lowercase and upper). +# The number of unique strings is ~2,700. +def word_gen(): + for _ in itertools.count(1): + yield "".join(random.choice(string.ascii_letters) for i in range(2)) + + +class BenchmarkLayer(benchmark.Benchmark): + """Benchmark the layer forward pass.""" + + def run_dataset_implementation(self, batch_size): + num_repeats = 5 + starts = [] + ends = [] + for _ in range(num_repeats): + ds = dataset_ops.Dataset.from_generator(word_gen, dtypes.string, + tensor_shape.TensorShape([])) + ds = ds.shuffle(batch_size * 100) + ds = ds.batch(batch_size) + num_batches = 5 + ds = ds.take(num_batches) + ds = ds.prefetch(num_batches) + starts.append(time.time()) + # Benchmarked code begins here. + for i in ds: + _ = string_ops.string_to_hash_bucket(i, num_buckets=2) + # Benchmarked code ends here. + ends.append(time.time()) + + avg_time = np.mean(np.array(ends) - np.array(starts)) / num_batches + return avg_time + + def bm_layer_implementation(self, batch_size): + input_1 = keras.Input(shape=(None,), dtype=dtypes.string, name="word") + layer = hashing.Hashing(num_bins=2) + _ = layer(input_1) + + num_repeats = 5 + starts = [] + ends = [] + for _ in range(num_repeats): + ds = dataset_ops.Dataset.from_generator(word_gen, dtypes.string, + tensor_shape.TensorShape([])) + ds = ds.shuffle(batch_size * 100) + ds = ds.batch(batch_size) + num_batches = 5 + ds = ds.take(num_batches) + ds = ds.prefetch(num_batches) + starts.append(time.time()) + # Benchmarked code begins here. + for i in ds: + _ = layer(i) + # Benchmarked code ends here. + ends.append(time.time()) + + avg_time = np.mean(np.array(ends) - np.array(starts)) / num_batches + name = "hashing|batch_%s" % batch_size + baseline = self.run_dataset_implementation(batch_size) + extras = { + "dataset implementation baseline": baseline, + "delta seconds": (baseline - avg_time), + "delta percent": ((baseline - avg_time) / baseline) * 100 + } + self.report_benchmark( + iters=num_repeats, wall_time=avg_time, extras=extras, name=name) + + def benchmark_vocab_size_by_batch(self): + for batch in [32, 64, 256]: + self.bm_layer_implementation(batch_size=batch) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/keras/layers/preprocessing/benchmarks/image_preproc_benchmark.py b/tensorflow/python/keras/layers/preprocessing/benchmarks/image_preproc_benchmark.py new file mode 100644 index 00000000000..302c890c823 --- /dev/null +++ b/tensorflow/python/keras/layers/preprocessing/benchmarks/image_preproc_benchmark.py @@ -0,0 +1,163 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Benchmark for Keras image preprocessing layer.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +import time + +from absl import flags +import numpy as np + +from tensorflow.python import keras +from tensorflow.python.compat import v2_compat +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.keras.layers.preprocessing import image_preprocessing +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import image_ops_impl +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.platform import benchmark +from tensorflow.python.platform import test + +FLAGS = flags.FLAGS + +v2_compat.enable_v2_behavior() + +LOWER = .2 +UPPER = .4 +BATCH_SIZE = 32 + + +def rotate(inputs): + """rotate image.""" + inputs_shape = array_ops.shape(inputs) + batch_size = inputs_shape[0] + img_hd = math_ops.cast(inputs_shape[1], dtypes.float32) + img_wd = math_ops.cast(inputs_shape[2], dtypes.float32) + min_angle = LOWER * 2. * np.pi + max_angle = UPPER * 2. * np.pi + angles = random_ops.random_uniform( + shape=[batch_size], minval=min_angle, maxval=max_angle) + return image_preprocessing.transform( + inputs, image_preprocessing.get_rotation_matrix(angles, img_hd, img_wd)) + + +def zoom(inputs): + """zoom image.""" + inputs_shape = array_ops.shape(inputs) + batch_size = inputs_shape[0] + img_hd = math_ops.cast(inputs_shape[1], dtypes.float32) + img_wd = math_ops.cast(inputs_shape[2], dtypes.float32) + height_zoom = random_ops.random_uniform( + shape=[batch_size, 1], minval=1. + LOWER, maxval=1. + UPPER) + width_zoom = random_ops.random_uniform( + shape=[batch_size, 1], minval=1. + LOWER, maxval=1. + UPPER) + zooms = math_ops.cast( + array_ops.concat([width_zoom, height_zoom], axis=1), dtype=dtypes.float32) + return image_preprocessing.transform( + inputs, image_preprocessing.get_zoom_matrix(zooms, img_hd, img_wd)) + + +def image_augmentation(inputs, batch_size): + """image augmentation.""" + img = inputs + img = image_ops_impl.resize_images_v2(img, size=[224, 224]) + img = random_ops.random_crop(img, size=[batch_size, 224, 224, 3]) + img = rotate(img) + img = zoom(img) + return img + + +class BenchmarkLayer(benchmark.Benchmark): + """Benchmark the layer forward pass.""" + + def run_dataset_implementation(self, batch_size): + num_repeats = 5 + starts = [] + ends = [] + for _ in range(num_repeats): + ds = dataset_ops.Dataset.from_tensor_slices( + np.random.random((batch_size, 256, 256, 3))) + ds = ds.shuffle(batch_size * 100) + ds = ds.batch(batch_size) + ds = ds.prefetch(batch_size) + img_augmentation = functools.partial( + image_augmentation, batch_size=batch_size) + ds = ds.map(img_augmentation) + starts.append(time.time()) + count = 0 + # Benchmarked code begins here. + for i in ds: + _ = i + count += 1 + # Benchmarked code ends here. + ends.append(time.time()) + + avg_time = np.mean(np.array(ends) - np.array(starts)) / count + return avg_time + + def bm_layer_implementation(self, batch_size): + with ops.device_v2("/gpu:0"): + img = keras.Input(shape=(256, 256, 3), dtype=dtypes.float32) + preprocessor = keras.Sequential([ + image_preprocessing.Resizing(224, 224), + image_preprocessing.RandomCrop(height=224, width=224), + image_preprocessing.RandomRotation(factor=(.2, .4)), + image_preprocessing.RandomFlip(mode="horizontal"), + image_preprocessing.RandomZoom(.2, .2) + ]) + _ = preprocessor(img) + + num_repeats = 5 + starts = [] + ends = [] + for _ in range(num_repeats): + ds = dataset_ops.Dataset.from_tensor_slices( + np.random.random((batch_size, 256, 256, 3))) + ds = ds.shuffle(batch_size * 100) + ds = ds.batch(batch_size) + ds = ds.prefetch(batch_size) + starts.append(time.time()) + count = 0 + # Benchmarked code begins here. + for i in ds: + _ = preprocessor(i) + count += 1 + # Benchmarked code ends here. + ends.append(time.time()) + + avg_time = np.mean(np.array(ends) - np.array(starts)) / count + name = "image_preprocessing|batch_%s" % batch_size + baseline = self.run_dataset_implementation(batch_size) + extras = { + "dataset implementation baseline": baseline, + "delta seconds": (baseline - avg_time), + "delta percent": ((baseline - avg_time) / baseline) * 100 + } + self.report_benchmark( + iters=num_repeats, wall_time=avg_time, extras=extras, name=name) + + def benchmark_vocab_size_by_batch(self): + for batch in [32, 64, 256]: + self.bm_layer_implementation(batch_size=batch) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/keras/layers/preprocessing/categorical_crossing.py b/tensorflow/python/keras/layers/preprocessing/category_crossing.py similarity index 87% rename from tensorflow/python/keras/layers/preprocessing/categorical_crossing.py rename to tensorflow/python/keras/layers/preprocessing/category_crossing.py index 68848458bb2..79c27d9ec36 100644 --- a/tensorflow/python/keras/layers/preprocessing/categorical_crossing.py +++ b/tensorflow/python/keras/layers/preprocessing/category_crossing.py @@ -49,6 +49,17 @@ class CategoryCrossing(Layer): [b'b_X_e'], [b'c_X_f']], dtype=object)> + + >>> inp_1 = tf.constant([['a'], ['b'], ['c']]) + >>> inp_2 = tf.constant([['d'], ['e'], ['f']]) + >>> layer = tf.keras.layers.experimental.preprocessing.CategoryCrossing( + ... separator='-') + >>> layer([inp_1, inp_2]) + <tf.Tensor: shape=(3, 1), dtype=string, numpy= + array([[b'a-d'], + [b'b-e'], + [b'c-f']], dtype=object)> + Arguments: depth: depth of input crossing. By default None, all inputs are crossed into one output. It can also be an int or tuple/list of ints. Passing an @@ -59,6 +70,8 @@ class CategoryCrossing(Layer): equal to N1 or N2. Passing `None` means a single crossed output with all inputs. For example, with inputs `a`, `b` and `c`, `depth=2` means the output will be [a;b;c;cross(a, b);cross(bc);cross(ca)]. + separator: A string added between each input being joined. Defaults to + '_X_'. name: Name to give to the layer. **kwargs: Keyword arguments to construct a layer. @@ -98,13 +111,12 @@ class CategoryCrossing(Layer): `[[b'1_X_2_X_3'], [b'4_X_5_X_6']]` """ - def __init__(self, - depth=None, - name=None, - **kwargs): - # TODO(tanzheny): Consider making seperator configurable. + def __init__(self, depth=None, name=None, separator=None, **kwargs): super(CategoryCrossing, self).__init__(name=name, **kwargs) self.depth = depth + if separator is None: + separator = '_X_' + self.separator = separator if isinstance(depth, (tuple, list)): self._depth_tuple = depth elif depth is not None: @@ -114,12 +126,16 @@ class CategoryCrossing(Layer): """Gets the crossed output from a partial list/tuple of inputs.""" # If ragged_out=True, convert output from sparse to ragged. if ragged_out: + # TODO(momernick): Support separator with ragged_cross. + if self.separator != '_X_': + raise ValueError('Non-default separator with ragged input is not ' + 'supported yet, given {}'.format(self.separator)) return ragged_array_ops.cross(partial_inputs) elif sparse_out: - return sparse_ops.sparse_cross(partial_inputs) + return sparse_ops.sparse_cross(partial_inputs, separator=self.separator) else: return sparse_ops.sparse_tensor_to_dense( - sparse_ops.sparse_cross(partial_inputs)) + sparse_ops.sparse_cross(partial_inputs, separator=self.separator)) def call(self, inputs): depth_tuple = self._depth_tuple if self.depth else (len(inputs),) @@ -178,6 +194,7 @@ class CategoryCrossing(Layer): def get_config(self): config = { 'depth': self.depth, + 'separator': self.separator, } base_config = super(CategoryCrossing, self).get_config() return dict(list(base_config.items()) + list(config.items())) diff --git a/tensorflow/python/keras/layers/preprocessing/categorical_crossing_distribution_test.py b/tensorflow/python/keras/layers/preprocessing/category_crossing_distribution_test.py similarity index 98% rename from tensorflow/python/keras/layers/preprocessing/categorical_crossing_distribution_test.py rename to tensorflow/python/keras/layers/preprocessing/category_crossing_distribution_test.py index 57dea6edf4a..1ccc7fe2296 100644 --- a/tensorflow/python/keras/layers/preprocessing/categorical_crossing_distribution_test.py +++ b/tensorflow/python/keras/layers/preprocessing/category_crossing_distribution_test.py @@ -28,7 +28,7 @@ from tensorflow.python.distribute import tpu_strategy from tensorflow.python.framework import config from tensorflow.python.framework import dtypes from tensorflow.python.keras import keras_parameterized -from tensorflow.python.keras.layers.preprocessing import categorical_crossing +from tensorflow.python.keras.layers.preprocessing import category_crossing from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils from tensorflow.python.platform import test @@ -72,7 +72,7 @@ class CategoryCrossingDistributionTest( input_data_2 = keras.Input(shape=(2,), dtype=dtypes.string, name='input_2') input_data = [input_data_1, input_data_2] - layer = categorical_crossing.CategoryCrossing() + layer = category_crossing.CategoryCrossing() int_data = layer(input_data) model = keras.Model(inputs=input_data, outputs=int_data) output_dataset = model.predict(inp_dataset) diff --git a/tensorflow/python/keras/layers/preprocessing/categorical_crossing_test.py b/tensorflow/python/keras/layers/preprocessing/category_crossing_test.py similarity index 82% rename from tensorflow/python/keras/layers/preprocessing/categorical_crossing_test.py rename to tensorflow/python/keras/layers/preprocessing/category_crossing_test.py index 5bbcf5ce022..f076c9ea865 100644 --- a/tensorflow/python/keras/layers/preprocessing/categorical_crossing_test.py +++ b/tensorflow/python/keras/layers/preprocessing/category_crossing_test.py @@ -29,7 +29,7 @@ from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras.engine import input_layer from tensorflow.python.keras.engine import training -from tensorflow.python.keras.layers.preprocessing import categorical_crossing +from tensorflow.python.keras.layers.preprocessing import category_crossing from tensorflow.python.ops import array_ops from tensorflow.python.ops import sparse_ops from tensorflow.python.ops.ragged import ragged_factory_ops @@ -41,7 +41,7 @@ from tensorflow.python.platform import test class CategoryCrossingTest(keras_parameterized.TestCase): def test_crossing_sparse_inputs(self): - layer = categorical_crossing.CategoryCrossing() + layer = category_crossing.CategoryCrossing() inputs_0 = sparse_tensor.SparseTensor( indices=[[0, 0], [1, 0], [1, 1]], values=['a', 'b', 'c'], @@ -52,8 +52,32 @@ class CategoryCrossingTest(keras_parameterized.TestCase): self.assertAllClose(np.asarray([[0, 0], [1, 0], [1, 1]]), output.indices) self.assertAllEqual([b'a_X_d', b'b_X_e', b'c_X_e'], output.values) + def test_crossing_sparse_inputs_custom_sep(self): + layer = category_crossing.CategoryCrossing(separator='_Y_') + inputs_0 = sparse_tensor.SparseTensor( + indices=[[0, 0], [1, 0], [1, 1]], + values=['a', 'b', 'c'], + dense_shape=[2, 2]) + inputs_1 = sparse_tensor.SparseTensor( + indices=[[0, 1], [1, 2]], values=['d', 'e'], dense_shape=[2, 3]) + output = layer([inputs_0, inputs_1]) + self.assertAllClose(np.asarray([[0, 0], [1, 0], [1, 1]]), output.indices) + self.assertAllEqual([b'a_Y_d', b'b_Y_e', b'c_Y_e'], output.values) + + def test_crossing_sparse_inputs_empty_sep(self): + layer = category_crossing.CategoryCrossing(separator='') + inputs_0 = sparse_tensor.SparseTensor( + indices=[[0, 0], [1, 0], [1, 1]], + values=['a', 'b', 'c'], + dense_shape=[2, 2]) + inputs_1 = sparse_tensor.SparseTensor( + indices=[[0, 1], [1, 2]], values=['d', 'e'], dense_shape=[2, 3]) + output = layer([inputs_0, inputs_1]) + self.assertAllClose(np.asarray([[0, 0], [1, 0], [1, 1]]), output.indices) + self.assertAllEqual([b'ad', b'be', b'ce'], output.values) + def test_crossing_sparse_inputs_depth_int(self): - layer = categorical_crossing.CategoryCrossing(depth=1) + layer = category_crossing.CategoryCrossing(depth=1) inputs_0 = sparse_tensor.SparseTensor( indices=[[0, 0], [1, 0], [2, 0]], values=['a', 'b', 'c'], @@ -69,7 +93,7 @@ class CategoryCrossingTest(keras_parameterized.TestCase): self.assertAllEqual(expected_out, output) def test_crossing_sparse_inputs_depth_tuple(self): - layer = categorical_crossing.CategoryCrossing(depth=(2, 3)) + layer = category_crossing.CategoryCrossing(depth=(2, 3)) inputs_0 = sparse_tensor.SparseTensor( indices=[[0, 0], [1, 0], [2, 0]], values=['a', 'b', 'c'], @@ -107,14 +131,14 @@ class CategoryCrossingTest(keras_parameterized.TestCase): inp_0_t = input_layer.Input(shape=(None,), ragged=True, dtype=dtypes.string) inp_1_t = input_layer.Input(shape=(None,), ragged=True, dtype=dtypes.string) - non_hashed_layer = categorical_crossing.CategoryCrossing() + non_hashed_layer = category_crossing.CategoryCrossing() out_t = non_hashed_layer([inp_0_t, inp_1_t]) model = training.Model(inputs=[inp_0_t, inp_1_t], outputs=out_t) expected_output = [[b'omar_X_a', b'skywalker_X_a'], [b'marlo_X_b']] self.assertAllEqual(expected_output, model.predict([inputs_0, inputs_1])) def test_crossing_ragged_inputs_depth_int(self): - layer = categorical_crossing.CategoryCrossing(depth=1) + layer = category_crossing.CategoryCrossing(depth=1) inputs_0 = ragged_factory_ops.constant([['a'], ['b'], ['c']]) inputs_1 = ragged_factory_ops.constant([['d'], ['e'], ['f']]) output = layer([inputs_0, inputs_1]) @@ -122,7 +146,7 @@ class CategoryCrossingTest(keras_parameterized.TestCase): self.assertIsInstance(output, ragged_tensor.RaggedTensor) self.assertAllEqual(expected_output, output) - layer = categorical_crossing.CategoryCrossing(depth=2) + layer = category_crossing.CategoryCrossing(depth=2) inp_0_t = input_layer.Input(shape=(None,), ragged=True, dtype=dtypes.string) inp_1_t = input_layer.Input(shape=(None,), ragged=True, dtype=dtypes.string) out_t = layer([inp_0_t, inp_1_t]) @@ -132,7 +156,7 @@ class CategoryCrossingTest(keras_parameterized.TestCase): self.assertAllEqual(expected_output, model.predict([inputs_0, inputs_1])) def test_crossing_ragged_inputs_depth_tuple(self): - layer = categorical_crossing.CategoryCrossing(depth=[2, 3]) + layer = category_crossing.CategoryCrossing(depth=[2, 3]) inputs_0 = ragged_factory_ops.constant([['a'], ['b'], ['c']]) inputs_1 = ragged_factory_ops.constant([['d'], ['e'], ['f']]) inputs_2 = ragged_factory_ops.constant([['g'], ['h'], ['i']]) @@ -149,21 +173,21 @@ class CategoryCrossingTest(keras_parameterized.TestCase): self.assertAllEqual(expected_output, output) def test_crossing_with_dense_inputs(self): - layer = categorical_crossing.CategoryCrossing() + layer = category_crossing.CategoryCrossing() inputs_0 = np.asarray([[1, 2]]) inputs_1 = np.asarray([[1, 3]]) output = layer([inputs_0, inputs_1]) self.assertAllEqual([[b'1_X_1', b'1_X_3', b'2_X_1', b'2_X_3']], output) def test_crossing_dense_inputs_depth_int(self): - layer = categorical_crossing.CategoryCrossing(depth=1) + layer = category_crossing.CategoryCrossing(depth=1) inputs_0 = constant_op.constant([['a'], ['b'], ['c']]) inputs_1 = constant_op.constant([['d'], ['e'], ['f']]) output = layer([inputs_0, inputs_1]) expected_output = [[b'a', b'd'], [b'b', b'e'], [b'c', b'f']] self.assertAllEqual(expected_output, output) - layer = categorical_crossing.CategoryCrossing(depth=2) + layer = category_crossing.CategoryCrossing(depth=2) inp_0_t = input_layer.Input(shape=(1,), dtype=dtypes.string) inp_1_t = input_layer.Input(shape=(1,), dtype=dtypes.string) out_t = layer([inp_0_t, inp_1_t]) @@ -174,7 +198,7 @@ class CategoryCrossingTest(keras_parameterized.TestCase): self.assertAllEqual(expected_output, model.predict([inputs_0, inputs_1])) def test_crossing_dense_inputs_depth_tuple(self): - layer = categorical_crossing.CategoryCrossing(depth=[2, 3]) + layer = category_crossing.CategoryCrossing(depth=[2, 3]) inputs_0 = constant_op.constant([['a'], ['b'], ['c']]) inputs_1 = constant_op.constant([['d'], ['e'], ['f']]) inputs_2 = constant_op.constant([['g'], ['h'], ['i']]) @@ -200,21 +224,21 @@ class CategoryCrossingTest(keras_parameterized.TestCase): tensor_spec.TensorSpec(input_shape, dtypes.string) for input_shape in input_shapes ] - layer = categorical_crossing.CategoryCrossing() + layer = category_crossing.CategoryCrossing() output_spec = layer.compute_output_signature(input_specs) self.assertEqual(output_spec.shape.dims[0], input_shapes[0].dims[0]) self.assertEqual(output_spec.dtype, dtypes.string) @tf_test_util.run_v2_only def test_config_with_custom_name(self): - layer = categorical_crossing.CategoryCrossing(depth=2, name='hashing') + layer = category_crossing.CategoryCrossing(depth=2, name='hashing') config = layer.get_config() - layer_1 = categorical_crossing.CategoryCrossing.from_config(config) + layer_1 = category_crossing.CategoryCrossing.from_config(config) self.assertEqual(layer_1.name, layer.name) - layer = categorical_crossing.CategoryCrossing(name='hashing') + layer = category_crossing.CategoryCrossing(name='hashing') config = layer.get_config() - layer_1 = categorical_crossing.CategoryCrossing.from_config(config) + layer_1 = category_crossing.CategoryCrossing.from_config(config) self.assertEqual(layer_1.name, layer.name) diff --git a/tensorflow/python/keras/layers/preprocessing/categorical_encoding.py b/tensorflow/python/keras/layers/preprocessing/category_encoding.py similarity index 82% rename from tensorflow/python/keras/layers/preprocessing/categorical_encoding.py rename to tensorflow/python/keras/layers/preprocessing/category_encoding.py index 466405a27a9..b0a7e746074 100644 --- a/tensorflow/python/keras/layers/preprocessing/categorical_encoding.py +++ b/tensorflow/python/keras/layers/preprocessing/category_encoding.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Keras text CategoricalEncoding preprocessing layer.""" +"""Keras text CategoryEncoding preprocessing layer.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -32,11 +32,13 @@ from tensorflow.python.keras import backend as K from tensorflow.python.keras.engine import base_preprocessing_layer from tensorflow.python.keras.utils import layer_utils from tensorflow.python.ops import array_ops +from tensorflow.python.ops import bincount_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import sparse_ops from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.util import compat +from tensorflow.python.util.tf_export import keras_export TFIDF = "tf-idf" INT = "int" @@ -49,14 +51,26 @@ _NUM_ELEMENTS_NAME = "num_elements" _IDF_NAME = "idf" -class CategoricalEncoding(base_preprocessing_layer.CombinerPreprocessingLayer): - """Categorical encoding layer. +@keras_export("keras.layers.experimental.preprocessing.CategoryEncoding", v1=[]) +class CategoryEncoding(base_preprocessing_layer.CombinerPreprocessingLayer): + """Category encoding layer. This layer provides options for condensing data into a categorical encoding. It accepts integer values as inputs and outputs a dense representation (one sample = 1-index tensor of float values representing data about the sample's tokens) of those inputs. + Examples: + + >>> layer = tf.keras.layers.experimental.preprocessing.CategoryEncoding( + ... max_tokens=4) + >>> layer([[0, 1], [0, 0], [1, 2], [3, 1]]) + <tf.Tensor: shape=(4, 4), dtype=int64, numpy= + array([[1, 1, 0, 0], + [2, 0, 0, 0], + [0, 1, 1, 0], + [0, 1, 0, 1]])> + Attributes: max_tokens: The maximum size of the vocabulary for this layer. If None, there is no cap on the size of the vocabulary. @@ -72,7 +86,6 @@ class CategoricalEncoding(base_preprocessing_layer.CombinerPreprocessingLayer): sparse: Boolean. If true, returns a `SparseTensor` instead of a dense `Tensor`. Defaults to `False`. """ - # TODO(momernick): Add an examples section to the docstring. def __init__(self, max_tokens=None, @@ -83,7 +96,7 @@ class CategoricalEncoding(base_preprocessing_layer.CombinerPreprocessingLayer): layer_utils.validate_string_arg( output_mode, allowable_strings=(COUNT, BINARY, TFIDF), - layer_name="CategoricalEncoding", + layer_name="CategoryEncoding", arg_name="output_mode") # If max_tokens is set, the value must be greater than 1 - otherwise we @@ -92,10 +105,10 @@ class CategoricalEncoding(base_preprocessing_layer.CombinerPreprocessingLayer): raise ValueError("max_tokens must be > 1.") # We need to call super() before we call _add_state_variable(). - combiner = _CategoricalEncodingCombiner( + combiner = _CategoryEncodingCombiner( compute_max_element=max_tokens is None, compute_idf=output_mode == TFIDF) - super(CategoricalEncoding, self).__init__(combiner=combiner, **kwargs) + super(CategoryEncoding, self).__init__(combiner=combiner, **kwargs) self._max_tokens = max_tokens self._output_mode = output_mode @@ -158,13 +171,12 @@ class CategoricalEncoding(base_preprocessing_layer.CombinerPreprocessingLayer): RuntimeError: if the layer cannot be adapted at this time. """ if not reset_state: - raise ValueError("CategoricalEncoding does not support streaming adapts.") + raise ValueError("CategoryEncoding does not support streaming adapts.") if self._called and self._max_tokens is None: - raise RuntimeError( - "CategoricalEncoding can't be adapted after being called " - "if max_tokens is None.") - super(CategoricalEncoding, self).adapt(data, reset_state) + raise RuntimeError("CategoryEncoding can't be adapted after being called " + "if max_tokens is None.") + super(CategoryEncoding, self).adapt(data, reset_state) def _set_state_variables(self, updates): if not self.built: @@ -180,7 +192,7 @@ class CategoricalEncoding(base_preprocessing_layer.CombinerPreprocessingLayer): "output_mode": self._output_mode, "sparse": self._sparse, } - base_config = super(CategoricalEncoding, self).get_config() + base_config = super(CategoryEncoding, self).get_config() return dict(list(base_config.items()) + list(config.items())) def _convert_to_ndarray(self, x): @@ -237,65 +249,40 @@ class CategoricalEncoding(base_preprocessing_layer.CombinerPreprocessingLayer): else: out_depth = self._max_tokens - if self._sparse: - if self._output_mode != COUNT: - raise ValueError("Only supports `sparse=True` when `output_mode` " - ' is \"count\", got {}'.format(self._output_mode)) - inputs = self._convert_to_sparse_inputs(inputs) - - # Consider having sparse.one_hot - # Append values to indices, and reduce sum to get the counts. - tokens = array_ops.expand_dims( - math_ops.cast(inputs.values, dtypes.int64), axis=1) - count_tokens = array_ops.concat([inputs.indices, tokens], axis=1) - count_values = array_ops.ones_like(inputs.values, dtype=dtypes.int64) - unreduced_count_shape = array_ops.concat( - [inputs.dense_shape, [out_depth]], axis=0) - counts = sparse_tensor.SparseTensor( - indices=count_tokens, - values=count_values, - dense_shape=unreduced_count_shape) - count_data = sparse_ops.sparse_reduce_sum_v2( - counts, axis=1, output_is_sparse=True) - return count_data - - # If the input is a sparse tensor, we densify it with the default value of - # -1. Because -1 is ignored by one_hot, this effectively drops the non-set - # positions from the output encoding. - if isinstance(inputs, sparse_tensor.SparseTensor): - inputs = sparse_ops.sparse_tensor_to_dense(inputs, default_value=-1) - - if self._output_mode == BINARY: - bool_one_hot_data = array_ops.one_hot( - inputs, depth=out_depth, on_value=True, off_value=False) - reduced_bool_data = math_ops.reduce_any(bool_one_hot_data, axis=1) - binary_data = math_ops.cast(reduced_bool_data, dtypes.int64) - binary_data.set_shape(tensor_shape.TensorShape((None, out_depth))) - return binary_data - - one_hot_data = array_ops.one_hot(inputs, depth=out_depth) - counts = math_ops.reduce_sum(one_hot_data, axis=1) - if self._output_mode == COUNT: - count_data = math_ops.cast(counts, dtypes.int64) - count_data.set_shape(tensor_shape.TensorShape((None, out_depth))) - return count_data - - tf_idf_data = math_ops.multiply(counts, self.tf_idf_weights) - tf_idf_data.set_shape(tensor_shape.TensorShape((None, out_depth))) if self._output_mode == TFIDF: + # If the input is a sparse tensor, we densify it with the default value of + # -1. Because -1 is ignored by one_hot, this effectively drops the non-set + # positions from the output encoding. + if isinstance(inputs, sparse_tensor.SparseTensor): + inputs = sparse_ops.sparse_tensor_to_dense(inputs, default_value=-1) + one_hot_data = array_ops.one_hot(inputs, depth=out_depth) + counts = math_ops.reduce_sum(one_hot_data, axis=1) + tf_idf_data = math_ops.multiply(counts, self.tf_idf_weights) + tf_idf_data.set_shape(tensor_shape.TensorShape((None, out_depth))) return tf_idf_data - # We can only get here if we didn't recognize the passed mode. - raise ValueError("Unknown output mode %s" % self._output_mode) + binary_output = (self._output_mode == BINARY) + if self._sparse: + return bincount_ops.sparse_bincount( + inputs, minlength=out_depth, axis=-1, binary_output=binary_output) + else: + result = bincount_ops.bincount( + inputs, + minlength=out_depth, + dtype=dtypes.int64, + axis=-1, + binary_output=binary_output) + result.set_shape(tensor_shape.TensorShape((None, out_depth))) + return result -class _CategoricalEncodingAccumulator( +class _CategoryEncodingAccumulator( collections.namedtuple("Accumulator", ["data", "per_doc_count_dict"])): pass -class _CategoricalEncodingCombiner(base_preprocessing_layer.Combiner): - """Combiner for the CategoricalEncoding preprocessing layer. +class _CategoryEncodingCombiner(base_preprocessing_layer.Combiner): + """Combiner for the CategoryEncoding preprocessing layer. This class encapsulates the logic for computing the number of elements in the input dataset and the document frequency for each element. @@ -411,7 +398,7 @@ class _CategoricalEncodingCombiner(base_preprocessing_layer.Combiner): def restore(self, output): """Creates an accumulator based on 'output'.""" raise NotImplementedError( - "CategoricalEncoding does not restore or support streaming updates.") + "CategoryEncoding does not restore or support streaming updates.") def serialize(self, accumulator): """Serializes an accumulator for a remote call.""" @@ -452,4 +439,4 @@ class _CategoricalEncodingCombiner(base_preprocessing_layer.Combiner): else: per_doc_count_dict = None data = [0, 0] - return _CategoricalEncodingAccumulator(data, per_doc_count_dict) + return _CategoryEncodingAccumulator(data, per_doc_count_dict) diff --git a/tensorflow/python/keras/layers/preprocessing/categorical_encoding_distribution_test.py b/tensorflow/python/keras/layers/preprocessing/category_encoding_distribution_test.py similarity index 64% rename from tensorflow/python/keras/layers/preprocessing/categorical_encoding_distribution_test.py rename to tensorflow/python/keras/layers/preprocessing/category_encoding_distribution_test.py index c5214533f94..011495b9314 100644 --- a/tensorflow/python/keras/layers/preprocessing/categorical_encoding_distribution_test.py +++ b/tensorflow/python/keras/layers/preprocessing/category_encoding_distribution_test.py @@ -21,39 +21,58 @@ from __future__ import print_function import numpy as np from tensorflow.python import keras +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import combinations from tensorflow.python.distribute import strategy_combinations +from tensorflow.python.distribute import tpu_strategy +from tensorflow.python.framework import config from tensorflow.python.framework import dtypes from tensorflow.python.keras import keras_parameterized -from tensorflow.python.keras.layers.preprocessing import categorical_encoding +from tensorflow.python.keras.layers.preprocessing import category_encoding from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils from tensorflow.python.platform import test +def batch_wrapper(dataset, batch_size, distribution, repeat=None): + if repeat: + dataset = dataset.repeat(repeat) + # TPUs currently require fully defined input shapes, drop_remainder ensures + # the input will have fully defined shapes. + if isinstance(distribution, + (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)): + return dataset.batch(batch_size, drop_remainder=True) + else: + return dataset.batch(batch_size) + + @combinations.generate( combinations.combine( - distribution=strategy_combinations.all_strategies, + # (b/156783625): Outside compilation failed for eager mode only. + distribution=strategy_combinations.strategies_minus_tpu, mode=["eager", "graph"])) -class CategoricalEncodingDistributionTest( +class CategoryEncodingDistributionTest( keras_parameterized.TestCase, preprocessing_test_utils.PreprocessingLayerTest): def test_distribution(self, distribution): input_array = np.array([[1, 2, 3, 1], [0, 3, 1, 0]]) + inp_dataset = dataset_ops.DatasetV2.from_tensor_slices(input_array) + inp_dataset = batch_wrapper(inp_dataset, 2, distribution) # pyformat: disable expected_output = [[0, 1, 1, 1, 0, 0], [1, 1, 0, 1, 0, 0]] # pyformat: enable max_tokens = 6 + config.set_soft_device_placement(True) with distribution.scope(): input_data = keras.Input(shape=(4,), dtype=dtypes.int32) - layer = categorical_encoding.CategoricalEncoding( - max_tokens=max_tokens, output_mode=categorical_encoding.BINARY) + layer = category_encoding.CategoryEncoding( + max_tokens=max_tokens, output_mode=category_encoding.BINARY) int_data = layer(input_data) model = keras.Model(inputs=input_data, outputs=int_data) - output_dataset = model.predict(input_array) + output_dataset = model.predict(inp_dataset) self.assertAllEqual(expected_output, output_dataset) diff --git a/tensorflow/python/keras/layers/preprocessing/categorical_encoding_test.py b/tensorflow/python/keras/layers/preprocessing/category_encoding_test.py similarity index 88% rename from tensorflow/python/keras/layers/preprocessing/categorical_encoding_test.py rename to tensorflow/python/keras/layers/preprocessing/category_encoding_test.py index e21e95a0078..08aa6d4871b 100644 --- a/tensorflow/python/keras/layers/preprocessing/categorical_encoding_test.py +++ b/tensorflow/python/keras/layers/preprocessing/category_encoding_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for Keras text categorical_encoding preprocessing layer.""" +"""Tests for Keras text category_encoding preprocessing layer.""" from __future__ import absolute_import from __future__ import division @@ -32,8 +32,8 @@ from tensorflow.python.framework import sparse_tensor from tensorflow.python.keras import backend from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras.layers import core -from tensorflow.python.keras.layers.preprocessing import categorical_encoding -from tensorflow.python.keras.layers.preprocessing import categorical_encoding_v1 +from tensorflow.python.keras.layers.preprocessing import category_encoding +from tensorflow.python.keras.layers.preprocessing import category_encoding_v1 from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils from tensorflow.python.ops import math_ops from tensorflow.python.ops import sparse_ops @@ -44,15 +44,15 @@ from tensorflow.python.platform import test def get_layer_class(): if context.executing_eagerly(): - return categorical_encoding.CategoricalEncoding + return category_encoding.CategoryEncoding else: - return categorical_encoding_v1.CategoricalEncoding + return category_encoding_v1.CategoryEncoding @keras_parameterized.run_all_keras_modes(always_skip_v1=True) -class CategoricalEncodingInputTest( - keras_parameterized.TestCase, - preprocessing_test_utils.PreprocessingLayerTest): +class CategoryEncodingInputTest(keras_parameterized.TestCase, + preprocessing_test_utils.PreprocessingLayerTest + ): def test_dense_input_sparse_output(self): input_array = constant_op.constant([[1, 2, 3], [3, 3, 0]]) @@ -67,9 +67,7 @@ class CategoricalEncodingInputTest( input_data = keras.Input(shape=(None,), dtype=dtypes.int32) layer = get_layer_class()( - max_tokens=max_tokens, - output_mode=categorical_encoding.COUNT, - sparse=True) + max_tokens=max_tokens, output_mode=category_encoding.COUNT, sparse=True) int_data = layer(input_data) model = keras.Model(inputs=input_data, outputs=int_data) @@ -80,7 +78,7 @@ class CategoricalEncodingInputTest( # Assert sparse output is same as dense output. layer = get_layer_class()( max_tokens=max_tokens, - output_mode=categorical_encoding.COUNT, + output_mode=category_encoding.COUNT, sparse=False) int_data = layer(input_data) model = keras.Model(inputs=input_data, outputs=int_data) @@ -103,7 +101,7 @@ class CategoricalEncodingInputTest( input_data = keras.Input(shape=(None,), dtype=dtypes.int64, sparse=True) layer = get_layer_class()( - max_tokens=max_tokens, output_mode=categorical_encoding.BINARY) + max_tokens=max_tokens, output_mode=category_encoding.BINARY) int_data = layer(input_data) self.assertAllEqual(expected_output_shape, int_data.shape.as_list()) @@ -128,9 +126,7 @@ class CategoricalEncodingInputTest( max_tokens = 6 layer = get_layer_class()( - max_tokens=max_tokens, - output_mode=categorical_encoding.COUNT, - sparse=True) + max_tokens=max_tokens, output_mode=category_encoding.COUNT, sparse=True) int_data = layer(input_data) model = keras.Model(inputs=input_data, outputs=int_data) @@ -141,7 +137,7 @@ class CategoricalEncodingInputTest( # Assert sparse output is same as dense output. layer = get_layer_class()( max_tokens=max_tokens, - output_mode=categorical_encoding.COUNT, + output_mode=category_encoding.COUNT, sparse=False) int_data = layer(input_data) model = keras.Model(inputs=input_data, outputs=int_data) @@ -163,7 +159,7 @@ class CategoricalEncodingInputTest( input_data = keras.Input(shape=(None,), dtype=dtypes.int32, ragged=True) layer = get_layer_class()( - max_tokens=max_tokens, output_mode=categorical_encoding.BINARY) + max_tokens=max_tokens, output_mode=category_encoding.BINARY) int_data = layer(input_data) self.assertAllEqual(expected_output_shape, int_data.shape.as_list()) @@ -184,9 +180,7 @@ class CategoricalEncodingInputTest( input_data = keras.Input(shape=(None,), dtype=dtypes.int32, ragged=True) layer = get_layer_class()( - max_tokens=max_tokens, - output_mode=categorical_encoding.COUNT, - sparse=True) + max_tokens=max_tokens, output_mode=category_encoding.COUNT, sparse=True) int_data = layer(input_data) model = keras.Model(inputs=input_data, outputs=int_data) @@ -197,7 +191,7 @@ class CategoricalEncodingInputTest( # Assert sparse output is same as dense output. layer = get_layer_class()( max_tokens=max_tokens, - output_mode=categorical_encoding.COUNT, + output_mode=category_encoding.COUNT, sparse=False) int_data = layer(input_data) model = keras.Model(inputs=input_data, outputs=int_data) @@ -214,9 +208,7 @@ class CategoricalEncodingInputTest( input_data = keras.Input(shape=(None,), dtype=dtypes.int32) encoding_layer = get_layer_class()( - max_tokens=max_tokens, - output_mode=categorical_encoding.COUNT, - sparse=True) + max_tokens=max_tokens, output_mode=category_encoding.COUNT, sparse=True) int_data = encoding_layer(input_data) output_data = math_ops.cast(int_data, dtypes.float32) weights = variables.Variable([[.1], [.2], [.3], [.4]], dtype=dtypes.float32) @@ -228,9 +220,9 @@ class CategoricalEncodingInputTest( @keras_parameterized.run_all_keras_modes -class CategoricalEncodingAdaptTest( - keras_parameterized.TestCase, - preprocessing_test_utils.PreprocessingLayerTest): +class CategoryEncodingAdaptTest(keras_parameterized.TestCase, + preprocessing_test_utils.PreprocessingLayerTest + ): def test_sparse_adapt(self): vocab_data = sparse_ops.from_dense( @@ -248,7 +240,7 @@ class CategoricalEncodingAdaptTest( input_data = keras.Input(shape=(None,), dtype=dtypes.int64, sparse=True) layer = get_layer_class()( - max_tokens=None, output_mode=categorical_encoding.BINARY) + max_tokens=None, output_mode=category_encoding.BINARY) layer.adapt(vocab_dataset) int_data = layer(input_data) self.assertAllEqual(expected_output_shape, int_data.shape.as_list()) @@ -273,7 +265,7 @@ class CategoricalEncodingAdaptTest( input_data = keras.Input(shape=(None,), dtype=dtypes.int32, ragged=True) layer = get_layer_class()( - max_tokens=None, output_mode=categorical_encoding.BINARY) + max_tokens=None, output_mode=category_encoding.BINARY) layer.adapt(vocab_dataset) int_data = layer(input_data) @@ -296,7 +288,7 @@ class CategoricalEncodingAdaptTest( input_data = keras.Input(shape=(None,), dtype=dtypes.int32) layer = get_layer_class()( - max_tokens=max_tokens, output_mode=categorical_encoding.BINARY) + max_tokens=max_tokens, output_mode=category_encoding.BINARY) int_data = layer(input_data) layer.adapt(vocab_data) self.assertAllEqual(expected_output_shape, int_data.shape.as_list()) @@ -306,7 +298,7 @@ class CategoricalEncodingAdaptTest( self.assertAllEqual(expected_output, output_dataset) def test_hard_maximum_set_state_variables_after_build(self): - state_variables = {categorical_encoding._NUM_ELEMENTS_NAME: 5} + state_variables = {category_encoding._NUM_ELEMENTS_NAME: 5} input_array = np.array([[1, 2, 3, 1], [0, 3, 1, 0]]) # pyformat: disable @@ -318,7 +310,7 @@ class CategoricalEncodingAdaptTest( input_data = keras.Input(shape=(None,), dtype=dtypes.int32) layer = get_layer_class()( - max_tokens=max_tokens, output_mode=categorical_encoding.BINARY) + max_tokens=max_tokens, output_mode=category_encoding.BINARY) int_data = layer(input_data) layer._set_state_variables(state_variables) self.assertAllEqual(expected_output_shape, int_data.shape.as_list()) @@ -339,7 +331,7 @@ class CategoricalEncodingAdaptTest( input_data = keras.Input(shape=(None,), dtype=dtypes.int32) layer = get_layer_class()( - max_tokens=None, output_mode=categorical_encoding.BINARY) + max_tokens=None, output_mode=category_encoding.BINARY) layer.build(input_data.shape) layer.set_num_elements(max_tokens) int_data = layer(input_data) @@ -351,8 +343,7 @@ class CategoricalEncodingAdaptTest( def test_set_weights_fails_on_wrong_size_weights(self): tfidf_data = [.05, .5, .25, .2, .125] - layer = get_layer_class()( - max_tokens=6, output_mode=categorical_encoding.TFIDF) + layer = get_layer_class()(max_tokens=6, output_mode=category_encoding.TFIDF) with self.assertRaisesRegex(ValueError, ".*Layer weight shape.*"): layer.set_weights([np.array(tfidf_data)]) @@ -360,7 +351,7 @@ class CategoricalEncodingAdaptTest( def test_set_num_elements_after_call_fails(self): input_data = keras.Input(shape=(None,), dtype=dtypes.int32) layer = get_layer_class()( - max_tokens=None, output_mode=categorical_encoding.BINARY) + max_tokens=None, output_mode=category_encoding.BINARY) _ = layer(input_data) with self.assertRaisesRegex(RuntimeError, "num_elements cannot be changed"): layer.set_num_elements(5) @@ -370,17 +361,17 @@ class CategoricalEncodingAdaptTest( input_data = keras.Input(shape=(None,), dtype=dtypes.int32) layer = get_layer_class()( - max_tokens=None, output_mode=categorical_encoding.BINARY) + max_tokens=None, output_mode=category_encoding.BINARY) _ = layer(input_data) with self.assertRaisesRegex(RuntimeError, "can't be adapted"): layer.adapt(vocab_data) def test_set_state_variables_after_call_fails(self): - state_variables = {categorical_encoding._NUM_ELEMENTS_NAME: 5} + state_variables = {category_encoding._NUM_ELEMENTS_NAME: 5} input_data = keras.Input(shape=(None,), dtype=dtypes.int32) layer = get_layer_class()( - max_tokens=None, output_mode=categorical_encoding.BINARY) + max_tokens=None, output_mode=category_encoding.BINARY) _ = layer(input_data) with self.assertRaisesRegex(RuntimeError, "num_elements cannot be changed"): layer._set_state_variables(state_variables) @@ -388,9 +379,9 @@ class CategoricalEncodingAdaptTest( @keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes -class CategoricalEncodingOutputTest( - keras_parameterized.TestCase, - preprocessing_test_utils.PreprocessingLayerTest): +class CategoryEncodingOutputTest(keras_parameterized.TestCase, + preprocessing_test_utils.PreprocessingLayerTest + ): def test_binary_output_hard_maximum(self): input_array = np.array([[1, 2, 3, 1], [0, 3, 1, 0]]) @@ -404,7 +395,7 @@ class CategoricalEncodingOutputTest( input_data = keras.Input(shape=(None,), dtype=dtypes.int32) layer = get_layer_class()( - max_tokens=max_tokens, output_mode=categorical_encoding.BINARY) + max_tokens=max_tokens, output_mode=category_encoding.BINARY) int_data = layer(input_data) self.assertAllEqual(expected_output_shape, int_data.shape.as_list()) @@ -424,7 +415,7 @@ class CategoricalEncodingOutputTest( input_data = keras.Input(shape=(None,), dtype=dtypes.int32) layer = get_layer_class()( - max_tokens=None, output_mode=categorical_encoding.BINARY) + max_tokens=None, output_mode=category_encoding.BINARY) layer.set_weights([np.array(max_tokens)]) int_data = layer(input_data) self.assertAllEqual(expected_output_shape, int_data.shape.as_list()) @@ -444,8 +435,7 @@ class CategoricalEncodingOutputTest( expected_output_shape = [None, max_tokens] input_data = keras.Input(shape=(None,), dtype=dtypes.int32) - layer = get_layer_class()( - max_tokens=6, output_mode=categorical_encoding.COUNT) + layer = get_layer_class()(max_tokens=6, output_mode=category_encoding.COUNT) int_data = layer(input_data) self.assertAllEqual(expected_output_shape, int_data.shape.as_list()) @@ -465,7 +455,7 @@ class CategoricalEncodingOutputTest( input_data = keras.Input(shape=(None,), dtype=dtypes.int32) layer = get_layer_class()( - max_tokens=None, output_mode=categorical_encoding.COUNT) + max_tokens=None, output_mode=category_encoding.COUNT) layer.set_weights([np.array(max_tokens)]) int_data = layer(input_data) self.assertAllEqual(expected_output_shape, int_data.shape.as_list()) @@ -488,8 +478,7 @@ class CategoricalEncodingOutputTest( expected_output_shape = [None, max_tokens] input_data = keras.Input(shape=(None,), dtype=dtypes.int32) - layer = get_layer_class()( - max_tokens=6, output_mode=categorical_encoding.TFIDF) + layer = get_layer_class()(max_tokens=6, output_mode=category_encoding.TFIDF) layer.set_tfidf_data(tfidf_data) int_data = layer(input_data) self.assertAllEqual(expected_output_shape, int_data.shape.as_list()) @@ -513,7 +502,7 @@ class CategoricalEncodingOutputTest( input_data = keras.Input(shape=(None,), dtype=dtypes.int32) layer = get_layer_class()( - max_tokens=None, output_mode=categorical_encoding.TFIDF) + max_tokens=None, output_mode=category_encoding.TFIDF) layer.set_num_elements(max_tokens) layer.set_tfidf_data(tfidf_data) int_data = layer(input_data) @@ -524,7 +513,7 @@ class CategoricalEncodingOutputTest( self.assertAllClose(expected_output, output_dataset) -class CategoricalEncodingModelBuildingTest( +class CategoryEncodingModelBuildingTest( keras_parameterized.TestCase, preprocessing_test_utils.PreprocessingLayerTest): @@ -532,27 +521,27 @@ class CategoricalEncodingModelBuildingTest( { "testcase_name": "count_hard_max", "max_tokens": 5, - "output_mode": categorical_encoding.COUNT + "output_mode": category_encoding.COUNT }, { "testcase_name": "count_soft_max", "max_tokens": None, - "output_mode": categorical_encoding.COUNT + "output_mode": category_encoding.COUNT }, { "testcase_name": "binary_hard_max", "max_tokens": 5, - "output_mode": categorical_encoding.BINARY + "output_mode": category_encoding.BINARY }, { "testcase_name": "binary_soft_max", "max_tokens": None, - "output_mode": categorical_encoding.BINARY + "output_mode": category_encoding.BINARY }, { "testcase_name": "tfidf_hard_max", "max_tokens": 5, - "output_mode": categorical_encoding.TFIDF + "output_mode": category_encoding.TFIDF }, { "testcase_name": "tfidf_soft_max", "max_tokens": None, - "output_mode": categorical_encoding.TFIDF + "output_mode": category_encoding.TFIDF }) def test_end_to_end_bagged_modeling(self, output_mode, max_tokens): tfidf_data = np.array([.03, .5, .25, .2, .125]) @@ -564,7 +553,7 @@ class CategoricalEncodingModelBuildingTest( weights = [] if max_tokens is None: weights.append(np.array(5)) - if output_mode == categorical_encoding.TFIDF: + if output_mode == category_encoding.TFIDF: weights.append(tfidf_data) layer.set_weights(weights) @@ -577,7 +566,7 @@ class CategoricalEncodingModelBuildingTest( @keras_parameterized.run_all_keras_modes -class CategoricalEncodingCombinerTest( +class CategoryEncodingCombinerTest( keras_parameterized.TestCase, preprocessing_test_utils.PreprocessingLayerTest): @@ -617,8 +606,7 @@ class CategoricalEncodingCombinerTest( def test_combiner_api_compatibility_int_mode(self): data = np.array([[1, 2, 3, 4], [1, 2, 3, 0]]) - combiner = categorical_encoding._CategoricalEncodingCombiner( - compute_idf=False) + combiner = category_encoding._CategoryEncodingCombiner(compute_idf=False) expected_accumulator_output = { "max_element": np.array(4), "num_documents": np.array(2), @@ -636,8 +624,7 @@ class CategoricalEncodingCombinerTest( def test_combiner_api_compatibility_tfidf_mode(self): data = np.array([[1, 2, 3, 4], [1, 2, 3, 0]]) - combiner = categorical_encoding._CategoricalEncodingCombiner( - compute_idf=True) + combiner = category_encoding._CategoryEncodingCombiner(compute_idf=True) expected_accumulator_output = { "max_element": np.array(4), "document_counts": np.array([1, 2, 2, 2, 1]), @@ -693,7 +680,7 @@ class CategoricalEncodingCombinerTest( expected_accumulator_output, expected_extract_output, compute_idf=True): - combiner = categorical_encoding._CategoricalEncodingCombiner( + combiner = category_encoding._CategoryEncodingCombiner( compute_idf=compute_idf) expected_accumulator = combiner._create_accumulator() expected_accumulator = self.update_accumulator(expected_accumulator, @@ -702,6 +689,5 @@ class CategoricalEncodingCombinerTest( self.validate_accumulator_extract(combiner, data, expected_extract_output) - if __name__ == "__main__": test.main() diff --git a/tensorflow/python/keras/layers/preprocessing/categorical_encoding_v1.py b/tensorflow/python/keras/layers/preprocessing/category_encoding_v1.py similarity index 89% rename from tensorflow/python/keras/layers/preprocessing/categorical_encoding_v1.py rename to tensorflow/python/keras/layers/preprocessing/category_encoding_v1.py index 83128ed5095..3afb86b344f 100644 --- a/tensorflow/python/keras/layers/preprocessing/categorical_encoding_v1.py +++ b/tensorflow/python/keras/layers/preprocessing/category_encoding_v1.py @@ -12,20 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tensorflow V1 version of the text categorical_encoding preprocessing layer.""" +"""Tensorflow V1 version of the text category_encoding preprocessing layer.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from tensorflow.python.keras.engine import base_preprocessing_layer_v1 -from tensorflow.python.keras.layers.preprocessing import categorical_encoding +from tensorflow.python.keras.layers.preprocessing import category_encoding +from tensorflow.python.util.tf_export import keras_export -class CategoricalEncoding(categorical_encoding.CategoricalEncoding, - base_preprocessing_layer_v1.CombinerPreprocessingLayer - ): - """CategoricalEncoding layer. +@keras_export(v1=["keras.layers.experimental.preprocessing.CategoryEncoding"]) +class CategoryEncoding(category_encoding.CategoryEncoding, + base_preprocessing_layer_v1.CombinerPreprocessingLayer): + """CategoryEncoding layer. This layer provides options for condensing input data into denser representations. It accepts either integer values or strings as inputs, diff --git a/tensorflow/python/keras/layers/preprocessing/discretization.py b/tensorflow/python/keras/layers/preprocessing/discretization.py index 003b6e64f90..3052cfb4369 100644 --- a/tensorflow/python/keras/layers/preprocessing/discretization.py +++ b/tensorflow/python/keras/layers/preprocessing/discretization.py @@ -52,6 +52,16 @@ class Discretization(Layer): exclude the right boundary, so `bins=[0., 1., 2.]` generates bins `(-inf, 0.)`, `[0., 1.)`, `[1., 2.)`, and `[2., +inf)`. output_mode: One of 'int', 'binary'. Defaults to 'int'. + + Examples: + + Bucketize float values based on provided buckets. + >>> input = np.array([[-1.5, 1.0, 3.4, .5], [0.0, 3.0, 1.3, 0.0]]) + >>> layer = Discretization(bins=[0., 1., 2.]) + >>> layer(input) + <tf.Tensor: shape=(2, 4), dtype=int32, numpy= + array([[0, 2, 3, 1], + [1, 3, 2, 1]], dtype=int32)> """ def __init__(self, bins, output_mode=INTEGER, **kwargs): diff --git a/tensorflow/python/keras/layers/preprocessing/hashing.py b/tensorflow/python/keras/layers/preprocessing/hashing.py index dfd4761f193..05b4445829a 100644 --- a/tensorflow/python/keras/layers/preprocessing/hashing.py +++ b/tensorflow/python/keras/layers/preprocessing/hashing.py @@ -22,20 +22,28 @@ import functools from tensorflow.python.framework import dtypes from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec from tensorflow.python.keras.engine.base_layer import Layer +from tensorflow.python.ops import gen_sparse_ops +from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import string_ops from tensorflow.python.ops.ragged import ragged_functional_ops from tensorflow.python.ops.ragged import ragged_tensor +from tensorflow.python.util.tf_export import keras_export + +# Default key from tf.sparse.cross_hashed +_DEFAULT_SALT_KEY = [0xDECAFCAFFE, 0xDECAFCAFFE] +@keras_export('keras.layers.experimental.preprocessing.Hashing') class Hashing(Layer): """Implements categorical feature hashing, also known as "hashing trick". - This layer transforms categorical inputs to hashed output. It converts a - sequence of int or string to a sequence of int. The stable hash function uses - tensorflow::ops::Fingerprint to produce universal output that is consistent - across platforms. + This layer transforms single or multiple categorical inputs to hashed output. + It converts a sequence of int or string to a sequence of int. The stable hash + function uses tensorflow::ops::Fingerprint to produce universal output that + is consistent across platforms. This layer uses [FarmHash64](https://github.com/google/farmhash) by default, which provides a consistent hashed output across different platforms and is @@ -48,50 +56,91 @@ class Hashing(Layer): the `salt` value serving as additional input to the hash function. Example (FarmHash64): - ```python - layer = Hashing(num_bins=3) - inp = np.asarray([['A'], ['B'], ['C'], ['D'], ['E']]) - layer(inputs) - [[1], [0], [1], [1], [2]] - ``` + + >>> layer = tf.keras.layers.experimental.preprocessing.Hashing(num_bins=3) + >>> inp = np.asarray([['A'], ['B'], ['C'], ['D'], ['E']]) + >>> layer(inp) + <tf.Tensor: shape=(5, 1), dtype=int64, numpy= + array([[1], + [0], + [1], + [1], + [2]])> + Example (SipHash64): - ```python - layer = Hashing(num_bins=3, salt=[133, 137]) - inp = np.asarray([['A'], ['B'], ['C'], ['D'], ['E']]) - layer(inputs) - [[1], [2], [1], [0], [2]] - ``` + + >>> layer = tf.keras.layers.experimental.preprocessing.Hashing(num_bins=3, + ... salt=[133, 137]) + >>> inp = np.asarray([['A'], ['B'], ['C'], ['D'], ['E']]) + >>> layer(inp) + <tf.Tensor: shape=(5, 1), dtype=int64, numpy= + array([[1], + [2], + [1], + [0], + [2]])> + + Example (Siphash64 with a single integer, same as `salt=[133, 133]` + + >>> layer = tf.keras.layers.experimental.preprocessing.Hashing(num_bins=3, + ... salt=133) + >>> inp = np.asarray([['A'], ['B'], ['C'], ['D'], ['E']]) + >>> layer(inp) + <tf.Tensor: shape=(5, 1), dtype=int64, numpy= + array([[0], + [0], + [2], + [1], + [0]])> + + Reference: [SipHash with salt](https://www.131002.net/siphash/siphash.pdf) Arguments: num_bins: Number of hash bins. - salt: A tuple/list of 2 unsigned integer numbers. If passed, the hash - function used will be SipHash64, with these values used as an additional - input (known as a "salt" in cryptography). + salt: A single unsigned integer or None. + If passed, the hash function used will be SipHash64, with these values + used as an additional input (known as a "salt" in cryptography). These should be non-zero. Defaults to `None` (in that - case, the FarmHash64 hash function is used). + case, the FarmHash64 hash function is used). It also supports + tuple/list of 2 unsigned integer numbers, see reference paper for details. name: Name to give to the layer. **kwargs: Keyword arguments to construct a layer. - Input shape: A string, int32 or int64 tensor of shape - `[batch_size, d1, ..., dm]` + Input shape: A single or list of string, int32 or int64 `Tensor`, + `SparseTensor` or `RaggedTensor` of shape `[batch_size, ...,]` - Output shape: An int64 tensor of shape `[batch_size, d1, ..., dm]` + Output shape: An int64 `Tensor`, `SparseTensor` or `RaggedTensor` of shape + `[batch_size, ...]`. If any input is `RaggedTensor` then output is + `RaggedTensor`, otherwise if any input is `SparseTensor` then output is + `SparseTensor`, otherwise the output is `Tensor`. """ def __init__(self, num_bins, salt=None, name=None, **kwargs): if num_bins is None or num_bins <= 0: raise ValueError('`num_bins` cannot be `None` or non-positive values.') - if salt is not None: - if not isinstance(salt, (tuple, list)) or len(salt) != 2: - raise ValueError('`salt` must be a tuple or list of 2 unsigned ' - 'integer numbers, got {}'.format(salt)) super(Hashing, self).__init__(name=name, **kwargs) self.num_bins = num_bins - self.salt = salt + self.strong_hash = True if salt is not None else False + if salt is not None: + if isinstance(salt, (tuple, list)) and len(salt) == 2: + self.salt = salt + elif isinstance(salt, int): + self.salt = [salt, salt] + else: + raise ValueError('`salt can only be a tuple of size 2 integers, or a ' + 'single integer, given {}'.format(salt)) + else: + self.salt = _DEFAULT_SALT_KEY def call(self, inputs): + if isinstance(inputs, (tuple, list)): + return self._process_input_list(inputs) + else: + return self._process_single_input(inputs) + + def _process_single_input(self, inputs): # Converts integer inputs to string. if inputs.dtype.is_integer: if isinstance(inputs, sparse_tensor.SparseTensor): @@ -116,10 +165,38 @@ class Hashing(Layer): else: return str_to_hash_bucket(inputs, self.num_bins, name='hash') + def _process_input_list(self, inputs): + # TODO(momernick): support ragged_cross_hashed with corrected fingerprint + # and siphash. + if any([isinstance(inp, ragged_tensor.RaggedTensor) for inp in inputs]): + raise ValueError('Hashing with ragged input is not supported yet.') + sparse_inputs = [ + inp for inp in inputs if isinstance(inp, sparse_tensor.SparseTensor) + ] + dense_inputs = [ + inp for inp in inputs if not isinstance(inp, sparse_tensor.SparseTensor) + ] + all_dense = True if not sparse_inputs else False + indices = [sp_inp.indices for sp_inp in sparse_inputs] + values = [sp_inp.values for sp_inp in sparse_inputs] + shapes = [sp_inp.dense_shape for sp_inp in sparse_inputs] + indices_out, values_out, shapes_out = gen_sparse_ops.sparse_cross_hashed( + indices=indices, + values=values, + shapes=shapes, + dense_inputs=dense_inputs, + num_buckets=self.num_bins, + strong_hash=self.strong_hash, + salt=self.salt) + sparse_out = sparse_tensor.SparseTensor(indices_out, values_out, shapes_out) + if all_dense: + return sparse_ops.sparse_tensor_to_dense(sparse_out) + return sparse_out + def _get_string_to_hash_bucket_fn(self): """Returns the string_to_hash_bucket op to use based on `hasher_key`.""" # string_to_hash_bucket_fast uses FarmHash64 as hash function. - if self.salt is None: + if not self.strong_hash: return string_ops.string_to_hash_bucket_fast # string_to_hash_bucket_strong uses SipHash64 as hash function. else: @@ -127,16 +204,43 @@ class Hashing(Layer): string_ops.string_to_hash_bucket_strong, key=self.salt) def compute_output_shape(self, input_shape): - return input_shape + if not isinstance(input_shape, (tuple, list)): + return input_shape + input_shapes = input_shape + batch_size = None + for inp_shape in input_shapes: + inp_tensor_shape = tensor_shape.TensorShape(inp_shape).as_list() + if len(inp_tensor_shape) != 2: + raise ValueError('Inputs must be rank 2, get {}'.format(input_shapes)) + if batch_size is None: + batch_size = inp_tensor_shape[0] + # The second dimension is dynamic based on inputs. + output_shape = [batch_size, None] + return tensor_shape.TensorShape(output_shape) def compute_output_signature(self, input_spec): - output_shape = self.compute_output_shape(input_spec.shape.as_list()) - output_dtype = dtypes.int64 - if isinstance(input_spec, sparse_tensor.SparseTensorSpec): + if not isinstance(input_spec, (tuple, list)): + output_shape = self.compute_output_shape(input_spec.shape) + output_dtype = dtypes.int64 + if isinstance(input_spec, sparse_tensor.SparseTensorSpec): + return sparse_tensor.SparseTensorSpec( + shape=output_shape, dtype=output_dtype) + else: + return tensor_spec.TensorSpec(shape=output_shape, dtype=output_dtype) + input_shapes = [x.shape for x in input_spec] + output_shape = self.compute_output_shape(input_shapes) + if any([ + isinstance(inp_spec, ragged_tensor.RaggedTensorSpec) + for inp_spec in input_spec + ]): + return tensor_spec.TensorSpec(shape=output_shape, dtype=dtypes.int64) + elif any([ + isinstance(inp_spec, sparse_tensor.SparseTensorSpec) + for inp_spec in input_spec + ]): return sparse_tensor.SparseTensorSpec( - shape=output_shape, dtype=output_dtype) - else: - return tensor_spec.TensorSpec(shape=output_shape, dtype=output_dtype) + shape=output_shape, dtype=dtypes.int64) + return tensor_spec.TensorSpec(shape=output_shape, dtype=dtypes.int64) def get_config(self): config = {'num_bins': self.num_bins, 'salt': self.salt} diff --git a/tensorflow/python/keras/layers/preprocessing/hashing_test.py b/tensorflow/python/keras/layers/preprocessing/hashing_test.py index 147e4bc371b..4c3fd9c7501 100644 --- a/tensorflow/python/keras/layers/preprocessing/hashing_test.py +++ b/tensorflow/python/keras/layers/preprocessing/hashing_test.py @@ -51,6 +51,15 @@ class HashingTest(keras_parameterized.TestCase): # Assert equal for hashed output that should be true on all platforms. self.assertAllClose([[0], [0], [1], [0], [0]], output) + def test_hash_dense_multi_inputs_farmhash(self): + layer = hashing.Hashing(num_bins=2) + inp_1 = np.asarray([['omar'], ['stringer'], ['marlo'], ['wire'], + ['skywalker']]) + inp_2 = np.asarray([['A'], ['B'], ['C'], ['D'], ['E']]) + output = layer([inp_1, inp_2]) + # Assert equal for hashed output that should be true on all platforms. + self.assertAllClose([[0], [0], [1], [1], [0]], output) + def test_hash_dense_int_input_farmhash(self): layer = hashing.Hashing(num_bins=3) inp = np.asarray([[0], [1], [2], [3], [4]]) @@ -72,6 +81,21 @@ class HashingTest(keras_parameterized.TestCase): # Note the result is different from (133, 137). self.assertAllClose([[1], [0], [1], [0], [1]], output_2) + def test_hash_dense_multi_inputs_siphash(self): + layer = hashing.Hashing(num_bins=2, salt=[133, 137]) + inp_1 = np.asarray([['omar'], ['stringer'], ['marlo'], ['wire'], + ['skywalker']]) + inp_2 = np.asarray([['A'], ['B'], ['C'], ['D'], ['E']]) + output = layer([inp_1, inp_2]) + # Assert equal for hashed output that should be true on all platforms. + # Note the result is different from FarmHash. + self.assertAllClose([[0], [1], [0], [0], [1]], output) + + layer_2 = hashing.Hashing(num_bins=2, salt=[211, 137]) + output_2 = layer_2([inp_1, inp_2]) + # Note the result is different from (133, 137). + self.assertAllClose([[1], [1], [1], [0], [1]], output_2) + def test_hash_dense_int_input_siphash(self): layer = hashing.Hashing(num_bins=3, salt=[133, 137]) inp = np.asarray([[0], [1], [2], [3], [4]]) @@ -90,6 +114,19 @@ class HashingTest(keras_parameterized.TestCase): self.assertAllClose(indices, output.indices) self.assertAllClose([0, 0, 1, 0, 0], output.values) + def test_hash_sparse_multi_inputs_farmhash(self): + layer = hashing.Hashing(num_bins=2) + indices = [[0, 0], [1, 0], [2, 0]] + inp_1 = sparse_tensor.SparseTensor( + indices=indices, + values=['omar', 'stringer', 'marlo'], + dense_shape=[3, 1]) + inp_2 = sparse_tensor.SparseTensor( + indices=indices, values=['A', 'B', 'C'], dense_shape=[3, 1]) + output = layer([inp_1, inp_2]) + self.assertAllClose(indices, output.indices) + self.assertAllClose([0, 0, 1], output.values) + def test_hash_sparse_int_input_farmhash(self): layer = hashing.Hashing(num_bins=3) indices = [[0, 0], [1, 0], [1, 1], [2, 0], [2, 1]] @@ -116,6 +153,25 @@ class HashingTest(keras_parameterized.TestCase): # The result should be same with test_hash_dense_input_siphash. self.assertAllClose([1, 0, 1, 0, 1], output.values) + def test_hash_sparse_multi_inputs_siphash(self): + layer = hashing.Hashing(num_bins=2, salt=[133, 137]) + indices = [[0, 0], [1, 0], [2, 0]] + inp_1 = sparse_tensor.SparseTensor( + indices=indices, + values=['omar', 'stringer', 'marlo'], + dense_shape=[3, 1]) + inp_2 = sparse_tensor.SparseTensor( + indices=indices, values=['A', 'B', 'C'], dense_shape=[3, 1]) + output = layer([inp_1, inp_2]) + # The result should be same with test_hash_dense_input_siphash. + self.assertAllClose(indices, output.indices) + self.assertAllClose([0, 1, 0], output.values) + + layer_2 = hashing.Hashing(num_bins=2, salt=[211, 137]) + output = layer_2([inp_1, inp_2]) + # The result should be same with test_hash_dense_input_siphash. + self.assertAllClose([1, 1, 1], output.values) + def test_hash_sparse_int_input_siphash(self): layer = hashing.Hashing(num_bins=3, salt=[133, 137]) indices = [[0, 0], [1, 0], [1, 1], [2, 0], [2, 1]] @@ -140,6 +196,17 @@ class HashingTest(keras_parameterized.TestCase): model = training.Model(inputs=inp_t, outputs=out_t) self.assertAllClose(out_data, model.predict(inp_data)) + def test_hash_ragged_string_multi_inputs_farmhash(self): + layer = hashing.Hashing(num_bins=2) + inp_data_1 = ragged_factory_ops.constant( + [['omar', 'stringer', 'marlo', 'wire'], ['marlo', 'skywalker', 'wire']], + dtype=dtypes.string) + inp_data_2 = ragged_factory_ops.constant( + [['omar', 'stringer', 'marlo', 'wire'], ['marlo', 'skywalker', 'wire']], + dtype=dtypes.string) + with self.assertRaisesRegexp(ValueError, 'not supported yet'): + _ = layer([inp_data_1, inp_data_2]) + def test_hash_ragged_int_input_farmhash(self): layer = hashing.Hashing(num_bins=3) inp_data = ragged_factory_ops.constant([[0, 1, 3, 4], [2, 1, 0]], @@ -178,6 +245,17 @@ class HashingTest(keras_parameterized.TestCase): model = training.Model(inputs=inp_t, outputs=out_t) self.assertAllClose(out_data, model.predict(inp_data)) + def test_hash_ragged_string_multi_inputs_siphash(self): + layer = hashing.Hashing(num_bins=2, salt=[133, 137]) + inp_data_1 = ragged_factory_ops.constant( + [['omar', 'stringer', 'marlo', 'wire'], ['marlo', 'skywalker', 'wire']], + dtype=dtypes.string) + inp_data_2 = ragged_factory_ops.constant( + [['omar', 'stringer', 'marlo', 'wire'], ['marlo', 'skywalker', 'wire']], + dtype=dtypes.string) + with self.assertRaisesRegexp(ValueError, 'not supported yet'): + _ = layer([inp_data_1, inp_data_2]) + def test_hash_ragged_int_input_siphash(self): layer = hashing.Hashing(num_bins=3, salt=[133, 137]) inp_data = ragged_factory_ops.constant([[0, 1, 3, 4], [2, 1, 0]], @@ -197,11 +275,11 @@ class HashingTest(keras_parameterized.TestCase): _ = hashing.Hashing(num_bins=None) with self.assertRaisesRegexp(ValueError, 'cannot be `None`'): _ = hashing.Hashing(num_bins=-1) - with self.assertRaisesRegexp(ValueError, 'must be a tuple'): + with self.assertRaisesRegexp(ValueError, 'can only be a tuple of size 2'): _ = hashing.Hashing(num_bins=2, salt='string') - with self.assertRaisesRegexp(ValueError, 'must be a tuple'): + with self.assertRaisesRegexp(ValueError, 'can only be a tuple of size 2'): _ = hashing.Hashing(num_bins=2, salt=[1]) - with self.assertRaisesRegexp(ValueError, 'must be a tuple'): + with self.assertRaisesRegexp(ValueError, 'can only be a tuple of size 2'): _ = hashing.Hashing(num_bins=1, salt=constant_op.constant([133, 137])) def test_hash_compute_output_signature(self): diff --git a/tensorflow/python/keras/layers/preprocessing/image_preprocessing.py b/tensorflow/python/keras/layers/preprocessing/image_preprocessing.py index 832915dac68..e4b92e44e69 100644 --- a/tensorflow/python/keras/layers/preprocessing/image_preprocessing.py +++ b/tensorflow/python/keras/layers/preprocessing/image_preprocessing.py @@ -292,11 +292,16 @@ class RandomCrop(Layer): @keras_export('keras.layers.experimental.preprocessing.Rescaling') class Rescaling(Layer): - """Multiply inputs by `scale`. + """Multiply inputs by `scale` and adds `offset`. - For instance, to rescale an input in the `[0, 255]` range + For instance: + + 1. To rescale an input in the `[0, 255]` range to be in the `[0, 1]` range, you would pass `scale=1./255`. + 2. To rescale an input in the `[0, 255]` range to be in the `[-1, 1]` range, + you would pass `scale=1./127.5, offset=-1`. + The rescaling is applied both during training and inference. Input shape: @@ -307,16 +312,20 @@ class Rescaling(Layer): Arguments: scale: Float, the scale to apply to the inputs. + offset: Float, the offset to apply to the inputs. name: A string, the name of the layer. """ - def __init__(self, scale, name=None, **kwargs): + def __init__(self, scale, offset=0., name=None, **kwargs): self.scale = scale + self.offset = offset super(Rescaling, self).__init__(name=name, **kwargs) def call(self, inputs): dtype = self._compute_dtype - return math_ops.cast(inputs, dtype) * math_ops.cast(self.scale, dtype) + scale = math_ops.cast(self.scale, dtype) + offset = math_ops.cast(self.offset, dtype) + return math_ops.cast(inputs, dtype) * scale + offset def compute_output_shape(self, input_shape): return input_shape @@ -324,6 +333,7 @@ class Rescaling(Layer): def get_config(self): config = { 'scale': self.scale, + 'offset': self.offset, } base_config = super(Rescaling, self).get_config() return dict(list(base_config.items()) + list(config.items())) diff --git a/tensorflow/python/keras/layers/preprocessing/image_preprocessing_test.py b/tensorflow/python/keras/layers/preprocessing/image_preprocessing_test.py index 38d2d25916a..14720d3541d 100644 --- a/tensorflow/python/keras/layers/preprocessing/image_preprocessing_test.py +++ b/tensorflow/python/keras/layers/preprocessing/image_preprocessing_test.py @@ -306,7 +306,7 @@ class RescalingTest(keras_parameterized.TestCase): @keras_parameterized.run_all_keras_modes(always_skip_v1=True) def test_rescaling_base(self): - kwargs = {'scale': 0.004} + kwargs = {'scale': 1./127.5, 'offset': -1.} testing_utils.layer_test( image_preprocessing.Rescaling, kwargs=kwargs, @@ -315,18 +315,18 @@ class RescalingTest(keras_parameterized.TestCase): @tf_test_util.run_v2_only def test_rescaling_correctness_float(self): - layer = image_preprocessing.Rescaling(0.004) + layer = image_preprocessing.Rescaling(scale=1./127.5, offset=-1.) inputs = random_ops.random_uniform((2, 4, 5, 3)) outputs = layer(inputs) - self.assertAllClose(outputs.numpy(), inputs.numpy() * 0.004) + self.assertAllClose(outputs.numpy(), inputs.numpy() * (1./127.5) - 1) @tf_test_util.run_v2_only def test_rescaling_correctness_int(self): - layer = image_preprocessing.Rescaling(0.004) + layer = image_preprocessing.Rescaling(scale=1./127.5, offset=-1) inputs = random_ops.random_uniform((2, 4, 5, 3), 0, 100, dtype='int32') outputs = layer(inputs) self.assertEqual(outputs.dtype.name, 'float32') - self.assertAllClose(outputs.numpy(), inputs.numpy() * 0.004) + self.assertAllClose(outputs.numpy(), inputs.numpy() * (1./127.5) - 1) def test_config_with_custom_name(self): layer = image_preprocessing.Rescaling(0.5, name='rescaling') diff --git a/tensorflow/python/keras/layers/preprocessing/index_lookup.py b/tensorflow/python/keras/layers/preprocessing/index_lookup.py index ba9b0d740e1..7d11feae341 100644 --- a/tensorflow/python/keras/layers/preprocessing/index_lookup.py +++ b/tensorflow/python/keras/layers/preprocessing/index_lookup.py @@ -29,6 +29,7 @@ from tensorflow.python.framework import tensor_spec from tensorflow.python.keras.engine import base_preprocessing_layer from tensorflow.python.keras.layers.preprocessing import table_utils from tensorflow.python.ops import lookup_ops +from tensorflow.python.ops import math_ops from tensorflow.python.util import compat # The string tokens in the extracted vocabulary @@ -75,8 +76,9 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer): only used when performing an inverse lookup. vocabulary: An optional list of vocabulary terms. If the list contains the same token multiple times, an error will be thrown. + invert: If true, this layer will map indices to vocabulary items instead + of mapping vocabulary items to indices. """ - # TODO(momernick): Add an examples section to the docstring. def __init__(self, max_tokens, @@ -84,17 +86,22 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer): mask_token, oov_token, vocabulary=None, + invert=False, **kwargs): # If max_tokens is set, the value must be greater than 1 - otherwise we # are creating a 0-element vocab, which doesn't make sense. if max_tokens is not None and max_tokens <= 1: - raise ValueError("If set, max_tokens must be greater than 1.") + raise ValueError("If set, `max_tokens` must be greater than 1.") if num_oov_indices < 0: - raise ValueError("num_oov_indices must be greater than 0. You passed %s" % - num_oov_indices) + raise ValueError("`num_oov_indices` must be greater than 0. You passed " + "%s" % num_oov_indices) + if invert and num_oov_indices != 1: + raise ValueError("`num_oov_tokens` must be 1 when `invert` is True.") + + self.invert = invert self.max_tokens = max_tokens self.num_oov_indices = num_oov_indices self.oov_token = oov_token @@ -111,16 +118,32 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer): else: self._oov_value = -1 + if max_tokens is not None: + num_mask_tokens = (0 if mask_token is None else 1) + vocab_size = max_tokens - (num_oov_indices + num_mask_tokens) + else: + vocab_size = None + super(IndexLookup, self).__init__( - combiner=_IndexLookupCombiner(self.max_tokens, self.mask_token), - **kwargs) + combiner=_IndexLookupCombiner(vocab_size, self.mask_token), **kwargs) self._output_dtype = dtypes.int64 + # We need to save the key dtype so that we know if we're expecting int64 + # keys. If we are, we will cast int32 inputs to int64 as well. + if invert: + self._key_dtype = self._output_dtype + value_dtype = self.dtype + oov_value = self.oov_token + else: + self._key_dtype = self.dtype + value_dtype = self._output_dtype + oov_value = self._oov_value + self._table = lookup_ops.MutableHashTable( - key_dtype=self.dtype, - value_dtype=self._output_dtype, - default_value=self._oov_value, + key_dtype=self._key_dtype, + value_dtype=value_dtype, + default_value=oov_value, name=(self._name + "_index_table")) tracked_table = self._add_trackable(self._table, trainable=False) # This is a workaround for summary() on this layer. Because the table is @@ -149,7 +172,7 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer): def compute_output_signature(self, input_spec): output_shape = self.compute_output_shape(input_spec.shape.as_list()) - output_dtype = dtypes.int64 + output_dtype = self.dtype if self.invert else self._output_dtype return tensor_spec.TensorSpec(shape=output_shape, dtype=output_dtype) def adapt(self, data, reset_state=True): @@ -176,13 +199,18 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer): keys, values = self._table_handler.data() # This is required because the MutableHashTable doesn't preserve insertion # order, but we rely on the order of the array to assign indices. - return [x for _, x in sorted(zip(values, keys))] + if self.invert: + # If we are inverting, the vocabulary is in the values instead of keys. + return [x for _, x in sorted(zip(keys, values))] + else: + return [x for _, x in sorted(zip(values, keys))] def vocab_size(self): return self._table_handler.vocab_size() def get_config(self): config = { + "invert": self.invert, "max_tokens": self.max_tokens, "num_oov_indices": self.num_oov_indices, "oov_token": self.oov_token, @@ -198,33 +226,15 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer): # abstraction for ease of saving!) we return 0. return 0 - def set_vocabulary(self, vocab): - """Sets vocabulary (and optionally document frequency) data for this layer. - - This method sets the vocabulary for this layer directly, instead of - analyzing a dataset through 'adapt'. It should be used whenever the vocab - information is already known. If vocabulary data is already present in the - layer, this method will either replace it - - Arguments: - vocab: An array of string tokens. - - Raises: - ValueError: If there are too many inputs, the inputs do not match, or - input data is missing. - """ - + def _set_forward_vocabulary(self, vocab): + """Sets vocabulary data for this layer when inverse is False.""" table_utils.validate_vocabulary_is_unique(vocab) should_have_mask = self.mask_token is not None - if should_have_mask: - has_mask = vocab[0] == self.mask_token - oov_start = 1 - else: - has_mask = False - oov_start = 0 + has_mask = vocab[0] == self.mask_token + oov_start = 1 if should_have_mask else 0 - should_have_oov = self.num_oov_indices > 0 + should_have_oov = (self.num_oov_indices > 0) and not self.invert if should_have_oov: oov_end = oov_start + self.num_oov_indices expected_oov = [self.oov_token] * self.num_oov_indices @@ -293,12 +303,73 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer): special_token_values = np.arange(num_special_tokens, dtype=np.int64) self._table_handler.insert(special_tokens, special_token_values) + def _set_inverse_vocabulary(self, vocab): + """Sets vocabulary data for this layer when inverse is True.""" + table_utils.validate_vocabulary_is_unique(vocab) + + should_have_mask = self.mask_token is not None + has_mask = vocab[0] == self.mask_token + + insert_special_tokens = should_have_mask and not has_mask + special_tokens = [] if self.mask_token is None else [self.mask_token] + + num_special_tokens = len(special_tokens) + tokens = vocab if insert_special_tokens else vocab[num_special_tokens:] + if self.mask_token in tokens: + raise ValueError("Reserved mask token %s was found in the passed " + "vocabulary at index %s. Please either remove the " + "reserved token from the vocabulary or change the " + "mask token for this layer." % + (self.mask_token, tokens.index(self.mask_token))) + + if insert_special_tokens: + total_vocab_size = len(vocab) + num_special_tokens + else: + total_vocab_size = len(vocab) + if self.max_tokens is not None and total_vocab_size > self.max_tokens: + raise ValueError( + "Attempted to set a vocabulary larger than the maximum vocab size. " + "Passed vocab size is %s, max vocab size is %s." % + (total_vocab_size, self.max_tokens)) + + start_index = num_special_tokens if insert_special_tokens else 0 + values = np.arange(start_index, len(vocab) + start_index, dtype=np.int64) + + self._table_handler.clear() + self._table_handler.insert(values, vocab) + + if insert_special_tokens and num_special_tokens > 0: + special_token_values = np.arange(num_special_tokens, dtype=np.int64) + self._table_handler.insert(special_token_values, special_tokens) + + def set_vocabulary(self, vocab): + """Sets vocabulary data for this layer with inverse=False. + + This method sets the vocabulary for this layer directly, instead of + analyzing a dataset through 'adapt'. It should be used whenever the vocab + information is already known. If vocabulary data is already present in the + layer, this method will either replace it + + Arguments: + vocab: An array of string tokens. + + Raises: + ValueError: If there are too many inputs, the inputs do not match, or + input data is missing. + """ + if self.invert: + self._set_inverse_vocabulary(vocab) + else: + self._set_forward_vocabulary(vocab) + def _set_state_variables(self, updates): if not self.built: raise RuntimeError("_set_state_variables() must be called after build().") self.set_vocabulary(updates[_VOCAB_NAME]) def call(self, inputs): + if self._key_dtype == dtypes.int64 and inputs.dtype == dtypes.int32: + inputs = math_ops.cast(inputs, dtypes.int64) return self._table_handler.lookup(inputs) def _use_v1_apis(self): diff --git a/tensorflow/python/keras/layers/preprocessing/index_lookup_test.py b/tensorflow/python/keras/layers/preprocessing/index_lookup_test.py index a95834233b3..a61cef6121f 100644 --- a/tensorflow/python/keras/layers/preprocessing/index_lookup_test.py +++ b/tensorflow/python/keras/layers/preprocessing/index_lookup_test.py @@ -77,6 +77,31 @@ def _get_end_to_end_test_cases(): "input_dtype": dtypes.string }, + { + "testcase_name": + "test_inverse_strings_soft_vocab_cap", + # Create an array where 'earth' is the most frequent term, followed by + # 'wind', then 'and', then 'fire'. This ensures that the vocab + # accumulator is sorting by frequency. + "vocab_data": + np.array([["fire"], ["earth"], ["earth"], ["earth"], ["earth"], + ["wind"], ["wind"], ["wind"], ["and"], ["and"]]), + "input_data": + np.array([[1], [2], [3], [4], [4], [3], [1], [5]]), + "kwargs": { + "max_tokens": None, + "num_oov_indices": 1, + "mask_token": "", + "oov_token": "[OOV]", + "dtype": dtypes.string, + "invert": True + }, + "expected_output": + np.array([[b"earth"], [b"wind"], [b"and"], [b"fire"], [b"fire"], + [b"and"], [b"earth"], [b"[OOV]"]]), + "input_dtype": + dtypes.int64 + }, { "testcase_name": "test_ints_soft_vocab_cap", @@ -101,6 +126,78 @@ def _get_end_to_end_test_cases(): "input_dtype": dtypes.int64 }, + { + "testcase_name": + "test_strings_hard_vocab_cap", + # Create an array where 'earth' is the most frequent term, followed by + # 'wind', then 'and', then 'fire'. This ensures that the vocab + # accumulator is sorting by frequency. + "vocab_data": + np.array([["fire"], ["earth"], ["earth"], ["earth"], ["earth"], + ["wind"], ["wind"], ["wind"], ["and"], ["and"]]), + "input_data": + np.array([["earth"], ["wind"], ["and"], ["fire"], ["fire"], + ["and"], ["earth"], ["michigan"]]), + "kwargs": { + "max_tokens": 5, + "num_oov_indices": 1, + "mask_token": "", + "oov_token": "[OOV]", + "dtype": dtypes.string, + }, + "expected_output": [[2], [3], [4], [1], [1], [4], [2], [1]], + "input_dtype": + dtypes.string + }, + { + "testcase_name": + "test_inverse_strings_hard_vocab_cap", + # Create an array where 'earth' is the most frequent term, followed by + # 'wind', then 'and', then 'fire'. This ensures that the vocab + # accumulator is sorting by frequency. + "vocab_data": + np.array([["fire"], ["earth"], ["earth"], ["earth"], ["earth"], + ["wind"], ["wind"], ["wind"], ["and"], ["and"]]), + "input_data": + np.array([[1], [2], [3], [4], [4], [3], [1], [5]]), + "kwargs": { + "max_tokens": 5, + "num_oov_indices": 1, + "mask_token": "", + "oov_token": "[OOV]", + "dtype": dtypes.string, + "invert": True + }, + "expected_output": + np.array([[b"earth"], [b"wind"], [b"and"], [b"[OOV]"], [b"[OOV]"], + [b"and"], [b"earth"], [b"[OOV]"]]), + "input_dtype": + dtypes.int64 + }, + { + "testcase_name": + "test_ints_hard_vocab_cap", + # Create an array where 1138 is the most frequent term, followed by + # 1729, then 725, then 42. This ensures that the vocab accumulator + # is sorting by frequency. + "vocab_data": + np.array([[42], [1138], [1138], [1138], [1138], [1729], [1729], + [1729], [725], [725]], + dtype=np.int64), + "input_data": + np.array([[1138], [1729], [725], [42], [42], [725], [1138], [4]], + dtype=np.int64), + "kwargs": { + "max_tokens": 5, + "num_oov_indices": 1, + "mask_token": 0, + "oov_token": -1, + "dtype": dtypes.int64, + }, + "expected_output": [[2], [3], [4], [1], [1], [4], [2], [1]], + "input_dtype": + dtypes.int64 + }, ) crossed_test_cases = [] @@ -125,7 +222,11 @@ class IndexLookupLayerTest(keras_parameterized.TestCase, use_dataset, expected_output, input_dtype): cls = get_layer_class() - expected_output_dtype = dtypes.int64 + if "invert" in kwargs and kwargs["invert"]: + expected_output_dtype = kwargs["dtype"] + else: + expected_output_dtype = dtypes.int64 + input_shape = input_data.shape if use_dataset: @@ -156,7 +257,10 @@ class IndexLookupLayerTest(keras_parameterized.TestCase, expected_output_dtype=expected_output_dtype, validate_training=False, adapt_data=vocab_data) - self.assertAllClose(expected_output, output_data) + if "invert" in kwargs and kwargs["invert"]: + self.assertAllEqual(expected_output, output_data) + else: + self.assertAllClose(expected_output, output_data) @keras_parameterized.run_all_keras_modes @@ -254,6 +358,25 @@ class CategoricalEncodingInputTest( output_dataset = model.predict(input_array) self.assertAllEqual(expected_output, output_dataset) + def test_int32_input_with_int64_keys(self): + vocab_data = np.array([10, 11, 12, 13], dtype=np.int64) + input_array = ragged_factory_ops.constant([[10, 11, 13], [13, 12, 10, 42]], + dtype=np.int32) + expected_output = [[2, 3, 5], [5, 4, 2, 1]] + + input_data = keras.Input(shape=(None,), dtype=dtypes.int32, ragged=True) + layer = get_layer_class()( + max_tokens=None, + dtype=dtypes.int64, + num_oov_indices=1, + mask_token=0, + oov_token=-1) + layer.set_vocabulary(vocab_data) + int_data = layer(input_data) + model = keras.Model(inputs=input_data, outputs=int_data) + output_dataset = model.predict(input_array) + self.assertAllEqual(expected_output, output_dataset) + @keras_parameterized.run_all_keras_modes class CategoricalEncodingMultiOOVTest( @@ -748,6 +871,118 @@ class IndexLookupVocabularyTest(keras_parameterized.TestCase, layer.set_vocabulary(vocab_data) +@keras_parameterized.run_all_keras_modes +class IndexLookupInverseVocabularyTest( + keras_parameterized.TestCase, + preprocessing_test_utils.PreprocessingLayerTest): + + def test_int_output_explicit_vocab(self): + vocab_data = ["[OOV]", "earth", "wind", "and", "fire"] + input_array = np.array([[2, 3, 4, 5], [5, 4, 2, 1]]) + expected_output = np.array([["earth", "wind", "and", "fire"], + ["fire", "and", "earth", "[OOV]"]]) + + input_data = keras.Input(shape=(None,), dtype=dtypes.int64) + layer = get_layer_class()( + vocabulary=vocab_data, + max_tokens=None, + num_oov_indices=1, + mask_token="", + oov_token="[OOV]", + dtype=dtypes.string, + invert=True) + int_data = layer(input_data) + model = keras.Model(inputs=input_data, outputs=int_data) + output_dataset = model.predict(input_array) + self.assertAllEqual(expected_output, output_dataset) + + def test_vocab_with_max_cap(self): + vocab_data = ["", "[OOV]", "wind", "and", "fire"] + layer = get_layer_class()( + max_tokens=5, + num_oov_indices=1, + mask_token="", + oov_token="[OOV]", + dtype=dtypes.string, + invert=True) + layer.set_vocabulary(vocab_data) + returned_vocab = layer.get_vocabulary() + self.assertAllEqual(vocab_data, returned_vocab) + + def test_int_vocab_with_max_cap(self): + vocab_data = [0, -1, 42, 1276, 1138] + layer = get_layer_class()( + max_tokens=5, + num_oov_indices=1, + mask_token=0, + oov_token=-1, + dtype=dtypes.int64, + invert=True) + layer.set_vocabulary(vocab_data) + returned_vocab = layer.get_vocabulary() + self.assertAllEqual(vocab_data, returned_vocab) + + def test_non_unique_vocab_fails(self): + vocab_data = ["earth", "wind", "and", "fire", "fire"] + with self.assertRaisesRegex(ValueError, ".*repeated term.*fire.*"): + _ = get_layer_class()( + vocabulary=vocab_data, + max_tokens=None, + num_oov_indices=1, + mask_token="", + oov_token="[OOV]", + dtype=dtypes.string, + invert=True) + + def test_vocab_with_repeated_element_fails(self): + vocab_data = ["earth", "earth", "wind", "and", "fire"] + layer = get_layer_class()( + max_tokens=None, + num_oov_indices=1, + mask_token="", + oov_token="[OOV]", + dtype=dtypes.string, + invert=True) + with self.assertRaisesRegex(ValueError, ".*repeated term.*earth.*"): + layer.set_vocabulary(vocab_data) + + def test_vocab_with_reserved_mask_element_fails(self): + vocab_data = ["earth", "mask_token", "wind", "and", "fire"] + layer = get_layer_class()( + max_tokens=None, + num_oov_indices=1, + mask_token="mask_token", + oov_token="[OOV]", + dtype=dtypes.string, + invert=True) + with self.assertRaisesRegex(ValueError, ".*Reserved mask.*"): + layer.set_vocabulary(vocab_data) + + def test_non_unique_int_vocab_fails(self): + vocab_data = [12, 13, 14, 15, 15] + with self.assertRaisesRegex(ValueError, ".*repeated term.*15.*"): + _ = get_layer_class()( + vocabulary=vocab_data, + max_tokens=None, + num_oov_indices=1, + mask_token=0, + oov_token=-1, + dtype=dtypes.int64, + invert=True) + + def test_int_vocab_with_repeated_element_fails(self): + vocab_data = [11, 11, 34, 23, 124] + layer = get_layer_class()( + max_tokens=None, + num_oov_indices=1, + mask_token=0, + oov_token=-1, + dtype=dtypes.int64, + invert=True) + with self.assertRaisesRegex(ValueError, ".*repeated term.*11.*"): + layer.set_vocabulary(vocab_data) + + @keras_parameterized.run_all_keras_modes(always_skip_eager=True) class IndexLookupSaveableTest(keras_parameterized.TestCase, preprocessing_test_utils.PreprocessingLayerTest): diff --git a/tensorflow/python/keras/layers/preprocessing/integer_lookup.py b/tensorflow/python/keras/layers/preprocessing/integer_lookup.py index 671c02573db..6f497983408 100644 --- a/tensorflow/python/keras/layers/preprocessing/integer_lookup.py +++ b/tensorflow/python/keras/layers/preprocessing/integer_lookup.py @@ -57,6 +57,86 @@ class IntegerLookup(index_lookup.IndexLookup): a vocabulary to load into this layer. The file should contain one value per line. If the list or file contains the same token multiple times, an error will be thrown. + invert: If true, this layer will map indices to vocabulary items instead + of mapping vocabulary items to indices. + + Examples: + + Creating a lookup layer with a known vocabulary + + This example creates a lookup layer with a pre-existing vocabulary. + + >>> vocab = [12, 36, 1138, 42] + >>> data = tf.constant([[12, 1138, 42], [42, 1000, 36]]) + >>> layer = IntegerLookup(vocabulary=vocab) + >>> layer(data) + <tf.Tensor: shape=(2, 3), dtype=int64, numpy= + array([[2, 4, 5], + [5, 1, 3]])> + + + Creating a lookup layer with an adapted vocabulary + + This example creates a lookup layer and generates the vocabulary by analyzing + the dataset. + + >>> data = tf.constant([[12, 1138, 42], [42, 1000, 36]]) + >>> layer = IntegerLookup() + >>> layer.adapt(data) + >>> layer.get_vocabulary() + [0, -1, 42, 1138, 1000, 36, 12] + + Note how the mask value 0 and the OOV value -1 have been added to the + vocabulary. The remaining tokens are sorted by frequency (1138, which has + 2 occurrences, is first) then by inverse sort order. + + >>> data = tf.constant([[12, 1138, 42], [42, 1000, 36]]) + >>> layer = IntegerLookup() + >>> layer.adapt(data) + >>> layer(data) + <tf.Tensor: shape=(2, 3), dtype=int64, numpy= + array([[6, 3, 2], + [2, 4, 5]])> + + + Inverse lookup + + This example demonstrates how to map indices to values using this layer. (You + can also use adapt() with inverse=True, but for simplicity we'll pass the + vocab in this example.) + + >>> vocab = [12, 36, 1138, 42] + >>> data = tf.constant([[1, 3, 4], [4, 5, 2]]) + >>> layer = IntegerLookup(vocabulary=vocab, invert=True) + >>> layer(data) + <tf.Tensor: shape=(2, 3), dtype=int64, numpy= + array([[ 12, 1138, 42], + [ 42, -1, 36]])> + + Note that the integer 5, which is out of the vocabulary space, returns an OOV + token. + + + Forward and inverse lookup pairs + + This example demonstrates how to use the vocabulary of a standard lookup + layer to create an inverse lookup layer. + + >>> vocab = [12, 36, 1138, 42] + >>> data = tf.constant([[12, 1138, 42], [42, 1000, 36]]) + >>> layer = IntegerLookup(vocabulary=vocab) + >>> i_layer = IntegerLookup(vocabulary=layer.get_vocabulary(), invert=True) + >>> int_data = layer(data) + >>> i_layer(int_data) + <tf.Tensor: shape=(2, 3), dtype=int64, numpy= + array([[ 12, 1138, 42], + [ 42, -1, 36]])> + + In this example, the input value 1000 resulted in an output of -1, since + 1000 was not in the vocabulary - it got represented as an OOV, and all OOV + values are returned as -1 in the inverse layer. Also, note that for the + inverse to work, you must have already set the forward layer vocabulary + either directly or via fit() before calling get_vocabulary(). """ def __init__(self, @@ -65,6 +145,7 @@ class IntegerLookup(index_lookup.IndexLookup): mask_value=0, oov_value=-1, vocabulary=None, + invert=False, **kwargs): allowed_dtypes = [dtypes.int64] @@ -95,6 +176,7 @@ class IntegerLookup(index_lookup.IndexLookup): mask_token=mask_value, oov_token=oov_value, vocabulary=vocabulary, + invert=invert, **kwargs) def get_config(self): diff --git a/tensorflow/python/keras/layers/preprocessing/integer_lookup_test.py b/tensorflow/python/keras/layers/preprocessing/integer_lookup_test.py index 515a1ca6667..0b71c6aaecc 100644 --- a/tensorflow/python/keras/layers/preprocessing/integer_lookup_test.py +++ b/tensorflow/python/keras/layers/preprocessing/integer_lookup_test.py @@ -347,6 +347,36 @@ class IntegerLookupOutputTest(keras_parameterized.TestCase, output_dataset = model.predict(input_array) self.assertAllEqual(expected_output, output_dataset) + def test_inverse_output(self): + vocab_data = [0, -1, 42, 1138, 725, 1729] + input_array = np.array([[2, 3, 4, 5], [5, 4, 2, 1]]) + expected_output = np.array([[42, 1138, 725, 1729], [1729, 725, 42, -1]]) + + input_data = keras.Input(shape=(None,), dtype=dtypes.int64) + layer = get_layer_class()(invert=True) + layer.set_vocabulary(vocab_data) + int_data = layer(input_data) + model = keras.Model(inputs=input_data, outputs=int_data) + output_dataset = model.predict(input_array) + self.assertAllEqual(expected_output, output_dataset) + + def test_forward_backward_output(self): + vocab_data = [42, 1138, 725, 1729] + input_array = np.array([[42, 1138, 725, 1729], [1729, 725, 42, 203]]) + expected_output = np.array([[42, 1138, 725, 1729], [1729, 725, 42, -1]]) + + input_data = keras.Input(shape=(None,), dtype=dtypes.int64) + layer = get_layer_class()() + inverse_layer = get_layer_class()() + layer.set_vocabulary(vocab_data) + inverse_layer = get_layer_class()( + vocabulary=layer.get_vocabulary(), invert=True) + int_data = layer(input_data) + inverse_data = inverse_layer(int_data) + model = keras.Model(inputs=input_data, outputs=inverse_data) + output_dataset = model.predict(input_array) + self.assertAllEqual(expected_output, output_dataset) + @keras_parameterized.run_all_keras_modes class IntegerLookupVocabularyTest( diff --git a/tensorflow/python/keras/layers/preprocessing/normalization.py b/tensorflow/python/keras/layers/preprocessing/normalization.py index cf9600a63ab..be04e9947b8 100644 --- a/tensorflow/python/keras/layers/preprocessing/normalization.py +++ b/tensorflow/python/keras/layers/preprocessing/normalization.py @@ -55,6 +55,21 @@ class Normalization(CombinerPreprocessingLayer): in the specified axis. If set to 'None', the layer will perform scalar normalization (diving the input by a single scalar value). 0 (the batch axis) is not allowed. + + + Examples: + + Calculate the mean and variance by analyzing the dataset in `adapt`. + + >>> adapt_data = np.array([[1.], [2.], [3.], [4.], [5.]], dtype=np.float32) + >>> input_data = np.array([[1.], [2.], [3.]], np.float32) + >>> layer = Normalization() + >>> layer.adapt(adapt_data) + >>> layer(input_data) + <tf.Tensor: shape=(3, 1), dtype=float32, numpy= + array([[-1.4142135 ], + [-0.70710677], + [ 0. ]], dtype=float32)> """ def __init__(self, axis=-1, dtype=None, **kwargs): @@ -107,6 +122,10 @@ class Normalization(CombinerPreprocessingLayer): super(Normalization, self).build(input_shape) def call(self, inputs): + # If the inputs are not floats, cast them to floats. This avoids issues + # with int-float multiplication and division below. + if inputs.dtype != K.floatx(): + inputs = math_ops.cast(inputs, K.floatx()) # We need to reshape the mean and variance data to ensure that Tensorflow # broadcasts the data correctly. mean = array_ops.reshape(self.mean, self._broadcast_shape) diff --git a/tensorflow/python/keras/layers/preprocessing/normalization_test.py b/tensorflow/python/keras/layers/preprocessing/normalization_test.py index 2e6f4990cc5..e5a429751f4 100644 --- a/tensorflow/python/keras/layers/preprocessing/normalization_test.py +++ b/tensorflow/python/keras/layers/preprocessing/normalization_test.py @@ -48,6 +48,12 @@ def _get_layer_computation_test_cases(): "test_data": np.array([[1.], [2.], [3.]], np.float32), "expected": np.array([[-1.414214], [-.707107], [0]], np.float32), "testcase_name": "2d_single_element" + }, { + "adapt_data": np.array([[1], [2], [3], [4], [5]], dtype=np.int32), + "axis": -1, + "test_data": np.array([[1], [2], [3]], np.int32), + "expected": np.array([[-1.414214], [-.707107], [0]], np.float32), + "testcase_name": "2d_int_data" }, { "adapt_data": np.array([[1.], [2.], [3.], [4.], [5.]], dtype=np.float32), "axis": None, @@ -140,6 +146,7 @@ class NormalizationTest(keras_parameterized.TestCase, self.validate_accumulator_extract(combiner, data, expected) self.validate_accumulator_extract_and_restore(combiner, data, expected) + @parameterized.named_parameters( { "data": np.array([[1], [2], [3], [4], [5]]), diff --git a/tensorflow/python/keras/layers/preprocessing/string_lookup.py b/tensorflow/python/keras/layers/preprocessing/string_lookup.py index 4032486b5f0..a420de8678a 100644 --- a/tensorflow/python/keras/layers/preprocessing/string_lookup.py +++ b/tensorflow/python/keras/layers/preprocessing/string_lookup.py @@ -58,6 +58,86 @@ class StringLookup(index_lookup.IndexLookup): one token per line. If the list or file contains the same token multiple times, an error will be thrown. encoding: The Python string encoding to use. Defaults to `'utf-8'`. + invert: If true, this layer will map indices to vocabulary items instead + of mapping vocabulary items to indices. + + Examples: + + Creating a lookup layer with a known vocabulary + + This example creates a lookup layer with a pre-existing vocabulary. + + >>> vocab = ["a", "b", "c", "d"] + >>> data = tf.constant([["a", "c", "d"], ["d", "z", "b"]]) + >>> layer = StringLookup(vocabulary=vocab) + >>> layer(data) + <tf.Tensor: shape=(2, 3), dtype=int64, numpy= + array([[2, 4, 5], + [5, 1, 3]])> + + + Creating a lookup layer with an adapted vocabulary + + This example creates a lookup layer and generates the vocabulary by analyzing + the dataset. + + >>> data = tf.constant([["a", "c", "d"], ["d", "z", "b"]]) + >>> layer = StringLookup() + >>> layer.adapt(data) + >>> layer.get_vocabulary() + ['', '[OOV]', 'd', 'z', 'c', 'b', 'a'] + + Note how the mask token '' and the OOV token [OOV] have been added to the + vocabulary. The remaining tokens are sorted by frequency ('d', which has + 2 occurrences, is first) then by inverse sort order. + + >>> data = tf.constant([["a", "c", "d"], ["d", "z", "b"]]) + >>> layer = StringLookup() + >>> layer.adapt(data) + >>> layer(data) + <tf.Tensor: shape=(2, 3), dtype=int64, numpy= + array([[6, 4, 2], + [2, 3, 5]])> + + + Inverse lookup + + This example demonstrates how to map indices to strings using this layer. (You + can also use adapt() with inverse=True, but for simplicity we'll pass the + vocab in this example.) + + >>> vocab = ["a", "b", "c", "d"] + >>> data = tf.constant([[1, 3, 4], [4, 5, 2]]) + >>> layer = StringLookup(vocabulary=vocab, invert=True) + >>> layer(data) + <tf.Tensor: shape=(2, 3), dtype=string, numpy= + array([[b'a', b'c', b'd'], + [b'd', b'[OOV]', b'b']], dtype=object)> + + Note that the integer 5, which is out of the vocabulary space, returns an OOV + token. + + + Forward and inverse lookup pairs + + This example demonstrates how to use the vocabulary of a standard lookup + layer to create an inverse lookup layer. + + >>> vocab = ["a", "b", "c", "d"] + >>> data = tf.constant([["a", "c", "d"], ["d", "z", "b"]]) + >>> layer = StringLookup(vocabulary=vocab) + >>> i_layer = StringLookup(vocabulary=layer.get_vocabulary(), invert=True) + >>> int_data = layer(data) + >>> i_layer(int_data) + <tf.Tensor: shape=(2, 3), dtype=string, numpy= + array([[b'a', b'c', b'd'], + [b'd', b'[OOV]', b'b']], dtype=object)> + + In this example, the input value 'z' resulted in an output of '[OOV]', since + 1000 was not in the vocabulary - it got represented as an OOV, and all OOV + values are returned as '[OOV}' in the inverse layer. Also, note that for the + inverse to work, you must have already set the forward layer vocabulary + either directly or via fit() before calling get_vocabulary(). """ def __init__(self, @@ -67,6 +147,7 @@ class StringLookup(index_lookup.IndexLookup): oov_token="[OOV]", vocabulary=None, encoding="utf-8", + invert=False, **kwargs): allowed_dtypes = [dtypes.string] @@ -89,6 +170,7 @@ class StringLookup(index_lookup.IndexLookup): mask_token=mask_token, oov_token=oov_token, vocabulary=vocabulary, + invert=invert, **kwargs) def get_config(self): diff --git a/tensorflow/python/keras/layers/preprocessing/string_lookup_test.py b/tensorflow/python/keras/layers/preprocessing/string_lookup_test.py index b2a610ac328..0b9081d815c 100644 --- a/tensorflow/python/keras/layers/preprocessing/string_lookup_test.py +++ b/tensorflow/python/keras/layers/preprocessing/string_lookup_test.py @@ -187,6 +187,36 @@ class StringLookupVocabularyTest(keras_parameterized.TestCase, with self.assertRaisesRegex(ValueError, ".*repeated term.*earth.*"): _ = get_layer_class()(vocabulary=vocab_path) + def test_inverse_layer(self): + vocab_data = ["earth", "wind", "and", "fire"] + input_array = np.array([[1, 2, 3, 4], [4, 3, 1, 0]]) + expected_output = np.array([["earth", "wind", "and", "fire"], + ["fire", "and", "earth", ""]]) + + input_data = keras.Input(shape=(None,), dtype=dtypes.int64) + layer = get_layer_class()(vocabulary=vocab_data, invert=True) + int_data = layer(input_data) + model = keras.Model(inputs=input_data, outputs=int_data) + output_dataset = model.predict(input_array) + self.assertAllEqual(expected_output, output_dataset) + + def test_forward_backward_layer(self): + vocab_data = ["earth", "wind", "and", "fire"] + input_array = np.array([["earth", "wind", "and", "fire"], + ["fire", "and", "earth", "michigan"]]) + expected_output = np.array([["earth", "wind", "and", "fire"], + ["fire", "and", "earth", "[OOV]"]]) + + input_data = keras.Input(shape=(None,), dtype=dtypes.string) + layer = get_layer_class()(vocabulary=vocab_data) + invert_layer = get_layer_class()( + vocabulary=layer.get_vocabulary(), invert=True) + int_data = layer(input_data) + out_data = invert_layer(int_data) + model = keras.Model(inputs=input_data, outputs=out_data) + output_dataset = model.predict(input_array) + self.assertAllEqual(expected_output, output_dataset) + @keras_parameterized.run_all_keras_modes(always_skip_eager=True) class StringLookupSaveableTest(keras_parameterized.TestCase, diff --git a/tensorflow/python/keras/layers/preprocessing/table_utils.py b/tensorflow/python/keras/layers/preprocessing/table_utils.py index 05447f6e9ff..cf1bfd741c9 100644 --- a/tensorflow/python/keras/layers/preprocessing/table_utils.py +++ b/tensorflow/python/keras/layers/preprocessing/table_utils.py @@ -21,6 +21,7 @@ import collections import numpy as np from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.keras import backend as K from tensorflow.python.ops import array_ops @@ -60,6 +61,11 @@ class TableHandler(object): raise RuntimeError("Size mismatch between values and key arrays. " "Keys had size %s, values had size %s." % (len(keys), len(values))) + keys = ops.convert_to_tensor(keys, dtype=self.table._key_dtype) # pylint: disable=protected-access + values = ops.convert_to_tensor(values, dtype=self.table._value_dtype) # pylint: disable=protected-access + if values.shape.ndims != 1: + raise ValueError("`values` must be 1-dimensional, got an input with " + " %s dimensions." % values.shape.ndims) self._run(self.table.insert(keys, values)) def _replace_oov_buckets(self, inputs, lookups): @@ -87,6 +93,8 @@ class TableHandler(object): self.table.lookup, inputs) indexed_data = ragged_functional_ops.map_flat_values( self._replace_oov_buckets, inputs, indexed_data) + # table.lookup is not shape-preserving, so we need to set the shape here. + indexed_data._set_shape(inputs.shape) # pylint: disable=protected-access # Composite tensors can pass tensor values through, which will cause # errors if all operations in the TF graph do so. We can break this chain # with an identity here. diff --git a/tensorflow/python/keras/layers/preprocessing/table_utils_test.py b/tensorflow/python/keras/layers/preprocessing/table_utils_test.py index 60a891f6ba8..ab7e80b628c 100644 --- a/tensorflow/python/keras/layers/preprocessing/table_utils_test.py +++ b/tensorflow/python/keras/layers/preprocessing/table_utils_test.py @@ -108,6 +108,15 @@ class CategoricalEncodingInputTest( self.assertAllEqual(expected_output, output_data) + def test_tensor_multi_dim_values_fails(self): + key_data = np.array([0, 1], dtype=np.int64) + value_data = np.array([[11, 12], [21, 22]]) + + table = get_table(dtype=dtypes.int64, oov_tokens=[1, 2]) + + with self.assertRaisesRegexp(ValueError, "must be 1-dimensional"): + table.insert(key_data, value_data) + @keras_parameterized.run_all_keras_modes class CategoricalEncodingMultiOOVTest( diff --git a/tensorflow/python/keras/layers/preprocessing/text_vectorization.py b/tensorflow/python/keras/layers/preprocessing/text_vectorization.py index 4156ba50c02..c80f998fe46 100644 --- a/tensorflow/python/keras/layers/preprocessing/text_vectorization.py +++ b/tensorflow/python/keras/layers/preprocessing/text_vectorization.py @@ -17,10 +17,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections -import json -import operator - import numpy as np from tensorflow.python.data.ops import dataset_ops @@ -29,9 +25,8 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec from tensorflow.python.keras import backend as K -from tensorflow.python.keras.engine.base_preprocessing_layer import Combiner from tensorflow.python.keras.engine.base_preprocessing_layer import CombinerPreprocessingLayer -from tensorflow.python.keras.layers.preprocessing import categorical_encoding +from tensorflow.python.keras.layers.preprocessing import category_encoding from tensorflow.python.keras.layers.preprocessing import string_lookup from tensorflow.python.keras.utils import layer_utils from tensorflow.python.ops import array_ops @@ -41,17 +36,16 @@ from tensorflow.python.ops import string_ops from tensorflow.python.ops.ragged import ragged_functional_ops from tensorflow.python.ops.ragged import ragged_string_ops from tensorflow.python.ops.ragged import ragged_tensor -from tensorflow.python.util import compat from tensorflow.python.util.tf_export import keras_export LOWER_AND_STRIP_PUNCTUATION = "lower_and_strip_punctuation" SPLIT_ON_WHITESPACE = "whitespace" -TFIDF = categorical_encoding.TFIDF -INT = categorical_encoding.INT -BINARY = categorical_encoding.BINARY -COUNT = categorical_encoding.COUNT +TFIDF = category_encoding.TFIDF +INT = category_encoding.INT +BINARY = category_encoding.BINARY +COUNT = category_encoding.COUNT # This is an explicit regex of all the tokens that will be stripped if # LOWER_AND_STRIP_PUNCTUATION is set. If an application requires other @@ -122,7 +116,9 @@ class TextVectorization(CombinerPreprocessingLayer): Attributes: max_tokens: The maximum size of the vocabulary for this layer. If None, - there is no cap on the size of the vocabulary. + there is no cap on the size of the vocabulary. Note that this vocabulary + contains 1 OOV token, so the effective number of tokens is `(max_tokens - + 1 - (1 if output == "int" else 0))`. standardize: Optional specification for standardization to apply to the input text. Values can be None (no standardization), 'lower_and_strip_punctuation' (lowercase and remove punctuation) or a @@ -138,7 +134,8 @@ class TextVectorization(CombinerPreprocessingLayer): output_mode: Optional specification for the output of the layer. Values can be "int", "binary", "count" or "tf-idf", configuring the layer as follows: "int": Outputs integer indices, one integer index per split string - token. + token. When output == "int", 0 is reserved for masked locations; + this reduces the vocab size to max_tokens-2 instead of max_tokens-1 "binary": Outputs a single int array per batch, of either vocab_size or max_tokens size, containing 1s in all elements where the token mapped to that index exists at least once in the batch item. @@ -160,42 +157,43 @@ class TextVectorization(CombinerPreprocessingLayer): Example: This example instantiates a TextVectorization layer that lowercases text, splits on whitespace, strips punctuation, and outputs integer vocab indices. - ``` - max_features = 5000 # Maximum vocab size. - max_len = 40 # Sequence length to pad the outputs to. - # Create the layer. - vectorize_layer = text_vectorization.TextVectorization( - max_tokens=max_features, - output_mode='int', - output_sequence_length=max_len) + >>> text_dataset = tf.data.Dataset.from_tensor_slices(["foo", "bar", "baz"]) + >>> max_features = 5000 # Maximum vocab size. + >>> max_len = 4 # Sequence length to pad the outputs to. + >>> embedding_dims = 2 + >>> + >>> # Create the layer. + >>> vectorize_layer = TextVectorization( + ... max_tokens=max_features, + ... output_mode='int', + ... output_sequence_length=max_len) + >>> + >>> # Now that the vocab layer has been created, call `adapt` on the text-only + >>> # dataset to create the vocabulary. You don't have to batch, but for large + >>> # datasets this means we're not keeping spare copies of the dataset. + >>> vectorize_layer.adapt(text_dataset.batch(64)) + >>> + >>> # Create the model that uses the vectorize text layer + >>> model = tf.keras.models.Sequential() + >>> + >>> # Start by creating an explicit input layer. It needs to have a shape of + >>> # (1,) (because we need to guarantee that there is exactly one string + >>> # input per batch), and the dtype needs to be 'string'. + >>> model.add(tf.keras.Input(shape=(1,), dtype=tf.string)) + >>> + >>> # The first layer in our model is the vectorization layer. After this + >>> # layer, we have a tensor of shape (batch_size, max_len) containing vocab + >>> # indices. + >>> model.add(vectorize_layer) + >>> + >>> # Now, the model can map strings to integers, and you can add an embedding + >>> # layer to map these integers to learned embeddings. + >>> input_data = [["foo qux bar"], ["qux baz"]] + >>> model.predict(input_data) + array([[2, 1, 4, 0], + [1, 3, 0, 0]]) - # Now that the vocab layer has been created, call `adapt` on the text-only - # dataset to create the vocabulary. You don't have to batch, but for large - # datasets this means we're not keeping spare copies of the dataset in memory. - vectorize_layer.adapt(text_dataset.batch(64)) - - # Create the model that uses the vectorize text layer - model = tf.keras.models.Sequential() - - # Start by creating an explicit input layer. It needs to have a shape of (1,) - # (because we need to guarantee that there is exactly one string input per - # batch), and the dtype needs to be 'string'. - model.add(tf.keras.Input(shape=(1,), dtype=tf.string)) - - # The first layer in our model is the vectorization layer. After this layer, - # we have a tensor of shape (batch_size, max_len) containing vocab indices. - model.add(vectorize_layer) - - # Next, we add a layer to map those vocab indices into a space of - # dimensionality 'embedding_dims'. Note that we're using max_features+1 here, - # since there's an OOV token that gets added to the vocabulary in - # vectorize_layer. - model.add(tf.keras.layers.Embedding(max_features+1, embedding_dims)) - - # At this point, you have embedded float data representing your tokens, and - # can add whatever other layers you need to create your model. - ``` """ # TODO(momernick): Add an examples section to the docstring. @@ -274,12 +272,6 @@ class TextVectorization(CombinerPreprocessingLayer): # the OOV value to zero instead of one. self._oov_value = 1 if output_mode == INT else 0 - # We always reduce the max token number by 1 to account for the OOV token - # if it is set. Keras' use of the reserved number 0 for padding tokens, - # if the output is in INT mode, does not really count as a 'token' for - # vocabulary purposes, so we only reduce vocab size by 1 here. - self._max_vocab_size = max_tokens - 1 if max_tokens is not None else None - self._standardize = standardize self._split = split self._ngrams_arg = ngrams @@ -295,8 +287,7 @@ class TextVectorization(CombinerPreprocessingLayer): self._called = False super(TextVectorization, self).__init__( - combiner=_TextVectorizationCombiner( - self._max_vocab_size, compute_idf=output_mode == TFIDF), + combiner=None, **kwargs) mask_token = "" if output_mode in [None, INT] else None @@ -306,31 +297,21 @@ class TextVectorization(CombinerPreprocessingLayer): # If this layer is configured for string or integer output, we do not # create a vectorization layer (as the output is not vectorized). if self._output_mode in [None, INT]: - return - - if max_tokens is not None and self._pad_to_max: - max_elements = max_tokens + self._vectorize_layer = None else: - max_elements = None - self._vectorize_layer = self._get_vectorization_class()( - max_tokens=max_elements, output_mode=self._output_mode) + if max_tokens is not None and self._pad_to_max: + max_elements = max_tokens + else: + max_elements = None + self._vectorize_layer = self._get_vectorization_class()( + max_tokens=max_elements, output_mode=self._output_mode) # These are V1/V2 shim points. There are V1 implementations in the V1 class. def _get_vectorization_class(self): - return categorical_encoding.CategoricalEncoding - - def _get_table_data(self): - keys, values = self._table.export() - return (keys.numpy(), values.numpy()) + return category_encoding.CategoryEncoding def _get_index_lookup_class(self): return string_lookup.StringLookup - - def _to_numpy(self, preprocessed_data): - """Converts preprocessed inputs into numpy arrays.""" - if isinstance(preprocessed_data, np.ndarray): - return preprocessed_data - return np.array(preprocessed_data.to_list()) # End of V1/V2 shim points. def _assert_same_type(self, expected_type, values, value_name): @@ -346,11 +327,16 @@ class TextVectorization(CombinerPreprocessingLayer): return tensor_shape.TensorShape([input_shape[0], self._max_tokens]) if self._output_mode == INT and self._split is None: - return input_shape + if len(input_shape) == 1: + input_shape = tuple(input_shape) + (1,) + return tensor_shape.TensorShape(input_shape) if self._output_mode == INT and self._split is not None: input_shape = list(input_shape) - input_shape[1] = self._output_sequence_length + if len(input_shape) == 1: + input_shape = input_shape + [self._output_sequence_length] + else: + input_shape[1] = self._output_sequence_length return tensor_shape.TensorShape(input_shape) def compute_output_signature(self, input_spec): @@ -366,7 +352,7 @@ class TextVectorization(CombinerPreprocessingLayer): Arguments: data: The data to train on. It can be passed either as a tf.data Dataset, - or as a numpy array. + as a NumPy array, a string tensor, or as a list of texts. reset_state: Optional argument specifying whether to clear the state of the layer at the start of the call to `adapt`. This must be True for this layer, which does not support repeated calls to `adapt`. @@ -377,26 +363,39 @@ class TextVectorization(CombinerPreprocessingLayer): # Build the layer explicitly with the original data shape instead of relying # on an implicit call to `build` in the base layer's `adapt`, since # preprocessing changes the input shape. - if isinstance(data, np.ndarray): - if data.ndim == 1: - data = np.expand_dims(data, axis=-1) + if isinstance(data, (list, tuple, np.ndarray)): + data = ops.convert_to_tensor(data) + + if isinstance(data, ops.Tensor): + if data.shape.rank == 1: + data = array_ops.expand_dims(data, axis=-1) self.build(data.shape) - preprocessed_inputs = self._to_numpy(self._preprocess(data)) + preprocessed_inputs = self._preprocess(data) elif isinstance(data, dataset_ops.DatasetV2): # TODO(momernick): Replace this with a more V2-friendly API. shape = dataset_ops.get_legacy_output_shapes(data) if not isinstance(shape, tensor_shape.TensorShape): raise ValueError("The dataset passed to 'adapt' must contain a single " "tensor value.") + if shape.rank == 0: + data = data.map(lambda tensor: array_ops.expand_dims(tensor, 0)) + shape = dataset_ops.get_legacy_output_shapes(data) if shape.rank == 1: data = data.map(lambda tensor: array_ops.expand_dims(tensor, -1)) self.build(dataset_ops.get_legacy_output_shapes(data)) preprocessed_inputs = data.map(self._preprocess) else: raise ValueError( - "adapt() requires a Dataset or a Numpy array as input, got {}".format( + "adapt() requires a Dataset or an array as input, got {}".format( type(data))) - super(TextVectorization, self).adapt(preprocessed_inputs, reset_state) + + self._index_lookup_layer.adapt(preprocessed_inputs) + if self._vectorize_layer: + if isinstance(data, ops.Tensor): + integer_data = self._index_lookup_layer(preprocessed_inputs) + else: + integer_data = preprocessed_inputs.map(self._index_lookup_layer) + self._vectorize_layer.adapt(integer_data) def get_vocabulary(self): return self._index_lookup_layer.get_vocabulary() @@ -492,11 +491,12 @@ class TextVectorization(CombinerPreprocessingLayer): # in None for undefined shape axes. If using 'and !=', this causes the # expression to evaluate to False instead of True if the shape is undefined; # the expression needs to evaluate to True in that case. - if self._split is not None and not input_shape[1] == 1: # pylint: disable=g-comparison-negation - raise RuntimeError( - "When using TextVectorization to tokenize strings, the first " - "dimension of the input array must be 1, got shape " - "{}".format(input_shape)) + if self._split is not None: + if input_shape.ndims > 1 and not input_shape[-1] == 1: # pylint: disable=g-comparison-negation + raise RuntimeError( + "When using TextVectorization to tokenize strings, the innermost " + "dimension of the input array must be 1, got shape " + "{}".format(input_shape)) super(TextVectorization, self).build(input_shape) @@ -538,7 +538,8 @@ class TextVectorization(CombinerPreprocessingLayer): # If we are splitting, we validate that the 1st axis is of dimension 1 and # so can be squeezed out. We do this here instead of after splitting for # performance reasons - it's more expensive to squeeze a ragged tensor. - inputs = array_ops.squeeze(inputs, axis=1) + if inputs.shape.ndims > 1: + inputs = array_ops.squeeze(inputs, axis=-1) if self._split == SPLIT_ON_WHITESPACE: # This treats multiple whitespaces as one whitespace, and strips leading # and trailing whitespace. @@ -561,8 +562,8 @@ class TextVectorization(CombinerPreprocessingLayer): return inputs def call(self, inputs): - if inputs.shape.rank == 1: - inputs = array_ops.expand_dims(inputs, axis=-1) + if isinstance(inputs, (list, tuple, np.ndarray)): + inputs = ops.convert_to_tensor(inputs) self._called = True inputs = self._preprocess(inputs) @@ -570,9 +571,7 @@ class TextVectorization(CombinerPreprocessingLayer): # If we're not doing any output processing, return right away. if self._output_mode is None: return inputs - indexed_data = self._index_lookup_layer(inputs) - if self._output_mode == INT: # Once we have the dense tensor, we can return it if we weren't given a # fixed output sequence length. If we were, though, we have to dynamically @@ -585,7 +584,6 @@ class TextVectorization(CombinerPreprocessingLayer): dense_data = indexed_data if self._output_sequence_length is None: - dense_data.set_shape(tensor_shape.TensorShape((None, None))) return dense_data else: sequence_len = K.shape(dense_data)[1] @@ -596,198 +594,11 @@ class TextVectorization(CombinerPreprocessingLayer): sequence_len < self._output_sequence_length, true_fn=pad_fn, false_fn=slice_fn) - output_tensor.set_shape( - tensor_shape.TensorShape((None, self._output_sequence_length))) + output_shape = output_tensor.shape.as_list() + output_shape[-1] = self._output_sequence_length + output_tensor.set_shape(tensor_shape.TensorShape(output_shape)) return output_tensor # If we're not returning integers here, we rely on the vectorization layer # to create the output. return self._vectorize_layer(indexed_data) - - -class _TextVectorizationAccumulator( - collections.namedtuple("_TextVectorizationAccumulator", - ["count_dict", "per_doc_count_dict", "metadata"])): - pass - - -# A note on this combiner: This contains functionality that will be extracted -# into the Vectorization and IndexLookup combiner objects. At that point, -# TextVectorization can become a PreprocessingStage instead of a Layer and -# this combiner can be retired. Until then, we leave this as is instead of -# attempting a refactor of what will soon be deleted. -class _TextVectorizationCombiner(Combiner): - """Combiner for the TextVectorization preprocessing layer. - - This class encapsulates the logic for computing a vocabulary based on the - frequency of each token. - - Attributes: - vocab_size: (Optional) If set, only the top `vocab_size` tokens (based on - frequency across the dataset) are retained in the vocabulary. If None, or - set to a value greater than the total number of distinct tokens in the - dataset, all tokens are retained. - compute_idf: (Optional) If set, the inverse document frequency will be - computed for each value. - """ - - def __init__(self, vocab_size=None, compute_idf=False): - self._vocab_size = vocab_size - self._compute_idf = compute_idf - self._input_dtype = dtypes.string - - def compute(self, values, accumulator=None): - """Compute a step in this computation, returning a new accumulator.""" - if dtypes.as_dtype(self._input_dtype) != dtypes.as_dtype(values.dtype): - raise RuntimeError("Expected input type %s, got %s" % - (self._input_dtype, values.dtype)) - if ragged_tensor.is_ragged(values): - values = values.to_list() - if isinstance(values, ops.EagerTensor): - values = values.numpy() - if isinstance(values, np.ndarray): - values = values.tolist() - - if accumulator is None: - accumulator = self._create_accumulator() - - # If we are being passed raw strings or bytestrings, we need to wrap them - # in an array so we don't accidentally iterate over the bytes instead of - # treating the string as one object. - if isinstance(values, (str, bytes)): - values = [values] - - # TODO(momernick): Benchmark improvements to this algorithm. - for document in values: - current_doc_id = accumulator.metadata[0] - for token in document: - accumulator.count_dict[token] += 1 - if self._compute_idf: - doc_count = accumulator.per_doc_count_dict[token] - if doc_count["last_doc_id"] != current_doc_id: - doc_count["count"] += 1 - doc_count["last_doc_id"] = current_doc_id - accumulator.metadata[0] += 1 - - return accumulator - - def merge(self, accumulators): - """Merge several accumulators to a single accumulator.""" - if not accumulators: - return accumulators - - base_accumulator = accumulators[0] - - for accumulator in accumulators[1:]: - base_accumulator.metadata[0] += accumulator.metadata[0] - for token, value in accumulator.count_dict.items(): - base_accumulator.count_dict[token] += value - if self._compute_idf: - for token, value in accumulator.per_doc_count_dict.items(): - # Any newly created token counts in 'base_accumulator''s - # per_doc_count_dict will have a last_doc_id of -1. This is always - # less than the next doc id (which are strictly positive), so any - # future occurrences are guaranteed to be counted. - base_accumulator.per_doc_count_dict[token]["count"] += value["count"] - - return base_accumulator - - def _inverse_document_frequency(self, document_counts, num_documents): - """Compute the inverse-document-frequency (IDF) component of TFIDF. - - Uses the default weighting scheme described in - https://en.wikipedia.org/wiki/Tf%E2%80%93idf. - - Args: - document_counts: An array of the # of documents each token appears in. - num_documents: An int representing the total number of documents - - Returns: - An array of "inverse document frequency" weights. - """ - return np.log(1 + num_documents / (1 + np.array(document_counts))) - - def extract(self, accumulator): - """Convert an accumulator into a dict of output values. - - Args: - accumulator: An accumulator aggregating over the full dataset. - - Returns: - A dict of: - "vocab": A list of the retained items in the vocabulary. - "idf": The inverse-document-frequency for each item in vocab. - idf[vocab_idx] is the IDF value for the corresponding vocab item. - "oov_idf": The inverse-document-frequency for the OOV token. - """ - if self._compute_idf: - vocab_counts, document_counts, num_documents = accumulator - else: - vocab_counts, _, _ = accumulator - - sorted_counts = sorted( - vocab_counts.items(), key=operator.itemgetter(1, 0), reverse=True) - vocab_data = ( - sorted_counts[:self._vocab_size] if self._vocab_size else sorted_counts) - vocab = [data[0] for data in vocab_data] - - if self._compute_idf: - doc_counts = [document_counts[token]["count"] for token in vocab] - idf = self._inverse_document_frequency(doc_counts, num_documents[0]) - oov_idf = np.array([np.log(1 + num_documents[0])]) - return {_VOCAB_NAME: vocab, _IDF_NAME: idf, _OOV_IDF_NAME: oov_idf} - else: - return {_VOCAB_NAME: vocab} - - def restore(self, output): - """Create an accumulator based on 'output'.""" - raise NotImplementedError( - "TextVectorization does not restore or support streaming updates.") - - def serialize(self, accumulator): - """Serialize an accumulator for a remote call.""" - output_dict = {} - output_dict["metadata"] = accumulator.metadata - output_dict["vocab"] = list(accumulator.count_dict.keys()) - output_dict["vocab_counts"] = list(accumulator.count_dict.values()) - if self._compute_idf: - output_dict["idf_vocab"] = list(accumulator.per_doc_count_dict.keys()) - output_dict["idf_counts"] = [ - counter["count"] - for counter in accumulator.per_doc_count_dict.values() - ] - return compat.as_bytes(json.dumps(output_dict)) - - def deserialize(self, encoded_accumulator): - """Deserialize an accumulator received from 'serialize()'.""" - accumulator_dict = json.loads(compat.as_text(encoded_accumulator)) - - accumulator = self._create_accumulator() - accumulator.metadata[0] = accumulator_dict["metadata"][0] - - count_dict = dict( - zip(accumulator_dict["vocab"], accumulator_dict["vocab_counts"])) - accumulator.count_dict.update(count_dict) - - if self._compute_idf: - create_dict = lambda x: {"count": x, "last_doc_id": -1} - idf_count_dicts = [ - create_dict(count) for count in accumulator_dict["idf_counts"] - ] - idf_dict = dict(zip(accumulator_dict["idf_vocab"], idf_count_dicts)) - accumulator.per_doc_count_dict.update(idf_dict) - - return accumulator - - def _create_accumulator(self): - """Accumulate a sorted array of vocab tokens and corresponding counts.""" - - count_dict = collections.defaultdict(int) - if self._compute_idf: - create_default_dict = lambda: {"count": 0, "last_doc_id": -1} - per_doc_count_dict = collections.defaultdict(create_default_dict) - else: - per_doc_count_dict = None - metadata = [0] - return _TextVectorizationAccumulator(count_dict, per_doc_count_dict, - metadata) diff --git a/tensorflow/python/keras/layers/preprocessing/text_vectorization_test.py b/tensorflow/python/keras/layers/preprocessing/text_vectorization_test.py index f8a1f5b9434..508f222eac7 100644 --- a/tensorflow/python/keras/layers/preprocessing/text_vectorization_test.py +++ b/tensorflow/python/keras/layers/preprocessing/text_vectorization_test.py @@ -29,6 +29,7 @@ from tensorflow.python import tf2 from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import one_device_strategy from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.keras import backend from tensorflow.python.keras import keras_parameterized @@ -36,12 +37,10 @@ from tensorflow.python.keras import testing_utils from tensorflow.python.keras.layers import convolutional from tensorflow.python.keras.layers import core from tensorflow.python.keras.layers import embeddings +from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils from tensorflow.python.keras.layers.preprocessing import text_vectorization from tensorflow.python.keras.layers.preprocessing import text_vectorization_v1 -from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils -from tensorflow.python.keras.saving import saved_model_experimental as saving from tensorflow.python.keras.utils import generic_utils -from tensorflow.python.keras.utils.generic_utils import CustomObjectScope from tensorflow.python.ops import gen_string_ops from tensorflow.python.ops.ragged import ragged_factory_ops from tensorflow.python.ops.ragged import ragged_string_ops @@ -61,7 +60,7 @@ def _get_end_to_end_test_cases(): "testcase_name": "test_simple_tokens_int_mode", # Create an array where 'earth' is the most frequent term, followed by - # 'wind', then 'and', then 'fire'. This ensures that the vocab accumulator + # 'wind', then 'and', then 'fire'. This ensures that the vocab # is sorting by frequency. "vocab_data": np.array([["fire"], ["earth"], ["earth"], ["earth"], ["earth"], @@ -77,6 +76,26 @@ def _get_end_to_end_test_cases(): }, "expected_output": [[2], [3], [4], [5], [5], [4], [2], [1]], }, + { + "testcase_name": + "test_simple_tokens_int_mode_hard_cap", + # Create an array where 'earth' is the most frequent term, followed by + # 'wind', then 'and', then 'fire'. This ensures that the vocab + # is sorting by frequency. + "vocab_data": + np.array([["fire"], ["earth"], ["earth"], ["earth"], ["earth"], + ["wind"], ["wind"], ["wind"], ["and"], ["and"]]), + "input_data": + np.array([["earth"], ["wind"], ["and"], ["fire"], ["fire"], + ["and"], ["earth"], ["michigan"]]), + "kwargs": { + "max_tokens": 6, + "standardize": None, + "split": None, + "output_mode": text_vectorization.INT + }, + "expected_output": [[2], [3], [4], [5], [5], [4], [2], [1]], + }, { "testcase_name": "test_documents_int_mode", @@ -274,18 +293,121 @@ class TextVectorizationLayerTest(keras_parameterized.TestCase, vocab_data = dataset_ops.Dataset.from_tensor_slices(vocab_data).batch( input_shape[0]) - with CustomObjectScope({"TextVectorization": cls}): - output_data = testing_utils.layer_test( - cls, - kwargs=kwargs, - input_shape=input_shape, - input_data=input_data, - input_dtype=dtypes.string, - expected_output_dtype=expected_output_dtype, - validate_training=False, - adapt_data=vocab_data) + output_data = testing_utils.layer_test( + cls, + kwargs=kwargs, + input_shape=input_shape, + input_data=input_data, + input_dtype=dtypes.string, + expected_output_dtype=expected_output_dtype, + validate_training=False, + adapt_data=vocab_data) self.assertAllClose(expected_output, output_data) + def test_list_inputs_1d(self): + vocab_data = ["two two two", "two three three", "three four four five"] + input_data = ["two three", "four five"] + layer = get_layer_class()() + layer.adapt(vocab_data) + out = layer(input_data) + if context.executing_eagerly(): + self.assertAllClose(out.numpy(), [[2, 3], [4, 5]]) + layer.set_vocabulary(["two", "three", "four", "five"]) + out = layer(input_data) + if context.executing_eagerly(): + self.assertAllClose(out.numpy(), [[2, 3], [4, 5]]) + + def test_tensor_inputs(self): + vocab_data = constant_op.constant( + ["two two two", "two three three", "three four four five"]) + input_data = constant_op.constant(["two three", "four five"]) + layer = get_layer_class()() + layer.adapt(vocab_data) + out = layer(input_data) + if context.executing_eagerly(): + self.assertAllClose(out.numpy(), [[2, 3], [4, 5]]) + layer.set_vocabulary(["two", "three", "four", "five"]) + out = layer(input_data) + if context.executing_eagerly(): + self.assertAllClose(out.numpy(), [[2, 3], [4, 5]]) + + def test_list_inputs_2d(self): + vocab_data = [ + ["two two two"], ["two three three"], ["three four four five"]] + input_data = [["two three"], ["four five"]] + layer = get_layer_class()() + layer.adapt(vocab_data) + out = layer(input_data) + if context.executing_eagerly(): + self.assertAllClose(out.numpy(), [[2, 3], [4, 5]]) + layer.set_vocabulary(["two", "three", "four", "five"]) + out = layer(input_data) + if context.executing_eagerly(): + self.assertAllClose(out.numpy(), [[2, 3], [4, 5]]) + + def test_dataset_of_single_strings(self): + vocab_data = ["two two two", "two three three", "three four four five"] + input_data = ["two three", "four five"] + vocab_ds = dataset_ops.Dataset.from_tensor_slices(vocab_data) # unbatched + layer = get_layer_class()() + layer.adapt(vocab_ds) + out = layer(input_data) + if context.executing_eagerly(): + self.assertAllClose(out.numpy(), [[2, 3], [4, 5]]) + + @parameterized.named_parameters( + { + "testcase_name": "1d", + "data": ["0", "a", "b", "c", "d", "e", "a", "b", "c", "d", "f"], + "expected": [1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1] + }, + { + "testcase_name": "2d", + "data": [["0", "a", "b", "c", "d"], ["e", "a", "b", "c", "d"], ["f"]], + "expected": [[1, 2, 3, 4, 5], [1, 2, 3, 4, 5], [1, 0, 0, 0, 0]] + }, + { + "testcase_name": + "3d", + "data": [[["0", "a", "b"], ["c", "d"]], [["e", "a"], ["b", "c", "d"]], + [["f"]]], + "expected": [[[1, 2, 3], [4, 5, 0]], [[1, 2, 0], [3, 4, 5]], + [[1, 0, 0], [0, 0, 0]]] + }, + ) + def test_layer_dimensionality_handling(self, data, expected): + vocab = ["a", "b", "c", "d"] + vectorization = get_layer_class()( + max_tokens=None, standardize=None, split=None, pad_to_max_tokens=False) + vectorization.set_vocabulary(vocab) + output = vectorization(ragged_factory_ops.constant(data)) + self.assertAllEqual(expected, output) + + @parameterized.named_parameters( + { + "testcase_name": "1d", + "data": ["0 a b c d e a b c d f"], + "expected": [[1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1]] + }, + { + "testcase_name": + "3d", + "data": [[["0 a b"], ["c d"]], [["e a"], ["b c d"]], [["f"]]], + "expected": [[[1, 2, 3], [4, 5, 0]], [[1, 2, 0], [3, 4, 5]], + [[1, 0, 0], [0, 0, 0]]] + }, + ) + def test_layer_dimensionality_handling_with_split(self, data, expected): + vocab = ["a", "b", "c", "d"] + vectorization = get_layer_class()( + max_tokens=None, + standardize=None, + split=text_vectorization.SPLIT_ON_WHITESPACE, + pad_to_max_tokens=False) + vectorization.set_vocabulary(vocab) + output = vectorization(ragged_factory_ops.constant(data, inner_shape=(1,))) + self.assertAllEqual(expected, output) + @keras_parameterized.run_all_keras_modes class TextVectorizationPreprocessingTest( @@ -511,7 +633,7 @@ class TextVectorizationPreprocessingTest( split=text_vectorization.SPLIT_ON_WHITESPACE, output_mode=None) with self.assertRaisesRegex(RuntimeError, - ".*tokenize strings, the first dimension.*"): + ".*tokenize strings, the innermost dime.*"): _ = layer(input_data) def test_string_splitting_with_non_1d_raggedarray_fails(self): @@ -522,7 +644,7 @@ class TextVectorizationPreprocessingTest( split=text_vectorization.SPLIT_ON_WHITESPACE, output_mode=None) with self.assertRaisesRegex(RuntimeError, - ".*tokenize strings, the first dimension.*"): + ".*tokenize strings, the innermost dime.*"): _ = layer(input_data) def test_standardization_with_invalid_standardize_arg(self): @@ -933,7 +1055,7 @@ class TextVectorizationOutputTest( output_mode=text_vectorization.BINARY, pad_to_max_tokens=False) _ = layer(input_data) - with self.assertRaisesRegex(RuntimeError, "vocabulary cannot be changed"): + with self.assertRaisesRegex(RuntimeError, "can't be adapted after being"): layer.adapt(vocab_data) def test_bag_output_soft_maximum_set_state_variables_after_call_fails(self): @@ -1295,6 +1417,7 @@ class TextVectorizationErrorTest(keras_parameterized.TestCase, ".*`output_sequence_length` must not be set.*"): _ = get_layer_class()(output_mode="count", output_sequence_length=2) + # Custom functions for the custom callable serialization test. Declared here # to avoid multiple registrations from run_all_keras_modes(). @generic_utils.register_keras_serializable(package="Test") @@ -1340,8 +1463,7 @@ class TextVectorizationSavingTest( if tf2.enabled(): keras.backend.clear_session() - loaded_model = keras.models.load_model( - output_path, custom_objects={"TextVectorization": get_layer_class()}) + loaded_model = keras.models.load_model(output_path) self.assertAllEqual(loaded_model.predict(input_array), expected_output) def test_saving_when_nested(self): @@ -1375,67 +1497,10 @@ class TextVectorizationSavingTest( if tf2.enabled(): keras.backend.clear_session() - loaded_model = keras.models.load_model( - output_path, custom_objects={"TextVectorization": get_layer_class()}) + loaded_model = keras.models.load_model(output_path) self.assertAllEqual(loaded_model.predict(input_array), expected_output) - def test_serialization_with_custom_callables(self): - input_array = np.array([["earth>wind>and Fire"], - ["\tfire>And\nearth>michigan"]]) - expected_output = [[b"earth", b"wind", b"and fire"], - [b"\tfire", b"and\nearth", b"michigan"]] - - input_data = keras.Input(shape=(1,), dtype=dtypes.string) - layer = get_layer_class()( - max_tokens=None, - standardize=custom_standardize_fn, - split=custom_split_fn, - ngrams=None, - output_mode=None) - int_data = layer(input_data) - model = keras.Model(inputs=input_data, outputs=int_data) - output_dataset = model.predict(input_array) - self.assertAllEqual(expected_output, output_dataset) - - serialized_model_data = model.get_config() - with CustomObjectScope({"TextVectorization": get_layer_class()}): - new_model = keras.Model.from_config(serialized_model_data) - new_output_dataset = new_model.predict(input_array) - self.assertAllEqual(expected_output, new_output_dataset) - - def DISABLED_test_vocabulary_persistence_across_saving(self): - vocab_data = ["earth", "wind", "and", "fire"] - input_array = np.array([["earth", "wind", "and", "fire"], - ["fire", "and", "earth", "michigan"]]) - expected_output = [[2, 3, 4, 5], [5, 4, 2, 1]] - - # Build and validate a golden model. - input_data = keras.Input(shape=(None,), dtype=dtypes.string) - layer = get_layer_class()( - max_tokens=None, - standardize=None, - split=None, - output_mode=text_vectorization.INT) - layer.set_vocabulary(vocab_data) - int_data = layer(input_data) - model = keras.Model(inputs=input_data, outputs=int_data) - output_dataset = model.predict(input_array) - self.assertAllEqual(output_dataset, expected_output) - - # Save the model to disk. - output_path = os.path.join(self.get_temp_dir(), "tf_keras_saved_model") - model.save(output_path, save_format="tf") - loaded_model = saving.load_from_saved_model( - output_path, custom_objects={"TextVectorization": get_layer_class()}) - - # Ensure that the loaded model is unique (so that the save/load is real) - self.assertIsNot(model, loaded_model) - - # Validate correctness of the new model. - new_output_dataset = loaded_model.predict(input_array) - self.assertAllEqual(new_output_dataset, expected_output) - - def DISABLED_test_vocabulary_persistence_across_saving_with_tfidf(self): + def test_saving_with_tfidf(self): vocab_data = ["earth", "wind", "and", "fire"] tfidf_data = [.5, .25, .2, .125] input_array = np.array([["earth", "wind", "and", "earth"], @@ -1465,8 +1530,7 @@ class TextVectorizationSavingTest( # Save the model to disk. output_path = os.path.join(self.get_temp_dir(), "tf_keras_saved_model") model.save(output_path, save_format="tf") - loaded_model = saving.load_from_saved_model( - output_path, custom_objects={"TextVectorization": get_layer_class()}) + loaded_model = keras.models.load_model(output_path) # Ensure that the loaded model is unique (so that the save/load is real) self.assertIsNot(model, loaded_model) @@ -1475,208 +1539,62 @@ class TextVectorizationSavingTest( new_output_dataset = loaded_model.predict(input_array) self.assertAllClose(new_output_dataset, expected_output) + def test_serialization_with_custom_callables(self): + input_array = np.array([["earth>wind>and Fire"], + ["\tfire>And\nearth>michigan"]]) + expected_output = [[b"earth", b"wind", b"and fire"], + [b"\tfire", b"and\nearth", b"michigan"]] + + input_data = keras.Input(shape=(1,), dtype=dtypes.string) + layer = get_layer_class()( + max_tokens=None, + standardize=custom_standardize_fn, + split=custom_split_fn, + ngrams=None, + output_mode=None) + int_data = layer(input_data) + model = keras.Model(inputs=input_data, outputs=int_data) + output_dataset = model.predict(input_array) + self.assertAllEqual(expected_output, output_dataset) + + serialized_model_data = model.get_config() + new_model = keras.Model.from_config(serialized_model_data) + new_output_dataset = new_model.predict(input_array) + self.assertAllEqual(expected_output, new_output_dataset) + @keras_parameterized.run_all_keras_modes -class TextVectorizationCombinerTest( - keras_parameterized.TestCase, - preprocessing_test_utils.PreprocessingLayerTest): +class TextVectorizationE2ETest(keras_parameterized.TestCase, + preprocessing_test_utils.PreprocessingLayerTest): - def compare_text_accumulators(self, a, b, msg=None): - if a is None or b is None: - self.assertAllEqual(a, b, msg=msg) + def test_keras_vocab_trimming_example(self): + vocab_data = np.array([ + "earth", "earth", "earth", "earth", "wind", "wind", "wind", "and", + "and", "fire" + ]) + input_array = np.array([["earth", "wind", "and", "earth"], + ["ohio", "and", "earth", "michigan"]]) - self.assertAllEqual(a.count_dict, b.count_dict, msg=msg) - self.assertAllEqual(a.metadata, b.metadata, msg=msg) - - if a.per_doc_count_dict is not None: - - def per_doc_counts(accumulator): - count_values = [ - count_dict["count"] - for count_dict in accumulator.per_doc_count_dict.values() - ] - return dict(zip(accumulator.per_doc_count_dict.keys(), count_values)) - - self.assertAllEqual(per_doc_counts(a), per_doc_counts(b), msg=msg) - - compare_accumulators = compare_text_accumulators - - def update_accumulator(self, accumulator, data): - accumulator.count_dict.update(dict(zip(data["vocab"], data["counts"]))) - accumulator.metadata[0] = data["num_documents"] - - if "document_counts" in data: - create_dict = lambda x: {"count": x, "last_doc_id": -1} - idf_count_dicts = [ - create_dict(count) for count in data["document_counts"] - ] - idf_dict = dict(zip(data["vocab"], idf_count_dicts)) - - accumulator.per_doc_count_dict.update(idf_dict) - - return accumulator - - def test_combiner_api_compatibility_int_mode(self): - data = np.array([["earth", "wind", "and", "fire"], - ["earth", "wind", "and", "michigan"]]) - combiner = text_vectorization._TextVectorizationCombiner(compute_idf=False) - expected_accumulator_output = { - "vocab": np.array(["and", "earth", "wind", "fire", "michigan"]), - "counts": np.array([2, 2, 2, 1, 1]), - "num_documents": np.array(2), - } - expected_extract_output = { - "vocab": np.array(["wind", "earth", "and", "michigan", "fire"]), - } - expected_accumulator = combiner._create_accumulator() - expected_accumulator = self.update_accumulator(expected_accumulator, - expected_accumulator_output) - self.validate_accumulator_serialize_and_deserialize(combiner, data, - expected_accumulator) - self.validate_accumulator_uniqueness(combiner, data) - self.validate_accumulator_extract(combiner, data, expected_extract_output) - - def test_combiner_api_compatibility_tfidf_mode(self): - data = np.array([["earth", "wind", "and", "fire"], - ["earth", "wind", "and", "michigan"]]) - combiner = text_vectorization._TextVectorizationCombiner(compute_idf=True) - expected_extract_output = { - "vocab": np.array(["wind", "earth", "and", "michigan", "fire"]), - "idf": np.array([0.510826, 0.510826, 0.510826, 0.693147, 0.693147]), - "oov_idf": np.array([1.098612]) - } - expected_accumulator_output = { - "vocab": np.array(["wind", "earth", "and", "michigan", "fire"]), - "counts": np.array([2, 2, 2, 1, 1]), - "document_counts": np.array([2, 2, 2, 1, 1]), - "num_documents": np.array(2), - } - - expected_accumulator = combiner._create_accumulator() - expected_accumulator = self.update_accumulator(expected_accumulator, - expected_accumulator_output) - self.validate_accumulator_serialize_and_deserialize(combiner, data, - expected_accumulator) - self.validate_accumulator_uniqueness(combiner, data) - self.validate_accumulator_extract(combiner, data, expected_extract_output) - - # TODO(askerryryan): Add tests confirming equivalence to behavior of - # existing tf.keras.preprocessing.text.Tokenizer. - @parameterized.named_parameters( - { - "testcase_name": - "top_k_smaller_than_full_vocab", - "data": - np.array([["earth", "wind"], ["fire", "wind"], ["and"], - ["fire", "wind"]]), - "vocab_size": - 3, - "expected_accumulator_output": { - "vocab": np.array(["wind", "fire", "earth", "and"]), - "counts": np.array([3, 2, 1, 1]), - "document_counts": np.array([3, 2, 1, 1]), - "num_documents": np.array(4), - }, - "expected_extract_output": { - "vocab": np.array(["wind", "fire", "earth"]), - "idf": np.array([0.693147, 0.847298, 1.098612]), - "oov_idf": np.array([1.609438]), - }, - }, - { - "testcase_name": - "top_k_larger_than_full_vocab", - "data": - np.array([["earth", "wind"], ["fire", "wind"], ["and"], - ["fire", "wind"]]), - "vocab_size": - 10, - "expected_accumulator_output": { - "vocab": np.array(["wind", "fire", "earth", "and"]), - "counts": np.array([3, 2, 1, 1]), - "document_counts": np.array([3, 2, 1, 1]), - "num_documents": np.array(4), - }, - "expected_extract_output": { - "vocab": np.array(["wind", "fire", "earth", "and"]), - "idf": np.array([0.693147, 0.847298, 1.098612, 1.098612]), - "oov_idf": np.array([1.609438]), - }, - }, - { - "testcase_name": - "no_top_k", - "data": - np.array([["earth", "wind"], ["fire", "wind"], ["and"], - ["fire", "wind"]]), - "vocab_size": - None, - "expected_accumulator_output": { - "vocab": np.array(["wind", "fire", "earth", "and"]), - "counts": np.array([3, 2, 1, 1]), - "document_counts": np.array([3, 2, 1, 1]), - "num_documents": np.array(4), - }, - "expected_extract_output": { - "vocab": np.array(["wind", "fire", "earth", "and"]), - "idf": np.array([0.693147, 0.847298, 1.098612, 1.098612]), - "oov_idf": np.array([1.609438]), - }, - }, - { - "testcase_name": "single_element_per_row", - "data": np.array([["earth"], ["wind"], ["fire"], ["wind"], ["and"]]), - "vocab_size": 3, - "expected_accumulator_output": { - "vocab": np.array(["wind", "and", "earth", "fire"]), - "counts": np.array([2, 1, 1, 1]), - "document_counts": np.array([2, 1, 1, 1]), - "num_documents": np.array(5), - }, - "expected_extract_output": { - "vocab": np.array(["wind", "fire", "earth"]), - "idf": np.array([0.980829, 1.252763, 1.252763]), - "oov_idf": np.array([1.791759]), - }, - }, - # Which tokens are retained are based on global frequency, and thus are - # sensitive to frequency within a document. In contrast, because idf only - # considers the presence of a token in a document, it is insensitive - # to the frequency of the token within the document. - { - "testcase_name": - "retained_tokens_sensitive_to_within_document_frequency", - "data": - np.array([["earth", "earth"], ["wind", "wind"], ["fire", "fire"], - ["wind", "wind"], ["and", "michigan"]]), - "vocab_size": - 3, - "expected_accumulator_output": { - "vocab": np.array(["wind", "earth", "fire", "and", "michigan"]), - "counts": np.array([4, 2, 2, 1, 1]), - "document_counts": np.array([2, 1, 1, 1, 1]), - "num_documents": np.array(5), - }, - "expected_extract_output": { - "vocab": np.array(["wind", "fire", "earth"]), - "idf": np.array([0.980829, 1.252763, 1.252763]), - "oov_idf": np.array([1.791759]), - }, - }) - def test_combiner_computation(self, - data, - vocab_size, - expected_accumulator_output, - expected_extract_output, - compute_idf=True): - combiner = text_vectorization._TextVectorizationCombiner( - vocab_size=vocab_size, compute_idf=compute_idf) - expected_accumulator = combiner._create_accumulator() - expected_accumulator = self.update_accumulator(expected_accumulator, - expected_accumulator_output) - self.validate_accumulator_computation(combiner, data, expected_accumulator) - self.validate_accumulator_extract(combiner, data, expected_extract_output) + # pyformat: disable + expected_output = [[1, 2, 1], + [3, 1, 0]] + # pyformat: enable + max_tokens = 3 + expected_output_shape = [None, max_tokens] + input_data = keras.Input(shape=(None,), dtype=dtypes.string) + layer = get_layer_class()( + max_tokens=max_tokens, + standardize=None, + split=None, + output_mode=text_vectorization.COUNT, + pad_to_max_tokens=True) + int_data = layer(input_data) + layer.adapt(vocab_data) + self.assertAllEqual(expected_output_shape, int_data.shape.as_list()) + model = keras.Model(input_data, int_data) + output = model.predict(input_array) + self.assertAllEqual(expected_output, output) if __name__ == "__main__": diff --git a/tensorflow/python/keras/layers/preprocessing/text_vectorization_v1.py b/tensorflow/python/keras/layers/preprocessing/text_vectorization_v1.py index 59cf2c61288..505cdc39547 100644 --- a/tensorflow/python/keras/layers/preprocessing/text_vectorization_v1.py +++ b/tensorflow/python/keras/layers/preprocessing/text_vectorization_v1.py @@ -18,14 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np - -from tensorflow.python.keras import backend as K from tensorflow.python.keras.engine import base_preprocessing_layer_v1 -from tensorflow.python.keras.layers.preprocessing import categorical_encoding_v1 +from tensorflow.python.keras.layers.preprocessing import category_encoding_v1 from tensorflow.python.keras.layers.preprocessing import string_lookup_v1 from tensorflow.python.keras.layers.preprocessing import text_vectorization -from tensorflow.python.ops.ragged import ragged_tensor_value from tensorflow.python.util.tf_export import keras_export @@ -81,17 +77,7 @@ class TextVectorization(text_vectorization.TextVectorization, """ def _get_vectorization_class(self): - return categorical_encoding_v1.CategoricalEncoding + return category_encoding_v1.CategoryEncoding def _get_index_lookup_class(self): return string_lookup_v1.StringLookup - - def _to_numpy(self, data): - """Converts preprocessed inputs into numpy arrays.""" - if isinstance(data, np.ndarray): - return data - session = K.get_session() - data = session.run(data) - if isinstance(data, ragged_tensor_value.RaggedTensorValue): - data = np.array(data.to_list()) - return data diff --git a/tensorflow/python/keras/layers/serialization.py b/tensorflow/python/keras/layers/serialization.py index 0a90441d8a0..992ff562755 100644 --- a/tensorflow/python/keras/layers/serialization.py +++ b/tensorflow/python/keras/layers/serialization.py @@ -45,9 +45,15 @@ from tensorflow.python.keras.layers import recurrent from tensorflow.python.keras.layers import recurrent_v2 from tensorflow.python.keras.layers import rnn_cell_wrapper_v2 from tensorflow.python.keras.layers import wrappers +from tensorflow.python.keras.layers.preprocessing import category_crossing +from tensorflow.python.keras.layers.preprocessing import category_encoding +from tensorflow.python.keras.layers.preprocessing import category_encoding_v1 +from tensorflow.python.keras.layers.preprocessing import hashing from tensorflow.python.keras.layers.preprocessing import image_preprocessing from tensorflow.python.keras.layers.preprocessing import normalization as preprocessing_normalization from tensorflow.python.keras.layers.preprocessing import normalization_v1 as preprocessing_normalization_v1 +from tensorflow.python.keras.layers.preprocessing import text_vectorization as preprocessing_text_vectorization +from tensorflow.python.keras.layers.preprocessing import text_vectorization_v1 as preprocessing_text_vectorization_v1 from tensorflow.python.keras.utils import generic_utils from tensorflow.python.util import tf_inspect as inspect from tensorflow.python.util.tf_export import keras_export @@ -57,30 +63,16 @@ ALL_MODULES = (base_layer, input_layer, advanced_activations, convolutional, convolutional_recurrent, core, cudnn_recurrent, dense_attention, embeddings, einsum_dense, local, merge, noise, normalization, pooling, image_preprocessing, preprocessing_normalization_v1, - recurrent, wrappers) -ALL_V2_MODULES = ( - rnn_cell_wrapper_v2, - normalization_v2, - recurrent_v2, - preprocessing_normalization -) -FEATURE_COLUMN_V1_OBJECTS = {} -FEATURE_COLUMN_V2_OBJECTS = {} + preprocessing_text_vectorization_v1, recurrent, wrappers, + hashing, category_crossing, category_encoding_v1) +ALL_V2_MODULES = (rnn_cell_wrapper_v2, normalization_v2, recurrent_v2, + preprocessing_normalization, preprocessing_text_vectorization, + category_encoding) # ALL_OBJECTS is meant to be a global mutable. Hence we need to make it # thread-local to avoid concurrent mutations. LOCAL = threading.local() -def inject_feature_column_v1_objects(name, cls): - global FEATURE_COLUMN_V1_OBJECTS - FEATURE_COLUMN_V1_OBJECTS[name] = cls - - -def inject_feature_column_v2_objects(name, cls): - global FEATURE_COLUMN_V2_OBJECTS - FEATURE_COLUMN_V2_OBJECTS[name] = cls - - def populate_deserializable_objects(): """Populates dict ALL_OBJECTS with every built-in layer. """ @@ -134,9 +126,11 @@ def populate_deserializable_objects(): LOCAL.ALL_OBJECTS['WideDeepModel'] = WideDeepModel if tf2.enabled(): - LOCAL.ALL_OBJECTS.update(FEATURE_COLUMN_V2_OBJECTS) + from tensorflow.python.keras.feature_column.dense_features_v2 import DenseFeatures # pylint: disable=g-import-not-at-top + LOCAL.ALL_OBJECTS['DenseFeatures'] = DenseFeatures else: - LOCAL.ALL_OBJECTS.update(FEATURE_COLUMN_V1_OBJECTS) + from tensorflow.python.keras.feature_column.dense_features import DenseFeatures # pylint: disable=g-import-not-at-top + LOCAL.ALL_OBJECTS['DenseFeatures'] = DenseFeatures # Merge layers, function versions. LOCAL.ALL_OBJECTS['add'] = merge.add diff --git a/tensorflow/python/keras/layers/tensorflow_op_layer_test.py b/tensorflow/python/keras/layers/tensorflow_op_layer_test.py index 73e395f5715..1a328995a80 100644 --- a/tensorflow/python/keras/layers/tensorflow_op_layer_test.py +++ b/tensorflow/python/keras/layers/tensorflow_op_layer_test.py @@ -288,9 +288,10 @@ class AutoLambdaTest(keras_parameterized.TestCase): constant_op.constant(40.0, shape=(1, 1))) def test_no_tracking(self): - x = keras.backend.placeholder((10, 10)) - keras.layers.Dense(1)(x) - self.assertTrue(x._keras_history_checked) + if not context.executing_eagerly(): + x = constant_op.constant(1.0, shape=(10, 10)) + keras.layers.Dense(1)(x) + self.assertTrue(x._keras_history_checked) def test_timing_scales_linearly(self): diff --git a/tensorflow/python/keras/layers/wrappers_test.py b/tensorflow/python/keras/layers/wrappers_test.py index bb22db25591..a73177fff12 100644 --- a/tensorflow/python/keras/layers/wrappers_test.py +++ b/tensorflow/python/keras/layers/wrappers_test.py @@ -33,6 +33,7 @@ from tensorflow.python.keras import combinations from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import testing_utils from tensorflow.python.keras.engine import base_layer_utils +from tensorflow.python.keras.layers import core from tensorflow.python.keras.layers.rnn_cell_wrapper_v2 import ResidualWrapper from tensorflow.python.keras.utils import generic_utils from tensorflow.python.ops import array_ops @@ -1213,9 +1214,14 @@ class BidirectionalTest(test.TestCase, parameterized.TestCase): f_merged = keras.backend.function([inputs], layer(inputs)) f_forward = keras.backend.function([inputs], layer.forward_layer(inputs)) + + # TODO(kaftan): after KerasTensor refactor TF op layers should work + # with many composite tensors, and this shouldn't need to be a lambda + # layer. + reverse_layer = core.Lambda(array_ops.reverse, arguments=dict(axis=[1])) f_backward = keras.backend.function( [inputs], - array_ops.reverse(layer.backward_layer(inputs), axis=[1])) + reverse_layer(layer.backward_layer(inputs))) y_merged = f_merged(x) y_expected = merge_func( diff --git a/tensorflow/python/keras/losses.py b/tensorflow/python/keras/losses.py index 99fb015288b..2bb53dcfaa5 100644 --- a/tensorflow/python/keras/losses.py +++ b/tensorflow/python/keras/losses.py @@ -38,6 +38,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops.losses import losses_impl from tensorflow.python.ops.losses import util as tf_losses_util +from tensorflow.python.util import dispatch from tensorflow.python.util.tf_export import keras_export from tensorflow.tools.docs import doc_controls @@ -1164,6 +1165,7 @@ class Huber(LossFunctionWrapper): 'keras.losses.mean_squared_error', 'keras.losses.mse', 'keras.losses.MSE') +@dispatch.add_dispatch_support def mean_squared_error(y_true, y_pred): """Computes the mean squared error between labels and predictions. @@ -1199,6 +1201,7 @@ def mean_squared_error(y_true, y_pred): 'keras.losses.mean_absolute_error', 'keras.losses.mae', 'keras.losses.MAE') +@dispatch.add_dispatch_support def mean_absolute_error(y_true, y_pred): """Computes the mean absolute error between labels and predictions. @@ -1231,6 +1234,7 @@ def mean_absolute_error(y_true, y_pred): 'keras.losses.mean_absolute_percentage_error', 'keras.losses.mape', 'keras.losses.MAPE') +@dispatch.add_dispatch_support def mean_absolute_percentage_error(y_true, y_pred): """Computes the mean absolute percentage error between `y_true` and `y_pred`. @@ -1267,6 +1271,7 @@ def mean_absolute_percentage_error(y_true, y_pred): 'keras.losses.mean_squared_logarithmic_error', 'keras.losses.msle', 'keras.losses.MSLE') +@dispatch.add_dispatch_support def mean_squared_logarithmic_error(y_true, y_pred): """Computes the mean squared logarithmic error between `y_true` and `y_pred`. @@ -1315,6 +1320,7 @@ def _maybe_convert_labels(y_true): @keras_export('keras.metrics.squared_hinge', 'keras.losses.squared_hinge') +@dispatch.add_dispatch_support def squared_hinge(y_true, y_pred): """Computes the squared hinge loss between `y_true` and `y_pred`. @@ -1347,6 +1353,7 @@ def squared_hinge(y_true, y_pred): @keras_export('keras.metrics.hinge', 'keras.losses.hinge') +@dispatch.add_dispatch_support def hinge(y_true, y_pred): """Computes the hinge loss between `y_true` and `y_pred`. @@ -1378,6 +1385,7 @@ def hinge(y_true, y_pred): @keras_export('keras.losses.categorical_hinge') +@dispatch.add_dispatch_support def categorical_hinge(y_true, y_pred): """Computes the categorical hinge loss between `y_true` and `y_pred`. @@ -1410,6 +1418,7 @@ def categorical_hinge(y_true, y_pred): @keras_export('keras.losses.huber', v1=[]) +@dispatch.add_dispatch_support def huber(y_true, y_pred, delta=1.0): """Computes Huber loss value. @@ -1447,6 +1456,7 @@ def huber(y_true, y_pred, delta=1.0): @keras_export('keras.losses.log_cosh', 'keras.losses.logcosh') +@dispatch.add_dispatch_support def log_cosh(y_true, y_pred): """Logarithm of the hyperbolic cosine of the prediction error. @@ -1485,6 +1495,7 @@ def log_cosh(y_true, y_pred): @keras_export('keras.metrics.categorical_crossentropy', 'keras.losses.categorical_crossentropy') +@dispatch.add_dispatch_support def categorical_crossentropy(y_true, y_pred, from_logits=False, @@ -1525,6 +1536,7 @@ def categorical_crossentropy(y_true, @keras_export('keras.metrics.sparse_categorical_crossentropy', 'keras.losses.sparse_categorical_crossentropy') +@dispatch.add_dispatch_support def sparse_categorical_crossentropy(y_true, y_pred, from_logits=False, axis=-1): """Computes the sparse categorical crossentropy loss. @@ -1556,6 +1568,7 @@ def sparse_categorical_crossentropy(y_true, y_pred, from_logits=False, axis=-1): @keras_export('keras.metrics.binary_crossentropy', 'keras.losses.binary_crossentropy') +@dispatch.add_dispatch_support def binary_crossentropy(y_true, y_pred, from_logits=False, label_smoothing=0): """Computes the binary crossentropy loss. @@ -1599,6 +1612,7 @@ def binary_crossentropy(y_true, y_pred, from_logits=False, label_smoothing=0): 'keras.losses.kullback_leibler_divergence', 'keras.losses.kld', 'keras.losses.KLD') +@dispatch.add_dispatch_support def kl_divergence(y_true, y_pred): """Computes Kullback-Leibler divergence loss between `y_true` and `y_pred`. @@ -1635,6 +1649,7 @@ def kl_divergence(y_true, y_pred): @keras_export('keras.metrics.poisson', 'keras.losses.poisson') +@dispatch.add_dispatch_support def poisson(y_true, y_pred): """Computes the Poisson loss between y_true and y_pred. @@ -1676,6 +1691,7 @@ def poisson(y_true, y_pred): 'keras.losses.cosine', 'keras.losses.cosine_similarity', ]) +@dispatch.add_dispatch_support def cosine_similarity(y_true, y_pred, axis=-1): """Computes the cosine similarity between labels and predictions. diff --git a/tensorflow/python/keras/losses_test.py b/tensorflow/python/keras/losses_test.py index 574d3d3f756..26a586b872b 100644 --- a/tensorflow/python/keras/losses_test.py +++ b/tensorflow/python/keras/losses_test.py @@ -125,8 +125,10 @@ class KerasLossesTest(test.TestCase, parameterized.TestCase): backend.eval(output_from_softmax), atol=1e-5) - @combinations.generate(combinations.combine(mode=['graph', 'eager'])) + @combinations.generate(combinations.combine(mode=['graph'])) def test_sparse_categorical_crossentropy_loss_with_unknown_rank_tensor(self): + # This test only runs in graph because the TF op layer is not supported yet + # for sparse ops. t = backend.placeholder() p = backend.placeholder() o = losses.sparse_categorical_crossentropy(t, p) diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py index 63cf7c578bc..a67755b9333 100644 --- a/tensorflow/python/keras/metrics.py +++ b/tensorflow/python/keras/metrics.py @@ -69,6 +69,7 @@ from tensorflow.python.ops import variables as tf_variables from tensorflow.python.ops import weights_broadcast_ops from tensorflow.python.ops.losses import util as tf_losses_utils from tensorflow.python.training.tracking import base as trackable +from tensorflow.python.util import dispatch from tensorflow.python.util import nest from tensorflow.python.util import tf_inspect from tensorflow.python.util.tf_export import keras_export @@ -3212,6 +3213,7 @@ def accuracy(y_true, y_pred): @keras_export('keras.metrics.binary_accuracy') +@dispatch.add_dispatch_support def binary_accuracy(y_true, y_pred, threshold=0.5): """Calculates how often predictions matches binary labels. @@ -3239,6 +3241,7 @@ def binary_accuracy(y_true, y_pred, threshold=0.5): @keras_export('keras.metrics.categorical_accuracy') +@dispatch.add_dispatch_support def categorical_accuracy(y_true, y_pred): """Calculates how often predictions matches one-hot labels. @@ -3267,6 +3270,7 @@ def categorical_accuracy(y_true, y_pred): @keras_export('keras.metrics.sparse_categorical_accuracy') +@dispatch.add_dispatch_support def sparse_categorical_accuracy(y_true, y_pred): """Calculates how often predictions matches integer labels. @@ -3307,6 +3311,7 @@ def sparse_categorical_accuracy(y_true, y_pred): @keras_export('keras.metrics.top_k_categorical_accuracy') +@dispatch.add_dispatch_support def top_k_categorical_accuracy(y_true, y_pred, k=5): """Computes how often targets are in the top `K` predictions. @@ -3332,6 +3337,7 @@ def top_k_categorical_accuracy(y_true, y_pred, k=5): @keras_export('keras.metrics.sparse_top_k_categorical_accuracy') +@dispatch.add_dispatch_support def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5): """Computes how often integer targets are in the top `K` predictions. diff --git a/tensorflow/python/keras/saving/saved_model/saved_model_test.py b/tensorflow/python/keras/saving/saved_model/saved_model_test.py index 30a93e2bba3..4ada84191dc 100644 --- a/tensorflow/python/keras/saving/saved_model/saved_model_test.py +++ b/tensorflow/python/keras/saving/saved_model/saved_model_test.py @@ -39,7 +39,6 @@ from tensorflow.python.distribute import mirrored_strategy from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.feature_column import feature_column_v2 as fc -from tensorflow.python.feature_column.dense_features import DenseFeatures from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -48,6 +47,7 @@ from tensorflow.python.keras import combinations from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import regularizers from tensorflow.python.keras import testing_utils +from tensorflow.python.keras.feature_column.dense_features import DenseFeatures from tensorflow.python.keras.saving.saved_model import load as keras_load from tensorflow.python.keras.saving.saved_model import save_impl as keras_save from tensorflow.python.keras.utils import generic_utils diff --git a/tensorflow/python/keras/testing_utils.py b/tensorflow/python/keras/testing_utils.py index 5da6aeef391..b41abbdf1f5 100644 --- a/tensorflow/python/keras/testing_utils.py +++ b/tensorflow/python/keras/testing_utils.py @@ -25,6 +25,7 @@ import numpy as np from tensorflow.python import tf2 from tensorflow.python.eager import context +from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util @@ -44,6 +45,14 @@ from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_inspect +def string_test(actual, expected): + np.testing.assert_array_equal(actual, expected) + + +def numeric_test(actual, expected): + np.testing.assert_allclose(actual, expected, rtol=1e-3, atol=1e-6) + + def get_test_data(train_samples, test_samples, input_shape, @@ -132,6 +141,11 @@ def layer_test(layer_cls, if expected_output_dtype is None: expected_output_dtype = input_dtype + if dtypes.as_dtype(expected_output_dtype) == dtypes.string: + assert_equal = string_test + else: + assert_equal = numeric_test + # instantiation kwargs = kwargs or {} layer = layer_cls(**kwargs) @@ -199,8 +213,7 @@ def layer_test(layer_cls, (layer_cls.__name__, x, actual_output.dtype, computed_output_signature.dtype, kwargs)) if expected_output is not None: - np.testing.assert_allclose(actual_output, expected_output, - rtol=1e-3, atol=1e-6) + assert_equal(actual_output, expected_output) # test serialization, weight setting at model level model_config = model.get_config() @@ -209,7 +222,7 @@ def layer_test(layer_cls, weights = model.get_weights() recovered_model.set_weights(weights) output = recovered_model.predict(input_data) - np.testing.assert_allclose(output, actual_output, rtol=1e-3, atol=1e-6) + assert_equal(output, actual_output) # test training mode (e.g. useful for dropout tests) # Rebuild the model to avoid the graph being reused between predict() and @@ -254,8 +267,7 @@ def layer_test(layer_cls, computed_output_shape, kwargs)) if expected_output is not None: - np.testing.assert_allclose(actual_output, expected_output, - rtol=1e-3, atol=1e-6) + assert_equal(actual_output, expected_output) # test serialization, weight setting at model level model_config = model.get_config() @@ -264,7 +276,7 @@ def layer_test(layer_cls, weights = model.get_weights() recovered_model.set_weights(weights) output = recovered_model.predict(input_data) - np.testing.assert_allclose(output, actual_output, rtol=1e-3, atol=1e-6) + assert_equal(output, actual_output) # for further checks in the caller function return actual_output diff --git a/tensorflow/python/keras/utils/vis_utils.py b/tensorflow/python/keras/utils/vis_utils.py index 158f6c83748..e56f07e4bb7 100644 --- a/tensorflow/python/keras/utils/vis_utils.py +++ b/tensorflow/python/keras/utils/vis_utils.py @@ -129,6 +129,7 @@ def model_to_dot(model, sub_w_first_node = {} sub_w_last_node = {} + layers = model.layers if not model._is_graph_network: node = pydot.Node(str(id(model)), label=model.name) dot.add_node(node) @@ -136,7 +137,7 @@ def model_to_dot(model, elif isinstance(model, sequential.Sequential): if not model.built: model.build() - layers = model._layers + layers = super(sequential.Sequential, model).layers # Create graph nodes. for i, layer in enumerate(layers): diff --git a/tensorflow/python/keras/utils/vis_utils_test.py b/tensorflow/python/keras/utils/vis_utils_test.py index 34bc835da32..984014216be 100644 --- a/tensorflow/python/keras/utils/vis_utils_test.py +++ b/tensorflow/python/keras/utils/vis_utils_test.py @@ -21,6 +21,7 @@ from __future__ import print_function from tensorflow.python import keras from tensorflow.python.keras.utils import vis_utils from tensorflow.python.lib.io import file_io +from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -67,6 +68,32 @@ class ModelToDotFormatTest(test.TestCase): except ImportError: pass + def test_plot_model_with_add_loss(self): + inputs = keras.Input(shape=(None, 3)) + outputs = keras.layers.Dense(1)(inputs) + model = keras.Model(inputs, outputs) + model.add_loss(math_ops.reduce_mean(outputs)) + dot_img_file = 'model_3.png' + try: + vis_utils.plot_model( + model, to_file=dot_img_file, show_shapes=True, expand_nested=True) + self.assertTrue(file_io.file_exists(dot_img_file)) + file_io.delete_file(dot_img_file) + except ImportError: + pass + + model = keras.Sequential([ + keras.Input(shape=(None, 3)), keras.layers.Dense(1)]) + model.add_loss(math_ops.reduce_mean(model.output)) + dot_img_file = 'model_4.png' + try: + vis_utils.plot_model( + model, to_file=dot_img_file, show_shapes=True, expand_nested=True) + self.assertTrue(file_io.file_exists(dot_img_file)) + file_io.delete_file(dot_img_file) + except ImportError: + pass + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index 13f59b74baf..a04c874c9d6 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -1,8 +1,11 @@ # Tests of TensorFlow kernels written using the Python API. -load("//tensorflow:tensorflow.bzl", "sycl_py_test", "tf_custom_op_library", "tf_py_test") +load("//tensorflow:tensorflow.bzl", "sycl_py_test", "tf_custom_op_library") load("//tensorflow:tensorflow.bzl", "cuda_py_test") +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "tf_py_test") + package( default_visibility = ["//tensorflow:internal"], licenses = ["notice"], # Apache 2.0 @@ -175,9 +178,9 @@ cuda_py_test( srcs = ["bincount_op_test.py"], tags = ["no_windows_gpu"], deps = [ + "//tensorflow/python:bincount_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:math_ops", ], ) @@ -861,6 +864,7 @@ cuda_py_test( srcs = ["resource_variable_ops_test.py"], # TODO(b/128347673): Re-enable. tags = ["no_windows"], + tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py index bea08ac70bf..9eb8bfcef41 100644 --- a/tensorflow/python/kernel_tests/array_ops_test.py +++ b/tensorflow/python/kernel_tests/array_ops_test.py @@ -42,6 +42,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import init_ops +from tensorflow.python.ops import list_ops from tensorflow.python.ops import map_fn from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops @@ -1994,5 +1995,32 @@ class RepeatTest(test_util.TensorFlowTestCase, parameterized.TestCase): self.assertAllEqual(v_tf_fn, v_np) +@test_util.run_all_in_graph_and_eager_modes +class TileVariantTest(test_util.TensorFlowTestCase): + + def test_tile_tensor_list(self): + t = constant_op.constant(np.random.uniform(size=[2, 3, 4])) + handle = list_ops.tensor_list_from_tensor(t, element_shape=None) + with ops.device("CPU:0"): + tiled_handles = array_ops.tile(array_ops.reshape(handle, [1]), [2]) + tiled_tensor_0 = list_ops.tensor_list_stack(tiled_handles[0], t.dtype, 2, + [3, 4]) + tiled_tensor_1 = list_ops.tensor_list_stack(tiled_handles[1], t.dtype, 2, + [3, 4]) + self.assertAllEqual(t, tiled_tensor_0) + self.assertAllEqual(t, tiled_tensor_1) + # Now mutate some of the lists and make sure the changes are not reflected + # in the tiled handles. + with ops.control_dependencies([ + list_ops.tensor_list_scatter([t[0] + 1], [0], input_handle=handle), + list_ops.tensor_list_set_item(tiled_handles[0], 0, t[0] + 2)]): + tiled_tensor_0 = list_ops.tensor_list_stack(tiled_handles[0], t.dtype, 2, + [3, 4]) + tiled_tensor_1 = list_ops.tensor_list_stack(tiled_handles[1], t.dtype, 2, + [3, 4]) + self.assertAllEqual(t, tiled_tensor_0) + self.assertAllEqual(t, tiled_tensor_1) + + if __name__ == "__main__": test_lib.main() diff --git a/tensorflow/python/kernel_tests/attention_ops_test.py b/tensorflow/python/kernel_tests/attention_ops_test.py index 87e709fc69e..804a0b20cc9 100644 --- a/tensorflow/python/kernel_tests/attention_ops_test.py +++ b/tensorflow/python/kernel_tests/attention_ops_test.py @@ -23,6 +23,7 @@ import numpy as np from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_image_ops from tensorflow.python.ops import image_ops from tensorflow.python.platform import test @@ -196,6 +197,55 @@ class ExtractGlimpseTest(test.TestCase): expected_rows=[None, None, None, 1, 2, 3, 4], expected_cols=[56, 57, 58, 59, 60]) + def testGlimpseNoiseZeroV1Compatible(self): + # Note: The old versions of extract_glimpse was incorrect in implementation. + # This test is for compatibility so that graph save in old versions behave + # the same. Notice the API uses gen_image_ops.extract_glimpse() on purpose. + # + # Image: + # [ 0. 1. 2. 3. 4.] + # [ 5. 6. 7. 8. 9.] + # [ 10. 11. 12. 13. 14.] + # [ 15. 16. 17. 18. 19.] + # [ 20. 21. 22. 23. 24.] + img = constant_op.constant( + np.arange(25).reshape((1, 5, 5, 1)), dtype=dtypes.float32) + with self.test_session(): + # Result 1: + # [ 0. 0. 0.] + # [ 0. 0. 0.] + # [ 0. 0. 0.] + result1 = gen_image_ops.extract_glimpse( + img, [3, 3], [[-2, 2]], + centered=False, + normalized=False, + noise='zero', + uniform_noise=False) + self.assertAllEqual( + np.asarray([[0, 0, 0], [0, 0, 0], [0, 0, 0]]), + self.evaluate(result1)[0, :, :, 0]) + + # Result 2: + # [ 0. 0. 0. 0. 0. 0. 0.] + # [ 0. 0. 1. 2. 3. 4. 0.] + # [ 0. 5. 6. 7. 8. 9. 0.] + # [ 0. 10. 11. 12. 13. 14. 0.] + # [ 0. 15. 16. 17. 18. 19. 0.] + # [ 0. 20. 21. 22. 23. 24. 0.] + # [ 0. 0. 0. 0. 0. 0. 0.] + result2 = gen_image_ops.extract_glimpse( + img, [7, 7], [[0, 0]], + normalized=False, + noise='zero', + uniform_noise=False) + self.assertAllEqual( + np.asarray([[0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 2, 3, 4, 0], + [0, 5, 6, 7, 8, 9, 0], [0, 10, 11, 12, 13, 14, 0], + [0, 15, 16, 17, 18, 19, 0], [0, 20, 21, 22, 23, 24, 0], + [0, 0, 0, 0, 0, 0, 0]]), + self.evaluate(result2)[0, :, :, 0]) + + def testGlimpseNoiseZero(self): # Image: # [ 0. 1. 2. 3. 4.] @@ -211,7 +261,7 @@ class ExtractGlimpseTest(test.TestCase): # [ 0. 0. 0.] # [ 0. 0. 0.] result1 = image_ops.extract_glimpse_v2( - img, [3, 3], [[-2, 2]], + img, [3, 3], [[-2, -2]], centered=False, normalized=False, noise='zero') @@ -220,22 +270,37 @@ class ExtractGlimpseTest(test.TestCase): self.evaluate(result1)[0, :, :, 0]) # Result 2: + # [ 12. 13. 14. 0. 0. 0. 0.] + # [ 17. 18. 19. 0. 0. 0. 0.] + # [ 22. 23. 24. 0. 0. 0. 0.] + # [ 0. 0. 0. 0. 0. 0. 0.] + # [ 0. 0. 0. 0. 0. 0. 0.] # [ 0. 0. 0. 0. 0. 0. 0.] - # [ 0. 0. 1. 2. 3. 4. 0.] - # [ 0. 5. 6. 7. 8. 9. 0.] - # [ 0. 10. 11. 12. 13. 14. 0.] - # [ 0. 15. 16. 17. 18. 19. 0.] - # [ 0. 20. 21. 22. 23. 24. 0.] # [ 0. 0. 0. 0. 0. 0. 0.] result2 = image_ops.extract_glimpse_v2( img, [7, 7], [[0, 0]], normalized=False, noise='zero') self.assertAllEqual( - np.asarray([[0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 2, 3, 4, 0], - [0, 5, 6, 7, 8, 9, 0], [0, 10, 11, 12, 13, 14, 0], - [0, 15, 16, 17, 18, 19, 0], [0, 20, 21, 22, 23, 24, 0], + np.asarray([[12, 13, 14, 0, 0, 0, 0], [17, 18, 19, 0, 0, 0, 0], + [22, 23, 24, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0]]), self.evaluate(result2)[0, :, :, 0]) + def testGlimpseNonNormalizedNonCentered(self): + img = constant_op.constant( + np.arange(25).reshape((1, 5, 5, 1)), dtype=dtypes.float32) + with self.test_session(): + result1 = image_ops.extract_glimpse_v2( + img, [3, 3], [[0, 0]], centered=False, normalized=False) + result2 = image_ops.extract_glimpse_v2( + img, [3, 3], [[1, 0]], centered=False, normalized=False) + self.assertAllEqual( + np.asarray([[0, 1, 2], [5, 6, 7], [10, 11, 12]]), + self.evaluate(result1)[0, :, :, 0]) + self.assertAllEqual( + np.asarray([[5, 6, 7], [10, 11, 12], [15, 16, 17]]), + self.evaluate(result2)[0, :, :, 0]) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/kernel_tests/betainc_op_test.py b/tensorflow/python/kernel_tests/betainc_op_test.py index c4f70b5bc29..c564c822918 100644 --- a/tensorflow/python/kernel_tests/betainc_op_test.py +++ b/tensorflow/python/kernel_tests/betainc_op_test.py @@ -55,8 +55,8 @@ class BetaincTest(test.TestCase): # the scipy version of betainc uses a double-only implementation. # TODO(ebrevdo): identify reasons for (sometime) precision loss # with doubles - rtol = 1e-4 if dtype == dtypes.float32 else 5e-5 - atol = 9e-6 if dtype == dtypes.float32 else 3e-6 + rtol = 1e-4 + atol = 1e-5 self.assertAllCloseAccordingToType( scipy_out, tf_out, rtol=rtol, atol=atol) @@ -66,7 +66,8 @@ class BetaincTest(test.TestCase): with self.cached_session(): tf_comb = math_ops.betainc(a_comb, b_comb, x_comb).eval() scipy_comb = special.betainc(a_comb, b_comb, x_comb, dtype=np_dt) - self.assertAllCloseAccordingToType(scipy_comb, tf_comb) + self.assertAllCloseAccordingToType( + scipy_comb, tf_comb, rtol=rtol, atol=atol) # Test broadcasting between scalars and other shapes with self.cached_session(): diff --git a/tensorflow/python/kernel_tests/bincount_op_test.py b/tensorflow/python/kernel_tests/bincount_op_test.py index 222716dfdfa..22ac9f8e99d 100644 --- a/tensorflow/python/kernel_tests/bincount_op_test.py +++ b/tensorflow/python/kernel_tests/bincount_op_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for math_ops.bincount.""" +"""Tests for bincount_ops.bincount.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -25,8 +25,8 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import bincount_ops from tensorflow.python.ops import gen_math_ops -from tensorflow.python.ops import math_ops from tensorflow.python.ops import sparse_ops from tensorflow.python.ops.ragged import ragged_factory_ops from tensorflow.python.ops.ragged import ragged_tensor @@ -37,45 +37,50 @@ class BincountTest(test_util.TensorFlowTestCase): def test_empty(self): with self.session(use_gpu=True): - self.assertAllEqual(self.evaluate(math_ops.bincount([], minlength=5)), - [0, 0, 0, 0, 0]) - self.assertAllEqual(self.evaluate(math_ops.bincount([], minlength=1)), - [0]) - self.assertAllEqual(self.evaluate(math_ops.bincount([], minlength=0)), - []) - self.assertEqual(self.evaluate(math_ops.bincount([], minlength=0, - dtype=np.float32)).dtype, - np.float32) - self.assertEqual(self.evaluate(math_ops.bincount([], minlength=3, - dtype=np.float64)).dtype, - np.float64) + self.assertAllEqual( + self.evaluate(bincount_ops.bincount([], minlength=5)), + [0, 0, 0, 0, 0]) + self.assertAllEqual( + self.evaluate(bincount_ops.bincount([], minlength=1)), [0]) + self.assertAllEqual( + self.evaluate(bincount_ops.bincount([], minlength=0)), []) + self.assertEqual( + self.evaluate( + bincount_ops.bincount([], minlength=0, dtype=np.float32)).dtype, + np.float32) + self.assertEqual( + self.evaluate( + bincount_ops.bincount([], minlength=3, dtype=np.float64)).dtype, + np.float64) def test_values(self): with self.session(use_gpu=True): - self.assertAllEqual(self.evaluate(math_ops.bincount([1, 1, 1, 2, 2, 3])), - [0, 3, 2, 1]) + self.assertAllEqual( + self.evaluate(bincount_ops.bincount([1, 1, 1, 2, 2, 3])), + [0, 3, 2, 1]) arr = [1, 1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5] - self.assertAllEqual(self.evaluate(math_ops.bincount(arr)), - [0, 5, 4, 3, 2, 1]) + self.assertAllEqual( + self.evaluate(bincount_ops.bincount(arr)), [0, 5, 4, 3, 2, 1]) arr += [0, 0, 0, 0, 0, 0] - self.assertAllEqual(self.evaluate(math_ops.bincount(arr)), - [6, 5, 4, 3, 2, 1]) + self.assertAllEqual( + self.evaluate(bincount_ops.bincount(arr)), [6, 5, 4, 3, 2, 1]) - self.assertAllEqual(self.evaluate(math_ops.bincount([])), []) - self.assertAllEqual(self.evaluate(math_ops.bincount([0, 0, 0])), [3]) - self.assertAllEqual(self.evaluate(math_ops.bincount([5])), - [0, 0, 0, 0, 0, 1]) - self.assertAllEqual(self.evaluate(math_ops.bincount(np.arange(10000))), - np.ones(10000)) + self.assertAllEqual(self.evaluate(bincount_ops.bincount([])), []) + self.assertAllEqual(self.evaluate(bincount_ops.bincount([0, 0, 0])), [3]) + self.assertAllEqual( + self.evaluate(bincount_ops.bincount([5])), [0, 0, 0, 0, 0, 1]) + self.assertAllEqual( + self.evaluate(bincount_ops.bincount(np.arange(10000))), + np.ones(10000)) def test_maxlength(self): with self.session(use_gpu=True): - self.assertAllEqual(self.evaluate(math_ops.bincount([5], maxlength=3)), - [0, 0, 0]) - self.assertAllEqual(self.evaluate(math_ops.bincount([1], maxlength=3)), - [0, 1]) - self.assertAllEqual(self.evaluate(math_ops.bincount([], maxlength=3)), - []) + self.assertAllEqual( + self.evaluate(bincount_ops.bincount([5], maxlength=3)), [0, 0, 0]) + self.assertAllEqual( + self.evaluate(bincount_ops.bincount([1], maxlength=3)), [0, 1]) + self.assertAllEqual( + self.evaluate(bincount_ops.bincount([], maxlength=3)), []) def test_random_with_weights(self): num_samples = 10000 @@ -88,7 +93,7 @@ class BincountTest(test_util.TensorFlowTestCase): else: weights = np.random.random(num_samples) self.assertAllClose( - self.evaluate(math_ops.bincount(arr, weights)), + self.evaluate(bincount_ops.bincount(arr, weights)), np.bincount(arr, weights)) def test_random_without_weights(self): @@ -99,20 +104,20 @@ class BincountTest(test_util.TensorFlowTestCase): arr = np.random.randint(0, 1000, num_samples) weights = np.ones(num_samples).astype(dtype) self.assertAllClose( - self.evaluate(math_ops.bincount(arr, None)), + self.evaluate(bincount_ops.bincount(arr, None)), np.bincount(arr, weights)) def test_zero_weights(self): with self.session(use_gpu=True): self.assertAllEqual( - self.evaluate(math_ops.bincount(np.arange(1000), np.zeros(1000))), + self.evaluate(bincount_ops.bincount(np.arange(1000), np.zeros(1000))), np.zeros(1000)) def test_negative(self): # unsorted_segment_sum will only report InvalidArgumentError on CPU with self.cached_session(), ops.device("/CPU:0"): with self.assertRaises(errors.InvalidArgumentError): - self.evaluate(math_ops.bincount([1, 2, 3, -1, 6, 8])) + self.evaluate(bincount_ops.bincount([1, 2, 3, -1, 6, 8])) @test_util.run_deprecated_v1 def test_shape_function(self): diff --git a/tensorflow/python/kernel_tests/boosted_trees/quantile_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/quantile_ops_test.py index fb44c33d602..7c3a382c955 100644 --- a/tensorflow/python/kernel_tests/boosted_trees/quantile_ops_test.py +++ b/tensorflow/python/kernel_tests/boosted_trees/quantile_ops_test.py @@ -82,7 +82,7 @@ class QuantileOpsTest(test_util.TensorFlowTestCase): self.eps = 0.01 self.max_elements = 1 << 16 - self.num_quantiles = constant_op.constant(3, dtype=dtypes.int64) + self.num_quantiles = constant_op.constant(4, dtype=dtypes.int64) def testBasicQuantileBucketsSingleResource(self): with self.cached_session() as sess: @@ -183,7 +183,10 @@ class QuantileOpsTest(test_util.TensorFlowTestCase): with self.cached_session() as sess: accumulator = boosted_trees_ops.QuantileAccumulator( - num_streams=2, num_quantiles=3, epsilon=self.eps, name="q0") + num_streams=2, + num_quantiles=self.num_quantiles, + epsilon=self.eps, + name="q0") save = saver.Saver() resources.initialize_resources(resources.shared_resources()).run() @@ -202,7 +205,10 @@ class QuantileOpsTest(test_util.TensorFlowTestCase): with self.session(graph=ops.Graph()) as sess: accumulator = boosted_trees_ops.QuantileAccumulator( - num_streams=2, num_quantiles=3, epsilon=self.eps, name="q0") + num_streams=2, + num_quantiles=self.num_quantiles, + epsilon=self.eps, + name="q0") save = saver.Saver() save.restore(sess, save_path) buckets = accumulator.get_bucket_boundaries() @@ -215,7 +221,10 @@ class QuantileOpsTest(test_util.TensorFlowTestCase): with self.cached_session() as sess: accumulator = boosted_trees_ops.QuantileAccumulator( - num_streams=2, num_quantiles=3, epsilon=self.eps, name="q0") + num_streams=2, + num_quantiles=self.num_quantiles, + epsilon=self.eps, + name="q0") save = saver.Saver() resources.initialize_resources(resources.shared_resources()).run() @@ -233,7 +242,10 @@ class QuantileOpsTest(test_util.TensorFlowTestCase): with self.session(graph=ops.Graph()) as sess: accumulator = boosted_trees_ops.QuantileAccumulator( - num_streams=2, num_quantiles=3, epsilon=self.eps, name="q0") + num_streams=2, + num_quantiles=self.num_quantiles, + epsilon=self.eps, + name="q0") save = saver.Saver() save.restore(sess, save_path) buckets = accumulator.get_bucket_boundaries() diff --git a/tensorflow/python/kernel_tests/cholesky_op_test.py b/tensorflow/python/kernel_tests/cholesky_op_test.py index 7d5f7715eb1..5dc334c897b 100644 --- a/tensorflow/python/kernel_tests/cholesky_op_test.py +++ b/tensorflow/python/kernel_tests/cholesky_op_test.py @@ -29,15 +29,14 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import gradient_checker +from tensorflow.python.ops import gradient_checker_v2 from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops +from tensorflow.python.ops import stateless_random_ops from tensorflow.python.ops import variables from tensorflow.python.ops.linalg import linalg from tensorflow.python.platform import benchmark from tensorflow.python.platform import test -from tensorflow.python.platform import tf_logging # Different gradient implementations for benchmark purposes @@ -91,7 +90,7 @@ def TriAngInvCompositeGrad(l, grad): class CholeskyOpTest(test.TestCase): - def _verifyCholeskyBase(self, sess, x, chol, verification): + def _verifyCholeskyBase(self, x, chol, verification): chol_np, verification_np = self.evaluate([chol, verification]) self.assertAllClose(x, verification_np) self.assertShapeEqual(x, chol) @@ -106,11 +105,11 @@ class CholeskyOpTest(test.TestCase): def _verifyCholesky(self, x): # Verify that LL^T == x. - with self.cached_session(use_gpu=True) as sess: - chol = linalg_ops.cholesky(x) - verification = math_ops.matmul(chol, chol, adjoint_b=True) - self._verifyCholeskyBase(sess, x, chol, verification) + chol = linalg_ops.cholesky(x) + verification = math_ops.matmul(chol, chol, adjoint_b=True) + self._verifyCholeskyBase(x, chol, verification) + @test_util.run_in_graph_and_eager_modes(use_gpu=True) def testBasic(self): data = np.array([[4., -1., 2.], [-1., 6., 0], [2., 0., 5.]]) for dtype in (np.float32, np.float64): @@ -123,6 +122,7 @@ class CholeskyOpTest(test.TestCase): complex_data += data self._verifyCholesky(complex_data) + @test_util.run_in_graph_and_eager_modes(use_gpu=True) def testBatch(self): simple_array = np.array([[[1., 0.], [0., 5.]]]) # shape (1, 2, 2) self._verifyCholesky(simple_array) @@ -144,21 +144,21 @@ class CholeskyOpTest(test.TestCase): matrices[i] = np.dot(matrices[i].T.conj(), matrices[i]) self._verifyCholesky(matrices) - @test_util.run_deprecated_v1 + @test_util.run_in_graph_and_eager_modes(use_gpu=True) def testNonSquareMatrix(self): - with self.assertRaises(ValueError): + with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)): linalg_ops.cholesky(np.array([[1., 2., 3.], [3., 4., 5.]])) - with self.assertRaises(ValueError): + with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)): linalg_ops.cholesky( np.array([[[1., 2., 3.], [3., 4., 5.]], [[1., 2., 3.], [3., 4., 5.]] ])) - @test_util.run_v1_only("b/120545219") + @test_util.run_in_graph_and_eager_modes(use_gpu=True) def testWrongDimensions(self): tensor3 = constant_op.constant([1., 2.]) - with self.assertRaises(ValueError): + with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)): linalg_ops.cholesky(tensor3) - with self.assertRaises(ValueError): + with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)): linalg_ops.cholesky(tensor3) # The below invalid Cholesky call returns an error with TF Classic and just @@ -175,126 +175,126 @@ class CholeskyOpTest(test.TestCase): self._verifyCholesky( np.array([[1., -1., 0.], [-1., 1., -1.], [0., -1., 1.]])) + @test_util.run_in_graph_and_eager_modes(use_gpu=True) def testEmpty(self): self._verifyCholesky(np.empty([0, 2, 2])) self._verifyCholesky(np.empty([2, 0, 0])) - @test_util.run_deprecated_v1 + @test_util.run_in_graph_and_eager_modes(use_gpu=True) def testConcurrentExecutesWithoutError(self): - with self.session(use_gpu=True) as sess: - matrix1 = random_ops.random_normal([5, 5], seed=42) - matrix2 = random_ops.random_normal([5, 5], seed=42) - matrix1 = math_ops.matmul(matrix1, matrix1, adjoint_a=True) - matrix2 = math_ops.matmul(matrix2, matrix2, adjoint_a=True) - c1 = linalg_ops.cholesky(matrix1) - c2 = linalg_ops.cholesky(matrix2) - c1_val, c2_val = self.evaluate([c1, c2]) - self.assertAllClose(c1_val, c2_val) + seed = [42, 24] + matrix_shape = [5, 5] + matrix1 = stateless_random_ops.stateless_random_normal(matrix_shape, seed) + matrix2 = stateless_random_ops.stateless_random_normal(matrix_shape, seed) + matrix1 = math_ops.matmul(matrix1, matrix1, adjoint_a=True) + matrix2 = math_ops.matmul(matrix2, matrix2, adjoint_a=True) + c1 = linalg_ops.cholesky(matrix1) + c2 = linalg_ops.cholesky(matrix2) + c1_val, c2_val = self.evaluate([c1, c2]) + self.assertAllClose(c1_val, c2_val) class CholeskyGradTest(test.TestCase): - _backprop_block_size = 32 + _backprop_block_size = 16 def getShapes(self, shapeList): return ((elem, int(np.floor(1.2 * elem))) for elem in shapeList) - @test_util.run_deprecated_v1 + @test_util.run_in_graph_and_eager_modes(use_gpu=True) def testSmallMatrices(self): np.random.seed(0) shapes = self.getShapes([1, 2, 10]) self.runFiniteDifferences( shapes, dtypes=(dtypes_lib.float32, dtypes_lib.float64)) - @test_util.run_deprecated_v1 + @test_util.run_in_graph_and_eager_modes(use_gpu=True) def testSmallMatricesComplex(self): np.random.seed(0) shapes = self.getShapes([1, 2, 10]) self.runFiniteDifferences( shapes, dtypes=(dtypes_lib.complex64, dtypes_lib.complex128)) - @test_util.run_deprecated_v1 + @test_util.run_in_graph_and_eager_modes(use_gpu=True) def testOneBlockMatrices(self): np.random.seed(0) shapes = self.getShapes([self._backprop_block_size + 1]) self.runFiniteDifferences( shapes, dtypes=(dtypes_lib.float32, dtypes_lib.float64), - scalarTest=True) + scalar_test=True) - @test_util.run_deprecated_v1 + @test_util.run_in_graph_and_eager_modes(use_gpu=True) def testTwoBlockMatrixFloat(self): np.random.seed(0) shapes = self.getShapes([2 * self._backprop_block_size + 1]) self.runFiniteDifferences( - shapes, dtypes=(dtypes_lib.float32,), scalarTest=True) + shapes, dtypes=(dtypes_lib.float32,), scalar_test=True) - @test_util.run_deprecated_v1 + @test_util.run_in_graph_and_eager_modes(use_gpu=True) def testTwoBlockMatrixDouble(self): np.random.seed(0) shapes = self.getShapes([2 * self._backprop_block_size + 1]) self.runFiniteDifferences( - shapes, dtypes=(dtypes_lib.float64,), scalarTest=True) + shapes, dtypes=(dtypes_lib.float64,), scalar_test=True) - @test_util.run_v1_only("b/120545219") + @test_util.run_in_graph_and_eager_modes(use_gpu=True) def testTwoBlockMatrixComplexFloat(self): np.random.seed(0) shapes = self.getShapes([2 * self._backprop_block_size + 1]) self.runFiniteDifferences( - shapes, dtypes=(dtypes_lib.complex64,), scalarTest=True) + shapes, dtypes=(dtypes_lib.complex64,), scalar_test=True) - @test_util.run_deprecated_v1 + @test_util.run_in_graph_and_eager_modes(use_gpu=True) def testTwoBlockMatrixComplexDouble(self): np.random.seed(0) shapes = self.getShapes([2 * self._backprop_block_size + 1]) self.runFiniteDifferences( - shapes, dtypes=(dtypes_lib.complex128,), scalarTest=True) + shapes, dtypes=(dtypes_lib.complex128,), scalar_test=True) + + def _runOneTest(self, shape, dtype, batch, scalar_test): + if dtype == dtypes_lib.float64: + tol = 1e-5 + elif dtype == dtypes_lib.complex128: + tol = 5e-5 + else: + tol = 5e-3 + epsilon = np.finfo(dtype.as_numpy_dtype).eps + delta = epsilon**(1.0 / 3.0) + + def RandomInput(): + a = np.random.randn(shape[0], shape[1]).astype(dtype.as_numpy_dtype) + if dtype.is_complex: + a += 1j * np.random.randn(shape[0], shape[1]).astype( + dtype.as_numpy_dtype) + return a + + def Compute(x): + # Turn the random matrix x into a Hermitian matrix by + # computing the quadratic form x * x^H. + a = math_ops.matmul(x, math_ops.conj( + array_ops.matrix_transpose(x))) / shape[0] + if batch: + a = array_ops.tile(array_ops.expand_dims(a, 0), [2, 1, 1]) + # Finally take the cholesky decomposition of the Hermitian matrix. + c = linalg_ops.cholesky(a) + if scalar_test: + # Reduce to a single scalar output to speed up test. + c = math_ops.reduce_mean(c) + return c + + theoretical, numerical = gradient_checker_v2.compute_gradient( + Compute, [RandomInput()], delta=delta) + self.assertAllClose(theoretical, numerical, atol=tol, rtol=tol) def runFiniteDifferences(self, shapes, dtypes=(dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.complex64, dtypes_lib.complex128), - scalarTest=False): - with self.session(use_gpu=True): - for shape in shapes: - for batch in False, True: - for dtype in dtypes: - if not scalarTest: - data = np.random.randn(shape[0], shape[1]) - if dtype.is_complex: - data = data.astype(np.complex64) - data += 1j * np.random.randn(shape[0], shape[1]) - x = constant_op.constant(data, dtype) - tensor = math_ops.matmul( - x, math_ops.conj(array_ops.transpose(x))) / shape[0] - else: - # This is designed to be a faster test for larger matrices. - data = np.random.randn() - if dtype.is_complex: - data = np.complex64(data) - data += 1j * np.random.randn() - x = constant_op.constant(data, dtype) - R = constant_op.constant( - np.random.randn(shape[0], shape[1]), dtype) - e = math_ops.multiply(R, x) - tensor = math_ops.matmul( - e, math_ops.conj(array_ops.transpose(e))) / shape[0] - - # Inner-most matrices in tensor are positive definite. - if batch: - tensor = array_ops.tile( - array_ops.expand_dims(tensor, 0), [4, 1, 1]) - y = linalg_ops.cholesky(tensor) - if scalarTest: - y = math_ops.reduce_mean(y) - error = gradient_checker.compute_gradient_error( - x, x._shape_as_list(), y, y._shape_as_list()) - tf_logging.info("error = %f", error) - if dtype == dtypes_lib.float64: - self.assertLess(error, 1e-5) - elif dtype == dtypes_lib.complex128: - self.assertLess(error, 5e-5) - else: - self.assertLess(error, 5e-3) + scalar_test=False): + for shape_ in shapes: + for dtype_ in dtypes: + for batch_ in False, True: + self._runOneTest(shape_, dtype_, batch_, scalar_test) class CholeskyBenchmark(test.Benchmark): diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py index 9192dc05ebc..e01abc8133d 100644 --- a/tensorflow/python/kernel_tests/conv_ops_test.py +++ b/tensorflow/python/kernel_tests/conv_ops_test.py @@ -431,6 +431,82 @@ class Conv2DTest(test.TestCase): padding="VALID", expected=expected_output) + @test_util.run_in_graph_and_eager_modes + def testConv2DExpandedBatch(self): + tensor_in_sizes_batch = [10, 2, 3, 3] + tensor_in_sizes_expanded_batch = [2, 5, 2, 3, 3] + filter_in_sizes = [1, 1, 3, 3] + filter_in = self._CreateNumpyTensor(filter_in_sizes) + x1 = self._CreateNumpyTensor(tensor_in_sizes_batch) + x2 = x1.reshape(tensor_in_sizes_expanded_batch) + conv1 = nn_ops.conv2d( + x1, + filter_in, + strides=[1, 1], + padding="VALID") + conv2 = nn_ops.conv2d( + x2, + filter_in, + strides=[1, 1], + padding="VALID") + self.assertEqual(conv1.shape, tensor_in_sizes_batch) + self.assertEqual(conv2.shape, tensor_in_sizes_expanded_batch) + self.assertAllEqual( + conv1, + self.evaluate(conv2).reshape(conv1.shape)) + + @test_util.run_in_graph_and_eager_modes + def testConvolutionClass2DExpandedBatch(self): + tensor_in_sizes_batch = [10, 2, 3, 3] + tensor_in_sizes_expanded_batch = [2, 5, 2, 3, 3] + filter_in_sizes = [1, 1, 3, 3] + filter_in = self._CreateNumpyTensor(filter_in_sizes) + x1 = self._CreateNumpyTensor(tensor_in_sizes_batch) + x2 = x1.reshape(tensor_in_sizes_expanded_batch) + convolver1 = nn_ops.Convolution( + input_shape=x1.shape, + filter_shape=filter_in.shape, + strides=[1, 1], + padding="VALID") + self.assertEqual(convolver1.num_batch_dims, 1) + convolver2 = nn_ops.Convolution( + input_shape=x2.shape, + filter_shape=filter_in.shape, + strides=[1, 1], + padding="VALID") + self.assertEqual(convolver2.num_batch_dims, 2) + conv1 = convolver1(x1, filter_in) + conv2 = convolver2(x2, filter_in) + self.assertEqual(conv1.shape, tensor_in_sizes_batch) + self.assertEqual(conv2.shape, tensor_in_sizes_expanded_batch) + self.assertAllEqual( + conv1, + self.evaluate(conv2).reshape(conv1.shape)) + + @test_util.run_in_graph_and_eager_modes + def testConvolutionWith2SpatialDimensionsAndExpandedBatch(self): + tensor_in_sizes_batch = [10, 2, 3, 3] + tensor_in_sizes_expanded_batch = [2, 5, 2, 3, 3] + filter_in_sizes = [1, 1, 3, 3] + filter_in = self._CreateNumpyTensor(filter_in_sizes) + x1 = self._CreateNumpyTensor(tensor_in_sizes_batch) + x2 = x1.reshape(tensor_in_sizes_expanded_batch) + conv1 = nn_ops.convolution( + x1, + filter_in, + strides=[1, 1], + padding="VALID") + conv2 = nn_ops.convolution( + x2, + filter_in, + strides=[1, 1], + padding="VALID") + self.assertEqual(conv1.shape, tensor_in_sizes_batch) + self.assertEqual(conv2.shape, tensor_in_sizes_expanded_batch) + self.assertAllEqual( + conv1, + self.evaluate(conv2).reshape(conv1.shape)) + @test_util.run_in_graph_and_eager_modes def testConv2D2x2Filter2x1Dilation(self): self._VerifyDilatedConvValues( diff --git a/tensorflow/python/kernel_tests/lu_op_test.py b/tensorflow/python/kernel_tests/lu_op_test.py index 7935b66f4af..de9d8c32cb5 100644 --- a/tensorflow/python/kernel_tests/lu_op_test.py +++ b/tensorflow/python/kernel_tests/lu_op_test.py @@ -30,7 +30,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import map_fn from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops +from tensorflow.python.ops import stateless_random_ops from tensorflow.python.ops import variables from tensorflow.python.platform import benchmark from tensorflow.python.platform import test @@ -214,15 +214,20 @@ class LuOpTest(test.TestCase): data = np.random.rand(n, n) + 1j * np.random.rand(n, n) self._verifyLu(data) - @test_util.run_v1_only("b/120545219") + @test_util.run_in_graph_and_eager_modes(use_gpu=True) def testEmpty(self): self._verifyLu(np.empty([0, 2, 2])) self._verifyLu(np.empty([2, 0, 0])) - @test_util.run_deprecated_v1 + @test_util.run_in_graph_and_eager_modes(use_gpu=True) def testConcurrentExecutesWithoutError(self): - matrix1 = random_ops.random_normal([5, 5], seed=42) - matrix2 = random_ops.random_normal([5, 5], seed=42) + matrix_shape = [5, 5] + seed = [42, 24] + matrix1 = stateless_random_ops.stateless_random_normal( + shape=matrix_shape, seed=seed) + matrix2 = stateless_random_ops.stateless_random_normal( + shape=matrix_shape, seed=seed) + self.assertAllEqual(matrix1, matrix2) lu1, p1 = linalg_ops.lu(matrix1) lu2, p2 = linalg_ops.lu(matrix2) lu1_val, p1_val, lu2_val, p2_val = self.evaluate([lu1, p1, lu2, p2]) diff --git a/tensorflow/python/kernel_tests/map_fn_test.py b/tensorflow/python/kernel_tests/map_fn_test.py index 1e10d689886..913c3a49cb0 100644 --- a/tensorflow/python/kernel_tests/map_fn_test.py +++ b/tensorflow/python/kernel_tests/map_fn_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python.eager import def_function from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -186,6 +187,25 @@ class MapFnTest(test.TestCase): self.assertAllEqual(-nums, received[1]) self.assertAllEqual(nums, received[2]) + @test_util.run_in_graph_and_eager_modes + def testMap_autograph_indirect(self): + + def test_function(x): + cond = constant_op.constant(-1) + if cond == 0: + result = x + else: + result = x + return result + + @def_function.function + def map_call(x): + return map_fn.map_fn(test_function, x) + + x = constant_op.constant([1]) + y = map_call(x) + self.assertAllEqual([1], self.evaluate(y)) + @test_util.run_in_graph_and_eager_modes def testMapShape(self): x = constant_op.constant([[1, 2, 3], [4, 5, 6]]) diff --git a/tensorflow/python/kernel_tests/matrix_logarithm_op_test.py b/tensorflow/python/kernel_tests/matrix_logarithm_op_test.py index fa466d975f8..8cc230d2806 100644 --- a/tensorflow/python/kernel_tests/matrix_logarithm_op_test.py +++ b/tensorflow/python/kernel_tests/matrix_logarithm_op_test.py @@ -23,12 +23,13 @@ import numpy as np from tensorflow.python.client import session from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_linalg_ops from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops +from tensorflow.python.ops import stateless_random_ops from tensorflow.python.ops import variables from tensorflow.python.ops.linalg import linalg_impl from tensorflow.python.platform import benchmark @@ -57,7 +58,7 @@ class LogarithmOpTest(test.TestCase): matrix_batch = np.tile(matrix_batch, [2, 3, 1, 1]) return matrix_batch - @test_util.run_v1_only("b/120545219") + @test_util.run_in_graph_and_eager_modes(use_gpu=True) def testNonsymmetric(self): # 2x2 matrices matrix1 = np.array([[1., 2.], [3., 4.]]) @@ -71,7 +72,7 @@ class LogarithmOpTest(test.TestCase): # Complex batch self._verifyLogarithmComplex(self._makeBatch(matrix1, matrix2)) - @test_util.run_v1_only("b/120545219") + @test_util.run_in_graph_and_eager_modes(use_gpu=True) def testSymmetricPositiveDefinite(self): # 2x2 matrices matrix1 = np.array([[2., 1.], [1., 2.]]) @@ -85,27 +86,27 @@ class LogarithmOpTest(test.TestCase): # Complex batch self._verifyLogarithmComplex(self._makeBatch(matrix1, matrix2)) - @test_util.run_v1_only("b/120545219") + @test_util.run_in_graph_and_eager_modes(use_gpu=True) def testNonSquareMatrix(self): # When the logarithm of a non-square matrix is attempted we should return # an error - with self.assertRaises(ValueError): + with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)): gen_linalg_ops.matrix_logarithm( np.array([[1., 2., 3.], [3., 4., 5.]], dtype=np.complex64)) - @test_util.run_v1_only("b/120545219") + @test_util.run_in_graph_and_eager_modes(use_gpu=True) def testWrongDimensions(self): # The input to the logarithm should be at least a 2-dimensional tensor. tensor3 = constant_op.constant([1., 2.], dtype=dtypes.complex64) - with self.assertRaises(ValueError): + with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)): gen_linalg_ops.matrix_logarithm(tensor3) - @test_util.run_v1_only("b/120545219") + @test_util.run_in_graph_and_eager_modes(use_gpu=True) def testEmpty(self): self._verifyLogarithmComplex(np.empty([0, 2, 2], dtype=np.complex64)) self._verifyLogarithmComplex(np.empty([2, 0, 0], dtype=np.complex64)) - @test_util.run_v1_only("b/120545219") + @test_util.run_in_graph_and_eager_modes(use_gpu=True) def testRandomSmallAndLargeComplex64(self): np.random.seed(42) for batch_dims in [(), (1,), (3,), (2, 2)]: @@ -116,7 +117,7 @@ class LogarithmOpTest(test.TestCase): size=np.prod(shape)).reshape(shape).astype(np.complex64) self._verifyLogarithmComplex(matrix) - @test_util.run_v1_only("b/120545219") + @test_util.run_in_graph_and_eager_modes(use_gpu=True) def testRandomSmallAndLargeComplex128(self): np.random.seed(42) for batch_dims in [(), (1,), (3,), (2, 2)]: @@ -127,17 +128,21 @@ class LogarithmOpTest(test.TestCase): size=np.prod(shape)).reshape(shape).astype(np.complex128) self._verifyLogarithmComplex(matrix) - @test_util.run_v1_only("b/120545219") + @test_util.run_in_graph_and_eager_modes(use_gpu=True) def testConcurrentExecutesWithoutError(self): - with self.session(use_gpu=True) as sess: - matrix1 = math_ops.cast( - random_ops.random_normal([5, 5], seed=42), dtypes.complex64) - matrix2 = math_ops.cast( - random_ops.random_normal([5, 5], seed=42), dtypes.complex64) - logm1 = gen_linalg_ops.matrix_logarithm(matrix1) - logm2 = gen_linalg_ops.matrix_logarithm(matrix2) - logm = self.evaluate([logm1, logm2]) - self.assertAllEqual(logm[0], logm[1]) + matrix_shape = [5, 5] + seed = [42, 24] + matrix1 = math_ops.cast( + stateless_random_ops.stateless_random_normal(matrix_shape, seed=seed), + dtypes.complex64) + matrix2 = math_ops.cast( + stateless_random_ops.stateless_random_normal(matrix_shape, seed=seed), + dtypes.complex64) + self.assertAllEqual(matrix1, matrix2) + logm1 = gen_linalg_ops.matrix_logarithm(matrix1) + logm2 = gen_linalg_ops.matrix_logarithm(matrix2) + logm = self.evaluate([logm1, logm2]) + self.assertAllEqual(logm[0], logm[1]) class MatrixLogarithmBenchmark(test.Benchmark): diff --git a/tensorflow/python/kernel_tests/matrix_solve_ls_op_test.py b/tensorflow/python/kernel_tests/matrix_solve_ls_op_test.py index b99c8f6d256..b7a159e2eff 100644 --- a/tensorflow/python/kernel_tests/matrix_solve_ls_op_test.py +++ b/tensorflow/python/kernel_tests/matrix_solve_ls_op_test.py @@ -20,10 +20,11 @@ from __future__ import print_function import numpy as np -from tensorflow.python import tf2 from tensorflow.python.client import session +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops @@ -89,6 +90,8 @@ class MatrixSolveLsOpTest(test_lib.TestCase): if not fast and l2_regularizer != 0: # The slow path does not support regularization. return + if use_placeholder and context.executing_eagerly(): + return maxdim = np.max(x.shape) if dtype == np.float32 or dtype == np.complex64: tol = maxdim * 5e-4 @@ -109,64 +112,70 @@ class MatrixSolveLsOpTest(test_lib.TestCase): b = np.tile(b, batch_shape + (1, 1)) np_ans = np.tile(np_ans, batch_shape + (1, 1)) np_r_norm = np.tile(np_r_norm, batch_shape) - with self.cached_session(use_gpu=fast) as sess: - if use_placeholder: - a_ph = array_ops.placeholder(dtypes.as_dtype(dtype)) - b_ph = array_ops.placeholder(dtypes.as_dtype(dtype)) - feed_dict = {a_ph: a, b_ph: b} - tf_ans = linalg_ops.matrix_solve_ls( - a_ph, b_ph, fast=fast, l2_regularizer=l2_regularizer) - else: - tf_ans = linalg_ops.matrix_solve_ls( - a, b, fast=fast, l2_regularizer=l2_regularizer) - feed_dict = {} - self.assertEqual(np_ans.shape, tf_ans.get_shape()) - if l2_regularizer == 0: - # The least squares solution should satisfy A^H * (b - A*x) = 0. - tf_r = b - math_ops.matmul(a, tf_ans) - tf_r = math_ops.matmul(a, tf_r, adjoint_a=True) - tf_r_norm = linalg_ops.norm(tf_r, ord="fro", axis=[-2, -1]) - tf_ans_val, tf_r_norm_val = sess.run( - [tf_ans, tf_r_norm], feed_dict=feed_dict) - self.assertAllClose(np_r_norm, tf_r_norm_val, atol=tol, rtol=tol) - else: + if use_placeholder: + a_ph = array_ops.placeholder(dtypes.as_dtype(dtype)) + b_ph = array_ops.placeholder(dtypes.as_dtype(dtype)) + feed_dict = {a_ph: a, b_ph: b} + tf_ans = linalg_ops.matrix_solve_ls( + a_ph, b_ph, fast=fast, l2_regularizer=l2_regularizer) + else: + tf_ans = linalg_ops.matrix_solve_ls( + a, b, fast=fast, l2_regularizer=l2_regularizer) + feed_dict = None + self.assertEqual(np_ans.shape, tf_ans.get_shape()) + if feed_dict: + with self.session(use_gpu=True) as sess: tf_ans_val = sess.run(tf_ans, feed_dict=feed_dict) - + else: + tf_ans_val = self.evaluate(tf_ans) self.assertEqual(np_ans.shape, tf_ans_val.shape) self.assertAllClose(np_ans, tf_ans_val, atol=2 * tol, rtol=2 * tol) - @test_util.run_v1_only("b/120545219") + if l2_regularizer == 0: + # The least squares solution should satisfy A^H * (b - A*x) = 0. + tf_r = b - math_ops.matmul(a, tf_ans) + tf_r = math_ops.matmul(a, tf_r, adjoint_a=True) + tf_r_norm = linalg_ops.norm(tf_r, ord="fro", axis=[-2, -1]) + if feed_dict: + with self.session(use_gpu=True) as sess: + tf_ans_val, tf_r_norm_val = sess.run([tf_ans, tf_r_norm], + feed_dict=feed_dict) + else: + tf_ans_val, tf_r_norm_val = self.evaluate([tf_ans, tf_r_norm]) + self.assertAllClose(np_r_norm, tf_r_norm_val, atol=tol, rtol=tol) + + @test_util.run_in_graph_and_eager_modes(use_gpu=True) def testWrongDimensions(self): # The matrix and right-hand sides should have the same number of rows. with self.session(use_gpu=True): matrix = constant_op.constant([[1., 0.], [0., 1.]]) rhs = constant_op.constant([[1., 0.]]) - with self.assertRaises(ValueError): + with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)): linalg_ops.matrix_solve_ls(matrix, rhs) + @test_util.run_in_graph_and_eager_modes(use_gpu=True) def testEmpty(self): full = np.array([[1., 2.], [3., 4.], [5., 6.]]) empty0 = np.empty([3, 0]) empty1 = np.empty([0, 2]) for fast in [True, False]: - with self.cached_session(use_gpu=True): - tf_ans = self.evaluate( - linalg_ops.matrix_solve_ls(empty0, empty0, fast=fast)) - self.assertEqual(tf_ans.shape, (0, 0)) - tf_ans = self.evaluate( - linalg_ops.matrix_solve_ls(empty0, full, fast=fast)) - self.assertEqual(tf_ans.shape, (0, 2)) - tf_ans = self.evaluate( - linalg_ops.matrix_solve_ls(full, empty0, fast=fast)) - self.assertEqual(tf_ans.shape, (2, 0)) - tf_ans = self.evaluate( - linalg_ops.matrix_solve_ls(empty1, empty1, fast=fast)) - self.assertEqual(tf_ans.shape, (2, 2)) + tf_ans = self.evaluate( + linalg_ops.matrix_solve_ls(empty0, empty0, fast=fast)) + self.assertEqual(tf_ans.shape, (0, 0)) + tf_ans = self.evaluate( + linalg_ops.matrix_solve_ls(empty0, full, fast=fast)) + self.assertEqual(tf_ans.shape, (0, 2)) + tf_ans = self.evaluate( + linalg_ops.matrix_solve_ls(full, empty0, fast=fast)) + self.assertEqual(tf_ans.shape, (2, 0)) + tf_ans = self.evaluate( + linalg_ops.matrix_solve_ls(empty1, empty1, fast=fast)) + self.assertEqual(tf_ans.shape, (2, 2)) - @test_util.run_v1_only("b/120545219") + @test_util.run_in_graph_and_eager_modes(use_gpu=True) def testBatchResultSize(self): # 3x3x3 matrices, 3x3x1 right-hand sides. - matrix = np.array([1., 2., 3., 4., 5., 6., 7., 8., 9.] * 3).reshape(3, 3, 3) + matrix = np.array([1., 0., 0., 0., 1., 0., 0., 0., 1.] * 3).reshape(3, 3, 3) rhs = np.array([1., 2., 3.] * 3).reshape(3, 3, 1) answer = linalg_ops.matrix_solve(matrix, rhs) ls_answer = linalg_ops.matrix_solve_ls(matrix, rhs) @@ -358,8 +367,7 @@ if __name__ == "__main__": # ROCm does not support BLAS operations for complex types dtypes_to_test += [np.complex64, np.complex128] for dtype_ in dtypes_to_test: - # TF2 does not support placeholders under eager so we skip it - for use_placeholder_ in set([False, not tf2.enabled()]): + for use_placeholder_ in set([False, True]): for fast_ in [True, False]: l2_regularizers = [0] if dtype_ == np.complex128 else [0, 0.1] for l2_regularizer_ in l2_regularizers: diff --git a/tensorflow/python/kernel_tests/matrix_solve_op_test.py b/tensorflow/python/kernel_tests/matrix_solve_op_test.py index 0b6b403210c..bbd909c8e58 100644 --- a/tensorflow/python/kernel_tests/matrix_solve_op_test.py +++ b/tensorflow/python/kernel_tests/matrix_solve_op_test.py @@ -21,14 +21,16 @@ from __future__ import print_function import numpy as np from tensorflow.python.client import session +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op +from tensorflow.python.framework import errors_impl from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import linalg_ops -from tensorflow.python.ops import random_ops +from tensorflow.python.ops import stateless_random_ops from tensorflow.python.ops import variables from tensorflow.python.platform import benchmark from tensorflow.python.platform import test @@ -56,19 +58,19 @@ class MatrixSolveOpTest(test.TestCase): a_np = np.tile(a_np, batch_dims + [1, 1]) b = np.tile(b, batch_dims + [1, 1]) np_ans = np.linalg.solve(a_np, b) - for use_placeholder in False, True: - with self.cached_session(use_gpu=True) as sess: - if use_placeholder: - a_ph = array_ops.placeholder(dtypes.as_dtype(np_type)) - b_ph = array_ops.placeholder(dtypes.as_dtype(np_type)) - tf_ans = linalg_ops.matrix_solve(a_ph, b_ph, adjoint=adjoint) + for use_placeholder in set((False, not context.executing_eagerly())): + if use_placeholder: + a_ph = array_ops.placeholder(dtypes.as_dtype(np_type)) + b_ph = array_ops.placeholder(dtypes.as_dtype(np_type)) + tf_ans = linalg_ops.matrix_solve(a_ph, b_ph, adjoint=adjoint) + with self.cached_session(use_gpu=True) as sess: out = sess.run(tf_ans, {a_ph: a, b_ph: b}) - else: - tf_ans = linalg_ops.matrix_solve(a, b, adjoint=adjoint) - out = self.evaluate(tf_ans) - self.assertEqual(tf_ans.get_shape(), out.shape) - self.assertEqual(np_ans.shape, out.shape) - self.assertAllClose(np_ans, out, atol=tol, rtol=tol) + else: + tf_ans = linalg_ops.matrix_solve(a, b, adjoint=adjoint) + out = self.evaluate(tf_ans) + self.assertEqual(tf_ans.get_shape(), out.shape) + self.assertEqual(np_ans.shape, out.shape) + self.assertAllClose(np_ans, out, atol=tol, rtol=tol) def _generateMatrix(self, m, n): matrix = (np.random.normal(-5, 5, @@ -77,7 +79,7 @@ class MatrixSolveOpTest(test.TestCase): [m, n])) return matrix - @test_util.run_deprecated_v1 + @test_util.run_in_graph_and_eager_modes(use_gpu=True) def testSolve(self): for n in 1, 2, 4, 9: matrix = self._generateMatrix(n, n) @@ -85,7 +87,7 @@ class MatrixSolveOpTest(test.TestCase): rhs = self._generateMatrix(n, nrhs) self._verifySolve(matrix, rhs) - @test_util.run_deprecated_v1 + @test_util.run_in_graph_and_eager_modes(use_gpu=True) def testSolveBatch(self): for n in 2, 5: matrix = self._generateMatrix(n, n) @@ -94,48 +96,50 @@ class MatrixSolveOpTest(test.TestCase): for batch_dims in [[2], [2, 2], [7, 4]]: self._verifySolve(matrix, rhs, batch_dims=batch_dims) - @test_util.run_deprecated_v1 + @test_util.run_in_graph_and_eager_modes(use_gpu=True) def testNonSquareMatrix(self): # When the solve of a non-square matrix is attempted we should return # an error - with self.session(use_gpu=True): - with self.assertRaises(ValueError): - matrix = constant_op.constant([[1., 2., 3.], [3., 4., 5.]]) - linalg_ops.matrix_solve(matrix, matrix) + with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)): + matrix = constant_op.constant([[1., 2., 3.], [3., 4., 5.]]) + self.evaluate(linalg_ops.matrix_solve(matrix, matrix)) - @test_util.run_deprecated_v1 + @test_util.run_in_graph_and_eager_modes(use_gpu=True) def testWrongDimensions(self): # The matrix and right-hand sides should have the same number of rows. - with self.session(use_gpu=True): - matrix = constant_op.constant([[1., 0.], [0., 1.]]) - rhs = constant_op.constant([[1., 0.]]) - with self.assertRaises(ValueError): - linalg_ops.matrix_solve(matrix, rhs) + matrix = constant_op.constant([[1., 0.], [0., 1.]]) + rhs = constant_op.constant([[1., 0.]]) + with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)): + self.evaluate(linalg_ops.matrix_solve(matrix, rhs)) def testNotInvertible(self): # The input should be invertible. - with self.session(use_gpu=True): - with self.assertRaisesOpError("Input matrix is not invertible."): - # All rows of the matrix below add to zero - matrix = constant_op.constant([[1., 0., -1.], [-1., 1., 0.], - [0., -1., 1.]]) - linalg_ops.matrix_solve(matrix, matrix).eval() + with self.assertRaisesOpError("Input matrix is not invertible."): + # All rows of the matrix below add to zero + matrix = constant_op.constant([[1., 0., -1.], [-1., 1., 0.], + [0., -1., 1.]]) + self.evaluate(linalg_ops.matrix_solve(matrix, matrix)) - @test_util.run_deprecated_v1 + @test_util.run_in_graph_and_eager_modes(use_gpu=True) def testConcurrent(self): - with self.session(use_gpu=True) as sess: - all_ops = [] - for adjoint_ in False, True: - lhs1 = random_ops.random_normal([3, 3], seed=42) - lhs2 = random_ops.random_normal([3, 3], seed=42) - rhs1 = random_ops.random_normal([3, 3], seed=42) - rhs2 = random_ops.random_normal([3, 3], seed=42) - s1 = linalg_ops.matrix_solve(lhs1, rhs1, adjoint=adjoint_) - s2 = linalg_ops.matrix_solve(lhs2, rhs2, adjoint=adjoint_) - all_ops += [s1, s2] - val = self.evaluate(all_ops) - self.assertAllEqual(val[0], val[1]) - self.assertAllEqual(val[2], val[3]) + seed = [42, 24] + matrix_shape = [3, 3] + all_ops = [] + for adjoint_ in False, True: + lhs1 = stateless_random_ops.stateless_random_normal( + matrix_shape, seed=seed) + lhs2 = stateless_random_ops.stateless_random_normal( + matrix_shape, seed=seed) + rhs1 = stateless_random_ops.stateless_random_normal( + matrix_shape, seed=seed) + rhs2 = stateless_random_ops.stateless_random_normal( + matrix_shape, seed=seed) + s1 = linalg_ops.matrix_solve(lhs1, rhs1, adjoint=adjoint_) + s2 = linalg_ops.matrix_solve(lhs2, rhs2, adjoint=adjoint_) + all_ops += [s1, s2] + val = self.evaluate(all_ops) + for i in range(0, len(all_ops), 2): + self.assertAllEqual(val[i], val[i + 1]) class MatrixSolveBenchmark(test.Benchmark): diff --git a/tensorflow/python/kernel_tests/matrix_square_root_op_test.py b/tensorflow/python/kernel_tests/matrix_square_root_op_test.py index c36d83e2530..6cf330ed981 100644 --- a/tensorflow/python/kernel_tests/matrix_square_root_op_test.py +++ b/tensorflow/python/kernel_tests/matrix_square_root_op_test.py @@ -21,10 +21,11 @@ from __future__ import print_function import numpy as np from tensorflow.python.framework import constant_op +from tensorflow.python.framework import errors_impl from tensorflow.python.framework import test_util from tensorflow.python.ops import gen_linalg_ops from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops +from tensorflow.python.ops import stateless_random_ops from tensorflow.python.platform import test @@ -89,31 +90,35 @@ class SquareRootOpTest(test.TestCase): self._verifySquareRootReal(np.empty([0, 2, 2])) self._verifySquareRootReal(np.empty([2, 0, 0])) - @test_util.run_v1_only("b/120545219") + @test_util.run_in_graph_and_eager_modes(use_gpu=True) def testWrongDimensions(self): # The input to the square root should be at least a 2-dimensional tensor. tensor = constant_op.constant([1., 2.]) - with self.assertRaises(ValueError): + with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)): gen_linalg_ops.matrix_square_root(tensor) - @test_util.run_v1_only("b/120545219") + @test_util.run_in_graph_and_eager_modes(use_gpu=True) def testNotSquare(self): - with self.assertRaises(ValueError): + with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)): tensor = constant_op.constant([[1., 0., -1.], [-1., 1., 0.]]) self.evaluate(gen_linalg_ops.matrix_square_root(tensor)) - @test_util.run_v1_only("b/120545219") + @test_util.run_in_graph_and_eager_modes(use_gpu=True) def testConcurrentExecutesWithoutError(self): - with test_util.use_gpu(): - matrix1 = random_ops.random_normal([5, 5], seed=42) - matrix2 = random_ops.random_normal([5, 5], seed=42) - square1 = math_ops.matmul(matrix1, matrix1) - square2 = math_ops.matmul(matrix2, matrix2) - sqrt1 = gen_linalg_ops.matrix_square_root(square1) - sqrt2 = gen_linalg_ops.matrix_square_root(square2) - all_ops = [sqrt1, sqrt2] - sqrt = self.evaluate(all_ops) - self.assertAllClose(sqrt[0], sqrt[1]) + matrix_shape = [5, 5] + seed = [42, 24] + matrix1 = stateless_random_ops.stateless_random_normal( + shape=matrix_shape, seed=seed) + matrix2 = stateless_random_ops.stateless_random_normal( + shape=matrix_shape, seed=seed) + self.assertAllEqual(matrix1, matrix2) + square1 = math_ops.matmul(matrix1, matrix1) + square2 = math_ops.matmul(matrix2, matrix2) + sqrt1 = gen_linalg_ops.matrix_square_root(square1) + sqrt2 = gen_linalg_ops.matrix_square_root(square2) + all_ops = [sqrt1, sqrt2] + sqrt = self.evaluate(all_ops) + self.assertAllClose(sqrt[0], sqrt[1]) if __name__ == "__main__": diff --git a/tensorflow/python/kernel_tests/numerics_test.py b/tensorflow/python/kernel_tests/numerics_test.py index 4d31cd45289..475badb6efe 100644 --- a/tensorflow/python/kernel_tests/numerics_test.py +++ b/tensorflow/python/kernel_tests/numerics_test.py @@ -24,7 +24,6 @@ import numpy as np from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops @@ -67,7 +66,7 @@ class VerifyTensorAllFiniteTest(test.TestCase): self.evaluate(t_verified) -@test_util.run_v1_only("b/120545219") +@test_util.run_v1_only("add_check_numerics_op() is meant to be a v1-only API") class NumericsTest(test.TestCase): def testInf(self): @@ -132,51 +131,6 @@ class NumericsTest(test.TestCase): r"or `tf.while_loop\(\)`\."): numerics.add_check_numerics_ops() - def testCheckNumericsV2OpNegativeAndPositiveInf(self): - """Test that CheckNumericsV2 op distinguishes negative and positive infs.""" - with self.session(graph=ops.Graph()): - t1 = constant_op.constant([-1.0, 1.0]) - t2 = constant_op.constant([0.0, 0.0]) - checked = array_ops.check_numerics_v2( - t1 / t2, message="pass through test") - caught = None - try: - self.evaluate(checked) - except errors.InvalidArgumentError as error: - caught = error - self.assertIn("had -Inf and +Inf values", caught.message) - self.assertIn("pass through test", caught.message) - - def testCheckNumericsV2OpNegativeAndPositiveInfAndNaN(self): - """CheckNumericsV2 op distinguishes - & + infs when nan is present.""" - with self.session(graph=ops.Graph()): - t1 = constant_op.constant([-1.0, 1.0, 0.0]) - t2 = constant_op.constant([0.0, 0.0, 0.0]) - checked = array_ops.check_numerics_v2( - t1 / t2, message="pass through test") - caught = None - try: - self.evaluate(checked) - except errors.InvalidArgumentError as error: - caught = error - self.assertIn("had -Inf, +Inf, and NaN values", caught.message) - self.assertIn("pass through test", caught.message) - - def testCheckNumericsV2PositiveInfAndNaN(self): - """Test that CheckNumericsV2 op shows sign of inf when nan is present.""" - with self.session(graph=ops.Graph()): - t1 = constant_op.constant([0.0, 1.0]) - t2 = constant_op.constant([0.0, 0.0]) - checked = array_ops.check_numerics_v2( - t1 / t2, message="pass through test") - caught = None - try: - self.evaluate(checked) - except errors.InvalidArgumentError as error: - caught = error - self.assertIn("had +Inf and NaN values", caught.message) - self.assertIn("pass through test", caught.message) - if __name__ == "__main__": # TODO(b/130689556): XLA CPU does not honor inf/nan which causes problems diff --git a/tensorflow/python/kernel_tests/pooling_ops_3d_test.py b/tensorflow/python/kernel_tests/pooling_ops_3d_test.py index d5331dcb3e9..051f7e1168a 100644 --- a/tensorflow/python/kernel_tests/pooling_ops_3d_test.py +++ b/tensorflow/python/kernel_tests/pooling_ops_3d_test.py @@ -205,14 +205,14 @@ class PoolingTest(test.TestCase): padding="VALID", expected=[29.5, 32.5, 50.5, 53.5, 176.5, 179.5, 197.5, 200.5]) - def _MaxPool3DEmptyTensorOutputShape(self): + def testMaxPool3DEmptyTensorOutputShape(self): """Verifies the output shape of the max pooling function when tensor is empty. Args: none """ input_sizes = [0, 112, 112, 112, 64] - input_data = 1 + input_data = 1. input_tensor = constant_op.constant( input_data, shape=input_sizes, name="input") max_pool_3d = nn_ops.max_pool3d( diff --git a/tensorflow/python/kernel_tests/proto/BUILD b/tensorflow/python/kernel_tests/proto/BUILD index d9643f3d125..0e935dfe8c4 100644 --- a/tensorflow/python/kernel_tests/proto/BUILD +++ b/tensorflow/python/kernel_tests/proto/BUILD @@ -1,7 +1,7 @@ # Tests of tf.io.*proto. -load("//tensorflow:tensorflow.bzl", "tf_py_test") load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object") +load("//tensorflow:tensorflow.bzl", "tf_py_test") load("//tensorflow/core/platform:build_config_root.bzl", "if_static") load("//tensorflow/core/platform:build_config.bzl", "tf_additional_all_protos", "tf_proto_library") diff --git a/tensorflow/python/kernel_tests/qr_op_test.py b/tensorflow/python/kernel_tests/qr_op_test.py index 4e0af934053..d5337c183a6 100644 --- a/tensorflow/python/kernel_tests/qr_op_test.py +++ b/tensorflow/python/kernel_tests/qr_op_test.py @@ -20,17 +20,18 @@ from __future__ import print_function import numpy as np -from tensorflow.python import tf2 from tensorflow.python.client import session +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op +from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import gradient_checker +from tensorflow.python.ops import gradient_checker_v2 from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops +from tensorflow.python.ops import stateless_random_ops from tensorflow.python.ops import variables from tensorflow.python.platform import benchmark from tensorflow.python.platform import test @@ -45,35 +46,37 @@ def _AddTest(test_class, op_name, testcase_name, fn): class QrOpTest(test.TestCase): - @test_util.run_v1_only("b/120545219") + @test_util.run_in_graph_and_eager_modes(use_gpu=True) def testWrongDimensions(self): - # The input to qr should be a tensor of at least rank 2. + # The input to svd should be a tensor of at least rank 2. scalar = constant_op.constant(1.) - with self.assertRaisesRegexp(ValueError, - "Shape must be at least rank 2 but is rank 0"): + with self.assertRaisesRegexp((ValueError, errors_impl.InvalidArgumentError), + "rank.* 2.*0"): linalg_ops.qr(scalar) vector = constant_op.constant([1., 2.]) - with self.assertRaisesRegexp(ValueError, - "Shape must be at least rank 2 but is rank 1"): + with self.assertRaisesRegexp((ValueError, errors_impl.InvalidArgumentError), + "rank.* 2.*1"): linalg_ops.qr(vector) - @test_util.run_deprecated_v1 + @test_util.run_in_graph_and_eager_modes(use_gpu=True) def testConcurrentExecutesWithoutError(self): - with self.session(use_gpu=True) as sess: - all_ops = [] - for full_matrices_ in True, False: - for rows_ in 4, 5: - for cols_ in 4, 5: - matrix1 = random_ops.random_normal([rows_, cols_], seed=42) - matrix2 = random_ops.random_normal([rows_, cols_], seed=42) - q1, r1 = linalg_ops.qr(matrix1, full_matrices=full_matrices_) - q2, r2 = linalg_ops.qr(matrix2, full_matrices=full_matrices_) - all_ops += [q1, r1, q2, r2] - val = self.evaluate(all_ops) - for i in range(8): - q = 4 * i - self.assertAllClose(val[q], val[q + 2]) # q1 == q2 - self.assertAllClose(val[q + 1], val[q + 3]) # r1 == r2 + seed = [42, 24] + all_ops = [] + for full_matrices_ in True, False: + for rows_ in 4, 5: + for cols_ in 4, 5: + matrix_shape = [rows_, cols_] + matrix1 = stateless_random_ops.stateless_random_normal( + matrix_shape, seed) + matrix2 = stateless_random_ops.stateless_random_normal( + matrix_shape, seed) + self.assertAllEqual(matrix1, matrix2) + q1, r1 = linalg_ops.qr(matrix1, full_matrices=full_matrices_) + q2, r2 = linalg_ops.qr(matrix2, full_matrices=full_matrices_) + all_ops += [q1, q2, r1, r2] + val = self.evaluate(all_ops) + for i in range(0, len(val), 2): + self.assertAllClose(val[i], val[i + 1]) def _GetQrOpTest(dtype_, shape_, full_matrices_, use_static_shape_): @@ -121,8 +124,10 @@ def _GetQrOpTest(dtype_, shape_, full_matrices_, use_static_shape_): tol = 1e-14 self.assertAllClose(identity, xx, atol=tol) - @test_util.run_v1_only("b/120545219") + @test_util.run_in_graph_and_eager_modes(use_gpu=True) def Test(self): + if not use_static_shape_ and context.executing_eagerly(): + return np.random.seed(1) x_np = np.random.uniform( low=-1.0, high=1.0, size=np.prod(shape_)).reshape(shape_).astype(dtype_) @@ -131,7 +136,6 @@ def _GetQrOpTest(dtype_, shape_, full_matrices_, use_static_shape_): low=-1.0, high=1.0, size=np.prod(shape_)).reshape(shape_).astype(dtype_) - with self.session(use_gpu=True) as sess: if use_static_shape_: x_tf = constant_op.constant(x_np) else: @@ -141,7 +145,8 @@ def _GetQrOpTest(dtype_, shape_, full_matrices_, use_static_shape_): if use_static_shape_: q_tf_val, r_tf_val = self.evaluate([q_tf, r_tf]) else: - q_tf_val, r_tf_val = sess.run([q_tf, r_tf], feed_dict={x_tf: x_np}) + with self.session(use_gpu=True) as sess: + q_tf_val, r_tf_val = sess.run([q_tf, r_tf], feed_dict={x_tf: x_np}) q_dims = q_tf_val.shape np_q = np.ndarray(q_dims, dtype_) @@ -170,13 +175,16 @@ class QrGradOpTest(test.TestCase): def _GetQrGradOpTest(dtype_, shape_, full_matrices_): - @test_util.run_v1_only("b/120545219") - def Test(self): - np.random.seed(42) + def RandomInput(): a = np.random.uniform(low=-1.0, high=1.0, size=shape_).astype(dtype_) if dtype_ in [np.complex64, np.complex128]: a += 1j * np.random.uniform( low=-1.0, high=1.0, size=shape_).astype(dtype_) + return a + + @test_util.run_in_graph_and_eager_modes(use_gpu=True) + def Test(self): + np.random.seed(42) # Optimal stepsize for central difference is O(epsilon^{1/3}). epsilon = np.finfo(dtype_).eps delta = 0.1 * epsilon**(1.0 / 3.0) @@ -184,23 +192,16 @@ def _GetQrGradOpTest(dtype_, shape_, full_matrices_): tol = 3e-2 else: tol = 1e-6 - with self.session(use_gpu=True): - tf_a = constant_op.constant(a) - tf_b = linalg_ops.qr(tf_a, full_matrices=full_matrices_) - for b in tf_b: - x_init = np.random.uniform( - low=-1.0, high=1.0, size=shape_).astype(dtype_) - if dtype_ in [np.complex64, np.complex128]: - x_init += 1j * np.random.uniform( - low=-1.0, high=1.0, size=shape_).astype(dtype_) - theoretical, numerical = gradient_checker.compute_gradient( - tf_a, - tf_a.get_shape().as_list(), - b, - b.get_shape().as_list(), - x_init_value=x_init, - delta=delta) - self.assertAllClose(theoretical, numerical, atol=tol, rtol=tol) + # TODO(b/157171666): Sadly we have to double the computation because + # gradient_checker_v2.compute_gradient expects a list of functions. + funcs = [ + lambda a: linalg_ops.qr(a, full_matrices=full_matrices_)[0], + lambda a: linalg_ops.qr(a, full_matrices=full_matrices_)[1] + ] + for f in funcs: + theoretical, numerical = gradient_checker_v2.compute_gradient( + f, [RandomInput()], delta=delta) + self.assertAllClose(theoretical, numerical, atol=tol, rtol=tol) return Test @@ -266,7 +267,7 @@ if __name__ == "__main__": for full_matrices in False, True: for batch_dims in [(), (3,)] + [(3, 2)] * (max(rows, cols) < 10): # TF2 does not support placeholders under eager so we skip it - for use_static_shape in set([True, tf2.enabled()]): + for use_static_shape in [True, False]: shape = batch_dims + (rows, cols) name = "%s_%s_full_%s_static_%s" % (dtype.__name__, "_".join(map(str, shape)), diff --git a/tensorflow/python/kernel_tests/random/BUILD b/tensorflow/python/kernel_tests/random/BUILD index c3335cbc546..b5d291d2973 100644 --- a/tensorflow/python/kernel_tests/random/BUILD +++ b/tensorflow/python/kernel_tests/random/BUILD @@ -87,6 +87,7 @@ cuda_py_test( name = "random_ops_test", size = "medium", srcs = ["random_ops_test.py"], + tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -101,6 +102,7 @@ cuda_py_test( size = "medium", srcs = ["stateless_random_ops_test.py"], shard_count = 2, + tfrt_enabled = True, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", diff --git a/tensorflow/python/kernel_tests/random/random_ops_test.py b/tensorflow/python/kernel_tests/random/random_ops_test.py index 4dbbb7c7f1e..73c8bd09db0 100644 --- a/tensorflow/python/kernel_tests/random/random_ops_test.py +++ b/tensorflow/python/kernel_tests/random/random_ops_test.py @@ -336,6 +336,8 @@ class RandomUniformTest(RandomOpTestCommon): self.assertLess(error.max(), 5 * std) # Check that minval = maxval is fine iff we're producing no numbers + @test_util.disable_tfrt( + "TFE_TensorHandleToNumpy not implemented yet. b/156191611") def testUniformIntsDegenerate(self): for dt in dtypes.int32, dtypes.int64: def sample(n): diff --git a/tensorflow/python/kernel_tests/random/stateless_random_ops_test.py b/tensorflow/python/kernel_tests/random/stateless_random_ops_test.py index 0b9fbab716c..d7e50083deb 100644 --- a/tensorflow/python/kernel_tests/random/stateless_random_ops_test.py +++ b/tensorflow/python/kernel_tests/random/stateless_random_ops_test.py @@ -154,44 +154,54 @@ class StatelessOpsTest(test.TestCase, parameterized.TestCase): **kwds), functools.partial(random_ops.random_poisson, shape=(10,), **kwds)) + @test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396') @test_util.run_deprecated_v1 def testMatchFloat(self): self._test_match(self._float_cases()) + @test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396') @test_util.run_deprecated_v1 def testMatchInt(self): self._test_match(self._int_cases()) + @test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396') @test_util.run_deprecated_v1 def testMatchMultinomial(self): self._test_match(self._multinomial_cases()) + @test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396') @test_util.run_deprecated_v1 def testMatchGamma(self): self._test_match(self._gamma_cases()) + @test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396') @test_util.run_deprecated_v1 def testMatchPoisson(self): self._test_match(self._poisson_cases()) + @test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396') @test_util.run_deprecated_v1 def testDeterminismFloat(self): self._test_determinism( self._float_cases(shape_dtypes=(dtypes.int32, dtypes.int64))) + @test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396') @test_util.run_deprecated_v1 def testDeterminismInt(self): self._test_determinism( self._int_cases(shape_dtypes=(dtypes.int32, dtypes.int64))) + @test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396') @test_util.run_deprecated_v1 def testDeterminismMultinomial(self): self._test_determinism(self._multinomial_cases()) + @test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396') @test_util.run_deprecated_v1 def testDeterminismGamma(self): self._test_determinism(self._gamma_cases()) + @test_util.disable_tfrt('tensorflow::DirectSession::Run crashes. b/156187396') @test_util.run_deprecated_v1 def testDeterminismPoisson(self): self._test_determinism(self._poisson_cases()) diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py index 41ce9eb8a57..bf229943fd4 100644 --- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py +++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py @@ -57,6 +57,8 @@ from tensorflow.python.training import training_util from tensorflow.python.util import compat +@test_util.disable_tfrt( + "Trying to assign variable with wrong dtype. b/156200342") @test_util.with_control_flow_v2 class ResourceVariableOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase): @@ -332,6 +334,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase, g = gradients_impl.gradients(c, [b], unconnected_gradients="zero")[0] self.assertAllEqual(g.shape.as_list(), [1, 2]) + @test_util.disable_tfrt("Graph is not supported yet. b/156187905") @test_util.run_deprecated_v1 def testGradientCondInWhileLoop(self): v = resource_variable_ops.ResourceVariable(initial_value=1.0) @@ -965,6 +968,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase, assign = var.assign(np.zeros(shape=[2, 2])) self.evaluate(assign) + @test_util.disable_tfrt("Graph is not supported yet. b/156187905") @test_util.disable_xla("XLA doesn't allow changing shape at assignment, as " "dictated by tf2xla/xla_resource.cc:SetTypeAndShape") @test_util.run_in_graph_and_eager_modes @@ -1327,6 +1331,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase, # TODO(ebrevdo): Add run_in_graph_and_eager_modes once we can create # EagerTensor constants with TensorProto inputs. + @test_util.disable_tfrt("Graph is not supported yet. b/156187905") @test_util.run_in_graph_and_eager_modes() def testVariantInitializer(self): variant_shape_and_type_data = self.create_variant_shape_and_type_data() @@ -1520,6 +1525,7 @@ class PerReplicaResourceHandleTest(test_util.TensorFlowTestCase): context.LogicalDeviceConfiguration(), ]) + @test_util.disable_tfrt("Multiple device support. b/154956430") def testAllowedDevices(self): device0 = "/job:localhost/replica:0/task:0/device:CPU:0" device1 = "/job:localhost/replica:0/task:0/device:CPU:1" diff --git a/tensorflow/python/kernel_tests/reverse_sequence_op_test.py b/tensorflow/python/kernel_tests/reverse_sequence_op_test.py index 05307c9834a..267decff38b 100644 --- a/tensorflow/python/kernel_tests/reverse_sequence_op_test.py +++ b/tensorflow/python/kernel_tests/reverse_sequence_op_test.py @@ -19,10 +19,11 @@ from __future__ import division from __future__ import print_function import numpy as np -from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradient_checker @@ -135,56 +136,52 @@ class ReverseSequenceTest(test.TestCase): print("ReverseSequence gradient error = %g" % err) self.assertLess(err, 1e-8) - @test_util.run_deprecated_v1 def testShapeFunctionEdgeCases(self): - t = array_ops.reverse_sequence( - array_ops.placeholder( - dtypes.float32, shape=None), - seq_lengths=array_ops.placeholder( - dtypes.int64, shape=(32,)), - batch_axis=0, - seq_axis=1) - self.assertIs(t.get_shape().ndims, None) + # Enter graph mode since we want to test partial shapes + with context.graph_mode(): + t = array_ops.reverse_sequence( + array_ops.placeholder(dtypes.float32, shape=None), + seq_lengths=array_ops.placeholder(dtypes.int64, shape=(32,)), + batch_axis=0, + seq_axis=1) + self.assertIs(t.get_shape().ndims, None) + def testInvalidArguments(self): # Batch size mismatched between input and seq_lengths. - with self.assertRaises(ValueError): - array_ops.reverse_sequence( - array_ops.placeholder( - dtypes.float32, shape=(32, 2, 3)), - seq_lengths=array_ops.placeholder( - dtypes.int64, shape=(33,)), - seq_axis=3) + # seq_length too long + with self.assertRaisesRegexp((ValueError, errors.InvalidArgumentError), + (r"Dimensions must be equal|" + r"Length of seq_lengths != input.dims\(0\)")): + array_ops.reverse_sequence([[1, 2], [3, 4]], [2, 2, 2], seq_axis=1) + + # seq_length too short + with self.assertRaisesRegexp((ValueError, errors.InvalidArgumentError), + (r"Dimensions must be equal|" + r"Length of seq_lengths != input.dims\(0\)")): + array_ops.reverse_sequence([[1, 2], [3, 4]], [2], seq_axis=1) + + # Invalid seq_length shape + with self.assertRaisesRegexp((ValueError, errors.InvalidArgumentError), + ("Shape must be rank 1 but is rank 2|" + "seq_lengths must be 1-dim")): + array_ops.reverse_sequence([[1, 2], [3, 4]], [[2, 2]], seq_axis=1) # seq_axis out of bounds. - with self.assertRaisesRegexp(ValueError, "seq_dim must be < input rank"): - array_ops.reverse_sequence( - array_ops.placeholder( - dtypes.float32, shape=(32, 2, 3)), - seq_lengths=array_ops.placeholder( - dtypes.int64, shape=(32,)), - seq_axis=3) + with self.assertRaisesRegexp((ValueError, errors.InvalidArgumentError), + "seq_dim must be < input rank"): + array_ops.reverse_sequence([[1, 2], [3, 4]], [2, 2], seq_axis=2) # batch_axis out of bounds. - with self.assertRaisesRegexp(ValueError, "batch_dim must be < input rank"): - array_ops.reverse_sequence( - array_ops.placeholder( - dtypes.float32, shape=(32, 2, 3)), - seq_lengths=array_ops.placeholder( - dtypes.int64, shape=(32,)), - seq_axis=0, - batch_axis=3) + with self.assertRaisesRegexp((ValueError, errors.InvalidArgumentError), + "batch_dim must be < input rank"): + array_ops.reverse_sequence([[1, 2], [3, 4]], [2, 2], + seq_axis=1, + batch_axis=3) - with self.cached_session(): - inputs = array_ops.placeholder(dtypes.float32, shape=(32, 2, 3)) - seq_lengths = array_ops.placeholder(dtypes.int64, shape=(32,)) - output = array_ops.reverse_sequence( - inputs, seq_lengths=seq_lengths, - seq_axis=0) # batch_axis default is 0 - with self.assertRaisesOpError("batch_dim == seq_dim"): - output.eval(feed_dict={ - inputs: np.random.rand(32, 2, 3), - seq_lengths: xrange(32) - }) + with self.assertRaisesRegexp((errors.OpError, errors.InvalidArgumentError), + "batch_dim == seq_dim == 0"): + output = array_ops.reverse_sequence([[1, 2], [3, 4]], [2, 2], seq_axis=0) + self.evaluate(output) if __name__ == "__main__": diff --git a/tensorflow/python/kernel_tests/shape_ops_test.py b/tensorflow/python/kernel_tests/shape_ops_test.py index 7dde89c9818..6c2f2e236f2 100644 --- a/tensorflow/python/kernel_tests/shape_ops_test.py +++ b/tensorflow/python/kernel_tests/shape_ops_test.py @@ -500,6 +500,8 @@ class TileTest(test.TestCase, parameterized.TestCase): "int16": (dtypes.int16, int), "int32": (dtypes.int32, int), "int64": (dtypes.int64, int), + "uint32": (dtypes.uint32, int), + "uint64": (dtypes.uint64, int), bytes: (dtypes.string, bytes) } for dtype_np, (dtype_tf, cast) in types_to_test.items(): diff --git a/tensorflow/python/kernel_tests/signal/BUILD b/tensorflow/python/kernel_tests/signal/BUILD index adb12a5e850..bd893184570 100644 --- a/tensorflow/python/kernel_tests/signal/BUILD +++ b/tensorflow/python/kernel_tests/signal/BUILD @@ -149,7 +149,6 @@ cuda_py_tests( python_version = "PY3", shard_count = 4, tags = [ - "no_oss_py38", #TODO(b/151631881) "no_windows_gpu", ], deps = [ diff --git a/tensorflow/python/kernel_tests/signal/test_util.py b/tensorflow/python/kernel_tests/signal/test_util.py index 1e95fe4b28f..e8d477a843b 100644 --- a/tensorflow/python/kernel_tests/signal/test_util.py +++ b/tensorflow/python/kernel_tests/signal/test_util.py @@ -50,7 +50,7 @@ def grappler_optimize(graph, fetches=None, config_proto=None): return tf_optimizer.OptimizeGraph(config_proto, metagraph) -def tflite_convert(fn, input_templates, use_mlir=False): +def tflite_convert(fn, input_templates): """Converts the provided fn to tf.lite model. Args: @@ -59,7 +59,6 @@ def tflite_convert(fn, input_templates, use_mlir=False): input_templates: A list of Tensors, ndarrays or TensorSpecs describing the inputs that fn expects. The actual values of the Tensors or ndarrays are unused. - use_mlir: Experimental. Whether to use the tf.lite MLIR converter. Returns: The serialized tf.lite model. @@ -67,7 +66,6 @@ def tflite_convert(fn, input_templates, use_mlir=False): fn = def_function.function(fn) concrete_func = fn.get_concrete_function(*input_templates) converter = lite.TFLiteConverterV2([concrete_func]) - converter.experimental_new_converter = use_mlir return converter.convert() diff --git a/tensorflow/python/kernel_tests/signal/window_ops_test.py b/tensorflow/python/kernel_tests/signal/window_ops_test.py index 9f5fe6f64c7..9432e70c7f2 100644 --- a/tensorflow/python/kernel_tests/signal/window_ops_test.py +++ b/tensorflow/python/kernel_tests/signal/window_ops_test.py @@ -156,15 +156,14 @@ class WindowOpsTest(test.TestCase, parameterized.TestCase): self.assertLen(rewritten_graph.node, 1) @parameterized.parameters( - # Due to control flow, only MLIR is supported. # Only float32 is supported. - (window_ops.hann_window, 10, False, dtypes.float32, True), - (window_ops.hann_window, 10, True, dtypes.float32, True), - (window_ops.hamming_window, 10, False, dtypes.float32, True), - (window_ops.hamming_window, 10, True, dtypes.float32, True), - (window_ops.vorbis_window, 12, None, dtypes.float32, True)) - def test_tflite_convert(self, window_fn, window_length, periodic, dtype, - use_mlir): + (window_ops.hann_window, 10, False, dtypes.float32), + (window_ops.hann_window, 10, True, dtypes.float32), + (window_ops.hamming_window, 10, False, dtypes.float32), + (window_ops.hamming_window, 10, True, dtypes.float32), + (window_ops.vorbis_window, 12, None, dtypes.float32)) + def test_tflite_convert(self, window_fn, window_length, periodic, dtype): + def fn(window_length): try: return window_fn(window_length, periodic=periodic, dtype=dtype) @@ -172,7 +171,7 @@ class WindowOpsTest(test.TestCase, parameterized.TestCase): return window_fn(window_length, dtype=dtype) tflite_model = test_util.tflite_convert( - fn, [tensor_spec.TensorSpec(shape=[], dtype=dtypes.int32)], use_mlir) + fn, [tensor_spec.TensorSpec(shape=[], dtype=dtypes.int32)]) window_length = np.array(window_length).astype(np.int32) actual_output, = test_util.evaluate_tflite_model( tflite_model, [window_length]) diff --git a/tensorflow/python/kernel_tests/sparse_cross_op_test.py b/tensorflow/python/kernel_tests/sparse_cross_op_test.py index 5037f82af72..b352c1a080f 100644 --- a/tensorflow/python/kernel_tests/sparse_cross_op_test.py +++ b/tensorflow/python/kernel_tests/sparse_cross_op_test.py @@ -27,10 +27,55 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_sparse_ops from tensorflow.python.ops import sparse_ops from tensorflow.python.platform import test +class BaseSparseCrossOpTest(test.TestCase): + + def _sparse_tensor(self, data, batch_size=-1): + """Generates a SparseTensor. + + Args: + data: Should be a list of list of strings or int64. Each item of the outer + list represents a batch. Each item of the batch is a feature of a + specific feature column. + batch_size: optional batch size, especially for cases when data has no + entry for some batches. + + Returns: + A SparseTensor. + """ + indices = [] + values = [] + max_col_count = 0 + for batch, batch_ix in zip(data, range(len(data))): + for column, column_ix in zip(batch, range(len(batch))): + indices.append([batch_ix, column_ix]) + values.append(column) + max_col_count = max(max_col_count, column_ix + 1) + shape = [batch_size if batch_size != -1 else len(data), max_col_count] + value_type = ( + dtypes.string + if not values or isinstance(values[0], str) else dtypes.int64) + return sparse_tensor.SparseTensor( + constant_op.constant(indices, dtypes.int64, [len(indices), 2]), + constant_op.constant(values, value_type, [len(indices)]), + constant_op.constant(shape, dtypes.int64)) + + def _assert_sparse_tensor_equals(self, sp1, sp2): + self.assertAllEqual(sp1.indices.eval(), sp2.indices) + self.assertAllEqual(sp1.values.eval(), sp2.values) + self.assertAllEqual(sp1.dense_shape.eval(), sp2.dense_shape) + + def _assert_sparse_tensor_empty(self, sp): + self.assertEqual(0, sp.indices.size) + self.assertEqual(0, sp.values.size) + # TODO(zakaria): check if we can ignore the first dim of the shape. + self.assertEqual(0, sp.dense_shape[1]) + + class SparseCrossOpTest(test.TestCase): @test_util.run_deprecated_v1 @@ -459,5 +504,552 @@ class SparseCrossOpTest(test.TestCase): self.evaluate(sparse_ops.sparse_cross([st1, st2])) +class SparseCrossV2OpTest(BaseSparseCrossOpTest): + + @test_util.run_deprecated_v1 + def test_sparse(self): + """Tests a simple scenario.""" + sp_inp_1 = self._sparse_tensor([['batch1-FC1-F1'], + ['batch2-FC1-F1', 'batch2-FC1-F2']]) + sp_inp_2 = self._sparse_tensor([['batch1-FC2-F1'], + ['batch2-FC2-F1', 'batch2-FC2-F2']]) + inds, vals, shapes = gen_sparse_ops.sparse_cross_v2( + indices=[sp_inp_1.indices, sp_inp_2.indices], + values=[sp_inp_1.values, sp_inp_2.values], + shapes=[sp_inp_1.dense_shape, sp_inp_2.dense_shape], + dense_inputs=[], + sep='_X_') + out = sparse_tensor.SparseTensor(inds, vals, shapes) + # pyformat: disable + expected_out = self._sparse_tensor([ + ['batch1-FC1-F1_X_batch1-FC2-F1'], + ['batch2-FC1-F1_X_batch2-FC2-F1', + 'batch2-FC1-F1_X_batch2-FC2-F2', + 'batch2-FC1-F2_X_batch2-FC2-F1', + 'batch2-FC1-F2_X_batch2-FC2-F2' + ]]) + # pyformat: enable + with self.cached_session(): + self._assert_sparse_tensor_equals(expected_out, self.evaluate(out)) + + @test_util.run_deprecated_v1 + def test_sparse_sep(self): + """Tests a simple scenario.""" + sp_inp_1 = self._sparse_tensor([['batch1-FC1-F1'], + ['batch2-FC1-F1', 'batch2-FC1-F2']]) + sp_inp_2 = self._sparse_tensor([['batch1-FC2-F1'], + ['batch2-FC2-F1', 'batch2-FC2-F2']]) + inds, vals, shapes = gen_sparse_ops.sparse_cross_v2( + indices=[sp_inp_1.indices, sp_inp_2.indices], + values=[sp_inp_1.values, sp_inp_2.values], + shapes=[sp_inp_1.dense_shape, sp_inp_2.dense_shape], + dense_inputs=[], + sep='_Y_') + out = sparse_tensor.SparseTensor(inds, vals, shapes) + # pyformat: disable + expected_out = self._sparse_tensor([ + ['batch1-FC1-F1_Y_batch1-FC2-F1'], + ['batch2-FC1-F1_Y_batch2-FC2-F1', + 'batch2-FC1-F1_Y_batch2-FC2-F2', + 'batch2-FC1-F2_Y_batch2-FC2-F1', + 'batch2-FC1-F2_Y_batch2-FC2-F2' + ]]) + # pyformat: enable + with self.cached_session(): + self._assert_sparse_tensor_equals(expected_out, self.evaluate(out)) + + @test_util.run_deprecated_v1 + def test_dense(self): + """Tests only dense inputs.""" + dense_inp_1 = constant_op.constant([['batch1-FC1-F1', 'batch1-FC1-F2'], + ['batch2-FC1-F1', 'batch2-FC1-F2']], + dtypes.string) + dense_inp_2 = constant_op.constant([['batch1-FC2-F1', 'batch1-FC2-F2'], + ['batch2-FC2-F1', 'batch2-FC2-F2']], + dtypes.string) + inds, vals, shapes = gen_sparse_ops.sparse_cross_v2( + indices=[], + values=[], + shapes=[], + dense_inputs=[dense_inp_1, dense_inp_2], + sep='_X_') + out = sparse_tensor.SparseTensor(inds, vals, shapes) + # pyformat: disable + expected_out = self._sparse_tensor([ + ['batch1-FC1-F1_X_batch1-FC2-F1', 'batch1-FC1-F1_X_batch1-FC2-F2', + 'batch1-FC1-F2_X_batch1-FC2-F1', 'batch1-FC1-F2_X_batch1-FC2-F2' + ], + ['batch2-FC1-F1_X_batch2-FC2-F1', 'batch2-FC1-F1_X_batch2-FC2-F2', + 'batch2-FC1-F2_X_batch2-FC2-F1', 'batch2-FC1-F2_X_batch2-FC2-F2' + ]]) + # pyformat: enable + with self.cached_session(): + self._assert_sparse_tensor_equals(expected_out, self.evaluate(out)) + + @test_util.run_deprecated_v1 + def test_dense_sep(self): + """Tests only dense inputs.""" + dense_inp_1 = constant_op.constant([['batch1-FC1-F1', 'batch1-FC1-F2'], + ['batch2-FC1-F1', 'batch2-FC1-F2']], + dtypes.string) + dense_inp_2 = constant_op.constant([['batch1-FC2-F1', 'batch1-FC2-F2'], + ['batch2-FC2-F1', 'batch2-FC2-F2']], + dtypes.string) + inds, vals, shapes = gen_sparse_ops.sparse_cross_v2( + indices=[], + values=[], + shapes=[], + dense_inputs=[dense_inp_1, dense_inp_2], + sep='_') + out = sparse_tensor.SparseTensor(inds, vals, shapes) + # pyformat: disable + expected_out = self._sparse_tensor([ + ['batch1-FC1-F1_batch1-FC2-F1', 'batch1-FC1-F1_batch1-FC2-F2', + 'batch1-FC1-F2_batch1-FC2-F1', 'batch1-FC1-F2_batch1-FC2-F2' + ], + ['batch2-FC1-F1_batch2-FC2-F1', 'batch2-FC1-F1_batch2-FC2-F2', + 'batch2-FC1-F2_batch2-FC2-F1', 'batch2-FC1-F2_batch2-FC2-F2' + ]]) + # pyformat: enable + with self.cached_session(): + self._assert_sparse_tensor_equals(expected_out, self.evaluate(out)) + + @test_util.run_deprecated_v1 + def test_integer_mixed_string_sparse(self): + """Tests mixed type.""" + sp_inp_1 = self._sparse_tensor([[11], [333, 55555]]) + sp_inp_2 = self._sparse_tensor([['batch1-FC2-F1'], + ['batch2-FC2-F1', 'batch2-FC2-F2']]) + inds, vals, shapes = gen_sparse_ops.sparse_cross_v2( + indices=[sp_inp_1.indices, sp_inp_2.indices], + values=[sp_inp_1.values, sp_inp_2.values], + shapes=[sp_inp_1.dense_shape, sp_inp_2.dense_shape], + dense_inputs=[], + sep='_X_') + out = sparse_tensor.SparseTensor(inds, vals, shapes) + # pyformat: disable + expected_out = self._sparse_tensor([ + ['11_X_batch1-FC2-F1'], + ['333_X_batch2-FC2-F1', '333_X_batch2-FC2-F2', + '55555_X_batch2-FC2-F1', '55555_X_batch2-FC2-F2' + ]]) + # pyformat: enable + with self.cached_session(): + self._assert_sparse_tensor_equals(expected_out, self.evaluate(out)) + + @test_util.run_deprecated_v1 + def test_integer_mixed_string_dense(self): + """Tests mixed dense inputs.""" + dense_inp_1 = constant_op.constant([[11, 333], [55555, 999999]], + dtypes.int64) + dense_inp_2 = constant_op.constant([['batch1-FC2-F1', 'batch1-FC2-F2'], + ['batch2-FC2-F1', 'batch2-FC2-F2']], + dtypes.string) + inds, vals, shapes = gen_sparse_ops.sparse_cross_v2( + indices=[], + values=[], + shapes=[], + dense_inputs=[dense_inp_1, dense_inp_2], + sep='_X_') + out = sparse_tensor.SparseTensor(inds, vals, shapes) + # pyformat: disable + expected_out = self._sparse_tensor([ + ['11_X_batch1-FC2-F1', '11_X_batch1-FC2-F2', + '333_X_batch1-FC2-F1', '333_X_batch1-FC2-F2' + ], + ['55555_X_batch2-FC2-F1', '55555_X_batch2-FC2-F2', + '999999_X_batch2-FC2-F1', '999999_X_batch2-FC2-F2' + ]]) + # pyformat: enable + with self.cached_session(): + self._assert_sparse_tensor_equals(expected_out, self.evaluate(out)) + + @test_util.run_deprecated_v1 + def test_sparse_cross_dense(self): + """Tests sparse and dense inputs.""" + sp_inp = self._sparse_tensor([['batch1-FC1-F1'], + ['batch2-FC1-F1', 'batch2-FC1-F2']]) + dense_inp = constant_op.constant([['batch1-FC2-F1', 'batch1-FC2-F2'], + ['batch2-FC2-F1', 'batch2-FC2-F2']], + dtypes.string) + inds, vals, shapes = gen_sparse_ops.sparse_cross_v2( + indices=[sp_inp.indices], + values=[sp_inp.values], + shapes=[sp_inp.dense_shape], + dense_inputs=[dense_inp], + sep='_X_') + expected_out = self._sparse_tensor( + [['batch1-FC1-F1_X_batch1-FC2-F1', 'batch1-FC1-F1_X_batch1-FC2-F2'], + [ + 'batch2-FC1-F1_X_batch2-FC2-F1', 'batch2-FC1-F1_X_batch2-FC2-F2', + 'batch2-FC1-F2_X_batch2-FC2-F1', 'batch2-FC1-F2_X_batch2-FC2-F2' + ]]) + out = sparse_tensor.SparseTensor(inds, vals, shapes) + with self.cached_session(): + self._assert_sparse_tensor_equals(expected_out, self.evaluate(out)) + + @test_util.run_deprecated_v1 + def test_permutation_3x3x3(self): + """Tests 3x3x3 permutation.""" + sp_inp_1 = self._sparse_tensor( + [['batch1-FC1-F1', 'batch1-FC1-F2', 'batch1-FC1-F3']]) + sp_inp_2 = self._sparse_tensor( + [['batch1-FC2-F1', 'batch1-FC2-F2', 'batch1-FC2-F3']]) + sp_inp_3 = self._sparse_tensor( + [['batch1-FC3-F1', 'batch1-FC3-F2', 'batch1-FC3-F3']]) + inds, vals, shapes = gen_sparse_ops.sparse_cross_v2( + indices=[sp_inp_1.indices, sp_inp_2.indices, sp_inp_3.indices], + values=[sp_inp_1.values, sp_inp_2.values, sp_inp_3.values], + shapes=[ + sp_inp_1.dense_shape, sp_inp_2.dense_shape, sp_inp_3.dense_shape + ], + dense_inputs=[], + sep='_X_') + expected_out = self._sparse_tensor([[ + 'batch1-FC1-F1_X_batch1-FC2-F1_X_batch1-FC3-F1', + 'batch1-FC1-F1_X_batch1-FC2-F1_X_batch1-FC3-F2', + 'batch1-FC1-F1_X_batch1-FC2-F1_X_batch1-FC3-F3', + 'batch1-FC1-F1_X_batch1-FC2-F2_X_batch1-FC3-F1', + 'batch1-FC1-F1_X_batch1-FC2-F2_X_batch1-FC3-F2', + 'batch1-FC1-F1_X_batch1-FC2-F2_X_batch1-FC3-F3', + 'batch1-FC1-F1_X_batch1-FC2-F3_X_batch1-FC3-F1', + 'batch1-FC1-F1_X_batch1-FC2-F3_X_batch1-FC3-F2', + 'batch1-FC1-F1_X_batch1-FC2-F3_X_batch1-FC3-F3', + 'batch1-FC1-F2_X_batch1-FC2-F1_X_batch1-FC3-F1', + 'batch1-FC1-F2_X_batch1-FC2-F1_X_batch1-FC3-F2', + 'batch1-FC1-F2_X_batch1-FC2-F1_X_batch1-FC3-F3', + 'batch1-FC1-F2_X_batch1-FC2-F2_X_batch1-FC3-F1', + 'batch1-FC1-F2_X_batch1-FC2-F2_X_batch1-FC3-F2', + 'batch1-FC1-F2_X_batch1-FC2-F2_X_batch1-FC3-F3', + 'batch1-FC1-F2_X_batch1-FC2-F3_X_batch1-FC3-F1', + 'batch1-FC1-F2_X_batch1-FC2-F3_X_batch1-FC3-F2', + 'batch1-FC1-F2_X_batch1-FC2-F3_X_batch1-FC3-F3', + 'batch1-FC1-F3_X_batch1-FC2-F1_X_batch1-FC3-F1', + 'batch1-FC1-F3_X_batch1-FC2-F1_X_batch1-FC3-F2', + 'batch1-FC1-F3_X_batch1-FC2-F1_X_batch1-FC3-F3', + 'batch1-FC1-F3_X_batch1-FC2-F2_X_batch1-FC3-F1', + 'batch1-FC1-F3_X_batch1-FC2-F2_X_batch1-FC3-F2', + 'batch1-FC1-F3_X_batch1-FC2-F2_X_batch1-FC3-F3', + 'batch1-FC1-F3_X_batch1-FC2-F3_X_batch1-FC3-F1', + 'batch1-FC1-F3_X_batch1-FC2-F3_X_batch1-FC3-F2', + 'batch1-FC1-F3_X_batch1-FC2-F3_X_batch1-FC3-F3' + ]]) + out = sparse_tensor.SparseTensor(inds, vals, shapes) + with self.cached_session(): + self._assert_sparse_tensor_equals(expected_out, self.evaluate(out)) + + @test_util.run_deprecated_v1 + def test_permutation_3x1x2(self): + """Tests 3x1x2 permutation.""" + sp_inp_1 = self._sparse_tensor( + [['batch1-FC1-F1', 'batch1-FC1-F2', 'batch1-FC1-F3']]) + sp_inp_2 = self._sparse_tensor([['batch1-FC2-F1']]) + sp_inp_3 = self._sparse_tensor([['batch1-FC3-F1', 'batch1-FC3-F2']]) + inds, vals, shapes = gen_sparse_ops.sparse_cross_v2( + indices=[sp_inp_1.indices, sp_inp_2.indices, sp_inp_3.indices], + values=[sp_inp_1.values, sp_inp_2.values, sp_inp_3.values], + shapes=[ + sp_inp_1.dense_shape, sp_inp_2.dense_shape, sp_inp_3.dense_shape + ], + dense_inputs=[], + sep='_X_') + expected_out = self._sparse_tensor([[ + 'batch1-FC1-F1_X_batch1-FC2-F1_X_batch1-FC3-F1', + 'batch1-FC1-F1_X_batch1-FC2-F1_X_batch1-FC3-F2', + 'batch1-FC1-F2_X_batch1-FC2-F1_X_batch1-FC3-F1', + 'batch1-FC1-F2_X_batch1-FC2-F1_X_batch1-FC3-F2', + 'batch1-FC1-F3_X_batch1-FC2-F1_X_batch1-FC3-F1', + 'batch1-FC1-F3_X_batch1-FC2-F1_X_batch1-FC3-F2' + ]]) + out = sparse_tensor.SparseTensor(inds, vals, shapes) + with self.cached_session(): + self._assert_sparse_tensor_equals(expected_out, self.evaluate(out)) + + @test_util.run_deprecated_v1 + def test_large_batch(self): + """Tests with large batch size to force multithreading.""" + batch_size = 5000 + col1 = [] + col2 = [] + col3 = [] + for b in range(batch_size): + col1.append( + ['batch%d-FC1-F1' % b, + 'batch%d-FC1-F2' % b, + 'batch%d-FC1-F3' % b]) + col2.append(['batch%d-FC2-F1' % b]) + col3.append(['batch%d-FC3-F1' % b, 'batch%d-FC3-F2' % b]) + sp_inp_1 = self._sparse_tensor(col1) + sp_inp_2 = self._sparse_tensor(col2) + sp_inp_3 = self._sparse_tensor(col3) + + inds, vals, shapes = gen_sparse_ops.sparse_cross_v2( + indices=[sp_inp_1.indices, sp_inp_2.indices, sp_inp_3.indices], + values=[sp_inp_1.values, sp_inp_2.values, sp_inp_3.values], + shapes=[ + sp_inp_1.dense_shape, sp_inp_2.dense_shape, sp_inp_3.dense_shape + ], + dense_inputs=[], + sep='_X_') + + col_out = [] + for b in range(batch_size): + col_out.append([ + 'batch%d-FC1-F1_X_batch%d-FC2-F1_X_batch%d-FC3-F1' % (b, b, b), + 'batch%d-FC1-F1_X_batch%d-FC2-F1_X_batch%d-FC3-F2' % (b, b, b), + 'batch%d-FC1-F2_X_batch%d-FC2-F1_X_batch%d-FC3-F1' % (b, b, b), + 'batch%d-FC1-F2_X_batch%d-FC2-F1_X_batch%d-FC3-F2' % (b, b, b), + 'batch%d-FC1-F3_X_batch%d-FC2-F1_X_batch%d-FC3-F1' % (b, b, b), + 'batch%d-FC1-F3_X_batch%d-FC2-F1_X_batch%d-FC3-F2' % (b, b, b) + ]) + + expected_out = self._sparse_tensor(col_out) + out = sparse_tensor.SparseTensor(inds, vals, shapes) + with self.cached_session(): + self._assert_sparse_tensor_equals(expected_out, self.evaluate(out)) + + @test_util.run_deprecated_v1 + def test_one_column_empty(self): + """Tests when one column is empty. + + The crossed tensor should be empty. + """ + sp_inp_1 = self._sparse_tensor([['batch1-FC1-F1', 'batch1-FC1-F2']]) + sp_inp_2 = self._sparse_tensor([], 1) + sp_inp_3 = self._sparse_tensor([['batch1-FC3-F1', 'batch1-FC3-F2']]) + inds, vals, shapes = gen_sparse_ops.sparse_cross_v2( + indices=[sp_inp_1.indices, sp_inp_2.indices, sp_inp_3.indices], + values=[sp_inp_1.values, sp_inp_2.values, sp_inp_3.values], + shapes=[ + sp_inp_1.dense_shape, sp_inp_2.dense_shape, sp_inp_3.dense_shape + ], + dense_inputs=[], + sep='_X_') + out = sparse_tensor.SparseTensor(inds, vals, shapes) + with self.cached_session(): + self._assert_sparse_tensor_empty(self.evaluate(out)) + + @test_util.run_deprecated_v1 + def test_some_columns_empty(self): + """Tests when more than one columns are empty. + + Cross for the corresponding batch should be empty. + """ + sp_inp_1 = self._sparse_tensor([['batch1-FC1-F1', 'batch1-FC1-F2']], 2) + sp_inp_2 = self._sparse_tensor([['batch1-FC2-F1'], ['batch2-FC2-F1']], 2) + sp_inp_3 = self._sparse_tensor([['batch1-FC3-F1', 'batch1-FC3-F2']], 2) + inds, vals, shapes = gen_sparse_ops.sparse_cross_v2( + indices=[sp_inp_1.indices, sp_inp_2.indices, sp_inp_3.indices], + values=[sp_inp_1.values, sp_inp_2.values, sp_inp_3.values], + shapes=[ + sp_inp_1.dense_shape, sp_inp_2.dense_shape, sp_inp_3.dense_shape + ], + dense_inputs=[], + sep='_X_') + expected_out = self._sparse_tensor([[ + 'batch1-FC1-F1_X_batch1-FC2-F1_X_batch1-FC3-F1', + 'batch1-FC1-F1_X_batch1-FC2-F1_X_batch1-FC3-F2', + 'batch1-FC1-F2_X_batch1-FC2-F1_X_batch1-FC3-F1', + 'batch1-FC1-F2_X_batch1-FC2-F1_X_batch1-FC3-F2' + ]], 2) + out = sparse_tensor.SparseTensor(inds, vals, shapes) + with self.cached_session(): + self._assert_sparse_tensor_equals(expected_out, self.evaluate(out)) + + @test_util.run_deprecated_v1 + def test_all_columns_empty(self): + """Tests when all columns are empty. + + The crossed tensor should be empty. + """ + sp_inp_1 = self._sparse_tensor([]) + sp_inp_2 = self._sparse_tensor([]) + sp_inp_3 = self._sparse_tensor([]) + inds, vals, shapes = gen_sparse_ops.sparse_cross_v2( + indices=[sp_inp_1.indices, sp_inp_2.indices, sp_inp_3.indices], + values=[sp_inp_1.values, sp_inp_2.values, sp_inp_3.values], + shapes=[ + sp_inp_1.dense_shape, sp_inp_2.dense_shape, sp_inp_3.dense_shape + ], + dense_inputs=[], + sep='_X_') + out = sparse_tensor.SparseTensor(inds, vals, shapes) + with self.cached_session(): + self._assert_sparse_tensor_empty(self.evaluate(out)) + + +class SparseCrossHashedOpTest(BaseSparseCrossOpTest): + + @test_util.run_deprecated_v1 + def test_hashed_zero_bucket_no_hash_key(self): + sp_inp_1 = self._sparse_tensor([['batch1-FC1-F1']]) + sp_inp_2 = self._sparse_tensor([['batch1-FC2-F1']]) + sp_inp_3 = self._sparse_tensor([['batch1-FC3-F1']]) + inds, vals, shapes = gen_sparse_ops.sparse_cross_hashed( + indices=[sp_inp_1.indices, sp_inp_2.indices, sp_inp_3.indices], + values=[sp_inp_1.values, sp_inp_2.values, sp_inp_3.values], + shapes=[ + sp_inp_1.dense_shape, sp_inp_2.dense_shape, sp_inp_3.dense_shape + ], + dense_inputs=[], + num_buckets=0, + salt=[1, 1], + strong_hash=False) + # Check actual hashed output to prevent unintentional hashing changes. + expected_out = self._sparse_tensor([[9186962005966787372]]) + out = sparse_tensor.SparseTensor(inds, vals, shapes) + with self.cached_session(): + self._assert_sparse_tensor_equals(expected_out, self.evaluate(out)) + + # salt is not being used when `strong_hash` is False. + inds_2, vals_2, shapes_2 = gen_sparse_ops.sparse_cross_hashed( + indices=[sp_inp_1.indices, sp_inp_2.indices, sp_inp_3.indices], + values=[sp_inp_1.values, sp_inp_2.values, sp_inp_3.values], + shapes=[ + sp_inp_1.dense_shape, sp_inp_2.dense_shape, sp_inp_3.dense_shape + ], + dense_inputs=[], + num_buckets=0, + salt=[137, 173], + strong_hash=False) + out_2 = sparse_tensor.SparseTensor(inds_2, vals_2, shapes_2) + with self.cached_session(): + self._assert_sparse_tensor_equals(expected_out, self.evaluate(out_2)) + + @test_util.run_deprecated_v1 + def test_hashed_output(self): + sp_inp_1 = self._sparse_tensor([['batch1-FC1-F1']]) + sp_inp_2 = self._sparse_tensor([['batch1-FC2-F1']]) + sp_inp_3 = self._sparse_tensor([['batch1-FC3-F1']]) + inds, vals, shapes = gen_sparse_ops.sparse_cross_hashed( + indices=[sp_inp_1.indices, sp_inp_2.indices, sp_inp_3.indices], + values=[sp_inp_1.values, sp_inp_2.values, sp_inp_3.values], + shapes=[ + sp_inp_1.dense_shape, sp_inp_2.dense_shape, sp_inp_3.dense_shape + ], + dense_inputs=[], + num_buckets=100, + salt=[137, 173], + strong_hash=False) + # Check actual hashed output to prevent unintentional hashing changes. + expected_out = self._sparse_tensor([[79]]) + out = sparse_tensor.SparseTensor(inds, vals, shapes) + with self.cached_session(): + self._assert_sparse_tensor_equals(expected_out, self.evaluate(out)) + + @test_util.run_deprecated_v1 + def test_hashed_has_no_collision(self): + """Tests that fingerprint concatenation has no collisions.""" + # Although the last 10 bits of 359 and 1024+359 are identical. + # As a result, all the crosses shouldn't collide. + t1 = constant_op.constant([[359], [359 + 1024]], dtype=dtypes.int64) + t2 = constant_op.constant( + [list(range(10)), list(range(10))], dtype=dtypes.int64) + inds, vals, shapes = gen_sparse_ops.sparse_cross_hashed( + indices=[], + values=[], + shapes=[], + dense_inputs=[t2, t1], + num_buckets=1024, + salt=[137, 173], + strong_hash=False) + cross = sparse_tensor.SparseTensor(inds, vals, shapes) + cross_dense = sparse_ops.sparse_tensor_to_dense(cross) + with session.Session(): + values = self.evaluate(cross_dense) + self.assertTrue(numpy.not_equal(values[0], values[1]).all()) + + def test_hashed_3x1x2(self): + """Tests 3x1x2 permutation with hashed output.""" + sp_inp_1 = self._sparse_tensor( + [['batch1-FC1-F1', 'batch1-FC1-F2', 'batch1-FC1-F3']]) + sp_inp_2 = self._sparse_tensor([['batch1-FC2-F1']]) + sp_inp_3 = self._sparse_tensor([['batch1-FC3-F1', 'batch1-FC3-F2']]) + inds, vals, shapes = gen_sparse_ops.sparse_cross_hashed( + indices=[sp_inp_1.indices, sp_inp_2.indices, sp_inp_3.indices], + values=[sp_inp_1.values, sp_inp_2.values, sp_inp_3.values], + shapes=[ + sp_inp_1.dense_shape, sp_inp_2.dense_shape, sp_inp_3.dense_shape + ], + dense_inputs=[], + num_buckets=1000, + salt=[137, 173], + strong_hash=False) + output = sparse_tensor.SparseTensor(inds, vals, shapes) + with self.cached_session(): + out = self.evaluate(output) + self.assertEqual(6, len(out.values)) + self.assertAllEqual([[0, i] for i in range(6)], out.indices) + self.assertTrue(all(x < 1000 and x >= 0 for x in out.values)) + all_values_are_different = len(out.values) == len(set(out.values)) + self.assertTrue(all_values_are_different) + + def test_hashed_different_salt(self): + sp_inp_1 = self._sparse_tensor( + [['batch1-FC1-F1', 'batch1-FC1-F2', 'batch1-FC1-F3']]) + sp_inp_2 = self._sparse_tensor([['batch1-FC2-F1']]) + sp_inp_3 = self._sparse_tensor([['batch1-FC3-F1', 'batch1-FC3-F2']]) + inds, vals, shapes = gen_sparse_ops.sparse_cross_hashed( + indices=[sp_inp_1.indices, sp_inp_2.indices, sp_inp_3.indices], + values=[sp_inp_1.values, sp_inp_2.values, sp_inp_3.values], + shapes=[ + sp_inp_1.dense_shape, sp_inp_2.dense_shape, sp_inp_3.dense_shape + ], + dense_inputs=[], + strong_hash=False, + num_buckets=1000, + salt=[137, 173]) + output = sparse_tensor.SparseTensor(inds, vals, shapes) + inds_2, vals_2, shapes_2 = gen_sparse_ops.sparse_cross_hashed( + indices=[sp_inp_1.indices, sp_inp_2.indices, sp_inp_3.indices], + values=[sp_inp_1.values, sp_inp_2.values, sp_inp_3.values], + shapes=[ + sp_inp_1.dense_shape, sp_inp_2.dense_shape, sp_inp_3.dense_shape + ], + dense_inputs=[], + strong_hash=True, + num_buckets=1000, + salt=[137, 1]) + output_2 = sparse_tensor.SparseTensor(inds_2, vals_2, shapes_2) + with self.cached_session(): + out = self.evaluate(output) + out_2 = self.evaluate(output_2) + self.assertAllEqual(out.indices, out_2.indices) + self.assertNotAllEqual(out.values, out_2.values) + + def test_sep_ignored_in_hashed_out(self): + sp_inp_1 = self._sparse_tensor( + [['batch1-FC1-F1', 'batch1-FC1-F2', 'batch1-FC1-F3']]) + sp_inp_2 = self._sparse_tensor([['batch1-FC2-F1']]) + sp_inp_3 = self._sparse_tensor([['batch1-FC3-F1', 'batch1-FC3-F2']]) + inds, vals, shapes = gen_sparse_ops.sparse_cross_hashed( + indices=[sp_inp_1.indices, sp_inp_2.indices, sp_inp_3.indices], + values=[sp_inp_1.values, sp_inp_2.values, sp_inp_3.values], + shapes=[ + sp_inp_1.dense_shape, sp_inp_2.dense_shape, sp_inp_3.dense_shape + ], + dense_inputs=[], + strong_hash=True, + num_buckets=1000, + salt=[137, 173]) + output = sparse_tensor.SparseTensor(inds, vals, shapes) + inds_2, vals_2, shapes_2 = gen_sparse_ops.sparse_cross_hashed( + indices=[sp_inp_1.indices, sp_inp_2.indices, sp_inp_3.indices], + values=[sp_inp_1.values, sp_inp_2.values, sp_inp_3.values], + shapes=[ + sp_inp_1.dense_shape, sp_inp_2.dense_shape, sp_inp_3.dense_shape + ], + dense_inputs=[], + strong_hash=True, + num_buckets=1000, + salt=[137, 173]) + output_2 = sparse_tensor.SparseTensor(inds_2, vals_2, shapes_2) + with self.cached_session(): + out = self.evaluate(output) + out_2 = self.evaluate(output_2) + self.assertAllEqual(out.indices, out_2.indices) + self.assertAllEqual(out.values, out_2.values) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/kernel_tests/svd_op_test.py b/tensorflow/python/kernel_tests/svd_op_test.py index 6c2199cc591..cad131dda74 100644 --- a/tensorflow/python/kernel_tests/svd_op_test.py +++ b/tensorflow/python/kernel_tests/svd_op_test.py @@ -23,11 +23,13 @@ import numpy as np from tensorflow.python.client import session from tensorflow.python.eager import context from tensorflow.python.framework import constant_op +from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradient_checker +from tensorflow.python.ops import gradient_checker_v2 from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops @@ -46,16 +48,16 @@ def _AddTest(test_class, op_name, testcase_name, fn): class SvdOpTest(test.TestCase): - @test_util.run_v1_only("b/120545219") + @test_util.run_in_graph_and_eager_modes(use_gpu=True) def testWrongDimensions(self): # The input to svd should be a tensor of at least rank 2. scalar = constant_op.constant(1.) - with self.assertRaisesRegexp(ValueError, - "Shape must be at least rank 2 but is rank 0"): + with self.assertRaisesRegexp((ValueError, errors_impl.InvalidArgumentError), + "rank.* 2.*0"): linalg_ops.svd(scalar) vector = constant_op.constant([1., 2.]) - with self.assertRaisesRegexp(ValueError, - "Shape must be at least rank 2 but is rank 1"): + with self.assertRaisesRegexp((ValueError, errors_impl.InvalidArgumentError), + "rank.* 2.*1"): linalg_ops.svd(vector) @test_util.run_in_graph_and_eager_modes(use_gpu=True) @@ -224,45 +226,41 @@ def _NormalizingSvd(tf_a, full_matrices_): def _GetSvdGradOpTest(dtype_, shape_, compute_uv_, full_matrices_): - @test_util.run_v1_only("b/120545219") + @test_util.run_in_graph_and_eager_modes(use_gpu=True) def Test(self): - np.random.seed(42) - a = np.random.uniform(low=-1.0, high=1.0, size=shape_).astype(dtype_) - if dtype_ in [np.complex64, np.complex128]: - a += 1j * np.random.uniform( - low=-1.0, high=1.0, size=shape_).astype(dtype_) + + def RandomInput(): + np.random.seed(42) + a = np.random.uniform(low=-1.0, high=1.0, size=shape_).astype(dtype_) + if dtype_ in [np.complex64, np.complex128]: + a += 1j * np.random.uniform( + low=-1.0, high=1.0, size=shape_).astype(dtype_) + return a + # Optimal stepsize for central difference is O(epsilon^{1/3}). # See Equation (21) in: # http://www.karenkopecky.net/Teaching/eco613614/Notes_NumericalDifferentiation.pdf # TODO(rmlarsen): Move step size control to gradient checker. epsilon = np.finfo(dtype_).eps - delta = 0.1 * epsilon**(1.0 / 3.0) + delta = 0.25 * epsilon**(1.0 / 3.0) if dtype_ in [np.float32, np.complex64]: tol = 3e-2 else: tol = 1e-6 - with self.session(use_gpu=True): - tf_a = constant_op.constant(a) - if compute_uv_: - tf_s, tf_u, tf_v = _NormalizingSvd(tf_a, full_matrices_) - outputs = [tf_s, tf_u, tf_v] - else: - tf_s = linalg_ops.svd(tf_a, compute_uv=False) - outputs = [tf_s] - for b in outputs: - x_init = np.random.uniform( - low=-1.0, high=1.0, size=shape_).astype(dtype_) - if dtype_ in [np.complex64, np.complex128]: - x_init += 1j * np.random.uniform( - low=-1.0, high=1.0, size=shape_).astype(dtype_) - theoretical, numerical = gradient_checker.compute_gradient( - tf_a, - tf_a.get_shape().as_list(), - b, - b.get_shape().as_list(), - x_init_value=x_init, - delta=delta) - self.assertAllClose(theoretical, numerical, atol=tol, rtol=tol) + if compute_uv_: + funcs = [ + lambda a: _NormalizingSvd(a, full_matrices_)[0], + lambda a: _NormalizingSvd(a, full_matrices_)[1], + lambda a: _NormalizingSvd(a, full_matrices_)[2] + ] + else: + funcs = [lambda a: linalg_ops.svd(a, compute_uv=False)] + + for f in funcs: + theoretical, numerical = gradient_checker_v2.compute_gradient( + f, [RandomInput()], delta=delta) + self.assertAllClose(theoretical, numerical, atol=tol, rtol=tol) + return Test diff --git a/tensorflow/python/kernel_tests/tensordot_op_test.py b/tensorflow/python/kernel_tests/tensordot_op_test.py index 71e448f7855..7f8c5e9781b 100644 --- a/tensorflow/python/kernel_tests/tensordot_op_test.py +++ b/tensorflow/python/kernel_tests/tensordot_op_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.python import tf2 +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl @@ -41,16 +41,19 @@ def _add_test(test, test_name, fn): class TensordotTest(test_lib.TestCase): - @test_util.run_v1_only("b/120545219") + @test_util.run_in_graph_and_eager_modes(use_gpu=True) def test_invalid_shape(self): a = [[1, 2], [3, 4]] b = [[1, 2], [3, 4], [5, 6]] a_axes = [1] b_axes = [0] # Invalid static shapes. - with self.assertRaises(ValueError): + with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)): math_ops.tensordot(a, b, (a_axes, b_axes)) + # Invalid dynamic shapes. + if context.executing_eagerly(): + return with self.cached_session() as sess: with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, "Matrix size-incompatible"): @@ -65,7 +68,7 @@ class TensordotTest(test_lib.TestCase): axes_ph: (a_axes, b_axes) }) - @test_util.run_v1_only("b/120545219") + @test_util.run_in_graph_and_eager_modes(use_gpu=True) def test_invalid_axes(self): a = [[1, 2], [3, 4]] b = [[1, 2], [3, 4]] @@ -77,6 +80,8 @@ class TensordotTest(test_lib.TestCase): with self.assertRaises(IndexError): math_ops.tensordot(a, b, [[0], [7]]) + if context.executing_eagerly(): + return # Invalid dynamic axes. a_ph = array_ops.placeholder(dtypes.float32) b_ph = array_ops.placeholder(dtypes.float32) @@ -93,22 +98,22 @@ class TensordotTest(test_lib.TestCase): axes_ph: axes_value }) - # Test case for 11950 + # Test case for https://github.com/tensorflow/tensorflow/issues/11950 + @test_util.run_in_graph_and_eager_modes(use_gpu=True) def test_valid_axis(self): for axes_value in [1, 2], [[1], [2]], [[], []], 0: - with self.cached_session(): - np_a = np.ones((3, 3)) - np_b = np.array([2, 3, 1])[None, None] - np_ans = np.tensordot(np_a, np_b, axes_value) + np_a = np.ones((3, 3)) + np_b = np.array([2, 3, 1])[None, None] + np_ans = np.tensordot(np_a, np_b, axes_value) - tf_a = array_ops.ones((3, 3), dtype=dtypes.float32) - tf_b = constant_op.constant([2, 3, 1], dtype=dtypes.float32)[None, None] - tf_ans = math_ops.tensordot(tf_a, tf_b, axes_value) + tf_a = array_ops.ones((3, 3), dtype=dtypes.float32) + tf_b = constant_op.constant([2, 3, 1], dtype=dtypes.float32)[None, None] + tf_ans = math_ops.tensordot(tf_a, tf_b, axes_value) - self.assertAllEqual(tf_ans.shape, np_ans.shape) - self.assertAllEqual(tf_ans, np_ans) + self.assertAllEqual(tf_ans.shape, np_ans.shape) + self.assertAllEqual(self.evaluate(tf_ans), np_ans) - @test_util.run_v1_only("b/120545219") + @test_util.run_v1_only("Shape inference test") def test_partial_shape_inference(self): for axes in ([1], [0]), 1: a = array_ops.placeholder(dtypes.float32) @@ -159,7 +164,10 @@ def _get_tensordot_tests(dtype_, rank_a_, rank_b_, num_dims_, dynamic_shape_): size=np.prod(b_shape)).reshape(b_shape).astype(dtype_) return a, b, a_dims, b_dims + @test_util.run_in_graph_and_eager_modes(use_gpu=True) def test_tensordot(self): + if dynamic_shape_ and context.executing_eagerly(): + self.skipTest("Placeholders not support in eager mode") num_trials = min(30, num_dims_ * num_dims_) if dtype_ == np.float16: tol = 0.05 @@ -187,7 +195,10 @@ def _get_tensordot_tests(dtype_, rank_a_, rank_b_, num_dims_, dynamic_shape_): self.assertAllClose(tf_ans, np_ans, rtol=tol, atol=tol) self.assertAllEqual(tf_ans.shape, np_ans.shape) + @test_util.run_in_graph_and_eager_modes(use_gpu=True) def test_tensordot_scalar_axes(self): + if dynamic_shape_ and context.executing_eagerly(): + self.skipTest("Placeholders not support in eager mode") if num_dims_ < 1: self.skipTest("Not a test") if dtype_ == np.float16: @@ -229,7 +240,7 @@ if __name__ == "__main__": for rank_b in 1, 2, 4, 5: for num_dims in range(0, min(rank_a, rank_b) + 1): # TF2 does not support placeholders under eager so we skip it - for dynamic_shape in set([False, not tf2.enabled()]): + for dynamic_shape in set([False, True]): for testcase in _get_tensordot_tests(dtype, rank_a, rank_b, num_dims, dynamic_shape): name = "%s_%s_%s_%s_%s_%s" % (testcase.__name__, dtype.__name__, diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 9a5e95d8aad..a641633b1f5 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -57,6 +57,7 @@ _BaseSlice = slice @tf_export("reshape", v1=["reshape", "manip.reshape"]) +@dispatch.add_dispatch_support def reshape(tensor, shape, name=None): # pylint: disable=redefined-outer-name r"""Reshapes a tensor. @@ -197,6 +198,7 @@ def reshape(tensor, shape, name=None): # pylint: disable=redefined-outer-name @tf_export("fill") +@dispatch.add_dispatch_support def fill(dims, value, name=None): r"""Creates a tensor filled with a scalar value. @@ -455,6 +457,7 @@ listdiff.__doc__ = gen_array_ops.list_diff.__doc__ + "\n" + listdiff.__doc__ "This op will be removed after the deprecation date. " "Please switch to tf.sets.difference().") @tf_export(v1=["setdiff1d"]) +@dispatch.add_dispatch_support def setdiff1d(x, y, index_dtype=dtypes.int32, name=None): """Computes the difference between two lists of numbers or strings. @@ -498,6 +501,7 @@ setdiff1d.__doc__ = gen_array_ops.list_diff.__doc__ @tf_export("broadcast_dynamic_shape") +@dispatch.add_dispatch_support def broadcast_dynamic_shape(shape_x, shape_y): """Computes the shape of a broadcast given symbolic shapes. @@ -523,6 +527,7 @@ def broadcast_dynamic_shape(shape_x, shape_y): @tf_export("broadcast_static_shape") +@dispatch.add_dispatch_support def broadcast_static_shape(shape_x, shape_y): """Computes the shape of a broadcast given known shapes. @@ -550,25 +555,23 @@ def broadcast_static_shape(shape_x, shape_y): @tf_export("shape", v1=[]) +@dispatch.add_dispatch_support def shape_v2(input, out_type=dtypes.int32, name=None): # pylint: disable=redefined-builtin """Returns the shape of a tensor. See also `tf.size`, `tf.rank`. - This operation returns a 1-D integer tensor representing the shape of `input`. - This represents the minimal set of known information at definition time. + `tf.shape` returns a 1-D integer tensor representing the shape of `input`. For example: >>> t = tf.constant([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]]) >>> tf.shape(t) <tf.Tensor: shape=(3,), dtype=int32, numpy=array([2, 2, 3], dtype=int32)> - >>> tf.shape(t).numpy() - array([2, 2, 3], dtype=int32) - Note: When using symbolic tensors, such as when using the Keras functional - API, tf.shape() will return the shape of the symbolic tensor. + Note: When using symbolic tensors, such as when using the Keras API, + tf.shape() will return the shape of the symbolic tensor. >>> a = tf.keras.layers.Input((None, 10)) >>> tf.shape(a) @@ -578,10 +581,13 @@ def shape_v2(input, out_type=dtypes.int32, name=None): >>> a.shape TensorShape([None, None, 10]) + + (The first `None` represents the as yet unknown batch size.) `tf.shape` and `Tensor.shape` should be identical in eager mode. Within `tf.function` or within a `compat.v1` context, not all dimensions may be - known until execution time. + known until execution time. Hence when defining custom layers and models + for graph mode, prefer the dynamic `tf.shape(x)` over the static `x.shape`. Args: input: A `Tensor` or `SparseTensor`. @@ -596,6 +602,7 @@ def shape_v2(input, out_type=dtypes.int32, name=None): @tf_export(v1=["shape"]) +@dispatch.add_dispatch_support def shape(input, name=None, out_type=dtypes.int32): # pylint: disable=redefined-builtin """Returns the shape of a tensor. @@ -650,6 +657,7 @@ def shape_internal(input, name=None, optimize=True, out_type=dtypes.int32): @tf_export("shape_n") +@dispatch.add_dispatch_support def shape_n(input, out_type=dtypes.int32, name=None): # pylint: disable=redefined-builtin """Returns shape of tensors. @@ -1007,6 +1015,7 @@ def _slice_helper(tensor, slice_spec, var=None): # pylint: disable=undefined-variable,protected-access,redefined-outer-name @tf_export("slice") +@dispatch.add_dispatch_support def slice(input_, begin, size, name=None): # pylint: disable=redefined-builtin """Extracts a slice from a tensor. @@ -1062,6 +1071,7 @@ def slice(input_, begin, size, name=None): # pylint: disable=invalid-name @tf_export("strided_slice") +@dispatch.add_dispatch_support def strided_slice(input_, begin, end, @@ -1253,6 +1263,7 @@ ops.Tensor._override_operator("__getitem__", _slice_helper) @tf_export("parallel_stack") +@dispatch.add_dispatch_support def parallel_stack(values, name="parallel_stack"): """Stacks a list of rank-`R` tensors into one rank-`(R+1)` tensor in parallel. @@ -1489,6 +1500,7 @@ ops.register_tensor_conversion_function((list, tuple), @tf_export("unstack") +@dispatch.add_dispatch_support def unstack(value, num=None, axis=0, name="unstack"): """Unpacks the given dimension of a rank-`R` tensor into rank-`(R-1)` tensors. @@ -1632,6 +1644,7 @@ def concat(values, axis, name="concat"): @tf_export(v1=["boolean_mask"]) +@dispatch.add_dispatch_support def boolean_mask(tensor, mask, name="boolean_mask", axis=None): """Apply boolean mask to tensor. @@ -1824,6 +1837,7 @@ def sparse_mask(a, mask_indices, name=None): @tf_export("unique") +@dispatch.add_dispatch_support def unique(x, out_idx=dtypes.int32, name=None): """Finds unique elements in a 1-D tensor. @@ -1871,6 +1885,7 @@ unique.__doc__ = gen_array_ops.unique.__doc__ @tf_export("unique_with_counts") +@dispatch.add_dispatch_support def unique_with_counts(x, out_idx=dtypes.int32, name=None): """Finds unique elements in a 1-D tensor. @@ -1923,6 +1938,7 @@ unique_with_counts.__doc__ = gen_array_ops.unique_with_counts.__doc__ @tf_export("split") +@dispatch.add_dispatch_support def split(value, num_or_size_splits, axis=0, num=None, name="split"): """Splits a tensor `value` into a list of sub tensors. @@ -2000,6 +2016,7 @@ def split(value, num_or_size_splits, axis=0, num=None, name="split"): @tf_export("transpose", v1=[]) +@dispatch.add_dispatch_support def transpose_v2(a, perm=None, conjugate=False, name="transpose"): """Transposes `a`, where `a` is a Tensor. @@ -2080,6 +2097,7 @@ def transpose_v2(a, perm=None, conjugate=False, name="transpose"): @tf_export(v1=["transpose"]) +@dispatch.add_dispatch_support def transpose(a, perm=None, name="transpose", conjugate=False): """Transposes `a`. @@ -2170,6 +2188,7 @@ def transpose(a, perm=None, name="transpose", conjugate=False): @tf_export( "linalg.matrix_transpose", v1=["linalg.transpose", "linalg.matrix_transpose", "matrix_transpose"]) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints("matrix_transpose", "linalg.transpose") def matrix_transpose(a, name="matrix_transpose", conjugate=False): """Transposes last two dimensions of tensor `a`. @@ -2248,6 +2267,7 @@ def matrix_transpose(a, name="matrix_transpose", conjugate=False): @tf_export("linalg.diag", v1=["linalg.diag", "matrix_diag"]) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints("matrix_diag") def matrix_diag(diagonal, name="diag", @@ -2416,6 +2436,7 @@ def matrix_diag(diagonal, @tf_export("linalg.diag_part", v1=["linalg.diag_part", "matrix_diag_part"]) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints("matrix_diag_part") @dispatch.add_dispatch_support def matrix_diag_part( @@ -2556,6 +2577,7 @@ def matrix_diag_part( @tf_export("linalg.set_diag", v1=["linalg.set_diag", "matrix_set_diag"]) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints("matrix_set_diag") def matrix_set_diag( input, # pylint:disable=redefined-builtin @@ -2719,6 +2741,7 @@ def _tag_zeros_tensor(fun): @tf_export("zeros") +@dispatch.add_dispatch_support @_tag_zeros_tensor def zeros(shape, dtype=dtypes.float32, name=None): """Creates a tensor with all elements set to zero. @@ -2971,6 +2994,7 @@ def ones_like_impl(tensor, dtype, name, optimize=True): @tf_export("ones") +@dispatch.add_dispatch_support def ones(shape, dtype=dtypes.float32, name=None): """Creates a tensor with all elements set to one (1). @@ -3182,6 +3206,7 @@ def sparse_placeholder(dtype, shape=None, name=None): @tf_export("pad", v1=[]) +@dispatch.add_dispatch_support def pad_v2(tensor, paddings, mode="CONSTANT", constant_values=0, name=None): """Pads a tensor. @@ -3240,6 +3265,7 @@ def pad_v2(tensor, paddings, mode="CONSTANT", constant_values=0, name=None): @tf_export(v1=["pad"]) +@dispatch.add_dispatch_support def pad(tensor, paddings, mode="CONSTANT", name=None, constant_values=0): # pylint: disable=invalid-name """Pads a tensor. @@ -3357,6 +3383,7 @@ def _get_paddings_constant(paddings): @tf_export("meshgrid") +@dispatch.add_dispatch_support def meshgrid(*args, **kwargs): """Broadcasts parameters for evaluation on an N-D grid. @@ -3500,6 +3527,7 @@ def _TileGradShape(op): @tf_export("edit_distance") +@dispatch.add_dispatch_support def edit_distance(hypothesis, truth, normalize=True, name="edit_distance"): """Computes the Levenshtein distance between sequences. @@ -3694,6 +3722,7 @@ def required_space_to_batch_paddings(input_shape, @tf_export(v1=["nn.space_to_batch", "space_to_batch"]) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints("space_to_batch") def space_to_batch( # pylint: disable=missing-docstring input, # pylint: disable=redefined-builtin @@ -3717,6 +3746,7 @@ space_to_batch.__doc__ = gen_array_ops.space_to_batch.__doc__ @tf_export("space_to_batch", "nn.space_to_batch", v1=[]) +@dispatch.add_dispatch_support def space_to_batch_v2(input, block_shape, paddings, name=None): # pylint: disable=redefined-builtin return space_to_batch_nd(input, block_shape, paddings, name) @@ -3725,6 +3755,7 @@ space_to_batch_v2.__doc__ = gen_array_ops.space_to_batch_nd.__doc__ @tf_export(v1=["nn.space_to_depth", "space_to_depth"]) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints("space_to_depth") def space_to_depth(input, block_size, name=None, data_format="NHWC"): # pylint: disable=redefined-builtin return gen_array_ops.space_to_depth(input, block_size, data_format, name=name) @@ -3734,6 +3765,7 @@ space_to_depth.__doc__ = gen_array_ops.space_to_depth.__doc__ @tf_export("nn.space_to_depth", v1=[]) +@dispatch.add_dispatch_support def space_to_depth_v2(input, block_size, data_format="NHWC", name=None): # pylint: disable=redefined-builtin return gen_array_ops.space_to_depth(input, block_size, data_format, name=name) @@ -3742,6 +3774,7 @@ space_to_depth_v2.__doc__ = gen_array_ops.space_to_depth.__doc__ @tf_export(v1=["nn.depth_to_space", "depth_to_space"]) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints("depth_to_space") def depth_to_space(input, block_size, name=None, data_format="NHWC"): # pylint: disable=redefined-builtin return gen_array_ops.depth_to_space(input, block_size, data_format, name=name) @@ -3751,6 +3784,7 @@ depth_to_space.__doc__ = gen_array_ops.depth_to_space.__doc__ @tf_export("nn.depth_to_space", v1=[]) +@dispatch.add_dispatch_support def depth_to_space_v2(input, block_size, data_format="NHWC", name=None): # pylint: disable=redefined-builtin return gen_array_ops.depth_to_space(input, block_size, data_format, name=name) @@ -3759,6 +3793,7 @@ depth_to_space_v2.__doc__ = gen_array_ops.depth_to_space.__doc__ @tf_export(v1=["batch_to_space"]) +@dispatch.add_dispatch_support def batch_to_space(input, crops, block_size, name=None, block_shape=None): # pylint: disable=redefined-builtin,missing-docstring block_size = deprecation.deprecated_argument_lookup("block_shape", block_shape, "block_size", @@ -3776,6 +3811,7 @@ batch_to_space.__doc__ = gen_array_ops.batch_to_space.__doc__ @tf_export("batch_to_space", v1=[]) +@dispatch.add_dispatch_support def batch_to_space_v2(input, block_shape, crops, name=None): # pylint: disable=redefined-builtin """BatchToSpace for N-D tensors of type T. @@ -4091,6 +4127,7 @@ def _all_dimensions(x): @tf_export("sequence_mask") +@dispatch.add_dispatch_support def sequence_mask(lengths, maxlen=None, dtype=dtypes.bool, name=None): """Returns a mask tensor representing the first N positions of each cell. @@ -4317,6 +4354,7 @@ def where(condition, x=None, y=None, name=None): @tf_export("where", v1=["where_v2"]) +@dispatch.add_dispatch_support def where_v2(condition, x=None, y=None, name=None): """Return the elements where `condition` is `True` (multiplexing `x` and `y`). @@ -4435,8 +4473,8 @@ def reverse_sequence(input, dimension `seq_axis`. The elements of `seq_lengths` must obey `seq_lengths[i] <= - input.dims[seq_dim]`, and `seq_lengths` must be a vector of length - `input.dims[batch_dim]`. + input.dims[seq_axis]`, and `seq_lengths` must be a vector of length + `input.dims[batch_axis]`. The output slice `i` along dimension `batch_axis` is then given by input slice `i`, with the first `seq_lengths[i]` slices along @@ -4458,8 +4496,8 @@ def reverse_sequence(input, Args: input: A `Tensor`. The input to reverse. seq_lengths: A `Tensor`. Must be one of the following types: `int32`, - `int64`. 1-D with length `input.dims(batch_dim)` and `max(seq_lengths) <= - input.dims(seq_dim)` + `int64`. 1-D with length `input.dims(batch_axis)` and `max(seq_lengths) <= + input.dims(seq_axis)` seq_axis: An `int`. The dimension which is partially reversed. batch_axis: An optional `int`. Defaults to `0`. The dimension along which reversal is performed. @@ -5003,6 +5041,7 @@ def batch_gather_nd(params, indices, batch_dims, name=None): # because round_mode was added later. # (And also now because of 'axis' processing). @tf_export(v1=["quantize_v2"]) +@dispatch.add_dispatch_support @deprecation.deprecated( "2017-10-25", "`tf.quantize_v2` is deprecated, please use `tf.quantization.quantize` " @@ -5056,6 +5095,7 @@ quantize_v2.__doc__ = """Please use `tf.quantization.quantize` instead.""" # tf.quantization.quantize; we can deprecate tf.quantization.quantize in next # version of TensorFlow. @tf_export("quantization.quantize", v1=["quantization.quantize", "quantize"]) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints("quantize") def quantize( input, # pylint: disable=redefined-builtin @@ -5095,6 +5135,7 @@ def quantize( @tf_export("quantization.dequantize", v1=["quantization.dequantize", "dequantize"]) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints("dequantize") def dequantize( # pylint: disable=missing-docstring input, # pylint: disable=redefined-builtin @@ -5130,6 +5171,7 @@ dequantize.__doc__ = gen_array_ops.dequantize.__doc__ @tf_export("quantization.quantize_and_dequantize") +@dispatch.add_dispatch_support def quantize_and_dequantize( input, # pylint: disable=redefined-builtin input_min, @@ -5189,6 +5231,7 @@ def quantize_and_dequantize( @tf_export("searchsorted") +@dispatch.add_dispatch_support def searchsorted(sorted_sequence, values, side="left", @@ -5253,6 +5296,7 @@ quantize.__doc__ = gen_array_ops.quantize_v2.__doc__ @tf_export("image.extract_patches") +@dispatch.add_dispatch_support def extract_image_patches_v2(images, sizes, strides, rates, padding, name=None): r"""Extract `patches` from `images`. @@ -5374,6 +5418,7 @@ def extract_image_patches_v2(images, sizes, strides, rates, padding, name=None): @tf_export(v1=["image.extract_image_patches", "extract_image_patches"]) +@dispatch.add_dispatch_support @deprecation.deprecated_args(None, "ksizes is deprecated, use sizes instead", "ksizes") def extract_image_patches( # pylint: disable=missing-docstring @@ -5422,6 +5467,7 @@ extract_image_patches.__doc__ = gen_array_ops.extract_image_patches.__doc__ @tf_export("fingerprint") +@dispatch.add_dispatch_support def fingerprint(data, method="farmhash64", name=None): r"""Generates fingerprint values. @@ -5668,6 +5714,7 @@ def _with_nonzero_rank(data): @tf_export("repeat") +@dispatch.add_dispatch_support def repeat(input, repeats, axis=None, name=None): # pylint: disable=redefined-builtin """Repeat elements of `input`. diff --git a/tensorflow/python/ops/bincount.py b/tensorflow/python/ops/bincount_ops.py similarity index 51% rename from tensorflow/python/ops/bincount.py rename to tensorflow/python/ops/bincount_ops.py index 68950eaf596..758f0180a84 100644 --- a/tensorflow/python/ops/bincount.py +++ b/tensorflow/python/ops/bincount_ops.py @@ -12,21 +12,245 @@ # See the License for the specific language governing permissions and # maxlengthations under the License. # ============================================================================== -"""tf.sparse.bincount ops.""" +"""bincount ops.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import gen_count_ops +from tensorflow.python.ops import gen_math_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops.ragged import ragged_tensor +from tensorflow.python.util import deprecation from tensorflow.python.util.tf_export import tf_export +@tf_export("math.bincount", v1=[]) +def bincount(arr, + weights=None, + minlength=None, + maxlength=None, + dtype=dtypes.int32, + name=None, + axis=None, + binary_output=False): + """Counts the number of occurrences of each value in an integer array. + + If `minlength` and `maxlength` are not given, returns a vector with length + `tf.reduce_max(arr) + 1` if `arr` is non-empty, and length 0 otherwise. + If `weights` are non-None, then index `i` of the output stores the sum of the + value in `weights` at each index where the corresponding value in `arr` is + `i`. + + ```python + values = tf.constant([1,1,2,3,2,4,4,5]) + tf.math.bincount(values) #[0 2 2 1 2 1] + ``` + Vector length = Maximum element in vector `values` is 5. Adding 1, which is 6 + will be the vector length. + + Each bin value in the output indicates number of occurrences of the particular + index. Here, index 1 in output has a value 2. This indicates value 1 occurs + two times in `values`. + + ```python + values = tf.constant([1,1,2,3,2,4,4,5]) + weights = tf.constant([1,5,0,1,0,5,4,5]) + tf.math.bincount(values, weights=weights) #[0 6 0 1 9 5] + ``` + Bin will be incremented by the corresponding weight instead of 1. + Here, index 1 in output has a value 6. This is the summation of weights + corresponding to the value in `values`. + + **Bin-counting on a certain axis** + + This example takes a 2 dimensional input and returns a `Tensor` with + bincounting on each sample. + + >>> data = np.array([[1, 2, 3, 0], [0, 0, 1, 2]], dtype=np.int32) + >>> tf.math.bincount(data, axis=-1) + <tf.Tensor: shape=(2, 4), dtype=int32, numpy= + array([[1, 1, 1, 1], + [2, 1, 1, 0]], dtype=int32)> + + + **Bin-counting with binary_output** + + This example gives binary output instead of counting the occurrence. + + >>> data = np.array([[1, 2, 3, 0], [0, 0, 1, 2]], dtype=np.int32) + >>> tf.math.bincount(data, axis=-1, binary_output=True) + <tf.Tensor: shape=(2, 4), dtype=int32, numpy= + array([[1, 1, 1, 1], + [1, 1, 1, 0]], dtype=int32)> + + Args: + arr: A Tensor, RaggedTensor, or SparseTensor whose values should be counted. + These tensors must have a rank of 2 if `axis=-1`. + weights: If non-None, must be the same shape as arr. For each value in + `arr`, the bin will be incremented by the corresponding weight instead of + 1. + minlength: If given, ensures the output has length at least `minlength`, + padding with zeros at the end if necessary. + maxlength: If given, skips values in `arr` that are equal or greater than + `maxlength`, ensuring that the output has length at most `maxlength`. + dtype: If `weights` is None, determines the type of the output bins. + name: A name scope for the associated operations (optional). + axis: The axis to slice over. Axes at and below `axis` will be flattened + before bin counting. Currently, only `0`, and `-1` are supported. If None, + all axes will be flattened (identical to passing `0`). + binary_output: If True, this op will output 1 instead of the number of times + a token appears (equivalent to one_hot + reduce_any instead of one_hot + + reduce_add). Defaults to False. + + Returns: + A vector with the same dtype as `weights` or the given `dtype`. The bin + values. + + Raises: + `InvalidArgumentError` if negative values are provided as an input. + + """ + name = "bincount" if name is None else name + with ops.name_scope(name): + # Somehow forward compatible needs to be False. + if not binary_output and axis is None: + arr = ops.convert_to_tensor(arr, name="arr", dtype=dtypes.int32) + array_is_nonempty = math_ops.reduce_prod(array_ops.shape(arr)) > 0 + output_size = math_ops.cast(array_is_nonempty, dtypes.int32) * ( + math_ops.reduce_max(arr) + 1) + if minlength is not None: + minlength = ops.convert_to_tensor( + minlength, name="minlength", dtype=dtypes.int32) + output_size = gen_math_ops.maximum(minlength, output_size) + if maxlength is not None: + maxlength = ops.convert_to_tensor( + maxlength, name="maxlength", dtype=dtypes.int32) + output_size = gen_math_ops.minimum(maxlength, output_size) + if weights is not None: + weights = ops.convert_to_tensor(weights, name="weights") + return gen_math_ops.unsorted_segment_sum(weights, arr, output_size) + weights = constant_op.constant([], dtype) + return gen_math_ops.bincount(arr, output_size, weights) + + if not isinstance(arr, sparse_tensor.SparseTensor): + arr = ragged_tensor.convert_to_tensor_or_ragged_tensor(arr, name="arr") + if weights is not None: + if not isinstance(weights, sparse_tensor.SparseTensor): + weights = ragged_tensor.convert_to_tensor_or_ragged_tensor( + weights, name="weights") + + if weights is not None and binary_output: + raise ValueError("binary_output and weights are mutually exclusive.") + + if not arr.dtype.is_integer: + arr = math_ops.cast(arr, dtypes.int32) + if axis is None: + axis = 0 + + if axis not in [0, -1]: + raise ValueError("Unsupported axis value %s. Only 0 and -1 are currently " + "supported." % axis) + + if isinstance(arr, ragged_tensor.RaggedTensor): + array_is_nonempty = math_ops.reduce_prod(array_ops.shape(arr.values)) > 0 + else: + array_is_nonempty = math_ops.reduce_prod(array_ops.shape(arr)) > 0 + if isinstance(arr, sparse_tensor.SparseTensor): + output_size = math_ops.cast(array_is_nonempty, arr.dtype) * ( + math_ops.reduce_max(arr.values) + 1) + else: + output_size = math_ops.cast(array_is_nonempty, arr.dtype) * ( + math_ops.reduce_max(arr) + 1) + if minlength is not None: + minlength = ops.convert_to_tensor( + minlength, name="minlength", dtype=arr.dtype) + output_size = gen_math_ops.maximum(minlength, output_size) + if maxlength is not None: + maxlength = ops.convert_to_tensor( + maxlength, name="maxlength", dtype=arr.dtype) + output_size = gen_math_ops.minimum(maxlength, output_size) + + if axis == 0: + if isinstance(arr, sparse_tensor.SparseTensor): + if weights is not None: + weights = validate_sparse_weights(arr, weights, dtype) + arr = arr.values + elif isinstance(arr, ragged_tensor.RaggedTensor): + if weights is not None: + weights = validate_ragged_weights(arr, weights, dtype) + arr = arr.values + else: + if weights is not None: + weights = array_ops.reshape(weights, [-1]) + arr = array_ops.reshape(arr, [-1]) + + if isinstance(arr, sparse_tensor.SparseTensor): + weights = validate_sparse_weights(arr, weights, dtype) + return gen_math_ops.sparse_bincount( + indices=arr.indices, + values=arr.values, + dense_shape=arr.dense_shape, + size=output_size, + weights=weights, + binary_output=binary_output) + elif isinstance(arr, ragged_tensor.RaggedTensor): + weights = validate_ragged_weights(arr, weights, dtype) + return gen_math_ops.ragged_bincount( + splits=arr.row_splits, + values=arr.values, + size=output_size, + weights=weights, + binary_output=binary_output) + else: + weights = validate_dense_weights(arr, weights, dtype) + return gen_math_ops.dense_bincount( + input=arr, + size=output_size, + weights=weights, + binary_output=binary_output) + + +@tf_export(v1=["math.bincount", "bincount"]) +@deprecation.deprecated_endpoints("bincount") +def bincount_v1(arr, + weights=None, + minlength=None, + maxlength=None, + dtype=dtypes.int32): + """Counts the number of occurrences of each value in an integer array. + + If `minlength` and `maxlength` are not given, returns a vector with length + `tf.reduce_max(arr) + 1` if `arr` is non-empty, and length 0 otherwise. + If `weights` are non-None, then index `i` of the output stores the sum of the + value in `weights` at each index where the corresponding value in `arr` is + `i`. + + Args: + arr: An int32 tensor of non-negative values. + weights: If non-None, must be the same shape as arr. For each value in + `arr`, the bin will be incremented by the corresponding weight instead of + 1. + minlength: If given, ensures the output has length at least `minlength`, + padding with zeros at the end if necessary. + maxlength: If given, skips values in `arr` that are equal or greater than + `maxlength`, ensuring that the output has length at most `maxlength`. + dtype: If `weights` is None, determines the type of the output bins. + + Returns: + A vector with the same dtype as `weights` or the given `dtype`. The bin + values. + """ + return bincount(arr, weights, minlength, maxlength, dtype) + + @tf_export("sparse.bincount") def sparse_bincount(values, weights=None, @@ -45,19 +269,17 @@ def sparse_bincount(values, Args: values: A Tensor, RaggedTensor, or SparseTensor whose values should be - counted. These tensors must have a rank of 1 or 2. - weights: A 1-dimensional Tensor of weights. If specified, the input array is - weighted by the weight array, i.e. if a value `n` is found at position - `i`, `out[n]` will be increased by `weight[i]` instead of 1. + counted. These tensors must have a rank of 2 if `axis=-1`. + weights: If non-None, must be the same shape as arr. For each value in + `value`, the bin will be incremented by the corresponding weight instead + of 1. axis: The axis to slice over. Axes at and below `axis` will be flattened before bin counting. Currently, only `0`, and `-1` are supported. If None, all axes will be flattened (identical to passing `0`). - minlength: If given, skips `values` that are less than `minlength`, and - ensures that the output has a `dense_shape` of at least `minlength` in the - inner dimension. - maxlength: If given, skips `values` that are greater than or equal to - `maxlength`, and ensures that the output has a `dense_shape` of at most - `maxlength` in the inner dimension. + minlength: If given, ensures the output has length at least `minlength`, + padding with zeros at the end if necessary. + maxlength: If given, skips values in `values` that are equal or greater than + `maxlength`, ensuring that the output has length at most `maxlength`. binary_output: If True, this op will output 1 instead of the number of times a token appears (equivalent to one_hot + reduce_any instead of one_hot + reduce_add). Defaults to False. @@ -229,9 +451,11 @@ def sparse_bincount(values, return sparse_tensor.SparseTensor(c_ind, c_val, c_shape) -def validate_dense_weights(values, weights): +def validate_dense_weights(values, weights, dtype=None): """Validates the passed weight tensor or creates an empty one.""" if weights is None: + if dtype: + return array_ops.constant([], dtype=dtype) return array_ops.constant([], dtype=values.dtype) if not isinstance(weights, ops.Tensor): @@ -241,9 +465,11 @@ def validate_dense_weights(values, weights): return weights -def validate_sparse_weights(values, weights): +def validate_sparse_weights(values, weights, dtype=None): """Validates the passed weight tensor or creates an empty one.""" if weights is None: + if dtype: + return array_ops.constant([], dtype=dtype) return array_ops.constant([], dtype=values.values.dtype) if not isinstance(weights, sparse_tensor.SparseTensor): @@ -273,9 +499,11 @@ def validate_sparse_weights(values, weights): return weights -def validate_ragged_weights(values, weights): +def validate_ragged_weights(values, weights, dtype=None): """Validates the passed weight tensor or creates an empty one.""" if weights is None: + if dtype: + return array_ops.constant([], dtype=dtype) return array_ops.constant([], dtype=values.values.dtype) if not isinstance(weights, ragged_tensor.RaggedTensor): diff --git a/tensorflow/python/ops/bincount_test.py b/tensorflow/python/ops/bincount_ops_test.py similarity index 71% rename from tensorflow/python/ops/bincount_test.py rename to tensorflow/python/ops/bincount_ops_test.py index 839af8dcc35..74fd17cae2b 100644 --- a/tensorflow/python/ops/bincount_test.py +++ b/tensorflow/python/ops/bincount_ops_test.py @@ -23,9 +23,12 @@ import numpy as np from tensorflow.python.eager import context from tensorflow.python.framework import errors -from tensorflow.python.ops import bincount +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import bincount_ops from tensorflow.python.ops import sparse_ops from tensorflow.python.ops.ragged import ragged_factory_ops +from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.platform import test @@ -151,7 +154,7 @@ class TestSparseCount(test.TestCase, parameterized.TestCase): binary_output=False, weights=None, axis=-1): - y = bincount.sparse_bincount( + y = bincount_ops.sparse_bincount( x, weights=weights, minlength=minlength, @@ -349,7 +352,7 @@ class TestSparseCount(test.TestCase, parameterized.TestCase): axis=-1): x_sparse = sparse_ops.from_dense(x) w_sparse = sparse_ops.from_dense(weights) if weights is not None else None - y = bincount.sparse_bincount( + y = bincount_ops.sparse_bincount( x_sparse, weights=w_sparse, minlength=minlength, @@ -496,7 +499,7 @@ class TestSparseCount(test.TestCase, parameterized.TestCase): axis=-1): x_ragged = ragged_factory_ops.constant(x) w = ragged_factory_ops.constant(weights) if weights is not None else None - y = bincount.sparse_bincount( + y = bincount_ops.sparse_bincount( x_ragged, weights=w, minlength=minlength, @@ -508,6 +511,237 @@ class TestSparseCount(test.TestCase, parameterized.TestCase): self.assertAllEqual(expected_shape, y.dense_shape) +class TestDenseBincount(test.TestCase, parameterized.TestCase): + + @parameterized.parameters([{ + "dtype": np.int32, + }, { + "dtype": np.int64, + }]) + def test_sparse_input_all_count(self, dtype): + np.random.seed(42) + num_rows = 128 + size = 1000 + n_elems = 4096 + inp_indices = np.random.randint(0, num_rows, (n_elems, 1)) + inp_indices = np.concatenate([inp_indices, np.zeros((n_elems, 1))], axis=1) + inp_vals = np.random.randint(0, size, (n_elems,), dtype=dtype) + sparse_inp = sparse_tensor.SparseTensor(inp_indices, inp_vals, + [num_rows, 1]) + + np_out = np.bincount(inp_vals, minlength=size) + self.assertAllEqual( + np_out, self.evaluate(bincount_ops.bincount(sparse_inp, axis=0))) + + @parameterized.parameters([{ + "dtype": np.int32, + }, { + "dtype": np.int64, + }]) + def test_sparse_input_all_count_with_weights(self, dtype): + np.random.seed(42) + num_rows = 128 + size = 1000 + n_elems = 4096 + inp_indices = np.random.randint(0, num_rows, (n_elems, 1)) + inp_indices = np.concatenate([inp_indices, np.zeros((n_elems, 1))], axis=1) + inp_vals = np.random.randint(0, size, (n_elems,), dtype=dtype) + sparse_inp = sparse_tensor.SparseTensor(inp_indices, inp_vals, + [num_rows, 1]) + weight_vals = np.random.random((n_elems,)) + sparse_weights = sparse_tensor.SparseTensor(inp_indices, weight_vals, + [num_rows, 1]) + + np_out = np.bincount(inp_vals, minlength=size, weights=weight_vals) + self.assertAllEqual( + np_out, + self.evaluate(bincount_ops.bincount( + sparse_inp, sparse_weights, axis=0))) + + @parameterized.parameters([{ + "dtype": np.int32, + }, { + "dtype": np.int64, + }]) + def test_sparse_input_all_binary(self, dtype): + np.random.seed(42) + num_rows = 128 + size = 10 + n_elems = 4096 + inp_indices = np.random.randint(0, num_rows, (n_elems, 1)) + inp_indices = np.concatenate([inp_indices, np.zeros((n_elems, 1))], axis=1) + inp_vals = np.random.randint(0, size, (n_elems,), dtype=dtype) + sparse_inp = sparse_tensor.SparseTensor(inp_indices, inp_vals, + [num_rows, 1]) + + np_out = np.ones((size,)) + self.assertAllEqual( + np_out, + self.evaluate(bincount_ops.bincount(sparse_inp, binary_output=True))) + + @parameterized.parameters([{ + "dtype": np.int32, + }, { + "dtype": np.int64, + }]) + def test_sparse_input_col_reduce_count(self, dtype): + num_rows = 128 + num_cols = 27 + size = 100 + np.random.seed(42) + inp = np.random.randint(0, size, (num_rows, num_cols), dtype=dtype) + np_out = np.reshape( + np.concatenate( + [np.bincount(inp[j, :], minlength=size) for j in range(num_rows)], + axis=0), (num_rows, size)) + # from_dense will filter out 0s. + inp = inp + 1 + # from_dense will cause OOM in GPU. + with ops.device("/CPU:0"): + inp_sparse = sparse_ops.from_dense(inp) + inp_sparse = sparse_tensor.SparseTensor(inp_sparse.indices, + inp_sparse.values - 1, + inp_sparse.dense_shape) + self.assertAllEqual( + np_out, self.evaluate(bincount_ops.bincount(arr=inp_sparse, axis=-1))) + + @parameterized.parameters([{ + "dtype": np.int32, + }, { + "dtype": np.int64, + }]) + def test_sparse_input_col_reduce_binary(self, dtype): + num_rows = 128 + num_cols = 27 + size = 100 + np.random.seed(42) + inp = np.random.randint(0, size, (num_rows, num_cols), dtype=dtype) + np_out = np.reshape( + np.concatenate([ + np.where(np.bincount(inp[j, :], minlength=size) > 0, 1, 0) + for j in range(num_rows) + ], + axis=0), (num_rows, size)) + # from_dense will filter out 0s. + inp = inp + 1 + # from_dense will cause OOM in GPU. + with ops.device("/CPU:0"): + inp_sparse = sparse_ops.from_dense(inp) + inp_sparse = sparse_tensor.SparseTensor(inp_sparse.indices, + inp_sparse.values - 1, + inp_sparse.dense_shape) + self.assertAllEqual( + np_out, + self.evaluate( + bincount_ops.bincount(arr=inp_sparse, axis=-1, binary_output=True))) + + @parameterized.parameters([{ + "dtype": np.int32, + }, { + "dtype": np.int64, + }]) + def test_ragged_input_count(self, dtype): + x = ragged_factory_ops.constant([[], [], [3, 0, 1], [], [5, 0, 4, 4]], + dtype) + # pyformat: disable + expected_output = [ + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 2, 1]] + # pyformat: enable + self.assertAllEqual(expected_output, + self.evaluate(bincount_ops.bincount(arr=x, axis=-1))) + + @parameterized.parameters([{ + "dtype": np.int32, + }, { + "dtype": np.int64, + }]) + def test_ragged_input_binary(self, dtype): + x = ragged_factory_ops.constant([[], [], [3, 0, 1], [], [5, 0, 4, 4]]) + # pyformat: disable + expected_output = [ + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 1, 1]] + # pyformat: enable + self.assertAllEqual( + expected_output, + self.evaluate( + bincount_ops.bincount(arr=x, axis=-1, binary_output=True))) + + @parameterized.parameters([{ + "dtype": np.int32, + }, { + "dtype": np.int64, + }]) + def test_ragged_input_count_with_weights(self, dtype): + x = ragged_factory_ops.constant([[], [], [3, 0, 1], [], [5, 0, 4, 4]]) + weights = ragged_factory_ops.constant([[], [], [.1, .2, .3], [], + [.2, .5, .6, .3]]) + # pyformat: disable + expected_output = [ + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [.2, .3, 0, .1, 0, 0], + [0, 0, 0, 0, 0, 0], + [.5, 0, 0, 0, .9, .2]] + # pyformat: enable + self.assertAllClose( + expected_output, + self.evaluate(bincount_ops.bincount(arr=x, weights=weights, axis=-1))) + + @parameterized.parameters([{ + "dtype": np.int32, + }, { + "dtype": np.int64, + }]) + def test_ragged_input_count_np(self, dtype): + np.random.seed(42) + num_rows = 128 + num_cols = 27 + size = 1000 + inp = np.random.randint(0, size, (num_rows, num_cols), dtype=dtype) + np_out = np.reshape( + np.concatenate( + [np.bincount(inp[j, :], minlength=size) for j in range(num_rows)], + axis=0), (num_rows, size)) + x = ragged_tensor.RaggedTensor.from_tensor(inp) + self.assertAllEqual( + np_out, + self.evaluate(bincount_ops.bincount(arr=x, minlength=size, axis=-1))) + + @parameterized.parameters([{ + "dtype": np.int32, + }, { + "dtype": np.int64, + }]) + def test_ragged_input_count_np_with_weights(self, dtype): + np.random.seed(42) + num_rows = 128 + num_cols = 27 + size = 1000 + inp = np.random.randint(0, size, (num_rows, num_cols), dtype=dtype) + np_weight = np.random.random((num_rows, num_cols)) + np_out = np.reshape( + np.concatenate([ + np.bincount(inp[j, :], weights=np_weight[j, :], minlength=size) + for j in range(num_rows) + ], + axis=0), (num_rows, size)) + x = ragged_tensor.RaggedTensor.from_tensor(inp) + weights = ragged_tensor.RaggedTensor.from_tensor(np_weight) + self.assertAllEqual( + np_out, + self.evaluate( + bincount_ops.bincount( + arr=x, weights=weights, minlength=size, axis=-1))) + + class TestSparseCountFailureModes(test.TestCase): def test_dense_input_sparse_weights_fails(self): @@ -515,13 +749,13 @@ class TestSparseCountFailureModes(test.TestCase): weights = sparse_ops.from_dense( np.array([[3, 0, 1, 0], [0, 0, 0, 0], [5, 0, 4, 4]], dtype=np.int32)) with self.assertRaisesRegexp(ValueError, "must be a tf.Tensor"): - self.evaluate(bincount.sparse_bincount(x, weights=weights, axis=-1)) + self.evaluate(bincount_ops.sparse_bincount(x, weights=weights, axis=-1)) def test_dense_input_ragged_weights_fails(self): x = np.array([[3, 2, 1], [5, 4, 4]], dtype=np.int32) weights = ragged_factory_ops.constant([[6, 0.5, 2], [14], [10, 0.25, 5, 3]]) with self.assertRaisesRegexp(ValueError, "must be a tf.Tensor"): - self.evaluate(bincount.sparse_bincount(x, weights=weights, axis=-1)) + self.evaluate(bincount_ops.sparse_bincount(x, weights=weights, axis=-1)) def test_dense_input_wrong_shape_fails(self): x = np.array([[3, 2, 1], [5, 4, 4]], dtype=np.int32) @@ -532,24 +766,24 @@ class TestSparseCountFailureModes(test.TestCase): if context.executing_eagerly(): with self.assertRaisesRegexp(errors.InvalidArgumentError, "must have the same shape"): - self.evaluate(bincount.sparse_bincount(x, weights=weights, axis=-1)) + self.evaluate(bincount_ops.sparse_bincount(x, weights=weights, axis=-1)) else: with self.assertRaisesRegexp(ValueError, "both shapes must be equal"): - self.evaluate(bincount.sparse_bincount(x, weights=weights, axis=-1)) + self.evaluate(bincount_ops.sparse_bincount(x, weights=weights, axis=-1)) def test_sparse_input_dense_weights_fails(self): x = sparse_ops.from_dense( np.array([[3, 0, 1, 0], [0, 0, 0, 0], [5, 0, 4, 4]], dtype=np.int32)) weights = np.array([[3, 2, 1], [5, 4, 4]], dtype=np.int32) with self.assertRaisesRegexp(ValueError, "must be a SparseTensor"): - self.evaluate(bincount.sparse_bincount(x, weights=weights, axis=-1)) + self.evaluate(bincount_ops.sparse_bincount(x, weights=weights, axis=-1)) def test_sparse_input_ragged_weights_fails(self): x = sparse_ops.from_dense( np.array([[3, 0, 1, 0], [0, 0, 0, 0], [5, 0, 4, 4]], dtype=np.int32)) weights = ragged_factory_ops.constant([[6, 0.5, 2], [14], [10, 0.25, 5, 3]]) with self.assertRaisesRegexp(ValueError, "must be a SparseTensor"): - self.evaluate(bincount.sparse_bincount(x, weights=weights, axis=-1)) + self.evaluate(bincount_ops.sparse_bincount(x, weights=weights, axis=-1)) def test_sparse_input_wrong_indices_fails(self): x = sparse_ops.from_dense( @@ -558,7 +792,7 @@ class TestSparseCountFailureModes(test.TestCase): np.array([[3, 1, 0, 0], [0, 0, 0, 0], [5, 0, 4, 4]], dtype=np.int32)) with self.assertRaisesRegexp(errors.InvalidArgumentError, "must have the same indices"): - self.evaluate(bincount.sparse_bincount(x, weights=weights, axis=-1)) + self.evaluate(bincount_ops.sparse_bincount(x, weights=weights, axis=-1)) def test_sparse_input_too_many_indices_fails(self): x = sparse_ops.from_dense( @@ -567,7 +801,7 @@ class TestSparseCountFailureModes(test.TestCase): np.array([[3, 1, 1, 0], [0, 0, 0, 0], [5, 0, 4, 4]], dtype=np.int32)) with self.assertRaisesRegexp(errors.InvalidArgumentError, "Incompatible shapes"): - self.evaluate(bincount.sparse_bincount(x, weights=weights, axis=-1)) + self.evaluate(bincount_ops.sparse_bincount(x, weights=weights, axis=-1)) def test_sparse_input_wrong_shape_fails(self): x = sparse_ops.from_dense( @@ -577,27 +811,27 @@ class TestSparseCountFailureModes(test.TestCase): dtype=np.int32)) with self.assertRaisesRegexp(errors.InvalidArgumentError, "must have the same dense shape"): - self.evaluate(bincount.sparse_bincount(x, weights=weights, axis=-1)) + self.evaluate(bincount_ops.sparse_bincount(x, weights=weights, axis=-1)) def test_ragged_input_dense_weights_fails(self): x = ragged_factory_ops.constant([[6, 1, 2], [14], [10, 1, 5, 3]]) weights = np.array([[3, 2, 1], [5, 4, 4]], dtype=np.int32) with self.assertRaisesRegexp(ValueError, "must be a RaggedTensor"): - self.evaluate(bincount.sparse_bincount(x, weights=weights, axis=-1)) + self.evaluate(bincount_ops.sparse_bincount(x, weights=weights, axis=-1)) def test_ragged_input_sparse_weights_fails(self): x = ragged_factory_ops.constant([[6, 1, 2], [14], [10, 1, 5, 3]]) weights = sparse_ops.from_dense( np.array([[3, 0, 1, 0], [0, 0, 0, 0], [5, 0, 4, 4]], dtype=np.int32)) with self.assertRaisesRegexp(ValueError, "must be a RaggedTensor"): - self.evaluate(bincount.sparse_bincount(x, weights=weights, axis=-1)) + self.evaluate(bincount_ops.sparse_bincount(x, weights=weights, axis=-1)) def test_ragged_input_different_shape_fails(self): x = ragged_factory_ops.constant([[6, 1, 2], [14], [10, 1, 5, 3]]) weights = ragged_factory_ops.constant([[6, 0.5, 2], [], [10, 0.25, 5, 3]]) with self.assertRaisesRegexp(errors.InvalidArgumentError, "must have the same row splits"): - self.evaluate(bincount.sparse_bincount(x, weights=weights, axis=-1)) + self.evaluate(bincount_ops.sparse_bincount(x, weights=weights, axis=-1)) if __name__ == "__main__": diff --git a/tensorflow/python/ops/candidate_sampling_ops.py b/tensorflow/python/ops/candidate_sampling_ops.py index 56f76a49d51..6c1a36e65c9 100644 --- a/tensorflow/python/ops/candidate_sampling_ops.py +++ b/tensorflow/python/ops/candidate_sampling_ops.py @@ -24,12 +24,14 @@ from tensorflow.python.ops import array_ops # pylint: disable=unused-import from tensorflow.python.ops import gen_candidate_sampling_ops from tensorflow.python.ops import math_ops # pylint: disable=unused-import from tensorflow.python.util import deprecation +from tensorflow.python.util import dispatch from tensorflow.python.util.tf_export import tf_export @tf_export( 'random.uniform_candidate_sampler', v1=['random.uniform_candidate_sampler', 'nn.uniform_candidate_sampler']) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints('nn.uniform_candidate_sampler') def uniform_candidate_sampler(true_classes, num_true, num_sampled, unique, range_max, seed=None, name=None): @@ -92,6 +94,7 @@ def uniform_candidate_sampler(true_classes, num_true, num_sampled, unique, 'random.log_uniform_candidate_sampler', 'nn.log_uniform_candidate_sampler' ]) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints('nn.log_uniform_candidate_sampler') def log_uniform_candidate_sampler(true_classes, num_true, num_sampled, unique, range_max, seed=None, name=None): @@ -154,6 +157,7 @@ def log_uniform_candidate_sampler(true_classes, num_true, num_sampled, unique, @tf_export( 'random.learned_unigram_candidate_sampler', 'nn.learned_unigram_candidate_sampler') +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints(['nn.learned_unigram_candidate_sampler']) def learned_unigram_candidate_sampler(true_classes, num_true, num_sampled, unique, range_max, seed=None, name=None): @@ -213,6 +217,7 @@ def learned_unigram_candidate_sampler(true_classes, num_true, num_sampled, @tf_export('random.fixed_unigram_candidate_sampler', 'nn.fixed_unigram_candidate_sampler') +@dispatch.add_dispatch_support def fixed_unigram_candidate_sampler(true_classes, num_true, num_sampled, @@ -341,6 +346,7 @@ def all_candidate_sampler(true_classes, num_true, num_sampled, unique, @tf_export('nn.compute_accidental_hits') +@dispatch.add_dispatch_support def compute_accidental_hits(true_classes, sampled_candidates, num_true, seed=None, name=None): """Compute the position ids in `sampled_candidates` matching `true_classes`. diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py index cefca5defae..b50313753d6 100644 --- a/tensorflow/python/ops/check_ops.py +++ b/tensorflow/python/ops/check_ops.py @@ -35,6 +35,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.util import compat from tensorflow.python.util import deprecation +from tensorflow.python.util import dispatch from tensorflow.python.util.tf_export import tf_export NUMERIC_TYPES = frozenset( @@ -375,6 +376,7 @@ def _binary_assert(sym, opname, op_func, static_func, x, y, data, summarize, @tf_export( 'debugging.assert_proper_iterable', v1=['debugging.assert_proper_iterable', 'assert_proper_iterable']) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints('assert_proper_iterable') def assert_proper_iterable(values): """Static assert that values is a "proper" iterable. @@ -404,6 +406,7 @@ def assert_proper_iterable(values): @tf_export('debugging.assert_negative', v1=[]) +@dispatch.add_dispatch_support def assert_negative_v2(x, message=None, summarize=None, name=None): """Assert the condition `x < 0` holds element-wise. @@ -436,6 +439,7 @@ def assert_negative_v2(x, message=None, summarize=None, name=None): @tf_export(v1=['debugging.assert_negative', 'assert_negative']) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints('assert_negative') @_unary_assert_doc('< 0', 'negative') def assert_negative(x, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring @@ -456,6 +460,7 @@ def assert_negative(x, data=None, summarize=None, message=None, name=None): # p @tf_export('debugging.assert_positive', v1=[]) +@dispatch.add_dispatch_support def assert_positive_v2(x, message=None, summarize=None, name=None): """Assert the condition `x > 0` holds element-wise. @@ -488,6 +493,7 @@ def assert_positive_v2(x, message=None, summarize=None, name=None): @tf_export(v1=['debugging.assert_positive', 'assert_positive']) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints('assert_positive') @_unary_assert_doc('> 0', 'positive') def assert_positive(x, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring @@ -507,6 +513,7 @@ def assert_positive(x, data=None, summarize=None, message=None, name=None): # p @tf_export('debugging.assert_non_negative', v1=[]) +@dispatch.add_dispatch_support def assert_non_negative_v2(x, message=None, summarize=None, name=None): """Assert the condition `x >= 0` holds element-wise. @@ -541,6 +548,7 @@ def assert_non_negative_v2(x, message=None, summarize=None, name=None): @tf_export(v1=['debugging.assert_non_negative', 'assert_non_negative']) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints('assert_non_negative') @_unary_assert_doc('>= 0', 'non-negative') def assert_non_negative(x, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring @@ -561,6 +569,7 @@ def assert_non_negative(x, data=None, summarize=None, message=None, name=None): @tf_export('debugging.assert_non_positive', v1=[]) +@dispatch.add_dispatch_support def assert_non_positive_v2(x, message=None, summarize=None, name=None): """Assert the condition `x <= 0` holds element-wise. @@ -595,6 +604,7 @@ def assert_non_positive_v2(x, message=None, summarize=None, name=None): @tf_export(v1=['debugging.assert_non_positive', 'assert_non_positive']) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints('assert_non_positive') @_unary_assert_doc('<= 0', 'non-positive') def assert_non_positive(x, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring @@ -615,6 +625,7 @@ def assert_non_positive(x, data=None, summarize=None, message=None, name=None): @tf_export('debugging.assert_equal', 'assert_equal', v1=[]) +@dispatch.add_dispatch_support def assert_equal_v2(x, y, message=None, summarize=None, name=None): """Assert the condition `x == y` holds element-wise. @@ -649,6 +660,7 @@ def assert_equal_v2(x, y, message=None, summarize=None, name=None): @tf_export(v1=['debugging.assert_equal', 'assert_equal']) +@dispatch.add_dispatch_support @_binary_assert_doc('==') def assert_equal(x, y, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring with ops.name_scope(name, 'assert_equal', [x, y, data]): @@ -660,6 +672,7 @@ def assert_equal(x, y, data=None, summarize=None, message=None, name=None): # p @tf_export('debugging.assert_none_equal', v1=[]) +@dispatch.add_dispatch_support def assert_none_equal_v2(x, y, summarize=None, message=None, name=None): """Assert the condition `x != y` holds for all elements. @@ -698,6 +711,7 @@ def assert_none_equal_v2(x, y, summarize=None, message=None, name=None): @tf_export(v1=['debugging.assert_none_equal', 'assert_none_equal']) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints('assert_none_equal') @_binary_assert_doc('!=') def assert_none_equal( @@ -707,6 +721,7 @@ def assert_none_equal( @tf_export('debugging.assert_near', v1=[]) +@dispatch.add_dispatch_support def assert_near_v2(x, y, rtol=None, atol=None, message=None, summarize=None, name=None): """Assert the condition `x` and `y` are close element-wise. @@ -750,9 +765,9 @@ def assert_near_v2(x, y, rtol=None, atol=None, message=None, summarize=None, statically known. @compatibility(numpy) - Similar to `numpy.assert_allclose`, except tolerance depends on data type. - This is due to the fact that `TensorFlow` is often used with `32bit`, `64bit`, - and even `16bit` data. + Similar to `numpy.testing.assert_allclose`, except tolerance depends on data + type. This is due to the fact that `TensorFlow` is often used with `32bit`, + `64bit`, and even `16bit` data. @end_compatibility """ return assert_near(x=x, y=y, rtol=rtol, atol=atol, summarize=summarize, @@ -760,6 +775,7 @@ def assert_near_v2(x, y, rtol=None, atol=None, message=None, summarize=None, @tf_export(v1=['debugging.assert_near', 'assert_near']) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints('assert_near') def assert_near( x, y, rtol=None, atol=None, data=None, summarize=None, message=None, @@ -802,9 +818,9 @@ def assert_near( Op that raises `InvalidArgumentError` if `x` and `y` are not close enough. @compatibility(numpy) - Similar to `numpy.assert_allclose`, except tolerance depends on data type. - This is due to the fact that `TensorFlow` is often used with `32bit`, `64bit`, - and even `16bit` data. + Similar to `numpy.testing.assert_allclose`, except tolerance depends on data + type. This is due to the fact that `TensorFlow` is often used with `32bit`, + `64bit`, and even `16bit` data. @end_compatibility """ message = message or '' @@ -839,6 +855,7 @@ def assert_near( @tf_export('debugging.assert_less', 'assert_less', v1=[]) +@dispatch.add_dispatch_support def assert_less_v2(x, y, message=None, summarize=None, name=None): """Assert the condition `x < y` holds element-wise. @@ -874,6 +891,7 @@ def assert_less_v2(x, y, message=None, summarize=None, name=None): @tf_export(v1=['debugging.assert_less', 'assert_less']) +@dispatch.add_dispatch_support @_binary_assert_doc('<') def assert_less(x, y, data=None, summarize=None, message=None, name=None): return _binary_assert('<', 'assert_less', math_ops.less, np.less, x, y, data, @@ -881,6 +899,7 @@ def assert_less(x, y, data=None, summarize=None, message=None, name=None): @tf_export('debugging.assert_less_equal', v1=[]) +@dispatch.add_dispatch_support def assert_less_equal_v2(x, y, message=None, summarize=None, name=None): """Assert the condition `x <= y` holds element-wise. @@ -917,6 +936,7 @@ def assert_less_equal_v2(x, y, message=None, summarize=None, name=None): @tf_export(v1=['debugging.assert_less_equal', 'assert_less_equal']) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints('assert_less_equal') @_binary_assert_doc('<=') def assert_less_equal(x, y, data=None, summarize=None, message=None, name=None): @@ -925,6 +945,7 @@ def assert_less_equal(x, y, data=None, summarize=None, message=None, name=None): @tf_export('debugging.assert_greater', 'assert_greater', v1=[]) +@dispatch.add_dispatch_support def assert_greater_v2(x, y, message=None, summarize=None, name=None): """Assert the condition `x > y` holds element-wise. @@ -961,6 +982,7 @@ def assert_greater_v2(x, y, message=None, summarize=None, name=None): @tf_export(v1=['debugging.assert_greater', 'assert_greater']) +@dispatch.add_dispatch_support @_binary_assert_doc('>') def assert_greater(x, y, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring return _binary_assert('>', 'assert_greater', math_ops.greater, np.greater, x, @@ -968,6 +990,7 @@ def assert_greater(x, y, data=None, summarize=None, message=None, name=None): # @tf_export('debugging.assert_greater_equal', v1=[]) +@dispatch.add_dispatch_support def assert_greater_equal_v2(x, y, message=None, summarize=None, name=None): """Assert the condition `x >= y` holds element-wise. @@ -1005,6 +1028,7 @@ def assert_greater_equal_v2(x, y, message=None, summarize=None, name=None): @tf_export(v1=['debugging.assert_greater_equal', 'assert_greater_equal']) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints('assert_greater_equal') @_binary_assert_doc('>=') def assert_greater_equal(x, y, data=None, summarize=None, message=None, @@ -1062,6 +1086,7 @@ def _assert_rank_condition( @tf_export('debugging.assert_rank', 'assert_rank', v1=[]) +@dispatch.add_dispatch_support def assert_rank_v2(x, rank, message=None, name=None): """Assert that `x` has rank equal to `rank`. @@ -1095,6 +1120,7 @@ def assert_rank_v2(x, rank, message=None, name=None): @tf_export(v1=['debugging.assert_rank', 'assert_rank']) +@dispatch.add_dispatch_support def assert_rank(x, rank, data=None, summarize=None, message=None, name=None): """Assert `x` has rank equal to `rank`. @@ -1157,6 +1183,7 @@ def assert_rank(x, rank, data=None, summarize=None, message=None, name=None): @tf_export('debugging.assert_rank_at_least', v1=[]) +@dispatch.add_dispatch_support def assert_rank_at_least_v2(x, rank, message=None, name=None): """Assert that `x` has rank of at least `rank`. @@ -1190,6 +1217,7 @@ def assert_rank_at_least_v2(x, rank, message=None, name=None): @tf_export(v1=['debugging.assert_rank_at_least', 'assert_rank_at_least']) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints('assert_rank_at_least') def assert_rank_at_least( x, rank, data=None, summarize=None, message=None, name=None): @@ -1322,6 +1350,7 @@ def _assert_ranks_condition( @tf_export('debugging.assert_rank_in', v1=[]) +@dispatch.add_dispatch_support def assert_rank_in_v2(x, ranks, message=None, name=None): """Assert that `x` has a rank in `ranks`. @@ -1354,6 +1383,7 @@ def assert_rank_in_v2(x, ranks, message=None, name=None): @tf_export(v1=['debugging.assert_rank_in', 'assert_rank_in']) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints('assert_rank_in') def assert_rank_in( x, ranks, data=None, summarize=None, message=None, name=None): @@ -1417,6 +1447,7 @@ def assert_rank_in( @tf_export('debugging.assert_integer', v1=[]) +@dispatch.add_dispatch_support def assert_integer_v2(x, message=None, name=None): """Assert that `x` is of integer dtype. @@ -1437,6 +1468,7 @@ def assert_integer_v2(x, message=None, name=None): @tf_export(v1=['debugging.assert_integer', 'assert_integer']) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints('assert_integer') def assert_integer(x, message=None, name=None): """Assert that `x` is of integer dtype. @@ -1476,6 +1508,7 @@ def assert_integer(x, message=None, name=None): @tf_export('debugging.assert_type', v1=[]) +@dispatch.add_dispatch_support def assert_type_v2(tensor, tf_type, message=None, name=None): """Asserts that the given `Tensor` is of the specified type. @@ -1495,6 +1528,7 @@ def assert_type_v2(tensor, tf_type, message=None, name=None): @tf_export(v1=['debugging.assert_type', 'assert_type']) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints('assert_type') def assert_type(tensor, tf_type, message=None, name=None): """Statically asserts that the given `Tensor` is of the specified type. @@ -1584,6 +1618,7 @@ _TensorDimSizes = collections.namedtuple( @tf_export('debugging.assert_shapes', v1=[]) +@dispatch.add_dispatch_support def assert_shapes_v2(shapes, data=None, summarize=None, message=None, name=None): """Assert tensor shapes and dimension size relationships between tensors. @@ -1650,6 +1685,7 @@ def assert_shapes_v2(shapes, data=None, summarize=None, message=None, @tf_export(v1=['debugging.assert_shapes']) +@dispatch.add_dispatch_support def assert_shapes(shapes, data=None, summarize=None, message=None, name=None): """Assert tensor shapes and dimension size relationships between tensors. @@ -1939,6 +1975,7 @@ def is_numeric_tensor(tensor): 'math.is_non_decreasing', 'debugging.is_non_decreasing', 'is_non_decreasing' ]) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints('debugging.is_non_decreasing', 'is_non_decreasing') def is_non_decreasing(x, name=None): @@ -1980,6 +2017,7 @@ def is_non_decreasing(x, name=None): 'math.is_strictly_increasing', 'debugging.is_strictly_increasing', 'is_strictly_increasing' ]) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints('debugging.is_strictly_increasing', 'is_strictly_increasing') def is_strictly_increasing(x, name=None): @@ -2066,6 +2104,7 @@ def _assert_same_base_type(items, expected_type=None): @tf_export( 'debugging.assert_same_float_dtype', v1=['debugging.assert_same_float_dtype', 'assert_same_float_dtype']) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints('assert_same_float_dtype') def assert_same_float_dtype(tensors=None, dtype=None): """Validate and return float type based on `tensors` and `dtype`. @@ -2098,6 +2137,7 @@ def assert_same_float_dtype(tensors=None, dtype=None): @tf_export('debugging.assert_scalar', v1=[]) +@dispatch.add_dispatch_support def assert_scalar_v2(tensor, message=None, name=None): """Asserts that the given `tensor` is a scalar. @@ -2120,6 +2160,7 @@ def assert_scalar_v2(tensor, message=None, name=None): @tf_export(v1=['debugging.assert_scalar', 'assert_scalar']) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints('assert_scalar') def assert_scalar(tensor, name=None, message=None): """Asserts that the given `tensor` is a scalar (i.e. zero-dimensional). @@ -2154,6 +2195,7 @@ def assert_scalar(tensor, name=None, message=None): @tf_export('ensure_shape') +@dispatch.add_dispatch_support def ensure_shape(x, shape, name=None): """Updates the shape of a tensor and checks at runtime that the shape holds. diff --git a/tensorflow/python/ops/clip_ops.py b/tensorflow/python/ops/clip_ops.py index edb35afa52c..f7662516b4f 100644 --- a/tensorflow/python/ops/clip_ops.py +++ b/tensorflow/python/ops/clip_ops.py @@ -152,6 +152,7 @@ def _clip_by_value_grad(op, grad): @tf_export("clip_by_norm") +@dispatch.add_dispatch_support def clip_by_norm(t, clip_norm, axes=None, name=None): """Clips tensor values to a maximum L2-norm. @@ -235,6 +236,7 @@ def clip_by_norm(t, clip_norm, axes=None, name=None): @tf_export("linalg.global_norm", v1=["linalg.global_norm", "global_norm"]) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints("global_norm") def global_norm(t_list, name=None): """Computes the global norm of multiple tensors. @@ -285,6 +287,7 @@ def global_norm(t_list, name=None): @tf_export("clip_by_global_norm") +@dispatch.add_dispatch_support def clip_by_global_norm(t_list, clip_norm, use_norm=None, name=None): """Clips values of multiple tensors by the ratio of the sum of their norms. @@ -382,6 +385,7 @@ def clip_by_global_norm(t_list, clip_norm, use_norm=None, name=None): "use clip_by_norm(t, clip_norm * tf.cast(tf.size(t), tf.float32), name) " "instead.") @tf_export(v1=["clip_by_average_norm"]) +@dispatch.add_dispatch_support def clip_by_average_norm(t, clip_norm, name=None): """Clips tensor values to a maximum average L2-norm. diff --git a/tensorflow/python/ops/confusion_matrix.py b/tensorflow/python/ops/confusion_matrix.py index 3e885975b03..39177defe57 100644 --- a/tensorflow/python/ops/confusion_matrix.py +++ b/tensorflow/python/ops/confusion_matrix.py @@ -27,6 +27,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import sparse_ops from tensorflow.python.util import deprecation +from tensorflow.python.util import dispatch from tensorflow.python.util.tf_export import tf_export @@ -93,6 +94,7 @@ def remove_squeezable_dimensions( @tf_export('math.confusion_matrix', v1=[]) +@dispatch.add_dispatch_support def confusion_matrix(labels, predictions, num_classes=None, @@ -202,6 +204,7 @@ def confusion_matrix(labels, @tf_export(v1=['math.confusion_matrix', 'confusion_matrix']) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints('confusion_matrix', 'train.confusion_matrix') def confusion_matrix_v1(labels, predictions, diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index 58948f7d52a..918c989432d 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -54,6 +54,7 @@ from tensorflow.python.ops.gen_control_flow_ops import * from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import compat from tensorflow.python.util import deprecation +from tensorflow.python.util import dispatch from tensorflow.python.util import nest from tensorflow.python.util import tf_should_use from tensorflow.python.util.lazy_loader import LazyLoader @@ -110,6 +111,7 @@ def _summarize_eager(tensor, summarize=None): # Assert and Print are special symbols in python, so we must # use an upper-case version of them. @tf_export("debugging.Assert", "Assert") +@dispatch.add_dispatch_support @tf_should_use.should_use_result def Assert(condition, data, summarize=None, name=None): """Asserts that the given condition is true. @@ -1095,6 +1097,7 @@ def _UnpackIfSingleton(res): # pylint: disable=redefined-outer-name # pylint: disable=g-doc-args @tf_export(v1=["cond"]) +@dispatch.add_dispatch_support @deprecation.deprecated_args( None, "fn1/fn2 are deprecated in favor of the true_fn/false_fn arguments.", "fn1", "fn2") @@ -1318,6 +1321,7 @@ def _cast_indexed_slice_indices(a, b): @tf_export("cond", v1=[]) +@dispatch.add_dispatch_support def cond_for_tf_v2(pred, true_fn=None, false_fn=None, name=None): """Return `true_fn()` if the predicate `pred` is true else `false_fn()`. @@ -2942,6 +2946,7 @@ def group(*inputs, **kwargs): @tf_export("tuple", v1=[]) +@dispatch.add_dispatch_support def tuple_v2(tensors, control_inputs=None, name=None): """Group tensors together. @@ -2978,6 +2983,7 @@ def tuple_v2(tensors, control_inputs=None, name=None): @tf_export(v1=["tuple"]) +@dispatch.add_dispatch_support def tuple(tensors, name=None, control_inputs=None): # pylint: disable=redefined-builtin """Group tensors together. @@ -3312,6 +3318,7 @@ def _indexed_case_helper(branch_fns, default, branch_index, name): @tf_export("case", v1=[]) +@dispatch.add_dispatch_support def case_v2(pred_fn_pairs, default=None, exclusive=False, @@ -3416,6 +3423,7 @@ def case_v2(pred_fn_pairs, @tf_export(v1=["case"]) +@dispatch.add_dispatch_support def case(pred_fn_pairs, default=None, exclusive=False, diff --git a/tensorflow/python/ops/ctc_ops.py b/tensorflow/python/ops/ctc_ops.py index d989bc0be44..6c9cdf1dd08 100644 --- a/tensorflow/python/ops/ctc_ops.py +++ b/tensorflow/python/ops/ctc_ops.py @@ -43,6 +43,7 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.ops import sparse_ops from tensorflow.python.ops.nn_grad import _BroadcastMul from tensorflow.python.util import deprecation +from tensorflow.python.util import dispatch from tensorflow.python.util import nest from tensorflow.python.util.tf_export import tf_export @@ -70,6 +71,7 @@ def _generate_defun_backend(unique_api_name, preferred_device, func): # pylint: disable=protected-access, invalid-name @tf_export(v1=["nn.ctc_loss"]) +@dispatch.add_dispatch_support def ctc_loss(labels, inputs=None, sequence_length=None, @@ -284,6 +286,7 @@ def _CTCLossV2Grad(op, grad_loss, _): @tf_export("nn.ctc_greedy_decoder") +@dispatch.add_dispatch_support def ctc_greedy_decoder(inputs, sequence_length, merge_repeated=True): """Performs greedy decoding on the logits given in input (best path). @@ -333,6 +336,7 @@ def ctc_greedy_decoder(inputs, sequence_length, merge_repeated=True): @tf_export(v1=["nn.ctc_beam_search_decoder"]) +@dispatch.add_dispatch_support def ctc_beam_search_decoder(inputs, sequence_length, beam_width=100, @@ -395,6 +399,7 @@ def ctc_beam_search_decoder(inputs, @tf_export("nn.ctc_beam_search_decoder", v1=["nn.ctc_beam_search_decoder_v2"]) +@dispatch.add_dispatch_support def ctc_beam_search_decoder_v2(inputs, sequence_length, beam_width=100, @@ -731,6 +736,7 @@ def _ctc_loss_shape(op): # pylint: disable=protected-access, invalid-name @tf_export(v1=["nn.ctc_loss_v2"]) +@dispatch.add_dispatch_support def ctc_loss_v2(labels, logits, label_length, @@ -825,6 +831,7 @@ def ctc_loss_v2(labels, @tf_export("nn.ctc_loss", v1=[]) +@dispatch.add_dispatch_support def ctc_loss_v3(labels, logits, label_length, @@ -1056,6 +1063,7 @@ def ctc_loss_dense(labels, @tf_export("nn.collapse_repeated") +@dispatch.add_dispatch_support def collapse_repeated(labels, seq_length, name=None): """Merge repeated labels into single labels. @@ -1153,6 +1161,7 @@ def dense_labels_to_sparse(dense, length): @tf_export("nn.ctc_unique_labels") +@dispatch.add_dispatch_support def ctc_unique_labels(labels, name=None): """Get unique labels and indices for batched labels for `tf.nn.ctc_loss`. diff --git a/tensorflow/python/ops/custom_gradient.py b/tensorflow/python/ops/custom_gradient.py index 4040a4db038..2a9194fb146 100644 --- a/tensorflow/python/ops/custom_gradient.py +++ b/tensorflow/python/ops/custom_gradient.py @@ -28,6 +28,7 @@ from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import op_selector from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope +from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest from tensorflow.python.util import tf_decorator @@ -351,13 +352,8 @@ def _graph_mode_decorator(f, args, kwargs): "argument 'variables'.") if variables_in_signature and not variables: # User seems to intend to use variables but none were captured. - if not variable_scope.get_variable_scope().use_resource: - raise TypeError("If using @custom_gradient with a function that " - "uses variables, the enclosing variable scope must " - "have use_resource=True.") - else: - logging.warn("@custom_gradient grad_fn has 'variables' in signature, but " - "no ResourceVariables were used on the forward pass.") + logging.warn("@custom_gradient grad_fn has 'variables' in signature, but " + "no ResourceVariables were used on the forward pass.") flat_result = nest.flatten(result) flat_result_len = len(flat_result) @@ -482,28 +478,47 @@ def recompute_grad(f): def inner(*args, **kwargs): """Inner function closure for calculating gradients.""" current_var_scope = variable_scope.get_variable_scope() + with tape_lib.stop_recording(): + result = f(*args, **kwargs) - result = f(*args, **kwargs) + def grad_wrapper(*wrapper_args, **grad_kwargs): + """Wrapper function to accomodate lack of kwargs in graph mode decorator.""" - def grad(*dresult, **grad_kwargs): - """Gradient function calculation for inner function.""" - variables = grad_kwargs.get("variables") - with backprop.GradientTape() as t: - id_args = [gen_array_ops.identity(x) for x in args] - t.watch(id_args) + @custom_gradient + def inner_recompute_grad(*dresult): + """Nested custom gradient function for computing grads in reverse and forward mode autodiff.""" + # Gradient calculation for reverse mode autodiff. + variables = grad_kwargs.get("variables") + with backprop.GradientTape() as t: + id_args = [gen_array_ops.identity(x) for x in args] + t.watch(id_args) + if variables is not None: + t.watch(variables) + with ops.control_dependencies(dresult): + with variable_scope.variable_scope(current_var_scope): + result = f(*id_args, **kwargs) + kw_vars = [] if variables is not None: - t.watch(variables) - with ops.control_dependencies(dresult): - with variable_scope.variable_scope(current_var_scope): - result = f(*id_args, **kwargs) - kw_vars = [] - if variables is not None: - kw_vars = list(variables) - grads = t.gradient( - result, list(id_args) + kw_vars, output_gradients=dresult) - return grads[:len(id_args)], grads[len(id_args):] + kw_vars = list(variables) + grads = t.gradient( + result, + list(id_args) + kw_vars, + output_gradients=dresult, + unconnected_gradients=UnconnectedGradients.ZERO) - return result, grad + def transpose(*t_args, **t_kwargs): + """Gradient function calculation for forward mode autodiff.""" + # Just throw an error since gradients / activations are not stored on tape for recompute. + raise NotImplementedError( + "recompute_grad tried to transpose grad of {}. " + "Consider not using recompute_grad in forward mode" + "autodiff".format(f.__name__)) + + return (grads[:len(id_args)], grads[len(id_args):]), transpose + + return inner_recompute_grad(*wrapper_args) + + return result, grad_wrapper return inner diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py index 2fdae49b1f6..1c7b204fa58 100644 --- a/tensorflow/python/ops/embedding_ops.py +++ b/tensorflow/python/ops/embedding_ops.py @@ -36,6 +36,7 @@ from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import variables from tensorflow.python.ops.ragged import ragged_functional_ops from tensorflow.python.ops.ragged import ragged_tensor +from tensorflow.python.util import dispatch from tensorflow.python.util.tf_export import tf_export @@ -250,6 +251,7 @@ def _embedding_lookup_and_transform(params, @tf_export(v1=["nn.embedding_lookup"]) +@dispatch.add_dispatch_support def embedding_lookup( params, ids, @@ -327,6 +329,7 @@ def embedding_lookup( @tf_export("nn.embedding_lookup", v1=[]) +@dispatch.add_dispatch_support def embedding_lookup_v2(params, ids, max_norm=None, name=None): """Looks up embeddings for the given `ids` from a list of tensors. @@ -392,6 +395,7 @@ def embedding_lookup_v2(params, ids, max_norm=None, name=None): @tf_export(v1=["nn.embedding_lookup_sparse"]) +@dispatch.add_dispatch_support def embedding_lookup_sparse(params, sp_ids, sp_weights, @@ -574,6 +578,7 @@ def embedding_lookup_sparse(params, @tf_export("nn.embedding_lookup_sparse", v1=[]) +@dispatch.add_dispatch_support def embedding_lookup_sparse_v2(params, sp_ids, sp_weights, @@ -664,6 +669,7 @@ def embedding_lookup_sparse_v2(params, @tf_export("nn.safe_embedding_lookup_sparse", v1=[]) +@dispatch.add_dispatch_support def safe_embedding_lookup_sparse_v2(embedding_weights, sparse_ids, sparse_weights=None, @@ -765,6 +771,7 @@ def safe_embedding_lookup_sparse_v2(embedding_weights, @tf_export(v1=["nn.safe_embedding_lookup_sparse"]) +@dispatch.add_dispatch_support def safe_embedding_lookup_sparse(embedding_weights, sparse_ids, sparse_weights=None, diff --git a/tensorflow/python/ops/functional_ops.py b/tensorflow/python/ops/functional_ops.py index 8ec925824de..37b41a55eb9 100644 --- a/tensorflow/python/ops/functional_ops.py +++ b/tensorflow/python/ops/functional_ops.py @@ -38,6 +38,7 @@ from tensorflow.python.ops.gen_functional_ops import remote_call from tensorflow.python.ops.gen_functional_ops import symbolic_gradient from tensorflow.python.util import compat from tensorflow.python.util import deprecation +from tensorflow.python.util import dispatch from tensorflow.python.util import function_utils from tensorflow.python.util import nest from tensorflow.python.util.tf_export import tf_export @@ -45,6 +46,7 @@ from tensorflow.python.util.tf_export import tf_export # TODO(yuanbyu, mrry): Handle stride to support sliding windows. @tf_export(v1=["foldl"]) +@dispatch.add_dispatch_support def foldl(fn, elems, initializer=None, @@ -162,6 +164,7 @@ def foldl(fn, @tf_export("foldl", v1=[]) +@dispatch.add_dispatch_support @deprecation.deprecated_arg_values( None, """back_prop=False is deprecated. Consider using tf.stop_gradient instead. @@ -238,6 +241,7 @@ def foldl_v2(fn, @tf_export(v1=["foldr"]) +@dispatch.add_dispatch_support def foldr(fn, elems, initializer=None, @@ -356,6 +360,7 @@ def foldr(fn, @tf_export("foldr", v1=[]) +@dispatch.add_dispatch_support @deprecation.deprecated_arg_values( None, """back_prop=False is deprecated. Consider using tf.stop_gradient instead. @@ -432,6 +437,7 @@ def foldr_v2(fn, @tf_export(v1=["scan"]) +@dispatch.add_dispatch_support def scan(fn, elems, initializer=None, @@ -686,6 +692,7 @@ def scan(fn, @tf_export("scan", v1=[]) +@dispatch.add_dispatch_support @deprecation.deprecated_arg_values( None, """back_prop=False is deprecated. Consider using tf.stop_gradient instead. diff --git a/tensorflow/python/ops/gradient_checker_test.py b/tensorflow/python/ops/gradient_checker_test.py index 92ca9c2971e..c8ebf12569a 100644 --- a/tensorflow/python/ops/gradient_checker_test.py +++ b/tensorflow/python/ops/gradient_checker_test.py @@ -149,7 +149,7 @@ class GradientCheckerTest(test.TestCase): self.assertAllEqual(correct, analytical) self.assertAllClose(correct, numerical, rtol=1e-4) self.assertLess( - gradient_checker.compute_gradient_error(x, size, y, size), 2e-4) + gradient_checker.compute_gradient_error(x, size, y, size), 3e-4) @test_util.run_deprecated_v1 def testComplexConj(self): diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py index 817d8a1adbe..a06be7af74b 100644 --- a/tensorflow/python/ops/gradients_test.py +++ b/tensorflow/python/ops/gradients_test.py @@ -59,6 +59,7 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.ops.nn_ops import bias_add from tensorflow.python.platform import googletest +from tensorflow.python.ops import gradient_checker_v2 class GradientsTest(test_util.TensorFlowTestCase, parameterized.TestCase): @@ -1340,6 +1341,46 @@ class VariablesGradientTest(test_util.TensorFlowTestCase): return grads_re, grads + def _grad(self, f, argnums=0): + """Return a function which computes the gradient of `f`.""" + + def _f(*params): + with backprop.GradientTape() as tape: + tape.watch(params) + outputs = f(*params) + return tape.gradient( + outputs, + params[argnums], + unconnected_gradients=unconnected_gradients.UnconnectedGradients.ZERO) + + return _f + + def _test_gradients(self, f, inputs, order, delta=1e-3, rtol=1e-2, atol=1e-6): + """Tests backward jacobians of `f`'s [0, `order`)-order gradients.""" + if order < 1: + raise ValueError( + "`order` should be a positive integer, got '{}'.".format(order)) + if order > 1: + self._test_gradients( + f=self._grad(f), + inputs=inputs, + order=order - 1, + delta=delta, + rtol=rtol, + atol=atol) + sym_jac_back, num_jac = gradient_checker_v2.compute_gradient( + f, inputs, delta=delta) + self.assertAllClose(num_jac, sym_jac_back, rtol=rtol, atol=atol) + + @test_util.run_v2_only + def testCustomGradientRecomputeGradHigherOrder(self): + + @custom_gradient.recompute_grad + def f(x): + return math_ops.reduce_prod(math_ops.tanh(x)**2) + + self._test_gradients(f, [constant_op.constant([1.])], order=3) + @test_util.run_in_graph_and_eager_modes def testFnRecompute(self): """Checks that recompute_grad works grads of function args.""" @@ -1356,8 +1397,8 @@ class VariablesGradientTest(test_util.TensorFlowTestCase): shape=10, trainable=True, ) - - test_input = constant(np.zeros((10, 10), dtype=np.float32)) + self.evaluate(test_var.assign(np.ones([10]))) + test_input = constant(np.ones((10, 10), dtype=np.float32)) grads_re, grads = self._TestFnVariablesGradient(test_input, TestFn, test_input) @@ -1400,6 +1441,7 @@ class VariablesGradientTest(test_util.TensorFlowTestCase): shape=10, trainable=True, ) + self.evaluate(test_var.assign(np.ones([10]))) return input_t * test_var test_input_t = constant(np.zeros((10, 10), dtype=np.float32)) @@ -1442,6 +1484,8 @@ class VariablesGradientTest(test_util.TensorFlowTestCase): out_re = test_fn_re(test_input_t) out = TestFn(test_input_t) + init = variables.global_variables_initializer() + self.evaluate(init) grads_re = gradients.gradients(out_re, variables.trainable_variables()) grads = gradients.gradients(out, variables.trainable_variables()) diff --git a/tensorflow/python/ops/histogram_ops.py b/tensorflow/python/ops/histogram_ops.py index 92f3e7a24ba..233ea46c48b 100644 --- a/tensorflow/python/ops/histogram_ops.py +++ b/tensorflow/python/ops/histogram_ops.py @@ -26,10 +26,12 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import math_ops +from tensorflow.python.util import dispatch from tensorflow.python.util.tf_export import tf_export @tf_export('histogram_fixed_width_bins') +@dispatch.add_dispatch_support def histogram_fixed_width_bins(values, value_range, nbins=100, @@ -62,17 +64,14 @@ def histogram_fixed_width_bins(values, Examples: - ```python - # Bins will be: (-inf, 1), [1, 2), [2, 3), [3, 4), [4, inf) - nbins = 5 - value_range = [0.0, 5.0] - new_values = [-1.0, 0.0, 1.5, 2.0, 5.0, 15] - - with tf.compat.v1.get_default_session() as sess: - indices = tf.histogram_fixed_width_bins(new_values, value_range, nbins=5) - variables.global_variables_initializer().run() - sess.run(indices) # [0, 0, 1, 2, 4, 4] - ``` + >>> # Bins will be: (-inf, 1), [1, 2), [2, 3), [3, 4), [4, inf) + ... + >>> nbins = 5 + >>> value_range = [0.0, 5.0] + >>> new_values = [-1.0, 0.0, 1.5, 2.0, 5.0, 15] + >>> indices = tf.histogram_fixed_width_bins(new_values, value_range, nbins=5) + >>> indices.numpy() + array([0, 0, 1, 2, 4, 4], dtype=int32) """ with ops.name_scope(name, 'histogram_fixed_width_bins', [values, value_range, nbins]): @@ -101,6 +100,7 @@ def histogram_fixed_width_bins(values, @tf_export('histogram_fixed_width') +@dispatch.add_dispatch_support def histogram_fixed_width(values, value_range, nbins=100, @@ -131,17 +131,14 @@ def histogram_fixed_width(values, Examples: - ```python - # Bins will be: (-inf, 1), [1, 2), [2, 3), [3, 4), [4, inf) - nbins = 5 - value_range = [0.0, 5.0] - new_values = [-1.0, 0.0, 1.5, 2.0, 5.0, 15] - - with tf.compat.v1.get_default_session() as sess: - hist = tf.histogram_fixed_width(new_values, value_range, nbins=5) - variables.global_variables_initializer().run() - sess.run(hist) => [2, 1, 1, 0, 2] - ``` + >>> # Bins will be: (-inf, 1), [1, 2), [2, 3), [3, 4), [4, inf) + ... + >>> nbins = 5 + >>> value_range = [0.0, 5.0] + >>> new_values = [-1.0, 0.0, 1.5, 2.0, 5.0, 15] + >>> hist = tf.histogram_fixed_width(new_values, value_range, nbins=5) + >>> hist.numpy() + array([2, 1, 1, 0, 2], dtype=int32) """ with ops.name_scope(name, 'histogram_fixed_width', [values, value_range, nbins]) as name: diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py index 52b65efad67..0532b24edca 100644 --- a/tensorflow/python/ops/image_ops_impl.py +++ b/tensorflow/python/ops/image_ops_impl.py @@ -39,6 +39,7 @@ from tensorflow.python.ops import sort_ops from tensorflow.python.ops import string_ops from tensorflow.python.ops import variables from tensorflow.python.util import deprecation +from tensorflow.python.util import dispatch from tensorflow.python.util.tf_export import tf_export ops.NotDifferentiable('RandomCrop') @@ -323,6 +324,7 @@ def fix_image_flip_shape(image, result): @tf_export('image.random_flip_up_down') +@dispatch.add_dispatch_support def random_flip_up_down(image, seed=None): """Randomly flips an image vertically (upside down). @@ -363,6 +365,7 @@ def random_flip_up_down(image, seed=None): @tf_export('image.random_flip_left_right') +@dispatch.add_dispatch_support def random_flip_left_right(image, seed=None): """Randomly flip an image horizontally (left to right). @@ -450,6 +453,7 @@ def _random_flip(image, flip_index, seed, scope_name): @tf_export('image.flip_left_right') +@dispatch.add_dispatch_support def flip_left_right(image): """Flip an image horizontally (left to right). @@ -484,6 +488,7 @@ def flip_left_right(image): @tf_export('image.flip_up_down') +@dispatch.add_dispatch_support def flip_up_down(image): """Flip an image vertically (upside down). @@ -549,6 +554,7 @@ def _flip(image, flip_index, scope_name): @tf_export('image.rot90') +@dispatch.add_dispatch_support def rot90(image, k=1, name=None): """Rotate image(s) counter-clockwise by 90 degrees. @@ -660,6 +666,7 @@ def _rot90_4D(images, k, name_scope): @tf_export('image.transpose', v1=['image.transpose', 'image.transpose_image']) +@dispatch.add_dispatch_support def transpose(image, name=None): """Transpose image(s) by swapping the height and width dimension. @@ -718,6 +725,7 @@ def transpose(image, name=None): @tf_export('image.central_crop') +@dispatch.add_dispatch_support def central_crop(image, central_fraction): """Crop the central region of the image(s). @@ -850,6 +858,7 @@ def central_crop(image, central_fraction): @tf_export('image.pad_to_bounding_box') +@dispatch.add_dispatch_support def pad_to_bounding_box(image, offset_height, offset_width, target_height, target_width): """Pad `image` with zeros to the specified `height` and `width`. @@ -959,6 +968,7 @@ def pad_to_bounding_box(image, offset_height, offset_width, target_height, @tf_export('image.crop_to_bounding_box') +@dispatch.add_dispatch_support def crop_to_bounding_box(image, offset_height, offset_width, target_height, target_width): """Crops an image to a specified bounding box. @@ -1041,6 +1051,7 @@ def crop_to_bounding_box(image, offset_height, offset_width, target_height, @tf_export( 'image.resize_with_crop_or_pad', v1=['image.resize_with_crop_or_pad', 'image.resize_image_with_crop_or_pad']) +@dispatch.add_dispatch_support def resize_image_with_crop_or_pad(image, target_height, target_width): """Crops and/or pads an image to a target width and height. @@ -1258,6 +1269,7 @@ def _resize_images_common(images, resizer_fn, size, preserve_aspect_ratio, name, @tf_export(v1=['image.resize_images', 'image.resize']) +@dispatch.add_dispatch_support def resize_images(images, size, method=ResizeMethodV1.BILINEAR, @@ -1343,6 +1355,7 @@ def resize_images(images, @tf_export('image.resize', v1=[]) +@dispatch.add_dispatch_support def resize_images_v2(images, size, method=ResizeMethod.BILINEAR, @@ -1594,6 +1607,7 @@ def _resize_image_with_pad_common(image, target_height, target_width, @tf_export(v1=['image.resize_image_with_pad']) +@dispatch.add_dispatch_support def resize_image_with_pad_v1(image, target_height, target_width, @@ -1636,6 +1650,7 @@ def resize_image_with_pad_v1(image, @tf_export('image.resize_with_pad', v1=[]) +@dispatch.add_dispatch_support def resize_image_with_pad_v2(image, target_height, target_width, @@ -1676,6 +1691,7 @@ def resize_image_with_pad_v2(image, @tf_export('image.per_image_standardization') +@dispatch.add_dispatch_support def per_image_standardization(image): """Linearly scales each image in `image` to have mean 0 and variance 1. @@ -1721,6 +1737,7 @@ def per_image_standardization(image): @tf_export('image.random_brightness') +@dispatch.add_dispatch_support def random_brightness(image, max_delta, seed=None): """Adjust the brightness of images by a random factor. @@ -1756,6 +1773,7 @@ def random_brightness(image, max_delta, seed=None): @tf_export('image.random_contrast') +@dispatch.add_dispatch_support def random_contrast(image, lower, upper, seed=None): """Adjust the contrast of an image or images by a random factor. @@ -1796,6 +1814,7 @@ def random_contrast(image, lower, upper, seed=None): @tf_export('image.adjust_brightness') +@dispatch.add_dispatch_support def adjust_brightness(image, delta): """Adjust the brightness of RGB or Grayscale images. @@ -1847,6 +1866,7 @@ def adjust_brightness(image, delta): @tf_export('image.adjust_contrast') +@dispatch.add_dispatch_support def adjust_contrast(images, contrast_factor): """Adjust contrast of RGB or grayscale images. @@ -1903,6 +1923,7 @@ def adjust_contrast(images, contrast_factor): @tf_export('image.adjust_gamma') +@dispatch.add_dispatch_support def adjust_gamma(image, gamma=1, gain=1): """Performs [Gamma Correction](http://en.wikipedia.org/wiki/Gamma_correction). @@ -1967,6 +1988,7 @@ def adjust_gamma(image, gamma=1, gain=1): @tf_export('image.convert_image_dtype') +@dispatch.add_dispatch_support def convert_image_dtype(image, dtype, saturate=False, name=None): """Convert `image` to `dtype`, scaling its values if needed. @@ -2066,6 +2088,7 @@ def convert_image_dtype(image, dtype, saturate=False, name=None): @tf_export('image.rgb_to_grayscale') +@dispatch.add_dispatch_support def rgb_to_grayscale(images, name=None): """Converts one or more images from RGB to Grayscale. @@ -2101,6 +2124,7 @@ def rgb_to_grayscale(images, name=None): @tf_export('image.grayscale_to_rgb') +@dispatch.add_dispatch_support def grayscale_to_rgb(images, name=None): """Converts one or more images from Grayscale to RGB. @@ -2137,6 +2161,7 @@ def grayscale_to_rgb(images, name=None): # pylint: disable=invalid-name @tf_export('image.random_hue') +@dispatch.add_dispatch_support def random_hue(image, max_delta, seed=None): """Adjust the hue of RGB images by a random factor. @@ -2179,6 +2204,7 @@ def random_hue(image, max_delta, seed=None): @tf_export('image.adjust_hue') +@dispatch.add_dispatch_support def adjust_hue(image, delta, name=None): """Adjust hue of RGB images. @@ -2246,6 +2272,7 @@ def adjust_hue(image, delta, name=None): # pylint: disable=invalid-name @tf_export('image.random_jpeg_quality') +@dispatch.add_dispatch_support def random_jpeg_quality(image, min_jpeg_quality, max_jpeg_quality, seed=None): """Randomly changes jpeg encoding quality for inducing jpeg noise. @@ -2293,6 +2320,7 @@ def random_jpeg_quality(image, min_jpeg_quality, max_jpeg_quality, seed=None): @tf_export('image.adjust_jpeg_quality') +@dispatch.add_dispatch_support def adjust_jpeg_quality(image, jpeg_quality, name=None): """Adjust jpeg encoding quality of an image. @@ -2343,6 +2371,7 @@ def adjust_jpeg_quality(image, jpeg_quality, name=None): @tf_export('image.random_saturation') +@dispatch.add_dispatch_support def random_saturation(image, lower, upper, seed=None): """Adjust the saturation of RGB images by a random factor. @@ -2389,6 +2418,7 @@ def random_saturation(image, lower, upper, seed=None): @tf_export('image.adjust_saturation') +@dispatch.add_dispatch_support def adjust_saturation(image, saturation_factor, name=None): """Adjust saturation of RGB images. @@ -2480,42 +2510,43 @@ tf_export( 'io.decode_and_crop_jpeg', 'image.decode_and_crop_jpeg', v1=['io.decode_and_crop_jpeg', 'image.decode_and_crop_jpeg'])( - gen_image_ops.decode_and_crop_jpeg) + dispatch.add_dispatch_support(gen_image_ops.decode_and_crop_jpeg)) tf_export( 'io.decode_bmp', 'image.decode_bmp', v1=['io.decode_bmp', 'image.decode_bmp'])( - gen_image_ops.decode_bmp) + dispatch.add_dispatch_support(gen_image_ops.decode_bmp)) tf_export( 'io.decode_gif', 'image.decode_gif', v1=['io.decode_gif', 'image.decode_gif'])( - gen_image_ops.decode_gif) + dispatch.add_dispatch_support(gen_image_ops.decode_gif)) tf_export( 'io.decode_jpeg', 'image.decode_jpeg', v1=['io.decode_jpeg', 'image.decode_jpeg'])( - gen_image_ops.decode_jpeg) + dispatch.add_dispatch_support(gen_image_ops.decode_jpeg)) tf_export( 'io.decode_png', 'image.decode_png', v1=['io.decode_png', 'image.decode_png'])( - gen_image_ops.decode_png) + dispatch.add_dispatch_support(gen_image_ops.decode_png)) tf_export( 'io.encode_jpeg', 'image.encode_jpeg', v1=['io.encode_jpeg', 'image.encode_jpeg'])( - gen_image_ops.encode_jpeg) + dispatch.add_dispatch_support(gen_image_ops.encode_jpeg)) tf_export( 'io.extract_jpeg_shape', 'image.extract_jpeg_shape', v1=['io.extract_jpeg_shape', 'image.extract_jpeg_shape'])( - gen_image_ops.extract_jpeg_shape) + dispatch.add_dispatch_support(gen_image_ops.extract_jpeg_shape)) @tf_export('io.encode_png', 'image.encode_png') +@dispatch.add_dispatch_support def encode_png(image, compression=-1, name=None): r"""PNG-encode an image. @@ -2548,6 +2579,7 @@ def encode_png(image, compression=-1, name=None): 'io.decode_image', 'image.decode_image', v1=['io.decode_image', 'image.decode_image']) +@dispatch.add_dispatch_support def decode_image(contents, channels=None, dtype=dtypes.uint8, @@ -2661,6 +2693,7 @@ def decode_image(contents, @tf_export('image.total_variation') +@dispatch.add_dispatch_support def total_variation(images, name=None): """Calculate and return the total variation for one or more images. @@ -2732,6 +2765,7 @@ def total_variation(images, name=None): @tf_export('image.sample_distorted_bounding_box', v1=[]) +@dispatch.add_dispatch_support def sample_distorted_bounding_box_v2(image_size, bounding_boxes, seed=0, @@ -2831,6 +2865,7 @@ def sample_distorted_bounding_box_v2(image_size, @tf_export(v1=['image.sample_distorted_bounding_box']) +@dispatch.add_dispatch_support @deprecation.deprecated( date=None, instructions='`seed2` arg is deprecated.' @@ -2945,6 +2980,7 @@ def sample_distorted_bounding_box(image_size, @tf_export('image.non_max_suppression') +@dispatch.add_dispatch_support def non_max_suppression(boxes, scores, max_output_size, @@ -2997,6 +3033,7 @@ def non_max_suppression(boxes, @tf_export('image.non_max_suppression_with_scores') +@dispatch.add_dispatch_support def non_max_suppression_with_scores(boxes, scores, max_output_size, @@ -3083,6 +3120,7 @@ def non_max_suppression_with_scores(boxes, @tf_export('image.non_max_suppression_overlaps') +@dispatch.add_dispatch_support def non_max_suppression_with_overlaps(overlaps, scores, max_output_size, @@ -3134,6 +3172,7 @@ _rgb_to_yiq_kernel = [[0.299, 0.59590059, 0.2115], @tf_export('image.rgb_to_yiq') +@dispatch.add_dispatch_support def rgb_to_yiq(images): """Converts one or more images from RGB to YIQ. @@ -3167,6 +3206,7 @@ _yiq_to_rgb_kernel = [[1, 1, 1], [0.95598634, -0.27201283, -1.10674021], @tf_export('image.yiq_to_rgb') +@dispatch.add_dispatch_support def yiq_to_rgb(images): """Converts one or more images from YIQ to RGB. @@ -3195,6 +3235,7 @@ _rgb_to_yuv_kernel = [[0.299, -0.14714119, 0.61497538], @tf_export('image.rgb_to_yuv') +@dispatch.add_dispatch_support def rgb_to_yuv(images): """Converts one or more images from RGB to YUV. @@ -3221,6 +3262,7 @@ _yuv_to_rgb_kernel = [[1, 1, 1], [0, -0.394642334, 2.03206185], @tf_export('image.yuv_to_rgb') +@dispatch.add_dispatch_support def yuv_to_rgb(images): """Converts one or more images from YUV to RGB. @@ -3314,6 +3356,7 @@ def _verify_compatible_image_shapes(img1, img2): @tf_export('image.psnr') +@dispatch.add_dispatch_support def psnr(a, b, max_val, name=None): """Returns the Peak Signal-to-Noise Ratio between a and b. @@ -3525,6 +3568,7 @@ def _ssim_per_channel(img1, @tf_export('image.ssim') +@dispatch.add_dispatch_support def ssim(img1, img2, max_val, @@ -3604,6 +3648,7 @@ _MSSSIM_WEIGHTS = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333) @tf_export('image.ssim_multiscale') +@dispatch.add_dispatch_support def ssim_multiscale(img1, img2, max_val, @@ -3731,6 +3776,7 @@ def ssim_multiscale(img1, @tf_export('image.image_gradients') +@dispatch.add_dispatch_support def image_gradients(image): """Returns image gradients (dy, dx) for each color channel. @@ -3804,6 +3850,7 @@ def image_gradients(image): @tf_export('image.sobel_edges') +@dispatch.add_dispatch_support def sobel_edges(image): """Returns a tensor holding Sobel edge maps. @@ -3888,21 +3935,22 @@ resize_area_deprecation = deprecation.deprecated( instructions=( 'Use `tf.image.resize(...method=ResizeMethod.AREA...)` instead.')) tf_export(v1=['image.resize_area'])( - resize_area_deprecation(gen_image_ops.resize_area)) + resize_area_deprecation( + dispatch.add_dispatch_support(gen_image_ops.resize_area))) resize_bicubic_deprecation = deprecation.deprecated( date=None, instructions=( 'Use `tf.image.resize(...method=ResizeMethod.BICUBIC...)` instead.')) tf_export(v1=['image.resize_bicubic'])( - resize_bicubic_deprecation(resize_bicubic)) + dispatch.add_dispatch_support(resize_bicubic_deprecation(resize_bicubic))) resize_bilinear_deprecation = deprecation.deprecated( date=None, instructions=( 'Use `tf.image.resize(...method=ResizeMethod.BILINEAR...)` instead.')) tf_export(v1=['image.resize_bilinear'])( - resize_bilinear_deprecation(resize_bilinear)) + dispatch.add_dispatch_support(resize_bilinear_deprecation(resize_bilinear))) resize_nearest_neighbor_deprecation = deprecation.deprecated( date=None, @@ -3910,10 +3958,12 @@ resize_nearest_neighbor_deprecation = deprecation.deprecated( 'Use `tf.image.resize(...method=ResizeMethod.NEAREST_NEIGHBOR...)` ' 'instead.')) tf_export(v1=['image.resize_nearest_neighbor'])( - resize_nearest_neighbor_deprecation(resize_nearest_neighbor)) + dispatch.add_dispatch_support( + resize_nearest_neighbor_deprecation(resize_nearest_neighbor))) @tf_export('image.crop_and_resize', v1=[]) +@dispatch.add_dispatch_support def crop_and_resize_v2(image, boxes, box_indices, @@ -3997,6 +4047,7 @@ def crop_and_resize_v2(image, @tf_export(v1=['image.crop_and_resize']) +@dispatch.add_dispatch_support @deprecation.deprecated_args(None, 'box_ind is deprecated, use box_indices instead', 'box_ind') @@ -4019,6 +4070,7 @@ crop_and_resize_v1.__doc__ = gen_image_ops.crop_and_resize.__doc__ @tf_export(v1=['image.extract_glimpse']) +@dispatch.add_dispatch_support def extract_glimpse( input, # pylint: disable=redefined-builtin size, @@ -4062,8 +4114,8 @@ def extract_glimpse( ... [[6.0], ... [7.0], ... [8.0]]]] - >>> tf.image.extract_glimpse(x, size=(2, 2), offsets=[[1, 1]], - ... centered=False, normalized=False) + >>> tf.compat.v1.image.extract_glimpse(x, size=(2, 2), offsets=[[1, 1]], + ... centered=False, normalized=False) <tf.Tensor: shape=(1, 2, 2, 1), dtype=float32, numpy= array([[[[0.], [1.]], @@ -4104,6 +4156,7 @@ def extract_glimpse( @tf_export('image.extract_glimpse', v1=[]) +@dispatch.add_dispatch_support def extract_glimpse_v2( input, # pylint: disable=redefined-builtin size, @@ -4150,10 +4203,10 @@ def extract_glimpse_v2( >>> tf.image.extract_glimpse(x, size=(2, 2), offsets=[[1, 1]], ... centered=False, normalized=False) <tf.Tensor: shape=(1, 2, 2, 1), dtype=float32, numpy= - array([[[[0.], - [1.]], - [[3.], - [4.]]]], dtype=float32)> + array([[[[4.], + [5.]], + [[7.], + [8.]]]], dtype=float32)> Args: input: A `Tensor` of type `float32`. A 4-D float tensor of shape @@ -4178,7 +4231,7 @@ def extract_glimpse_v2( Returns: A `Tensor` of type `float32`. """ - return gen_image_ops.extract_glimpse( + return gen_image_ops.extract_glimpse_v2( input=input, size=size, offsets=offsets, @@ -4190,6 +4243,7 @@ def extract_glimpse_v2( @tf_export('image.combined_non_max_suppression') +@dispatch.add_dispatch_support def combined_non_max_suppression(boxes, scores, max_output_size_per_class, @@ -4442,6 +4496,7 @@ def _suppression_loop_body(boxes, iou_threshold, output_size, idx, tile_size): @tf_export('image.non_max_suppression_padded') +@dispatch.add_dispatch_support def non_max_suppression_padded(boxes, scores, max_output_size, @@ -4816,6 +4871,7 @@ def non_max_suppression_padded_v1(boxes, @tf_export('image.draw_bounding_boxes', v1=[]) +@dispatch.add_dispatch_support def draw_bounding_boxes_v2(images, boxes, colors, name=None): """Draw bounding boxes on a batch of images. @@ -4870,6 +4926,7 @@ def draw_bounding_boxes_v2(images, boxes, colors, name=None): @tf_export(v1=['image.draw_bounding_boxes']) +@dispatch.add_dispatch_support def draw_bounding_boxes(images, boxes, name=None, colors=None): """Draw bounding boxes on a batch of images. @@ -4922,6 +4979,7 @@ def draw_bounding_boxes(images, boxes, name=None, colors=None): @tf_export('image.generate_bounding_box_proposals') +@dispatch.add_dispatch_support def generate_bounding_box_proposals(scores, bbox_deltas, image_info, diff --git a/tensorflow/python/ops/linalg/linalg_impl.py b/tensorflow/python/ops/linalg/linalg_impl.py index f7617d83caf..82acd09caec 100644 --- a/tensorflow/python/ops/linalg/linalg_impl.py +++ b/tensorflow/python/ops/linalg/linalg_impl.py @@ -41,7 +41,7 @@ cholesky = linalg_ops.cholesky cholesky_solve = linalg_ops.cholesky_solve det = linalg_ops.matrix_determinant slogdet = gen_linalg_ops.log_matrix_determinant -tf_export('linalg.slogdet')(slogdet) +tf_export('linalg.slogdet')(dispatch.add_dispatch_support(slogdet)) diag = array_ops.matrix_diag diag_part = array_ops.matrix_diag_part eigh = linalg_ops.self_adjoint_eig @@ -51,7 +51,7 @@ eye = linalg_ops.eye inv = linalg_ops.matrix_inverse logm = gen_linalg_ops.matrix_logarithm lu = gen_linalg_ops.lu -tf_export('linalg.logm')(logm) +tf_export('linalg.logm')(dispatch.add_dispatch_support(logm)) lstsq = linalg_ops.matrix_solve_ls norm = linalg_ops.norm qr = linalg_ops.qr @@ -230,6 +230,7 @@ def _matrix_exp_pade13(matrix): @tf_export('linalg.expm') +@dispatch.add_dispatch_support def matrix_exponential(input, name=None): # pylint: disable=redefined-builtin r"""Computes the matrix exponential of one or more square matrices. @@ -340,6 +341,7 @@ def matrix_exponential(input, name=None): # pylint: disable=redefined-builtin @tf_export('linalg.tridiagonal_solve') +@dispatch.add_dispatch_support def tridiagonal_solve(diagonals, rhs, diagonals_format='compact', @@ -541,6 +543,7 @@ def _tridiagonal_solve_compact_format(diagonals, rhs, transpose_rhs, @tf_export('linalg.tridiagonal_matmul') +@dispatch.add_dispatch_support def tridiagonal_matmul(diagonals, rhs, diagonals_format='compact', name=None): r"""Multiplies tridiagonal matrix by matrix. @@ -638,6 +641,7 @@ def _maybe_validate_matrix(a, validate_args): @tf_export('linalg.matrix_rank') +@dispatch.add_dispatch_support def matrix_rank(a, tol=None, validate_args=False, name=None): """Compute the matrix rank of one or more matrices. @@ -676,6 +680,7 @@ def matrix_rank(a, tol=None, validate_args=False, name=None): @tf_export('linalg.pinv') +@dispatch.add_dispatch_support def pinv(a, rcond=None, validate_args=False, name=None): """Compute the Moore-Penrose pseudo-inverse of one or more matrices. @@ -805,6 +810,7 @@ def pinv(a, rcond=None, validate_args=False, name=None): @tf_export('linalg.lu_solve') +@dispatch.add_dispatch_support def lu_solve(lower_upper, perm, rhs, validate_args=False, name=None): """Solves systems of linear eqns `A X = RHS`, given LU factorizations. @@ -902,6 +908,7 @@ def lu_solve(lower_upper, perm, rhs, validate_args=False, name=None): @tf_export('linalg.lu_matrix_inverse') +@dispatch.add_dispatch_support def lu_matrix_inverse(lower_upper, perm, validate_args=False, name=None): """Computes the inverse given the LU decomposition(s) of one or more matrices. @@ -966,6 +973,7 @@ def lu_matrix_inverse(lower_upper, perm, validate_args=False, name=None): @tf_export('linalg.lu_reconstruct') +@dispatch.add_dispatch_support def lu_reconstruct(lower_upper, perm, validate_args=False, name=None): """The reconstruct one or more matrices from their LU decomposition(s). diff --git a/tensorflow/python/ops/linalg/sparse/conjugate_gradient.py b/tensorflow/python/ops/linalg/sparse/conjugate_gradient.py index 613309f856d..6794636c3fd 100644 --- a/tensorflow/python/ops/linalg/sparse/conjugate_gradient.py +++ b/tensorflow/python/ops/linalg/sparse/conjugate_gradient.py @@ -27,10 +27,12 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.linalg import linalg_impl as linalg +from tensorflow.python.util import dispatch from tensorflow.python.util.tf_export import tf_export @tf_export('linalg.experimental.conjugate_gradient') +@dispatch.add_dispatch_support def conjugate_gradient(operator, rhs, preconditioner=None, diff --git a/tensorflow/python/ops/linalg_ops.py b/tensorflow/python/ops/linalg_ops.py index abca7df19e0..03b7b98119d 100644 --- a/tensorflow/python/ops/linalg_ops.py +++ b/tensorflow/python/ops/linalg_ops.py @@ -32,6 +32,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops.gen_linalg_ops import * # pylint: enable=wildcard-import from tensorflow.python.util import deprecation +from tensorflow.python.util import dispatch from tensorflow.python.util.tf_export import tf_export # Names below are lower_case. @@ -82,6 +83,7 @@ def _RegularizedGramianCholesky(matrix, l2_regularizer, first_kind): @tf_export( 'linalg.triangular_solve', v1=['linalg.triangular_solve', 'matrix_triangular_solve']) +@dispatch.add_dispatch_support def matrix_triangular_solve(matrix, rhs, lower=True, adjoint=False, name=None): """Solve systems of linear equations with upper or lower triangular matrices. @@ -143,6 +145,7 @@ def matrix_triangular_solve(matrix, rhs, lower=True, adjoint=False, name=None): @tf_export( 'linalg.cholesky_solve', v1=['linalg.cholesky_solve', 'cholesky_solve']) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints('cholesky_solve') def cholesky_solve(chol, rhs, name=None): """Solves systems of linear eqns `A X = RHS`, given Cholesky factorizations. @@ -187,6 +190,7 @@ def cholesky_solve(chol, rhs, name=None): @tf_export('eye', 'linalg.eye') +@dispatch.add_dispatch_support def eye(num_rows, num_columns=None, batch_shape=None, @@ -234,6 +238,7 @@ def eye(num_rows, @tf_export('linalg.lstsq', v1=['linalg.lstsq', 'matrix_solve_ls']) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints('matrix_solve_ls') def matrix_solve_ls(matrix, rhs, l2_regularizer=0.0, fast=True, name=None): r"""Solves one or more linear least-squares problems. @@ -371,6 +376,7 @@ def matrix_solve_ls(matrix, rhs, l2_regularizer=0.0, fast=True, name=None): @tf_export('linalg.eig', 'eig', v1=[]) +@dispatch.add_dispatch_support def eig(tensor, name=None): """Computes the eigen decomposition of a batch of matrices. @@ -401,6 +407,7 @@ def eig(tensor, name=None): @tf_export('linalg.eigvals', 'eigvals', v1=[]) +@dispatch.add_dispatch_support def eigvals(tensor, name=None): """Computes the eigenvalues of one or more matrices. @@ -427,6 +434,7 @@ def eigvals(tensor, name=None): @tf_export('linalg.eigh', v1=['linalg.eigh', 'self_adjoint_eig']) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints('self_adjoint_eig') def self_adjoint_eig(tensor, name=None): """Computes the eigen decomposition of a batch of self-adjoint matrices. @@ -450,6 +458,7 @@ def self_adjoint_eig(tensor, name=None): @tf_export('linalg.eigvalsh', v1=['linalg.eigvalsh', 'self_adjoint_eigvals']) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints('self_adjoint_eigvals') def self_adjoint_eigvals(tensor, name=None): """Computes the eigenvalues of one or more self-adjoint matrices. @@ -473,6 +482,7 @@ def self_adjoint_eigvals(tensor, name=None): @tf_export('linalg.svd', v1=['linalg.svd', 'svd']) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints('svd') def svd(tensor, full_matrices=False, compute_uv=True, name=None): r"""Computes the singular value decompositions of one or more matrices. @@ -544,6 +554,7 @@ def svd(tensor, full_matrices=False, compute_uv=True, name=None): # pylint: disable=redefined-builtin @tf_export('norm', 'linalg.norm', v1=[]) +@dispatch.add_dispatch_support def norm_v2(tensor, ord='euclidean', axis=None, @@ -615,6 +626,7 @@ def norm_v2(tensor, # pylint: disable=redefined-builtin @tf_export(v1=['norm', 'linalg.norm']) +@dispatch.add_dispatch_support @deprecation.deprecated_args( None, 'keep_dims is deprecated, use keepdims instead', 'keep_dims') def norm(tensor, diff --git a/tensorflow/python/ops/logging_ops.py b/tensorflow/python/ops/logging_ops.py index 7e980a0dbb3..8ca63f55987 100644 --- a/tensorflow/python/ops/logging_ops.py +++ b/tensorflow/python/ops/logging_ops.py @@ -38,6 +38,7 @@ from tensorflow.python.ops import string_ops from tensorflow.python.ops.gen_logging_ops import * # pylint: enable=wildcard-import from tensorflow.python.platform import tf_logging +from tensorflow.python.util import dispatch from tensorflow.python.util import nest from tensorflow.python.util.deprecation import deprecated from tensorflow.python.util.tf_export import tf_export @@ -71,6 +72,7 @@ except NameError: "only a concern in graph mode. Below is an example " "of how to ensure tf.print executes in graph mode:\n") @tf_export(v1=["Print"]) +@dispatch.add_dispatch_support def Print(input_, data, message=None, first_n=None, summarize=None, name=None): """Prints a list of tensors. @@ -136,6 +138,7 @@ def _is_filepath(output_stream): # function definition. # pylint: disable=g-doc-args @tf_export("print") +@dispatch.add_dispatch_support def print_v2(*inputs, **kwargs): """Print the specified inputs. diff --git a/tensorflow/python/ops/losses/losses_impl.py b/tensorflow/python/ops/losses/losses_impl.py index 556c646f2a7..6a7b4b68420 100644 --- a/tensorflow/python/ops/losses/losses_impl.py +++ b/tensorflow/python/ops/losses/losses_impl.py @@ -29,6 +29,7 @@ from tensorflow.python.ops import nn from tensorflow.python.ops import nn_ops from tensorflow.python.ops import weights_broadcast_ops from tensorflow.python.ops.losses import util +from tensorflow.python.util import dispatch from tensorflow.python.util.deprecation import deprecated_args from tensorflow.python.util.deprecation import deprecated_argument_lookup from tensorflow.python.util.tf_export import tf_export @@ -136,6 +137,7 @@ def _num_elements(losses): @tf_export(v1=["losses.compute_weighted_loss"]) +@dispatch.add_dispatch_support def compute_weighted_loss( losses, weights=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES, reduction=Reduction.SUM_BY_NONZERO_WEIGHTS): @@ -204,6 +206,7 @@ def compute_weighted_loss( @tf_export(v1=["losses.absolute_difference"]) +@dispatch.add_dispatch_support def absolute_difference( labels, predictions, weights=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES, @@ -257,6 +260,7 @@ def absolute_difference( @tf_export(v1=["losses.cosine_distance"]) +@dispatch.add_dispatch_support @deprecated_args(None, "dim is deprecated, use axis instead", "dim") def cosine_distance( labels, predictions, axis=None, weights=1.0, scope=None, @@ -313,6 +317,7 @@ def cosine_distance( @tf_export(v1=["losses.hinge_loss"]) +@dispatch.add_dispatch_support def hinge_loss(labels, logits, weights=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES, reduction=Reduction.SUM_BY_NONZERO_WEIGHTS): @@ -363,6 +368,7 @@ def hinge_loss(labels, logits, weights=1.0, scope=None, @tf_export(v1=["losses.huber_loss"]) +@dispatch.add_dispatch_support def huber_loss(labels, predictions, weights=1.0, delta=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES, reduction=Reduction.SUM_BY_NONZERO_WEIGHTS): @@ -439,6 +445,7 @@ def huber_loss(labels, predictions, weights=1.0, delta=1.0, scope=None, @tf_export(v1=["losses.log_loss"]) +@dispatch.add_dispatch_support def log_loss(labels, predictions, weights=1.0, epsilon=1e-7, scope=None, loss_collection=ops.GraphKeys.LOSSES, reduction=Reduction.SUM_BY_NONZERO_WEIGHTS): @@ -496,6 +503,7 @@ def log_loss(labels, predictions, weights=1.0, epsilon=1e-7, scope=None, # TODO(b/37208492): Add reduction arg. @tf_export(v1=["losses.mean_pairwise_squared_error"]) +@dispatch.add_dispatch_support def mean_pairwise_squared_error( labels, predictions, weights=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES): @@ -592,6 +600,7 @@ def mean_pairwise_squared_error( @tf_export(v1=["losses.mean_squared_error"]) +@dispatch.add_dispatch_support def mean_squared_error( labels, predictions, weights=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES, @@ -645,6 +654,7 @@ def mean_squared_error( @tf_export(v1=["losses.sigmoid_cross_entropy"]) +@dispatch.add_dispatch_support def sigmoid_cross_entropy( multi_class_labels, logits, weights=1.0, label_smoothing=0, scope=None, loss_collection=ops.GraphKeys.LOSSES, @@ -709,6 +719,7 @@ def sigmoid_cross_entropy( @tf_export(v1=["losses.softmax_cross_entropy"]) +@dispatch.add_dispatch_support def softmax_cross_entropy( onehot_labels, logits, weights=1.0, label_smoothing=0, scope=None, loss_collection=ops.GraphKeys.LOSSES, @@ -831,6 +842,7 @@ def _remove_squeezable_dimensions( @tf_export(v1=["losses.sparse_softmax_cross_entropy"]) +@dispatch.add_dispatch_support def sparse_softmax_cross_entropy( labels, logits, weights=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES, diff --git a/tensorflow/python/ops/manip_ops.py b/tensorflow/python/ops/manip_ops.py index 56e8a894c24..fe99696f82f 100644 --- a/tensorflow/python/ops/manip_ops.py +++ b/tensorflow/python/ops/manip_ops.py @@ -20,11 +20,13 @@ from __future__ import print_function from tensorflow.python.ops import gen_manip_ops as _gen_manip_ops from tensorflow.python.util import deprecation +from tensorflow.python.util import dispatch from tensorflow.python.util.tf_export import tf_export # pylint: disable=protected-access @tf_export('roll', v1=['roll', 'manip.roll']) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints('manip.roll') def roll(input, shift, axis, name=None): # pylint: disable=redefined-builtin return _gen_manip_ops.roll(input, shift, axis, name) diff --git a/tensorflow/python/ops/map_fn.py b/tensorflow/python/ops/map_fn.py index 2c9c678336e..40f8edfcdd1 100644 --- a/tensorflow/python/ops/map_fn.py +++ b/tensorflow/python/ops/map_fn.py @@ -22,6 +22,8 @@ from __future__ import print_function import re +from tensorflow.python.autograph.core import ag_ctx as autograph_ctx +from tensorflow.python.autograph.impl import api as autograph from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops @@ -477,7 +479,9 @@ def map_fn(fn, elems_value_flat = _elems_value_batchable_to_flat(elems_value_batchable, elems_flat_signature) elems_value = elems_unflatten(elems_value_flat) - result_value = fn(elems_value) + ag_ctx = autograph_ctx.control_status_ctx() + autographed_fn = autograph.tf_convert(fn, ag_ctx) + result_value = autographed_fn(elems_value) nest.assert_same_structure(fn_output_signature or elems, result_value) result_value_flat = nest.flatten(result_value) result_value_batchable = _result_value_flat_to_batchable( diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 4c4982c6fd5..ed1db4f539d 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -104,6 +104,7 @@ nextafter = gen_math_ops.next_after @tf_export("linspace", v1=["lin_space", "linspace"]) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints("lin_space") def linspace_nd(start, stop, num, name=None, axis=0): r"""Generates evenly-spaced values in an interval along a given axis. @@ -214,8 +215,8 @@ linspace = linspace_nd arg_max = deprecation.deprecated(None, "Use `tf.math.argmax` instead")(arg_max) # pylint: disable=used-before-assignment arg_min = deprecation.deprecated(None, "Use `tf.math.argmin` instead")(arg_min) # pylint: disable=used-before-assignment -tf_export(v1=["arg_max"])(arg_max) -tf_export(v1=["arg_min"])(arg_min) +tf_export(v1=["arg_max"])(dispatch.add_dispatch_support(arg_max)) +tf_export(v1=["arg_min"])(dispatch.add_dispatch_support(arg_min)) # This is set by resource_variable_ops.py. It is included in this way since @@ -234,6 +235,7 @@ def _set_doc(doc): # pylint: disable=redefined-builtin @tf_export(v1=["math.argmax", "argmax"]) +@dispatch.add_dispatch_support @deprecation.deprecated_args(None, "Use the `axis` argument instead", "dimension") @_set_doc( @@ -250,10 +252,11 @@ def argmax(input, @tf_export("math.argmax", "argmax", v1=[]) +@dispatch.add_dispatch_support def argmax_v2(input, axis=None, output_type=dtypes.int64, name=None): """Returns the index with the largest value across axes of a tensor. - Note that in case of ties the identity of the return value is not guaranteed. + In case of identity returns the smallest index. For example: @@ -266,6 +269,9 @@ def argmax_v2(input, axis=None, output_type=dtypes.int64, name=None): <tf.Tensor: shape=(5,), dtype=int64, numpy=array([2, 2, 0, 2, 2])> >>> tf.math.argmax(B, 1) <tf.Tensor: shape=(3,), dtype=int64, numpy=array([2, 2, 1])> + >>> C = tf.constant([0, 0, 0, 0]) + >>> tf.math.argmax(C) # Returns smallest index in case of ties + <tf.Tensor: shape=(), dtype=int64, numpy=0> Args: input: A `Tensor`. @@ -283,6 +289,7 @@ def argmax_v2(input, axis=None, output_type=dtypes.int64, name=None): @tf_export(v1=["math.argmin", "argmin"]) +@dispatch.add_dispatch_support @deprecation.deprecated_args(None, "Use the `axis` argument instead", "dimension") @_set_doc( @@ -299,10 +306,11 @@ def argmin(input, @tf_export("math.argmin", "argmin", v1=[]) +@dispatch.add_dispatch_support def argmin_v2(input, axis=None, output_type=dtypes.int64, name=None): """Returns the index with the smallest value across axes of a tensor. - Note that in case of ties the identity of the return value is not guaranteed. + Returns the smallest index in case of ties. Args: input: A `Tensor`. Must be one of the following types: `float32`, `float64`, @@ -549,6 +557,7 @@ def _neg(x, name=None): @tf_export(v1=["math.scalar_mul", "scalar_mul"]) +@dispatch.add_dispatch_support def scalar_mul(scalar, x, name=None): """Multiplies a scalar times a `Tensor` or `IndexedSlices` object. @@ -581,6 +590,7 @@ def scalar_mul(scalar, x, name=None): @tf_export("math.scalar_mul", "scalar_mul", v1=[]) +@dispatch.add_dispatch_support @_set_doc(scalar_mul.__doc__) def scalar_mul_v2(scalar, x, name=None): with ops.name_scope(name, "scalar_mul", [x]) as name: @@ -701,6 +711,7 @@ def sign(x, name=None): @tf_export("math.real", v1=["math.real", "real"]) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints("real") @dispatch.add_dispatch_support def real(input, name=None): @@ -735,6 +746,7 @@ def real(input, name=None): @tf_export("math.imag", v1=["math.imag", "imag"]) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints("imag") @dispatch.add_dispatch_support def imag(input, name=None): @@ -768,6 +780,7 @@ def imag(input, name=None): @tf_export("math.angle", v1=["math.angle", "angle"]) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints("angle") @dispatch.add_dispatch_support def angle(input, name=None): @@ -937,6 +950,7 @@ def saturate_cast(value, dtype, name=None): @deprecation.deprecated(date=None, instructions="Use `tf.cast` instead.") @tf_export(v1=["to_float"]) +@dispatch.add_dispatch_support def to_float(x, name="ToFloat"): """Casts a tensor to type `float32`. @@ -956,6 +970,7 @@ def to_float(x, name="ToFloat"): @deprecation.deprecated(date=None, instructions="Use `tf.cast` instead.") @tf_export(v1=["to_double"]) +@dispatch.add_dispatch_support def to_double(x, name="ToDouble"): """Casts a tensor to type `float64`. @@ -975,6 +990,7 @@ def to_double(x, name="ToDouble"): @deprecation.deprecated(date=None, instructions="Use `tf.cast` instead.") @tf_export(v1=["to_int32"]) +@dispatch.add_dispatch_support def to_int32(x, name="ToInt32"): """Casts a tensor to type `int32`. @@ -994,6 +1010,7 @@ def to_int32(x, name="ToInt32"): @deprecation.deprecated(date=None, instructions="Use `tf.cast` instead.") @tf_export(v1=["to_int64"]) +@dispatch.add_dispatch_support def to_int64(x, name="ToInt64"): """Casts a tensor to type `int64`. @@ -1013,6 +1030,7 @@ def to_int64(x, name="ToInt64"): @deprecation.deprecated(date=None, instructions="Use `tf.cast` instead.") @tf_export(v1=["to_bfloat16"]) +@dispatch.add_dispatch_support def to_bfloat16(x, name="ToBFloat16"): """Casts a tensor to type `bfloat16`. @@ -1032,6 +1050,7 @@ def to_bfloat16(x, name="ToBFloat16"): @deprecation.deprecated(date=None, instructions="Use `tf.cast` instead.") @tf_export(v1=["to_complex64"]) +@dispatch.add_dispatch_support def to_complex64(x, name="ToComplex64"): """Casts a tensor to type `complex64`. @@ -1051,6 +1070,7 @@ def to_complex64(x, name="ToComplex64"): @deprecation.deprecated(date=None, instructions="Use `tf.cast` instead.") @tf_export(v1=["to_complex128"]) +@dispatch.add_dispatch_support def to_complex128(x, name="ToComplex128"): """Casts a tensor to type `complex128`. @@ -1090,21 +1110,26 @@ def _OverrideBinaryOperatorHelper(func, op_name, clazz_object=ops.Tensor): def binary_op_wrapper(x, y): with ops.name_scope(None, op_name, [x, y]) as name: - if isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor): + try: return func(x, y, name=name) - elif not isinstance(y, sparse_tensor.SparseTensor): - try: - y = ops.convert_to_tensor_v2( - y, dtype_hint=x.dtype.base_dtype, name="y") - except TypeError: - # If the RHS is not a tensor, it might be a tensor aware object - # that can implement the operator with knowledge of itself - # and the tensor. - if hasattr(type(y), "__r%s__" % op_name): - return NotImplemented - else: - raise - return func(x, y, name=name) + except (TypeError, ValueError) as e: + # Even if dispatching the op failed, the RHS may be a tensor aware + # object that can implement the operator with knowledge of itself + # and the tensor. + # If the RHS is not tensor aware we still want to raise the + # original error from the LHS, because it may be more + # informative. + if hasattr(type(y), "__r%s__" % op_name): + try: + r_op = getattr(y, "__r%s__" % op_name) + out = r_op(x) + if out == NotImplemented: + raise + return out + except (TypeError, ValueError): + raise e + else: + raise def binary_op_wrapper_sparse(sp_x, y): with ops.name_scope(None, op_name, [sp_x, y]) as name: @@ -1184,7 +1209,7 @@ def _sparse_dense_truediv(sp_indices, sp_values, sp_shape, y, name=None): def _truediv_python3(x, y, name=None): with ops.name_scope(name, "truediv", [x, y]) as name: x = ops.convert_to_tensor(x, name="x") - y = ops.convert_to_tensor(y, name="y") + y = ops.convert_to_tensor(y, dtype_hint=x.dtype.base_dtype, name="y") x_dtype = x.dtype.base_dtype y_dtype = y.dtype.base_dtype if x_dtype != y_dtype: @@ -1265,6 +1290,7 @@ def truediv(x, y, name=None): date=None, instructions="Deprecated in favor of operator or tf.math.divide.") @tf_export(v1=["div"]) +@dispatch.add_dispatch_support def div(x, y, name=None): """Divides x / y elementwise (using Python 2 division operator semantics). @@ -1288,6 +1314,7 @@ def div(x, y, name=None): @tf_export("math.divide_no_nan", v1=["math.divide_no_nan", "div_no_nan"]) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints("div_no_nan") @dispatch.add_dispatch_support def div_no_nan(x, y, name=None): @@ -1380,6 +1407,9 @@ floormod = gen_math_ops.floor_mod def _add_dispatch(x, y, name=None): """Dispatches to add for strings and add_v2 for all other types.""" + if not isinstance(y, ops.Tensor) and not isinstance( + y, sparse_tensor.SparseTensor): + y = ops.convert_to_tensor(y, dtype_hint=x.dtype.base_dtype, name="y") if x.dtype == dtypes.string: return gen_math_ops.add(x, y, name=name) else: @@ -1388,14 +1418,12 @@ def _add_dispatch(x, y, name=None): def _mul_dispatch(x, y, name=None): """Dispatches cwise mul for "Dense*Dense" and "Dense*Sparse".""" - is_tensor_y = isinstance(y, ops.Tensor) - if is_tensor_y: - return gen_math_ops.mul(x, y, name=name) - else: - assert isinstance(y, sparse_tensor.SparseTensor) # Case: Dense * Sparse. + if isinstance(y, sparse_tensor.SparseTensor): # Case: Dense * Sparse. new_vals = gen_sparse_ops.sparse_dense_cwise_mul(y.indices, y.values, y.dense_shape, x, name) return sparse_tensor.SparseTensor(y.indices, new_vals, y.dense_shape) + else: + return multiply(x, y, name=name) # NOTE(aselle): When integer division is added for sparse_dense_cwise, @@ -1409,10 +1437,10 @@ _OverrideBinaryOperatorHelper(gen_sparse_ops.sparse_dense_cwise_mul, "mul", sparse_tensor.SparseTensor) _OverrideBinaryOperatorHelper(_add_dispatch, "add") -_OverrideBinaryOperatorHelper(gen_math_ops.sub, "sub") +_OverrideBinaryOperatorHelper(subtract, "sub") _OverrideBinaryOperatorHelper(_mul_dispatch, "mul") -_OverrideBinaryOperatorHelper(_div_python2, "div") -_OverrideBinaryOperatorHelper(_truediv_python3, "truediv") +_OverrideBinaryOperatorHelper(div, "div") +_OverrideBinaryOperatorHelper(truediv, "truediv") _OverrideBinaryOperatorHelper(floordiv, "floordiv") _OverrideBinaryOperatorHelper(gen_math_ops.floor_mod, "mod") _OverrideBinaryOperatorHelper(pow, "pow") @@ -1509,7 +1537,7 @@ def logical_and(x, y, name=None): return gen_math_ops.logical_and(x, y, name) -_OverrideBinaryOperatorHelper(gen_math_ops.logical_and, "and") +_OverrideBinaryOperatorHelper(logical_and, "and") _OverrideBinaryOperatorHelper(gen_math_ops.logical_or, "or") _OverrideBinaryOperatorHelper(logical_xor, "xor") @@ -1620,6 +1648,7 @@ ops.Tensor._override_operator("__ne__", tensor_not_equals) @tf_export("range") +@dispatch.add_dispatch_support def range(start, limit=None, delta=1, dtype=None, name="range"): # pylint: disable=redefined-builtin """Creates a sequence of numbers. @@ -1751,6 +1780,7 @@ def _may_reduce_to_scalar(keepdims, axis, output): @tf_export(v1=["math.reduce_sum", "reduce_sum"]) +@dispatch.add_dispatch_support @deprecation.deprecated_args(None, "keep_dims is deprecated, use keepdims instead", "keep_dims") @@ -1885,6 +1915,7 @@ def reduce_sum_with_dims(input_tensor, @tf_export("math.reduce_euclidean_norm") +@dispatch.add_dispatch_support def reduce_euclidean_norm(input_tensor, axis=None, keepdims=False, name=None): """Computes the Euclidean norm of elements across dimensions of a tensor. @@ -1928,6 +1959,7 @@ def reduce_euclidean_norm(input_tensor, axis=None, keepdims=False, name=None): @tf_export(v1=["math.count_nonzero", "count_nonzero"]) +@dispatch.add_dispatch_support @deprecation.deprecated_args(None, "keep_dims is deprecated, use keepdims instead", "keep_dims") @@ -2005,6 +2037,7 @@ def count_nonzero(input_tensor=None, @tf_export("math.count_nonzero", v1=[]) +@dispatch.add_dispatch_support def count_nonzero_v2( input, # pylint: disable=redefined-builtin axis=None, @@ -2072,6 +2105,7 @@ def count_nonzero_v2( @tf_export(v1=["math.reduce_mean", "reduce_mean"]) +@dispatch.add_dispatch_support def reduce_mean_v1(input_tensor, axis=None, keepdims=None, @@ -2198,6 +2232,7 @@ def reduce_mean(input_tensor, axis=None, keepdims=False, name=None): @tf_export("math.reduce_variance") +@dispatch.add_dispatch_support def reduce_variance(input_tensor, axis=None, keepdims=False, name=None): """Computes the variance of elements across dimensions of a tensor. @@ -2246,6 +2281,7 @@ def reduce_variance(input_tensor, axis=None, keepdims=False, name=None): @tf_export("math.reduce_std") +@dispatch.add_dispatch_support def reduce_std(input_tensor, axis=None, keepdims=False, name=None): """Computes the standard deviation of elements across dimensions of a tensor. @@ -2328,6 +2364,7 @@ def reduce_prod(input_tensor, axis=None, keepdims=False, name=None): @tf_export(v1=["math.reduce_prod", "reduce_prod"]) +@dispatch.add_dispatch_support @deprecation.deprecated_args(None, "keep_dims is deprecated, use keepdims instead", "keep_dims") @@ -2373,6 +2410,7 @@ def reduce_prod_v1(input_tensor, @tf_export(v1=["math.reduce_min", "reduce_min"]) +@dispatch.add_dispatch_support @deprecation.deprecated_args(None, "keep_dims is deprecated, use keepdims instead", "keep_dims") @@ -2459,6 +2497,7 @@ def reduce_min(input_tensor, axis=None, keepdims=False, name=None): @tf_export(v1=["math.reduce_max", "reduce_max"]) +@dispatch.add_dispatch_support @deprecation.deprecated_args(None, "keep_dims is deprecated, use keepdims instead", "keep_dims") @@ -2563,6 +2602,7 @@ def reduce_max_with_dims(input_tensor, @tf_export(v1=["math.reduce_all", "reduce_all"]) +@dispatch.add_dispatch_support @deprecation.deprecated_args(None, "keep_dims is deprecated, use keepdims instead", "keep_dims") @@ -2662,6 +2702,7 @@ def reduce_all(input_tensor, axis=None, keepdims=False, name=None): @tf_export(v1=["math.reduce_any", "reduce_any"]) +@dispatch.add_dispatch_support @deprecation.deprecated_args(None, "keep_dims is deprecated, use keepdims instead", "keep_dims") @@ -2761,6 +2802,7 @@ def reduce_any(input_tensor, axis=None, keepdims=False, name=None): @tf_export(v1=["math.reduce_logsumexp", "reduce_logsumexp"]) +@dispatch.add_dispatch_support @deprecation.deprecated_args(None, "keep_dims is deprecated, use keepdims instead", "keep_dims") @@ -2817,6 +2859,7 @@ def reduce_logsumexp_v1(input_tensor, @tf_export("math.reduce_logsumexp", "reduce_logsumexp", v1=[]) +@dispatch.add_dispatch_support def reduce_logsumexp(input_tensor, axis=None, keepdims=False, name=None): """Computes log(sum(exp(elements across dimensions of a tensor))). @@ -2877,6 +2920,7 @@ def reduce_logsumexp(input_tensor, axis=None, keepdims=False, name=None): @tf_export("linalg.trace", v1=["linalg.trace", "trace"]) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints("trace") @dispatch.add_dispatch_support def trace(x, name=None): @@ -3050,10 +3094,10 @@ def matmul(a, if not isinstance(a, (ops.EagerTensor, _resource_variable_type)): a = ops.convert_to_tensor(a, name="a") if not isinstance(b, (ops.EagerTensor, _resource_variable_type)): - b = ops.convert_to_tensor(b, name="b") + b = ops.convert_to_tensor(b, dtype_hint=a.dtype.base_dtype, name="b") else: a = ops.convert_to_tensor(a, name="a") - b = ops.convert_to_tensor(b, name="b") + b = ops.convert_to_tensor(b, dtype_hint=a.dtype.base_dtype, name="b") # TODO(apassos) remove _shape_tuple here when it is not needed. a_shape = a._shape_tuple() # pylint: disable=protected-access @@ -3116,6 +3160,7 @@ def matmul(a, @tf_export("linalg.matvec") +@dispatch.add_dispatch_support def matvec(a, b, transpose_a=False, @@ -3219,6 +3264,7 @@ _OverrideBinaryOperatorHelper(matmul, "matmul") sparse_matmul = deprecation.deprecated(None, "Use `tf.linalg.matmul` instead")( gen_math_ops.sparse_mat_mul) tf_export(v1=["sparse_matmul"])(sparse_matmul) +@dispatch.add_dispatch_support @ops.RegisterStatistics("MatMul", "flops") @@ -3371,6 +3417,7 @@ def add_n(inputs, name=None): @tf_export("math.accumulate_n", v1=["math.accumulate_n", "accumulate_n"]) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints("accumulate_n") def accumulate_n(inputs, shape=None, tensor_dtype=None, name=None): """Returns the element-wise sum of a list of tensors. @@ -3449,6 +3496,7 @@ def _accumulate_n_grad(op, grad): @tf_export("math.sigmoid", "nn.sigmoid", "sigmoid") +@dispatch.add_dispatch_support def sigmoid(x, name=None): r"""Computes sigmoid of `x` element-wise. @@ -3520,115 +3568,8 @@ def log_sigmoid(x, name=None): return gen_math_ops.neg(gen_nn_ops.softplus(-x), name=name) -@tf_export("math.bincount", v1=[]) -def bincount(arr, - weights=None, - minlength=None, - maxlength=None, - dtype=dtypes.int32, - name=None): - """Counts the number of occurrences of each value in an integer array. - - If `minlength` and `maxlength` are not given, returns a vector with length - `tf.reduce_max(arr) + 1` if `arr` is non-empty, and length 0 otherwise. - If `weights` are non-None, then index `i` of the output stores the sum of the - value in `weights` at each index where the corresponding value in `arr` is - `i`. - - ```python - values = tf.constant([1,1,2,3,2,4,4,5]) - tf.math.bincount(values) #[0 2 2 1 2 1] - ``` - Vector length = Maximum element in vector `values` is 5. Adding 1, which is 6 - will be the vector length. - - Each bin value in the output indicates number of occurrences of the particular - index. Here, index 1 in output has a value 2. This indicates value 1 occurs - two times in `values`. - - ```python - values = tf.constant([1,1,2,3,2,4,4,5]) - weights = tf.constant([1,5,0,1,0,5,4,5]) - tf.math.bincount(values, weights=weights) #[0 6 0 1 9 5] - ``` - Bin will be incremented by the corresponding weight instead of 1. - Here, index 1 in output has a value 6. This is the summation of weights - corresponding to the value in `values`. - - Args: - arr: An int32 tensor of non-negative values. - weights: If non-None, must be the same shape as arr. For each value in - `arr`, the bin will be incremented by the corresponding weight instead of - 1. - minlength: If given, ensures the output has length at least `minlength`, - padding with zeros at the end if necessary. - maxlength: If given, skips values in `arr` that are equal or greater than - `maxlength`, ensuring that the output has length at most `maxlength`. - dtype: If `weights` is None, determines the type of the output bins. - name: A name scope for the associated operations (optional). - - Returns: - A vector with the same dtype as `weights` or the given `dtype`. The bin - values. - - Raises: - `InvalidArgumentError` if negative values are provided as an input. - - """ - name = "bincount" if name is None else name - with ops.name_scope(name): - arr = ops.convert_to_tensor(arr, name="arr", dtype=dtypes.int32) - array_is_nonempty = reduce_prod(array_ops.shape(arr)) > 0 - output_size = cast(array_is_nonempty, dtypes.int32) * (reduce_max(arr) + 1) - if minlength is not None: - minlength = ops.convert_to_tensor( - minlength, name="minlength", dtype=dtypes.int32) - output_size = gen_math_ops.maximum(minlength, output_size) - if maxlength is not None: - maxlength = ops.convert_to_tensor( - maxlength, name="maxlength", dtype=dtypes.int32) - output_size = gen_math_ops.minimum(maxlength, output_size) - if weights is not None: - weights = ops.convert_to_tensor(weights, name="weights") - return gen_math_ops.unsorted_segment_sum(weights, arr, output_size) - weights = constant_op.constant([], dtype) - return gen_math_ops.bincount(arr, output_size, weights) - - -@tf_export(v1=["math.bincount", "bincount"]) -@deprecation.deprecated_endpoints("bincount") -def bincount_v1(arr, - weights=None, - minlength=None, - maxlength=None, - dtype=dtypes.int32): - """Counts the number of occurrences of each value in an integer array. - - If `minlength` and `maxlength` are not given, returns a vector with length - `tf.reduce_max(arr) + 1` if `arr` is non-empty, and length 0 otherwise. - If `weights` are non-None, then index `i` of the output stores the sum of the - value in `weights` at each index where the corresponding value in `arr` is - `i`. - - Args: - arr: An int32 tensor of non-negative values. - weights: If non-None, must be the same shape as arr. For each value in - `arr`, the bin will be incremented by the corresponding weight instead of - 1. - minlength: If given, ensures the output has length at least `minlength`, - padding with zeros at the end if necessary. - maxlength: If given, skips values in `arr` that are equal or greater than - `maxlength`, ensuring that the output has length at most `maxlength`. - dtype: If `weights` is None, determines the type of the output bins. - - Returns: - A vector with the same dtype as `weights` or the given `dtype`. The bin - values. - """ - return bincount(arr, weights, minlength, maxlength, dtype) - - @tf_export("math.cumsum", "cumsum") +@dispatch.add_dispatch_support def cumsum(x, axis=0, exclusive=False, reverse=False, name=None): """Compute the cumulative sum of the tensor `x` along `axis`. @@ -3700,6 +3641,7 @@ def cumsum(x, axis=0, exclusive=False, reverse=False, name=None): @tf_export("math.cumprod", v1=["math.cumprod", "cumprod"]) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints("cumprod") def cumprod(x, axis=0, exclusive=False, reverse=False, name=None): """Compute the cumulative product of the tensor `x` along `axis`. @@ -3753,6 +3695,7 @@ def cumprod(x, axis=0, exclusive=False, reverse=False, name=None): @tf_export("math.cumulative_logsumexp", v1=["math.cumulative_logsumexp"]) +@dispatch.add_dispatch_support def cumulative_logsumexp(x, axis=0, exclusive=False, reverse=False, name=None): """Compute the cumulative log-sum-exp of the tensor `x` along `axis`. @@ -3912,6 +3855,7 @@ def _unsorted_segment_N(data, segment_ids, num_segments): @tf_export( "math.unsorted_segment_mean", v1=["math.unsorted_segment_mean", "unsorted_segment_mean"]) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints("unsorted_segment_mean") @dispatch.add_dispatch_support def unsorted_segment_mean(data, segment_ids, num_segments, name=None): @@ -3958,6 +3902,7 @@ def unsorted_segment_mean(data, segment_ids, num_segments, name=None): @tf_export( "math.unsorted_segment_sqrt_n", v1=["math.unsorted_segment_sqrt_n", "unsorted_segment_sqrt_n"]) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints("unsorted_segment_sqrt_n") @dispatch.add_dispatch_support def unsorted_segment_sqrt_n(data, segment_ids, num_segments, name=None): @@ -4307,6 +4252,7 @@ def sparse_segment_sqrt_n_v2(data, @tf_export("tensordot", "linalg.tensordot") +@dispatch.add_dispatch_support def tensordot(a, b, axes, name=None): r"""Tensor contraction of a and b along specified axes and outer product. @@ -4493,6 +4439,7 @@ def tensordot(a, b, axes, name=None): @tf_export("math.polyval") +@dispatch.add_dispatch_support def polyval(coeffs, x, name=None): r"""Computes the elementwise value of a polynomial. @@ -4505,9 +4452,9 @@ def polyval(coeffs, x, name=None): p(x) = coeffs[n-1] + x * (coeffs[n-2] + ... + x * (coeffs[1] + x * coeffs[0])) - + Usage Example: - + >>> coefficients = [1.0, 2.5, -4.2] >>> x = 5.0 >>> y = tf.math.polyval(coefficients, x) @@ -4563,6 +4510,7 @@ def polyval(coeffs, x, name=None): @tf_export("math.reciprocal_no_nan") +@dispatch.add_dispatch_support def reciprocal_no_nan(x, name=None): """Performs a safe reciprocal operation, element wise. @@ -4665,6 +4613,7 @@ def ndtri(x, name=None): @tf_export("math.ceil", v1=["math.ceil", "ceil"]) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints("ceil") @dispatch.add_dispatch_support def ceil(x, name=None): @@ -4778,6 +4727,7 @@ def exp(x, name=None): @tf_export("math.sobol_sample") +@dispatch.add_dispatch_support def sobol_sample(dim, num_results, skip=0, dtype=dtypes.float32, name=None): """Generates points from the Sobol sequence. @@ -4802,6 +4752,7 @@ def sobol_sample(dim, num_results, skip=0, dtype=dtypes.float32, name=None): @tf_export("math.rsqrt", v1=["math.rsqrt", "rsqrt"]) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints("rsqrt") @dispatch.add_dispatch_support def rsqrt(x, name=None): diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py index 2405eec9e49..9093a06b84a 100644 --- a/tensorflow/python/ops/math_ops_test.py +++ b/tensorflow/python/ops/math_ops_test.py @@ -682,6 +682,45 @@ class BinaryOpsTest(test_util.TensorFlowTestCase): a = array_ops.ones([1], dtype=dtypes.int32) + 1.0 self.evaluate(a) + def testRHSDispatchingAndErrorRaising(self): + if context.executing_eagerly(): + error = ValueError + error_message = ( + r"Attempt to convert a value .* with an unsupported type") + else: + error = TypeError + error_message = ( + r"Failed to convert object of type .* to Tensor") + + class RHSReturnsTrue(object): + + def __radd__(self, other): + return True + a = array_ops.ones([1], dtype=dtypes.int32) + RHSReturnsTrue() + self.assertEqual(a, True) + + class RHSRaisesError(object): + + def __radd__(self, other): + raise TypeError("RHS not implemented") + with self.assertRaisesRegexp(error, error_message): + a = array_ops.ones([1], dtype=dtypes.int32) + RHSRaisesError() + self.evaluate(a) + + class RHSReturnsNotImplemented(object): + + def __radd__(self, other): + return NotImplemented + with self.assertRaisesRegexp(error, error_message): + a = array_ops.ones([1], dtype=dtypes.int32) + RHSReturnsNotImplemented() + self.evaluate(a) + + class RHSNotImplemented(object): + pass + with self.assertRaisesRegexp(error, error_message): + a = array_ops.ones([1], dtype=dtypes.int32) + RHSNotImplemented() + self.evaluate(a) + class SignTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py index 03c1289246e..4bda85077bc 100644 --- a/tensorflow/python/ops/nn_impl.py +++ b/tensorflow/python/ops/nn_impl.py @@ -39,12 +39,14 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.ops import variables from tensorflow.python.ops.losses import util as losses_util from tensorflow.python.platform import device_context +from tensorflow.python.util import dispatch from tensorflow.python.util.deprecation import deprecated_args from tensorflow.python.util.deprecation import deprecated_argument_lookup from tensorflow.python.util.tf_export import tf_export @tf_export("nn.log_poisson_loss") +@dispatch.add_dispatch_support def log_poisson_loss(targets, log_input, compute_full_loss=False, name=None): """Computes log Poisson loss given `log_input`. @@ -110,6 +112,7 @@ def log_poisson_loss(targets, log_input, compute_full_loss=False, name=None): @tf_export(v1=["nn.sigmoid_cross_entropy_with_logits"]) +@dispatch.add_dispatch_support def sigmoid_cross_entropy_with_logits( # pylint: disable=invalid-name _sentinel=None, labels=None, @@ -192,6 +195,7 @@ def sigmoid_cross_entropy_with_logits( # pylint: disable=invalid-name # Note: intentionally calling this v2 to not allow existing code with indirect # imports to ignore the sentinel behavior. @tf_export("nn.sigmoid_cross_entropy_with_logits", v1=[]) +@dispatch.add_dispatch_support def sigmoid_cross_entropy_with_logits_v2( # pylint: disable=invalid-name labels=None, logits=None, @@ -242,6 +246,7 @@ def sigmoid_cross_entropy_with_logits_v2( # pylint: disable=invalid-name @tf_export("nn.weighted_cross_entropy_with_logits", v1=[]) +@dispatch.add_dispatch_support def weighted_cross_entropy_with_logits_v2(labels, logits, pos_weight, name=None): """Computes a weighted cross entropy. @@ -320,6 +325,7 @@ def weighted_cross_entropy_with_logits_v2(labels, logits, pos_weight, @tf_export(v1=["nn.weighted_cross_entropy_with_logits"]) +@dispatch.add_dispatch_support @deprecated_args(None, "targets is deprecated, use labels instead", "targets") def weighted_cross_entropy_with_logits(labels=None, logits=None, @@ -384,6 +390,7 @@ def weighted_cross_entropy_with_logits(labels=None, @tf_export("nn.compute_average_loss") +@dispatch.add_dispatch_support def compute_average_loss(per_example_loss, sample_weight=None, global_batch_size=None): @@ -440,6 +447,7 @@ def compute_average_loss(per_example_loss, @tf_export("nn.scale_regularization_loss") +@dispatch.add_dispatch_support def scale_regularization_loss(regularization_loss): """Scales the sum of the given regularization losses by number of replicas. @@ -478,6 +486,7 @@ def scale_regularization_loss(regularization_loss): @tf_export(v1=["nn.relu_layer"]) +@dispatch.add_dispatch_support def relu_layer(x, weights, biases, name=None): """Computes Relu(x * weight + biases). @@ -501,6 +510,7 @@ def relu_layer(x, weights, biases, name=None): @tf_export("nn.swish") +@dispatch.add_dispatch_support @custom_gradient.custom_gradient def swish(features): # pylint: disable=g-doc-args @@ -538,6 +548,7 @@ def swish(features): # pylint: disable=redefined-builtin @tf_export("linalg.normalize") +@dispatch.add_dispatch_support def normalize(tensor, ord="euclidean", axis=None, name=None): """Normalizes `tensor` along dimension `axis` using specified norm. @@ -590,6 +601,7 @@ def normalize(tensor, ord="euclidean", axis=None, name=None): @tf_export(v1=["math.l2_normalize", "linalg.l2_normalize", "nn.l2_normalize"]) +@dispatch.add_dispatch_support @deprecated_args(None, "dim is deprecated, use axis instead", "dim") def l2_normalize(x, axis=None, epsilon=1e-12, name=None, dim=None): """Normalizes along dimension `axis` using an L2 norm. @@ -618,6 +630,7 @@ def l2_normalize(x, axis=None, epsilon=1e-12, name=None, dim=None): @tf_export("math.l2_normalize", "linalg.l2_normalize", "nn.l2_normalize", v1=[]) +@dispatch.add_dispatch_support def l2_normalize_v2(x, axis=None, epsilon=1e-12, name=None): """Normalizes along dimension `axis` using an L2 norm. @@ -668,6 +681,7 @@ def _count_nonzero(input_tensor, dtype=dtypes.int64): @tf_export("math.zero_fraction", "nn.zero_fraction") +@dispatch.add_dispatch_support def zero_fraction(value, name=None): """Returns the fraction of zeros in `value`. @@ -710,6 +724,7 @@ def zero_fraction(value, name=None): # pylint: disable=redefined-builtin @tf_export(v1=["nn.depthwise_conv2d"]) +@dispatch.add_dispatch_support def depthwise_conv2d(input, filter, strides, @@ -838,6 +853,7 @@ def depthwise_conv2d(input, @tf_export("nn.depthwise_conv2d", v1=[]) +@dispatch.add_dispatch_support def depthwise_conv2d_v2(input, filter, strides, @@ -935,6 +951,7 @@ def depthwise_conv2d_v2(input, # pylint: disable=redefined-builtin,line-too-long @tf_export(v1=["nn.separable_conv2d"]) +@dispatch.add_dispatch_support def separable_conv2d(input, depthwise_filter, pointwise_filter, @@ -1042,6 +1059,7 @@ def separable_conv2d(input, @tf_export("nn.separable_conv2d", v1=[]) +@dispatch.add_dispatch_support def separable_conv2d_v2( input, depthwise_filter, @@ -1117,6 +1135,7 @@ def separable_conv2d_v2( @tf_export(v1=["nn.sufficient_statistics"]) +@dispatch.add_dispatch_support def sufficient_statistics(x, axes, shift=None, keep_dims=None, name=None, keepdims=None): """Calculate the sufficient statistics for the mean and variance of `x`. @@ -1174,6 +1193,7 @@ def sufficient_statistics(x, axes, shift=None, keep_dims=None, name=None, @tf_export("nn.sufficient_statistics", v1=[]) +@dispatch.add_dispatch_support def sufficient_statistics_v2(x, axes, shift=None, keepdims=False, name=None): """Calculate the sufficient statistics for the mean and variance of `x`. @@ -1203,6 +1223,7 @@ def sufficient_statistics_v2(x, axes, shift=None, keepdims=False, name=None): @tf_export("nn.normalize_moments") +@dispatch.add_dispatch_support def normalize_moments(counts, mean_ss, variance_ss, shift, name=None): """Calculate the mean and variance of based on the sufficient statistics. @@ -1235,6 +1256,7 @@ def normalize_moments(counts, mean_ss, variance_ss, shift, name=None): @tf_export(v1=["nn.moments"]) +@dispatch.add_dispatch_support def moments( x, axes, @@ -1300,6 +1322,7 @@ def moments( @tf_export("nn.moments", v1=[]) +@dispatch.add_dispatch_support def moments_v2( x, axes, @@ -1336,6 +1359,7 @@ def moments_v2( @tf_export(v1=["nn.weighted_moments"]) +@dispatch.add_dispatch_support def weighted_moments(x, axes, frequency_weights, name=None, keep_dims=None, keepdims=None): """Returns the frequency-weighted mean and variance of `x`. @@ -1414,6 +1438,7 @@ def weighted_moments(x, axes, frequency_weights, name=None, keep_dims=None, @tf_export("nn.weighted_moments", v1=[]) +@dispatch.add_dispatch_support def weighted_moments_v2(x, axes, frequency_weights, keepdims=False, name=None): """Returns the frequency-weighted mean and variance of `x`. @@ -1438,6 +1463,7 @@ def weighted_moments_v2(x, axes, frequency_weights, keepdims=False, name=None): @tf_export("nn.batch_normalization") +@dispatch.add_dispatch_support def batch_normalization(x, mean, variance, @@ -1508,6 +1534,7 @@ def batch_normalization(x, @tf_export(v1=["nn.fused_batch_norm"]) +@dispatch.add_dispatch_support def fused_batch_norm( x, scale, @@ -1631,6 +1658,7 @@ def fused_batch_norm( @tf_export(v1=["nn.batch_norm_with_global_normalization"]) +@dispatch.add_dispatch_support def batch_norm_with_global_normalization(t=None, m=None, v=None, @@ -1685,6 +1713,7 @@ def batch_norm_with_global_normalization(t=None, # pylint: disable=redefined-builtin,line-too-long @tf_export("nn.batch_norm_with_global_normalization", v1=[]) +@dispatch.add_dispatch_support def batch_norm_with_global_normalization_v2(input, mean, variance, @@ -1934,6 +1963,7 @@ def _compute_sampled_logits(weights, @tf_export("nn.nce_loss", v1=[]) +@dispatch.add_dispatch_support def nce_loss_v2(weights, biases, labels, @@ -2038,6 +2068,7 @@ def nce_loss_v2(weights, @tf_export(v1=["nn.nce_loss"]) +@dispatch.add_dispatch_support def nce_loss(weights, biases, labels, @@ -2149,6 +2180,7 @@ def nce_loss(weights, @tf_export("nn.sampled_softmax_loss", v1=[]) +@dispatch.add_dispatch_support def sampled_softmax_loss_v2(weights, biases, labels, @@ -2240,6 +2272,7 @@ def sampled_softmax_loss_v2(weights, @tf_export(v1=["nn.sampled_softmax_loss"]) +@dispatch.add_dispatch_support def sampled_softmax_loss(weights, biases, labels, diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index 248c57c1ba5..b7dd1d20aae 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import collections +import functools import numbers import os @@ -130,9 +131,9 @@ def _non_atrous_convolution( """ with ops.name_scope(name, "non_atrous_convolution", [input, filter]) as scope: input = ops.convert_to_tensor(input, name="input") # pylint: disable=redefined-builtin - input_shape = input.get_shape() + input_shape = input.shape filter = ops.convert_to_tensor(filter, name="filter") # pylint: disable=redefined-builtin - filter_shape = filter.get_shape() + filter_shape = filter.shape op = _NonAtrousConvolution( input_shape, filter_shape=filter_shape, @@ -147,36 +148,51 @@ class _NonAtrousConvolution(object): """Helper class for _non_atrous_convolution. Note that this class assumes that shapes of input and filter passed to - __call__ are compatible with input_shape and filter_shape passed to the + `__call__` are compatible with `input_shape` and filter_shape passed to the constructor. Arguments: - input_shape: static input shape, i.e. input.get_shape(). - filter_shape: static filter shape, i.e. filter.get_shape(). + input_shape: static input shape, i.e. input.shape. + filter_shape: static filter shape, i.e. filter.shape. padding: see _non_atrous_convolution. data_format: see _non_atrous_convolution. strides: see _non_atrous_convolution. name: see _non_atrous_convolution. + num_batch_dims: (Optional.) The number of batch dimensions in the input; + if not provided, the default of `1` is used. """ def __init__( self, input_shape, - filter_shape, # pylint: disable=redefined-builtin + filter_shape, padding, data_format=None, strides=None, - name=None): - filter_shape = filter_shape.with_rank(input_shape.ndims) + name=None, + num_batch_dims=1): + # filter shape is always rank num_spatial_dims + 2 + # and num_spatial_dims == input_shape.ndims - num_batch_dims - 1 + if input_shape.ndims is not None: + filter_shape = filter_shape.with_rank( + input_shape.ndims - num_batch_dims + 1) self.padding = padding self.name = name - input_shape = input_shape.with_rank(filter_shape.ndims) + # input shape is == num_spatial_dims + num_batch_dims + 1 + # and filter_shape is always rank num_spatial_dims + 2 + if filter_shape.ndims is not None: + input_shape = input_shape.with_rank( + filter_shape.ndims + num_batch_dims - 1) if input_shape.ndims is None: - raise ValueError("Rank of convolution must be known") - if input_shape.ndims < 3 or input_shape.ndims > 5: raise ValueError( - "`input` and `filter` must have rank at least 3 and at most 5") - conv_dims = input_shape.ndims - 2 + "Rank of convolution must be known, but saw input_shape.ndims == {}" + .format(input_shape.ndims)) + if input_shape.ndims < 3 or input_shape.ndims - num_batch_dims + 1 > 5: + raise ValueError( + "`input_shape.ndims - num_batch_dims + 1` must be at least 3 and at " + "most 5 but saw `input_shape.ndims == {}` and `num_batch_dims == {}`" + .format(input_shape.ndims, num_batch_dims)) + conv_dims = input_shape.ndims - num_batch_dims - 1 if strides is None: strides = [1] * conv_dims elif len(strides) != conv_dims: @@ -238,7 +254,57 @@ class _NonAtrousConvolution(object): name=self.name) +def _squeeze_batch_dims(inp, op, inner_rank, name): + """Returns `unsqueeze_batch(op(squeeze_batch(inp)))`. + + Where `squeeze_batch` reshapes `inp` to shape + `[prod(inp.shape[:-inner_rank])] + inp.shape[-inner_rank:]` + and `unsqueeze_batch` does the reverse reshape but on the output. + + Args: + inp: A tensor with dims `batch_shape + inner_shape` where `inner_shape` + is length `inner_rank`. + op: A callable that takes a single input tensor and returns a single. + output tensor. + inner_rank: A python integer. + name: A string. + + Returns: + `unsqueeze_batch_op(squeeze_batch(inp))`. + """ + with ops.name_scope(name, "Convolution", [inp]): + inp = ops.convert_to_tensor(inp, name="input") + shape = inp.shape + + inner_shape = shape[-inner_rank:] + if not inner_shape.is_fully_defined(): + inner_shape = array_ops.shape(inp)[-inner_rank:] + + batch_shape = shape[:-inner_rank] + if not batch_shape.is_fully_defined(): + batch_shape = array_ops.shape(inp)[:-inner_rank] + + if isinstance(inner_shape, tensor_shape.TensorShape): + inp_reshaped = array_ops.reshape(inp, [-1] + inner_shape.as_list()) + else: + inp_reshaped = array_ops.reshape( + inp, array_ops.concat(([-1], inner_shape), axis=-1)) + + out_reshaped = op(inp_reshaped) + + out_inner_shape = out_reshaped.shape[-inner_rank:] + if not out_inner_shape.is_fully_defined(): + out_inner_shape = array_ops.shape(out_reshaped)[-inner_rank:] + + out = array_ops.reshape( + out_reshaped, array_ops.concat((batch_shape, out_inner_shape), axis=-1)) + + out.set_shape(inp.shape[:-inner_rank] + out.shape[-inner_rank:]) + return out + + @tf_export("nn.dilation2d", v1=[]) +@dispatch.add_dispatch_support def dilation2d_v2( input, # pylint: disable=redefined-builtin filters, # pylint: disable=redefined-builtin @@ -306,6 +372,7 @@ def dilation2d_v2( @tf_export(v1=["nn.dilation2d"]) +@dispatch.add_dispatch_support def dilation2d_v1( # pylint: disable=missing-docstring input, # pylint: disable=redefined-builtin filter=None, # pylint: disable=redefined-builtin @@ -324,6 +391,7 @@ dilation2d_v1.__doc__ = gen_nn_ops.dilation2d.__doc__ @tf_export("nn.with_space_to_batch") +@dispatch.add_dispatch_support def with_space_to_batch( input, # pylint: disable=redefined-builtin dilation_rate, @@ -467,7 +535,7 @@ def with_space_to_batch( """ input = ops.convert_to_tensor(input, name="input") # pylint: disable=redefined-builtin - input_shape = input.get_shape() + input_shape = input.shape def build_op(num_spatial_dims, padding): return lambda inp, _: op(inp, num_spatial_dims, padding) @@ -487,18 +555,19 @@ class _WithSpaceToBatch(object): """Helper class for with_space_to_batch. Note that this class assumes that shapes of input and filter passed to - __call__ are compatible with input_shape and filter_shape passed to the - constructor. + `__call__` are compatible with `input_shape`, `filter_shape`, and + `spatial_dims` passed to the constructor. Arguments - input_shape: static shape of input. i.e. input.get_shape(). - dilation_rate: see with_space_to_batch - padding: see with_space_to_batch + input_shape: static shape of input. i.e. input.shape. + dilation_rate: see `with_space_to_batch`. + padding: see `with_space_to_batch`. build_op: Function that maps (num_spatial_dims, paddings) -> (function that maps (input, filter) -> output). - filter_shape: see with_space_to_batch - spatial_dims: see with_space_to_batch - data_format: see with_space_to_batch + filter_shape: see `with_space_to_batch`. + spatial_dims: `see with_space_to_batch`. + data_format: see `with_space_to_batch`. + num_batch_dims: (Optional). Number of batch dims in `input_shape`. """ def __init__(self, @@ -508,24 +577,25 @@ class _WithSpaceToBatch(object): build_op, filter_shape=None, spatial_dims=None, - data_format=None): + data_format=None, + num_batch_dims=1): """Helper class for _with_space_to_batch.""" dilation_rate = ops.convert_to_tensor( dilation_rate, dtypes.int32, name="dilation_rate") - try: - rate_shape = dilation_rate.get_shape().with_rank(1) - except ValueError: - raise ValueError("rate must be rank 1") + if dilation_rate.shape.ndims not in (None, 1): + raise ValueError( + "rate must be rank 1 but saw {}".format(dilation_rate.shape.ndims)) - if not dilation_rate.get_shape().is_fully_defined(): - raise ValueError("rate must have known shape") + if not dilation_rate.shape.is_fully_defined(): + raise ValueError("rate must have known shape, but saw {}" + .format(dilation_rate.shape)) - num_spatial_dims = rate_shape.dims[0].value + num_spatial_dims = dilation_rate.shape.dims[0].value if data_format is not None and data_format.startswith("NC"): - starting_spatial_dim = 2 + starting_spatial_dim = num_batch_dims + 1 else: - starting_spatial_dim = 1 + starting_spatial_dim = num_batch_dims if spatial_dims is None: spatial_dims = range(starting_spatial_dim, @@ -535,7 +605,7 @@ class _WithSpaceToBatch(object): if spatial_dims != orig_spatial_dims or any(x < 1 for x in spatial_dims): raise ValueError( "spatial_dims must be a monotonically increasing sequence of " - "positive integers") + "positive integers, but saw: {}".format(orig_spatial_dims)) if data_format is not None and data_format.startswith("NC"): expected_input_rank = spatial_dims[-1] @@ -546,14 +616,16 @@ class _WithSpaceToBatch(object): input_shape.with_rank_at_least(expected_input_rank) except ValueError: raise ValueError( - "input tensor must have rank %d at least" % (expected_input_rank)) + "input tensor must have rank at least {}, but saw rank {}" + .format(expected_input_rank, input_shape.ndims)) const_rate = tensor_util.constant_value(dilation_rate) rate_or_const_rate = dilation_rate if const_rate is not None: rate_or_const_rate = const_rate if np.any(const_rate < 1): - raise ValueError("dilation_rate must be positive") + raise ValueError("dilation_rate must be positive, but saw: {}" + .format(const_rate)) if np.all(const_rate == 1): self.call = build_op(num_spatial_dims, padding) return @@ -619,6 +691,7 @@ class _WithSpaceToBatch(object): filter_shape = array_ops.shape(filter) base_paddings = _with_space_to_batch_base_paddings( filter_shape, self.num_spatial_dims, self.rate_or_const_rate) + paddings, crops = array_ops.required_space_to_batch_paddings( input_shape=input_spatial_shape, base_paddings=base_paddings, @@ -772,6 +845,7 @@ def _get_strides_and_dilation_rate(num_spatial_dims, strides, dilation_rate): @tf_export(v1=["nn.convolution"]) +@dispatch.add_dispatch_support def convolution( input, # pylint: disable=redefined-builtin filter, # pylint: disable=redefined-builtin @@ -907,7 +981,8 @@ def convolution( @tf_export("nn.convolution", v1=[]) -def convolution_v2( +@dispatch.add_dispatch_support +def convolution_v2( # pylint: disable=missing-docstring input, # pylint: disable=redefined-builtin filters, strides=None, @@ -939,31 +1014,84 @@ def convolution_internal( data_format=None, dilations=None, name=None, - call_from_convolution=True): - """Internal function which performs rank agnostic convolution.""" - if isinstance(input.shape, tensor_shape.TensorShape) and \ - input.shape.rank is not None: - n = len(input.shape) - 2 - elif not isinstance(input.shape, tensor_shape.TensorShape) and \ - input.shape is not None: - n = len(input.shape) - 2 - elif isinstance(filters.shape, tensor_shape.TensorShape) and \ - filters.shape.rank is not None: + call_from_convolution=True, + num_spatial_dims=None): + """Internal function which performs rank agnostic convolution. + + Args: + input: See `convolution`. + filters: See `convolution`. + strides: See `convolution`. + padding: See `convolution`. + data_format: See `convolution`. + dilations: See `convolution`. + name: See `convolution`. + call_from_convolution: See `convolution`. + num_spatial_dims: (Optional.). It is a integer describing the + rank of the spatial dimensions. For `1-D`, `2-D` and `3-D` convolutions, + the value of `num_spatial_dims` is `1`, `2`, and `3`, respectively. + This argument is only required to disambiguate the rank of `batch_shape` + when `filter_shape.ndims is None` and `len(batch_shape) > 1`. For + backwards compatibility, if `num_spatial_dims is None` and + `filter_shape.ndims is None`, then `len(batch_shape)` is assumed to be + `1` (i.e., the input is expected to be + `[batch_size, num_channels] + input_spatial_shape` + or `[batch_size] + input_spatial_shape + [num_channels]`. + + Returns: + A tensor of shape and dtype matching that of `input`. + + Raises: + ValueError: If input and filter both have unknown shapes, or if + `num_spatial_dims` is provided and incompatible with the value + estimated from `filters.shape`. + """ + n = None + if getattr(filters, 'shape', None) is None: + with ops.name_scope(name, 'convolution_internal', [filters, input]): + filters = ops.convert_to_tensor(filters, name='filters') + if (isinstance(filters.shape, tensor_shape.TensorShape) + and filters.shape.rank is not None): n = len(filters.shape) - 2 - elif not isinstance(filters.shape, tensor_shape.TensorShape) and \ - filters.shape is not None: + elif (not isinstance(filters.shape, tensor_shape.TensorShape) + and filters.shape is not None): n = len(filters.shape) - 2 + + if (isinstance(input.shape, tensor_shape.TensorShape) + and input.shape.rank is not None): + if n is None: + n = (num_spatial_dims if num_spatial_dims is not None + else len(input.shape) - 2) + num_batch_dims = len(input.shape) - n - 1 + elif (not isinstance(input.shape, tensor_shape.TensorShape) + and input.shape is not None): + if n is None: + n = (num_spatial_dims if num_spatial_dims is not None + else len(input.shape) - 2) + num_batch_dims = len(input.shape) - n - 1 else: + num_batch_dims = 1 # Default behavior if it cannot be estimated. + + if n is None: raise ValueError("rank of input or filter must be known") + if num_spatial_dims is not None and n != num_spatial_dims: + raise ValueError( + "inconsistent estimate of spatial dims ({}) vs. actual passed " + "num_spatial_dims ({}). n was estimated as len(filters.shape) - 2, " + "but filters shape is: {}".format(n, num_spatial_dims, filters.shape)) + if not 1 <= n <= 3: raise ValueError( - "Input tensor must be of rank 3, 4 or 5 but was {}.".format(n + 2)) + "num_spatial_dims (input.shape.ndims - num_batch_dims - 1) must be one " + "of 1, 2 or 3 but saw {}. num_batch_dims: {}." + .format(n, num_batch_dims)) if data_format is None: - channel_index = n + 1 + channel_index = num_batch_dims + n else: - channel_index = 1 if data_format.startswith("NC") else n + 1 + channel_index = ( + num_batch_dims if data_format.startswith("NC") else n + num_batch_dims) strides = _get_sequence(strides, n, channel_index, "strides") dilations = _get_sequence(dilations, n, channel_index, "dilations") @@ -976,7 +1104,7 @@ def convolution_internal( scope = "convolution" with ops.name_scope(name, scope, [input, filters]) as name: - conv_ops = {1: conv1d, 2: gen_nn_ops.conv2d, 3: gen_nn_ops.conv3d} + conv_ops = {1: conv1d, 2: _conv2d_expanded_batch, 3: gen_nn_ops.conv3d} if device_context.enclosing_tpu_context() is not None or all( i == 1 for i in dilations): @@ -1006,7 +1134,8 @@ def convolution_internal( strides=strides, dilation_rate=dilations, name=name, - data_format=data_format) + data_format=data_format, + num_spatial_dims=n) return op(input, filters) @@ -1014,17 +1143,34 @@ class Convolution(object): """Helper class for convolution. Note that this class assumes that shapes of input and filter passed to - __call__ are compatible with input_shape and filter_shape passed to the - constructor. + `__call__` are compatible with `input_shape`, `filter_shape`, and + `num_spatial_dims` passed to the constructor. Arguments - input_shape: static shape of input. i.e. input.get_shape(). - filter_shape: static shape of the filter. i.e. filter.get_shape(). - padding: see convolution. + input_shape: static shape of input. i.e. input.shape. Its length is + `batch_shape + input_spatial_shape + [num_channels]` if `data_format` + does not start with `NC`, or + `batch_shape + [num_channels] + input_spatial_shape` if `data_format` + starts with `NC`. + filter_shape: static shape of the filter. i.e. filter.shape. + padding: The padding algorithm, must be "SAME" or "VALID". strides: see convolution. dilation_rate: see convolution. name: see convolution. - data_format: see convolution. + data_format: A string or `None`. Specifies whether the channel dimension of + the `input` and output is the last dimension (if `data_format` is `None` + or does not start with `NC`), or the first post-batch dimension (i.e. if + `data_format` starts with `NC`). + num_spatial_dims: (Usually optional.) Python integer, the rank of the + spatial and channel dimensions. For `1-D`, `2-D` and `3-D` convolutions, + the value of `num_spatial_dims` is `1`, `2`, and `3`, respectively. + This argument is only required to disambiguate the rank of `batch_shape` + when `filter_shape.ndims is None` and `len(batch_shape) > 1`. For + backwards compatibility, if `num_spatial_dims is None` and + `filter_shape.ndims is None`, then `len(batch_shape)` is assumed to be + `1` (i.e., the input is expected to be + `[batch_size, num_channels] + input_spatial_shape` + or `[batch_size] + input_spatial_shape + [num_channels]`. """ def __init__(self, @@ -1034,40 +1180,72 @@ class Convolution(object): strides=None, dilation_rate=None, name=None, - data_format=None): + data_format=None, + num_spatial_dims=None): """Helper function for convolution.""" - num_total_dims = filter_shape.ndims - if num_total_dims is None: - num_total_dims = input_shape.ndims - if num_total_dims is None: - raise ValueError("rank of input or filter must be known") + num_batch_dims = None + filter_shape = tensor_shape.as_shape(filter_shape) + input_shape = tensor_shape.as_shape(input_shape) - num_spatial_dims = num_total_dims - 2 + if filter_shape.ndims is not None: + if (num_spatial_dims is not None and + filter_shape.ndims != num_spatial_dims + 2): + raise ValueError( + "Expected filter_shape.ndims == num_spatial_dims + 2, " + "but saw filter_shape.ndims == {} and num_spatial_dims == {}" + .format(filter_shape.ndims, num_spatial_dims)) + else: + num_spatial_dims = filter_shape.ndims - 2 - try: - input_shape.with_rank(num_spatial_dims + 2) - except ValueError: + if input_shape.ndims is not None and num_spatial_dims is not None: + num_batch_dims = input_shape.ndims - num_spatial_dims - 1 + + if num_spatial_dims is None: + num_spatial_dims = input_shape.ndims - 2 + else: + if input_shape.ndims is not None: + if input_shape.ndims < num_spatial_dims + 2: + raise ValueError( + "Expected input_shape.ndims >= num_spatial_dims + 2, but saw " + "input_shape.ndims == {} and num_spatial_dims == {}" + .format(input_shape.ndims, num_spatial_dims)) + else: + if num_batch_dims is None: + num_batch_dims = input_shape.ndims - num_spatial_dims - 1 + + if num_spatial_dims is None: raise ValueError( - "input tensor must have rank %d" % (num_spatial_dims + 2)) + "Cannot estimate num_spatial_dims since input_shape.ndims is None, " + "filter_shape.ndims is None, and argument num_spatial_dims is also " + "None.") - try: - filter_shape.with_rank(num_spatial_dims + 2) - except ValueError: + if num_batch_dims is None: + num_batch_dims = 1 + + if num_batch_dims < 1: raise ValueError( - "filter tensor must have rank %d" % (num_spatial_dims + 2)) + "num_batch_dims should be >= 1, but saw {}. num_batch_dims was " + "estimated as `input_shape.ndims - num_spatial_dims - 1` and " + "num_spatial_dims was either provided or estimated as " + "`filter_shape.ndims - 2`. input_shape.ndims: {}, " + "num_spatial_dims: {}, filter_shape.ndims: {}" + .format(num_batch_dims, input_shape.ndims, num_spatial_dims, + filter_shape.ndims)) if data_format is None or not data_format.startswith("NC"): input_channels_dim = tensor_shape.dimension_at_index( - input_shape, num_spatial_dims + 1) - spatial_dims = range(1, num_spatial_dims + 1) + input_shape, num_spatial_dims + num_batch_dims) + spatial_dims = range(num_batch_dims, num_spatial_dims + num_batch_dims) else: - input_channels_dim = tensor_shape.dimension_at_index(input_shape, 1) - spatial_dims = range(2, num_spatial_dims + 2) + input_channels_dim = tensor_shape.dimension_at_index( + input_shape, num_batch_dims) + spatial_dims = range( + num_batch_dims + 1, num_spatial_dims + num_batch_dims + 1) if not input_channels_dim.is_compatible_with( filter_shape[num_spatial_dims]): raise ValueError( - "number of input channels does not match corresponding dimension of " + "Number of input channels does not match corresponding dimension of " "filter, {} != {}".format(input_channels_dim, filter_shape[num_spatial_dims])) @@ -1081,6 +1259,8 @@ class Convolution(object): self.padding = padding self.name = name self.dilation_rate = dilation_rate + self.num_batch_dims = num_batch_dims + self.num_spatial_dims = num_spatial_dims self.conv_op = _WithSpaceToBatch( input_shape, dilation_rate=dilation_rate, @@ -1088,7 +1268,8 @@ class Convolution(object): build_op=self._build_op, filter_shape=filter_shape, spatial_dims=spatial_dims, - data_format=data_format) + data_format=data_format, + num_batch_dims=num_batch_dims) def _build_op(self, _, padding): return _NonAtrousConvolution( @@ -1097,7 +1278,8 @@ class Convolution(object): padding=padding, data_format=self.data_format, strides=self.strides, - name=self.name) + name=self.name, + num_batch_dims=self.num_batch_dims) def __call__(self, inp, filter): # pylint: disable=redefined-builtin # TPU convolution supports dilations greater than 1. @@ -1110,12 +1292,14 @@ class Convolution(object): data_format=self.data_format, dilations=self.dilation_rate, name=self.name, - call_from_convolution=False) + call_from_convolution=False, + num_spatial_dims=self.num_spatial_dims) else: return self.conv_op(inp, filter) @tf_export(v1=["nn.pool"]) +@dispatch.add_dispatch_support def pool( input, # pylint: disable=redefined-builtin window_shape, @@ -1290,6 +1474,7 @@ def pool( @tf_export("nn.pool", v1=[]) +@dispatch.add_dispatch_support def pool_v2( input, # pylint: disable=redefined-builtin window_shape, @@ -1389,6 +1574,7 @@ def pool_v2( @tf_export("nn.atrous_conv2d") +@dispatch.add_dispatch_support def atrous_conv2d(value, filters, rate, padding, name=None): """Atrous convolution (a.k.a. convolution with holes or dilated convolution). @@ -1576,6 +1762,7 @@ def convert_padding(padding): @tf_export(v1=["nn.conv1d"]) +@dispatch.add_dispatch_support @deprecation.deprecated_arg_values( None, "`NCHW` for data_format is deprecated, use `NCW` instead", @@ -1674,6 +1861,7 @@ def conv1d( @tf_export("nn.conv1d", v1=[]) +@dispatch.add_dispatch_support def conv1d_v2( input, # pylint: disable=redefined-builtin filters, @@ -1739,6 +1927,7 @@ def conv1d_v2( @tf_export("nn.conv1d_transpose") +@dispatch.add_dispatch_support def conv1d_transpose( input, # pylint: disable=redefined-builtin filters, @@ -1827,6 +2016,7 @@ def conv1d_transpose( @tf_export("nn.conv2d", v1=[]) +@dispatch.add_dispatch_support def conv2d_v2(input, # pylint: disable=redefined-builtin filters, strides, @@ -1835,12 +2025,15 @@ def conv2d_v2(input, # pylint: disable=redefined-builtin dilations=None, name=None): # pylint: disable=line-too-long - r"""Computes a 2-D convolution given 4-D `input` and `filters` tensors. + r"""Computes a 2-D convolution given `input` and 4-D `filters` tensors. - Given an input tensor of shape `[batch, in_height, in_width, in_channels]` - and a filter / kernel tensor of shape - `[filter_height, filter_width, in_channels, out_channels]`, this op - performs the following: + The `input` tensor may have rank `4` or higher, where shape dimensions `[:-3]` + are considered batch dimensions (`batch_shape`). + + Given an input tensor of shape + `batch_shape + [in_height, in_width, in_channels]` and a filter / kernel + tensor of shape `[filter_height, filter_width, in_channels, out_channels]`, + this op performs the following: 1. Flattens the filter to a 2-D matrix with shape `[filter_height * filter_width * in_channels, output_channels]`. @@ -1878,8 +2071,9 @@ def conv2d_v2(input, # pylint: disable=redefined-builtin Args: input: A `Tensor`. Must be one of the following types: `half`, `bfloat16`, `float32`, `float64`. - A 4-D tensor. The dimension order is interpreted according to the value - of `data_format`, see below for details. + A 4+-D tensor. The dimension order is interpreted according to the value + of `data_format`; with the all-but-inner-3 dimensions acting as batch + dimensions. See below for details. filters: A `Tensor`. Must have the same type as `input`. A 4-D tensor of shape `[filter_height, filter_width, in_channels, out_channels]` @@ -1899,9 +2093,9 @@ def conv2d_v2(input, # pylint: disable=redefined-builtin Defaults to `"NHWC"`. Specify the data format of the input and output data. With the default format "NHWC", the data is stored in the order of: - [batch, height, width, channels]. + `batch_shape + [height, width, channels]`. Alternatively, the format could be "NCHW", the data storage order of: - [batch, channels, height, width]. + `batch_shape + [channels, height, width]`. dilations: An int or list of `ints` that has length `1`, `2` or `4`, defaults to 1. The dilation factor for each dimension of`input`. If a single value is given it is replicated in the `H` and `W` dimension. By @@ -1913,7 +2107,7 @@ def conv2d_v2(input, # pylint: disable=redefined-builtin name: A name for the operation (optional). Returns: - A `Tensor`. Has the same type as `input`. + A `Tensor`. Has the same type as `input` and the same outer batch shape. """ # pylint: enable=line-too-long return conv2d(input, # pylint: disable=redefined-builtin @@ -1927,6 +2121,7 @@ def conv2d_v2(input, # pylint: disable=redefined-builtin @tf_export(v1=["nn.conv2d"]) +@dispatch.add_dispatch_support def conv2d( # pylint: disable=redefined-builtin,dangerous-default-value input, filter=None, @@ -2012,18 +2207,40 @@ def conv2d( # pylint: disable=redefined-builtin,dangerous-default-value strides = _get_sequence(strides, 2, channel_index, "strides") dilations = _get_sequence(dilations, 2, channel_index, "dilations") - return gen_nn_ops.conv2d(input, # pylint: disable=redefined-builtin - filter, - strides, - padding, - use_cudnn_on_gpu=use_cudnn_on_gpu, - explicit_paddings=explicit_paddings, - data_format=data_format, - dilations=dilations, - name=name) + + # Try really hard to avoid modifying the legacy name scopes - return early. + shape = getattr(input, "shape", None) + if shape is not None: + ndims = getattr(shape, "ndims", -1) + if ndims == -1: ndims = len(shape) + if ndims in (4, 3, 2, 1, 0, None): + return gen_nn_ops.conv2d( + input, + filter=filter, + strides=strides, + padding=padding, + use_cudnn_on_gpu=use_cudnn_on_gpu, + explicit_paddings=explicit_paddings, + data_format=data_format, + dilations=dilations, + name=name) + return _squeeze_batch_dims( + input, + functools.partial( + gen_nn_ops.conv2d, + filter=filter, + strides=strides, + padding=padding, + use_cudnn_on_gpu=use_cudnn_on_gpu, + explicit_paddings=explicit_paddings, + data_format=data_format, + dilations=dilations), + inner_rank=3, + name=name) @tf_export(v1=["nn.conv2d_backprop_filter"]) +@dispatch.add_dispatch_support def conv2d_backprop_filter( # pylint: disable=redefined-builtin,dangerous-default-value input, filter_sizes, @@ -2084,6 +2301,7 @@ def conv2d_backprop_filter( # pylint: disable=redefined-builtin,dangerous-defau @tf_export(v1=["nn.conv2d_backprop_input"]) +@dispatch.add_dispatch_support def conv2d_backprop_input( # pylint: disable=redefined-builtin,dangerous-default-value input_sizes, filter=None, @@ -2148,6 +2366,7 @@ def conv2d_backprop_input( # pylint: disable=redefined-builtin,dangerous-defaul @tf_export(v1=["nn.conv2d_transpose"]) +@dispatch.add_dispatch_support def conv2d_transpose( value=None, filter=None, # pylint: disable=redefined-builtin @@ -2224,6 +2443,7 @@ def conv2d_transpose( @tf_export("nn.conv2d_transpose", v1=[]) +@dispatch.add_dispatch_support def conv2d_transpose_v2( input, # pylint: disable=redefined-builtin filters, # pylint: disable=redefined-builtin @@ -2300,7 +2520,44 @@ def conv2d_transpose_v2( name=name) +def _conv2d_expanded_batch( + input, # pylint: disable=redefined-builtin + filters, + strides, + padding, + data_format, + dilations, + name): + """Helper function for `convolution_internal`; handles expanded batches.""" + # Try really hard to avoid modifying the legacy name scopes - return early. + shape = getattr(input, "shape", None) + if shape is not None: + ndims = getattr(shape, "ndims", -1) + if ndims == -1: ndims = len(shape) + if ndims in (4, 3, 2, 1, 0, None): + return gen_nn_ops.conv2d( + input, + filter=filters, + strides=strides, + padding=padding, + data_format=data_format, + dilations=dilations, + name=name) + return _squeeze_batch_dims( + input, + functools.partial( + gen_nn_ops.conv2d, + filter=filters, + strides=strides, + padding=padding, + data_format=data_format, + dilations=dilations), + inner_rank=3, + name=name) + + @tf_export("nn.atrous_conv2d_transpose") +@dispatch.add_dispatch_support def atrous_conv2d_transpose(value, filters, output_shape, @@ -2459,6 +2716,7 @@ def atrous_conv2d_transpose(value, @tf_export(v1=["nn.depthwise_conv2d_native"]) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints("nn.depthwise_conv2d_native") def depthwise_conv2d_native( # pylint: disable=redefined-builtin,dangerous-default-value input, @@ -2538,6 +2796,7 @@ def depthwise_conv2d_native( # pylint: disable=redefined-builtin,dangerous-defa "nn.depthwise_conv2d_native_backprop_input", "nn.depthwise_conv2d_backprop_input" ]) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints("nn.depthwise_conv2d_native_backprop_input") def depthwise_conv2d_native_backprop_input( # pylint: disable=redefined-builtin,dangerous-default-value input_sizes, @@ -2607,6 +2866,7 @@ def depthwise_conv2d_native_backprop_input( # pylint: disable=redefined-builtin "nn.depthwise_conv2d_native_backprop_filter", "nn.depthwise_conv2d_backprop_filter" ]) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints("nn.depthwise_conv2d_native_backprop_filter") def depthwise_conv2d_native_backprop_filter( # pylint: disable=redefined-builtin,dangerous-default-value input, @@ -2672,6 +2932,7 @@ def depthwise_conv2d_native_backprop_filter( # pylint: disable=redefined-builti @tf_export("nn.conv3d", v1=[]) +@dispatch.add_dispatch_support def conv3d_v2(input, # pylint: disable=redefined-builtin,missing-docstring filters, strides, @@ -2691,6 +2952,7 @@ def conv3d_v2(input, # pylint: disable=redefined-builtin,missing-docstring @tf_export(v1=["nn.conv3d"]) +@dispatch.add_dispatch_support def conv3d_v1( # pylint: disable=missing-docstring,dangerous-default-value input, # pylint: disable=redefined-builtin filter=None, # pylint: disable=redefined-builtin @@ -2711,6 +2973,7 @@ conv3d_v1.__doc__ = gen_nn_ops.conv3d.__doc__ @tf_export(v1=["nn.conv3d_transpose"]) +@dispatch.add_dispatch_support def conv3d_transpose( value, filter=None, # pylint: disable=redefined-builtin @@ -2782,6 +3045,7 @@ def conv3d_transpose( @tf_export("nn.conv3d_transpose", v1=[]) +@dispatch.add_dispatch_support def conv3d_transpose_v2(input, # pylint: disable=redefined-builtin filters, output_shape, @@ -2797,12 +3061,12 @@ def conv3d_transpose_v2(input, # pylint: disable=redefined-builtin rather than an actual deconvolution. Args: - input: A 5-D `Tensor` of type `float` and shape `[batch, height, width, - in_channels]` for `NHWC` data format or `[batch, in_channels, height, - width]` for `NCHW` data format. - filters: A 5-D `Tensor` with the same type as `value` and shape `[height, - width, output_channels, in_channels]`. `filter`'s `in_channels` dimension - must match that of `value`. + input: A 5-D `Tensor` of type `float` and shape `[batch, depth, height, + width, in_channels]` for `NDHWC` data format or `[batch, in_channels, + depth, height, width]` for `NCDHW` data format. + filters: A 5-D `Tensor` with the same type as `value` and shape `[depth, + height, width, output_channels, in_channels]`. `filter`'s `in_channels` + dimension must match that of `value`. output_shape: A 1-D `Tensor` representing the output shape of the deconvolution op. strides: An int or list of `ints` that has length `1`, `3` or `5`. The @@ -2861,6 +3125,7 @@ CONV_TRANSPOSE_OPS = ( @tf_export("nn.conv_transpose") +@dispatch.add_dispatch_support def conv_transpose(input, # pylint: disable=redefined-builtin filters, output_shape, @@ -2958,6 +3223,7 @@ _tf_deterministic_ops.value = None @tf_export("nn.bias_add") +@dispatch.add_dispatch_support def bias_add(value, bias, data_format=None, name=None): """Adds `bias` to `value`. @@ -3047,6 +3313,7 @@ def bias_add_v1(value, bias, name=None): @tf_export(v1=["nn.crelu"]) +@dispatch.add_dispatch_support def crelu(features, name=None, axis=-1): """Computes Concatenated ReLU. @@ -3079,12 +3346,14 @@ def crelu(features, name=None, axis=-1): @tf_export("nn.crelu", v1=[]) +@dispatch.add_dispatch_support def crelu_v2(features, axis=-1, name=None): return crelu(features, name=name, axis=axis) crelu_v2.__doc__ = crelu.__doc__ @tf_export("nn.relu6") +@dispatch.add_dispatch_support def relu6(features, name=None): """Computes Rectified Linear 6: `min(max(features, 0), 6)`. @@ -3107,6 +3376,7 @@ def relu6(features, name=None): @tf_export("nn.leaky_relu") +@dispatch.add_dispatch_support def leaky_relu(features, alpha=0.2, name=None): """Compute the Leaky ReLU activation function. @@ -3245,6 +3515,7 @@ def _softmax(logits, compute_op, dim=-1, name=None): @tf_export(v1=["nn.softmax", "math.softmax"]) +@dispatch.add_dispatch_support @deprecation.deprecated_args(None, "dim is deprecated, use axis instead", "dim") def softmax(logits, axis=None, name=None, dim=None): """Computes softmax activations. @@ -3289,6 +3560,7 @@ def softmax(logits, axis=None, name=None, dim=None): @tf_export("nn.softmax", "math.softmax", v1=[]) +@dispatch.add_dispatch_support def softmax_v2(logits, axis=None, name=None): """Computes softmax activations. @@ -3316,6 +3588,7 @@ def softmax_v2(logits, axis=None, name=None): @tf_export(v1=["nn.log_softmax", "math.log_softmax"]) +@dispatch.add_dispatch_support @deprecation.deprecated_args(None, "dim is deprecated, use axis instead", "dim") def log_softmax(logits, axis=None, name=None, dim=None): """Computes log softmax activations. @@ -3346,6 +3619,7 @@ def log_softmax(logits, axis=None, name=None, dim=None): @tf_export("nn.log_softmax", "math.log_softmax", v1=[]) +@dispatch.add_dispatch_support def log_softmax_v2(logits, axis=None, name=None): """Computes log softmax activations. @@ -3382,6 +3656,7 @@ def _ensure_xent_args(name, sentinel, labels, logits): @tf_export("nn.softmax_cross_entropy_with_logits", v1=[]) +@dispatch.add_dispatch_support def softmax_cross_entropy_with_logits_v2(labels, logits, axis=-1, name=None): """Computes softmax cross entropy between `logits` and `labels`. @@ -3444,6 +3719,7 @@ def softmax_cross_entropy_with_logits_v2(labels, logits, axis=-1, name=None): @tf_export(v1=["nn.softmax_cross_entropy_with_logits_v2"]) +@dispatch.add_dispatch_support @deprecated_args(None, "dim is deprecated, use axis instead", "dim") def softmax_cross_entropy_with_logits_v2_helper( labels, logits, axis=None, name=None, dim=None): @@ -3571,6 +3847,7 @@ See `tf.nn.softmax_cross_entropy_with_logits_v2`. @tf_export(v1=["nn.softmax_cross_entropy_with_logits"]) +@dispatch.add_dispatch_support @deprecation.deprecated(date=None, instructions=_XENT_DEPRECATION) def softmax_cross_entropy_with_logits( _sentinel=None, # pylint: disable=invalid-name @@ -3639,6 +3916,7 @@ def softmax_cross_entropy_with_logits( @tf_export(v1=["nn.sparse_softmax_cross_entropy_with_logits"]) +@dispatch.add_dispatch_support def sparse_softmax_cross_entropy_with_logits( _sentinel=None, # pylint: disable=invalid-name labels=None, @@ -3764,6 +4042,7 @@ def sparse_softmax_cross_entropy_with_logits( @tf_export("nn.sparse_softmax_cross_entropy_with_logits", v1=[]) +@dispatch.add_dispatch_support def sparse_softmax_cross_entropy_with_logits_v2(labels, logits, name=None): """Computes sparse softmax cross entropy between `logits` and `labels`. @@ -3816,6 +4095,7 @@ def sparse_softmax_cross_entropy_with_logits_v2(labels, logits, name=None): @tf_export("nn.avg_pool", v1=["nn.avg_pool_v2"]) +@dispatch.add_dispatch_support def avg_pool_v2(input, ksize, strides, padding, data_format=None, name=None): # pylint: disable=redefined-builtin """Performs the avg pooling on the input. @@ -3878,6 +4158,7 @@ def avg_pool_v2(input, ksize, strides, padding, data_format=None, name=None): # @tf_export(v1=["nn.avg_pool", "nn.avg_pool2d"]) +@dispatch.add_dispatch_support def avg_pool(value, ksize, strides, padding, data_format="NHWC", name=None, input=None): # pylint: disable=redefined-builtin """Performs the average pooling on the input. @@ -3922,6 +4203,7 @@ def avg_pool(value, ksize, strides, padding, data_format="NHWC", @tf_export("nn.avg_pool2d", v1=[]) +@dispatch.add_dispatch_support def avg_pool2d(input, ksize, strides, padding, data_format="NHWC", name=None): # pylint: disable=redefined-builtin """Performs the average pooling on the input. @@ -3961,6 +4243,7 @@ def avg_pool2d(input, ksize, strides, padding, data_format="NHWC", name=None): @tf_export("nn.avg_pool1d") +@dispatch.add_dispatch_support def avg_pool1d(input, ksize, strides, padding, data_format="NWC", name=None): # pylint: disable=redefined-builtin """Performs the average pooling on the input. @@ -4006,6 +4289,7 @@ def avg_pool1d(input, ksize, strides, padding, data_format="NWC", name=None): # @tf_export("nn.avg_pool3d") +@dispatch.add_dispatch_support def avg_pool3d(input, ksize, strides, padding, data_format="NDHWC", name=None): # pylint: disable=redefined-builtin """Performs the average pooling on the input. @@ -4046,6 +4330,7 @@ def avg_pool3d(input, ksize, strides, padding, data_format="NDHWC", name=None): # pylint: disable=redefined-builtin @tf_export("nn.max_pool", v1=["nn.max_pool_v2"]) +@dispatch.add_dispatch_support def max_pool_v2(input, ksize, strides, padding, data_format=None, name=None): """Performs the max pooling on the input. @@ -4106,6 +4391,7 @@ def max_pool_v2(input, ksize, strides, padding, data_format=None, name=None): @tf_export(v1=["nn.max_pool"]) +@dispatch.add_dispatch_support def max_pool(value, ksize, strides, @@ -4155,6 +4441,7 @@ def max_pool(value, # pylint: disable=redefined-builtin @tf_export("nn.max_pool1d") +@dispatch.add_dispatch_support def max_pool1d(input, ksize, strides, padding, data_format="NWC", name=None): """Performs the max pooling on the input. @@ -4199,6 +4486,7 @@ def max_pool1d(input, ksize, strides, padding, data_format="NWC", name=None): # pylint: disable=redefined-builtin @tf_export("nn.max_pool2d") +@dispatch.add_dispatch_support def max_pool2d(input, ksize, strides, padding, data_format="NHWC", name=None): """Performs the max pooling on the input. @@ -4237,6 +4525,7 @@ def max_pool2d(input, ksize, strides, padding, data_format="NHWC", name=None): # pylint: disable=redefined-builtin @tf_export("nn.max_pool3d") +@dispatch.add_dispatch_support def max_pool3d(input, ksize, strides, padding, data_format="NDHWC", name=None): """Performs the max pooling on the input. @@ -4279,6 +4568,7 @@ def max_pool3d(input, ksize, strides, padding, data_format="NDHWC", name=None): @tf_export("nn.max_pool_with_argmax", v1=[]) +@dispatch.add_dispatch_support def max_pool_with_argmax_v2( input, # pylint: disable=redefined-builtin ksize, @@ -4348,6 +4638,7 @@ def max_pool_with_argmax_v2( @tf_export(v1=["nn.max_pool_with_argmax"]) +@dispatch.add_dispatch_support def max_pool_with_argmax_v1( # pylint: disable=missing-docstring,invalid-name input, # pylint: disable=redefined-builtin ksize, @@ -4442,6 +4733,7 @@ def _calc_bias_add_flops(graph, node): @tf_export(v1=["nn.xw_plus_b"]) +@dispatch.add_dispatch_support def xw_plus_b(x, weights, biases, name=None): # pylint: disable=invalid-name """Computes matmul(x, weights) + biases. @@ -4691,6 +4983,7 @@ def dropout_v2(x, rate, noise_shape=None, seed=None, name=None): @tf_export("math.top_k", "nn.top_k") +@dispatch.add_dispatch_support def top_k(input, k=1, sorted=True, name=None): # pylint: disable=redefined-builtin """Finds values and indices of the `k` largest entries for the last dimension. @@ -4751,6 +5044,7 @@ def nth_element(input, n, reverse=False, name=None): # pylint: disable=redefine @tf_export(v1=["nn.fractional_max_pool"]) +@dispatch.add_dispatch_support @deprecation.deprecated(date=None, instructions="`seed2` and `deterministic` " "args are deprecated. Use fractional_max_pool_v2.") def fractional_max_pool(value, @@ -4837,6 +5131,7 @@ def fractional_max_pool(value, @tf_export("nn.fractional_max_pool", v1=[]) +@dispatch.add_dispatch_support def fractional_max_pool_v2(value, pooling_ratio, pseudo_random=False, @@ -4922,6 +5217,7 @@ def fractional_max_pool_v2(value, @tf_export(v1=["nn.fractional_avg_pool"]) +@dispatch.add_dispatch_support @deprecation.deprecated(date=None, instructions="`seed2` and `deterministic` " "args are deprecated. Use fractional_avg_pool_v2.") def fractional_avg_pool(value, @@ -4987,6 +5283,7 @@ def fractional_avg_pool(value, @tf_export("nn.fractional_avg_pool", v1=[]) +@dispatch.add_dispatch_support def fractional_avg_pool_v2(value, pooling_ratio, pseudo_random=False, @@ -5065,6 +5362,7 @@ def _calc_dilation2d_flops(graph, node): @tf_export(v1=["nn.erosion2d"]) +@dispatch.add_dispatch_support def erosion2d(value, kernel, strides, rates, padding, name=None): """Computes the grayscale erosion of 4-D `value` and 3-D `kernel` tensors. @@ -5124,6 +5422,7 @@ def erosion2d(value, kernel, strides, rates, padding, name=None): @tf_export("nn.erosion2d", v1=[]) +@dispatch.add_dispatch_support def erosion2d_v2(value, filters, strides, @@ -5193,6 +5492,7 @@ def erosion2d_v2(value, @tf_export(v1=["math.in_top_k", "nn.in_top_k"]) +@dispatch.add_dispatch_support def in_top_k(predictions, targets, k, name=None): r"""Says whether the targets are in the top `K` predictions. @@ -5227,6 +5527,7 @@ def in_top_k(predictions, targets, k, name=None): @tf_export("math.in_top_k", "nn.in_top_k", v1=[]) +@dispatch.add_dispatch_support def in_top_k_v2(targets, predictions, k, name=None): return in_top_k(predictions, targets, k, name) @@ -5234,7 +5535,11 @@ def in_top_k_v2(targets, predictions, k, name=None): in_top_k_v2.__doc__ = in_top_k.__doc__ -tf_export(v1=["nn.quantized_avg_pool"])(gen_nn_ops.quantized_avg_pool) -tf_export(v1=["nn.quantized_conv2d"])(gen_nn_ops.quantized_conv2d) -tf_export(v1=["nn.quantized_relu_x"])(gen_nn_ops.quantized_relu_x) -tf_export(v1=["nn.quantized_max_pool"])(gen_nn_ops.quantized_max_pool) +tf_export(v1=["nn.quantized_avg_pool"])( + dispatch.add_dispatch_support(gen_nn_ops.quantized_avg_pool)) +tf_export(v1=["nn.quantized_conv2d"])( + dispatch.add_dispatch_support(gen_nn_ops.quantized_conv2d)) +tf_export(v1=["nn.quantized_relu_x"])( + dispatch.add_dispatch_support(gen_nn_ops.quantized_relu_x)) +tf_export(v1=["nn.quantized_max_pool"])( + dispatch.add_dispatch_support(gen_nn_ops.quantized_max_pool)) diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py index 860bdc60387..0088c04f909 100644 --- a/tensorflow/python/ops/nn_test.py +++ b/tensorflow/python/ops/nn_test.py @@ -1199,6 +1199,30 @@ class DataFormatVectorPermuteTest(test_lib.TestCase): y_val = self.evaluate(y) self.assertAllEqual(y_val, [7, 3, 4, 9]) + def testNHWCToNCHW_Size2(self): + x_val = [4, 9] + x = constant_op.constant(x_val) + y = nn_ops.data_format_vec_permute(x) + with test_util.use_gpu(): + y_val = self.evaluate(y) + self.assertAllEqual(y_val, [4, 9]) + + def testNHWCToWHCN(self): + x_val = [7, 4, 9, 3] + x = constant_op.constant(x_val) + y = nn_ops.data_format_vec_permute(x, src_format="NHWC", dst_format="WHCN") + with test_util.use_gpu(): + y_val = self.evaluate(y) + self.assertAllEqual(y_val, [9, 4, 3, 7]) + + def testNHWCToWHCN_Size2(self): + x_val = [4, 9] + x = constant_op.constant(x_val) + y = nn_ops.data_format_vec_permute(x, src_format="NHWC", dst_format="WHCN") + with test_util.use_gpu(): + y_val = self.evaluate(y) + self.assertAllEqual(y_val, [9, 4]) + def testNCHWToNHWC(self): x_val = [7, 4, 9, 3] x = constant_op.constant(x_val) @@ -1207,6 +1231,14 @@ class DataFormatVectorPermuteTest(test_lib.TestCase): y_val = self.evaluate(y) self.assertAllEqual(y_val, [7, 9, 3, 4]) + def testNCHWToNHWC_Size2(self): + x_val = [9, 3] + x = constant_op.constant(x_val) + y = nn_ops.data_format_vec_permute(x) + with test_util.use_gpu(): + y_val = self.evaluate(y) + self.assertAllEqual(y_val, [9, 3]) + def testNHWCToHWNC(self): x_val = [7, 4, 9, 3] x = constant_op.constant(x_val) diff --git a/tensorflow/python/ops/numerics.py b/tensorflow/python/ops/numerics.py index 9f9e7229442..81a532bb150 100644 --- a/tensorflow/python/ops/numerics.py +++ b/tensorflow/python/ops/numerics.py @@ -25,10 +25,12 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.util import deprecation +from tensorflow.python.util import dispatch from tensorflow.python.util.tf_export import tf_export @tf_export(v1=["debugging.assert_all_finite", "verify_tensor_all_finite"]) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints("verify_tensor_all_finite") def verify_tensor_all_finite(t=None, msg=None, name=None, x=None, message=None): """Assert that the tensor does not contain any NaN's or Inf's. @@ -50,6 +52,7 @@ def verify_tensor_all_finite(t=None, msg=None, name=None, x=None, message=None): @tf_export("debugging.assert_all_finite", v1=[]) +@dispatch.add_dispatch_support def verify_tensor_all_finite_v2(x, message, name=None): """Assert that the tensor does not contain any NaN's or Inf's. diff --git a/tensorflow/python/ops/numpy_ops/BUILD b/tensorflow/python/ops/numpy_ops/BUILD new file mode 100644 index 00000000000..5b4dae352d6 --- /dev/null +++ b/tensorflow/python/ops/numpy_ops/BUILD @@ -0,0 +1,16 @@ +# TF numpy API + +package( + default_visibility = [ + "//tensorflow:internal", + ], + licenses = ["notice"], # Apache 2.0 +) + +py_library( + name = "numpy_ops", + srcs = [ + "__init__.py", + ], + srcs_version = "PY2AND3", +) diff --git a/tensorflow/lite/micro/tools/make/download_dependencies.sh b/tensorflow/python/ops/numpy_ops/__init__.py old mode 100755 new mode 100644 similarity index 75% rename from tensorflow/lite/micro/tools/make/download_dependencies.sh rename to tensorflow/python/ops/numpy_ops/__init__.py index df2caedb28d..d78a4c3a6fb --- a/tensorflow/lite/micro/tools/make/download_dependencies.sh +++ b/tensorflow/python/ops/numpy_ops/__init__.py @@ -1,5 +1,4 @@ -#!/bin/bash -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +"""Tensorflow numpy API.""" -set -e - -echo "download_dependencies.sh is no longer needed, just use 'make -f tensorflow/lite/micro/tools/make/Makefile'." >&2 -exit 1 +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py index 01776808525..243471553d9 100644 --- a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py +++ b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py @@ -47,6 +47,7 @@ from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import gradients as gradient_ops from tensorflow.python.ops import image_ops +from tensorflow.python.ops import list_ops from tensorflow.python.ops import logging_ops from tensorflow.python.ops import map_fn from tensorflow.python.ops import math_ops @@ -884,6 +885,136 @@ class TensorArrayTest(PForTestCase): self.assertAllClose(actual_grad, computed_grad) +class TensorListTest(PForTestCase): + + def test_create_outside_and_write(self): + handle1 = list_ops.tensor_list_reserve([], 2, dtypes.int32) + handle2 = list_ops.tensor_list_reserve([], 2, dtypes.int32) + + def loop_fn(i): + h1 = list_ops.tensor_list_set_item(handle1, 0, i) + h1 = list_ops.tensor_list_set_item(h1, 1, 1) + h2 = list_ops.tensor_list_set_item(handle2, 0, 1) + return (list_ops.tensor_list_stack(h1, dtypes.int32), + list_ops.tensor_list_stack(h2, dtypes.int32)) + + self._test_loop_fn(loop_fn, 3) + + def test_create_inside_and_write(self): + + def loop_fn(i): + h1 = list_ops.tensor_list_reserve([], 2, dtypes.int32) + h1 = list_ops.tensor_list_set_item(h1, 0, i) + h1 = list_ops.tensor_list_set_item(h1, 1, 1) + h2 = list_ops.tensor_list_reserve([], 2, dtypes.int32) + h2 = list_ops.tensor_list_set_item(h2, 0, 1) + return (list_ops.tensor_list_stack(h1, dtypes.int32), + list_ops.tensor_list_stack(h2, dtypes.int32)) + + self._test_loop_fn(loop_fn, 3) + + def test_create_outside_and_read(self): + handle = list_ops.tensor_list_reserve([], 2, dtypes.int32) + handle = list_ops.tensor_list_set_item(handle, 0, 0) + handle = list_ops.tensor_list_set_item(handle, 1, 1) + + def loop_fn(i): + return (list_ops.tensor_list_get_item(handle, i, dtypes.int32), + list_ops.tensor_list_get_item(handle, 0, dtypes.int32), + list_ops.tensor_list_length(handle), + list_ops.tensor_list_element_shape(handle, dtypes.int32), + list_ops.tensor_list_element_shape(handle, dtypes.int64)) + + self._test_loop_fn(loop_fn, 2) + + def test_create_inside_and_read(self): + + def loop_fn(i): + handle = list_ops.tensor_list_reserve([], 2, dtypes.int32) + handle = list_ops.tensor_list_set_item(handle, 0, i) + handle = list_ops.tensor_list_set_item(handle, 1, 1) + return (list_ops.tensor_list_get_item(handle, 0, dtypes.int32), + list_ops.tensor_list_get_item(handle, i, dtypes.int32), + list_ops.tensor_list_length(handle), + list_ops.tensor_list_element_shape(handle, dtypes.int32), + list_ops.tensor_list_element_shape(handle, dtypes.int64)) + + self._test_loop_fn(loop_fn, 2) + + def test_create_outside_and_scatter(self): + h = list_ops.tensor_list_reserve([2], 2, dtypes.int32) + + def loop_fn(i): + handle = list_ops.tensor_list_scatter([[i, 2]], [0], input_handle=h) + handle = list_ops.tensor_list_scatter([[1, 2]], [1], input_handle=handle) + handle = list_ops.tensor_list_scatter([[1, 2]], [1], input_handle=handle) + return list_ops.tensor_list_stack(handle, dtypes.int32) + + self._test_loop_fn(loop_fn, 3) + + def test_create_inside_and_scatter(self): + + def loop_fn(i): + handle = list_ops.tensor_list_reserve([2], 2, dtypes.int32) + handle = list_ops.tensor_list_scatter([[i, 2]], [0], input_handle=handle) + handle = list_ops.tensor_list_scatter([[1, 2]], [1], input_handle=handle) + return list_ops.tensor_list_stack(handle, dtypes.int32) + + self._test_loop_fn(loop_fn, 3) + + def test_create_outside_and_gather(self): + handle = list_ops.tensor_list_reserve([2], 2, dtypes.int32) + handle = list_ops.tensor_list_scatter([[2, 3]], [0], input_handle=handle) + handle = list_ops.tensor_list_scatter([[1, 2]], [1], input_handle=handle) + + def loop_fn(i): + return (list_ops.tensor_list_gather(handle, [0, 1], dtypes.int32), + list_ops.tensor_list_gather(handle, [i], dtypes.int32)) + + self._test_loop_fn(loop_fn, 2) + + def test_create_inside_and_gather(self): + + def loop_fn(i): + handle = list_ops.tensor_list_reserve([2], 2, dtypes.int32) + handle = list_ops.tensor_list_scatter([[i, 2]], [0], input_handle=handle) + handle = list_ops.tensor_list_scatter([[1, 2]], [1], input_handle=handle) + return (list_ops.tensor_list_gather(handle, [0, 1], dtypes.int32), + list_ops.tensor_list_gather(handle, [i], dtypes.int32)) + + self._test_loop_fn(loop_fn, 2) + + def test_tensor_list_from_tensor(self): + t = random_ops.random_uniform([2, 3, 4]) + + def loop_fn(i): + handle = list_ops.tensor_list_from_tensor(array_ops.gather(t, i), [4]) + return list_ops.tensor_list_stack(handle, t.dtype) + + self._test_loop_fn(loop_fn, 2) + + def test_tensor_list_reserve_while_loop(self): + # Here a loop invariant TensorList is captured by a while_loop, which then + # performs loop dependent operations on it, resulting in a loop variant + # output. This forces stacking of the variant handle captured by the + # while_loop. + # We handle this particular case by forcing vectorization of + # TensorListReserve operation. + v2_enabled = control_flow_v2_toggles.control_flow_v2_enabled() + control_flow_v2_toggles.enable_control_flow_v2() + def loop_fn(i): + handle = list_ops.tensor_list_reserve([], 2, dtypes.int32) + _, out_handle = control_flow_ops.while_loop( + lambda j, _: j < 2, + lambda j, h: (j + 1, list_ops.tensor_list_set_item(h, j, i)), + (0, handle)) + return list_ops.tensor_list_stack(out_handle, dtypes.int32) + + self._test_loop_fn(loop_fn, 2) + if not v2_enabled: + control_flow_v2_toggles.disable_control_flow_v2() + + class StackTest(PForTestCase): @test_util.run_v1_only("b/122612051") @@ -1903,6 +2034,14 @@ class VariableTest(PForTestCase): ): pfor_control_flow_ops.vectorized_map(f, x) + @test_util.run_all_in_graph_and_eager_modes + def test_variable_shape(self): + v = resource_variable_ops.ResourceVariable([1, 2]) + + def loop_fn(_): + return resource_variable_ops.variable_shape(v.handle) + + self._test_loop_fn(loop_fn, 2) if __name__ == "__main__": test.main() diff --git a/tensorflow/python/ops/parallel_for/pfor.py b/tensorflow/python/ops/parallel_for/pfor.py index c4621758702..582bfecdc76 100644 --- a/tensorflow/python/ops/parallel_for/pfor.py +++ b/tensorflow/python/ops/parallel_for/pfor.py @@ -24,6 +24,7 @@ import string import sys import traceback +import numpy as np import six from tensorflow.compiler.tf2xla.python import xla @@ -52,6 +53,7 @@ from tensorflow.python.ops import gen_random_ops from tensorflow.python.ops import gen_sparse_ops from tensorflow.python.ops import gen_spectral_ops from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import list_ops from tensorflow.python.ops import map_fn from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops @@ -74,6 +76,19 @@ flags.DEFINE_bool( def _stack(t, length): """stacks `t` `length` times.""" + # Note that this stacking may currently be triggered, for example, when a + # loop invariant tensor with dtype variant is input to a while_loop which then + # produces a loop dependent output. Simply stacking the variants may not be + # suitable since operations on stacked handles may expect a vectorized version + # of the variant. + # Given that variant types are generic, we are currently unable to figure out + # which particular variant type is being considered here and hence it may not + # be safe to allow stacking it. + if t.dtype == dtypes.variant: + raise NotImplementedError( + "Vectorization tried to stack variant tensor %s. " + "This is likely because vectorization of that variant " + "is not fully supported yet." % t) ones = array_ops.ones_like(array_ops.shape(t)) ones = array_ops.reshape(ones, [-1]) length = array_ops.reshape(length, [-1]) @@ -93,6 +108,7 @@ def _stack(t, length): passthrough_stateful_ops = set([ "VariableV2", "VarHandleOp", + "VariableShape", "ReadVariableOp", "StackV2", "TensorArrayWriteV3", @@ -101,6 +117,15 @@ passthrough_stateful_ops = set([ ]) +# Ops which we will treat like stateful for the purpose of vectorization. +# Typically this is used to force pfor converters to run for these ops. +force_stateful_ops = set([ + # We vectorize this since we need to change the element shape set on the + # list. + "TensorListReserve", +]) + + def _is_stateful_pfor_op(op): if isinstance(op, WhileOp): return op.is_stateful @@ -109,6 +134,8 @@ def _is_stateful_pfor_op(op): return False if op.type in passthrough_stateful_ops: return False + if op.type in force_stateful_ops: + return True assert hasattr(op, "op_def") and op.op_def is not None, op return op.op_def.is_stateful @@ -964,8 +991,9 @@ def wrap(tensor, is_stacked=True, is_sparse_stacked=False): return WrappedTensor(tensor, is_stacked, is_sparse_stacked) -def _fallback_converter(pfor_input): - logging.warn("Using a while_loop for converting %s", pfor_input.op_type) +def _fallback_converter(pfor_input, warn=True): + if warn: + logging.warn("Using a while_loop for converting %s", pfor_input.op_type) output_dtypes = [x.dtype for x in pfor_input.outputs] iters = pfor_input.pfor.loop_len_vector[0] @@ -2062,7 +2090,7 @@ def _convert_diag(pfor_input): else: # It is not clear if we can do better than a while loop here with existing # kernels. - return _fallback_converter(pfor_input) + return _fallback_converter(pfor_input, warn=False) # See notes for MatrixDiagV2 @@ -2105,7 +2133,7 @@ def _convert_diag_part(pfor_input): else: # It is not clear if we can do better than a while loop here with existing # kernels. - return _fallback_converter(pfor_input) + return _fallback_converter(pfor_input, warn=False) @RegisterPFor("OneHot") @@ -3039,7 +3067,7 @@ def _convert_stateless_multinomial(pfor_input): # random numbers under vectorization. # Unfortunately, the kernels currently are not necessarily setup to do this # efficiently and hence we fallback to a sequential loop for vectorization. - return _fallback_converter(pfor_input) + return _fallback_converter(pfor_input, warn=False) # linalg_ops @@ -3471,6 +3499,241 @@ def _convert_tensor_array_grad_v3(pfor_input): return [wrap(grad_handle, False), wrap(flow_out, True)] +def _stack_tensor_list_shape(shape, pfor_input): + first_dim = pfor_input.pfor.loop_len_vector + shape_value = tensor_util.constant_value(shape) + # Note that negative values in the shape are used to signify unknown shapes + # and are handled in a special way. + if shape_value is not None: + shape_value = np.asarray(shape_value) + if -1 in shape_value: + return constant_op.constant(-1) + elif not shape_value.size: + return first_dim + else: + shape = array_ops.reshape(shape, [-1]) + return control_flow_ops.cond( + math_ops.reduce_any(shape < 0), + lambda: constant_op.constant(-1), + lambda: array_ops.concat([first_dim, shape], axis=0)) + + +def _tile_variant(t, pfor_input): + """stacks `t` `length` times.""" + t.set_shape([]) + t = array_ops.reshape(t, [-1]) + with ops.device("CPU:0"): + return array_ops.tile(t, pfor_input.pfor.loop_len_vector) + + +def _untile_variant(t): + return array_ops.gather(t, 0) + + +@RegisterPFor("TensorListReserve") +def _convert_tensor_list_reserve(pfor_input): + element_shape = pfor_input.unstacked_input(0) + num_elements = pfor_input.unstacked_input(1) + element_dtype = pfor_input.get_attr("element_dtype") + + # Prepend a dimension to element_shape. + element_shape = _stack_tensor_list_shape(element_shape, pfor_input) + handle = list_ops.tensor_list_reserve( + element_shape, num_elements, element_dtype=element_dtype) + + return wrap(_tile_variant(handle, pfor_input), True) + + +@RegisterPFor("TensorListElementShape") +def _convert_tensor_list_element_shape(pfor_input): + handle = _untile_variant(pfor_input.stacked_input(0)) + shape_type = pfor_input.get_attr("shape_type") + shape = list_ops.tensor_list_element_shape(handle, shape_type) + shape = array_ops.reshape(shape, [-1]) + shape = shape[1:] + return wrap(shape, False) + + +@RegisterPFor("TensorListLength") +def _convert_tensor_list_length(pfor_input): + handle = _untile_variant(pfor_input.stacked_input(0)) + return wrap(list_ops.tensor_list_length(handle), False) + + +def _stack_tensor_list(handle, dtype, pfor_input, element_shape=None): + if element_shape is None: + element_shape = list_ops.tensor_list_element_shape(handle, dtypes.int32) + length = list_ops.tensor_list_length(handle) + new_handle = list_ops.tensor_list_reserve( + _stack_tensor_list_shape(element_shape, pfor_input), length, dtype) + + def _body_fn(i, h): + elem = list_ops.tensor_list_get_item(handle, i, dtype, element_shape) + elem = _stack(elem, pfor_input.pfor.loop_len_vector).t + return i + 1, list_ops.tensor_list_set_item(h, i, elem) + + return control_flow_ops.while_loop(lambda i, _: i < length, _body_fn, + [0, new_handle])[1] + + +@RegisterPFor("TensorListGetItem") +def _convert_tensor_list_get_item(pfor_input): + handle, handle_stacked, _ = pfor_input.input(0) + index, index_stacked, _ = pfor_input.input(1) + element_shape = pfor_input.unstacked_input(2) + element_dtype = pfor_input.get_attr("element_dtype") + + if handle_stacked: + handle = _untile_variant(handle) + element_shape = _stack_tensor_list_shape(element_shape, pfor_input) + if index_stacked: + # We use a sequential loop since that may be more efficient than first + # gathering and concatenating all the element corresponding to `index`, + # and then doing a gather on it. + def _map_fn(i): + item_i = list_ops.tensor_list_get_item( + handle, + index[i], + element_dtype=element_dtype) + return array_ops.gather(item_i, i) + + output = map_fn.map_fn(_map_fn, pfor_input.pfor.all_indices) + return wrap(output, True) + else: + output = list_ops.tensor_list_get_item( + handle, + index, + element_shape=element_shape, + element_dtype=element_dtype) + return wrap(output, True) + else: + assert index_stacked + return wrap( + list_ops.tensor_list_gather( + handle, + index, + element_shape=element_shape, + element_dtype=element_dtype), True) + + +@RegisterPFor("TensorListSetItem") +def _convert_tensor_array_set_item(pfor_input): + handle, handle_stacked, _ = pfor_input.input(0) + index, index_stacked, _ = pfor_input.input(1) + item, item_stacked, _ = pfor_input.input(2) + + if not handle_stacked: + # Special case where we can statically guarantee that the indices are + # disjoint. + if index is pfor_input.pfor.all_indices: + if not item_stacked: + item = _stack(item, pfor_input.pfor.loop_len_vector).t + return wrap( + list_ops.tensor_list_scatter(item, index, input_handle=handle), False) + else: + handle = _stack_tensor_list(handle, item.dtype, pfor_input) + else: + handle = _untile_variant(handle) + + if index_stacked: + # TODO(agarwal): handle this. + raise ValueError("Vectorizing writes to a TensorList with loop " + "variant indices is currently unsupported.") + + else: + if not item_stacked: + item = _stack(item, pfor_input.pfor.loop_len_vector).t + handle = list_ops.tensor_list_set_item(handle, index, item) + return wrap(_tile_variant(handle, pfor_input), True) + + +@RegisterPFor("TensorListStack") +def _convert_tensor_list_stack(pfor_input): + handle = pfor_input.stacked_input(0) + input_shape = pfor_input.unstacked_input(1) + element_dtype = pfor_input.get_attr("element_dtype") + num_elements = pfor_input.get_attr("num_elements") + + handle = _untile_variant(handle) + input_shape = _stack_tensor_list_shape(input_shape, pfor_input) + output = list_ops.tensor_list_stack( + handle, + element_dtype, + element_shape=input_shape, + num_elements=num_elements) + output = _transpose_first_two_dims(output) + return wrap(output, True) + + +@RegisterPFor("TensorListGather") +def _convert_tensor_list_gather(pfor_input): + handle, handle_stacked, _ = pfor_input.input(0) + index, index_stacked, _ = pfor_input.input(1) + element_shape = pfor_input.unstacked_input(2) + element_dtype = pfor_input.get_attr("element_dtype") + + if handle_stacked: + handle = _untile_variant(handle) + element_shape = _stack_tensor_list_shape(element_shape, pfor_input) + if index_stacked: + # We use a sequential loop since that may be more efficient than first + # gathering and concatenating all the element corresponding to `index`, + # and then doing a gather on it. + def _map_fn(i): + item_i = list_ops.tensor_list_gather( + handle, + index[i], + element_dtype=element_dtype) + axis = array_ops.rank(index) - 1 + return array_ops.gather(item_i, i, axis=axis) + + output = map_fn.map_fn(_map_fn, pfor_input.pfor.all_indices) + return wrap(output, True) + else: + output = list_ops.tensor_list_gather( + handle, + index, + element_shape=element_shape, + element_dtype=element_dtype) + return wrap(output, True) + else: + assert index_stacked + index_shape = array_ops.shape(index) + index = array_ops.reshape(index, [-1]) + values = list_ops.tensor_list_gather( + handle, index, element_shape=element_shape, element_dtype=element_dtype) + final_shape = array_ops.concat( + [index_shape, array_ops.shape(values)[1:]], axis=0) + return wrap(array_ops.reshape(values, final_shape), True) + + +@RegisterPFor("TensorListScatterIntoExistingList") +def _convert_tensor_list_scatter(pfor_input): + pfor_input.stack_inputs([1]) + handle, handle_stacked, _ = pfor_input.input(0) + item = pfor_input.stacked_input(1) + # TODO(agarwal): handle stacked indices. + indices = pfor_input.unstacked_input(2) + if handle_stacked: + handle = _untile_variant(handle) + else: + handle = _stack_tensor_list(handle, item.dtype, pfor_input) + + item = _transpose_first_two_dims(item) + handle = list_ops.tensor_list_scatter(item, indices, input_handle=handle) + return wrap(_tile_variant(handle, pfor_input), True) + + +@RegisterPFor("TensorListFromTensor") +def _convert_tensor_list_from_tensor(pfor_input): + tensor = pfor_input.stacked_input(0) + element_shape = pfor_input.unstacked_input(1) + tensor = _transpose_first_two_dims(tensor) + element_shape = _stack_tensor_list_shape(element_shape, pfor_input) + handle = list_ops.tensor_list_from_tensor(tensor, element_shape) + return wrap(_tile_variant(handle, pfor_input), True) + + # StackV2 conversion is tricky since we don't have arrays of StackV2. So similar # to TensorArrays, we convert them by changing the dimension of the elements # inside the stack. diff --git a/tensorflow/python/ops/parsing_ops.py b/tensorflow/python/ops/parsing_ops.py index 8e518e913be..edcae89aada 100644 --- a/tensorflow/python/ops/parsing_ops.py +++ b/tensorflow/python/ops/parsing_ops.py @@ -30,6 +30,7 @@ from tensorflow.python.ops import parsing_config from tensorflow.python.ops.gen_parsing_ops import * # pylint: enable=wildcard-import,undefined-variable from tensorflow.python.util import deprecation +from tensorflow.python.util import dispatch from tensorflow.python.util.tf_export import tf_export @@ -77,6 +78,7 @@ def _prepend_none_dimension(features): @tf_export("io.parse_example", v1=[]) +@dispatch.add_dispatch_support def parse_example_v2(serialized, features, example_names=None, name=None): # pylint: disable=line-too-long """Parses `Example` protos into a `dict` of tensors. @@ -314,6 +316,7 @@ def parse_example_v2(serialized, features, example_names=None, name=None): @tf_export(v1=["io.parse_example", "parse_example"]) +@dispatch.add_dispatch_support def parse_example(serialized, features, name=None, example_names=None): return parse_example_v2(serialized, features, example_names, name) @@ -373,6 +376,7 @@ def _parse_example_raw(serialized, names, params, name): @tf_export(v1=["io.parse_single_example", "parse_single_example"]) +@dispatch.add_dispatch_support def parse_single_example(serialized, features, name=None, example_names=None): """Parses a single `Example` proto. @@ -407,6 +411,7 @@ def parse_single_example(serialized, features, name=None, example_names=None): @tf_export("io.parse_single_example", v1=[]) +@dispatch.add_dispatch_support def parse_single_example_v2( serialized, features, example_names=None, name=None ): @@ -448,6 +453,7 @@ def parse_single_example_v2( @tf_export("io.parse_sequence_example") +@dispatch.add_dispatch_support def parse_sequence_example(serialized, context_features=None, sequence_features=None, @@ -692,6 +698,7 @@ def _parse_sequence_example_raw(serialized, @tf_export("io.parse_single_sequence_example", v1=["io.parse_single_sequence_example", "parse_single_sequence_example"]) +@dispatch.add_dispatch_support def parse_single_sequence_example( serialized, context_features=None, sequence_features=None, example_name=None, name=None): @@ -835,6 +842,7 @@ def _parse_single_sequence_example_raw(serialized, @tf_export("io.decode_raw", v1=[]) +@dispatch.add_dispatch_support def decode_raw(input_bytes, out_type, little_endian=True, @@ -877,6 +885,7 @@ def decode_raw(input_bytes, @tf_export(v1=["decode_raw", "io.decode_raw"]) +@dispatch.add_dispatch_support @deprecation.deprecated_args(None, "bytes is deprecated, use input_bytes instead", "bytes") @@ -921,6 +930,7 @@ def decode_raw_v1( # Swap `name` and `na_value` for backward compatibility. @tf_export(v1=["io.decode_csv", "decode_csv"]) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints("decode_csv") def decode_csv(records, record_defaults, @@ -970,6 +980,7 @@ def decode_csv(records, @tf_export("io.decode_csv", v1=[]) +@dispatch.add_dispatch_support def decode_csv_v2(records, record_defaults, field_delim=",", diff --git a/tensorflow/python/ops/proto_ops.py b/tensorflow/python/ops/proto_ops.py index 1f7300dbef9..0e19aad584c 100644 --- a/tensorflow/python/ops/proto_ops.py +++ b/tensorflow/python/ops/proto_ops.py @@ -22,10 +22,11 @@ from __future__ import print_function from tensorflow.python.framework import ops from tensorflow.python.ops.gen_decode_proto_ops import decode_proto_v2 as decode_proto from tensorflow.python.ops.gen_encode_proto_ops import encode_proto +from tensorflow.python.util import dispatch from tensorflow.python.util.tf_export import tf_export -tf_export("io.decode_proto")(decode_proto) -tf_export("io.encode_proto")(encode_proto) +tf_export("io.decode_proto")(dispatch.add_dispatch_support(decode_proto)) +tf_export("io.encode_proto")(dispatch.add_dispatch_support(encode_proto)) ops.NotDifferentiable("DecodeProtoV2") ops.NotDifferentiable("EncodeProto") diff --git a/tensorflow/python/ops/ragged/BUILD b/tensorflow/python/ops/ragged/BUILD index 66cac6a11d2..b2a02b82454 100644 --- a/tensorflow/python/ops/ragged/BUILD +++ b/tensorflow/python/ops/ragged/BUILD @@ -307,6 +307,7 @@ py_library( deps = [ ":segment_id_ops", "//tensorflow/python:array_ops", + "//tensorflow/python:bincount_ops", "//tensorflow/python:check_ops", "//tensorflow/python:constant_op", "//tensorflow/python:control_flow_ops", @@ -417,6 +418,7 @@ py_library( deps = [ ":ragged_util", "//tensorflow/python:array_ops", + "//tensorflow/python:bincount_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", diff --git a/tensorflow/python/ops/ragged/ragged_array_ops.py b/tensorflow/python/ops/ragged/ragged_array_ops.py index 7f971cd558f..782902f2f71 100644 --- a/tensorflow/python/ops/ragged/ragged_array_ops.py +++ b/tensorflow/python/ops/ragged/ragged_array_ops.py @@ -32,6 +32,7 @@ from tensorflow.python.ops.ragged import ragged_math_ops from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops.ragged import ragged_util from tensorflow.python.ops.ragged import segment_id_ops +from tensorflow.python.util import dispatch from tensorflow.python.util.tf_export import tf_export #=============================================================================== @@ -40,6 +41,7 @@ from tensorflow.python.util.tf_export import tf_export @tf_export('ragged.boolean_mask') +@dispatch.add_dispatch_support def boolean_mask(data, mask, name=None): """Applies a boolean mask to `data` without flattening the mask dimensions. @@ -538,6 +540,7 @@ def ragged_one_hot(indices, # ragged.stack_dynamic_partitions #=============================================================================== @tf_export('ragged.stack_dynamic_partitions') +@dispatch.add_dispatch_support def stack_dynamic_partitions(data, partitions, num_partitions, name=None): """Stacks dynamic partitions of a Tensor or RaggedTensor. @@ -699,6 +702,7 @@ def reverse(tensor, axis, name=None): @tf_export('ragged.cross') +@dispatch.add_dispatch_support def cross(inputs, name=None): """Generates feature cross from a list of tensors. @@ -725,6 +729,7 @@ def cross(inputs, name=None): @tf_export('ragged.cross_hashed') +@dispatch.add_dispatch_support def cross_hashed(inputs, num_buckets=0, hash_key=None, name=None): """Generates hashed feature cross from a list of tensors. diff --git a/tensorflow/python/ops/ragged/ragged_concat_ops.py b/tensorflow/python/ops/ragged/ragged_concat_ops.py index 9bcb1aa4765..cd710f449a6 100644 --- a/tensorflow/python/ops/ragged/ragged_concat_ops.py +++ b/tensorflow/python/ops/ragged/ragged_concat_ops.py @@ -27,6 +27,7 @@ from tensorflow.python.ops.ragged import ragged_array_ops from tensorflow.python.ops.ragged import ragged_gather_ops from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops.ragged import ragged_util +from tensorflow.python.util import dispatch from tensorflow.python.util.tf_export import tf_export @@ -71,6 +72,7 @@ def concat(values, axis, name=None): @tf_export('ragged.stack') +@dispatch.add_dispatch_support def stack(values, axis=0, name=None): """Stacks a list of rank-`R` tensors into one rank-`(R+1)` `RaggedTensor`. diff --git a/tensorflow/python/ops/ragged/ragged_factory_ops.py b/tensorflow/python/ops/ragged/ragged_factory_ops.py index aa148ae7fe8..3a6f6231149 100644 --- a/tensorflow/python/ops/ragged/ragged_factory_ops.py +++ b/tensorflow/python/ops/ragged/ragged_factory_ops.py @@ -27,6 +27,7 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops.ragged import ragged_tensor_value +from tensorflow.python.util import dispatch from tensorflow.python.util.tf_export import tf_export @@ -34,6 +35,7 @@ from tensorflow.python.util.tf_export import tf_export # Op to construct a constant RaggedTensor from a nested Python list. #=============================================================================== @tf_export("ragged.constant") +@dispatch.add_dispatch_support def constant(pylist, dtype=None, ragged_rank=None, inner_shape=None, name=None, row_splits_dtype=dtypes.int64): """Constructs a constant RaggedTensor from a nested Python list. @@ -86,6 +88,7 @@ def constant(pylist, dtype=None, ragged_rank=None, inner_shape=None, @tf_export(v1=["ragged.constant_value"]) +@dispatch.add_dispatch_support def constant_value(pylist, dtype=None, ragged_rank=None, inner_shape=None, row_splits_dtype="int64"): """Constructs a RaggedTensorValue from a nested Python list. @@ -311,6 +314,7 @@ def _default_inner_shape_for_pylist(pylist, ragged_rank): @tf_export(v1=["ragged.placeholder"]) +@dispatch.add_dispatch_support def placeholder(dtype, ragged_rank, value_shape=None, name=None): """Creates a placeholder for a `tf.RaggedTensor` that will always be fed. diff --git a/tensorflow/python/ops/ragged/ragged_functional_ops.py b/tensorflow/python/ops/ragged/ragged_functional_ops.py index cc45f729e58..00b5ced6170 100644 --- a/tensorflow/python/ops/ragged/ragged_functional_ops.py +++ b/tensorflow/python/ops/ragged/ragged_functional_ops.py @@ -24,10 +24,12 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops.ragged import ragged_config from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops.ragged import ragged_util +from tensorflow.python.util import dispatch from tensorflow.python.util.tf_export import tf_export @tf_export("ragged.map_flat_values") +@dispatch.add_dispatch_support def map_flat_values(op, *args, **kwargs): """Applies `op` to the values of one or more RaggedTensors. diff --git a/tensorflow/python/ops/ragged/ragged_math_ops.py b/tensorflow/python/ops/ragged/ragged_math_ops.py index 5483cda571c..73a53583ada 100644 --- a/tensorflow/python/ops/ragged/ragged_math_ops.py +++ b/tensorflow/python/ops/ragged/ragged_math_ops.py @@ -30,6 +30,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops.ragged import ragged_functional_ops from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops.ragged import segment_id_ops +from tensorflow.python.util import dispatch from tensorflow.python.util.tf_export import tf_export @@ -38,6 +39,7 @@ from tensorflow.python.util.tf_export import tf_export #=============================================================================== # pylint: disable=redefined-builtin @tf_export('ragged.range') +@dispatch.add_dispatch_support def range(starts, limits=None, deltas=1, dtype=None, name=None, row_splits_dtype=dtypes.int64): """Returns a `RaggedTensor` containing the specified sequences of numbers. diff --git a/tensorflow/python/ops/ragged/ragged_string_ops.py b/tensorflow/python/ops/ragged/ragged_string_ops.py index d5f21832044..0d9c4d506f3 100755 --- a/tensorflow/python/ops/ragged/ragged_string_ops.py +++ b/tensorflow/python/ops/ragged/ragged_string_ops.py @@ -29,10 +29,12 @@ from tensorflow.python.ops.ragged import ragged_math_ops from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.util import compat as util_compat from tensorflow.python.util import deprecation +from tensorflow.python.util import dispatch from tensorflow.python.util.tf_export import tf_export @tf_export("strings.bytes_split") +@dispatch.add_dispatch_support def string_bytes_split(input, name=None): # pylint: disable=redefined-builtin """Split string elements of `input` into bytes. @@ -80,6 +82,7 @@ def string_bytes_split(input, name=None): # pylint: disable=redefined-builtin # pylint: disable=redefined-builtin @tf_export("strings.unicode_encode") +@dispatch.add_dispatch_support def unicode_encode(input, output_encoding, errors="replace", @@ -177,6 +180,7 @@ def unicode_encode(input, # pylint: disable=redefined-builtin @tf_export("strings.unicode_decode") +@dispatch.add_dispatch_support def unicode_decode(input, input_encoding, errors="replace", @@ -222,6 +226,7 @@ def unicode_decode(input, @tf_export("strings.unicode_decode_with_offsets") +@dispatch.add_dispatch_support def unicode_decode_with_offsets(input, input_encoding, errors="replace", @@ -283,6 +288,7 @@ def unicode_decode_with_offsets(input, @tf_export("strings.unicode_split") +@dispatch.add_dispatch_support def unicode_split(input, input_encoding, errors="replace", @@ -330,6 +336,7 @@ def unicode_split(input, @tf_export("strings.unicode_split_with_offsets") +@dispatch.add_dispatch_support def unicode_split_with_offsets(input, input_encoding, errors="replace", @@ -453,6 +460,7 @@ def _unicode_decode(input, input_encoding, errors, replacement_char, @tf_export("strings.split", v1=[]) +@dispatch.add_dispatch_support def string_split_v2(input, sep=None, maxsplit=-1, name=None): # pylint: disable=redefined-builtin """Split elements of `input` based on `sep` into a `RaggedTensor`. @@ -514,6 +522,7 @@ def string_split_v2(input, sep=None, maxsplit=-1, name=None): # pylint: disable @tf_export(v1=["string_split"]) +@dispatch.add_dispatch_support @deprecation.deprecated_args(None, "delimiter is deprecated, please use sep instead.", "delimiter") @@ -578,6 +587,7 @@ def string_split(source, sep=None, skip_empty=True, delimiter=None, # In TensorFlow 1.x, "tf.strings.split" uses the new signature (with maxsplit), # but we need to add the result_type argument. @tf_export(v1=["strings.split"]) +@dispatch.add_dispatch_support def strings_split_v1(input=None, sep=None, maxsplit=-1, # pylint: disable=redefined-builtin result_type="SparseTensor", source=None, name=None): """Split elements of `input` based on `sep`. @@ -651,6 +661,7 @@ def reduce_join(inputs, axis=None, keepdims=None, separator="", name=None): @tf_export("strings.ngrams") +@dispatch.add_dispatch_support def ngrams(data, ngram_width, separator=" ", diff --git a/tensorflow/python/ops/ragged/row_partition.py b/tensorflow/python/ops/ragged/row_partition.py index 133b55a53bf..e86ecc3f034 100644 --- a/tensorflow/python/ops/ragged/row_partition.py +++ b/tensorflow/python/ops/ragged/row_partition.py @@ -228,6 +228,9 @@ class RowPartition(composite_tensor.CompositeTensor): ... nrows=4)) tf.RowPartition(row_splits=tf.Tensor([0 4 4 7 8], shape=(5,), dtype=int64)) """ + # Local import bincount_ops to avoid import-cycle since bincount_ops + # imports ragged_tensor. + from tensorflow.python.ops import bincount_ops # pylint: disable=g-import-not-at-top if not isinstance(validate, bool): raise TypeError("validate must have type bool") with ops.name_scope(None, "RowPartitionFromValueRowIds", @@ -278,7 +281,7 @@ class RowPartition(composite_tensor.CompositeTensor): # cast. value_rowids_int32 = math_ops.cast(value_rowids, dtypes.int32) nrows_int32 = math_ops.cast(nrows, dtypes.int32) - row_lengths = math_ops.bincount( + row_lengths = bincount_ops.bincount( value_rowids_int32, minlength=nrows_int32, maxlength=nrows_int32, diff --git a/tensorflow/python/ops/ragged/segment_id_ops.py b/tensorflow/python/ops/ragged/segment_id_ops.py index 5329860743e..3b3809d8d56 100644 --- a/tensorflow/python/ops/ragged/segment_id_ops.py +++ b/tensorflow/python/ops/ragged/segment_id_ops.py @@ -25,12 +25,14 @@ from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.ragged import ragged_util +from tensorflow.python.util import dispatch from tensorflow.python.util.tf_export import tf_export # For background on "segments" and "segment ids", see: # https://www.tensorflow.org/api_docs/python/tf/math#Segmentation @tf_export("ragged.row_splits_to_segment_ids") +@dispatch.add_dispatch_support def row_splits_to_segment_ids(splits, name=None, out_type=None): """Generates the segmentation corresponding to a RaggedTensor `row_splits`. @@ -74,6 +76,7 @@ def row_splits_to_segment_ids(splits, name=None, out_type=None): # For background on "segments" and "segment ids", see: # https://www.tensorflow.org/api_docs/python/tf/math#Segmentation @tf_export("ragged.segment_ids_to_row_splits") +@dispatch.add_dispatch_support def segment_ids_to_row_splits(segment_ids, num_segments=None, out_type=None, name=None): """Generates the RaggedTensor `row_splits` corresponding to a segmentation. @@ -95,6 +98,8 @@ def segment_ids_to_row_splits(segment_ids, num_segments=None, Returns: A sorted 1-D integer Tensor, with `shape=[num_segments + 1]`. """ + # Local import bincount_ops to avoid import-cycle. + from tensorflow.python.ops import bincount_ops # pylint: disable=g-import-not-at-top if out_type is None: if isinstance(segment_ids, ops.Tensor): out_type = segment_ids.dtype @@ -116,7 +121,7 @@ def segment_ids_to_row_splits(segment_ids, num_segments=None, dtype=dtypes.int32) num_segments.shape.assert_has_rank(0) - row_lengths = math_ops.bincount( + row_lengths = bincount_ops.bincount( segment_ids, minlength=num_segments, maxlength=num_segments, diff --git a/tensorflow/python/ops/random_ops.py b/tensorflow/python/ops/random_ops.py index 83cb7fcc92a..1af91ed0dd3 100644 --- a/tensorflow/python/ops/random_ops.py +++ b/tensorflow/python/ops/random_ops.py @@ -36,10 +36,12 @@ from tensorflow.python.ops.gen_random_ops import * # pylint: enable=wildcard-import from tensorflow.python.util import deprecation +from tensorflow.python.util import dispatch from tensorflow.python.util.tf_export import tf_export @tf_export("random.normal", v1=["random.normal", "random_normal"]) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints("random_normal") def random_normal(shape, mean=0.0, @@ -155,6 +157,7 @@ def parameterized_truncated_normal(shape, @tf_export("random.truncated_normal", v1=["random.truncated_normal", "truncated_normal"]) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints("truncated_normal") def truncated_normal(shape, mean=0.0, @@ -202,6 +205,7 @@ ops.NotDifferentiable("TruncatedNormal") @tf_export("random.uniform", v1=["random.uniform", "random_uniform"]) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints("random_uniform") def random_uniform(shape, minval=0, @@ -313,6 +317,7 @@ ops.NotDifferentiable("RandomUniform") @tf_export("random.shuffle", v1=["random.shuffle", "random_shuffle"]) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints("random_shuffle") def random_shuffle(value, seed=None, name=None): """Randomly shuffles a tensor along its first dimension. @@ -345,6 +350,7 @@ def random_shuffle(value, seed=None, name=None): @tf_export("image.random_crop", v1=["image.random_crop", "random_crop"]) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints("random_crop") def random_crop(value, size, seed=None, name=None): """Randomly crops a tensor to a given size. @@ -389,6 +395,7 @@ def random_crop(value, size, seed=None, name=None): @tf_export(v1=["random.multinomial", "multinomial"]) +@dispatch.add_dispatch_support @deprecation.deprecated( date=None, instructions="Use `tf.random.categorical` instead.") def multinomial(logits, num_samples, seed=None, name=None, output_dtype=None): @@ -468,6 +475,7 @@ def _maybe_set_static_shape_helper(tensor, shape, postfix_tensor): @tf_export("random.gamma", v1=["random.gamma", "random_gamma"]) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints("random_gamma") def random_gamma(shape, alpha, @@ -561,6 +569,7 @@ def random_gamma(shape, @tf_export(v1=["random.poisson", "random_poisson"]) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints("random_poisson") def random_poisson(lam, shape, dtype=dtypes.float32, seed=None, name=None): """Draws `shape` samples from each of the given Poisson distribution(s). @@ -601,6 +610,7 @@ def random_poisson(lam, shape, dtype=dtypes.float32, seed=None, name=None): @tf_export("random.poisson", v1=[]) +@dispatch.add_dispatch_support def random_poisson_v2(shape, lam, dtype=dtypes.float32, seed=None, name=None): """Draws `shape` samples from each of the given Poisson distribution(s). diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py index b87e5d65a37..6c11ebefb1c 100644 --- a/tensorflow/python/ops/rnn.py +++ b/tensorflow/python/ops/rnn.py @@ -32,6 +32,7 @@ from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.util import deprecation +from tensorflow.python.util import dispatch from tensorflow.python.util import nest from tensorflow.python.util.tf_export import tf_export @@ -342,6 +343,7 @@ def _reverse_seq(input_seq, lengths): "keras.layers.RNN(cell))`, which is equivalent to " "this API") @tf_export(v1=["nn.bidirectional_dynamic_rnn"]) +@dispatch.add_dispatch_support def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, @@ -499,6 +501,7 @@ def bidirectional_dynamic_rnn(cell_fw, None, "Please use `keras.layers.RNN(cell)`, which is equivalent to this API") @tf_export(v1=["nn.dynamic_rnn"]) +@dispatch.add_dispatch_support def dynamic_rnn(cell, inputs, sequence_length=None, @@ -912,6 +915,7 @@ def _dynamic_rnn_loop(cell, @tf_export(v1=["nn.raw_rnn"]) +@dispatch.add_dispatch_support def raw_rnn(cell, loop_fn, parallel_iterations=None, @@ -1238,6 +1242,7 @@ def raw_rnn(cell, "Please use `keras.layers.RNN(cell, unroll=True)`, " "which is equivalent to this API") @tf_export(v1=["nn.static_rnn"]) +@dispatch.add_dispatch_support def static_rnn(cell, inputs, initial_state=None, @@ -1416,6 +1421,7 @@ def static_rnn(cell, "Please use `keras.layers.RNN(cell, stateful=True)`, " "which is equivalent to this API") @tf_export(v1=["nn.static_state_saving_rnn"]) +@dispatch.add_dispatch_support def static_state_saving_rnn(cell, inputs, state_saver, @@ -1510,6 +1516,7 @@ def static_state_saving_rnn(cell, "keras.layers.RNN(cell, unroll=True))`, which is " "equivalent to this API") @tf_export(v1=["nn.static_bidirectional_rnn"]) +@dispatch.add_dispatch_support def static_bidirectional_rnn(cell_fw, cell_bw, inputs, diff --git a/tensorflow/python/ops/script_ops.py b/tensorflow/python/ops/script_ops.py index bee85dc4a5b..7ee5a16ca9a 100644 --- a/tensorflow/python/ops/script_ops.py +++ b/tensorflow/python/ops/script_ops.py @@ -39,6 +39,7 @@ from tensorflow.python.ops import gen_script_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.util import compat from tensorflow.python.util import deprecation +from tensorflow.python.util import dispatch from tensorflow.python.util import lazy_loader from tensorflow.python.util import nest from tensorflow.python.util import tf_inspect @@ -370,6 +371,7 @@ def _EagerPyFuncGrad(op, *dy): @tf_export("py_function") +@dispatch.add_dispatch_support def eager_py_func(func, inp, Tout, name=None): """Wraps a python function into a TensorFlow op that executes it eagerly. @@ -551,6 +553,7 @@ def py_func_common(func, inp, Tout, stateful=True, name=None): stateful argument making all functions stateful. """) @tf_export(v1=["py_func"]) +@dispatch.add_dispatch_support def py_func(func, inp, Tout, stateful=True, name=None): return py_func_common(func, inp, Tout, stateful, name=name) @@ -559,6 +562,7 @@ py_func.__doc__ = "%s" % py_func_common.__doc__ @tf_export("numpy_function") +@dispatch.add_dispatch_support def numpy_function(func, inp, Tout, name=None): """Wraps a python function and uses it as a TensorFlow op. diff --git a/tensorflow/python/ops/sets_impl.py b/tensorflow/python/ops/sets_impl.py index 988d437bae8..0b65033ce8c 100644 --- a/tensorflow/python/ops/sets_impl.py +++ b/tensorflow/python/ops/sets_impl.py @@ -23,6 +23,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import gen_set_ops +from tensorflow.python.util import dispatch from tensorflow.python.util.tf_export import tf_export @@ -32,6 +33,7 @@ _VALID_DTYPES = set([ @tf_export("sets.size", v1=["sets.size", "sets.set_size"]) +@dispatch.add_dispatch_support def set_size(a, validate_indices=True): """Compute number of unique elements along last dimension of `a`. @@ -135,6 +137,7 @@ def _set_operation(a, b, set_operation, validate_indices=True): @tf_export( "sets.intersection", v1=["sets.intersection", "sets.set_intersection"]) +@dispatch.add_dispatch_support def set_intersection(a, b, validate_indices=True): """Compute set intersection of elements in last dimension of `a` and `b`. @@ -205,6 +208,7 @@ def set_intersection(a, b, validate_indices=True): @tf_export( "sets.difference", v1=["sets.difference", "sets.set_difference"]) +@dispatch.add_dispatch_support def set_difference(a, b, aminusb=True, validate_indices=True): """Compute set difference of elements in last dimension of `a` and `b`. @@ -286,6 +290,7 @@ def set_difference(a, b, aminusb=True, validate_indices=True): @tf_export( "sets.union", v1=["sets.union", "sets.set_union"]) +@dispatch.add_dispatch_support def set_union(a, b, validate_indices=True): """Compute set union of elements in last dimension of `a` and `b`. diff --git a/tensorflow/python/ops/signal/dct_ops.py b/tensorflow/python/ops/signal/dct_ops.py index d628e54cdf9..18730743941 100644 --- a/tensorflow/python/ops/signal/dct_ops.py +++ b/tensorflow/python/ops/signal/dct_ops.py @@ -25,6 +25,7 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops as _array_ops from tensorflow.python.ops import math_ops as _math_ops from tensorflow.python.ops.signal import fft_ops +from tensorflow.python.util import dispatch from tensorflow.python.util.tf_export import tf_export @@ -50,6 +51,7 @@ def _validate_dct_arguments(input_tensor, dct_type, n, axis, norm): # TODO(rjryan): Implement `axis` parameter. @tf_export("signal.dct", v1=["signal.dct", "spectral.dct"]) +@dispatch.add_dispatch_support def dct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disable=redefined-builtin """Computes the 1D [Discrete Cosine Transform (DCT)][dct] of `input`. @@ -181,6 +183,7 @@ def dct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disabl # TODO(rjryan): Implement `n` and `axis` parameters. @tf_export("signal.idct", v1=["signal.idct", "spectral.idct"]) +@dispatch.add_dispatch_support def idct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disable=redefined-builtin """Computes the 1D [Inverse Discrete Cosine Transform (DCT)][idct] of `input`. diff --git a/tensorflow/python/ops/signal/fft_ops.py b/tensorflow/python/ops/signal/fft_ops.py index 6e9e8ef80e4..86a94cf5de7 100644 --- a/tensorflow/python/ops/signal/fft_ops.py +++ b/tensorflow/python/ops/signal/fft_ops.py @@ -26,6 +26,7 @@ from tensorflow.python.ops import array_ops as _array_ops from tensorflow.python.ops import gen_spectral_ops from tensorflow.python.ops import manip_ops from tensorflow.python.ops import math_ops as _math_ops +from tensorflow.python.util import dispatch from tensorflow.python.util.tf_export import tf_export @@ -181,17 +182,23 @@ ifft2d = gen_spectral_ops.ifft2d fft3d = gen_spectral_ops.fft3d ifft3d = gen_spectral_ops.ifft3d rfft = _rfft_wrapper(gen_spectral_ops.rfft, 1, "rfft") -tf_export("signal.rfft", v1=["signal.rfft", "spectral.rfft"])(rfft) +tf_export("signal.rfft", v1=["signal.rfft", "spectral.rfft"])( + dispatch.add_dispatch_support(rfft)) irfft = _irfft_wrapper(gen_spectral_ops.irfft, 1, "irfft") -tf_export("signal.irfft", v1=["signal.irfft", "spectral.irfft"])(irfft) +tf_export("signal.irfft", v1=["signal.irfft", "spectral.irfft"])( + dispatch.add_dispatch_support(irfft)) rfft2d = _rfft_wrapper(gen_spectral_ops.rfft2d, 2, "rfft2d") -tf_export("signal.rfft2d", v1=["signal.rfft2d", "spectral.rfft2d"])(rfft2d) +tf_export("signal.rfft2d", v1=["signal.rfft2d", "spectral.rfft2d"])( + dispatch.add_dispatch_support(rfft2d)) irfft2d = _irfft_wrapper(gen_spectral_ops.irfft2d, 2, "irfft2d") -tf_export("signal.irfft2d", v1=["signal.irfft2d", "spectral.irfft2d"])(irfft2d) +tf_export("signal.irfft2d", v1=["signal.irfft2d", "spectral.irfft2d"])( + dispatch.add_dispatch_support(irfft2d)) rfft3d = _rfft_wrapper(gen_spectral_ops.rfft3d, 3, "rfft3d") -tf_export("signal.rfft3d", v1=["signal.rfft3d", "spectral.rfft3d"])(rfft3d) +tf_export("signal.rfft3d", v1=["signal.rfft3d", "spectral.rfft3d"])( + dispatch.add_dispatch_support(rfft3d)) irfft3d = _irfft_wrapper(gen_spectral_ops.irfft3d, 3, "irfft3d") -tf_export("signal.irfft3d", v1=["signal.irfft3d", "spectral.irfft3d"])(irfft3d) +tf_export("signal.irfft3d", v1=["signal.irfft3d", "spectral.irfft3d"])( + dispatch.add_dispatch_support(irfft3d)) def _fft_size_for_grad(grad, rank): @@ -363,6 +370,7 @@ def _irfft_grad_helper(rank, rfft_fn): @tf_export("signal.fftshift") +@dispatch.add_dispatch_support def fftshift(x, axes=None, name=None): """Shift the zero-frequency component to the center of the spectrum. @@ -407,6 +415,7 @@ def fftshift(x, axes=None, name=None): @tf_export("signal.ifftshift") +@dispatch.add_dispatch_support def ifftshift(x, axes=None, name=None): """The inverse of fftshift. diff --git a/tensorflow/python/ops/signal/mel_ops.py b/tensorflow/python/ops/signal/mel_ops.py index b95876bc977..cf0bed9ef1b 100644 --- a/tensorflow/python/ops/signal/mel_ops.py +++ b/tensorflow/python/ops/signal/mel_ops.py @@ -24,6 +24,7 @@ from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.signal import shape_ops +from tensorflow.python.util import dispatch from tensorflow.python.util.tf_export import tf_export @@ -90,6 +91,7 @@ def _validate_arguments(num_mel_bins, sample_rate, @tf_export('signal.linear_to_mel_weight_matrix') +@dispatch.add_dispatch_support def linear_to_mel_weight_matrix(num_mel_bins=20, num_spectrogram_bins=129, sample_rate=8000, diff --git a/tensorflow/python/ops/signal/mfcc_ops.py b/tensorflow/python/ops/signal/mfcc_ops.py index 56cbff40bca..948b78a858e 100644 --- a/tensorflow/python/ops/signal/mfcc_ops.py +++ b/tensorflow/python/ops/signal/mfcc_ops.py @@ -22,10 +22,12 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.signal import dct_ops +from tensorflow.python.util import dispatch from tensorflow.python.util.tf_export import tf_export @tf_export('signal.mfccs_from_log_mel_spectrograms') +@dispatch.add_dispatch_support def mfccs_from_log_mel_spectrograms(log_mel_spectrograms, name=None): """Computes [MFCCs][mfcc] of `log_mel_spectrograms`. diff --git a/tensorflow/python/ops/signal/reconstruction_ops.py b/tensorflow/python/ops/signal/reconstruction_ops.py index fcdcf592f14..e340e97b3e5 100644 --- a/tensorflow/python/ops/signal/reconstruction_ops.py +++ b/tensorflow/python/ops/signal/reconstruction_ops.py @@ -23,10 +23,12 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.util import dispatch from tensorflow.python.util.tf_export import tf_export @tf_export("signal.overlap_and_add") +@dispatch.add_dispatch_support def overlap_and_add(signal, frame_step, name=None): """Reconstructs a signal from a framed representation. diff --git a/tensorflow/python/ops/signal/shape_ops.py b/tensorflow/python/ops/signal/shape_ops.py index 1c95873fc3d..7a3acce3475 100644 --- a/tensorflow/python/ops/signal/shape_ops.py +++ b/tensorflow/python/ops/signal/shape_ops.py @@ -25,6 +25,7 @@ from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.signal import util_ops +from tensorflow.python.util import dispatch from tensorflow.python.util.tf_export import tf_export @@ -55,6 +56,7 @@ def _infer_frame_shape(signal, frame_length, frame_step, pad_end, axis): @tf_export("signal.frame") +@dispatch.add_dispatch_support def frame(signal, frame_length, frame_step, pad_end=False, pad_value=0, axis=-1, name=None): """Expands `signal`'s `axis` dimension into frames of `frame_length`. diff --git a/tensorflow/python/ops/signal/spectral_ops.py b/tensorflow/python/ops/signal/spectral_ops.py index d096e53e8f8..7c4c5542b84 100644 --- a/tensorflow/python/ops/signal/spectral_ops.py +++ b/tensorflow/python/ops/signal/spectral_ops.py @@ -31,10 +31,12 @@ from tensorflow.python.ops.signal import fft_ops from tensorflow.python.ops.signal import reconstruction_ops from tensorflow.python.ops.signal import shape_ops from tensorflow.python.ops.signal import window_ops +from tensorflow.python.util import dispatch from tensorflow.python.util.tf_export import tf_export @tf_export('signal.stft') +@dispatch.add_dispatch_support def stft(signals, frame_length, frame_step, fft_length=None, window_fn=window_ops.hann_window, pad_end=False, name=None): @@ -95,6 +97,7 @@ def stft(signals, frame_length, frame_step, fft_length=None, @tf_export('signal.inverse_stft_window_fn') +@dispatch.add_dispatch_support def inverse_stft_window_fn(frame_step, forward_window_fn=window_ops.hann_window, name=None): @@ -156,6 +159,7 @@ def inverse_stft_window_fn(frame_step, @tf_export('signal.inverse_stft') +@dispatch.add_dispatch_support def inverse_stft(stfts, frame_length, frame_step, @@ -291,6 +295,7 @@ def _enclosing_power_of_two(value): @tf_export('signal.mdct') +@dispatch.add_dispatch_support def mdct(signals, frame_length, window_fn=window_ops.vorbis_window, pad_end=False, norm=None, name=None): """Computes the [Modified Discrete Cosine Transform][mdct] of `signals`. @@ -366,6 +371,7 @@ def mdct(signals, frame_length, window_fn=window_ops.vorbis_window, @tf_export('signal.inverse_mdct') +@dispatch.add_dispatch_support def inverse_mdct(mdcts, window_fn=window_ops.vorbis_window, norm=None, diff --git a/tensorflow/python/ops/signal/window_ops.py b/tensorflow/python/ops/signal/window_ops.py index bb10bdf4be5..eb33c3f3b58 100644 --- a/tensorflow/python/ops/signal/window_ops.py +++ b/tensorflow/python/ops/signal/window_ops.py @@ -27,6 +27,7 @@ from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.util import dispatch from tensorflow.python.util.tf_export import tf_export @@ -52,6 +53,7 @@ def _check_params(window_length, dtype): @tf_export('signal.kaiser_window') +@dispatch.add_dispatch_support def kaiser_window(window_length, beta=12., dtype=dtypes.float32, name=None): """Generate a [Kaiser window][kaiser]. @@ -91,6 +93,7 @@ def kaiser_window(window_length, beta=12., dtype=dtypes.float32, name=None): @tf_export('signal.kaiser_bessel_derived_window') +@dispatch.add_dispatch_support def kaiser_bessel_derived_window(window_length, beta=12., dtype=dtypes.float32, name=None): """Generate a [Kaiser Bessel derived window][kbd]. @@ -118,6 +121,7 @@ def kaiser_bessel_derived_window(window_length, beta=12., @tf_export('signal.vorbis_window') +@dispatch.add_dispatch_support def vorbis_window(window_length, dtype=dtypes.float32, name=None): """Generate a [Vorbis power complementary window][vorbis]. @@ -142,6 +146,7 @@ def vorbis_window(window_length, dtype=dtypes.float32, name=None): @tf_export('signal.hann_window') +@dispatch.add_dispatch_support def hann_window(window_length, periodic=True, dtype=dtypes.float32, name=None): """Generate a [Hann window][hann]. @@ -167,6 +172,7 @@ def hann_window(window_length, periodic=True, dtype=dtypes.float32, name=None): @tf_export('signal.hamming_window') +@dispatch.add_dispatch_support def hamming_window(window_length, periodic=True, dtype=dtypes.float32, name=None): """Generate a [Hamming][hamming] window. diff --git a/tensorflow/python/ops/sort_ops.py b/tensorflow/python/ops/sort_ops.py index 92435e6bdef..4e66a80bc01 100644 --- a/tensorflow/python/ops/sort_ops.py +++ b/tensorflow/python/ops/sort_ops.py @@ -30,10 +30,12 @@ from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops +from tensorflow.python.util import dispatch from tensorflow.python.util.tf_export import tf_export @tf_export('sort') +@dispatch.add_dispatch_support def sort(values, axis=-1, direction='ASCENDING', name=None): """Sorts a tensor. @@ -67,6 +69,7 @@ def sort(values, axis=-1, direction='ASCENDING', name=None): @tf_export('argsort') +@dispatch.add_dispatch_support def argsort(values, axis=-1, direction='ASCENDING', stable=False, name=None): """Returns the indices of a tensor that give its sorted order along an axis. diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py index 844aa3c744c..cc4b1010021 100644 --- a/tensorflow/python/ops/sparse_ops.py +++ b/tensorflow/python/ops/sparse_ops.py @@ -27,6 +27,7 @@ import numbers import numpy as np +from tensorflow.python.compat import compat as tf_compat from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -569,7 +570,7 @@ def sparse_add_v2(a, b, threshold=0): @tf_export("sparse.cross") -def sparse_cross(inputs, name=None): +def sparse_cross(inputs, name=None, separator=None): """Generates sparse cross from a list of sparse and dense tensors. For example, if the inputs are @@ -590,14 +591,39 @@ def sparse_cross(inputs, name=None): [1, 0]: "b_X_e_X_g" [1, 1]: "c_X_e_X_g" + Customized separator "_Y_": + + >>> inp_0 = tf.constant([['a'], ['b']]) + >>> inp_1 = tf.constant([['c'], ['d']]) + >>> output = tf.sparse.cross([inp_0, inp_1], separator='_Y_') + >>> output.values + <tf.Tensor: shape=(2,), dtype=string, numpy=array([b'a_Y_c', b'b_Y_d'], + dtype=object)> + + Args: inputs: An iterable of `Tensor` or `SparseTensor`. name: Optional name for the op. + separator: A string added between each string being joined. Defaults to + '_X_'. Returns: A `SparseTensor` of type `string`. """ - return _sparse_cross_internal(inputs=inputs, hashed_output=False, name=name) + if separator is None and not tf_compat.forward_compatible(2020, 6, 14): + return _sparse_cross_internal(inputs=inputs, hashed_output=False, name=name) + if separator is None: + separator = "_X_" + separator = ops.convert_to_tensor(separator, dtypes.string) + indices, values, shapes, dense_inputs = _sparse_cross_internval_v2(inputs) + indices_out, values_out, shape_out = gen_sparse_ops.sparse_cross_v2( + indices=indices, + values=values, + shapes=shapes, + dense_inputs=dense_inputs, + sep=separator, + name=name) + return sparse_tensor.SparseTensor(indices_out, values_out, shape_out) _sparse_cross = sparse_cross @@ -655,6 +681,32 @@ _sparse_cross_hashed = sparse_cross_hashed _DEFAULT_HASH_KEY = 0xDECAFCAFFE +def _sparse_cross_internval_v2(inputs): + """See gen_sparse_ops.sparse_cross_v2.""" + if not isinstance(inputs, (tuple, list)): + raise TypeError("Inputs must be a list") + if not all( + isinstance(i, sparse_tensor.SparseTensor) or isinstance(i, ops.Tensor) + for i in inputs): + raise TypeError("All inputs must be Tensor or SparseTensor.") + sparse_inputs = [ + i for i in inputs if isinstance(i, sparse_tensor.SparseTensor) + ] + dense_inputs = [ + i for i in inputs if not isinstance(i, sparse_tensor.SparseTensor) + ] + indices = [sp_input.indices for sp_input in sparse_inputs] + values = [sp_input.values for sp_input in sparse_inputs] + shapes = [sp_input.dense_shape for sp_input in sparse_inputs] + for i in range(len(values)): + if values[i].dtype != dtypes.string: + values[i] = math_ops.cast(values[i], dtypes.int64) + for i in range(len(dense_inputs)): + if dense_inputs[i].dtype != dtypes.string: + dense_inputs[i] = math_ops.cast(dense_inputs[i], dtypes.int64) + return indices, values, shapes, dense_inputs + + def _sparse_cross_internal(inputs, hashed_output=False, num_buckets=0, @@ -1065,6 +1117,7 @@ def sparse_slice(sp_input, start, size, name=None): @tf_export(v1=["sparse_to_dense"]) +@dispatch.add_dispatch_support @deprecation.deprecated( None, "Create a `tf.sparse.SparseTensor` and use `tf.sparse.to_dense` instead.") @@ -1994,6 +2047,7 @@ def sparse_fill_empty_rows(sp_input, default_value, name=None): @tf_export(v1=["io.serialize_sparse", "serialize_sparse"]) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints("serialize_sparse") def serialize_sparse(sp_input, name=None, out_type=dtypes.string): """Serialize a `SparseTensor` into a 3-vector (1-D `Tensor`) object. @@ -2014,6 +2068,7 @@ def serialize_sparse(sp_input, name=None, out_type=dtypes.string): @tf_export("io.serialize_sparse", v1=[]) +@dispatch.add_dispatch_support def serialize_sparse_v2(sp_input, out_type=dtypes.string, name=None): """Serialize a `SparseTensor` into a 3-vector (1-D `Tensor`) object. @@ -2040,6 +2095,7 @@ def serialize_sparse_v2(sp_input, out_type=dtypes.string, name=None): @tf_export(v1=["io.serialize_many_sparse", "serialize_many_sparse"]) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints("serialize_many_sparse") def serialize_many_sparse(sp_input, name=None, out_type=dtypes.string): """Serialize `N`-minibatch `SparseTensor` into an `[N, 3]` `Tensor`. @@ -2069,6 +2125,7 @@ def serialize_many_sparse(sp_input, name=None, out_type=dtypes.string): @tf_export("io.serialize_many_sparse", v1=[]) +@dispatch.add_dispatch_support def serialize_many_sparse_v2(sp_input, out_type=dtypes.string, name=None): """Serialize `N`-minibatch `SparseTensor` into an `[N, 3]` `Tensor`. @@ -2172,6 +2229,7 @@ def deserialize_sparse(serialized_sparse, dtype, rank=None, name=None): @tf_export( "io.deserialize_many_sparse", v1=["io.deserialize_many_sparse", "deserialize_many_sparse"]) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints("deserialize_many_sparse") def deserialize_many_sparse(serialized_sparse, dtype, rank=None, name=None): """Deserialize and concatenate `SparseTensors` from a serialized minibatch. diff --git a/tensorflow/python/ops/special_math_ops.py b/tensorflow/python/ops/special_math_ops.py index a05a488408d..036346cdecd 100644 --- a/tensorflow/python/ops/special_math_ops.py +++ b/tensorflow/python/ops/special_math_ops.py @@ -42,11 +42,13 @@ from tensorflow.python.ops import gen_special_math_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import deprecation +from tensorflow.python.util import dispatch from tensorflow.python.util.tf_export import tf_export # TODO(b/27419586) Change docstring for required dtype of x once int allowed @tf_export('math.lbeta', v1=['math.lbeta', 'lbeta']) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints('lbeta') def lbeta(x, name=None): r"""Computes \\(ln(|Beta(x)|)\\), reducing along the last dimension. @@ -102,6 +104,7 @@ def lbeta(x, name=None): @tf_export('math.special.dawsn') +@dispatch.add_dispatch_support def dawsn(x, name=None): """Computes Dawson's integral of `x` element-wise. @@ -131,6 +134,7 @@ def dawsn(x, name=None): @tf_export('math.special.expint') +@dispatch.add_dispatch_support def expint(x, name=None): """Computes the Exponential integral of `x` element-wise. @@ -159,6 +163,7 @@ def expint(x, name=None): @tf_export('math.special.fresnel_cos') +@dispatch.add_dispatch_support def fresnel_cos(x, name=None): """Computes Fresnel's cosine integral of `x` element-wise. @@ -188,6 +193,7 @@ def fresnel_cos(x, name=None): @tf_export('math.special.fresnel_sin') +@dispatch.add_dispatch_support def fresnel_sin(x, name=None): """Computes Fresnel's sine integral of `x` element-wise. @@ -216,6 +222,7 @@ def fresnel_sin(x, name=None): @tf_export('math.special.spence') +@dispatch.add_dispatch_support def spence(x, name=None): """Computes Spence's integral of `x` element-wise. @@ -244,6 +251,7 @@ def spence(x, name=None): @tf_export('math.bessel_i0') +@dispatch.add_dispatch_support def bessel_i0(x, name=None): """Computes the Bessel i0 function of `x` element-wise. @@ -268,6 +276,7 @@ def bessel_i0(x, name=None): @tf_export('math.bessel_i1') +@dispatch.add_dispatch_support def bessel_i1(x, name=None): """Computes the Bessel i1 function of `x` element-wise. @@ -325,6 +334,7 @@ def _enclosing_tpu_context(): @tf_export('einsum', 'linalg.einsum') +@dispatch.add_dispatch_support def einsum(equation, *inputs, **kwargs): """Tensor contraction over specified indices and outer product. diff --git a/tensorflow/python/ops/stateless_random_ops.py b/tensorflow/python/ops/stateless_random_ops.py index 2bf53d3a0f7..0ae29ba0219 100644 --- a/tensorflow/python/ops/stateless_random_ops.py +++ b/tensorflow/python/ops/stateless_random_ops.py @@ -27,6 +27,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_stateless_random_ops from tensorflow.python.ops import math_ops from tensorflow.python.util import deprecation +from tensorflow.python.util import dispatch from tensorflow.python.util.tf_export import tf_export ops.NotDifferentiable("StatelessMultinomial") @@ -40,6 +41,7 @@ ops.NotDifferentiable("StatelessTruncatedNormal") @tf_export("random.experimental.stateless_split") +@dispatch.add_dispatch_support def split(seed, num=2): """Splits an RNG seed into `num` new seeds by adding a leading axis. @@ -73,6 +75,7 @@ def split(seed, num=2): @tf_export("random.experimental.stateless_fold_in") +@dispatch.add_dispatch_support def fold_in(seed, data): """Folds in data to an RNG seed to form a new RNG seed. @@ -111,6 +114,7 @@ def fold_in(seed, data): @tf_export("random.stateless_uniform") +@dispatch.add_dispatch_support def stateless_random_uniform(shape, seed, minval=0, @@ -205,6 +209,7 @@ def stateless_random_uniform(shape, @tf_export("random.stateless_binomial") +@dispatch.add_dispatch_support def stateless_random_binomial(shape, seed, counts, @@ -274,6 +279,7 @@ def stateless_random_binomial(shape, @tf_export("random.stateless_gamma") +@dispatch.add_dispatch_support def stateless_random_gamma(shape, seed, alpha, @@ -372,6 +378,7 @@ def stateless_random_gamma(shape, @tf_export("random.stateless_poisson") +@dispatch.add_dispatch_support def stateless_random_poisson(shape, seed, lam, @@ -434,6 +441,7 @@ def stateless_random_poisson(shape, @tf_export("random.stateless_normal") +@dispatch.add_dispatch_support def stateless_random_normal(shape, seed, mean=0.0, @@ -474,6 +482,7 @@ def stateless_random_normal(shape, @tf_export("random.stateless_truncated_normal") +@dispatch.add_dispatch_support def stateless_truncated_normal(shape, seed, mean=0.0, @@ -520,6 +529,7 @@ def stateless_truncated_normal(shape, @tf_export(v1=["random.stateless_multinomial"]) +@dispatch.add_dispatch_support @deprecation.deprecated( date=None, instructions="Use `tf.random.stateless_categorical` instead.") def stateless_multinomial(logits, @@ -562,6 +572,7 @@ def stateless_multinomial(logits, @tf_export("random.stateless_categorical") +@dispatch.add_dispatch_support def stateless_categorical(logits, num_samples, seed, diff --git a/tensorflow/python/ops/string_ops.py b/tensorflow/python/ops/string_ops.py index 09ba078383a..dd0ae223d9d 100644 --- a/tensorflow/python/ops/string_ops.py +++ b/tensorflow/python/ops/string_ops.py @@ -73,6 +73,7 @@ regex_full_match.__doc__ = gen_string_ops.regex_full_match.__doc__ @tf_export( "strings.regex_replace", v1=["strings.regex_replace", "regex_replace"]) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints("regex_replace") @dispatch.add_dispatch_support def regex_replace(input, pattern, rewrite, replace_global=True, name=None): @@ -112,6 +113,7 @@ def regex_replace(input, pattern, rewrite, replace_global=True, name=None): @tf_export("strings.format") +@dispatch.add_dispatch_support def string_format(template, inputs, placeholder="{}", summarize=3, name=None): r"""Formats a string template using a list of tensors. @@ -300,6 +302,7 @@ def _reduce_join_reduction_dims(x, axis): @tf_export(v1=["strings.reduce_join", "reduce_join"]) +@dispatch.add_dispatch_support @deprecation.deprecated_args(None, "keep_dims is deprecated, use keepdims instead", "keep_dims") @@ -412,6 +415,7 @@ string_length_v2.__doc__ = gen_string_ops.string_length.__doc__ @tf_export(v1=["substr"]) +@dispatch.add_dispatch_support @deprecation.deprecated(None, "Use `tf.strings.substr` instead of `tf.substr`.") def substr_deprecated(input, pos, len, name=None, unit="BYTE"): return substr(input, pos, len, name=name, unit=unit) @@ -476,6 +480,7 @@ def string_to_number(input, out_type=dtypes.float32, name=None): @tf_export(v1=["strings.to_number", "string_to_number"]) +@dispatch.add_dispatch_support def string_to_number_v1( string_tensor=None, out_type=dtypes.float32, @@ -519,6 +524,7 @@ def string_to_hash_bucket(input, num_buckets, name=None): @tf_export(v1=["strings.to_hash_bucket", "string_to_hash_bucket"]) +@dispatch.add_dispatch_support def string_to_hash_bucket_v1( string_tensor=None, num_buckets=None, @@ -532,6 +538,7 @@ string_to_hash_bucket_v1.__doc__ = gen_string_ops.string_to_hash_bucket.__doc__ @tf_export("strings.join", v1=["strings.join", "string_join"]) +@dispatch.add_dispatch_support @deprecation.deprecated_endpoints("string_join") @dispatch.add_dispatch_support def string_join(inputs, separator="", name=None): diff --git a/tensorflow/python/ops/structured/BUILD b/tensorflow/python/ops/structured/BUILD index e9504efdd99..64b7bd7f1d5 100644 --- a/tensorflow/python/ops/structured/BUILD +++ b/tensorflow/python/ops/structured/BUILD @@ -5,6 +5,7 @@ load("//tensorflow:tensorflow.bzl", "py_test") package( default_visibility = [ "//learning/tfx/autotfx:__subpackages__", + "//research/graph/convolutions/model/autotfx:__subpackages__", "//tensorflow:internal", ], licenses = ["notice"], # Apache 2.0 diff --git a/tensorflow/python/ops/tensor_array_ops.py b/tensorflow/python/ops/tensor_array_ops.py index d386d14b64a..58dc92084a6 100644 --- a/tensorflow/python/ops/tensor_array_ops.py +++ b/tensorflow/python/ops/tensor_array_ops.py @@ -21,10 +21,11 @@ from __future__ import print_function import contextlib -import numpy as np import traceback import weakref +import numpy as np + from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -985,22 +986,20 @@ class TensorArray(object): Example 3: A simple loop interacting with a `tf.Variable`. - # TODO(b/153898334) reenable this one flakyness is removed - # >>> v = tf.Variable(1) - # >>> - # >>> @tf.function - # ... def f(x): - # ... ta = tf.TensorArray(tf.int32, size=0, dynamic_size=True) - # ... - # ... for i in tf.range(x): - # ... v.assign_add(i) - # ... ta = ta.write(i, v) - # ... - # ... return ta.stack() - # >>> - # >>> f(5) - # <tf.Tensor: shape=(5,), dtype=int32, numpy=array([ 1, 2, 4, 7, 11], - # dtype=int32)> + # TODO(b/153898334): Convert back to doctest once bug is resolved. + ``` + v = tf.Variable(1) + @tf.function + def f(x): + ta = tf.TensorArray(tf.int32, size=0, dynamic_size=True) + for i in tf.range(x): + v.assign_add(i) + ta = ta.write(i, v) + return ta.stack() + f(5) + <tf.Tensor: shape=(5,), dtype=int32, numpy=array([ 1, 2, 4, 7, 11], + dtype=int32)> + ``` """ def __init__(self, diff --git a/tensorflow/python/profiler/BUILD b/tensorflow/python/profiler/BUILD index b6565f594c9..ffc090a4676 100644 --- a/tensorflow/python/profiler/BUILD +++ b/tensorflow/python/profiler/BUILD @@ -224,10 +224,8 @@ py_library( srcs_version = "PY2AND3", visibility = ["//tensorflow:internal"], deps = [ - "//tensorflow/python:util", + "//tensorflow/python:tf_export", "//tensorflow/python/profiler/internal:_pywrap_traceme", - "//tensorflow/python/types", - "@six_archive//:six", ], ) @@ -239,13 +237,3 @@ py_library( ":trace", ], ) - -py_library( - name = "scoped_annotation", - srcs = ["scoped_annotation.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/python/profiler/internal:_pywrap_scoped_annotation", - "@six_archive//:six", - ], -) diff --git a/tensorflow/python/profiler/internal/BUILD b/tensorflow/python/profiler/internal/BUILD index d9f93c2fb21..6f7193b3207 100644 --- a/tensorflow/python/profiler/internal/BUILD +++ b/tensorflow/python/profiler/internal/BUILD @@ -80,28 +80,28 @@ cuda_py_test( tf_python_pybind_extension( name = "_pywrap_traceme", srcs = ["traceme_wrapper.cc"], - features = ["-layering_check"], module_name = "_pywrap_traceme", visibility = [ "//perftools/accelerators/xprof/xprofilez/integration_tests:__pkg__", "//tensorflow/python/profiler:__subpackages__", ], deps = [ - "//tensorflow/core:lib", - "//tensorflow/core/profiler/lib:traceme_headers", - "@com_google_absl//absl/types:optional", + ":traceme_wrapper", "@pybind11", ], ) -tf_python_pybind_extension( - name = "_pywrap_scoped_annotation", - srcs = ["scoped_annotation_wrapper.cc"], +cc_library( + name = "traceme_wrapper", + hdrs = ["traceme_wrapper.h"], features = ["-layering_check"], - module_name = "_pywrap_scoped_annotation", + visibility = [ + "//tensorflow/compiler/xla/python:__pkg__", + ], deps = [ "//tensorflow/core:lib", - "//tensorflow/core/profiler/lib:scoped_annotation_headers", + "//tensorflow/core/profiler/lib:traceme_headers", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@pybind11", ], diff --git a/tensorflow/python/profiler/internal/scoped_annotation_wrapper.cc b/tensorflow/python/profiler/internal/scoped_annotation_wrapper.cc deleted file mode 100644 index 078ebb0966c..00000000000 --- a/tensorflow/python/profiler/internal/scoped_annotation_wrapper.cc +++ /dev/null @@ -1,55 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include <utility> - -#include "absl/types/optional.h" -#include "pybind11/pybind11.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/lib/scoped_annotation.h" - -namespace py = pybind11; - -namespace { - -// Helper to implement ScopedAnnotation as a context manager in Python. -class ScopedAnnotationWrapper { - public: - explicit ScopedAnnotationWrapper(const tensorflow::string& name) - : name_(name) {} - - void Enter() { annotation_.emplace(std::move(name_)); } - - void Exit() { annotation_.reset(); } - - static bool IsEnabled() { - return tensorflow::profiler::ScopedAnnotation::IsEnabled(); - } - - private: - tensorflow::string name_; - absl::optional<tensorflow::profiler::ScopedAnnotation> annotation_; -}; - -} // namespace - -PYBIND11_MODULE(_pywrap_scoped_annotation, m) { - py::class_<ScopedAnnotationWrapper> scoped_annotation_class( - m, "ScopedAnnotation"); - scoped_annotation_class.def(py::init<const tensorflow::string&>()) - .def("Enter", &ScopedAnnotationWrapper::Enter) - .def("Exit", &ScopedAnnotationWrapper::Exit) - .def_static("IsEnabled", &ScopedAnnotationWrapper::IsEnabled); -}; diff --git a/tensorflow/python/profiler/internal/traceme_wrapper.cc b/tensorflow/python/profiler/internal/traceme_wrapper.cc index a1b5370836b..32a1f423918 100644 --- a/tensorflow/python/profiler/internal/traceme_wrapper.cc +++ b/tensorflow/python/profiler/internal/traceme_wrapper.cc @@ -13,46 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include <utility> +#include "tensorflow/python/profiler/internal/traceme_wrapper.h" -#include "absl/types/optional.h" +#include "pybind11/attr.h" #include "pybind11/pybind11.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/lib/traceme.h" -namespace py = pybind11; +namespace py = ::pybind11; -namespace { - -// Helper to implement TraceMe as a context manager in Python. -class TraceMeWrapper { - public: - explicit TraceMeWrapper(const tensorflow::string& name) : name_(name) {} - - void Enter() { traceme_.emplace(std::move(name_)); } - - void SetMetadata(const tensorflow::string& new_metadata) { - if (TF_PREDICT_TRUE(traceme_)) { - traceme_->SetMetadata(new_metadata); - } - } - - void Exit() { traceme_.reset(); } - - static bool IsEnabled() { return tensorflow::profiler::TraceMe::Active(); } - - private: - tensorflow::string name_; - absl::optional<tensorflow::profiler::TraceMe> traceme_; -}; - -} // namespace +using ::tensorflow::profiler::TraceMeWrapper; PYBIND11_MODULE(_pywrap_traceme, m) { - py::class_<TraceMeWrapper> traceme_class(m, "TraceMe"); - traceme_class.def(py::init<const tensorflow::string&>()) - .def("Enter", &TraceMeWrapper::Enter) - .def("Exit", &TraceMeWrapper::Exit) + py::class_<TraceMeWrapper>(m, "TraceMe", py::module_local()) + .def(py::init<const py::str&, const py::kwargs&>()) .def("SetMetadata", &TraceMeWrapper::SetMetadata) .def_static("IsEnabled", &TraceMeWrapper::IsEnabled); }; diff --git a/tensorflow/python/profiler/internal/traceme_wrapper.h b/tensorflow/python/profiler/internal/traceme_wrapper.h new file mode 100644 index 00000000000..c074e909640 --- /dev/null +++ b/tensorflow/python/profiler/internal/traceme_wrapper.h @@ -0,0 +1,80 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_PYTHON_PROFILER_INTERNAL_TRACEME_WRAPPER_ +#define TENSORFLOW_PYTHON_PROFILER_INTERNAL_TRACEME_WRAPPER_ + +#include <string> +#include <utility> + +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "pybind11/pytypes.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/profiler/lib/traceme.h" + +namespace tensorflow { +namespace profiler { + +// Wraps TraceMe with an interface that takes python types. +class TraceMeWrapper { + public: + // pybind11::str and pybind11::kwargs are taken by const reference to avoid + // python reference-counting overhead. + TraceMeWrapper(const pybind11::str& name, const pybind11::kwargs& kwargs) + : traceme_([&]() { + std::string name_and_metadata(name); + if (!kwargs.empty()) { + AppendMetadata(&name_and_metadata, kwargs); + } + return name_and_metadata; + }) {} + + // pybind11::kwargs is taken by const reference to avoid python + // reference-counting overhead. + void SetMetadata(const pybind11::kwargs& kwargs) { + if (TF_PREDICT_FALSE(!kwargs.empty())) { + traceme_.AppendMetadata([&]() { + std::string metadata; + AppendMetadata(&metadata, kwargs); + return metadata; + }); + } + } + + void Stop() { traceme_.Stop(); } + + static bool IsEnabled() { return tensorflow::profiler::TraceMe::Active(); } + + private: + // Converts kwargs to strings and appends them to name encoded as TraceMe + // metadata. + static void AppendMetadata(std::string* name, + const pybind11::kwargs& kwargs) { + name->push_back('#'); + for (const auto& kv : kwargs) { + absl::StrAppend(name, std::string(pybind11::str(kv.first)), "=", + std::string(pybind11::str(kv.second)), ","); + } + name->back() = '#'; + } + + tensorflow::profiler::TraceMe traceme_; +}; + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_PYTHON_PROFILER_INTERNAL_TRACEME_WRAPPER_ diff --git a/tensorflow/python/profiler/scoped_annotation.py b/tensorflow/python/profiler/scoped_annotation.py deleted file mode 100644 index 1d7e2b024b4..00000000000 --- a/tensorflow/python/profiler/scoped_annotation.py +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""ScopedAnnotation allows the profiler to annotate device (e.g., GPU) events. - -Usage: - with scoped_annotation.ScopedAnnotation('name'): - ... -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import six - -from tensorflow.python.profiler.internal import _pywrap_scoped_annotation - - -class ScopedAnnotation(object): - """Context manager that generates an annotation for the profiler.""" - - def __init__(self, name, **kwargs): - if _pywrap_scoped_annotation.ScopedAnnotation.IsEnabled(): - if kwargs: - name += '#' + ','.join(key + '=' + str(value) - for key, value in six.iteritems(kwargs)) + '#' - self._scoped_annotation = _pywrap_scoped_annotation.ScopedAnnotation(name) - else: - self._scoped_annotation = None - - def __enter__(self): - if self._scoped_annotation: - self._scoped_annotation.Enter() - - def __exit__(self, exc_type, exc_val, exc_tb): - if self._scoped_annotation: - self._scoped_annotation.Exit() diff --git a/tensorflow/python/profiler/trace.py b/tensorflow/python/profiler/trace.py index 424bdd6f3fc..ea4eb060488 100644 --- a/tensorflow/python/profiler/trace.py +++ b/tensorflow/python/profiler/trace.py @@ -18,29 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import six - from tensorflow.python.profiler.internal import _pywrap_traceme from tensorflow.python.util.tf_export import tf_export -def encode_metadata(metadata): - """Encodes the given metadata to a string. - - Args: - metadata: in key-value pairs. - - Returns: - The encoded string. - """ - if not metadata: - return '' - content = [] - for key, value in six.iteritems(metadata): - content.append('%s=%s'%(key, value)) - return '#' + ','.join(content) + '#' - - @tf_export('profiler.experimental.Trace', v1=[]) class Trace(object): """Context manager that generates a trace event in the profiler. @@ -92,14 +73,13 @@ class Trace(object): training step being traced. """ if _pywrap_traceme.TraceMe.IsEnabled(): - name += encode_metadata(kwargs) - self._traceme = _pywrap_traceme.TraceMe(name) + # Creating _pywrap_traceme.TraceMe starts the clock. + self._traceme = _pywrap_traceme.TraceMe(name, **kwargs) else: self._traceme = None def __enter__(self): - if self._traceme: - self._traceme.Enter() + # Starting the TraceMe clock here would require an extra Python->C++ call. return self def set_metadata(self, **kwargs): @@ -134,9 +114,8 @@ class Trace(object): to measure the entire duration of call()). """ if self._traceme and kwargs: - additional_metadata = encode_metadata(kwargs) - self._traceme.SetMetadata(additional_metadata) + self._traceme.SetMetadata(**kwargs) def __exit__(self, exc_type, exc_val, exc_tb): - if self._traceme: - self._traceme.Exit() + # Deallocating _pywrap_traceme.TraceMe stops the clock. + self._traceme = None diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD index 2e5db7edd27..5c30d320fb7 100644 --- a/tensorflow/python/saved_model/BUILD +++ b/tensorflow/python/saved_model/BUILD @@ -2,6 +2,8 @@ # TensorFlow SavedModel. load("//tensorflow:tensorflow.bzl", "cuda_py_test") + +# buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "tf_py_test") package( diff --git a/tensorflow/python/tf_program/tests/mlir_gen_test.py b/tensorflow/python/tf_program/tests/mlir_gen_test.py index 5e1ca5b36e0..49737352d73 100644 --- a/tensorflow/python/tf_program/tests/mlir_gen_test.py +++ b/tensorflow/python/tf_program/tests/mlir_gen_test.py @@ -83,7 +83,7 @@ class MLIRGenTest(MLIRGenTestBase): CHECK-LABEL: func @test_fn(%arg0: i1, %arg1: i1) -> i1 CHECK: %[[r0:[0-9]+]] = "tfp.And"(%arg0, %arg0, %arg1) : (i1, i1, i1) -> tensor<*xi1> CHECK: %[[r1:[0-9]+]] = "tfp.Or"(%arg0, %arg1, %[[r0]]) : (i1, i1, tensor<*xi1>) -> tensor<*xi1> - return %[[r1]] : tensor<*xi1> + CHECK: return %[[r1]] : tensor<*xi1> """ self._check_code(mlir_code, exp_mlir_code) @@ -158,7 +158,7 @@ class MLIRGenTest(MLIRGenTestBase): mlir_code = mlir_gen(test_fn) exp_mlir_code = r""" CHECK-LABEL: func @test_fn(%arg0: tensor<*xi32>) -> i32 - + CHECK: %[[r1:[0-9]+]] = "tf.Greater"(%arg0, %{{[0-9]+}}) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi1> CHECK-NEXT: %[[r2:[0-9]+]] = "tfp.If"(%[[r1]]) ( { CHECK: return %{{[0-9]+}} : tensor<i32> @@ -222,7 +222,7 @@ class MLIRGenTest(MLIRGenTestBase): CHECK: %[[r5:[0-9]+]] = "tf.Equal"(%arg0, %{{[0-9]+}}) {incompatible_shape_error = true} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi1> CHECK: %[[r7:[0-9]+]] = "tf.Equal"(%arg0, %{{[0-9]+}}) {incompatible_shape_error = true} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi1> CHECK: %[[r8:[0-9]+]] = "tfp.Or"(%[[r5]], %[[r7]]) : (tensor<*xi1>, tensor<*xi1>) -> tensor<*xi1> - + CHECK: %[[r9:[0-9]+]]:4 = "tfp.If"(%[[r8]]) ( { CHECK-NEXT: return %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : tensor<{{(\*x)?}}i32>, tensor<{{(\*x)?}}i32>, tensor<{{(\*x)?}}i32>, tensor<{{(\*x)?}}i32> CHECK-NEXT: }, { diff --git a/tensorflow/python/tfe_wrapper.cc b/tensorflow/python/tfe_wrapper.cc index 836cafbd494..efcd912f430 100644 --- a/tensorflow/python/tfe_wrapper.cc +++ b/tensorflow/python/tfe_wrapper.cc @@ -210,6 +210,22 @@ TFE_OutputTensorHandles InputTFE_OutputTensorHandles( return output_tensor_handles; } +// Packs multiple `EagerTensor`s of the same dtype and shape into one +// `EagerTensor`. +py::object TFE_Py_PackEagerTensors_wrapper(const py::handle& context, + const py::handle& tensors) { + TFE_Context* ctx = tensorflow::InputTFE_Context(context); + TFE_InputTensorHandles handles = InputTFE_InputTensorHandles(tensors); + tensorflow::Safe_TF_StatusPtr status = tensorflow::make_safe(TF_NewStatus()); + int size = handles.size(); + TFE_TensorHandle* packed_handle = + TFE_CreatePackedTensorHandle(ctx, handles.data(), &size, status.get()); + tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); + PyObject* packed_tensor = + EagerTensorFromHandle(packed_handle, /*is_packed=*/true); + return tensorflow::PyoOrThrow(packed_tensor); +} + // This function was created from fusing the typemap logic in platform/base.i. py::object TFE_Py_ExecuteCancelable_wrapper( const py::handle& context, const char* device_name, const char* op_name, @@ -558,6 +574,10 @@ PYBIND11_MODULE(_pywrap_tfe, m) { m.def("TFE_Py_InitEagerTensor", [](const py::handle& o) { return tensorflow::PyoOrThrow(TFE_Py_InitEagerTensor(o.ptr())); }); + m.def("TFE_Py_PackEagerTensors", + [](const py::handle& context, const py::handle& handles) { + return tensorflow::TFE_Py_PackEagerTensors_wrapper(context, handles); + }); m.def("TFE_Py_SetEagerTensorProfiler", &TFE_Py_SetEagerTensorProfiler); m.def("TFE_Py_RegisterJVPFunction", [](const py::handle& o) { return tensorflow::PyoOrThrow(TFE_Py_RegisterJVPFunction(o.ptr())); diff --git a/tensorflow/python/tools/api/generator/api_init_files.bzl b/tensorflow/python/tools/api/generator/api_init_files.bzl index 13068a8090e..03120fb8dc4 100644 --- a/tensorflow/python/tools/api/generator/api_init_files.bzl +++ b/tensorflow/python/tools/api/generator/api_init_files.bzl @@ -67,6 +67,7 @@ TENSORFLOW_API_INIT_FILES = [ "summary/experimental/__init__.py", "sysconfig/__init__.py", "test/__init__.py", + "tpu/experimental/embedding/__init__.py", "tpu/experimental/__init__.py", "tpu/__init__.py", "train/__init__.py", diff --git a/tensorflow/python/tools/api/generator/api_init_files_v1.bzl b/tensorflow/python/tools/api/generator/api_init_files_v1.bzl index e5f0f46898f..a8154c6f35c 100644 --- a/tensorflow/python/tools/api/generator/api_init_files_v1.bzl +++ b/tensorflow/python/tools/api/generator/api_init_files_v1.bzl @@ -85,6 +85,7 @@ TENSORFLOW_API_INIT_FILES_V1 = [ "summary/__init__.py", "sysconfig/__init__.py", "test/__init__.py", + "tpu/experimental/embedding/__init__.py", "tpu/experimental/__init__.py", "tpu/__init__.py", "train/__init__.py", diff --git a/tensorflow/python/tools/saved_model_aot_compile.py b/tensorflow/python/tools/saved_model_aot_compile.py index a8694454ef2..5a34d10420a 100644 --- a/tensorflow/python/tools/saved_model_aot_compile.py +++ b/tensorflow/python/tools/saved_model_aot_compile.py @@ -215,6 +215,7 @@ def aot_compile_cpu_meta_graph_def(checkpoint_path, signature_def_key, cpp_class, target_triple, + target_cpu, variables_to_feed=(), enable_multithreading=False): """Compile a `MetaGraphDef` to header+object files in `output_prefix`. @@ -239,6 +240,7 @@ def aot_compile_cpu_meta_graph_def(checkpoint_path, signature_def_key: String, the signature_def to use in the SavedModel. cpp_class: String, Name of output C++ class. target_triple: String, LLVM target triple. + target_cpu: String, LLVM target cpu name. variables_to_feed: A list of strings, the variables that will be fed by the user; these won't be frozen. If `None`, then we will extract all the variables in the graph and mark them as to-feed. The default behavior is @@ -367,6 +369,7 @@ def aot_compile_cpu_meta_graph_def(checkpoint_path, config=config_pbtxt_location, cpp_class=cpp_class, target_triple=target_triple, + target_cpu=target_cpu, entry_point='entry_{}'.format(entry_digest), out_function_object='{}.o'.format(output_prefix), out_header='{}.h'.format(output_prefix), diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py index 261ee1b9e9d..0f8f68436a3 100644 --- a/tensorflow/python/tools/saved_model_cli.py +++ b/tensorflow/python/tools/saved_model_cli.py @@ -825,6 +825,7 @@ def aot_compile_cpu(args): variables_to_feed=variables_to_feed, output_prefix=args.output_prefix, target_triple=args.target_triple, + target_cpu=args.target_cpu, cpp_class=args.cpp_class, enable_multithreading=args.enable_multithreading) @@ -1096,6 +1097,14 @@ def add_aot_compile_cpu_subparser(subparsers): 'x86_64-none-darwin, x86_64-apple-ios, arm64-none-ios, ' 'armv7-none-android. More examples are available in tfcompile.bzl ' 'in the tensorflow codebase.')) + parser_compile.add_argument( + '--target_cpu', + type=str, + default='', + help=('Target cpu name for LLVM during AOT compilation. Examples: ' + 'x86_64, skylake, haswell, westmere, <empty> (unknown). For ' + 'a complete list of options, run (for x86 targets): ' + '`llc -march=x86 -mcpu=help`')) parser_compile.add_argument( '--checkpoint_path', type=str, diff --git a/tensorflow/python/tools/tools.bzl b/tensorflow/python/tools/tools.bzl index c6853e1fc63..79f771bbcad 100644 --- a/tensorflow/python/tools/tools.bzl +++ b/tensorflow/python/tools/tools.bzl @@ -1,6 +1,7 @@ """Definitions for using tools like saved_model_cli.""" load("//tensorflow:tensorflow.bzl", "clean_dep", "if_xla_available") +load("//tensorflow:tensorflow.bzl", "tfcompile_target_cpu") load("//tensorflow/compiler/aot:tfcompile.bzl", "target_llvm_triple") def _maybe_force_compile(args, force_compile): @@ -19,6 +20,7 @@ def saved_model_compile_aot( signature_def = "serving_default", variables_to_feed = "", target_triple = None, + target_cpu = None, force_without_xla_support_flag = True, tags = None): """Compile a SavedModel directory accessible from a filegroup. @@ -88,7 +90,9 @@ def saved_model_compile_aot( uninitialized in the compiled object (this applies to all input arguments from the signature as well). target_triple: The LLVM target triple to use (defaults to current build's - target architecture's triple). + target architecture's triple). Similar to clang's -target flag. + target_cpu: The LLVM cpu name used for compilation. Similar to clang's + -mcpu flag. force_without_xla_support_flag: Whether to compile even when `--define=with_xla_support=true` is not set. If `False`, and the define is not passed when building, then the created `cc_library` @@ -100,6 +104,7 @@ def saved_model_compile_aot( """ saved_model = "{}/saved_model.pb".format(directory) target_triple = target_triple or target_llvm_triple() + target_cpu = target_cpu or tfcompile_target_cpu() or "" variables_to_feed = variables_to_feed or "''" if checkpoint_path: checkpoint_cmd_args = ( @@ -131,6 +136,7 @@ def saved_model_compile_aot( "--variables_to_feed {} ".format(variables_to_feed) + "--signature_def_key {} ".format(signature_def) + "--target_triple " + target_triple + " " + + ("--target_cpu " + target_cpu + " " if target_cpu else "") + "--tag_set {} ".format(tag_set) ), tags = tags, diff --git a/tensorflow/python/tpu/BUILD b/tensorflow/python/tpu/BUILD index ebf0a4ffc57..d398396ec2a 100644 --- a/tensorflow/python/tpu/BUILD +++ b/tensorflow/python/tpu/BUILD @@ -179,6 +179,8 @@ py_library( ":feature_column_v2", ":preempted_hook_py", ":tpu_embedding", + ":tpu_embedding_v2", + ":tpu_embedding_v2_utils", ":tpu_lib", ], ) @@ -435,6 +437,37 @@ tf_py_test( ], ) +py_library( + name = "tpu_embedding_v2_utils", + srcs = ["tpu_embedding_v2_utils.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:variable_scope", + "//tensorflow/python/distribute:device_util", + "//tensorflow/python/distribute:sharded_variable", + "//tensorflow/python/tpu:tpu_lib", + "//tensorflow/python/tpu:tpu_py", + "//tensorflow/python/training/saving:saveable_hook", + "@six_archive//:six", + ], +) + +py_library( + name = "tpu_embedding_v2", + srcs = ["tpu_embedding_v2.py"], + srcs_version = "PY2AND3", + deps = [ + ":tpu_embedding_v2_utils", + "//tensorflow/python:variable_scope", + "//tensorflow/python/distribute:device_util", + "//tensorflow/python/distribute:sharded_variable", + "//tensorflow/python/tpu:tpu_lib", + "//tensorflow/python/tpu:tpu_py", + "//tensorflow/python/training/saving:saveable_hook", + "@six_archive//:six", + ], +) + tf_proto_library( name = "tensor_tracer_proto", srcs = ["tensor_tracer.proto"], diff --git a/tensorflow/python/tpu/api.py b/tensorflow/python/tpu/api.py index 7296de81dfe..a7db89ec0a5 100644 --- a/tensorflow/python/tpu/api.py +++ b/tensorflow/python/tpu/api.py @@ -27,5 +27,7 @@ from tensorflow.python.tpu import bfloat16 from tensorflow.python.tpu import feature_column_v2 from tensorflow.python.tpu import tpu from tensorflow.python.tpu import tpu_embedding +from tensorflow.python.tpu import tpu_embedding_v2 +from tensorflow.python.tpu import tpu_embedding_v2_utils from tensorflow.python.tpu import tpu_optimizer # pylint: enable=unused-import diff --git a/tensorflow/python/tpu/feature_column_v2.py b/tensorflow/python/tpu/feature_column_v2.py index d9820425467..1012506c48b 100644 --- a/tensorflow/python/tpu/feature_column_v2.py +++ b/tensorflow/python/tpu/feature_column_v2.py @@ -31,15 +31,18 @@ from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import sparse_ops +from tensorflow.python.ops import variable_scope from tensorflow.python.tpu import tpu from tensorflow.python.tpu.feature_column import _is_running_on_cpu from tensorflow.python.tpu.feature_column import _record_variable_scope_and_name from tensorflow.python.tpu.feature_column import _SUPPORTED_CATEGORICAL_COLUMNS_V2 +from tensorflow.python.tpu.feature_column import _SUPPORTED_SEQUENCE_COLUMNS from tensorflow.python.tpu.feature_column import _TPUBaseEmbeddingColumn from tensorflow.python.util.tf_export import tf_export # pylint: disable=protected-access _ALLOWED_DEVICES = ['cpu', 'tpu_tensor_core', 'tpu_embedding_core'] +_TENSOR_CORE_MASK_KEY_SUFFIX = '__TENSOR_CORE_MASK' class EmbeddingDevice(enum.Enum): @@ -174,10 +177,13 @@ def embedding_column_v2(categorical_column, elif embedding_lookup_device == 'tpu_embedding_core': embedding_lookup_device = EmbeddingDevice.TPU_EMBEDDING_CORE - if (embedding_lookup_device == EmbeddingDevice.TPU_TENSOR_CORE and - not tensor_core_shape): - raise ValueError('Using embedding_lookup_device=tpu_tensor_core requires ' - 'tensor_core_shape to be set.') + if embedding_lookup_device == EmbeddingDevice.TPU_TENSOR_CORE: + if not tensor_core_shape: + raise ValueError('Using embedding_lookup_device=tpu_tensor_core requires ' + 'tensor_core_shape to be set.') + if isinstance(categorical_column, _SUPPORTED_SEQUENCE_COLUMNS): + raise ValueError('embedding_lookup_device=tpu_tensor_core currently does ' + 'not support sequence columns.') if not embedding_lookup_device: return _TPUEmbeddingColumnV2( @@ -372,10 +378,14 @@ def shared_embedding_columns_v2(categorical_columns, elif embedding_lookup_device == 'tpu_embedding_core': embedding_lookup_device = EmbeddingDevice.TPU_EMBEDDING_CORE - if (embedding_lookup_device == EmbeddingDevice.TPU_EMBEDDING_CORE and - not tensor_core_shape): - raise ValueError('Using embedding_lookup_device=tpu_tensor_core requires ' - 'tensor_core_shape to be set.') + if embedding_lookup_device == EmbeddingDevice.TPU_TENSOR_CORE: + if not tensor_core_shape: + raise ValueError('Using embedding_lookup_device=tpu_tensor_core requires ' + 'tensor_core_shape to be set.') + for c in sorted_columns: + if isinstance(c, _SUPPORTED_SEQUENCE_COLUMNS): + raise ValueError('embedding_lookup_device=tpu_tensor_core currently ' + 'does not support sequence columns.') # Create the state (_SharedEmbeddingColumnLayer) here. for categorical_column, max_sequence_length in zip( @@ -807,7 +817,13 @@ def sparse_embedding_aggregate_slice(params, if combiner == 'sum': return aggregate_emb elif combiner == 'mean': - return aggregate_emb / math_ops.reduce_sum(values_mask_broadcast, axis=1) + # In the case we have an empty row, both aggregate_emb and + # math_ops.reduce_sum(values_mask_broadcast, axis=1) will be 0. Thus, + # we can take max it with a non-zero value to prevent NaNs. Note that + # math_ops.reduce_sum(values_mask_broadcast, axis=1) will have integer + # values so 1.0 is the smallest value. + return aggregate_emb / math_ops.maximum( + math_ops.reduce_sum(values_mask_broadcast, axis=1), 1.0) else: raise ValueError('Dense TPU Embedding does not support combiner ' 'other than sum and mean.') @@ -851,6 +867,20 @@ def pad_sparse_embedding_lookup_indices(sparse_indices, padded_size): return padded_values, padded_mask +def _check_invalid_cases(embedding_lookup_device): + """Checks for invalid embedding_lookup_device configurations.""" + if (tpu.under_tpu_inference_context() and + embedding_lookup_device == EmbeddingDevice.TPU_EMBEDDING_CORE): + raise ValueError( + 'Using embedding_lookup_device=tpu_embedding_core during inference ' + 'is not supported.') + if embedding_lookup_device == EmbeddingDevice.CPU: + if not tpu.under_tpu_inference_context(): + raise ValueError( + 'Using TPUEmbeddingColumn with embedding_lookup_device="cpu" ' + 'during training is not supported.') + + class _TPUDeviceSpecificEmbeddingColumnV2(_TPUEmbeddingColumnV2): """TPUEmbeddingColumn which allows serving on TensorCore.""" @@ -874,46 +904,108 @@ class _TPUDeviceSpecificEmbeddingColumnV2(_TPUEmbeddingColumnV2): del kwargs['embedding_lookup_device'] _TPUEmbeddingColumnV2.__init__(self, *args, **kwargs) - def create_state(self, state_manager): - if (tpu.under_tpu_inference_context() and - self._embedding_lookup_device == EmbeddingDevice.TPU_EMBEDDING_CORE): - raise ValueError( - 'Using embedding_lookup_device=tpu_embedding_core during inference ' - 'is not supported.') - if self._embedding_lookup_device == EmbeddingDevice.CPU: - if tpu.under_tpu_inference_context(): - return fc_lib.EmbeddingColumn.create_state(self, state_manager) - else: - raise ValueError( - 'Using TPUEmbeddingColumn with embedding_lookup_device="cpu" ' - 'during training is not supported.') + def __deepcopy__(self, memo): + return _TPUDeviceSpecificEmbeddingColumnV2( + *(copy.deepcopy(a, memo) for a in self.__getnewargs__()), + tensor_core_shape=self._tensor_core_shape, + embedding_lookup_device=self._embedding_lookup_device) - return super(_TPUDeviceSpecificEmbeddingColumnV2, - self).create_state(state_manager) + def create_state(self, state_manager): + _check_invalid_cases(self._embedding_lookup_device) + # CPU case. + is_cpu = self._embedding_lookup_device == EmbeddingDevice.CPU + is_cpu = is_cpu or _is_running_on_cpu() + if is_cpu: + return fc_lib.EmbeddingColumn.create_state(self, state_manager) + # TPU_EMBEDDING_CORE case. + elif self._embedding_lookup_device == EmbeddingDevice.TPU_EMBEDDING_CORE: + return super(_TPUDeviceSpecificEmbeddingColumnV2, + self).create_state(state_manager) + + # TPU_EMBEDDING_CORE case. + return fc_lib.EmbeddingColumn.create_state(self, state_manager) def get_dense_tensor(self, transformation_cache, state_manager): """Private method that follows get_dense_tensor.""" - - # If we aren't inferencing on TensorCore, just delegate to parent. - if not tpu.under_tpu_inference_context() or not self._tensor_core_shape: + _check_invalid_cases(self._embedding_lookup_device) + # CPU Case. + is_cpu = self._embedding_lookup_device == EmbeddingDevice.CPU + is_cpu = is_cpu or _is_running_on_cpu() + if is_cpu: + return super(_TPUDeviceSpecificEmbeddingColumnV2, + self).get_dense_tensor(transformation_cache, state_manager) + # TPU_EMBEDDING_CORE case. + elif self._embedding_lookup_device == EmbeddingDevice.TPU_EMBEDDING_CORE: return super(_TPUDeviceSpecificEmbeddingColumnV2, self).get_dense_tensor(transformation_cache, state_manager) - sparse_tensor = transformation_cache.get(self.categorical_column.name, - state_manager) - # Use outside compile to densify and pad the input tensors. - def host_computation(): - return pad_sparse_embedding_lookup_indices(sparse_tensor, - self._tensor_core_shape[1]) + # TPU_EMBEDDING_CORE cases. + if tpu.under_tpu_inference_context(): + # For inference, use outside compile to densify and pad the input tensors. + sparse_tensor = transformation_cache.get(self.categorical_column.name, + state_manager) - values, mask = tpu.outside_compilation(host_computation) + def host_computation(): + return pad_sparse_embedding_lookup_indices(sparse_tensor, + self._tensor_core_shape[1]) - # Do a dense embedding lookup on TensorCore. - embedding_weights = state_manager.get_variable(self, 'embedding_weights') - embedding = sparse_embedding_aggregate_slice(embedding_weights, - (values, mask), - self.get_combiner()) - return embedding + values, mask = tpu.outside_compilation(host_computation) + else: + # For training, the inputs should already have been densified and padded. + values = transformation_cache.get(self.categorical_column.name, + state_manager) + mask = transformation_cache.get( + self.categorical_column.name + _TENSOR_CORE_MASK_KEY_SUFFIX, + state_manager) + embedding_weights = state_manager.get_variable( + self, name='embedding_weights') + return sparse_embedding_aggregate_slice(embedding_weights, (values, mask), + self.get_combiner()) + + def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): + _check_invalid_cases(self._embedding_lookup_device) + # CPU Case. + is_cpu = self._embedding_lookup_device == EmbeddingDevice.CPU + is_cpu = is_cpu or _is_running_on_cpu() + if is_cpu: + return super(_TPUDeviceSpecificEmbeddingColumnV2, + self)._get_dense_tensor(inputs, weight_collections, + trainable) + # TPU_EMBEDDING_CORE case. + elif self._embedding_lookup_device == EmbeddingDevice.TPU_EMBEDDING_CORE: + return super(_TPUDeviceSpecificEmbeddingColumnV2, + self)._get_dense_tensor(inputs, weight_collections, + trainable) + + # TPU_EMBEDDING_CORE cases. + if tpu.under_tpu_inference_context(): + # For inference, use outside compile to densify and pad the input tensors. + sparse_tensor = inputs.get(self.get_feature_key_name()) + + def host_computation(): + return pad_sparse_embedding_lookup_indices(sparse_tensor, + self._tensor_core_shape[1]) + + values, mask = tpu.outside_compilation(host_computation) + else: + # For training, the inputs should already have been densified and padded. + values = inputs.get(self.get_feature_key_name()) + mask = inputs.get(self.get_feature_key_name() + + _TENSOR_CORE_MASK_KEY_SUFFIX) + + embedding_shape = (self.categorical_column._num_buckets, self.dimension) # pylint: disable=protected-access + if (weight_collections and + ops.GraphKeys.GLOBAL_VARIABLES not in weight_collections): + weight_collections.append(ops.GraphKeys.GLOBAL_VARIABLES) + embedding_weights = variable_scope.get_variable( + name='embedding_weights', + shape=embedding_shape, + dtype=dtypes.float32, + initializer=self.initializer, + trainable=self.trainable and trainable, + collections=weight_collections) + return sparse_embedding_aggregate_slice(embedding_weights, (values, mask), + self.get_combiner()) class _TPUSharedDeviceSpecificEmbeddingColumnV2(_TPUSharedEmbeddingColumnV2): @@ -940,34 +1032,48 @@ class _TPUSharedDeviceSpecificEmbeddingColumnV2(_TPUSharedEmbeddingColumnV2): del kwargs['embedding_lookup_device'] _TPUSharedEmbeddingColumnV2.__init__(self, *args, **kwargs) + def __deepcopy__(self, memo): + return _TPUSharedDeviceSpecificEmbeddingColumnV2( + *(copy.deepcopy(a, memo) for a in self.__getnewargs__()), + tensor_core_shape=self._tensor_core_shape, + embedding_lookup_device=self._embedding_lookup_device) + def _get_dense_tensor_internal(self, transformation_cache, state_manager): """Private method that follows _get_dense_tensor_internal.""" - if (tpu.under_tpu_inference_context() and - self._embedding_lookup_device == EmbeddingDevice.TPU_EMBEDDING_CORE): - raise ValueError('Using embedding_lookup_device=tpu_embedding_core ' - 'during inference is not supported.') - if self._embedding_lookup_device == EmbeddingDevice.CPU: - if tpu.under_tpu_inference_context(): - return super(_TPUSharedDeviceSpecificEmbeddingColumnV2, - self)._get_dense_tensor_internal(transformation_cache, - state_manager) - else: - raise ValueError( - 'Using TPUSharedEmbeddingColumn with ' - 'embedding_lookup_device="cpu" during training is not supported.') - sparse_tensor = transformation_cache.get(self.categorical_column.name, - state_manager) + _check_invalid_cases(self._embedding_lookup_device) + # CPU Case. + is_cpu = self._embedding_lookup_device == EmbeddingDevice.CPU + is_cpu = is_cpu or _is_running_on_cpu() + if is_cpu: + return super(_TPUSharedDeviceSpecificEmbeddingColumnV2, + self)._get_dense_tensor_internal(transformation_cache, + state_manager) + # TPU_EMBEDDING_CORE case. + if self._embedding_lookup_device == EmbeddingDevice.TPU_EMBEDDING_CORE: + return super(_TPUSharedDeviceSpecificEmbeddingColumnV2, + self)._get_dense_tensor_internal(transformation_cache, + state_manager) - # Use outside compile to densify and pad the input tensors. - def host_computation(): - return pad_sparse_embedding_lookup_indices(sparse_tensor, - self._tensor_core_shape[1]) + # TPU_EMBEDDING_CORE cases. + if tpu.under_tpu_inference_context(): + # For inference, use outside compile to densify and pad the input tensors. + sparse_tensor = transformation_cache.get(self.categorical_column.name, + state_manager) - values, mask = tpu.outside_compilation(host_computation) + def host_computation(): + return pad_sparse_embedding_lookup_indices(sparse_tensor, + self._tensor_core_shape[1]) + + values, mask = tpu.outside_compilation(host_computation) + else: + # For training, the inputs should already have been densified and padded. + values = transformation_cache.get(self.categorical_column.name, + state_manager) + mask = transformation_cache.get( + self.categorical_column.name + _TENSOR_CORE_MASK_KEY_SUFFIX, + state_manager) # Do a dense embedding lookup on TensorCore. embedding_weights = self.shared_embedding_column_creator.embedding_weights - embedding = sparse_embedding_aggregate_slice(embedding_weights, - (values, mask), - self.get_combiner()) - return embedding + return sparse_embedding_aggregate_slice(embedding_weights, (values, mask), + self.get_combiner()) diff --git a/tensorflow/python/tpu/feature_column_v2_test.py b/tensorflow/python/tpu/feature_column_v2_test.py index 282d176b301..932fe4e5a0a 100644 --- a/tensorflow/python/tpu/feature_column_v2_test.py +++ b/tensorflow/python/tpu/feature_column_v2_test.py @@ -34,6 +34,7 @@ from tensorflow.python.ops import variables as variables_lib from tensorflow.python.platform import test from tensorflow.python.tpu import feature_column_v2 as tpu_fc from tensorflow.python.tpu import tpu +from tensorflow.python.tpu import tpu_function def _initialized_session(): @@ -514,50 +515,119 @@ class DeviceSpecificEmbeddingColumnTestV2(test.TestCase, embedding_lookup_device='tpu_tensor_core', tensor_core_shape=[None, 3]) - # Run in TPUInferenceContext so that we hit the intended densification case. + # Run in TPUContexts so that we hit the intended densification case. context = tpu._TPUInferenceContext('tpu_inference') context.Enter() + with tpu_function.tpu_shard_context(1): + dense_features = fc_lib.DenseFeatures(embedding_column) + # Sqrtn combiner not supported for now. + if combiner == 'sqrtn': + with self.assertRaisesRegexp( + ValueError, 'Dense TPU Embedding does not support combiner'): + embedding_lookup = dense_features(input_features) + return + if combiner == 'mean': + expected_lookups = ( + # example 0: + (7., 11.), # ids [2], embedding = [7, 11] + # example 1: + (2., 3.5), # ids [0, 1], embedding = mean([1, 2] + [3, 5]) = + # [2, 3.5] + ) + elif combiner == 'sum': + expected_lookups = ( + # example 0: + (7., 11.), # ids [2], embedding = [7, 11] + # example 1: + (4., 7), # ids [0, 1], embedding = sum([1, 2] + [3, 5]) = [4, 7] + ) - dense_features = fc_lib.DenseFeatures(embedding_column) - # Sqrtn combiner not supported for now. - if combiner == 'sqrtn': - with self.assertRaisesRegexp( - ValueError, 'Dense TPU Embedding does not support combiner'): - embedding_lookup = dense_features(input_features) - return - if combiner == 'mean': + embedding_lookup = dense_features(input_features) + + # Assert expected embedding variable and lookups. + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + if shared: + self.assertCountEqual(('inp_shared_embedding:0',), + tuple([v.name for v in global_vars])) + else: + self.assertCountEqual( + ('dense_features/inp_embedding/embedding_weights:0',), + tuple([v.name for v in global_vars])) + + embedding_var = global_vars[0] + with _initialized_session(): + self.assertAllEqual(embedding_values, embedding_var.eval()) + eval_res = embedding_lookup.eval() + self.assertAllEqual(expected_lookups, eval_res) + context.Exit() + + @test_util.deprecated_graph_mode_only + def test_empty_row(self): + # Inputs. + vocabulary_size = 3 + input_sparse_tensor = sparse_tensor.SparseTensorValue( + # example 0, ids [] + # example 1, ids [0, 1, 3] + indices=((1, 0), (1, 1), (1, 4)), + values=(0, 1, 3), + dense_shape=(2, 5)) + input_features = {'inp': input_sparse_tensor} + + # Embedding variable. + embedding_dimension = 2 + embedding_values = ( + (1., 2.), # id 0 + (3., 5.), # id 1 + (7., 11.), # id 2 + (13., 17.) # id 3 + ) + + def _initializer(shape, dtype, partition_info=None): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + + # Build columns. + categorical_column_input = fc_lib.categorical_column_with_identity( + key='inp', num_buckets=vocabulary_size) + + # Set tensor_core_shape to be [None, 20] to ensure some padding and + # dynamic batch size. + embedding_column = tpu_fc.embedding_column_v2( + categorical_column_input, + dimension=embedding_dimension, + initializer=_initializer, + combiner='mean', + embedding_lookup_device='tpu_tensor_core', + tensor_core_shape=[None, 3]) + + # Run in TPUContexts so that we hit the intended densification case. + context = tpu._TPUInferenceContext('tpu_inference') + context.Enter() + with tpu_function.tpu_shard_context(1): + dense_features = fc_lib.DenseFeatures(embedding_column) expected_lookups = ( # example 0: - (7., 11.), # ids [2], embedding = [7, 11] + (0., 0.), # ids [], embedding = [0, 0] # example 1: (2., 3.5), # ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5] ) - elif combiner == 'sum': - expected_lookups = ( - # example 0: - (7., 11.), # ids [2], embedding = [7, 11] - # example 1: - (4., 7), # ids [0, 1], embedding = sum([1, 2] + [3, 5]) = [4, 7] - ) - embedding_lookup = dense_features(input_features) + embedding_lookup = dense_features(input_features) - # Assert expected embedding variable and lookups. - global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - if shared: - self.assertCountEqual(('inp_shared_embedding:0',), - tuple([v.name for v in global_vars])) - else: + # Assert expected embedding variable and lookups. + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) self.assertCountEqual( ('dense_features/inp_embedding/embedding_weights:0',), tuple([v.name for v in global_vars])) - embedding_var = global_vars[0] - with _initialized_session(): - self.assertAllEqual(embedding_values, embedding_var.eval()) - eval_res = embedding_lookup.eval() - self.assertAllEqual(expected_lookups, eval_res) - context.Exit() + embedding_var = global_vars[0] + with _initialized_session(): + self.assertAllEqual(embedding_values, embedding_var.eval()) + eval_res = embedding_lookup.eval() + self.assertAllEqual(expected_lookups, eval_res) + context.Exit() @test_util.deprecated_graph_mode_only def test_error_dense_shape_invalid(self): diff --git a/tensorflow/python/tpu/profiler/BUILD b/tensorflow/python/tpu/profiler/BUILD index b505262c6a2..84ffb4234c0 100644 --- a/tensorflow/python/tpu/profiler/BUILD +++ b/tensorflow/python/tpu/profiler/BUILD @@ -38,7 +38,8 @@ py_library( "//tensorflow/python:platform", "//tensorflow/python:versions", "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib", - "//tensorflow/python/eager:profiler_client", + "//tensorflow/python/profiler:profiler_client", + "//tensorflow/python/profiler:profiler_v2", "@absl_py//absl:app", "@absl_py//absl/flags", ], diff --git a/tensorflow/python/tpu/profiler/capture_tpu_profile.py b/tensorflow/python/tpu/profiler/capture_tpu_profile.py index f0d22027e4e..0068dc402c0 100644 --- a/tensorflow/python/tpu/profiler/capture_tpu_profile.py +++ b/tensorflow/python/tpu/profiler/capture_tpu_profile.py @@ -25,7 +25,8 @@ from absl import flags from distutils.version import LooseVersion from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver as resolver -from tensorflow.python.eager import profiler_client +from tensorflow.python.profiler import profiler_client +from tensorflow.python.profiler import profiler_v2 as profiler from tensorflow.python.framework import errors from tensorflow.python.framework import versions from tensorflow.python.platform import gfile @@ -65,9 +66,10 @@ flags.DEFINE_integer('duration_ms', 0, flags.DEFINE_integer( 'num_tracing_attempts', 3, 'Automatically retry N times when no trace ' 'event is collected.') -flags.DEFINE_boolean('include_dataset_ops', True, - 'Set to false to profile longer TPU ' - 'device traces.') +flags.DEFINE_boolean('include_dataset_ops', True, 'Deprecated.') +flags.DEFINE_integer( + 'host_tracer_level', 2, 'Adjust host tracer level to control the verbosity ' + ' of the TraceMe event being collected.') # Monitoring parameters flags.DEFINE_integer( @@ -77,8 +79,7 @@ flags.DEFINE_integer( flags.DEFINE_integer( 'num_queries', 100, 'This script will run monitoring for num_queries before it stops.') -flags.DEFINE_boolean('display_timestamp', False, - 'Set to true to display timestamp in monitoring results.') +flags.DEFINE_boolean('display_timestamp', True, 'Deprecated.') def get_workers_list(cluster_resolver): @@ -111,8 +112,7 @@ def get_workers_list(cluster_resolver): return ','.join(workers_list) -def monitoring_helper(service_addr, duration_ms, monitoring_level, - display_timestamp, num_queries): +def monitoring_helper(service_addr, duration_ms, monitoring_level, num_queries): """Helper function to print monitoring results. Helper function to print monitoring results for num_queries times. @@ -122,15 +122,13 @@ def monitoring_helper(service_addr, duration_ms, monitoring_level, duration_ms: Duration of one monitoring sample in milliseconds. monitoring_level: An integer between 1 and 2. Level 2 is more verbose than level 1 and shows more metrics. - display_timestamp: Set to true to display timestamp in monitoring. num_queries: Number of monitoring samples to collect. """ if monitoring_level <= 0 or monitoring_level > 2: sys.exit('Please choose a monitoring level between 1 and 2.') for query in range(0, num_queries): - res = profiler_client.monitor(service_addr, duration_ms, monitoring_level, - display_timestamp) + res = profiler_client.monitor(service_addr, duration_ms, monitoring_level) print('Cloud TPU Monitoring Results (Sample ', query, '):\n\n', res) @@ -144,8 +142,8 @@ def main(unused_argv=None): print('TensorFlow version %s detected' % tf_version) print('Welcome to the Cloud TPU Profiler v%s' % profiler_version.__version__) - if LooseVersion(tf_version) < LooseVersion('1.14.0'): - sys.exit('You must install tensorflow >= 1.14.0 to use this plugin.') + if LooseVersion(tf_version) < LooseVersion('2.2.0'): + sys.exit('You must install tensorflow >= 2.2.0 to use this plugin.') if not FLAGS.service_addr and not FLAGS.tpu: sys.exit('You must specify either --service_addr or --tpu.') @@ -184,7 +182,7 @@ def main(unused_argv=None): FLAGS.duration_ms, ' ms and show metrics for ', FLAGS.num_queries, ' time(s).') monitoring_helper(service_addr, duration_ms, FLAGS.monitoring_level, - FLAGS.display_timestamp, FLAGS.num_queries) + FLAGS.num_queries) else: if not FLAGS.logdir: sys.exit('You must specify either --logdir or --monitoring_level.') @@ -193,11 +191,16 @@ def main(unused_argv=None): gfile.MakeDirs(FLAGS.logdir) try: - profiler_client.start_tracing(service_addr, - os.path.expanduser(FLAGS.logdir), - duration_ms, workers_list, - FLAGS.include_dataset_ops, - FLAGS.num_tracing_attempts) + if LooseVersion(tf_version) < LooseVersion('2.3.0'): + profiler_client.trace(service_addr, os.path.expanduser(FLAGS.logdir), + duration_ms, workers_list, + FLAGS.num_tracing_attempts) + else: + options = profiler.ProfilerOptions( + host_tracer_level=FLAGS.host_tracer_level) + profiler_client.trace(service_addr, os.path.expanduser(FLAGS.logdir), + duration_ms, workers_list, + FLAGS.num_tracing_attempts, options) except errors.UnavailableError: sys.exit(0) diff --git a/tensorflow/python/tpu/tensor_tracer.proto b/tensorflow/python/tpu/tensor_tracer.proto index ad5392d65fe..7b745f0f45b 100644 --- a/tensorflow/python/tpu/tensor_tracer.proto +++ b/tensorflow/python/tpu/tensor_tracer.proto @@ -21,6 +21,10 @@ message TensorTracerReport { // A map from tensor name to its TracedTensorDef. map<string, TracedTensorDef> tensordef = 3; + // The fingerprint of the TensorTracerReport (fingerprint calculation excludes + // this field and graphdef). + string fingerprint = 4; + message TensorTracerConfig { // Tensor tracer version, e.g. hostcall, outside compilation. string version = 1; diff --git a/tensorflow/python/tpu/tensor_tracer.py b/tensorflow/python/tpu/tensor_tracer.py index bd96de42f3a..b4f99897094 100644 --- a/tensorflow/python/tpu/tensor_tracer.py +++ b/tensorflow/python/tpu/tensor_tracer.py @@ -100,7 +100,7 @@ _TT_TENSORBOARD_PLUGIN_NAME = 'tensor_tracer' _TT_HOSTCALL_KEY = 'tensor_tracer_host_call' _TT_EVENT_FILE_SUFFIX = '.tensor_tracer' -_TT_SUMMARY_MAX_QUEUE = 100 +_TT_SUMMARY_MAX_QUEUE = 10 def set_parameters(tensor_tracer_params=None): @@ -206,6 +206,9 @@ def set_parameters(tensor_tracer_params=None): -> op2 -> op1 -> op0, if op0 has a NaN and trace_stack_size is 1, the result of op1 will also be printed. trace_stack_size is 2, the result of op1 and op2 will be printed. + - use_fingerprint_subdirectory: The trace directory will be chosen as + using the fingerprint of the trace metadata under the provided + trace_dir. """ flags = '--%s=1' % tensor_tracer_flags.FLAG_NAME_ENABLE if tensor_tracer_params: @@ -547,6 +550,7 @@ class TensorTracer(object): self._traced_op_names = set() self._report_proto = None self._temp_cache_var = [] + self._report_proto_path = '' def report_proto(self): """Getter for tensor_tracer.proto object for summary and full_tensor_summary modes. @@ -564,6 +568,14 @@ class TensorTracer(object): 'Report proto only exists for ' 'trace_mode=[summary|full_tensor_summary]') + def report_proto_path(self): + """Getter for path where tensor_tracer.proto object should be written. + + Returns: + A string path. + """ + return self._report_proto_path + def _get_all_cache_variables(self): return self._cache_variables @@ -1366,6 +1378,13 @@ class TensorTracer(object): self._report_proto = report_handler.create_report_proto( self._tt_config, self._parameters, tensor_trace_order, tensor_trace_points, self._signature_types()) + if self._parameters.use_fingerprint_subdir: + self._parameters.trace_dir = os.path.join( + self._parameters.trace_dir, self._report_proto.fingerprint) + logging.info('TensorTracer updating trace_dir to %s', + self._parameters.trace_dir) + self._report_proto_path = tensor_tracer_report.report_proto_path( + self._parameters.trace_dir) if self._parameters.report_file_path != _SKIP_REPORT_FILE: report_handler.write_report_proto(self._report_proto, self._parameters) else: diff --git a/tensorflow/python/tpu/tensor_tracer_flags.py b/tensorflow/python/tpu/tensor_tracer_flags.py index c5e3e88597b..4e412c46e82 100644 --- a/tensorflow/python/tpu/tensor_tracer_flags.py +++ b/tensorflow/python/tpu/tensor_tracer_flags.py @@ -74,6 +74,7 @@ FLAG_NAME_DUMP_BEFORE_AFTER_GRAPHS = 'dump_graphs' FLAG_NAME_SUMMARY_SIGNATURES = 'signatures' FLAG_NAME_SUMMARY_PER_CORE = 'collect_summary_per_core' FLAG_NAME_TEMP_CACHE_VAR = 'use_temp_cache' +FLAG_NAME_FINGERPRINT_DIR = 'use_fingerprint_subdirectory' _OP_RANGE_PAT = re.compile(r'(\d+):(\d+)') _TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR = 'TEST_UNDECLARED_OUTPUTS_DIR' @@ -127,6 +128,7 @@ class TTParameters(object): self.trace_scalar_ops = self.is_flag_on(FLAG_NAME_TRACE_SCALAR_OPS) self.use_compact_trace = self.is_flag_on(FLAG_NAME_USE_COMPACT_TRACE) self.use_temp_cache_var = self.is_flag_on(FLAG_NAME_TEMP_CACHE_VAR) + self.use_fingerprint_subdir = self.is_flag_on(FLAG_NAME_FINGERPRINT_DIR) # _trace_ops_before_included and _trace_ops_after_included denotes to depth # of tracing relative to the ops given in --included_opnames or @@ -274,7 +276,7 @@ class TTParameters(object): FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS, FLAG_NAME_OP_RANGE, FLAG_NAME_DUMP_BEFORE_AFTER_GRAPHS, FLAG_NAME_TRACE_LEVEL, FLAG_NAME_SUMMARY_SIGNATURES, FLAG_NAME_SUMMARY_PER_CORE, - FLAG_NAME_TEMP_CACHE_VAR + FLAG_NAME_TEMP_CACHE_VAR, FLAG_NAME_FINGERPRINT_DIR ] tensor_tracer_flags = self._env.get(FLAGS_ENV_VAR) if not tensor_tracer_flags: diff --git a/tensorflow/python/tpu/tensor_tracer_report.py b/tensorflow/python/tpu/tensor_tracer_report.py index e8a122d981f..3270b2a2fd3 100644 --- a/tensorflow/python/tpu/tensor_tracer_report.py +++ b/tensorflow/python/tpu/tensor_tracer_report.py @@ -19,8 +19,10 @@ from __future__ import division from __future__ import print_function import collections +import hashlib import os + from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.tpu import tensor_tracer_pb2 @@ -53,6 +55,18 @@ _CURRENT_VERSION = 'use-outside-compilation' _TT_REPORT_PROTO = 'tensor_tracer_report.report_pb' +def report_proto_path(trace_dir): + """Returns the path where report proto should be written. + + Args: + trace_dir: String denoting the trace directory. + + Returns: + A string denoting the path to the report proto. + """ + return os.path.join(trace_dir, _TT_REPORT_PROTO) + + def topological_sort(g): """Performs topological sort on the given graph. @@ -206,6 +220,12 @@ class OpenReportFile(object): self._report_file.close() +def proto_fingerprint(message_proto): + serialized_message = message_proto.SerializeToString() + hasher = hashlib.sha256(serialized_message) + return hasher.hexdigest() + + class TTReportHandle(object): """Utility class responsible from creating a tensor tracer report.""" @@ -255,8 +275,6 @@ class TTReportHandle(object): key=lambda x: x[1]): report.config.signatures.append(signature_name) - tf_graph = tensor_trace_order.graph_order.graph - report.graphdef.CopyFrom(tf_graph.as_graph_def()) for tensor in tensor_trace_order.graph_order.tensors: tensor_def = tensor_tracer_pb2.TensorTracerReport.TracedTensorDef() tensor_def.name = tensor.name @@ -265,6 +283,11 @@ class TTReportHandle(object): tensor_def.cache_index = ( tensor_trace_order.tensorname_to_cache_idx[tensor.name]) else: + # To prevent small changes affecting the fingerprint calculation, avoid + # writing the untraced tensors to metadata. Fingerprints will be + # different only when the list of the traced tensors are different. + if tt_parameters.use_fingerprint_subdir: + continue tensor_def.is_traced = False if tensor.name in tensor_trace_points: @@ -274,12 +297,17 @@ class TTReportHandle(object): elif tensor.op.name in self.instrument_records: tensor_def.explanation = self.instrument_records[tensor.op.name] report.tensordef[tensor.name].CopyFrom(tensor_def) + report.fingerprint = proto_fingerprint(report) + logging.info('TensorTracerProto fingerprint is %s.', + report.fingerprint) + tf_graph = tensor_trace_order.graph_order.graph + report.graphdef.CopyFrom(tf_graph.as_graph_def()) return report def write_report_proto(self, report_proto, tt_parameters): """Writes the given report proto under trace_dir.""" gfile.MakeDirs(tt_parameters.trace_dir) - report_path = os.path.join(tt_parameters.trace_dir, _TT_REPORT_PROTO) + report_path = report_proto_path(tt_parameters.trace_dir) with gfile.GFile(report_path, 'wb') as f: f.write(report_proto.SerializeToString()) diff --git a/tensorflow/python/tpu/tpu_embedding_v2.py b/tensorflow/python/tpu/tpu_embedding_v2.py new file mode 100644 index 00000000000..5a66f6ce8b9 --- /dev/null +++ b/tensorflow/python/tpu/tpu_embedding_v2.py @@ -0,0 +1,1321 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Mid level API for TPU Embeddings.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import functools +from absl import logging + +from tensorflow.core.framework import attr_value_pb2 +from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2 +from tensorflow.python.distribute import device_util +from tensorflow.python.distribute import distribution_strategy_context +from tensorflow.python.distribute import sharded_variable +from tensorflow.python.distribute import tpu_strategy +from tensorflow.python.distribute import values as tf_values +from tensorflow.python.eager import def_function +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import embedding_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables as tf_variables +from tensorflow.python.ops.ragged import ragged_tensor +from tensorflow.python.tpu import tpu +from tensorflow.python.tpu import tpu_embedding_v2_utils +from tensorflow.python.tpu.ops import tpu_ops +from tensorflow.python.training.saving import saveable_hook +from tensorflow.python.training.tracking import tracking +from tensorflow.python.util import compat +from tensorflow.python.util import nest +from tensorflow.python.util.tf_export import tf_export + + +_HOOK_KEY = "TPUEmbedding_saveable" +_NAME_KEY = "_tpu_embedding_layer" + + +# TODO(bfontain): Cleanup and remove this once there is an implementation of +# sharded variables that can be used in the PSStrategy with optimizers. +# We implement just enough of the of a tf.Variable so that this could be passed +# to an optimizer. +class TPUShardedVariable(sharded_variable.ShardedVariable): + """A ShardedVariable class for TPU.""" + + @property + def _in_graph_mode(self): + return self.variables[0]._in_graph_mode # pylint: disable=protected-access + + @property + def _unique_id(self): + return self.variables[0]._unique_id # pylint: disable=protected-access + + @property + def _distribute_strategy(self): + return self.variables[0]._distribute_strategy # pylint: disable=protected-access + + @property + def _shared_name(self): + return self._name + + +def _add_key_attr(op, name): + op._set_attr(_NAME_KEY, attr_value_pb2.AttrValue(s=compat.as_bytes(name))) # pylint: disable=protected-access + + +@tf_export("tpu.experimental.embedding.TPUEmbedding") +class TPUEmbedding(tracking.AutoTrackable): + """The TPUEmbedding mid level API. + + NOTE: When instantiated under a TPUStrategy, this class can only be created + once per call to `tf.tpu.experimental.initialize_tpu_system`. If you wish to + re-initialize the embedding engine you must re-initialize the tpu as well. + Doing this will clear any variables from TPU, so ensure you have checkpointed + before you do this. If a further instances of the class are needed, + set the `initialize_tpu_embedding` argument to `False`. + + This class can be used to support training large embeddings on TPU. When + creating an instance of this class, you must specify the complete set of + tables and features you expect to lookup in those tables. See the + documentation of `tf.tpu.experimental.embedding.TableConfig` and + `tf.tpu.experimental.embedding.FeatureConfig` for more details on the complete + set of options. We will cover the basic usage here. + + NOTE: multiple `FeatureConfig` objects can use the same `TableConfig` object, + allowing different features to share the same table: + + ```python + table_config_one = tf.tpu.experimental.embedding.TableConfig( + vocabulary_size=..., + dim=...) + table_config_two = tf.tpu.experimental.embedding.TableConfig( + vocabulary_size=..., + dim=...) + feature_config = { + 'feature_one': tf.tpu.experimental.embedding.FeatureConfig( + table=table_config_one), + 'feature_two': tf.tpu.experimental.embedding.FeatureConfig( + table=table_config_one), + 'feature_three': tf.tpu.experimental.embedding.FeatureConfig( + table=table_config_two)} + ``` + + There are two modes under which the `TPUEmbedding` class can used. This + depends on if the class was created under a `TPUStrategy` scope or not. + + Under `TPUStrategy`, we allow access to the method `enqueue`, `dequeue` and + `apply_gradients`. We will show examples below of how to use these to train + and evaluate your model. Under CPU, we only access to the `embedding_tables` + property which allow access to the embedding tables so that you can use them + to run model evaluation/prediction on CPU. + + First lets look at the `TPUStrategy` mode. Initial setup looks like: + + ```python + strategy = tf.distribute.experimental.TPUStrategy(...) + with strategy.scope(): + embedding = tf.tpu.experimental.embedding.TPUEmbedding( + feature_config=feature_config, + batch_size=1024, + optimizer=tf.tpu.experimental.embedding.SGD(0.1)) + ``` + + To use this API on TPU you should use a custom training loop. Below is an + example of a training and evaluation step: + + ```python + @tf.function + def training_step(dataset_iterator, num_steps): + def tpu_step(tpu_features): + with tf.GradientTape() as tape: + activations = embedding.dequeue() + tape.watch(activations) + model_output = model(activations) + loss = ... # some function of labels and model_output + + embedding_gradients = tape.gradient(loss, activations) + embedding.apply_gradients(embedding_gradients) + # Insert your model gradient and optimizer application here + + for _ in tf.range(num_steps): + embedding_features, tpu_features = next(dataset_iterator) + embedding.enqueue(embedding_features, training=True) + strategy.run(tpu_step, args=(embedding_features, )) + + @tf.function + def evalution_step(dataset_iterator, num_steps): + def tpu_step(tpu_features): + activations = embedding.dequeue() + model_output = model(activations) + # Insert your evaluation code here. + + for _ in tf.range(num_steps): + embedding_features, tpu_features = next(dataset_iterator) + embedding.enqueue(embedding_features, training=False) + strategy.run(tpu_step, args=(embedding_features, )) + ``` + + NOTE: The calls to `enqueue` have `training` set to `True` when + `embedding.apply_gradients` is used and set to `False` when + `embedding.apply_gradients` is not present in the function. If you don't + follow this pattern you may cause an error to be raised or the tpu may + deadlock. + + In the above examples, we assume that the user has a dataset which returns + a tuple where the first element of the tuple matches the structure of what + was passed as the `feature_config` argument to the object initializer. Also we + utilize `tf.range` to get a `tf.while_loop` in order to increase performance. + + When checkpointing your model, you should include your + `tf.tpu.experimental.embedding.TPUEmbedding` object in the checkpoint. It is a + trackable object and saving it will save the embedding tables and their + optimizer slot variables: + + ```python + checkpoint = tf.train.Checkpoint(model=model, embedding=embedding) + checkpoint.save(...) + ``` + + On CPU, only the `embedding_table` property is usable. This will allow you to + restore a checkpoint to the object and have access to the table variables: + + ```python + model = model_fn(...) + embedding = tf.tpu.experimental.embedding.TPUEmbedding( + feature_config=feature_config, + batch_size=1024, + optimizer=tf.tpu.experimental.embedding.SGD(0.1)) + checkpoint = tf.train.Checkpoint(model=model, embedding=embedding) + checkpoint.restore(...) + + tables = embedding.embedding_tables + ``` + + You can now use table in functions like `tf.nn.embedding_lookup` to perform + your embedding lookup and pass to your model. + + """ + + def __init__(self, feature_config, batch_size, optimizer, + pipeline_execution_with_tensor_core=False, + initialize_tpu_embedding=True): + """Creates the TPUEmbedding mid level API object. + + ```python + strategy = tf.distribute.experimental.TPUStrategy(...) + with strategy.scope(): + embedding = tf.tpu.experimental.embedding.TPUEmbedding( + feature_config=tf.tpu.experimental.embedding.FeatureConfig( + table=tf.tpu.experimental.embedding.TableConfig( + dim=..., + vocabulary_size=...))) + ``` + + Args: + feature_config: A nested structure of + `tf.tpu.experimental.embedding.FeatureConfig` configs. + batch_size: The global batch size that you indend to use. Note that is + fixed and the same batch size must be used for both training and + evaluation. + optimizer: An instance of one of `tf.tpu.experimental.embedding.SGD`, + `tf.tpu.experimental.embedding.Adagrad` or + `tf.tpu.experimental.embedding.Adam`. + pipeline_execution_with_tensor_core: If True, the TPU embedding + computations will overlap with the TensorCore computations (and hence + will be one step old). Set to True for improved performance. + initialize_tpu_embedding: If False, will not initialize the TPU embedding + engine. If this is set to False and another instance of this class has + not initialized the tpu embedding engine, the creation of this object + will fail. + + Raises: + ValueError: If optimizer is not one of tf.tpu.experimental.embedding.(SGD, + Adam or Adagrad). + """ + self._strategy = distribution_strategy_context.get_strategy() + self._using_tpu = isinstance(self._strategy, tpu_strategy.TPUStrategy) + self._pipeline_execution_with_tensor_core = ( + pipeline_execution_with_tensor_core) + + self._feature_config = feature_config + + # The TPU embedding ops are slightly inconsistent with how they refer to + # tables: + # * The enqueue op takes a parallel list of tensors for input, one of those + # is the table id for the feature which matches the integer index of the + # table in the proto created by _create_config_proto(). + # * The recv_tpu_embedding_activations op emits lookups per table in the + # order from the config proto. + # * The send_tpu_embedding_gradients expects input tensors to be per table + # in the same order as the config proto. + # * Per optimizer load and retrieve ops are specified per table and take the + # table name rather than the table id. + # Thus we must fix a common order to tables and ensure they have unique + # names. + + # Set table order here + self._table_config = list( + {feature.table for feature in nest.flatten(feature_config)}) + + # Ensure tables have unique names. Also error check the optimizer as we + # specifically don't do that in the TableConfig class to allow high level + # APIs that are built on this to use strings/other classes to represent + # optimizers (before they are passed to this class). + table_names = [] + for i, table in enumerate(self._table_config): + if table.optimizer is None: + # TODO(bfontain) Should we allow some sort of optimizer merging here? + table.optimizer = optimizer + if not isinstance(table.optimizer, tpu_embedding_v2_utils._Optimizer): # pylint: disable=protected-access + raise ValueError("{} is an unsupported optimizer class. Please pass an " + "instance of one of the optimizer classes under " + "tf.tpu.experimental.embedding.".format( + type(table.optimizer))) + if table.name is None: + table.name = "table_{}".format(i) + if table.name in table_names: + raise ValueError("Multiple tables with name {} found.".format( + table.name)) + table_names.append(table.name) + + if self._using_tpu: + # Extract a list of callable learning rates also in fixed order. Each + # table in the confix proto will get a index into this list and we will + # pass this list in the same order after evaluation to the + # send_tpu_embedding_gradients op. + self._dynamic_learning_rates = list({ + table.optimizer.learning_rate for table in self._table_config if + callable(table.optimizer.learning_rate)}) + + # We need to list of host devices for the load/retrieve operations. + self._hosts = get_list_of_hosts(self._strategy) + + # TODO(bfontain) Remove this once we have an official way of splitting + # prefetch between host and device. + self._strategy.extended._set_prefetch_on_host(True) # pylint: disable=protected-access + + # We generally use the per core batch size, but will have the user pass + # in a global batch size. + self._batch_size = batch_size // self._strategy.num_replicas_in_sync + + self._config_proto = self._create_config_proto() + if initialize_tpu_embedding: + # This is mainly for testing purposes, sometimes we don't want to + # initialize the embedding engine, but just want a copy of the API + # which can interact with an already initialized engine. + logging.info("Initializing TPU Embedding engine with config: %s", + self._config_proto) + @def_function.function + def load_config(): + tpu.initialize_system_for_tpu_embedding(self._config_proto) + + load_config() + logging.info("Done initializing TPU Embedding engine.") + + # Create and load variables and slot variables into the TPU. + # Note that this is a dict of dicts. Keys to the first dict are table names. + # We would prefer to use TableConfigs, but then these variables won't be + # properly tracked by the tracking API. + self._variables = self._create_variables_and_slots() + if self._using_tpu: + self._load_variables() + + @property + def embedding_tables(self): + """Returns a dict of embedding tables, keyed by `TableConfig`. + + This property only works when the `TPUEmbedding` object is created under a + non-TPU strategy. This is intended to be used to for CPU based lookup when + creating a serving checkpoint. + + Returns: + A dict of embedding tables, keyed by `TableConfig`. + + Raises: + RuntimeError: If object was created under a `TPUStrategy`. + """ + # We don't support returning tables on TPU due to their sharded nature and + # the fact that when using a TPUStrategy: + # 1. Variables are stale and are only updated when a checkpoint is made. + # 2. Updating the variables won't affect the actual tables on the TPU. + if self._using_tpu: + raise RuntimeError("Unable to retrieve embedding tables when using a TPU " + "strategy. If you need access, save your model, " + "create this object under a CPU strategy and restore.") + + # Only return the tables and not the slot variables. On CPU this are honest + # tf.Variables. + return {table: self._variables[table.name]["parameters"] + for table in self._table_config} + + def _create_config_proto(self): + """Creates the TPUEmbeddingConfiguration proto. + + This proto is used to initialize the TPU embedding engine. + + Returns: + A TPUEmbeddingConfiguration proto. + """ + + config_proto = tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration() + + # There are several things that need to be computed here: + # 1. Each table has a num_features, which corresponds to the number of + # output rows per example for this table. Sequence features count for + # their maximum sequence length. + # 2. Learning rate index: the index of the dynamic learning rate for this + # table (if it exists) in the list we created at initialization. + # We don't simply create one learning rate index per table as this has + # extremely bad performance characteristics. The more separate + # optimization configurations we have, the worse the performance will be. + num_features = {table: 0 for table in self._table_config} + for feature in nest.flatten(self._feature_config): + num_features[feature.table] += (1 if feature.max_sequence_length == 0 + else feature.max_sequence_length) + + # Map each callable dynamic learning rate to its in index in the list. + learning_rate_index = {r: i for i, r in enumerate( + self._dynamic_learning_rates)} + + for table in self._table_config: + table_descriptor = config_proto.table_descriptor.add() + table_descriptor.name = table.name + + # For small tables, we pad to the number of hosts so that at least one + # id will be assigned to each host. + table_descriptor.vocabulary_size = max(table.vocabulary_size, + self._strategy.extended.num_hosts) + table_descriptor.dimension = table.dim + + table_descriptor.num_features = num_features[table] + + parameters = table_descriptor.optimization_parameters + + # We handle the learning rate separately here and don't allow the + # optimization class to handle this, as it doesn't know about dynamic + # rates. + if callable(table.optimizer.learning_rate): + parameters.learning_rate.dynamic.tag = ( + learning_rate_index[table.optimizer.learning_rate]) + else: + parameters.learning_rate.constant = table.optimizer.learning_rate + + # Use optimizer to handle the rest of the parameters. + table.optimizer._set_optimization_parameters(parameters) # pylint: disable=protected-access + + # Always set mode to training, we override the mode during enqueue. + config_proto.mode = ( + tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration.TRAINING) + + config_proto.batch_size_per_tensor_core = self._batch_size + config_proto.num_hosts = self._strategy.extended.num_hosts + config_proto.num_tensor_cores = self._strategy.num_replicas_in_sync + + # TODO(bfontain): Allow users to pick MOD for the host sharding. + config_proto.sharding_strategy = ( + tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration.DIV_DEFAULT) + config_proto.pipeline_execution_with_tensor_core = ( + self._pipeline_execution_with_tensor_core) + + return config_proto + + def _compute_per_table_gradients(self, gradients): + """Computes a dict of lists of gradients, keyed by table name. + + Args: + gradients: A nested structure of Tensors (and Nones) with the same + structure as the feature config. + + Returns: + A dict of lists of tensors, keyed by the table names, containing the + gradients in the correct order with None gradients repalaced by zeros. + """ + + nest.assert_same_structure(self._feature_config, gradients) + + per_table_gradients = {table: [] for table in self._table_config} + for (path, gradient), feature in zip( + nest.flatten_with_joined_string_paths(gradients), + nest.flatten(self._feature_config)): + if gradient is not None and not isinstance(gradient, ops.Tensor): + raise ValueError( + "Found {} at path {} in gradients. Expected Tensor.".format( + type(gradient), path)) + + # Expected tensor shape differs for sequence and non-sequence features. + if feature.max_sequence_length > 0: + shape = [self._batch_size, feature.max_sequence_length, + feature.table.dim] + else: + shape = [self._batch_size, feature.table.dim] + + if gradient is not None: + if gradient.shape != shape: + raise ValueError("Found gradient of shape {} at path {}. Expected " + "shape {}.".format(gradient.shape, path, shape)) + + # We expand dims on non-sequence features so that all features are + # of rank 3 and we can concat on axis=1. + if len(shape) == 2: + gradient = array_ops.expand_dims(gradient, axis=1) + else: + # No gradient for this feature, since we must give a gradient for all + # features, pass in a zero tensor here. Note that this is not correct + # for all optimizers. + logging.warn("No gradient passed for feature %s, sending zero " + "gradient. This may not be correct behavior for certain " + "optimizers like Adam.", path) + # Create a shape to mimic the expand_dims above for non-sequence + # features. + if len(shape) == 2: + shape = [shape[0], 1, shape[1]] + gradient = array_ops.zeros(shape, dtype=dtypes.float32) + per_table_gradients[feature.table].append(gradient) + + return per_table_gradients + + def apply_gradients(self, gradients, name=None): + """Applies the gradient update to the embedding tables. + + If a gradient of `None` is passed in any position of the nested structure, + then an gradient update with a zero gradient is applied for that feature. + For optimizers like SGD or Adagrad, this is the same as applying no update + at all. For lazy Adam and other sparsely applied optimizers with decay, + ensure you understand the effect of applying a zero gradient. + + ```python + strategy = tf.distribute.experimental.TPUStrategy(...) + with strategy.scope(): + embedding = tf.tpu.experimental.embedding.TPUEmbedding(...) + + distributed_dataset = strategy.experimental_distribute_dataset(...) + dataset_iterator = iter(distributed_dataset) + + @tf.function + def training_step(): + def tpu_step(tpu_features): + with tf.GradientTape() as tape: + activations = embedding.dequeue() + tape.watch(activations) + + loss = ... # some computation involving activations + + embedding_gradients = tape.gradient(loss, activations) + embedding.apply_gradients(embedding_gradients) + + embedding_features, tpu_features = next(dataset_iterator) + embedding.enqueue(embedding_features, training=True) + strategy.run(tpu_step, args=(embedding_features, )) + + training_step() + ``` + + Args: + gradients: A nested structure of gradients, with structure matching the + `feature_config` passed to this object. + name: A name for the underlying op. + + Raises: + RuntimeError: If called when object wasn't created under a `TPUStrategy`. + ValueError: If a non-`tf.Tensor` non-`None` gradient is passed in, or a + `tf.Tensor` of the incorrect shape is passed in. Also if + the size of any sequence in `gradients` does not match corresponding + sequence in `feature_config`. + TypeError: If the type of any sequence in `gradients` does not match + corresponding sequence in `feature_config`. + """ + if not self._using_tpu: + raise RuntimeError("apply_gradients is not valid when TPUEmbedding " + "object is not created under a TPUStrategy.") + + # send_tpu_embedding_gradients requires per table gradient, if we only have + # one feature per table this isn't an issue. When multiple features share + # the same table, the order of the features in per table tensor returned by + # recv_tpu_embedding_activations matches the order in which they were passed + # to enqueue. + # In all three places, we use the fixed order given by nest.flatten to have + # a consistent feature order. + + # First construct a dict of tensors one for each table. + per_table_gradients = self._compute_per_table_gradients(gradients) + + # Now that we have a list of gradients we can compute a list of gradients + # in the fixed order of self._table_config which interleave the gradients of + # the individual features. We concat on axis 1 and then reshape into a 2d + # tensor. The send gradients op expects a tensor of shape + # [num_features*batch_size, dim] for each table. + interleaved_gradients = [] + for table in self._table_config: + interleaved_gradients.append(array_ops.reshape( + array_ops.concat(per_table_gradients[table], axis=1), + [-1, table.dim])) + op = tpu_ops.send_tpu_embedding_gradients( + inputs=interleaved_gradients, + learning_rates=[math_ops.cast(fn(), dtype=dtypes.float32) + for fn in self._dynamic_learning_rates], + config=self._config_proto.SerializeToString()) + + # Apply the name tag to the op. + if name is not None: + _add_key_attr(op, name) + + def dequeue(self, name=None): + """Get the embedding results. + + Returns a nested structure of `tf.Tensor` objects, matching the structure of + the `feature_config` argument to the `TPUEmbedding` class. The output shape + of the tensors is `(batch_size, dim)`, where `batch_size` is the per core + batch size, `dim` is the dimension of the corresponding `TableConfig`. If + the feature's corresponding `FeatureConfig` has `max_sequence_length` + greater than 0, the output will be a sequence of shape + `(batch_size, max_sequence_length, dim)` instead. + + ```python + strategy = tf.distribute.experimental.TPUStrategy(...) + with strategy.scope(): + embedding = tf.tpu.experimental.embedding.TPUEmbedding(...) + + distributed_dataset = strategy.experimental_distribute_dataset(...) + dataset_iterator = iter(distributed_dataset) + + @tf.function + def training_step(): + def tpu_step(tpu_features): + with tf.GradientTape() as tape: + activations = embedding.dequeue() + tape.watch(activations) + + loss = ... # some computation involving activations + + embedding_gradients = tape.gradient(loss, activations) + embedding.apply_gradients(embedding_gradients) + + embedding_features, tpu_features = next(dataset_iterator) + embedding.enqueue(embedding_features, training=True) + strategy.run(tpu_step, args=(embedding_features, )) + + training_step() + ``` + + Args: + name: A name for the underlying op. + + Returns: + A nested structure of tensors, with the same structure as `feature_config` + passed to this instance of the `TPUEmbedding` object. + + Raises: + RuntimeError: If called when object wasn't created under a `TPUStrategy`. + """ + if not self._using_tpu: + raise RuntimeError("dequeue is not valid when TPUEmbedding object is not " + "created under a TPUStrategy.") + + # The activations returned by this op are per table. So we must separate + # them out into per feature activations. The activations are interleaved: + # for each table, we expect a [num_features*batch_size, dim] tensor. + # E.g. we expect the slice [:num_features, :] to contain the lookups for the + # first example of all features using this table. + activations = tpu_ops.recv_tpu_embedding_activations( + num_outputs=len(self._table_config), + config=self._config_proto.SerializeToString()) + + # Apply the name tag to the op. + if name is not None: + _add_key_attr(activations[0].op, name) + + # Compute the number of features for this table. + num_features = {table: 0 for table in self._table_config} + for feature in nest.flatten(self._feature_config): + num_features[feature.table] += (1 if feature.max_sequence_length == 0 + else feature.max_sequence_length) + + # Activations are reshaped so that they are indexed by batch size and then + # by the 'feature' index within the batch. The final dimension should equal + # the dimension of the table. + table_to_activation = { + table: array_ops.reshape(activation, + [self._batch_size, num_features[table], -1]) + for table, activation in zip(self._table_config, activations)} + + # We process the features in the same order we enqueued them. + # For each feature we take the next slice of the activations, so need to + # track the activations and the current position we are in. + table_to_position = {table: 0 for table in self._table_config} + + per_feature_activations = [] + for feature in nest.flatten(self._feature_config): + activation = table_to_activation[feature.table] + feature_index = table_to_position[feature.table] + # We treat non-sequence and sequence features differently here as sequence + # features have rank 3 while non-sequence features have rank 2. + if feature.max_sequence_length == 0: + per_feature_activations.append( + activation[:, feature_index, :]) + table_to_position[feature.table] += 1 + else: + per_feature_activations.append( + activation[:, feature_index:( + feature_index+feature.max_sequence_length), :]) + table_to_position[feature.table] += feature.max_sequence_length + + # Pack the list back into the same nested structure as the features. + return nest.pack_sequence_as(self._feature_config, per_feature_activations) + + def _create_variables_and_slots(self): + """Create variables for TPU embeddings. + + Note under TPUStrategy this will ensure that all creations happen within a + variable creation scope of the sharded variable creator. + + Returns: + A dict of dicts. The outer dict is keyed by the table names and the inner + dicts are keyed by 'parameters' and the slot variable names. + """ + + def create_variables(table): + """Create all variables.""" + shape = (table.vocabulary_size, table.dim) + + # We use functools.partial here for the initial_value so that we have a + # variable creation that is compatible with both the sharded variable + # creator and the normal variable creator. The sharded variable creator + # will extract the shape of the tensor from the functool.partial object to + # decide on the sharding. + parameters = tf_variables.Variable( + name=table.name, + initial_value=functools.partial( + table.initializer, shape=shape, dtype=dtypes.float32), + trainable=not self._using_tpu) + slot_vars = table.optimizer._create_slots(parameters) # pylint: disable=protected-access + slot_vars["parameters"] = parameters + return slot_vars + + # Store tables based on name rather than TableConfig as we can't track + # through dicts with non-string keys, i.e. we won't be able to save. + variables = {} + for table in self._table_config: + if not self._using_tpu: + variables[table.name] = create_variables(table) + else: + with variable_scope.variable_creator_scope( + make_sharded_variable_creator(self._hosts)): + variables[table.name] = create_variables(table) + + return variables + + @def_function.function + def _load_variables(self): + """Load embedding tables to onto TPU for each table and host.""" + + def select_fn(host_id): + return lambda x: x.variables[host_id] + + num_hosts = self._strategy.extended.num_hosts + config = self._config_proto.SerializeToString() + for host_id, host in enumerate(self._hosts): + variables = nest.map_structure(select_fn(host_id), self._variables) + with ops.device(host): + for table in self._table_config: + table.optimizer._load()( # pylint: disable=protected-access + table_name=table.name, + num_shards=num_hosts, + shard_id=host_id, + config=config, + **variables[table.name]) + # Ensure that only the first table/first host gets a config so that we + # don't bloat graph by attaching this large string to each op. + # We have num tables * num hosts of these so for models with a large + # number of tables training on a large slice, this can be an issue. + config = None + + @def_function.function + def _retrieve_variables(self): + """Retrieve embedding tables from TPU to host memory.""" + num_hosts = self._strategy.extended.num_hosts + config = self._config_proto.SerializeToString() + for host_id, host in enumerate(self._hosts): + with ops.device(host): + for table in self._table_config: + retrieved = table.optimizer._retrieve()( # pylint: disable=protected-access + table_name=table.name, + num_shards=num_hosts, + shard_id=host_id, + config=config) + # When there are no slot variables (e.g with SGD) this returns a + # single tensor rather than a tuple. In this case we put the tensor in + # a list to make the following code easier to write. + if not isinstance(retrieved, tuple): + retrieved = (retrieved,) + + for i, slot in enumerate(["parameters"] + + table.optimizer._slot_names()): # pylint: disable=protected-access + # We must assign the CPU variables the values of tensors that were + # returned from the TPU. + self._variables[table.name][slot].variables[host_id].assign( + retrieved[i]) + # Ensure that only the first table/first host gets a config so that we + # don't bloat graph by attaching this large string to each op. + # We have num tables * num hosts of these so for models with a large + # number of tables training on a large slice, this can be an issue. + config = None + + def _gather_saveables_for_checkpoint(self): + """Overrides default Trackable implementation to add load/retrieve hook.""" + # This saveable should be here in both TPU and CPU checkpoints, so when on + # CPU, we add the hook with no functions. + # TODO(bfontain): Update restore logic in saver so that these hooks are + # always executed. Once that is done, we can output an empty list when on + # CPU. + def factory(name=_HOOK_KEY): + return TPUEmbeddingSaveable( + name, + self._load_variables if self._using_tpu else None, + self._retrieve_variables if self._using_tpu else None) + return {_HOOK_KEY: factory} + + # Some helper functions for the below enqueue function. + def _add_data_for_tensor(self, tensor, weight, indices, values, weights, + int_zeros, float_zeros, path): + if weight is not None: + raise ValueError( + "Weight specified for dense input {}, which is not allowed. " + "Weight will always be 1 in this case.".format(path)) + # For tensors, there are no indices and no weights. + indices.append(int_zeros) + values.append(math_ops.cast(tensor, dtypes.int32)) + weights.append(float_zeros) + + def _add_data_for_sparse_tensor(self, tensor, weight, indices, values, + weights, int_zeros, float_zeros, path): + indices.append(math_ops.cast(tensor.indices, dtypes.int32)) + values.append(math_ops.cast(tensor.values, dtypes.int32)) + # If we have weights they must be a SparseTensor. + if weight is not None: + if not isinstance(weight, sparse_tensor.SparseTensor): + raise ValueError("Weight for {} is type {} which does not match " + "type input which is SparseTensor.".format( + path, type(weight))) + weights.append(math_ops.cast(weight.values, dtypes.float32)) + else: + weights.append(float_zeros) + + def _add_data_for_ragged_tensor(self, tensor, weight, indices, values, + weights, int_zeros, float_zeros, path): + indices.append(math_ops.cast(tensor.row_splits, dtypes.int32)) + values.append(math_ops.cast(tensor.values, dtypes.int32)) + # If we have weights they must be a RaggedTensor. + if weight is not None: + if not isinstance(weight, ragged_tensor.RaggedTensor): + raise ValueError("Weight for {} is type {} which does not match " + "type input which is RaggedTensor.".format( + path, type(weight))) + weights.append(math_ops.cast(weight.values, dtypes.float32)) + else: + weights.append(float_zeros) + + def _generate_enqueue_op(self, flat_inputs, flat_weights, flat_features, + device_ordinal, mode_override): + """Outputs a the enqueue op given the inputs and weights. + + Args: + flat_inputs: A list of input tensors. + flat_weights: A list of input weights (or None) of the same length as + flat_inputs. + flat_features: A list of FeatureConfigs of the same length as flat_inputs. + device_ordinal: The device to create the enqueue op for. + mode_override: A tensor containing the string "train" or "inference". + + Returns: + The enqueue op. + """ + + # First we need to understand which op to use. This depends on if sparse + # or ragged tensors are in the flat_inputs. + sparse = False + ragged = False + for inp in flat_inputs: + if isinstance(inp, sparse_tensor.SparseTensor): + sparse = True + elif isinstance(inp, ragged_tensor.RaggedTensor): + ragged = True + if sparse and ragged: + raise ValueError( + "Found both SparseTensors and RaggedTensors in the input to the " + "enqueue operation. Please ensure that your data does not include " + "both SparseTensors and RaggedTensors. It is ok to have Tensors in " + "combination with one of the previous types.") + + # Combiners are per table, list in the same order as the table order. + combiners = [table.combiner for table in self._table_config] + + # Reverse mapping of self._table_config, so that we can lookup the table + # index. + table_to_id = {table: i for i, table in enumerate(self._table_config)} + + # These parallel arrays will be the inputs to the enqueue op. + indices = [] # sample_indices for sparse, sample_splits for ragged. + values = [] + weights = [] + table_ids = [] + max_sequence_lengths = [] + + # We have to supply a empty/zero tensor in a list position where we don't + # have data (e.g. indices for standard Tensor input, weight when no weight + # is specified). We create one op here per call, so that we reduce the + # graph size. + int_zeros = array_ops.zeros((0,), dtype=dtypes.int32) + float_zeros = array_ops.zeros((0,), dtype=dtypes.float32) + + # In the following loop we insert casts so that everything is either int32 + # or float32. This is because op inputs which are lists of tensors must be + # of the same type within the list. Moreover the CPU implementions of these + # ops cast to these types anyway, so we don't lose any data by casting + # early. + for inp, weight, (path, feature) in zip( + flat_inputs, flat_weights, flat_features): + table_ids.append(table_to_id[feature.table]) + max_sequence_lengths.append(feature.max_sequence_length) + if isinstance(inp, ops.Tensor): + self._add_data_for_tensor(inp, weight, indices, values, weights, + int_zeros, float_zeros, path) + elif isinstance(inp, sparse_tensor.SparseTensor): + self._add_data_for_sparse_tensor(inp, weight, indices, values, weights, + int_zeros, float_zeros, path) + elif isinstance(inp, ragged_tensor.RaggedTensor): + self._add_data_for_ragged_tensor(inp, weight, indices, values, weights, + int_zeros, float_zeros, path) + else: + raise ValueError("Input {} is of unknown type {}. Please only pass " + "Tensor, SparseTensor or RaggedTensor as input to " + "enqueue.".format(path, type(inp))) + + if ragged: + return tpu_ops.enqueue_tpu_embedding_ragged_tensor_batch( + sample_splits=indices, + embedding_indices=values, + aggregation_weights=weights, + mode_override=mode_override, + device_ordinal=device_ordinal, + combiners=combiners, + table_ids=table_ids, + max_sequence_lengths=max_sequence_lengths) + return tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch( + sample_indices=indices, + embedding_indices=values, + aggregation_weights=weights, + mode_override=mode_override, + device_ordinal=device_ordinal, + combiners=combiners, + table_ids=table_ids, + max_sequence_lengths=max_sequence_lengths) + + def _raise_error_for_incorrect_control_flow_context(self): + """Raises an error if we are not in the TPUReplicateContext.""" + # Do not allow any XLA control flow (i.e. control flow in between a + # TPUStrategy's run call and the call to this function), as we can't + # extract the enqueue from the head when in XLA control flow. + graph = ops.get_default_graph() + in_tpu_ctx = False + while graph is not None: + ctx = graph._get_control_flow_context() # pylint: disable=protected-access + while ctx is not None: + if isinstance(ctx, tpu.TPUReplicateContext): + in_tpu_ctx = True + break + ctx = ctx.outer_context + if in_tpu_ctx: + break + graph = getattr(graph, "outer_graph", None) + if graph != ops.get_default_graph() and in_tpu_ctx: + raise RuntimeError( + "Current graph {} does not match graph which contains " + "TPUReplicateContext {}. This is most likely due to the fact that " + "enqueueing embedding data is called inside control flow or a " + "nested function inside `strategy.run`. This is not supported " + "because outside compilation fails to extract the enqueue ops as " + "head of computation.".format(ops.get_default_graph(), graph)) + return in_tpu_ctx + + def _raise_error_for_non_direct_inputs(self, features): + """Checks all tensors in features to see if they are a direct input.""" + + # expand_composites here is important: as composite tensors pass through + # tpu.replicate, they get 'flattened' into their component tensors and then + # repacked before being passed to the tpu function. In means that it is the + # component tensors which are produced by an op with the + # "_tpu_input_identity" attribute. + for path, input_tensor in nest.flatten_with_joined_string_paths( + features, expand_composites=True): + if input_tensor.op.type == "Placeholder": + continue + try: + is_input = input_tensor.op.get_attr("_tpu_input_identity") + except ValueError: + is_input = False + if not is_input: + raise ValueError( + "Received input tensor {} which is the output of op {} (type {}) " + "which does not have the `_tpu_input_identity` attr. Please " + "ensure that the inputs to this layer are taken directly from " + "the arguments of the function called by " + "strategy.run. Two possible causes are: dynamic batch size " + "support or you are using a keras layer and are not passing " + "tensors which match the dtype of the `tf.keras.Input`s." + "If you are triggering dynamic batch size support, you can " + "disable it by passing tf.distribute.RunOptions(" + "experimental_enable_dynamic_batch_size=False) to the options " + "argument of strategy.run().".format(path, + input_tensor.op.name, + input_tensor.op.type)) + + def enqueue(self, features, weights=None, training=True, name=None): + """Enqueues id tensors for embedding lookup. + + This function enqueues a structure of features to be looked up in the + embedding tables. We expect that the batch size of each of the tensors in + features matches the per core batch size. This will automatically happen if + your input dataset is batched to the global batch size and you use + `tf.distribute.experimental.TPUStrategy`'s `experimental_distribute_dataset` + or if you use `experimental_distribute_datasets_from_function` and batch + to the per core batch size computed by the context passed to your input + function. + + ```python + strategy = tf.distribute.experimental.TPUStrategy(...) + with strategy.scope(): + embedding = tf.tpu.experimental.embedding.TPUEmbedding(...) + + distributed_dataset = strategy.experimental_distribute_dataset(...) + dataset_iterator = iter(distributed_dataset) + + @tf.function + def training_step(): + def tpu_step(tpu_features): + with tf.GradientTape() as tape: + activations = embedding.dequeue() + tape.watch(activations) + + loss = ... # some computation involving activations + + embedding_gradients = tape.gradient(loss, activations) + embedding.apply_gradients(embedding_gradients) + + embedding_features, tpu_features = next(dataset_iterator) + embedding.enqueue(embedding_features, training=True) + strategy.run(tpu_step, args=(embedding_features,)) + + training_step() + ``` + + NOTE: You should specify `training=True` when using + `embedding.apply_gradients` as above and `training=False` when not using + `embedding.apply_gradients` (e.g. for frozen embeddings or when doing + evaluation). + + Args: + features: A nested structure of `tf.Tensor`s, `tf.SparseTensor`s or + `tf.RaggedTensor`s, with the same structure as `feature_config`. Inputs + will be downcast to `tf.int32`. Only one type out of `tf.SparseTensor` + or `tf.RaggedTensor` is supported per call. + weights: If not `None`, a nested structure of `tf.Tensor`s, + `tf.SparseTensor`s or `tf.RaggedTensor`s, matching the above, except + that the tensors should be of float type (and they will be downcast to + `tf.float32`). For `tf.SparseTensor`s we assume the `indices` are the + same for the parallel entries from `features` and similarly for + `tf.RaggedTensor`s we assume the row_splits are the same. + training: Defaults to `True`. If `False`, enqueue the batch as inference + batch (forward pass only). Do not call `apply_gradients` when this is + `False` as this may lead to a deadlock. + name: A name for the underlying op. + + Raises: + ValueError: When called inside a strategy.run call and input is not + directly taken from the args of the `strategy.run` call. Also if + the size of any sequence in `features` does not match corresponding + sequence in `feature_config`. Similarly for `weights`, if not `None`. + RuntimeError: When called inside a strategy.run call and inside XLA + control flow. + TypeError: If the type of any sequence in `features` does not match + corresponding sequence in `feature_config`. Similarly for `weights`, if + not `None`. + """ + if not self._using_tpu: + raise RuntimeError("enqueue is not valid when TPUEmbedding object is not " + "created under a TPUStrategy.") + + nest.assert_same_structure(self._feature_config, features) + + # TODO(bfontain): Add a check that the input batch_size matches the per core + # batch size that this instance of the API was initialized with. + + flat_inputs = nest.flatten(features) + flat_weights = [None] * len(flat_inputs) + if weights is not None: + nest.assert_same_structure(self._feature_config, weights) + flat_weights = nest.flatten(weights) + flat_features = nest.flatten_with_joined_string_paths(self._feature_config) + + in_tpu_context = self._raise_error_for_incorrect_control_flow_context() + # If we are in a tpu_context, automatically apply outside compilation. + if in_tpu_context: + self._raise_error_for_non_direct_inputs(features) + + def generate_enqueue_ops(): + """Generate enqueue ops for outside compilation.""" + # Note that we put array_ops.where_v2 rather than a python if so that + # the op is explicitly create and the constant ops are both in the graph + # even though we don't expect training to be a tensor (and thus generate + # control flow automatically). This need to make it easier to re-write + # the graph later if we need to fix which mode needs to be used. + mode_override = array_ops.where_v2(training, + constant_op.constant("train"), + constant_op.constant("inference")) + + # Device ordinal is -1 here, a later rewrite will fix this once the op + # is expanded by outside compilation. + enqueue_op = self._generate_enqueue_op( + flat_inputs, flat_weights, flat_features, device_ordinal=-1, + mode_override=mode_override) + + # Apply the name tag to the op. + if name is not None: + _add_key_attr(enqueue_op, name) + + # Ensure that this op has outbound control flow, otherwise it won't be + # executed. + ops.get_default_graph().control_outputs.append(enqueue_op) + + tpu.outside_compilation(generate_enqueue_ops) + + else: + mode_override = "train" if training else "inference" + # We generate enqueue ops per device, so we need to gather the all + # features for a single device in to a dict. + # We rely here on the fact that the devices in the PerReplica value occur + # in the same (standard) order as self._strategy.extended.worker_devices. + enqueue_ops = [] + for replica_id in range(self._strategy.num_replicas_in_sync): + replica_inputs = tf_values.select_replica(replica_id, flat_inputs) + replica_weights = tf_values.select_replica(replica_id, flat_weights) + tpu_device = self._strategy.extended.worker_devices[replica_id] + # TPU devices string are like /job:worker/replica:0/task:0/device:TPU:0 + # the device ordinal is the last number + device_ordinal = int(tpu_device.rsplit(":", 1)[1]) + with ops.device(device_util.get_host_for_device(tpu_device)): + enqueue_op = self._generate_enqueue_op( + replica_inputs, replica_weights, flat_features, + device_ordinal=device_ordinal, mode_override=mode_override) + + # Apply the name tag to the op. + if name is not None: + _add_key_attr(enqueue_op, name) + enqueue_ops.append(enqueue_op) + ops.get_default_graph().control_outputs.extend(enqueue_ops) + + +class TPUEmbeddingSaveable(saveable_hook.SaveableHook): + """Save/Restore hook to Retrieve/Load TPUEmbedding variables.""" + + def __init__(self, name, load, retrieve): + self._load = load + self._retrieve = retrieve + super(TPUEmbeddingSaveable, self).__init__(name=name) + + def before_save(self): + if self._retrieve is not None: + self._retrieve() + + def after_restore(self): + if self._load is not None: + self._load() + + +def _ragged_embedding_lookup_with_reduce(table, ragged, weights, combiner): + """Compute a ragged lookup followed by a reduce on axis 1. + + Args: + table: The embedding table. + ragged: A RaggedTensor of ids to look up. + weights: A RaggedTensor of weights (or None). + combiner: One of "mean", "sum", "sqrtn". + + Returns: + A Tensor. + """ + if weights is None: + weights = array_ops.ones_like(ragged) + weights = array_ops.expand_dims(weights, axis=2) + ragged_result = embedding_ops.embedding_lookup_ragged(table, ragged) + ragged_result = math_ops.reduce_sum(ragged_result * weights, axis=1) + if combiner == "mean": + ragged_result = ragged_result / math_ops.reduce_sum(weights, axis=1) + elif combiner == "sqrtn": + ragged_result = ragged_result, math_ops.sqrt(math_ops.reduce_sum( + weights*weights, axis=1)) + return ragged_result + + +def cpu_embedding_lookup(inputs, weights, tables, feature_config): + """Uses CPU embedding lookup for embedding ids in features. + + Args: + inputs: a nested structure of Tensors, SparseTensors or RaggedTensors. + weights: a nested structure of Tensors, SparseTensors or RaggedTensors or + None for no weights. + tables: a dict of mapping TableConfig objects to Variables. + feature_config: a nested structure of FeatureConfig objects with the same + structure as inputs. + + Returns: + A nested structure of Tensors with the same structure as inputs. + """ + + nest.assert_same_structure(inputs, feature_config) + + flat_inputs = nest.flatten(inputs) + flat_weights = [None] * len(flat_inputs) + if weights is not None: + nest.assert_same_structure(inputs, weights) + flat_weights = nest.flatten(weights) + flat_features = nest.flatten_with_joined_string_paths(feature_config) + + outputs = [] + for inp, weight, (path, feature) in zip( + flat_inputs, flat_weights, flat_features): + table = tables[feature.table] + if feature.max_sequence_length > 0: + raise ValueError("Sequence features unsupported at this time.") + + if weight is not None: + if isinstance(inp, ops.Tensor): + raise ValueError( + "Weight specified for {}, but input is dense.".format(path)) + elif type(weight) is not type(inp): + raise ValueError( + "Weight for {} is of type {} but it does not match type of the " + "input which is {}.".format(path, type(weight), type(inp))) + + if isinstance(inp, ops.Tensor): + outputs.append(embedding_ops.embedding_lookup_v2(table, inp)) + + elif isinstance(inp, sparse_tensor.SparseTensor): + outputs.append(embedding_ops.safe_embedding_lookup_sparse_v2( + table, inp, sparse_weights=weight, combiner=feature.table.combiner)) + + elif isinstance(inp, ragged_tensor.RaggedTensor): + outputs.append(_ragged_embedding_lookup_with_reduce( + table, inp, weight, feature.table.combiner)) + + else: + raise ValueError("Input {} is type {}. Tensor, SparseTensor or " + "RaggedTensor expected.".format(path, type(inp))) + return nest.pack_sequence_as(feature_config, outputs) + + +def get_list_of_hosts(strategy): + """Returns a sorted list of CPU devices for the remote jobs. + + Args: + strategy: A TPUStrategy object. + + Returns: + A sort list of device strings. + """ + list_of_hosts = [] + # Assume this is sorted by task + for tpu_device in strategy.extended.worker_devices: + host = device_util.get_host_for_device(tpu_device) + if host not in list_of_hosts: + list_of_hosts.append(host) + assert len(list_of_hosts) == strategy.extended.num_hosts + return list_of_hosts + + +def extract_variable_info(kwargs): + """Extracts the variable creation attributes from the kwargs. + + Args: + kwargs: a dict of keyword arguments that were passed to a variable creator + scope. + + Returns: + A tuple of variable name, initialization function, shape, and dtype. + """ + if (isinstance(kwargs["initial_value"], functools.partial) and ( + "shape" in kwargs["initial_value"].keywords or + kwargs["initial_value"].args)): + # Sometimes shape is passed positionally, sometimes it's passed as a kwarg. + if "shape" in kwargs["initial_value"].keywords: + shape = kwargs["initial_value"].keywords["shape"] + else: + shape = kwargs["initial_value"].args[0] + return (kwargs["name"], shape, + kwargs["initial_value"].keywords.get("dtype", kwargs["dtype"]), + kwargs["initial_value"].func) + elif "shape" not in kwargs or kwargs["shape"] is None: + raise ValueError( + "Unable to extract initializer function and shape from {}. Please " + "either pass a function that expects a shape and dtype as the " + "initial value for your variable or functools.partial object with " + "the shape and dtype kwargs set. This is needed so that we can " + "initialize the shards of the ShardedVariable locally.".format( + kwargs["initial_value"])) + else: + return (kwargs["name"], kwargs["shape"], kwargs["dtype"], + kwargs["initial_value"]) + + +def make_sharded_variable_creator(hosts): + """Makes a sharded variable creator given a list of hosts. + + Args: + hosts: a list of tensorflow devices on which to shard the tensors. + + Returns: + A variable creator function. + """ + + def sharded_variable_creator(next_creator, *args, **kwargs): + """The sharded variable creator.""" + kwargs["skip_mirrored_creator"] = True + + num_hosts = len(hosts) + name, shape, dtype, initial_value = extract_variable_info(kwargs) + rows = shape[0] + cols = shape[1] + missing = rows % num_hosts + # we partition as if we were using MOD sharding. + partitions = ([rows // num_hosts + 1] * missing + [rows // num_hosts] * + (num_hosts - missing)) + variables = [] + newkwargs = kwargs + newkwargs["dtype"] = dtype + for i, p in enumerate(partitions): + with ops.device(hosts[i]): + newkwargs["shape"] = (p, cols) + newkwargs["name"] = "{}_{}".format(name, i) + newkwargs["initial_value"] = ( + lambda: initial_value(newkwargs["shape"], dtype=dtype)) + variables.append(next_creator(*args, **kwargs)) + return TPUShardedVariable(variables, name=name) + return sharded_variable_creator diff --git a/tensorflow/python/tpu/tpu_embedding_v2_utils.py b/tensorflow/python/tpu/tpu_embedding_v2_utils.py new file mode 100644 index 00000000000..bba0d10a62f --- /dev/null +++ b/tensorflow/python/tpu/tpu_embedding_v2_utils.py @@ -0,0 +1,624 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Companion classes for mid level API for TPU Embeddings in TF2.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import abc +import functools +import math +import six + +from tensorflow.core.protobuf.tpu import optimization_parameters_pb2 +from tensorflow.python.ops import init_ops_v2 +from tensorflow.python.ops import variables as tf_variables +from tensorflow.python.tpu.ops import tpu_ops +from tensorflow.python.util.tf_export import tf_export + + +@six.add_metaclass(abc.ABCMeta) +class _Optimizer(object): + """Base class for all optimizers, with common parameters.""" + + def __init__(self, learning_rate, use_gradient_accumulation, clip_weight_min, + clip_weight_max, weight_decay_factor, + multiply_weight_decay_factor_by_learning_rate, + slot_variable_creation_fn=None): + self.learning_rate = learning_rate + self.use_gradient_accumulation = use_gradient_accumulation + self.clip_weight_min = clip_weight_min + self.clip_weight_max = clip_weight_max + self.weight_decay_factor = weight_decay_factor + self.multiply_weight_decay_factor_by_learning_rate = ( + multiply_weight_decay_factor_by_learning_rate) + + if (slot_variable_creation_fn is not None and + not callable(slot_variable_creation_fn)): + raise ValueError("slot_variable_creation_fn must be either None or a " + "callable.") + self.slot_variable_creation_fn = slot_variable_creation_fn + + @abc.abstractmethod + def _slot_names(self): + """Returns the name of all the slot variables. + + This does not include the 'parameters' variable and these names must match + the names of the slots variables as used in the corresponding + `tpu_ops.load_tpu_embedding_*` ops. + """ + raise NotImplementedError + + @abc.abstractmethod + def _slot_initializers(self): + """Returns initializers for slot variables. + + This returns a parallel list to self._slot_names(). + """ + raise NotImplementedError + + def _set_optimization_parameters(self, parameters): + """Sets the optimizer fields in the OptimizationParameters.""" + if self.use_gradient_accumulation: + parameters.gradient_accumulation_status = ( + optimization_parameters_pb2.GradientAccumulationStatus.ENABLED) + else: + parameters.gradient_accumulation_status = ( + optimization_parameters_pb2.GradientAccumulationStatus.DISABLED) + + if self.clip_weight_min is not None: + parameters.clipping_limits.lower.value = self.clip_weight_min + + if self.clip_weight_max is not None: + parameters.clipping_limits.upper.value = self.clip_weight_max + + if self.weight_decay_factor: + parameters.weight_decay_factor = self.weight_decay_factor + if self.multiply_weight_decay_factor_by_learning_rate: + parameters.multiply_weight_decay_factor_by_learning_rate = True + + @abc.abstractmethod + def _load(self): + """Returns the load function for the optimizer.""" + raise NotImplementedError + + @abc.abstractmethod + def _retrieve(self): + """Returns the retrieve function for the optimizer.""" + raise NotImplementedError + + def _create_slots(self, table): + """Creates slot variables for table. + + Uses shape of table to create parallel slot variables. + + Args: + table: A Variable or equivalent. + + Returns: + A dict of variables, keyed by self._slot_names(). + """ + if self.slot_variable_creation_fn is not None: + return self.slot_variable_creation_fn(table, self._slot_names()) + else: + slots = {} + for slot, initializer in zip(self._slot_names(), + self._slot_initializers()): + slots[slot] = tf_variables.Variable( + name=table.name + "/" + slot, + initial_value=functools.partial( + initializer, shape=table.shape, dtype=table.dtype), + trainable=False) + return slots + + +@tf_export("tpu.experimental.embedding.SGD") +class SGD(_Optimizer): + """Optimization parameters for stochastic gradient descent for TPU embeddings. + + Pass this to `tf.tpu.experimental.embedding.TPUEmbedding` via the `optimizer` + argument to set the global optimizer and its parameters: + + ``` + embedding = tf.tpu.experimental.embedding.TPUEmbedding( + ... + optimizer=tf.tpu.experimental.embedding.SGD(0.1)) + ``` + + This can also be used in a `tf.tpu.experimental.embedding.TableConfig` as the + optimizer parameter to set a table specific optimizer. This will override the + optimizer and parameters for global embedding optimizer defined above: + + ``` + table_one = tf.tpu.experimental.embedding.TableConfig( + vocabulary_size=..., + dim=..., + optimizer=tf.tpu.experimental.embedding.SGD(0.2)) + table_two = tf.tpu.experimental.embedding.TableConfig( + vocabulary_size=..., + dim=...) + + feature_config = ( + tf.tpu.experimental.embedding.FeatureConfig( + table=table_one), + tf.tpu.experimental.embedding.FeatureConfig( + table=table_two)) + + embedding = tf.tpu.experimental.embedding.TPUEmbedding( + feature_config=feature_config, + batch_size=... + optimizer=tf.tpu.experimental.embedding.SGD(0.1)) + ``` + + In the above example, the first feature will be looked up in a table that has + a learning rate of 0.2 while the second feature will be looked up in a table + that has a learning rate of 0.1. + + See 'tensorflow/core/protobuf/tpu/optimization_parameters.proto' for a + complete description of these parameters and their impacts on the optimizer + algorithm. + """ + + def __init__(self, + learning_rate=0.01, + clip_weight_min=None, + clip_weight_max=None, + weight_decay_factor=None, + multiply_weight_decay_factor_by_learning_rate=None): + """Optimization parameters for stochastic gradient descent. + + Args: + learning_rate: The learning rate. It should be a floating point value or a + callable taking no arguments for a dynamic learning rate. + clip_weight_min: the minimum value to clip by; None means -infinity. + clip_weight_max: the maximum value to clip by; None means +infinity. + weight_decay_factor: amount of weight decay to apply; None means that the + weights are not decayed. Weights are decayed by multiplying the weight + by this factor each step. + multiply_weight_decay_factor_by_learning_rate: if true, + `weight_decay_factor` is multiplied by the current learning rate. + """ + super(SGD, self).__init__( + learning_rate, False, clip_weight_min, clip_weight_max, + weight_decay_factor, multiply_weight_decay_factor_by_learning_rate) + + def _slot_names(self): + return [] + + def _slot_initializers(self): + return [] + + def _set_optimization_parameters(self, parameters): + super(SGD, self)._set_optimization_parameters(parameters) + parameters.stochastic_gradient_descent.SetInParent() + + def _load(self): + return tpu_ops.load_tpu_embedding_stochastic_gradient_descent_parameters + + def _retrieve(self): + return tpu_ops.retrieve_tpu_embedding_stochastic_gradient_descent_parameters + + +@tf_export("tpu.experimental.embedding.Adagrad") +class Adagrad(_Optimizer): + """Optimization parameters for Adagrad with TPU embeddings. + + Pass this to `tf.tpu.experimental.embedding.TPUEmbedding` via the `optimizer` + argument to set the global optimizer and its parameters: + + ```python + embedding = tf.tpu.experimental.embedding.TPUEmbedding( + ... + optimizer=tf.tpu.experimental.embedding.Adagrad(0.1)) + ``` + + This can also be used in a `tf.tpu.experimental.embedding.TableConfig` as the + optimizer parameter to set a table specific optimizer. This will override the + optimizer and parameters for global embedding optimizer defined above: + + ```python + table_one = tf.tpu.experimental.embedding.TableConfig( + vocabulary_size=..., + dim=..., + optimizer=tf.tpu.experimental.embedding.Adagrad(0.2)) + table_two = tf.tpu.experimental.embedding.TableConfig( + vocabulary_size=..., + dim=...) + + feature_config = ( + tf.tpu.experimental.embedding.FeatureConfig( + table=table_one), + tf.tpu.experimental.embedding.FeatureConfig( + table=table_two)) + + embedding = tf.tpu.experimental.embedding.TPUEmbedding( + feature_config=feature_config, + batch_size=... + optimizer=tf.tpu.experimental.embedding.Adagrad(0.1)) + ``` + + In the above example, the first feature will be looked up in a table that has + a learning rate of 0.2 while the second feature will be looked up in a table + that has a learning rate of 0.1. + + See 'tensorflow/core/protobuf/tpu/optimization_parameters.proto' for a + complete description of these parameters and their impacts on the optimizer + algorithm. + """ + + def __init__(self, + learning_rate=0.001, + initial_accumulator_value=0.1, + use_gradient_accumulation=True, + clip_weight_min=None, + clip_weight_max=None, + weight_decay_factor=None, + multiply_weight_decay_factor_by_learning_rate=None, + slot_variable_creation_fn=None): + """Optimization parameters for Adagrad. + + Args: + learning_rate: The learning rate. It should be a floating point value or a + callable taking no arguments for a dynamic learning rate. + initial_accumulator_value: initial accumulator for Adagrad. + use_gradient_accumulation: setting this to `False` makes embedding + gradients calculation less accurate but faster. + clip_weight_min: the minimum value to clip by; None means -infinity. + clip_weight_max: the maximum value to clip by; None means +infinity. + weight_decay_factor: amount of weight decay to apply; None means that the + weights are not decayed. + multiply_weight_decay_factor_by_learning_rate: if true, + `weight_decay_factor` is multiplied by the current learning rate. + slot_variable_creation_fn: Defaults to `None`. If you wish do directly + control the creation of the slot variables, set this to a callable + taking two parameters, a variable and a list of slot names to create for + it. This function should return a dict with the slot names as keys and + the created variables as values. When set to None (the default), uses + the built-in variable creation. + """ + super(Adagrad, self).__init__( + learning_rate, use_gradient_accumulation, clip_weight_min, + clip_weight_max, weight_decay_factor, + multiply_weight_decay_factor_by_learning_rate, + slot_variable_creation_fn) + if initial_accumulator_value <= 0: + raise ValueError("Adagrad initial_accumulator_value must be positive") + self.initial_accumulator_value = initial_accumulator_value + + def _slot_names(self): + return ["accumulators"] + + def _slot_initializers(self): + return [init_ops_v2.Constant(self.initial_accumulator_value)] + + def _set_optimization_parameters(self, parameters): + super(Adagrad, self)._set_optimization_parameters(parameters) + parameters.adagrad.SetInParent() + + def _load(self): + return tpu_ops.load_tpu_embedding_adagrad_parameters + + def _retrieve(self): + return tpu_ops.retrieve_tpu_embedding_adagrad_parameters + + +@tf_export("tpu.experimental.embedding.Adam") +class Adam(_Optimizer): + """Optimization parameters for Adam with TPU embeddings. + + Pass this to `tf.tpu.experimental.embedding.TPUEmbedding` via the `optimizer` + argument to set the global optimizer and its parameters: + + NOTE: By default this optimizer is lazy, i.e. it will not apply the gradient + update of zero to rows that were not looked up. You can change this behavior + by setting `lazy_adam` to `False`. + + ```python + embedding = tf.tpu.experimental.embedding.TPUEmbedding( + ... + optimizer=tf.tpu.experimental.embedding.Adam(0.1)) + ``` + + This can also be used in a `tf.tpu.experimental.embedding.TableConfig` as the + optimizer parameter to set a table specific optimizer. This will override the + optimizer and parameters for global embedding optimizer defined above: + + ```python + table_one = tf.tpu.experimental.embedding.TableConfig( + vocabulary_size=..., + dim=..., + optimizer=tf.tpu.experimental.embedding.Adam(0.2)) + table_two = tf.tpu.experimental.embedding.TableConfig( + vocabulary_size=..., + dim=...) + + feature_config = ( + tf.tpu.experimental.embedding.FeatureConfig( + table=table_one), + tf.tpu.experimental.embedding.FeatureConfig( + table=table_two)) + + embedding = tf.tpu.experimental.embedding.TPUEmbedding( + feature_config=feature_config, + batch_size=... + optimizer=tf.tpu.experimental.embedding.Adam(0.1)) + ``` + + In the above example, the first feature will be looked up in a table that has + a learning rate of 0.2 while the second feature will be looked up in a table + that has a learning rate of 0.1. + + See 'tensorflow/core/protobuf/tpu/optimization_parameters.proto' for a + complete description of these parameters and their impacts on the optimizer + algorithm. + """ + + def __init__(self, + learning_rate=0.001, + beta_1=0.9, + beta_2=0.999, + epsilon=1e-07, + lazy_adam=True, + sum_inside_sqrt=True, + use_gradient_accumulation=True, + clip_weight_min=None, + clip_weight_max=None, + weight_decay_factor=None, + multiply_weight_decay_factor_by_learning_rate=None, + slot_variable_creation_fn=None): + """Optimization parameters for Adam. + + See 'tensorflow/core/protobuf/tpu/optimization_parameters.proto' for a + complete description of these parameters and their impacts on the optimizer + algorithm. + + Args: + learning_rate: The learning rate. It should be a floating point value or a + callable taking no arguments for a dynamic learning rate. + beta_1: A float value. + The exponential decay rate for the 1st moment estimates. + beta_2: A float value. + The exponential decay rate for the 2nd moment estimates. + epsilon: A small constant for numerical stability. + lazy_adam: Use lazy Adam instead of Adam. Lazy Adam trains faster. + sum_inside_sqrt: When this is true, the Adam update formula is changed + from `m / (sqrt(v) + epsilon)` to `m / sqrt(v + epsilon**2)`. This + option improves the performance of TPU training and is not expected to + harm model quality. + use_gradient_accumulation: Setting this to `False` makes embedding + gradients calculation less accurate but faster. + clip_weight_min: the minimum value to clip by; None means -infinity. + clip_weight_max: the maximum value to clip by; None means +infinity. + weight_decay_factor: amount of weight decay to apply; None means that the + weights are not decayed. + multiply_weight_decay_factor_by_learning_rate: if true, + `weight_decay_factor` is multiplied by the current learning rate. + slot_variable_creation_fn: a callable taking two parameters, a variable + and a list of slot names to create for it. This function should return + a dict with the slot names as keys and the created variables as values. + When set to None (the default), uses the built-in variable creation. + """ + super(Adam, self).__init__( + learning_rate, use_gradient_accumulation, clip_weight_min, + clip_weight_max, weight_decay_factor, + multiply_weight_decay_factor_by_learning_rate, + slot_variable_creation_fn) + if beta_1 < 0. or beta_1 >= 1.: + raise ValueError("beta1 must be in the range [0, 1), but received {}." + .format(beta_1)) + if beta_2 < 0. or beta_2 >= 1.: + raise ValueError("beta2 must be in the range [0, 1), but received {}." + .format(beta_2)) + if epsilon <= 0.: + raise ValueError("epsilon must be positive; got {}.".format(epsilon)) + if not use_gradient_accumulation and not lazy_adam: + raise ValueError( + "When disabling Lazy Adam, gradient accumulation must be used.") + + self.beta_1 = beta_1 + self.beta_2 = beta_2 + self.epsilon = epsilon + self.lazy_adam = lazy_adam + self.sum_inside_sqrt = sum_inside_sqrt + + def _slot_names(self): + return ["momenta", "velocities"] + + def _slot_initializers(self): + return [init_ops_v2.Constant(), init_ops_v2.Constant()] + + def _set_optimization_parameters(self, parameters): + super(Adam, self)._set_optimization_parameters(parameters) + parameters.adam.beta1 = self.beta_1 + parameters.adam.beta2 = self.beta_2 + parameters.adam.epsilon = self.epsilon + parameters.adam.use_non_lazy_adam = not self.lazy_adam + parameters.adam.use_sum_inside_sqrt = self.sum_inside_sqrt + + def _load(self): + return tpu_ops.load_tpu_embedding_adam_parameters + + def _retrieve(self): + return tpu_ops.retrieve_tpu_embedding_adam_parameters + + +@tf_export("tpu.experimental.embedding.TableConfig") +class TableConfig(object): + """Configuration data for one embedding table. + + This class holds the configuration data for a single embedding table. It is + used as the `table` parameter of a + `tf.tpu.experimental.embedding.FeatureConfig`. Multiple + `tf.tpu.experimental.embedding.FeatureConfig` objects can use the same + `tf.tpu.experimental.embedding.TableConfig` object. In this case a shared + table will be created for those feature lookups. + + ```python + table_config_one = tf.tpu.experimental.embedding.TableConfig( + vocabulary_size=..., + dim=...) + table_config_two = tf.tpu.experimental.embedding.TableConfig( + vocabulary_size=..., + dim=...) + feature_config = { + 'feature_one': tf.tpu.experimental.embedding.FeatureConfig( + table=table_config_one), + 'feature_two': tf.tpu.experimental.embedding.FeatureConfig( + table=table_config_one), + 'feature_three': tf.tpu.experimental.embedding.FeatureConfig( + table=table_config_two)} + embedding = tf.tpu.experimental.embedding.TPUEmbedding( + feature_config=feature_config, + batch_size=... + optimizer=tf.tpu.experimental.embedding.Adam(0.1)) + ``` + + The above configuration has 2 tables, and three features. The first two + features will be looked up in the first table and the third feature will be + looked up in the second table. + + """ + + def __init__(self, vocabulary_size, dim, initializer, optimizer=None, + combiner="mean", name=None): + """Embedding table configuration. + + Args: + vocabulary_size: Size of the table's vocabulary (number of rows). + dim: The embedding dimension (width) of the table. + initializer: A callable initializer taking one parameter, the shape of the + variable that will be initialized. Will be called once per task, to + initialize that task's shard of the embedding table. If not specified, + defaults to `truncated_normal_initializer` with mean `0.0` and standard + deviation `1/sqrt(dim)`. + optimizer: An optional instance of an optimizer parameters class, instance + of one of `tf.tpu.experimental.embedding.SGD`, + `tf.tpu.experimental.embedding.Adagrad` or + `tf.tpu.experimental.embedding.Adam`. It set will override the global + optimizer passed to `tf.tpu.experimental.embedding.TPUEmbedding`. + combiner: A string specifying how to reduce if there are multiple entries + in a single row. Currently 'mean', 'sqrtn', 'sum' are + supported, with 'mean' the default. 'sqrtn' often achieves good + accuracy, in particular with bag-of-words columns. For more information, + see `tf.nn.embedding_lookup_sparse`. + name: An optional string used to name the table. Useful for debugging. + + Returns: + `TableConfig`. + + Raises: + ValueError: if `vocabulary_size` is not a positive integer. + ValueError: if `dim` is not a positive integer. + ValueError: if `initializer` is specified and is not callable. + ValueError: if `combiner` is not supported. + """ + if not isinstance(vocabulary_size, int) or vocabulary_size < 1: + raise ValueError("Invalid vocabulary_size {}.".format(vocabulary_size)) + + if not isinstance(dim, int) or dim < 1: + raise ValueError("Invalid dim {}.".format(dim)) + + if (initializer is not None) and (not callable(initializer)): + raise ValueError("initializer must be callable if specified.") + if initializer is None: + initializer = init_ops_v2.TruncatedNormal(mean=0.0, + stddev=1/math.sqrt(dim)) + + if combiner not in ("mean", "sum", "sqrtn"): + raise ValueError("Invalid combiner {}".format(combiner)) + + self.vocabulary_size = vocabulary_size + self.dim = dim + self.initializer = initializer + self.optimizer = optimizer + self.combiner = combiner + self.name = name + + +@tf_export("tpu.experimental.embedding.FeatureConfig") +class FeatureConfig(object): + """Configuration data for one embedding feature. + + This class holds the configuration data for a single embedding feature. The + main use is to assign features to `tf.tpu.experimental.embedding.TableConfig`s + via the table parameter: + + ```python + table_config_one = tf.tpu.experimental.embedding.TableConfig( + vocabulary_size=..., + dim=...) + table_config_two = tf.tpu.experimental.embedding.TableConfig( + vocabulary_size=..., + dim=...) + feature_config = { + 'feature_one': tf.tpu.experimental.embedding.FeatureConfig( + table=table_config_one), + 'feature_two': tf.tpu.experimental.embedding.FeatureConfig( + table=table_config_one), + 'feature_three': tf.tpu.experimental.embedding.FeatureConfig( + table=table_config_two)} + embedding = tf.tpu.experimental.embedding.TPUEmbedding( + feature_config=feature_config, + batch_size=... + optimizer=tf.tpu.experimental.embedding.Adam(0.1)) + ``` + + The above configuration has 2 tables, and three features. The first two + features will be looked up in the first table and the third feature will be + looked up in the second table. + + When feeding features into `embedding.enqueue` they can be `tf.Tensor`s, + `tf.SparseTensor`s or `tf.RaggedTensor`s. When the argument + `max_sequence_length` is 0, the default, you should expect a output of + `embedding.dequeue` for this feature of shape `(batch_size, dim)`. If + `max_sequence_length` is greater than 0, the feature is embedded as a sequence + and padded up to the given length. The shape of the output for this feature + will be `(batch_size, max_sequence_length, dim)`. + """ + + def __init__(self, table, max_sequence_length=0, name=None): + """Feature configuration. + + Args: + table: An instance of `tf.tpu.experimental.embedding.TableConfig`, + describing the table in which this feature should be looked up. + max_sequence_length: If positive, the feature is a sequence feature with + the corresponding maximum sequence length. If the sequence is longer + than this, it will be truncated. If 0, the feature is not a sequence + feature. + name: An optional name for the feature, useful for debugging. + + Returns: + `FeatureConfig`. + + Raises: + ValueError: if `table` is not an instance of + `tf.tpu.experimental.embedding.TableConfig`. + ValueError: if `max_sequence_length` not an integer or is negative. + """ + if not isinstance(table, TableConfig): + raise ValueError("table is type {}, expected " + "`tf.tpu.experimental.embedding.TableConfig`".format( + type(table))) + + if not isinstance(max_sequence_length, int) or max_sequence_length < 0: + raise ValueError("Invalid max_sequence_length {}.".format( + max_sequence_length)) + + self.table = table + self.max_sequence_length = max_sequence_length + self.name = name diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py index 9732ea04f26..1fe8a8c729b 100644 --- a/tensorflow/python/training/optimizer.py +++ b/tensorflow/python/training/optimizer.py @@ -768,7 +768,7 @@ class Optimizer( # pylint: enable=protected-access mirrored_slot = named_slots.get(key, None) if mirrored_slot is None: return None - return mirrored_slot._get_closest() # pylint: disable=protected-access + return mirrored_slot._get_on_device_or_primary() # pylint: disable=protected-access return named_slots.get(_var_key(var), None) diff --git a/tensorflow/python/util/dispatch.py b/tensorflow/python/util/dispatch.py index 3868da14b44..51dfe3793ae 100644 --- a/tensorflow/python/util/dispatch.py +++ b/tensorflow/python/util/dispatch.py @@ -99,7 +99,7 @@ class GlobalOpDispatcher(object): _GLOBAL_DISPATCHERS.append(self) -def dispatch(op, *args, **kwargs): +def dispatch(op, args, kwargs): """Returns the result from the first successful dispatcher for a given op. Calls the `handle` method of each `OpDispatcher` that has been registered @@ -107,8 +107,8 @@ def dispatch(op, *args, **kwargs): Args: op: Python function: the operation to dispatch for. - *args: The arguments to the operation. - **kwargs: They keyword arguments to the operation. + args: The arguments to the operation. + kwargs: They keyword arguments to the operation. Returns: The result of the operation, or `NOT_SUPPORTED` if no registered @@ -202,7 +202,7 @@ def add_dispatch_support(target): except (TypeError, ValueError): # Note: convert_to_eager_tensor currently raises a ValueError, not a # TypeError, when given unexpected types. So we need to catch both. - result = dispatch(wrapper, *args, **kwargs) + result = dispatch(wrapper, args, kwargs) if result is not OpDispatcher.NOT_SUPPORTED: return result else: diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py index 695cc4cc909..b4736bee142 100644 --- a/tensorflow/python/util/nest.py +++ b/tensorflow/python/util/nest.py @@ -215,7 +215,15 @@ def _yield_sorted_items(iterable): Yields: The iterable's (key, value) pairs, in order of sorted keys. """ - if isinstance(iterable, _collections_abc.Mapping): + # Ordered to check common structure types (list, tuple, dict) first. + if isinstance(iterable, list): + for item in enumerate(iterable): + yield item + # namedtuples handled separately to avoid expensive namedtuple check. + elif type(iterable) == tuple: # pylint: disable=unidiomatic-typecheck + for item in enumerate(iterable): + yield item + elif isinstance(iterable, (dict, _collections_abc.Mapping)): # Iterate through dictionaries in a deterministic order by sorting the # keys. Notice this means that we ignore the original order of `OrderedDict` # instances. This is intentional, to avoid potential bugs caused by mixing diff --git a/tensorflow/python/util/tf_should_use.py b/tensorflow/python/util/tf_should_use.py index 0c11b08131c..9ba4b7520e5 100644 --- a/tensorflow/python/util/tf_should_use.py +++ b/tensorflow/python/util/tf_should_use.py @@ -19,6 +19,7 @@ from __future__ import print_function import copy import sys +import textwrap import traceback import six # pylint: disable=unused-import @@ -231,20 +232,27 @@ def should_use_result(fn=None, warn_in_eager=False, error_in_function=False): The wrapped function. """ def decorated(fn): + """Decorates the input function.""" def wrapped(*args, **kwargs): return _add_should_use_warning(fn(*args, **kwargs), warn_in_eager=warn_in_eager, error_in_function=error_in_function) + fn_doc = fn.__doc__ or '' + split_doc = fn_doc.split('\n', 1) + if len(split_doc) == 1: + updated_doc = fn_doc + else: + brief, rest = split_doc + updated_doc = '\n'.join([brief, textwrap.dedent(rest)]) + + note = ('\n\nNote: The output of this function should be used. If it is ' + 'not, a warning will be logged or an error may be raised. ' + 'To mark the output as used, call its .mark_used() method.') return tf_decorator.make_decorator( target=fn, decorator_func=wrapped, decorator_name='should_use_result', - decorator_doc=( - (fn.__doc__ or '') + - ('\n\n ' - '**NOTE** The output of this function should be used. If it is ' - 'not, a warning will be logged or an error may be raised. ' - 'To mark the output as used, call its .mark_used() method.'))) + decorator_doc=updated_doc + note) if fn is not None: return decorated(fn) diff --git a/tensorflow/stream_executor/cuda/cublas_10_2.inc b/tensorflow/stream_executor/cuda/cublas_10_2.inc index 42c4e5fef3b..067ba675288 100644 --- a/tensorflow/stream_executor/cuda/cublas_10_2.inc +++ b/tensorflow/stream_executor/cuda/cublas_10_2.inc @@ -2,29 +2,31 @@ extern "C" { -cublasStatus_t CUBLASWINAPI cublasCreate_v2 (cublasHandle_t *handle) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t *); +cublasStatus_t CUBLASWINAPI cublasCreate_v2(cublasHandle_t *handle) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCreate_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle); } -cublasStatus_t CUBLASWINAPI cublasDestroy_v2 (cublasHandle_t handle) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t); +cublasStatus_t CUBLASWINAPI cublasDestroy_v2(cublasHandle_t handle) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDestroy_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle); } -cublasStatus_t CUBLASWINAPI cublasGetVersion_v2(cublasHandle_t handle, int *version) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int *); +cublasStatus_t CUBLASWINAPI cublasGetVersion_v2(cublasHandle_t handle, + int *version) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasGetVersion_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, version); } -cublasStatus_t CUBLASWINAPI cublasGetProperty(libraryPropertyType type, int *value) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(libraryPropertyType, int *); +cublasStatus_t CUBLASWINAPI cublasGetProperty(libraryPropertyType type, + int *value) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(libraryPropertyType, int *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasGetProperty"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(type, value); @@ -37,57 +39,71 @@ size_t CUBLASWINAPI cublasGetCudartVersion(void) { return func_ptr(); } -cublasStatus_t CUBLASWINAPI cublasSetStream_v2 (cublasHandle_t handle, cudaStream_t streamId) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cudaStream_t); +cublasStatus_t CUBLASWINAPI cublasSetStream_v2(cublasHandle_t handle, + cudaStream_t streamId) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, cudaStream_t); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSetStream_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, streamId); } -cublasStatus_t CUBLASWINAPI cublasGetStream_v2 (cublasHandle_t handle, cudaStream_t *streamId) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cudaStream_t *); +cublasStatus_t CUBLASWINAPI cublasGetStream_v2(cublasHandle_t handle, + cudaStream_t *streamId) { + using FuncPtr = + cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, cudaStream_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasGetStream_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, streamId); } -cublasStatus_t CUBLASWINAPI cublasGetPointerMode_v2 (cublasHandle_t handle, cublasPointerMode_t *mode) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasPointerMode_t *); +cublasStatus_t CUBLASWINAPI cublasGetPointerMode_v2(cublasHandle_t handle, + cublasPointerMode_t *mode) { + using FuncPtr = + cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, cublasPointerMode_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasGetPointerMode_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, mode); } -cublasStatus_t CUBLASWINAPI cublasSetPointerMode_v2 (cublasHandle_t handle, cublasPointerMode_t mode) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasPointerMode_t); +cublasStatus_t CUBLASWINAPI cublasSetPointerMode_v2(cublasHandle_t handle, + cublasPointerMode_t mode) { + using FuncPtr = + cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, cublasPointerMode_t); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSetPointerMode_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, mode); } -cublasStatus_t CUBLASWINAPI cublasGetAtomicsMode(cublasHandle_t handle, cublasAtomicsMode_t *mode) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasAtomicsMode_t *); +cublasStatus_t CUBLASWINAPI cublasGetAtomicsMode(cublasHandle_t handle, + cublasAtomicsMode_t *mode) { + using FuncPtr = + cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, cublasAtomicsMode_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasGetAtomicsMode"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, mode); } -cublasStatus_t CUBLASWINAPI cublasSetAtomicsMode(cublasHandle_t handle, cublasAtomicsMode_t mode) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasAtomicsMode_t); +cublasStatus_t CUBLASWINAPI cublasSetAtomicsMode(cublasHandle_t handle, + cublasAtomicsMode_t mode) { + using FuncPtr = + cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, cublasAtomicsMode_t); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSetAtomicsMode"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, mode); } -cublasStatus_t CUBLASWINAPI cublasGetMathMode(cublasHandle_t handle, cublasMath_t *mode) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasMath_t *); +cublasStatus_t CUBLASWINAPI cublasGetMathMode(cublasHandle_t handle, + cublasMath_t *mode) { + using FuncPtr = + cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, cublasMath_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasGetMathMode"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, mode); } -cublasStatus_t CUBLASWINAPI cublasSetMathMode(cublasHandle_t handle, cublasMath_t mode) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasMath_t); +cublasStatus_t CUBLASWINAPI cublasSetMathMode(cublasHandle_t handle, + cublasMath_t mode) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, cublasMath_t); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSetMathMode"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, mode); @@ -118,399 +134,384 @@ cublasGetLoggerCallback(cublasLogCallback *userCallback) { return func_ptr(userCallback); } -cublasStatus_t CUBLASWINAPI cublasSetVector (int n, int elemSize, const void *x, - int incx, void *devicePtr, int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(int, int, const void *, int, void *, int); +cublasStatus_t CUBLASWINAPI cublasSetVector(int n, int elemSize, const void *x, + int incx, void *devicePtr, + int incy) { + using FuncPtr = + cublasStatus_t(CUBLASWINAPI *)(int, int, const void *, int, void *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSetVector"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(n, elemSize, x, incx, devicePtr, incy); } -cublasStatus_t CUBLASWINAPI cublasGetVector (int n, int elemSize, const void *x, - int incx, void *y, int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(int, int, const void *, int, void *, int); +cublasStatus_t CUBLASWINAPI cublasGetVector(int n, int elemSize, const void *x, + int incx, void *y, int incy) { + using FuncPtr = + cublasStatus_t(CUBLASWINAPI *)(int, int, const void *, int, void *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasGetVector"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(n, elemSize, x, incx, y, incy); } -cublasStatus_t CUBLASWINAPI cublasSetMatrix (int rows, int cols, int elemSize, - const void *A, int lda, void *B, - int ldb) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(int, int, int, const void *, int, void *, int); +cublasStatus_t CUBLASWINAPI cublasSetMatrix(int rows, int cols, int elemSize, + const void *A, int lda, void *B, + int ldb) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(int, int, int, const void *, + int, void *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSetMatrix"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(rows, cols, elemSize, A, lda, B, ldb); } -cublasStatus_t CUBLASWINAPI cublasGetMatrix (int rows, int cols, int elemSize, - const void *A, int lda, void *B, - int ldb) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(int, int, int, const void *, int, void *, int); +cublasStatus_t CUBLASWINAPI cublasGetMatrix(int rows, int cols, int elemSize, + const void *A, int lda, void *B, + int ldb) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(int, int, int, const void *, + int, void *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasGetMatrix"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(rows, cols, elemSize, A, lda, B, ldb); } -cublasStatus_t CUBLASWINAPI cublasSetVectorAsync (int n, int elemSize, - const void *hostPtr, int incx, - void *devicePtr, int incy, - cudaStream_t stream) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(int, int, const void *, int, void *, int, cudaStream_t); +cublasStatus_t CUBLASWINAPI cublasSetVectorAsync(int n, int elemSize, + const void *hostPtr, int incx, + void *devicePtr, int incy, + cudaStream_t stream) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(int, int, const void *, int, + void *, int, cudaStream_t); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSetVectorAsync"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(n, elemSize, hostPtr, incx, devicePtr, incy, stream); } -cublasStatus_t CUBLASWINAPI cublasGetVectorAsync (int n, int elemSize, - const void *devicePtr, int incx, - void *hostPtr, int incy, - cudaStream_t stream) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(int, int, const void *, int, void *, int, cudaStream_t); +cublasStatus_t CUBLASWINAPI cublasGetVectorAsync(int n, int elemSize, + const void *devicePtr, + int incx, void *hostPtr, + int incy, + cudaStream_t stream) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(int, int, const void *, int, + void *, int, cudaStream_t); static auto func_ptr = LoadSymbol<FuncPtr>("cublasGetVectorAsync"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(n, elemSize, devicePtr, incx, hostPtr, incy, stream); } -cublasStatus_t CUBLASWINAPI cublasSetMatrixAsync (int rows, int cols, int elemSize, - const void *A, int lda, void *B, - int ldb, cudaStream_t stream) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(int, int, int, const void *, int, void *, int, cudaStream_t); +cublasStatus_t CUBLASWINAPI cublasSetMatrixAsync(int rows, int cols, + int elemSize, const void *A, + int lda, void *B, int ldb, + cudaStream_t stream) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + int, int, int, const void *, int, void *, int, cudaStream_t); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSetMatrixAsync"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(rows, cols, elemSize, A, lda, B, ldb, stream); } -cublasStatus_t CUBLASWINAPI cublasGetMatrixAsync (int rows, int cols, int elemSize, - const void *A, int lda, void *B, - int ldb, cudaStream_t stream) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(int, int, int, const void *, int, void *, int, cudaStream_t); +cublasStatus_t CUBLASWINAPI cublasGetMatrixAsync(int rows, int cols, + int elemSize, const void *A, + int lda, void *B, int ldb, + cudaStream_t stream) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + int, int, int, const void *, int, void *, int, cudaStream_t); static auto func_ptr = LoadSymbol<FuncPtr>("cublasGetMatrixAsync"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(rows, cols, elemSize, A, lda, B, ldb, stream); } -void CUBLASWINAPI cublasXerbla (const char *srName, int info) { - using FuncPtr = void (CUBLASWINAPI *)(const char *, int); +void CUBLASWINAPI cublasXerbla(const char *srName, int info) { + using FuncPtr = void(CUBLASWINAPI *)(const char *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasXerbla"); if (!func_ptr) LogFatalSymbolNotFound("cublasXerbla"); return func_ptr(srName, info); } -cublasStatus_t CUBLASWINAPI cublasNrm2Ex(cublasHandle_t handle, - int n, - const void *x, - cudaDataType xType, - int incx, - void *result, - cudaDataType resultType, - cudaDataType executionType) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const void *, cudaDataType, int, void *, cudaDataType, cudaDataType); +cublasStatus_t CUBLASWINAPI cublasNrm2Ex(cublasHandle_t handle, int n, + const void *x, cudaDataType xType, + int incx, void *result, + cudaDataType resultType, + cudaDataType executionType) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const void *, cudaDataType, int, void *, + cudaDataType, cudaDataType); static auto func_ptr = LoadSymbol<FuncPtr>("cublasNrm2Ex"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, xType, incx, result, resultType, executionType); } -cublasStatus_t CUBLASWINAPI cublasSnrm2_v2(cublasHandle_t handle, - int n, - const float *x, - int incx, - float *result) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const float *, int, float *); +cublasStatus_t CUBLASWINAPI cublasSnrm2_v2(cublasHandle_t handle, int n, + const float *x, int incx, + float *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, + const float *, int, float *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSnrm2_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, result); } -cublasStatus_t CUBLASWINAPI cublasDnrm2_v2(cublasHandle_t handle, - int n, - const double *x, - int incx, - double *result) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const double *, int, double *); +cublasStatus_t CUBLASWINAPI cublasDnrm2_v2(cublasHandle_t handle, int n, + const double *x, int incx, + double *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, + const double *, int, double *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDnrm2_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, result); } -cublasStatus_t CUBLASWINAPI cublasScnrm2_v2(cublasHandle_t handle, - int n, - const cuComplex *x, - int incx, - float *result) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuComplex *, int, float *); +cublasStatus_t CUBLASWINAPI cublasScnrm2_v2(cublasHandle_t handle, int n, + const cuComplex *x, int incx, + float *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const cuComplex *, int, float *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasScnrm2_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, result); } -cublasStatus_t CUBLASWINAPI cublasDznrm2_v2(cublasHandle_t handle, - int n, - const cuDoubleComplex *x, - int incx, - double *result) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuDoubleComplex *, int, double *); +cublasStatus_t CUBLASWINAPI cublasDznrm2_v2(cublasHandle_t handle, int n, + const cuDoubleComplex *x, int incx, + double *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const cuDoubleComplex *, int, double *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDznrm2_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, result); } -cublasStatus_t CUBLASWINAPI cublasDotEx (cublasHandle_t handle, - int n, - const void *x, - cudaDataType xType, - int incx, - const void *y, - cudaDataType yType, - int incy, - void *result, - cudaDataType resultType, - cudaDataType executionType) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const void *, cudaDataType, int, const void *, cudaDataType, int, void *, cudaDataType, cudaDataType); +cublasStatus_t CUBLASWINAPI cublasDotEx(cublasHandle_t handle, int n, + const void *x, cudaDataType xType, + int incx, const void *y, + cudaDataType yType, int incy, + void *result, cudaDataType resultType, + cudaDataType executionType) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const void *, cudaDataType, int, const void *, + cudaDataType, int, void *, cudaDataType, cudaDataType); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDotEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, n, x, xType, incx, y, yType, incy, result, resultType, executionType); + return func_ptr(handle, n, x, xType, incx, y, yType, incy, result, resultType, + executionType); } -cublasStatus_t CUBLASWINAPI cublasDotcEx (cublasHandle_t handle, - int n, - const void *x, - cudaDataType xType, - int incx, - const void *y, - cudaDataType yType, - int incy, - void *result, - cudaDataType resultType, - cudaDataType executionType) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const void *, cudaDataType, int, const void *, cudaDataType, int, void *, cudaDataType, cudaDataType); +cublasStatus_t CUBLASWINAPI cublasDotcEx(cublasHandle_t handle, int n, + const void *x, cudaDataType xType, + int incx, const void *y, + cudaDataType yType, int incy, + void *result, cudaDataType resultType, + cudaDataType executionType) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const void *, cudaDataType, int, const void *, + cudaDataType, int, void *, cudaDataType, cudaDataType); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDotcEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, n, x, xType, incx, y, yType, incy, result, resultType, executionType); + return func_ptr(handle, n, x, xType, incx, y, yType, incy, result, resultType, + executionType); } -cublasStatus_t CUBLASWINAPI cublasSdot_v2 (cublasHandle_t handle, - int n, - const float *x, - int incx, - const float *y, - int incy, - float *result) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const float *, int, const float *, int, float *); +cublasStatus_t CUBLASWINAPI cublasSdot_v2(cublasHandle_t handle, int n, + const float *x, int incx, + const float *y, int incy, + float *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const float *, int, const float *, int, float *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSdot_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, y, incy, result); } -cublasStatus_t CUBLASWINAPI cublasDdot_v2 (cublasHandle_t handle, - int n, - const double *x, - int incx, - const double *y, - int incy, - double *result) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const double *, int, const double *, int, double *); +cublasStatus_t CUBLASWINAPI cublasDdot_v2(cublasHandle_t handle, int n, + const double *x, int incx, + const double *y, int incy, + double *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const double *, int, const double *, int, double *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDdot_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, y, incy, result); } -cublasStatus_t CUBLASWINAPI cublasCdotu_v2 (cublasHandle_t handle, - int n, - const cuComplex *x, - int incx, - const cuComplex *y, - int incy, - cuComplex *result) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuComplex *, int, const cuComplex *, int, cuComplex *); +cublasStatus_t CUBLASWINAPI cublasCdotu_v2(cublasHandle_t handle, int n, + const cuComplex *x, int incx, + const cuComplex *y, int incy, + cuComplex *result) { + using FuncPtr = + cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, const cuComplex *, + int, const cuComplex *, int, cuComplex *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCdotu_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, y, incy, result); } -cublasStatus_t CUBLASWINAPI cublasCdotc_v2 (cublasHandle_t handle, - int n, - const cuComplex *x, - int incx, - const cuComplex *y, - int incy, - cuComplex *result) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuComplex *, int, const cuComplex *, int, cuComplex *); +cublasStatus_t CUBLASWINAPI cublasCdotc_v2(cublasHandle_t handle, int n, + const cuComplex *x, int incx, + const cuComplex *y, int incy, + cuComplex *result) { + using FuncPtr = + cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, const cuComplex *, + int, const cuComplex *, int, cuComplex *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCdotc_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, y, incy, result); } -cublasStatus_t CUBLASWINAPI cublasZdotu_v2 (cublasHandle_t handle, - int n, - const cuDoubleComplex *x, - int incx, - const cuDoubleComplex *y, - int incy, - cuDoubleComplex *result) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex *); +cublasStatus_t CUBLASWINAPI cublasZdotu_v2(cublasHandle_t handle, int n, + const cuDoubleComplex *x, int incx, + const cuDoubleComplex *y, int incy, + cuDoubleComplex *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, cuDoubleComplex *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZdotu_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, y, incy, result); } -cublasStatus_t CUBLASWINAPI cublasZdotc_v2 (cublasHandle_t handle, - int n, - const cuDoubleComplex *x, - int incx, - const cuDoubleComplex *y, - int incy, - cuDoubleComplex *result) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex *); +cublasStatus_t CUBLASWINAPI cublasZdotc_v2(cublasHandle_t handle, int n, + const cuDoubleComplex *x, int incx, + const cuDoubleComplex *y, int incy, + cuDoubleComplex *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, cuDoubleComplex *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZdotc_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, y, incy, result); } -cublasStatus_t CUBLASWINAPI cublasScalEx(cublasHandle_t handle, - int n, - const void *alpha, /* host or device pointer */ - cudaDataType alphaType, - void *x, - cudaDataType xType, - int incx, - cudaDataType executionType) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const void *, cudaDataType, void *, cudaDataType, int, cudaDataType); +cublasStatus_t CUBLASWINAPI +cublasScalEx(cublasHandle_t handle, int n, + const void *alpha, /* host or device pointer */ + cudaDataType alphaType, void *x, cudaDataType xType, int incx, + cudaDataType executionType) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const void *, cudaDataType, void *, cudaDataType, + int, cudaDataType); static auto func_ptr = LoadSymbol<FuncPtr>("cublasScalEx"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, alpha, alphaType, x, xType, incx, executionType); } -cublasStatus_t CUBLASWINAPI cublasSscal_v2(cublasHandle_t handle, - int n, - const float *alpha, /* host or device pointer */ - float *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const float *, float *, int); +cublasStatus_t CUBLASWINAPI +cublasSscal_v2(cublasHandle_t handle, int n, + const float *alpha, /* host or device pointer */ + float *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, + const float *, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSscal_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, alpha, x, incx); } -cublasStatus_t CUBLASWINAPI cublasDscal_v2(cublasHandle_t handle, - int n, - const double *alpha, /* host or device pointer */ - double *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const double *, double *, int); +cublasStatus_t CUBLASWINAPI +cublasDscal_v2(cublasHandle_t handle, int n, + const double *alpha, /* host or device pointer */ + double *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, + const double *, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDscal_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, alpha, x, incx); } -cublasStatus_t CUBLASWINAPI cublasCscal_v2(cublasHandle_t handle, - int n, - const cuComplex *alpha, /* host or device pointer */ - cuComplex *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuComplex *, cuComplex *, int); +cublasStatus_t CUBLASWINAPI +cublasCscal_v2(cublasHandle_t handle, int n, + const cuComplex *alpha, /* host or device pointer */ + cuComplex *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const cuComplex *, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCscal_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, alpha, x, incx); } -cublasStatus_t CUBLASWINAPI cublasCsscal_v2(cublasHandle_t handle, - int n, - const float *alpha, /* host or device pointer */ - cuComplex *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const float *, cuComplex *, int); +cublasStatus_t CUBLASWINAPI +cublasCsscal_v2(cublasHandle_t handle, int n, + const float *alpha, /* host or device pointer */ + cuComplex *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const float *, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCsscal_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, alpha, x, incx); } -cublasStatus_t CUBLASWINAPI cublasZscal_v2(cublasHandle_t handle, - int n, - const cuDoubleComplex *alpha, /* host or device pointer */ - cuDoubleComplex *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuDoubleComplex *, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI +cublasZscal_v2(cublasHandle_t handle, int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + cuDoubleComplex *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const cuDoubleComplex *, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZscal_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, alpha, x, incx); } -cublasStatus_t CUBLASWINAPI cublasZdscal_v2(cublasHandle_t handle, - int n, - const double *alpha, /* host or device pointer */ - cuDoubleComplex *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const double *, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI +cublasZdscal_v2(cublasHandle_t handle, int n, + const double *alpha, /* host or device pointer */ + cuDoubleComplex *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const double *, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZdscal_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, alpha, x, incx); } -cublasStatus_t CUBLASWINAPI cublasAxpyEx (cublasHandle_t handle, - int n, - const void *alpha, /* host or device pointer */ - cudaDataType alphaType, - const void *x, - cudaDataType xType, - int incx, - void *y, - cudaDataType yType, - int incy, - cudaDataType executiontype) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const void *, cudaDataType, const void *, cudaDataType, int, void *, cudaDataType, int, cudaDataType); +cublasStatus_t CUBLASWINAPI cublasAxpyEx( + cublasHandle_t handle, int n, + const void *alpha, /* host or device pointer */ + cudaDataType alphaType, const void *x, cudaDataType xType, int incx, + void *y, cudaDataType yType, int incy, cudaDataType executiontype) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const void *, cudaDataType, const void *, + cudaDataType, int, void *, cudaDataType, int, cudaDataType); static auto func_ptr = LoadSymbol<FuncPtr>("cublasAxpyEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, n, alpha, alphaType, x, xType, incx, y, yType, incy, executiontype); + return func_ptr(handle, n, alpha, alphaType, x, xType, incx, y, yType, incy, + executiontype); } -cublasStatus_t CUBLASWINAPI cublasSaxpy_v2 (cublasHandle_t handle, - int n, - const float *alpha, /* host or device pointer */ - const float *x, - int incx, - float *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const float *, const float *, int, float *, int); +cublasStatus_t CUBLASWINAPI +cublasSaxpy_v2(cublasHandle_t handle, int n, + const float *alpha, /* host or device pointer */ + const float *x, int incx, float *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const float *, const float *, int, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSaxpy_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, alpha, x, incx, y, incy); } -cublasStatus_t CUBLASWINAPI cublasDaxpy_v2 (cublasHandle_t handle, - int n, - const double *alpha, /* host or device pointer */ - const double *x, - int incx, - double *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const double *, const double *, int, double *, int); +cublasStatus_t CUBLASWINAPI +cublasDaxpy_v2(cublasHandle_t handle, int n, + const double *alpha, /* host or device pointer */ + const double *x, int incx, double *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const double *, const double *, int, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDaxpy_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, alpha, x, incx, y, incy); } -cublasStatus_t CUBLASWINAPI cublasCaxpy_v2 (cublasHandle_t handle, - int n, - const cuComplex *alpha, /* host or device pointer */ - const cuComplex *x, - int incx, - cuComplex *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuComplex *, const cuComplex *, int, cuComplex *, int); +cublasStatus_t CUBLASWINAPI +cublasCaxpy_v2(cublasHandle_t handle, int n, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *x, int incx, cuComplex *y, int incy) { + using FuncPtr = + cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, const cuComplex *, + const cuComplex *, int, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCaxpy_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, alpha, x, incx, y, incy); } -cublasStatus_t CUBLASWINAPI cublasZaxpy_v2 (cublasHandle_t handle, - int n, - const cuDoubleComplex *alpha, /* host or device pointer */ - const cuDoubleComplex *x, - int incx, - cuDoubleComplex *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuDoubleComplex *, const cuDoubleComplex *, int, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI cublasZaxpy_v2( + cublasHandle_t handle, int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *x, int incx, cuDoubleComplex *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const cuDoubleComplex *, const cuDoubleComplex *, + int, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZaxpy_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, alpha, x, incx, y, incy); @@ -528,97 +529,82 @@ cublasStatus_t CUBLASWINAPI cublasCopyEx(cublasHandle_t handle, int n, return func_ptr(handle, n, x, xType, incx, y, yType, incy); } -cublasStatus_t CUBLASWINAPI cublasScopy_v2 (cublasHandle_t handle, - int n, - const float *x, - int incx, - float *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const float *, int, float *, int); +cublasStatus_t CUBLASWINAPI cublasScopy_v2(cublasHandle_t handle, int n, + const float *x, int incx, float *y, + int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const float *, int, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasScopy_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, y, incy); } -cublasStatus_t CUBLASWINAPI cublasDcopy_v2 (cublasHandle_t handle, - int n, - const double *x, - int incx, - double *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const double *, int, double *, int); +cublasStatus_t CUBLASWINAPI cublasDcopy_v2(cublasHandle_t handle, int n, + const double *x, int incx, double *y, + int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const double *, int, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDcopy_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, y, incy); } -cublasStatus_t CUBLASWINAPI cublasCcopy_v2 (cublasHandle_t handle, - int n, - const cuComplex *x, - int incx, - cuComplex *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuComplex *, int, cuComplex *, int); +cublasStatus_t CUBLASWINAPI cublasCcopy_v2(cublasHandle_t handle, int n, + const cuComplex *x, int incx, + cuComplex *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const cuComplex *, int, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCcopy_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, y, incy); } -cublasStatus_t CUBLASWINAPI cublasZcopy_v2 (cublasHandle_t handle, - int n, - const cuDoubleComplex *x, - int incx, - cuDoubleComplex *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI cublasZcopy_v2(cublasHandle_t handle, int n, + const cuDoubleComplex *x, int incx, + cuDoubleComplex *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, + const cuDoubleComplex *, int, + cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZcopy_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, y, incy); } -cublasStatus_t CUBLASWINAPI cublasSswap_v2 (cublasHandle_t handle, - int n, - float *x, - int incx, - float *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, float *, int, float *, int); +cublasStatus_t CUBLASWINAPI cublasSswap_v2(cublasHandle_t handle, int n, + float *x, int incx, float *y, + int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, float *, + int, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSswap_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, y, incy); } -cublasStatus_t CUBLASWINAPI cublasDswap_v2 (cublasHandle_t handle, - int n, - double *x, - int incx, - double *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, double *, int, double *, int); +cublasStatus_t CUBLASWINAPI cublasDswap_v2(cublasHandle_t handle, int n, + double *x, int incx, double *y, + int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, double *, + int, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDswap_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, y, incy); } -cublasStatus_t CUBLASWINAPI cublasCswap_v2 (cublasHandle_t handle, - int n, - cuComplex *x, - int incx, - cuComplex *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, cuComplex *, int, cuComplex *, int); +cublasStatus_t CUBLASWINAPI cublasCswap_v2(cublasHandle_t handle, int n, + cuComplex *x, int incx, cuComplex *y, + int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, cuComplex *, int, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCswap_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, y, incy); } -cublasStatus_t CUBLASWINAPI cublasZswap_v2 (cublasHandle_t handle, - int n, - cuDoubleComplex *x, - int incx, - cuDoubleComplex *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, cuDoubleComplex *, int, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI cublasZswap_v2(cublasHandle_t handle, int n, + cuDoubleComplex *x, int incx, + cuDoubleComplex *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, cuDoubleComplex *, int, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZswap_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, y, incy); @@ -635,45 +621,41 @@ cublasStatus_t CUBLASWINAPI cublasSwapEx(cublasHandle_t handle, int n, void *x, return func_ptr(handle, n, x, xType, incx, y, yType, incy); } -cublasStatus_t CUBLASWINAPI cublasIsamax_v2(cublasHandle_t handle, - int n, - const float *x, - int incx, - int *result) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const float *, int, int *); +cublasStatus_t CUBLASWINAPI cublasIsamax_v2(cublasHandle_t handle, int n, + const float *x, int incx, + int *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, + const float *, int, int *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasIsamax_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, result); } -cublasStatus_t CUBLASWINAPI cublasIdamax_v2(cublasHandle_t handle, - int n, - const double *x, - int incx, - int *result) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const double *, int, int *); +cublasStatus_t CUBLASWINAPI cublasIdamax_v2(cublasHandle_t handle, int n, + const double *x, int incx, + int *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, + const double *, int, int *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasIdamax_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, result); } -cublasStatus_t CUBLASWINAPI cublasIcamax_v2(cublasHandle_t handle, - int n, - const cuComplex *x, - int incx, - int *result) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuComplex *, int, int *); +cublasStatus_t CUBLASWINAPI cublasIcamax_v2(cublasHandle_t handle, int n, + const cuComplex *x, int incx, + int *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, + const cuComplex *, int, int *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasIcamax_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, result); } -cublasStatus_t CUBLASWINAPI cublasIzamax_v2(cublasHandle_t handle, - int n, - const cuDoubleComplex *x, - int incx, - int *result) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuDoubleComplex *, int, int *); +cublasStatus_t CUBLASWINAPI cublasIzamax_v2(cublasHandle_t handle, int n, + const cuDoubleComplex *x, int incx, + int *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const cuDoubleComplex *, int, int *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasIzamax_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, result); @@ -690,45 +672,41 @@ cublasStatus_t CUBLASWINAPI cublasIamaxEx( return func_ptr(handle, n, x, xType, incx, result); } -cublasStatus_t CUBLASWINAPI cublasIsamin_v2(cublasHandle_t handle, - int n, - const float *x, - int incx, - int *result) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const float *, int, int *); +cublasStatus_t CUBLASWINAPI cublasIsamin_v2(cublasHandle_t handle, int n, + const float *x, int incx, + int *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, + const float *, int, int *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasIsamin_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, result); } -cublasStatus_t CUBLASWINAPI cublasIdamin_v2(cublasHandle_t handle, - int n, - const double *x, - int incx, - int *result) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const double *, int, int *); +cublasStatus_t CUBLASWINAPI cublasIdamin_v2(cublasHandle_t handle, int n, + const double *x, int incx, + int *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, + const double *, int, int *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasIdamin_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, result); } -cublasStatus_t CUBLASWINAPI cublasIcamin_v2(cublasHandle_t handle, - int n, - const cuComplex *x, - int incx, - int *result) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuComplex *, int, int *); +cublasStatus_t CUBLASWINAPI cublasIcamin_v2(cublasHandle_t handle, int n, + const cuComplex *x, int incx, + int *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, + const cuComplex *, int, int *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasIcamin_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, result); } -cublasStatus_t CUBLASWINAPI cublasIzamin_v2(cublasHandle_t handle, - int n, - const cuDoubleComplex *x, - int incx, - int *result) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuDoubleComplex *, int, int *); +cublasStatus_t CUBLASWINAPI cublasIzamin_v2(cublasHandle_t handle, int n, + const cuDoubleComplex *x, int incx, + int *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const cuDoubleComplex *, int, int *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasIzamin_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, result); @@ -757,129 +735,113 @@ cublasStatus_t CUBLASWINAPI cublasAsumEx( return func_ptr(handle, n, x, xType, incx, result, resultType, executiontype); } -cublasStatus_t CUBLASWINAPI cublasSasum_v2(cublasHandle_t handle, - int n, - const float *x, - int incx, - float *result) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const float *, int, float *); +cublasStatus_t CUBLASWINAPI cublasSasum_v2(cublasHandle_t handle, int n, + const float *x, int incx, + float *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, + const float *, int, float *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSasum_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, result); } -cublasStatus_t CUBLASWINAPI cublasDasum_v2(cublasHandle_t handle, - int n, - const double *x, - int incx, - double *result) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const double *, int, double *); +cublasStatus_t CUBLASWINAPI cublasDasum_v2(cublasHandle_t handle, int n, + const double *x, int incx, + double *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, + const double *, int, double *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDasum_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, result); } -cublasStatus_t CUBLASWINAPI cublasScasum_v2(cublasHandle_t handle, - int n, - const cuComplex *x, - int incx, - float *result) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuComplex *, int, float *); +cublasStatus_t CUBLASWINAPI cublasScasum_v2(cublasHandle_t handle, int n, + const cuComplex *x, int incx, + float *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const cuComplex *, int, float *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasScasum_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, result); } -cublasStatus_t CUBLASWINAPI cublasDzasum_v2(cublasHandle_t handle, - int n, - const cuDoubleComplex *x, - int incx, - double *result) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuDoubleComplex *, int, double *); +cublasStatus_t CUBLASWINAPI cublasDzasum_v2(cublasHandle_t handle, int n, + const cuDoubleComplex *x, int incx, + double *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const cuDoubleComplex *, int, double *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDzasum_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, result); } -cublasStatus_t CUBLASWINAPI cublasSrot_v2 (cublasHandle_t handle, - int n, - float *x, - int incx, - float *y, - int incy, - const float *c, /* host or device pointer */ - const float *s) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, float *, int, float *, int, const float *, const float *); +cublasStatus_t CUBLASWINAPI +cublasSrot_v2(cublasHandle_t handle, int n, float *x, int incx, float *y, + int incy, const float *c, /* host or device pointer */ + const float *s) { + using FuncPtr = + cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, float *, int, float *, + int, const float *, const float *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSrot_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, y, incy, c, s); } -cublasStatus_t CUBLASWINAPI cublasDrot_v2 (cublasHandle_t handle, - int n, - double *x, - int incx, - double *y, - int incy, - const double *c, /* host or device pointer */ - const double *s) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, double *, int, double *, int, const double *, const double *); +cublasStatus_t CUBLASWINAPI +cublasDrot_v2(cublasHandle_t handle, int n, double *x, int incx, double *y, + int incy, const double *c, /* host or device pointer */ + const double *s) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, double *, int, double *, int, const double *, + const double *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDrot_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, y, incy, c, s); } -cublasStatus_t CUBLASWINAPI cublasCrot_v2 (cublasHandle_t handle, - int n, - cuComplex *x, - int incx, - cuComplex *y, - int incy, - const float *c, /* host or device pointer */ - const cuComplex *s) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, cuComplex *, int, cuComplex *, int, const float *, const cuComplex *); +cublasStatus_t CUBLASWINAPI cublasCrot_v2( + cublasHandle_t handle, int n, cuComplex *x, int incx, cuComplex *y, + int incy, const float *c, /* host or device pointer */ + const cuComplex *s) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, cuComplex *, int, cuComplex *, int, const float *, + const cuComplex *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCrot_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, y, incy, c, s); } -cublasStatus_t CUBLASWINAPI cublasCsrot_v2(cublasHandle_t handle, - int n, - cuComplex *x, - int incx, - cuComplex *y, - int incy, - const float *c, /* host or device pointer */ - const float *s) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, cuComplex *, int, cuComplex *, int, const float *, const float *); +cublasStatus_t CUBLASWINAPI cublasCsrot_v2( + cublasHandle_t handle, int n, cuComplex *x, int incx, cuComplex *y, + int incy, const float *c, /* host or device pointer */ + const float *s) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, cuComplex *, int, cuComplex *, int, const float *, + const float *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCsrot_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, y, incy, c, s); } -cublasStatus_t CUBLASWINAPI cublasZrot_v2 (cublasHandle_t handle, - int n, - cuDoubleComplex *x, - int incx, - cuDoubleComplex *y, - int incy, - const double *c, /* host or device pointer */ - const cuDoubleComplex *s) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, cuDoubleComplex *, int, cuDoubleComplex *, int, const double *, const cuDoubleComplex *); +cublasStatus_t CUBLASWINAPI cublasZrot_v2( + cublasHandle_t handle, int n, cuDoubleComplex *x, int incx, + cuDoubleComplex *y, int incy, const double *c, /* host or device pointer */ + const cuDoubleComplex *s) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, cuDoubleComplex *, int, cuDoubleComplex *, int, + const double *, const cuDoubleComplex *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZrot_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, y, incy, c, s); } -cublasStatus_t CUBLASWINAPI cublasZdrot_v2(cublasHandle_t handle, - int n, - cuDoubleComplex *x, - int incx, - cuDoubleComplex *y, - int incy, - const double *c, /* host or device pointer */ - const double *s) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, cuDoubleComplex *, int, cuDoubleComplex *, int, const double *, const double *); +cublasStatus_t CUBLASWINAPI cublasZdrot_v2( + cublasHandle_t handle, int n, cuDoubleComplex *x, int incx, + cuDoubleComplex *y, int incy, const double *c, /* host or device pointer */ + const double *s) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, cuDoubleComplex *, int, cuDoubleComplex *, int, + const double *, const double *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZdrot_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, y, incy, c, s); @@ -899,45 +861,50 @@ cublasRotEx(cublasHandle_t handle, int n, void *x, cudaDataType xType, int incx, executiontype); } -cublasStatus_t CUBLASWINAPI cublasSrotg_v2(cublasHandle_t handle, - float *a, /* host or device pointer */ - float *b, /* host or device pointer */ - float *c, /* host or device pointer */ - float *s) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, float *, float *, float *, float *); +cublasStatus_t CUBLASWINAPI +cublasSrotg_v2(cublasHandle_t handle, float *a, /* host or device pointer */ + float *b, /* host or device pointer */ + float *c, /* host or device pointer */ + float *s) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, float *, + float *, float *, float *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSrotg_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, a, b, c, s); } -cublasStatus_t CUBLASWINAPI cublasDrotg_v2(cublasHandle_t handle, - double *a, /* host or device pointer */ - double *b, /* host or device pointer */ - double *c, /* host or device pointer */ - double *s) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, double *, double *, double *, double *); +cublasStatus_t CUBLASWINAPI +cublasDrotg_v2(cublasHandle_t handle, double *a, /* host or device pointer */ + double *b, /* host or device pointer */ + double *c, /* host or device pointer */ + double *s) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, double *, + double *, double *, double *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDrotg_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, a, b, c, s); } -cublasStatus_t CUBLASWINAPI cublasCrotg_v2(cublasHandle_t handle, - cuComplex *a, /* host or device pointer */ - cuComplex *b, /* host or device pointer */ - float *c, /* host or device pointer */ - cuComplex *s) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cuComplex *, cuComplex *, float *, cuComplex *); +cublasStatus_t CUBLASWINAPI +cublasCrotg_v2(cublasHandle_t handle, cuComplex *a, /* host or device pointer */ + cuComplex *b, /* host or device pointer */ + float *c, /* host or device pointer */ + cuComplex *s) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cuComplex *, cuComplex *, float *, cuComplex *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCrotg_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, a, b, c, s); } -cublasStatus_t CUBLASWINAPI cublasZrotg_v2(cublasHandle_t handle, - cuDoubleComplex *a, /* host or device pointer */ - cuDoubleComplex *b, /* host or device pointer */ - double *c, /* host or device pointer */ - cuDoubleComplex *s) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cuDoubleComplex *, cuDoubleComplex *, double *, cuDoubleComplex *); +cublasStatus_t CUBLASWINAPI cublasZrotg_v2( + cublasHandle_t handle, cuDoubleComplex *a, /* host or device pointer */ + cuDoubleComplex *b, /* host or device pointer */ + double *c, /* host or device pointer */ + cuDoubleComplex *s) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cuDoubleComplex *, cuDoubleComplex *, double *, + cuDoubleComplex *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZrotg_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, a, b, c, s); @@ -959,27 +926,21 @@ cublasStatus_t CUBLASWINAPI cublasRotgEx(cublasHandle_t handle, return func_ptr(handle, a, b, abType, c, s, csType, executiontype); } -cublasStatus_t CUBLASWINAPI cublasSrotm_v2(cublasHandle_t handle, - int n, - float *x, - int incx, - float *y, - int incy, - const float* param) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, float *, int, float *, int, const float *); +cublasStatus_t CUBLASWINAPI cublasSrotm_v2(cublasHandle_t handle, int n, + float *x, int incx, float *y, + int incy, const float *param) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, float *, int, float *, int, const float *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSrotm_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, y, incy, param); } -cublasStatus_t CUBLASWINAPI cublasDrotm_v2(cublasHandle_t handle, - int n, - double *x, - int incx, - double *y, - int incy, - const double* param) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, double *, int, double *, int, const double *); +cublasStatus_t CUBLASWINAPI cublasDrotm_v2(cublasHandle_t handle, int n, + double *x, int incx, double *y, + int incy, const double *param) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, double *, int, double *, int, const double *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDrotm_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, y, incy, param); @@ -999,25 +960,27 @@ cublasRotmEx(cublasHandle_t handle, int n, void *x, cudaDataType xType, executiontype); } -cublasStatus_t CUBLASWINAPI cublasSrotmg_v2(cublasHandle_t handle, - float *d1, /* host or device pointer */ - float *d2, /* host or device pointer */ - float *x1, /* host or device pointer */ - const float *y1, /* host or device pointer */ - float *param) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, float *, float *, float *, const float *, float *); +cublasStatus_t CUBLASWINAPI +cublasSrotmg_v2(cublasHandle_t handle, float *d1, /* host or device pointer */ + float *d2, /* host or device pointer */ + float *x1, /* host or device pointer */ + const float *y1, /* host or device pointer */ + float *param) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, float *, float *, float *, const float *, float *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSrotmg_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, d1, d2, x1, y1, param); } -cublasStatus_t CUBLASWINAPI cublasDrotmg_v2(cublasHandle_t handle, - double *d1, /* host or device pointer */ - double *d2, /* host or device pointer */ - double *x1, /* host or device pointer */ - const double *y1, /* host or device pointer */ - double *param) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, double *, double *, double *, const double *, double *); +cublasStatus_t CUBLASWINAPI +cublasDrotmg_v2(cublasHandle_t handle, double *d1, /* host or device pointer */ + double *d2, /* host or device pointer */ + double *x1, /* host or device pointer */ + const double *y1, /* host or device pointer */ + double *param) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, double *, double *, double *, const double *, double *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDrotmg_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, d1, d2, x1, y1, param); @@ -1040,2031 +1003,1701 @@ cublasRotmgEx(cublasHandle_t handle, void *d1, /* host or device pointer */ paramType, executiontype); } -cublasStatus_t CUBLASWINAPI cublasSgemv_v2 (cublasHandle_t handle, - cublasOperation_t trans, - int m, - int n, - const float *alpha, /* host or device pointer */ - const float *A, - int lda, - const float *x, - int incx, - const float *beta, /* host or device pointer */ - float *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, int, int, const float *, const float *, int, const float *, int, const float *, float *, int); +cublasStatus_t CUBLASWINAPI +cublasSgemv_v2(cublasHandle_t handle, cublasOperation_t trans, int m, int n, + const float *alpha, /* host or device pointer */ + const float *A, int lda, const float *x, int incx, + const float *beta, /* host or device pointer */ + float *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, int, int, const float *, const float *, + int, const float *, int, const float *, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSgemv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, trans, m, n, alpha, A, lda, x, incx, beta, y, incy); } -cublasStatus_t CUBLASWINAPI cublasDgemv_v2 (cublasHandle_t handle, - cublasOperation_t trans, - int m, - int n, - const double *alpha, /* host or device pointer */ - const double *A, - int lda, - const double *x, - int incx, - const double *beta, /* host or device pointer */ - double *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, int, int, const double *, const double *, int, const double *, int, const double *, double *, int); +cublasStatus_t CUBLASWINAPI +cublasDgemv_v2(cublasHandle_t handle, cublasOperation_t trans, int m, int n, + const double *alpha, /* host or device pointer */ + const double *A, int lda, const double *x, int incx, + const double *beta, /* host or device pointer */ + double *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, int, int, const double *, + const double *, int, const double *, int, const double *, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDgemv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, trans, m, n, alpha, A, lda, x, incx, beta, y, incy); } -cublasStatus_t CUBLASWINAPI cublasCgemv_v2 (cublasHandle_t handle, - cublasOperation_t trans, - int m, - int n, - const cuComplex *alpha, /* host or device pointer */ - const cuComplex *A, - int lda, - const cuComplex *x, - int incx, - const cuComplex *beta, /* host or device pointer */ - cuComplex *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, int, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, const cuComplex *, cuComplex *, int); +cublasStatus_t CUBLASWINAPI +cublasCgemv_v2(cublasHandle_t handle, cublasOperation_t trans, int m, int n, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, const cuComplex *x, int incx, + const cuComplex *beta, /* host or device pointer */ + cuComplex *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, int, int, const cuComplex *, + const cuComplex *, int, const cuComplex *, int, const cuComplex *, + cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgemv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, trans, m, n, alpha, A, lda, x, incx, beta, y, incy); } -cublasStatus_t CUBLASWINAPI cublasZgemv_v2 (cublasHandle_t handle, - cublasOperation_t trans, - int m, - int n, - const cuDoubleComplex *alpha, /* host or device pointer */ - const cuDoubleComplex *A, - int lda, - const cuDoubleComplex *x, - int incx, - const cuDoubleComplex *beta, /* host or device pointer */ - cuDoubleComplex *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, int, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI cublasZgemv_v2( + cublasHandle_t handle, cublasOperation_t trans, int m, int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, const cuDoubleComplex *x, int incx, + const cuDoubleComplex *beta, /* host or device pointer */ + cuDoubleComplex *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, int, int, const cuDoubleComplex *, + const cuDoubleComplex *, int, const cuDoubleComplex *, int, + const cuDoubleComplex *, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZgemv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, trans, m, n, alpha, A, lda, x, incx, beta, y, incy); } -cublasStatus_t CUBLASWINAPI cublasSgbmv_v2 (cublasHandle_t handle, - cublasOperation_t trans, - int m, - int n, - int kl, - int ku, - const float *alpha, /* host or device pointer */ - const float *A, - int lda, - const float *x, - int incx, - const float *beta, /* host or device pointer */ - float *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, int, int, int, int, const float *, const float *, int, const float *, int, const float *, float *, int); +cublasStatus_t CUBLASWINAPI +cublasSgbmv_v2(cublasHandle_t handle, cublasOperation_t trans, int m, int n, + int kl, int ku, const float *alpha, /* host or device pointer */ + const float *A, int lda, const float *x, int incx, + const float *beta, /* host or device pointer */ + float *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, int, int, int, int, const float *, + const float *, int, const float *, int, const float *, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSgbmv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, trans, m, n, kl, ku, alpha, A, lda, x, incx, beta, y, incy); + return func_ptr(handle, trans, m, n, kl, ku, alpha, A, lda, x, incx, beta, y, + incy); } -cublasStatus_t CUBLASWINAPI cublasDgbmv_v2 (cublasHandle_t handle, - cublasOperation_t trans, - int m, - int n, - int kl, - int ku, - const double *alpha, /* host or device pointer */ - const double *A, - int lda, - const double *x, - int incx, - const double *beta, /* host or device pointer */ - double *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, int, int, int, int, const double *, const double *, int, const double *, int, const double *, double *, int); +cublasStatus_t CUBLASWINAPI +cublasDgbmv_v2(cublasHandle_t handle, cublasOperation_t trans, int m, int n, + int kl, int ku, const double *alpha, /* host or device pointer */ + const double *A, int lda, const double *x, int incx, + const double *beta, /* host or device pointer */ + double *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, int, int, int, int, const double *, + const double *, int, const double *, int, const double *, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDgbmv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, trans, m, n, kl, ku, alpha, A, lda, x, incx, beta, y, incy); + return func_ptr(handle, trans, m, n, kl, ku, alpha, A, lda, x, incx, beta, y, + incy); } -cublasStatus_t CUBLASWINAPI cublasCgbmv_v2 (cublasHandle_t handle, - cublasOperation_t trans, - int m, - int n, - int kl, - int ku, - const cuComplex *alpha, /* host or device pointer */ - const cuComplex *A, - int lda, - const cuComplex *x, - int incx, - const cuComplex *beta, /* host or device pointer */ - cuComplex *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, int, int, int, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, const cuComplex *, cuComplex *, int); +cublasStatus_t CUBLASWINAPI cublasCgbmv_v2( + cublasHandle_t handle, cublasOperation_t trans, int m, int n, int kl, + int ku, const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, const cuComplex *x, int incx, + const cuComplex *beta, /* host or device pointer */ + cuComplex *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, int, int, int, int, const cuComplex *, + const cuComplex *, int, const cuComplex *, int, const cuComplex *, + cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgbmv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, trans, m, n, kl, ku, alpha, A, lda, x, incx, beta, y, incy); + return func_ptr(handle, trans, m, n, kl, ku, alpha, A, lda, x, incx, beta, y, + incy); } -cublasStatus_t CUBLASWINAPI cublasZgbmv_v2 (cublasHandle_t handle, - cublasOperation_t trans, - int m, - int n, - int kl, - int ku, - const cuDoubleComplex *alpha, /* host or device pointer */ - const cuDoubleComplex *A, - int lda, - const cuDoubleComplex *x, - int incx, - const cuDoubleComplex *beta, /* host or device pointer */ - cuDoubleComplex *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, int, int, int, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI cublasZgbmv_v2( + cublasHandle_t handle, cublasOperation_t trans, int m, int n, int kl, + int ku, const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, const cuDoubleComplex *x, int incx, + const cuDoubleComplex *beta, /* host or device pointer */ + cuDoubleComplex *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, int, int, int, int, + const cuDoubleComplex *, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, + int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZgbmv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, trans, m, n, kl, ku, alpha, A, lda, x, incx, beta, y, incy); + return func_ptr(handle, trans, m, n, kl, ku, alpha, A, lda, x, incx, beta, y, + incy); } -cublasStatus_t CUBLASWINAPI cublasStrmv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int n, - const float *A, - int lda, - float *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, const float *, int, float *, int); +cublasStatus_t CUBLASWINAPI cublasStrmv_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + cublasDiagType_t diag, int n, const float *A, int lda, float *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, const float *, int, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasStrmv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, diag, n, A, lda, x, incx); } -cublasStatus_t CUBLASWINAPI cublasDtrmv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int n, - const double *A, - int lda, - double *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, const double *, int, double *, int); +cublasStatus_t CUBLASWINAPI cublasDtrmv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, int n, + const double *A, int lda, double *x, + int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, const double *, int, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtrmv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, diag, n, A, lda, x, incx); } -cublasStatus_t CUBLASWINAPI cublasCtrmv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int n, - const cuComplex *A, - int lda, - cuComplex *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, const cuComplex *, int, cuComplex *, int); +cublasStatus_t CUBLASWINAPI cublasCtrmv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, int n, + const cuComplex *A, int lda, + cuComplex *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, const cuComplex *, int, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtrmv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, diag, n, A, lda, x, incx); } -cublasStatus_t CUBLASWINAPI cublasZtrmv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int n, - const cuDoubleComplex *A, - int lda, - cuDoubleComplex *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI cublasZtrmv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, int n, + const cuDoubleComplex *A, int lda, + cuDoubleComplex *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, const cuDoubleComplex *, int, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtrmv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, diag, n, A, lda, x, incx); } -cublasStatus_t CUBLASWINAPI cublasStbmv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int n, - int k, - const float *A, - int lda, - float *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, int, const float *, int, float *, int); +cublasStatus_t CUBLASWINAPI cublasStbmv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, int n, int k, + const float *A, int lda, float *x, + int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, int, const float *, int, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasStbmv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, diag, n, k, A, lda, x, incx); } -cublasStatus_t CUBLASWINAPI cublasDtbmv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int n, - int k, - const double *A, - int lda, - double *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, int, const double *, int, double *, int); +cublasStatus_t CUBLASWINAPI cublasDtbmv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, int n, int k, + const double *A, int lda, double *x, + int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, int, const double *, int, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtbmv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, diag, n, k, A, lda, x, incx); } -cublasStatus_t CUBLASWINAPI cublasCtbmv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int n, - int k, - const cuComplex *A, - int lda, - cuComplex *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, int, const cuComplex *, int, cuComplex *, int); +cublasStatus_t CUBLASWINAPI cublasCtbmv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, int n, int k, + const cuComplex *A, int lda, + cuComplex *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, int, const cuComplex *, int, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtbmv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, diag, n, k, A, lda, x, incx); } -cublasStatus_t CUBLASWINAPI cublasZtbmv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int n, - int k, - const cuDoubleComplex *A, - int lda, - cuDoubleComplex *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI cublasZtbmv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, int n, int k, + const cuDoubleComplex *A, int lda, + cuDoubleComplex *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtbmv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, diag, n, k, A, lda, x, incx); } -cublasStatus_t CUBLASWINAPI cublasStpmv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int n, - const float *AP, - float *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, const float *, float *, int); +cublasStatus_t CUBLASWINAPI cublasStpmv_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + cublasDiagType_t diag, int n, const float *AP, float *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, const float *, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasStpmv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, diag, n, AP, x, incx); } -cublasStatus_t CUBLASWINAPI cublasDtpmv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int n, - const double *AP, - double *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, const double *, double *, int); +cublasStatus_t CUBLASWINAPI cublasDtpmv_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + cublasDiagType_t diag, int n, const double *AP, double *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, const double *, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtpmv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, diag, n, AP, x, incx); } -cublasStatus_t CUBLASWINAPI cublasCtpmv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int n, - const cuComplex *AP, - cuComplex *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, const cuComplex *, cuComplex *, int); +cublasStatus_t CUBLASWINAPI cublasCtpmv_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + cublasDiagType_t diag, int n, const cuComplex *AP, cuComplex *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, const cuComplex *, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtpmv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, diag, n, AP, x, incx); } -cublasStatus_t CUBLASWINAPI cublasZtpmv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int n, - const cuDoubleComplex *AP, - cuDoubleComplex *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, const cuDoubleComplex *, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI cublasZtpmv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, int n, + const cuDoubleComplex *AP, + cuDoubleComplex *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, const cuDoubleComplex *, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtpmv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, diag, n, AP, x, incx); } -cublasStatus_t CUBLASWINAPI cublasStrsv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int n, - const float *A, - int lda, - float *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, const float *, int, float *, int); +cublasStatus_t CUBLASWINAPI cublasStrsv_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + cublasDiagType_t diag, int n, const float *A, int lda, float *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, const float *, int, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasStrsv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, diag, n, A, lda, x, incx); } -cublasStatus_t CUBLASWINAPI cublasDtrsv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int n, - const double *A, - int lda, - double *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, const double *, int, double *, int); +cublasStatus_t CUBLASWINAPI cublasDtrsv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, int n, + const double *A, int lda, double *x, + int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, const double *, int, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtrsv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, diag, n, A, lda, x, incx); } -cublasStatus_t CUBLASWINAPI cublasCtrsv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int n, - const cuComplex *A, - int lda, - cuComplex *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, const cuComplex *, int, cuComplex *, int); +cublasStatus_t CUBLASWINAPI cublasCtrsv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, int n, + const cuComplex *A, int lda, + cuComplex *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, const cuComplex *, int, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtrsv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, diag, n, A, lda, x, incx); } -cublasStatus_t CUBLASWINAPI cublasZtrsv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int n, - const cuDoubleComplex *A, - int lda, - cuDoubleComplex *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI cublasZtrsv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, int n, + const cuDoubleComplex *A, int lda, + cuDoubleComplex *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, const cuDoubleComplex *, int, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtrsv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, diag, n, A, lda, x, incx); } -cublasStatus_t CUBLASWINAPI cublasStpsv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int n, - const float *AP, - float *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, const float *, float *, int); +cublasStatus_t CUBLASWINAPI cublasStpsv_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + cublasDiagType_t diag, int n, const float *AP, float *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, const float *, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasStpsv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, diag, n, AP, x, incx); } -cublasStatus_t CUBLASWINAPI cublasDtpsv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int n, - const double *AP, - double *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, const double *, double *, int); +cublasStatus_t CUBLASWINAPI cublasDtpsv_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + cublasDiagType_t diag, int n, const double *AP, double *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, const double *, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtpsv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, diag, n, AP, x, incx); } -cublasStatus_t CUBLASWINAPI cublasCtpsv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int n, - const cuComplex *AP, - cuComplex *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, const cuComplex *, cuComplex *, int); +cublasStatus_t CUBLASWINAPI cublasCtpsv_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + cublasDiagType_t diag, int n, const cuComplex *AP, cuComplex *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, const cuComplex *, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtpsv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, diag, n, AP, x, incx); } -cublasStatus_t CUBLASWINAPI cublasZtpsv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int n, - const cuDoubleComplex *AP, - cuDoubleComplex *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, const cuDoubleComplex *, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI cublasZtpsv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, int n, + const cuDoubleComplex *AP, + cuDoubleComplex *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, const cuDoubleComplex *, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtpsv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, diag, n, AP, x, incx); } -cublasStatus_t CUBLASWINAPI cublasStbsv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int n, - int k, - const float *A, - int lda, - float *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, int, const float *, int, float *, int); +cublasStatus_t CUBLASWINAPI cublasStbsv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, int n, int k, + const float *A, int lda, float *x, + int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, int, const float *, int, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasStbsv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, diag, n, k, A, lda, x, incx); } -cublasStatus_t CUBLASWINAPI cublasDtbsv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int n, - int k, - const double *A, - int lda, - double *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, int, const double *, int, double *, int); +cublasStatus_t CUBLASWINAPI cublasDtbsv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, int n, int k, + const double *A, int lda, double *x, + int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, int, const double *, int, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtbsv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, diag, n, k, A, lda, x, incx); } -cublasStatus_t CUBLASWINAPI cublasCtbsv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int n, - int k, - const cuComplex *A, - int lda, - cuComplex *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, int, const cuComplex *, int, cuComplex *, int); +cublasStatus_t CUBLASWINAPI cublasCtbsv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, int n, int k, + const cuComplex *A, int lda, + cuComplex *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, int, const cuComplex *, int, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtbsv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, diag, n, k, A, lda, x, incx); } -cublasStatus_t CUBLASWINAPI cublasZtbsv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int n, - int k, - const cuDoubleComplex *A, - int lda, - cuDoubleComplex *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI cublasZtbsv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, int n, int k, + const cuDoubleComplex *A, int lda, + cuDoubleComplex *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtbsv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, diag, n, k, A, lda, x, incx); } -cublasStatus_t CUBLASWINAPI cublasSsymv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const float *alpha, /* host or device pointer */ - const float *A, - int lda, - const float *x, - int incx, - const float *beta, /* host or device pointer */ - float *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const float *, const float *, int, const float *, int, const float *, float *, int); +cublasStatus_t CUBLASWINAPI +cublasSsymv_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const float *alpha, /* host or device pointer */ + const float *A, int lda, const float *x, int incx, + const float *beta, /* host or device pointer */ + float *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const float *, const float *, int, + const float *, int, const float *, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSsymv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, A, lda, x, incx, beta, y, incy); } -cublasStatus_t CUBLASWINAPI cublasDsymv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const double *alpha, /* host or device pointer */ - const double *A, - int lda, - const double *x, - int incx, - const double *beta, /* host or device pointer */ - double *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const double *, const double *, int, const double *, int, const double *, double *, int); +cublasStatus_t CUBLASWINAPI +cublasDsymv_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const double *alpha, /* host or device pointer */ + const double *A, int lda, const double *x, int incx, + const double *beta, /* host or device pointer */ + double *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const double *, const double *, + int, const double *, int, const double *, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDsymv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, A, lda, x, incx, beta, y, incy); } -cublasStatus_t CUBLASWINAPI cublasCsymv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const cuComplex *alpha, /* host or device pointer */ - const cuComplex *A, - int lda, - const cuComplex *x, - int incx, - const cuComplex *beta, /* host or device pointer */ - cuComplex *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, const cuComplex *, cuComplex *, int); +cublasStatus_t CUBLASWINAPI +cublasCsymv_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, const cuComplex *x, int incx, + const cuComplex *beta, /* host or device pointer */ + cuComplex *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const cuComplex *, + const cuComplex *, int, const cuComplex *, int, const cuComplex *, + cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCsymv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, A, lda, x, incx, beta, y, incy); } -cublasStatus_t CUBLASWINAPI cublasZsymv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const cuDoubleComplex *alpha, /* host or device pointer */ - const cuDoubleComplex *A, - int lda, - const cuDoubleComplex *x, - int incx, - const cuDoubleComplex *beta, /* host or device pointer */ - cuDoubleComplex *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI cublasZsymv_v2( + cublasHandle_t handle, cublasFillMode_t uplo, int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, const cuDoubleComplex *x, int incx, + const cuDoubleComplex *beta, /* host or device pointer */ + cuDoubleComplex *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, + const cuDoubleComplex *, int, const cuDoubleComplex *, int, + const cuDoubleComplex *, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZsymv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, A, lda, x, incx, beta, y, incy); } -cublasStatus_t CUBLASWINAPI cublasChemv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const cuComplex *alpha, /* host or device pointer */ - const cuComplex *A, - int lda, - const cuComplex *x, - int incx, - const cuComplex *beta, /* host or device pointer */ - cuComplex *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, const cuComplex *, cuComplex *, int); +cublasStatus_t CUBLASWINAPI +cublasChemv_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, const cuComplex *x, int incx, + const cuComplex *beta, /* host or device pointer */ + cuComplex *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const cuComplex *, + const cuComplex *, int, const cuComplex *, int, const cuComplex *, + cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasChemv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, A, lda, x, incx, beta, y, incy); } -cublasStatus_t CUBLASWINAPI cublasZhemv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const cuDoubleComplex *alpha, /* host or device pointer */ - const cuDoubleComplex *A, - int lda, - const cuDoubleComplex *x, - int incx, - const cuDoubleComplex *beta, /* host or device pointer */ - cuDoubleComplex *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI cublasZhemv_v2( + cublasHandle_t handle, cublasFillMode_t uplo, int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, const cuDoubleComplex *x, int incx, + const cuDoubleComplex *beta, /* host or device pointer */ + cuDoubleComplex *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, + const cuDoubleComplex *, int, const cuDoubleComplex *, int, + const cuDoubleComplex *, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZhemv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, A, lda, x, incx, beta, y, incy); } -cublasStatus_t CUBLASWINAPI cublasSsbmv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - int k, - const float *alpha, /* host or device pointer */ - const float *A, - int lda, - const float *x, - int incx, - const float *beta, /* host or device pointer */ - float *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, int, const float *, const float *, int, const float *, int, const float *, float *, int); +cublasStatus_t CUBLASWINAPI +cublasSsbmv_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, int k, + const float *alpha, /* host or device pointer */ + const float *A, int lda, const float *x, int incx, + const float *beta, /* host or device pointer */ + float *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, int, const float *, const float *, + int, const float *, int, const float *, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSsbmv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, k, alpha, A, lda, x, incx, beta, y, incy); } -cublasStatus_t CUBLASWINAPI cublasDsbmv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - int k, - const double *alpha, /* host or device pointer */ - const double *A, - int lda, - const double *x, - int incx, - const double *beta, /* host or device pointer */ - double *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, int, const double *, const double *, int, const double *, int, const double *, double *, int); +cublasStatus_t CUBLASWINAPI +cublasDsbmv_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, int k, + const double *alpha, /* host or device pointer */ + const double *A, int lda, const double *x, int incx, + const double *beta, /* host or device pointer */ + double *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, int, const double *, + const double *, int, const double *, int, const double *, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDsbmv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, k, alpha, A, lda, x, incx, beta, y, incy); } -cublasStatus_t CUBLASWINAPI cublasChbmv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - int k, - const cuComplex *alpha, /* host or device pointer */ - const cuComplex *A, - int lda, - const cuComplex *x, - int incx, - const cuComplex *beta, /* host or device pointer */ - cuComplex *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, const cuComplex *, cuComplex *, int); +cublasStatus_t CUBLASWINAPI +cublasChbmv_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, int k, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, const cuComplex *x, int incx, + const cuComplex *beta, /* host or device pointer */ + cuComplex *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, int, const cuComplex *, + const cuComplex *, int, const cuComplex *, int, const cuComplex *, + cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasChbmv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, k, alpha, A, lda, x, incx, beta, y, incy); } -cublasStatus_t CUBLASWINAPI cublasZhbmv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - int k, - const cuDoubleComplex *alpha, /* host or device pointer */ - const cuDoubleComplex *A, - int lda, - const cuDoubleComplex *x, - int incx, - const cuDoubleComplex *beta, /* host or device pointer */ - cuDoubleComplex *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI cublasZhbmv_v2( + cublasHandle_t handle, cublasFillMode_t uplo, int n, int k, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, const cuDoubleComplex *x, int incx, + const cuDoubleComplex *beta, /* host or device pointer */ + cuDoubleComplex *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, int, const cuDoubleComplex *, + const cuDoubleComplex *, int, const cuDoubleComplex *, int, + const cuDoubleComplex *, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZhbmv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, k, alpha, A, lda, x, incx, beta, y, incy); } -cublasStatus_t CUBLASWINAPI cublasSspmv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const float *alpha, /* host or device pointer */ - const float *AP, - const float *x, - int incx, - const float *beta, /* host or device pointer */ - float *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const float *, const float *, const float *, int, const float *, float *, int); +cublasStatus_t CUBLASWINAPI +cublasSspmv_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const float *alpha, /* host or device pointer */ + const float *AP, const float *x, int incx, + const float *beta, /* host or device pointer */ + float *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const float *, const float *, + const float *, int, const float *, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSspmv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, AP, x, incx, beta, y, incy); } -cublasStatus_t CUBLASWINAPI cublasDspmv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const double *alpha, /* host or device pointer */ - const double *AP, - const double *x, - int incx, - const double *beta, /* host or device pointer */ - double *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const double *, const double *, const double *, int, const double *, double *, int); +cublasStatus_t CUBLASWINAPI +cublasDspmv_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const double *alpha, /* host or device pointer */ + const double *AP, const double *x, int incx, + const double *beta, /* host or device pointer */ + double *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const double *, const double *, + const double *, int, const double *, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDspmv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, AP, x, incx, beta, y, incy); } -cublasStatus_t CUBLASWINAPI cublasChpmv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const cuComplex *alpha, /* host or device pointer */ - const cuComplex *AP, - const cuComplex *x, - int incx, - const cuComplex *beta, /* host or device pointer */ - cuComplex *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const cuComplex *, const cuComplex *, const cuComplex *, int, const cuComplex *, cuComplex *, int); +cublasStatus_t CUBLASWINAPI +cublasChpmv_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *AP, const cuComplex *x, int incx, + const cuComplex *beta, /* host or device pointer */ + cuComplex *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const cuComplex *, + const cuComplex *, const cuComplex *, int, const cuComplex *, cuComplex *, + int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasChpmv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, AP, x, incx, beta, y, incy); } -cublasStatus_t CUBLASWINAPI cublasZhpmv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const cuDoubleComplex *alpha, /* host or device pointer */ - const cuDoubleComplex *AP, - const cuDoubleComplex *x, - int incx, - const cuDoubleComplex *beta, /* host or device pointer */ - cuDoubleComplex *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI +cublasZhpmv_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *AP, const cuDoubleComplex *x, int incx, + const cuDoubleComplex *beta, /* host or device pointer */ + cuDoubleComplex *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, + const cuDoubleComplex *, const cuDoubleComplex *, int, + const cuDoubleComplex *, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZhpmv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, AP, x, incx, beta, y, incy); } -cublasStatus_t CUBLASWINAPI cublasSger_v2 (cublasHandle_t handle, - int m, - int n, - const float *alpha, /* host or device pointer */ - const float *x, - int incx, - const float *y, - int incy, - float *A, - int lda) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, int, const float *, const float *, int, const float *, int, float *, int); +cublasStatus_t CUBLASWINAPI cublasSger_v2( + cublasHandle_t handle, int m, int n, + const float *alpha, /* host or device pointer */ + const float *x, int incx, const float *y, int incy, float *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, int, const float *, const float *, int, + const float *, int, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSger_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, m, n, alpha, x, incx, y, incy, A, lda); } -cublasStatus_t CUBLASWINAPI cublasDger_v2 (cublasHandle_t handle, - int m, - int n, - const double *alpha, /* host or device pointer */ - const double *x, - int incx, - const double *y, - int incy, - double *A, - int lda) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, int, const double *, const double *, int, const double *, int, double *, int); +cublasStatus_t CUBLASWINAPI cublasDger_v2( + cublasHandle_t handle, int m, int n, + const double *alpha, /* host or device pointer */ + const double *x, int incx, const double *y, int incy, double *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, int, const double *, const double *, int, + const double *, int, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDger_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, m, n, alpha, x, incx, y, incy, A, lda); } -cublasStatus_t CUBLASWINAPI cublasCgeru_v2 (cublasHandle_t handle, - int m, - int n, - const cuComplex *alpha, /* host or device pointer */ - const cuComplex *x, - int incx, - const cuComplex *y, - int incy, - cuComplex *A, - int lda) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, cuComplex *, int); +cublasStatus_t CUBLASWINAPI +cublasCgeru_v2(cublasHandle_t handle, int m, int n, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *x, int incx, const cuComplex *y, int incy, + cuComplex *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, int, const cuComplex *, const cuComplex *, int, + const cuComplex *, int, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgeru_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, m, n, alpha, x, incx, y, incy, A, lda); } -cublasStatus_t CUBLASWINAPI cublasCgerc_v2 (cublasHandle_t handle, - int m, - int n, - const cuComplex *alpha, /* host or device pointer */ - const cuComplex *x, - int incx, - const cuComplex *y, - int incy, - cuComplex *A, - int lda) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, cuComplex *, int); +cublasStatus_t CUBLASWINAPI +cublasCgerc_v2(cublasHandle_t handle, int m, int n, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *x, int incx, const cuComplex *y, int incy, + cuComplex *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, int, const cuComplex *, const cuComplex *, int, + const cuComplex *, int, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgerc_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, m, n, alpha, x, incx, y, incy, A, lda); } -cublasStatus_t CUBLASWINAPI cublasZgeru_v2 (cublasHandle_t handle, - int m, - int n, - const cuDoubleComplex *alpha, /* host or device pointer */ - const cuDoubleComplex *x, - int incx, - const cuDoubleComplex *y, - int incy, - cuDoubleComplex *A, - int lda) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI +cublasZgeru_v2(cublasHandle_t handle, int m, int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *x, int incx, const cuDoubleComplex *y, + int incy, cuDoubleComplex *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, int, const cuDoubleComplex *, + const cuDoubleComplex *, int, const cuDoubleComplex *, int, + cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZgeru_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, m, n, alpha, x, incx, y, incy, A, lda); } -cublasStatus_t CUBLASWINAPI cublasZgerc_v2 (cublasHandle_t handle, - int m, - int n, - const cuDoubleComplex *alpha, /* host or device pointer */ - const cuDoubleComplex *x, - int incx, - const cuDoubleComplex *y, - int incy, - cuDoubleComplex *A, - int lda) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI +cublasZgerc_v2(cublasHandle_t handle, int m, int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *x, int incx, const cuDoubleComplex *y, + int incy, cuDoubleComplex *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, int, const cuDoubleComplex *, + const cuDoubleComplex *, int, const cuDoubleComplex *, int, + cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZgerc_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, m, n, alpha, x, incx, y, incy, A, lda); } -cublasStatus_t CUBLASWINAPI cublasSsyr_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const float *alpha, /* host or device pointer */ - const float *x, - int incx, - float *A, - int lda) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const float *, const float *, int, float *, int); +cublasStatus_t CUBLASWINAPI +cublasSsyr_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const float *alpha, /* host or device pointer */ + const float *x, int incx, float *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const float *, const float *, int, + float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSsyr_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, x, incx, A, lda); } -cublasStatus_t CUBLASWINAPI cublasDsyr_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const double *alpha, /* host or device pointer */ - const double *x, - int incx, - double *A, - int lda) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const double *, const double *, int, double *, int); +cublasStatus_t CUBLASWINAPI +cublasDsyr_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const double *alpha, /* host or device pointer */ + const double *x, int incx, double *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const double *, const double *, + int, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDsyr_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, x, incx, A, lda); } -cublasStatus_t CUBLASWINAPI cublasCsyr_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const cuComplex *alpha, /* host or device pointer */ - const cuComplex *x, - int incx, - cuComplex *A, - int lda) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const cuComplex *, const cuComplex *, int, cuComplex *, int); +cublasStatus_t CUBLASWINAPI +cublasCsyr_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *x, int incx, cuComplex *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const cuComplex *, + const cuComplex *, int, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCsyr_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, x, incx, A, lda); } -cublasStatus_t CUBLASWINAPI cublasZsyr_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const cuDoubleComplex *alpha, /* host or device pointer */ - const cuDoubleComplex *x, - int incx, - cuDoubleComplex *A, - int lda) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, const cuDoubleComplex *, int, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI +cublasZsyr_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *x, int incx, cuDoubleComplex *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, + const cuDoubleComplex *, int, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZsyr_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, x, incx, A, lda); } -cublasStatus_t CUBLASWINAPI cublasCher_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const float *alpha, /* host or device pointer */ - const cuComplex *x, - int incx, - cuComplex *A, - int lda) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const float *, const cuComplex *, int, cuComplex *, int); +cublasStatus_t CUBLASWINAPI +cublasCher_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const float *alpha, /* host or device pointer */ + const cuComplex *x, int incx, cuComplex *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const float *, const cuComplex *, + int, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCher_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, x, incx, A, lda); } -cublasStatus_t CUBLASWINAPI cublasZher_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const double *alpha, /* host or device pointer */ - const cuDoubleComplex *x, - int incx, - cuDoubleComplex *A, - int lda) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const double *, const cuDoubleComplex *, int, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI +cublasZher_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const double *alpha, /* host or device pointer */ + const cuDoubleComplex *x, int incx, cuDoubleComplex *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const double *, + const cuDoubleComplex *, int, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZher_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, x, incx, A, lda); } -cublasStatus_t CUBLASWINAPI cublasSspr_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const float *alpha, /* host or device pointer */ - const float *x, - int incx, - float *AP) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const float *, const float *, int, float *); +cublasStatus_t CUBLASWINAPI +cublasSspr_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const float *alpha, /* host or device pointer */ + const float *x, int incx, float *AP) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const float *, const float *, int, + float *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSspr_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, x, incx, AP); } -cublasStatus_t CUBLASWINAPI cublasDspr_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const double *alpha, /* host or device pointer */ - const double *x, - int incx, - double *AP) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const double *, const double *, int, double *); +cublasStatus_t CUBLASWINAPI +cublasDspr_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const double *alpha, /* host or device pointer */ + const double *x, int incx, double *AP) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const double *, const double *, + int, double *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDspr_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, x, incx, AP); } -cublasStatus_t CUBLASWINAPI cublasChpr_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const float *alpha, /* host or device pointer */ - const cuComplex *x, - int incx, - cuComplex *AP) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const float *, const cuComplex *, int, cuComplex *); +cublasStatus_t CUBLASWINAPI +cublasChpr_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const float *alpha, /* host or device pointer */ + const cuComplex *x, int incx, cuComplex *AP) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const float *, const cuComplex *, + int, cuComplex *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasChpr_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, x, incx, AP); } -cublasStatus_t CUBLASWINAPI cublasZhpr_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const double *alpha, /* host or device pointer */ - const cuDoubleComplex *x, - int incx, - cuDoubleComplex *AP) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const double *, const cuDoubleComplex *, int, cuDoubleComplex *); +cublasStatus_t CUBLASWINAPI +cublasZhpr_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const double *alpha, /* host or device pointer */ + const cuDoubleComplex *x, int incx, cuDoubleComplex *AP) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const double *, + const cuDoubleComplex *, int, cuDoubleComplex *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZhpr_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, x, incx, AP); } -cublasStatus_t CUBLASWINAPI cublasSsyr2_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const float *alpha, /* host or device pointer */ - const float *x, - int incx, - const float *y, - int incy, - float *A, - int lda) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const float *, const float *, int, const float *, int, float *, int); +cublasStatus_t CUBLASWINAPI cublasSsyr2_v2( + cublasHandle_t handle, cublasFillMode_t uplo, int n, + const float *alpha, /* host or device pointer */ + const float *x, int incx, const float *y, int incy, float *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const float *, const float *, int, + const float *, int, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSsyr2_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, x, incx, y, incy, A, lda); } -cublasStatus_t CUBLASWINAPI cublasDsyr2_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const double *alpha, /* host or device pointer */ - const double *x, - int incx, - const double *y, - int incy, - double *A, - int lda) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const double *, const double *, int, const double *, int, double *, int); +cublasStatus_t CUBLASWINAPI cublasDsyr2_v2( + cublasHandle_t handle, cublasFillMode_t uplo, int n, + const double *alpha, /* host or device pointer */ + const double *x, int incx, const double *y, int incy, double *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const double *, const double *, + int, const double *, int, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDsyr2_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, x, incx, y, incy, A, lda); } -cublasStatus_t CUBLASWINAPI cublasCsyr2_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, int n, - const cuComplex *alpha, /* host or device pointer */ - const cuComplex *x, - int incx, - const cuComplex *y, - int incy, - cuComplex *A, - int lda) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, cuComplex *, int); +cublasStatus_t CUBLASWINAPI +cublasCsyr2_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *x, int incx, const cuComplex *y, int incy, + cuComplex *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const cuComplex *, + const cuComplex *, int, const cuComplex *, int, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCsyr2_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, x, incx, y, incy, A, lda); } -cublasStatus_t CUBLASWINAPI cublasZsyr2_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const cuDoubleComplex *alpha, /* host or device pointer */ - const cuDoubleComplex *x, - int incx, - const cuDoubleComplex *y, - int incy, - cuDoubleComplex *A, - int lda) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI +cublasZsyr2_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *x, int incx, const cuDoubleComplex *y, + int incy, cuDoubleComplex *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, + const cuDoubleComplex *, int, const cuDoubleComplex *, int, + cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZsyr2_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, x, incx, y, incy, A, lda); } -cublasStatus_t CUBLASWINAPI cublasCher2_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, int n, - const cuComplex *alpha, /* host or device pointer */ - const cuComplex *x, - int incx, - const cuComplex *y, - int incy, - cuComplex *A, - int lda) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, cuComplex *, int); +cublasStatus_t CUBLASWINAPI +cublasCher2_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *x, int incx, const cuComplex *y, int incy, + cuComplex *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const cuComplex *, + const cuComplex *, int, const cuComplex *, int, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCher2_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, x, incx, y, incy, A, lda); } -cublasStatus_t CUBLASWINAPI cublasZher2_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const cuDoubleComplex *alpha, /* host or device pointer */ - const cuDoubleComplex *x, - int incx, - const cuDoubleComplex *y, - int incy, - cuDoubleComplex *A, - int lda) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI +cublasZher2_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *x, int incx, const cuDoubleComplex *y, + int incy, cuDoubleComplex *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, + const cuDoubleComplex *, int, const cuDoubleComplex *, int, + cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZher2_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, x, incx, y, incy, A, lda); } -cublasStatus_t CUBLASWINAPI cublasSspr2_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const float *alpha, /* host or device pointer */ - const float *x, - int incx, - const float *y, - int incy, - float *AP) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const float *, const float *, int, const float *, int, float *); +cublasStatus_t CUBLASWINAPI +cublasSspr2_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const float *alpha, /* host or device pointer */ + const float *x, int incx, const float *y, int incy, float *AP) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const float *, const float *, int, + const float *, int, float *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSspr2_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, x, incx, y, incy, AP); } -cublasStatus_t CUBLASWINAPI cublasDspr2_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const double *alpha, /* host or device pointer */ - const double *x, - int incx, - const double *y, - int incy, - double *AP) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const double *, const double *, int, const double *, int, double *); +cublasStatus_t CUBLASWINAPI cublasDspr2_v2( + cublasHandle_t handle, cublasFillMode_t uplo, int n, + const double *alpha, /* host or device pointer */ + const double *x, int incx, const double *y, int incy, double *AP) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const double *, const double *, + int, const double *, int, double *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDspr2_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, x, incx, y, incy, AP); } -cublasStatus_t CUBLASWINAPI cublasChpr2_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const cuComplex *alpha, /* host or device pointer */ - const cuComplex *x, - int incx, - const cuComplex *y, - int incy, - cuComplex *AP) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, cuComplex *); +cublasStatus_t CUBLASWINAPI cublasChpr2_v2( + cublasHandle_t handle, cublasFillMode_t uplo, int n, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *x, int incx, const cuComplex *y, int incy, cuComplex *AP) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const cuComplex *, + const cuComplex *, int, const cuComplex *, int, cuComplex *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasChpr2_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, x, incx, y, incy, AP); } -cublasStatus_t CUBLASWINAPI cublasZhpr2_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const cuDoubleComplex *alpha, /* host or device pointer */ - const cuDoubleComplex *x, - int incx, - const cuDoubleComplex *y, - int incy, - cuDoubleComplex *AP) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex *); +cublasStatus_t CUBLASWINAPI +cublasZhpr2_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *x, int incx, const cuDoubleComplex *y, + int incy, cuDoubleComplex *AP) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, + const cuDoubleComplex *, int, const cuDoubleComplex *, int, + cuDoubleComplex *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZhpr2_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, x, incx, y, incy, AP); } -cublasStatus_t CUBLASWINAPI cublasSgemm_v2 (cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const float *alpha, /* host or device pointer */ - const float *A, - int lda, - const float *B, - int ldb, - const float *beta, /* host or device pointer */ - float *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const float *, const float *, int, const float *, int, const float *, float *, int); +cublasStatus_t CUBLASWINAPI cublasSgemm_v2( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const float *alpha, /* host or device pointer */ + const float *A, int lda, const float *B, int ldb, + const float *beta, /* host or device pointer */ + float *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const float *, const float *, int, const float *, int, const float *, + float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSgemm_v2"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, + C, ldc); } -cublasStatus_t CUBLASWINAPI cublasDgemm_v2 (cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const double *alpha, /* host or device pointer */ - const double *A, - int lda, - const double *B, - int ldb, - const double *beta, /* host or device pointer */ - double *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const double *, const double *, int, const double *, int, const double *, double *, int); +cublasStatus_t CUBLASWINAPI cublasDgemm_v2( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const double *alpha, /* host or device pointer */ + const double *A, int lda, const double *B, int ldb, + const double *beta, /* host or device pointer */ + double *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const double *, const double *, int, const double *, int, const double *, + double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDgemm_v2"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, + C, ldc); } -cublasStatus_t CUBLASWINAPI cublasCgemm_v2 (cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const cuComplex *alpha, /* host or device pointer */ - const cuComplex *A, - int lda, - const cuComplex *B, - int ldb, - const cuComplex *beta, /* host or device pointer */ - cuComplex *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, const cuComplex *, cuComplex *, int); +cublasStatus_t CUBLASWINAPI cublasCgemm_v2( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, const cuComplex *B, int ldb, + const cuComplex *beta, /* host or device pointer */ + cuComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const cuComplex *, const cuComplex *, int, const cuComplex *, int, + const cuComplex *, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgemm_v2"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, + C, ldc); } -cublasStatus_t CUBLASWINAPI cublasCgemm3m (cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const cuComplex *alpha, /* host or device pointer */ - const cuComplex *A, - int lda, - const cuComplex *B, - int ldb, - const cuComplex *beta, /* host or device pointer */ - cuComplex *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, const cuComplex *, cuComplex *, int); +cublasStatus_t CUBLASWINAPI cublasCgemm3m( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, const cuComplex *B, int ldb, + const cuComplex *beta, /* host or device pointer */ + cuComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const cuComplex *, const cuComplex *, int, const cuComplex *, int, + const cuComplex *, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgemm3m"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, + C, ldc); } -cublasStatus_t CUBLASWINAPI cublasCgemm3mEx (cublasHandle_t handle, - cublasOperation_t transa, cublasOperation_t transb, - int m, int n, int k, - const cuComplex *alpha, - const void *A, - cudaDataType Atype, - int lda, - const void *B, - cudaDataType Btype, - int ldb, - const cuComplex *beta, - void *C, - cudaDataType Ctype, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const cuComplex *, const void *, cudaDataType, int, const void *, cudaDataType, int, const cuComplex *, void *, cudaDataType, int); +cublasStatus_t CUBLASWINAPI cublasCgemm3mEx( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const cuComplex *alpha, const void *A, + cudaDataType Atype, int lda, const void *B, cudaDataType Btype, int ldb, + const cuComplex *beta, void *C, cudaDataType Ctype, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const cuComplex *, const void *, cudaDataType, int, const void *, + cudaDataType, int, const cuComplex *, void *, cudaDataType, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgemm3mEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, + Btype, ldb, beta, C, Ctype, ldc); } -cublasStatus_t CUBLASWINAPI cublasZgemm_v2 (cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const cuDoubleComplex *alpha, /* host or device pointer */ - const cuDoubleComplex *A, - int lda, - const cuDoubleComplex *B, - int ldb, - const cuDoubleComplex *beta, /* host or device pointer */ - cuDoubleComplex *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI cublasZgemm_v2( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, const cuDoubleComplex *B, int ldb, + const cuDoubleComplex *beta, /* host or device pointer */ + cuDoubleComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const cuDoubleComplex *, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, + int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZgemm_v2"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, + C, ldc); } -cublasStatus_t CUBLASWINAPI cublasZgemm3m (cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const cuDoubleComplex *alpha, /* host or device pointer */ - const cuDoubleComplex *A, - int lda, - const cuDoubleComplex *B, - int ldb, - const cuDoubleComplex *beta, /* host or device pointer */ - cuDoubleComplex *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI +cublasZgemm3m(cublasHandle_t handle, cublasOperation_t transa, + cublasOperation_t transb, int m, int n, int k, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, const cuDoubleComplex *B, + int ldb, const cuDoubleComplex *beta, /* host or device pointer */ + cuDoubleComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const cuDoubleComplex *, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, + int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZgemm3m"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, + C, ldc); } -cublasStatus_t CUBLASWINAPI cublasSgemmEx (cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const float *alpha, /* host or device pointer */ - const void *A, - cudaDataType Atype, - int lda, - const void *B, - cudaDataType Btype, - int ldb, - const float *beta, /* host or device pointer */ - void *C, - cudaDataType Ctype, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const float *, const void *, cudaDataType, int, const void *, cudaDataType, int, const float *, void *, cudaDataType, int); +cublasStatus_t CUBLASWINAPI cublasSgemmEx( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const float *alpha, /* host or device pointer */ + const void *A, cudaDataType Atype, int lda, const void *B, + cudaDataType Btype, int ldb, const float *beta, /* host or device pointer */ + void *C, cudaDataType Ctype, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const float *, const void *, cudaDataType, int, const void *, + cudaDataType, int, const float *, void *, cudaDataType, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSgemmEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, + Btype, ldb, beta, C, Ctype, ldc); } -cublasStatus_t CUBLASWINAPI cublasGemmEx (cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const void *alpha, /* host or device pointer */ - const void *A, - cudaDataType Atype, - int lda, - const void *B, - cudaDataType Btype, - int ldb, - const void *beta, /* host or device pointer */ - void *C, - cudaDataType Ctype, - int ldc, - cudaDataType computeType, - cublasGemmAlgo_t algo) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const void *, const void *, cudaDataType, int, const void *, cudaDataType, int, const void *, void *, cudaDataType, int, cudaDataType, cublasGemmAlgo_t); +cublasStatus_t CUBLASWINAPI cublasGemmEx( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const void *alpha, /* host or device pointer */ + const void *A, cudaDataType Atype, int lda, const void *B, + cudaDataType Btype, int ldb, const void *beta, /* host or device pointer */ + void *C, cudaDataType Ctype, int ldc, cudaDataType computeType, + cublasGemmAlgo_t algo) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const void *, const void *, cudaDataType, int, const void *, cudaDataType, + int, const void *, void *, cudaDataType, int, cudaDataType, + cublasGemmAlgo_t); static auto func_ptr = LoadSymbol<FuncPtr>("cublasGemmEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc, computeType, algo); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, + Btype, ldb, beta, C, Ctype, ldc, computeType, algo); } -cublasStatus_t CUBLASWINAPI cublasCgemmEx (cublasHandle_t handle, - cublasOperation_t transa, cublasOperation_t transb, - int m, int n, int k, - const cuComplex *alpha, - const void *A, - cudaDataType Atype, - int lda, - const void *B, - cudaDataType Btype, - int ldb, - const cuComplex *beta, - void *C, - cudaDataType Ctype, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const cuComplex *, const void *, cudaDataType, int, const void *, cudaDataType, int, const cuComplex *, void *, cudaDataType, int); +cublasStatus_t CUBLASWINAPI cublasCgemmEx( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const cuComplex *alpha, const void *A, + cudaDataType Atype, int lda, const void *B, cudaDataType Btype, int ldb, + const cuComplex *beta, void *C, cudaDataType Ctype, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const cuComplex *, const void *, cudaDataType, int, const void *, + cudaDataType, int, const cuComplex *, void *, cudaDataType, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgemmEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, + Btype, ldb, beta, C, Ctype, ldc); } -cublasStatus_t CUBLASWINAPI cublasUint8gemmBias (cublasHandle_t handle, - cublasOperation_t transa, cublasOperation_t transb, cublasOperation_t transc, - int m, int n, int k, - const unsigned char *A, int A_bias, int lda, - const unsigned char *B, int B_bias, int ldb, - unsigned char *C, int C_bias, int ldc, - int C_mult, int C_shift) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, cublasOperation_t, int, int, int, const unsigned char *, int, int, const unsigned char *, int, int, unsigned char *, int, int, int, int); +cublasStatus_t CUBLASWINAPI cublasUint8gemmBias( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + cublasOperation_t transc, int m, int n, int k, const unsigned char *A, + int A_bias, int lda, const unsigned char *B, int B_bias, int ldb, + unsigned char *C, int C_bias, int ldc, int C_mult, int C_shift) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, cublasOperation_t, + int, int, int, const unsigned char *, int, int, const unsigned char *, + int, int, unsigned char *, int, int, int, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasUint8gemmBias"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, transa, transb, transc, m, n, k, A, A_bias, lda, B, B_bias, ldb, C, C_bias, ldc, C_mult, C_shift); + return func_ptr(handle, transa, transb, transc, m, n, k, A, A_bias, lda, B, + B_bias, ldb, C, C_bias, ldc, C_mult, C_shift); } -cublasStatus_t CUBLASWINAPI cublasSsyrk_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - int n, - int k, - const float *alpha, /* host or device pointer */ - const float *A, - int lda, - const float *beta, /* host or device pointer */ - float *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const float *, const float *, int, const float *, float *, int); +cublasStatus_t CUBLASWINAPI cublasSsyrk_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const float *alpha, /* host or device pointer */ + const float *A, int lda, const float *beta, /* host or device pointer */ + float *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const float *, const float *, int, const float *, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSsyrk_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, beta, C, ldc); } -cublasStatus_t CUBLASWINAPI cublasDsyrk_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - int n, - int k, - const double *alpha, /* host or device pointer */ - const double *A, - int lda, - const double *beta, /* host or device pointer */ - double *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const double *, const double *, int, const double *, double *, int); +cublasStatus_t CUBLASWINAPI cublasDsyrk_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const double *alpha, /* host or device pointer */ + const double *A, int lda, const double *beta, /* host or device pointer */ + double *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const double *, const double *, int, const double *, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDsyrk_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, beta, C, ldc); } -cublasStatus_t CUBLASWINAPI cublasCsyrk_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - int n, - int k, - const cuComplex *alpha, /* host or device pointer */ - const cuComplex *A, - int lda, - const cuComplex *beta, /* host or device pointer */ - cuComplex *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const cuComplex *, const cuComplex *, int, const cuComplex *, cuComplex *, int); +cublasStatus_t CUBLASWINAPI cublasCsyrk_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, + const cuComplex *beta, /* host or device pointer */ + cuComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const cuComplex *, const cuComplex *, int, const cuComplex *, cuComplex *, + int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCsyrk_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, beta, C, ldc); } -cublasStatus_t CUBLASWINAPI cublasZsyrk_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - int n, - int k, - const cuDoubleComplex *alpha, /* host or device pointer */ - const cuDoubleComplex *A, - int lda, - const cuDoubleComplex *beta, /* host or device pointer */ - cuDoubleComplex *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI cublasZsyrk_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, + const cuDoubleComplex *beta, /* host or device pointer */ + cuDoubleComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const cuDoubleComplex *, const cuDoubleComplex *, int, + const cuDoubleComplex *, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZsyrk_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, beta, C, ldc); } -cublasStatus_t CUBLASWINAPI cublasCsyrkEx ( cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - int n, - int k, - const cuComplex *alpha, /* host or device pointer */ - const void *A, - cudaDataType Atype, - int lda, - const cuComplex *beta, /* host or device pointer */ - void *C, - cudaDataType Ctype, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const cuComplex *, const void *, cudaDataType, int, const cuComplex *, void *, cudaDataType, int); +cublasStatus_t CUBLASWINAPI cublasCsyrkEx( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const cuComplex *alpha, /* host or device pointer */ + const void *A, cudaDataType Atype, int lda, + const cuComplex *beta, /* host or device pointer */ + void *C, cudaDataType Ctype, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const cuComplex *, const void *, cudaDataType, int, const cuComplex *, + void *, cudaDataType, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCsyrkEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, uplo, trans, n, k, alpha, A, Atype, lda, beta, C, Ctype, ldc); + return func_ptr(handle, uplo, trans, n, k, alpha, A, Atype, lda, beta, C, + Ctype, ldc); } -cublasStatus_t CUBLASWINAPI cublasCsyrk3mEx(cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - int n, - int k, - const cuComplex *alpha, - const void *A, - cudaDataType Atype, - int lda, - const cuComplex *beta, - void *C, - cudaDataType Ctype, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const cuComplex *, const void *, cudaDataType, int, const cuComplex *, void *, cudaDataType, int); +cublasStatus_t CUBLASWINAPI cublasCsyrk3mEx( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const cuComplex *alpha, const void *A, cudaDataType Atype, + int lda, const cuComplex *beta, void *C, cudaDataType Ctype, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const cuComplex *, const void *, cudaDataType, int, const cuComplex *, + void *, cudaDataType, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCsyrk3mEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, uplo, trans, n, k, alpha, A, Atype, lda, beta, C, Ctype, ldc); + return func_ptr(handle, uplo, trans, n, k, alpha, A, Atype, lda, beta, C, + Ctype, ldc); } -cublasStatus_t CUBLASWINAPI cublasCherk_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - int n, - int k, - const float *alpha, /* host or device pointer */ - const cuComplex *A, - int lda, - const float *beta, /* host or device pointer */ - cuComplex *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const float *, const cuComplex *, int, const float *, cuComplex *, int); +cublasStatus_t CUBLASWINAPI cublasCherk_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const float *alpha, /* host or device pointer */ + const cuComplex *A, int lda, const float *beta, /* host or device pointer */ + cuComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const float *, const cuComplex *, int, const float *, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCherk_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, beta, C, ldc); } -cublasStatus_t CUBLASWINAPI cublasZherk_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - int n, - int k, - const double *alpha, /* host or device pointer */ - const cuDoubleComplex *A, - int lda, - const double *beta, /* host or device pointer */ - cuDoubleComplex *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const double *, const cuDoubleComplex *, int, const double *, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI cublasZherk_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const double *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, + const double *beta, /* host or device pointer */ + cuDoubleComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const double *, const cuDoubleComplex *, int, const double *, + cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZherk_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, beta, C, ldc); } -cublasStatus_t CUBLASWINAPI cublasCherkEx (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - int n, - int k, - const float *alpha, /* host or device pointer */ - const void *A, - cudaDataType Atype, - int lda, - const float *beta, /* host or device pointer */ - void *C, - cudaDataType Ctype, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const float *, const void *, cudaDataType, int, const float *, void *, cudaDataType, int); +cublasStatus_t CUBLASWINAPI cublasCherkEx( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const float *alpha, /* host or device pointer */ + const void *A, cudaDataType Atype, int lda, + const float *beta, /* host or device pointer */ + void *C, cudaDataType Ctype, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const float *, const void *, cudaDataType, int, const float *, void *, + cudaDataType, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCherkEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, uplo, trans, n, k, alpha, A, Atype, lda, beta, C, Ctype, ldc); + return func_ptr(handle, uplo, trans, n, k, alpha, A, Atype, lda, beta, C, + Ctype, ldc); } -cublasStatus_t CUBLASWINAPI cublasCherk3mEx (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - int n, - int k, - const float *alpha, - const void *A, cudaDataType Atype, - int lda, - const float *beta, - void *C, - cudaDataType Ctype, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const float *, const void *, cudaDataType, int, const float *, void *, cudaDataType, int); +cublasStatus_t CUBLASWINAPI cublasCherk3mEx( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const float *alpha, const void *A, cudaDataType Atype, + int lda, const float *beta, void *C, cudaDataType Ctype, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const float *, const void *, cudaDataType, int, const float *, void *, + cudaDataType, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCherk3mEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, uplo, trans, n, k, alpha, A, Atype, lda, beta, C, Ctype, ldc); + return func_ptr(handle, uplo, trans, n, k, alpha, A, Atype, lda, beta, C, + Ctype, ldc); } -cublasStatus_t CUBLASWINAPI cublasSsyr2k_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - int n, - int k, - const float *alpha, /* host or device pointer */ - const float *A, - int lda, - const float *B, - int ldb, - const float *beta, /* host or device pointer */ - float *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const float *, const float *, int, const float *, int, const float *, float *, int); +cublasStatus_t CUBLASWINAPI cublasSsyr2k_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const float *alpha, /* host or device pointer */ + const float *A, int lda, const float *B, int ldb, + const float *beta, /* host or device pointer */ + float *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const float *, const float *, int, const float *, int, const float *, + float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSsyr2k_v2"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, + ldc); } -cublasStatus_t CUBLASWINAPI cublasDsyr2k_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - int n, - int k, - const double *alpha, /* host or device pointer */ - const double *A, - int lda, - const double *B, - int ldb, - const double *beta, /* host or device pointer */ - double *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const double *, const double *, int, const double *, int, const double *, double *, int); +cublasStatus_t CUBLASWINAPI cublasDsyr2k_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const double *alpha, /* host or device pointer */ + const double *A, int lda, const double *B, int ldb, + const double *beta, /* host or device pointer */ + double *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const double *, const double *, int, const double *, int, const double *, + double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDsyr2k_v2"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, + ldc); } -cublasStatus_t CUBLASWINAPI cublasCsyr2k_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - int n, - int k, - const cuComplex *alpha, /* host or device pointer */ - const cuComplex *A, - int lda, - const cuComplex *B, - int ldb, - const cuComplex *beta, /* host or device pointer */ - cuComplex *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, const cuComplex *, cuComplex *, int); +cublasStatus_t CUBLASWINAPI cublasCsyr2k_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, const cuComplex *B, int ldb, + const cuComplex *beta, /* host or device pointer */ + cuComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const cuComplex *, const cuComplex *, int, const cuComplex *, int, + const cuComplex *, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCsyr2k_v2"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, + ldc); } -cublasStatus_t CUBLASWINAPI cublasZsyr2k_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - int n, - int k, - const cuDoubleComplex *alpha, /* host or device pointer */ - const cuDoubleComplex *A, - int lda, - const cuDoubleComplex *B, - int ldb, - const cuDoubleComplex *beta, /* host or device pointer */ - cuDoubleComplex *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI cublasZsyr2k_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, const cuDoubleComplex *B, int ldb, + const cuDoubleComplex *beta, /* host or device pointer */ + cuDoubleComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const cuDoubleComplex *, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, + int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZsyr2k_v2"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, + ldc); } -cublasStatus_t CUBLASWINAPI cublasCher2k_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - int n, - int k, - const cuComplex *alpha, /* host or device pointer */ - const cuComplex *A, - int lda, - const cuComplex *B, - int ldb, - const float *beta, /* host or device pointer */ - cuComplex *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, const float *, cuComplex *, int); +cublasStatus_t CUBLASWINAPI cublasCher2k_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, const cuComplex *B, int ldb, + const float *beta, /* host or device pointer */ + cuComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const cuComplex *, const cuComplex *, int, const cuComplex *, int, + const float *, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCher2k_v2"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, + ldc); } -cublasStatus_t CUBLASWINAPI cublasZher2k_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - int n, - int k, - const cuDoubleComplex *alpha, /* host or device pointer */ - const cuDoubleComplex *A, - int lda, - const cuDoubleComplex *B, - int ldb, - const double *beta, /* host or device pointer */ - cuDoubleComplex *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, const double *, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI cublasZher2k_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, const cuDoubleComplex *B, int ldb, + const double *beta, /* host or device pointer */ + cuDoubleComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const cuDoubleComplex *, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, const double *, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZher2k_v2"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, + ldc); } -cublasStatus_t CUBLASWINAPI cublasSsyrkx (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - int n, - int k, - const float *alpha, /* host or device pointer */ - const float *A, - int lda, - const float *B, - int ldb, - const float *beta, /* host or device pointer */ - float *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const float *, const float *, int, const float *, int, const float *, float *, int); +cublasStatus_t CUBLASWINAPI cublasSsyrkx( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const float *alpha, /* host or device pointer */ + const float *A, int lda, const float *B, int ldb, + const float *beta, /* host or device pointer */ + float *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const float *, const float *, int, const float *, int, const float *, + float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSsyrkx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, + ldc); } -cublasStatus_t CUBLASWINAPI cublasDsyrkx (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - int n, - int k, - const double *alpha, /* host or device pointer */ - const double *A, - int lda, - const double *B, - int ldb, - const double *beta, /* host or device pointer */ - double *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const double *, const double *, int, const double *, int, const double *, double *, int); +cublasStatus_t CUBLASWINAPI cublasDsyrkx( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const double *alpha, /* host or device pointer */ + const double *A, int lda, const double *B, int ldb, + const double *beta, /* host or device pointer */ + double *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const double *, const double *, int, const double *, int, const double *, + double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDsyrkx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, + ldc); } -cublasStatus_t CUBLASWINAPI cublasCsyrkx (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - int n, - int k, - const cuComplex *alpha, /* host or device pointer */ - const cuComplex *A, - int lda, - const cuComplex *B, - int ldb, - const cuComplex *beta, /* host or device pointer */ - cuComplex *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, const cuComplex *, cuComplex *, int); +cublasStatus_t CUBLASWINAPI cublasCsyrkx( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, const cuComplex *B, int ldb, + const cuComplex *beta, /* host or device pointer */ + cuComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const cuComplex *, const cuComplex *, int, const cuComplex *, int, + const cuComplex *, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCsyrkx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, + ldc); } -cublasStatus_t CUBLASWINAPI cublasZsyrkx (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - int n, - int k, - const cuDoubleComplex *alpha, /* host or device pointer */ - const cuDoubleComplex *A, - int lda, - const cuDoubleComplex *B, - int ldb, - const cuDoubleComplex *beta, /* host or device pointer */ - cuDoubleComplex *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI cublasZsyrkx( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, const cuDoubleComplex *B, int ldb, + const cuDoubleComplex *beta, /* host or device pointer */ + cuDoubleComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const cuDoubleComplex *, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, + int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZsyrkx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, + ldc); } -cublasStatus_t CUBLASWINAPI cublasCherkx (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - int n, - int k, - const cuComplex *alpha, /* host or device pointer */ - const cuComplex *A, - int lda, - const cuComplex *B, - int ldb, - const float *beta, /* host or device pointer */ - cuComplex *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, const float *, cuComplex *, int); +cublasStatus_t CUBLASWINAPI cublasCherkx( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, const cuComplex *B, int ldb, + const float *beta, /* host or device pointer */ + cuComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const cuComplex *, const cuComplex *, int, const cuComplex *, int, + const float *, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCherkx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, + ldc); } -cublasStatus_t CUBLASWINAPI cublasZherkx (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - int n, - int k, - const cuDoubleComplex *alpha, /* host or device pointer */ - const cuDoubleComplex *A, - int lda, - const cuDoubleComplex *B, - int ldb, - const double *beta, /* host or device pointer */ - cuDoubleComplex *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, const double *, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI cublasZherkx( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, const cuDoubleComplex *B, int ldb, + const double *beta, /* host or device pointer */ + cuDoubleComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const cuDoubleComplex *, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, const double *, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZherkx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, + ldc); } -cublasStatus_t CUBLASWINAPI cublasSsymm_v2 (cublasHandle_t handle, - cublasSideMode_t side, - cublasFillMode_t uplo, - int m, - int n, - const float *alpha, /* host or device pointer */ - const float *A, - int lda, - const float *B, - int ldb, - const float *beta, /* host or device pointer */ - float *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, cublasFillMode_t, int, int, const float *, const float *, int, const float *, int, const float *, float *, int); +cublasStatus_t CUBLASWINAPI cublasSsymm_v2( + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, int m, + int n, const float *alpha, /* host or device pointer */ + const float *A, int lda, const float *B, int ldb, + const float *beta, /* host or device pointer */ + float *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, cublasFillMode_t, int, int, + const float *, const float *, int, const float *, int, const float *, + float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSsymm_v2"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc); + return func_ptr(handle, side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, + ldc); } -cublasStatus_t CUBLASWINAPI cublasDsymm_v2 (cublasHandle_t handle, - cublasSideMode_t side, - cublasFillMode_t uplo, - int m, - int n, - const double *alpha, /* host or device pointer */ - const double *A, - int lda, - const double *B, - int ldb, - const double *beta, /* host or device pointer */ - double *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, cublasFillMode_t, int, int, const double *, const double *, int, const double *, int, const double *, double *, int); +cublasStatus_t CUBLASWINAPI cublasDsymm_v2( + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, int m, + int n, const double *alpha, /* host or device pointer */ + const double *A, int lda, const double *B, int ldb, + const double *beta, /* host or device pointer */ + double *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, cublasFillMode_t, int, int, + const double *, const double *, int, const double *, int, const double *, + double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDsymm_v2"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc); + return func_ptr(handle, side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, + ldc); } -cublasStatus_t CUBLASWINAPI cublasCsymm_v2 (cublasHandle_t handle, - cublasSideMode_t side, - cublasFillMode_t uplo, - int m, - int n, - const cuComplex *alpha, /* host or device pointer */ - const cuComplex *A, - int lda, - const cuComplex *B, - int ldb, - const cuComplex *beta, /* host or device pointer */ - cuComplex *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, cublasFillMode_t, int, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, const cuComplex *, cuComplex *, int); +cublasStatus_t CUBLASWINAPI cublasCsymm_v2( + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, int m, + int n, const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, const cuComplex *B, int ldb, + const cuComplex *beta, /* host or device pointer */ + cuComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, cublasFillMode_t, int, int, + const cuComplex *, const cuComplex *, int, const cuComplex *, int, + const cuComplex *, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCsymm_v2"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc); + return func_ptr(handle, side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, + ldc); } -cublasStatus_t CUBLASWINAPI cublasZsymm_v2 (cublasHandle_t handle, - cublasSideMode_t side, - cublasFillMode_t uplo, - int m, - int n, - const cuDoubleComplex *alpha, /* host or device pointer */ - const cuDoubleComplex *A, - int lda, - const cuDoubleComplex *B, - int ldb, - const cuDoubleComplex *beta, /* host or device pointer */ - cuDoubleComplex *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, cublasFillMode_t, int, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI cublasZsymm_v2( + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, int m, + int n, const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, const cuDoubleComplex *B, int ldb, + const cuDoubleComplex *beta, /* host or device pointer */ + cuDoubleComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, cublasFillMode_t, int, int, + const cuDoubleComplex *, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, + int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZsymm_v2"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc); + return func_ptr(handle, side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, + ldc); } -cublasStatus_t CUBLASWINAPI cublasChemm_v2 (cublasHandle_t handle, - cublasSideMode_t side, - cublasFillMode_t uplo, - int m, - int n, - const cuComplex *alpha, /* host or device pointer */ - const cuComplex *A, - int lda, - const cuComplex *B, - int ldb, - const cuComplex *beta, /* host or device pointer */ - cuComplex *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, cublasFillMode_t, int, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, const cuComplex *, cuComplex *, int); +cublasStatus_t CUBLASWINAPI cublasChemm_v2( + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, int m, + int n, const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, const cuComplex *B, int ldb, + const cuComplex *beta, /* host or device pointer */ + cuComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, cublasFillMode_t, int, int, + const cuComplex *, const cuComplex *, int, const cuComplex *, int, + const cuComplex *, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasChemm_v2"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc); + return func_ptr(handle, side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, + ldc); } -cublasStatus_t CUBLASWINAPI cublasZhemm_v2 (cublasHandle_t handle, - cublasSideMode_t side, - cublasFillMode_t uplo, - int m, - int n, - const cuDoubleComplex *alpha, /* host or device pointer */ - const cuDoubleComplex *A, - int lda, - const cuDoubleComplex *B, - int ldb, - const cuDoubleComplex *beta, /* host or device pointer */ - cuDoubleComplex *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, cublasFillMode_t, int, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI cublasZhemm_v2( + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, int m, + int n, const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, const cuDoubleComplex *B, int ldb, + const cuDoubleComplex *beta, /* host or device pointer */ + cuDoubleComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, cublasFillMode_t, int, int, + const cuDoubleComplex *, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, + int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZhemm_v2"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc); + return func_ptr(handle, side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, + ldc); } -cublasStatus_t CUBLASWINAPI cublasStrsm_v2 (cublasHandle_t handle, - cublasSideMode_t side, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int m, - int n, - const float *alpha, /* host or device pointer */ - const float *A, - int lda, - float *B, - int ldb) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, int, const float *, const float *, int, float *, int); +cublasStatus_t CUBLASWINAPI cublasStrsm_v2( + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, cublasDiagType_t diag, int m, int n, + const float *alpha, /* host or device pointer */ + const float *A, int lda, float *B, int ldb) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + cublasDiagType_t, int, int, const float *, const float *, int, float *, + int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasStrsm_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb); } -cublasStatus_t CUBLASWINAPI cublasDtrsm_v2 (cublasHandle_t handle, - cublasSideMode_t side, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int m, - int n, - const double *alpha, /* host or device pointer */ - const double *A, - int lda, - double *B, - int ldb) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, int, const double *, const double *, int, double *, int); +cublasStatus_t CUBLASWINAPI cublasDtrsm_v2( + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, cublasDiagType_t diag, int m, int n, + const double *alpha, /* host or device pointer */ + const double *A, int lda, double *B, int ldb) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + cublasDiagType_t, int, int, const double *, const double *, int, double *, + int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtrsm_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb); } -cublasStatus_t CUBLASWINAPI cublasCtrsm_v2(cublasHandle_t handle, - cublasSideMode_t side, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int m, - int n, - const cuComplex *alpha, /* host or device pointer */ - const cuComplex *A, - int lda, - cuComplex *B, - int ldb) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, int, const cuComplex *, const cuComplex *, int, cuComplex *, int); +cublasStatus_t CUBLASWINAPI cublasCtrsm_v2( + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, cublasDiagType_t diag, int m, int n, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, cuComplex *B, int ldb) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + cublasDiagType_t, int, int, const cuComplex *, const cuComplex *, int, + cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtrsm_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb); } -cublasStatus_t CUBLASWINAPI cublasZtrsm_v2(cublasHandle_t handle, - cublasSideMode_t side, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int m, - int n, - const cuDoubleComplex *alpha, /* host or device pointer */ - const cuDoubleComplex *A, - int lda, - cuDoubleComplex *B, - int ldb) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, int, const cuDoubleComplex *, const cuDoubleComplex *, int, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI cublasZtrsm_v2( + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, cublasDiagType_t diag, int m, int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, cuDoubleComplex *B, int ldb) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + cublasDiagType_t, int, int, const cuDoubleComplex *, + const cuDoubleComplex *, int, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtrsm_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb); } -cublasStatus_t CUBLASWINAPI cublasStrmm_v2 (cublasHandle_t handle, - cublasSideMode_t side, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int m, - int n, - const float *alpha, /* host or device pointer */ - const float *A, - int lda, - const float *B, - int ldb, - float *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, int, const float *, const float *, int, const float *, int, float *, int); +cublasStatus_t CUBLASWINAPI cublasStrmm_v2( + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, cublasDiagType_t diag, int m, int n, + const float *alpha, /* host or device pointer */ + const float *A, int lda, const float *B, int ldb, float *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + cublasDiagType_t, int, int, const float *, const float *, int, + const float *, int, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasStrmm_v2"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, C, ldc); + return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, + C, ldc); } -cublasStatus_t CUBLASWINAPI cublasDtrmm_v2 (cublasHandle_t handle, - cublasSideMode_t side, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int m, - int n, - const double *alpha, /* host or device pointer */ - const double *A, - int lda, - const double *B, - int ldb, - double *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, int, const double *, const double *, int, const double *, int, double *, int); +cublasStatus_t CUBLASWINAPI cublasDtrmm_v2( + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, cublasDiagType_t diag, int m, int n, + const double *alpha, /* host or device pointer */ + const double *A, int lda, const double *B, int ldb, double *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + cublasDiagType_t, int, int, const double *, const double *, int, + const double *, int, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtrmm_v2"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, C, ldc); + return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, + C, ldc); } -cublasStatus_t CUBLASWINAPI cublasCtrmm_v2(cublasHandle_t handle, - cublasSideMode_t side, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int m, - int n, - const cuComplex *alpha, /* host or device pointer */ - const cuComplex *A, - int lda, - const cuComplex *B, - int ldb, - cuComplex *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, cuComplex *, int); +cublasStatus_t CUBLASWINAPI cublasCtrmm_v2( + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, cublasDiagType_t diag, int m, int n, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, const cuComplex *B, int ldb, cuComplex *C, + int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + cublasDiagType_t, int, int, const cuComplex *, const cuComplex *, int, + const cuComplex *, int, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtrmm_v2"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, C, ldc); + return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, + C, ldc); } -cublasStatus_t CUBLASWINAPI cublasZtrmm_v2(cublasHandle_t handle, cublasSideMode_t side, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int m, - int n, - const cuDoubleComplex *alpha, /* host or device pointer */ - const cuDoubleComplex *A, - int lda, - const cuDoubleComplex *B, - int ldb, - cuDoubleComplex *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI cublasZtrmm_v2( + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, cublasDiagType_t diag, int m, int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, const cuDoubleComplex *B, int ldb, + cuDoubleComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + cublasDiagType_t, int, int, const cuDoubleComplex *, + const cuDoubleComplex *, int, const cuDoubleComplex *, int, + cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtrmm_v2"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, C, ldc); + return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, + C, ldc); } cublasStatus_t CUBLASWINAPI cublasSgemmBatched( @@ -3079,7 +2712,8 @@ cublasStatus_t CUBLASWINAPI cublasSgemmBatched( const float *, float *const[], int, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSgemmBatched"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount); + return func_ptr(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, + ldb, beta, Carray, ldc, batchCount); } cublasStatus_t CUBLASWINAPI cublasDgemmBatched( @@ -3094,7 +2728,8 @@ cublasStatus_t CUBLASWINAPI cublasDgemmBatched( const double *, double *const[], int, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDgemmBatched"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount); + return func_ptr(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, + ldb, beta, Carray, ldc, batchCount); } cublasStatus_t CUBLASWINAPI cublasCgemmBatched( @@ -3110,7 +2745,8 @@ cublasStatus_t CUBLASWINAPI cublasCgemmBatched( int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgemmBatched"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount); + return func_ptr(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, + ldb, beta, Carray, ldc, batchCount); } cublasStatus_t CUBLASWINAPI cublasCgemm3mBatched( @@ -3126,7 +2762,8 @@ cublasStatus_t CUBLASWINAPI cublasCgemm3mBatched( int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgemm3mBatched"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount); + return func_ptr(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, + ldb, beta, Carray, ldc, batchCount); } cublasStatus_t CUBLASWINAPI @@ -3144,7 +2781,8 @@ cublasZgemmBatched(cublasHandle_t handle, cublasOperation_t transa, cuDoubleComplex *const[], int, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZgemmBatched"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount); + return func_ptr(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, + ldb, beta, Carray, ldc, batchCount); } cublasStatus_t CUBLASWINAPI cublasGemmBatchedEx( @@ -3188,200 +2826,155 @@ cublasStatus_t CUBLASWINAPI cublasGemmStridedBatchedEx( batchCount, computeType, algo); } -cublasStatus_t CUBLASWINAPI cublasSgemmStridedBatched (cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const float *alpha, /* host or device pointer */ - const float *A, - int lda, - long long int strideA, /* purposely signed */ - const float *B, - int ldb, - long long int strideB, - const float *beta, /* host or device pointer */ - float *C, - int ldc, - long long int strideC, - int batchCount) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const float *, const float *, int, long long, const float *, int, long long, const float *, float *, int, long long, int); +cublasStatus_t CUBLASWINAPI cublasSgemmStridedBatched( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const float *alpha, /* host or device pointer */ + const float *A, int lda, long long int strideA, /* purposely signed */ + const float *B, int ldb, long long int strideB, + const float *beta, /* host or device pointer */ + float *C, int ldc, long long int strideC, int batchCount) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const float *, const float *, int, long long, const float *, int, + long long, const float *, float *, int, long long, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSgemmStridedBatched"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, ldb, strideB, beta, C, ldc, strideC, batchCount); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, + ldb, strideB, beta, C, ldc, strideC, batchCount); } -cublasStatus_t CUBLASWINAPI cublasDgemmStridedBatched (cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const double *alpha, /* host or device pointer */ - const double *A, - int lda, - long long int strideA, /* purposely signed */ - const double *B, - int ldb, - long long int strideB, - const double *beta, /* host or device pointer */ - double *C, - int ldc, - long long int strideC, - int batchCount) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const double *, const double *, int, long long, const double *, int, long long, const double *, double *, int, long long, int); +cublasStatus_t CUBLASWINAPI cublasDgemmStridedBatched( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const double *alpha, /* host or device pointer */ + const double *A, int lda, long long int strideA, /* purposely signed */ + const double *B, int ldb, long long int strideB, + const double *beta, /* host or device pointer */ + double *C, int ldc, long long int strideC, int batchCount) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const double *, const double *, int, long long, const double *, int, + long long, const double *, double *, int, long long, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDgemmStridedBatched"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, ldb, strideB, beta, C, ldc, strideC, batchCount); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, + ldb, strideB, beta, C, ldc, strideC, batchCount); } -cublasStatus_t CUBLASWINAPI cublasCgemmStridedBatched (cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const cuComplex *alpha, /* host or device pointer */ - const cuComplex *A, - int lda, - long long int strideA, /* purposely signed */ - const cuComplex *B, - int ldb, - long long int strideB, - const cuComplex *beta, /* host or device pointer */ - cuComplex *C, - int ldc, - long long int strideC, - int batchCount) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const cuComplex *, const cuComplex *, int, long long, const cuComplex *, int, long long, const cuComplex *, cuComplex *, int, long long, int); +cublasStatus_t CUBLASWINAPI cublasCgemmStridedBatched( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, long long int strideA, /* purposely signed */ + const cuComplex *B, int ldb, long long int strideB, + const cuComplex *beta, /* host or device pointer */ + cuComplex *C, int ldc, long long int strideC, int batchCount) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const cuComplex *, const cuComplex *, int, long long, const cuComplex *, + int, long long, const cuComplex *, cuComplex *, int, long long, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgemmStridedBatched"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, ldb, strideB, beta, C, ldc, strideC, batchCount); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, + ldb, strideB, beta, C, ldc, strideC, batchCount); } -cublasStatus_t CUBLASWINAPI cublasCgemm3mStridedBatched (cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const cuComplex *alpha, /* host or device pointer */ - const cuComplex *A, - int lda, - long long int strideA, /* purposely signed */ - const cuComplex *B, - int ldb, - long long int strideB, - const cuComplex *beta, /* host or device pointer */ - cuComplex *C, - int ldc, - long long int strideC, - int batchCount) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const cuComplex *, const cuComplex *, int, long long, const cuComplex *, int, long long, const cuComplex *, cuComplex *, int, long long, int); +cublasStatus_t CUBLASWINAPI cublasCgemm3mStridedBatched( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, long long int strideA, /* purposely signed */ + const cuComplex *B, int ldb, long long int strideB, + const cuComplex *beta, /* host or device pointer */ + cuComplex *C, int ldc, long long int strideC, int batchCount) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const cuComplex *, const cuComplex *, int, long long, const cuComplex *, + int, long long, const cuComplex *, cuComplex *, int, long long, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgemm3mStridedBatched"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, ldb, strideB, beta, C, ldc, strideC, batchCount); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, + ldb, strideB, beta, C, ldc, strideC, batchCount); } -cublasStatus_t CUBLASWINAPI cublasZgemmStridedBatched (cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const cuDoubleComplex *alpha, /* host or device pointer */ - const cuDoubleComplex *A, - int lda, - long long int strideA, /* purposely signed */ - const cuDoubleComplex *B, - int ldb, - long long int strideB, - const cuDoubleComplex *beta, /* host or device poi */ - cuDoubleComplex *C, - int ldc, - long long int strideC, - int batchCount) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const cuDoubleComplex *, const cuDoubleComplex *, int, long long, const cuDoubleComplex *, int, long long, const cuDoubleComplex *, cuDoubleComplex *, int, long long, int); +cublasStatus_t CUBLASWINAPI cublasZgemmStridedBatched( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, + long long int strideA, /* purposely signed */ + const cuDoubleComplex *B, int ldb, long long int strideB, + const cuDoubleComplex *beta, /* host or device poi */ + cuDoubleComplex *C, int ldc, long long int strideC, int batchCount) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const cuDoubleComplex *, const cuDoubleComplex *, int, long long, + const cuDoubleComplex *, int, long long, const cuDoubleComplex *, + cuDoubleComplex *, int, long long, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZgemmStridedBatched"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, ldb, strideB, beta, C, ldc, strideC, batchCount); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, + ldb, strideB, beta, C, ldc, strideC, batchCount); } -cublasStatus_t CUBLASWINAPI cublasSgeam(cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - const float *alpha, /* host or device pointer */ - const float *A, - int lda, - const float *beta , /* host or device pointer */ - const float *B, - int ldb, - float *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, const float *, const float *, int, const float *, const float *, int, float *, int); +cublasStatus_t CUBLASWINAPI cublasSgeam( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, const float *alpha, /* host or device pointer */ + const float *A, int lda, const float *beta, /* host or device pointer */ + const float *B, int ldb, float *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, + const float *, const float *, int, const float *, const float *, int, + float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSgeam"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc); + return func_ptr(handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, + ldc); } -cublasStatus_t CUBLASWINAPI cublasDgeam(cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - const double *alpha, /* host or device pointer */ - const double *A, - int lda, - const double *beta, /* host or device pointer */ - const double *B, - int ldb, - double *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, const double *, const double *, int, const double *, const double *, int, double *, int); +cublasStatus_t CUBLASWINAPI cublasDgeam( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, const double *alpha, /* host or device pointer */ + const double *A, int lda, const double *beta, /* host or device pointer */ + const double *B, int ldb, double *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, + const double *, const double *, int, const double *, const double *, int, + double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDgeam"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc); + return func_ptr(handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, + ldc); } -cublasStatus_t CUBLASWINAPI cublasCgeam(cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - const cuComplex *alpha, /* host or device pointer */ - const cuComplex *A, - int lda, - const cuComplex *beta, /* host or device pointer */ - const cuComplex *B, - int ldb, - cuComplex *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, const cuComplex *, const cuComplex *, int, const cuComplex *, const cuComplex *, int, cuComplex *, int); +cublasStatus_t CUBLASWINAPI cublasCgeam( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, + const cuComplex *beta, /* host or device pointer */ + const cuComplex *B, int ldb, cuComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, + const cuComplex *, const cuComplex *, int, const cuComplex *, + const cuComplex *, int, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgeam"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc); + return func_ptr(handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, + ldc); } -cublasStatus_t CUBLASWINAPI cublasZgeam(cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - const cuDoubleComplex *alpha, /* host or device pointer */ - const cuDoubleComplex *A, - int lda, - const cuDoubleComplex *beta, /* host or device pointer */ - const cuDoubleComplex *B, - int ldb, - cuDoubleComplex *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, const cuDoubleComplex *, int, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI cublasZgeam( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, + const cuDoubleComplex *beta, /* host or device pointer */ + const cuDoubleComplex *B, int ldb, cuDoubleComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, + const cuDoubleComplex *, const cuDoubleComplex *, int, + const cuDoubleComplex *, const cuDoubleComplex *, int, cuDoubleComplex *, + int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZgeam"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc); + return func_ptr(handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, + ldc); } cublasStatus_t CUBLASWINAPI cublasSgetrfBatched( @@ -3494,7 +3087,8 @@ cublasStatus_t CUBLASWINAPI cublasSgetrsBatched( const int *, float *const[], int, int *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSgetrsBatched"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, trans, n, nrhs, Aarray, lda, devIpiv, Barray, ldb, info, batchSize); + return func_ptr(handle, trans, n, nrhs, Aarray, lda, devIpiv, Barray, ldb, + info, batchSize); } cublasStatus_t CUBLASWINAPI cublasDgetrsBatched( @@ -3506,7 +3100,8 @@ cublasStatus_t CUBLASWINAPI cublasDgetrsBatched( const int *, double *const[], int, int *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDgetrsBatched"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, trans, n, nrhs, Aarray, lda, devIpiv, Barray, ldb, info, batchSize); + return func_ptr(handle, trans, n, nrhs, Aarray, lda, devIpiv, Barray, ldb, + info, batchSize); } cublasStatus_t CUBLASWINAPI cublasCgetrsBatched( @@ -3518,7 +3113,8 @@ cublasStatus_t CUBLASWINAPI cublasCgetrsBatched( int, const int *, cuComplex *const[], int, int *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgetrsBatched"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, trans, n, nrhs, Aarray, lda, devIpiv, Barray, ldb, info, batchSize); + return func_ptr(handle, trans, n, nrhs, Aarray, lda, devIpiv, Barray, ldb, + info, batchSize); } cublasStatus_t CUBLASWINAPI cublasZgetrsBatched( @@ -3531,7 +3127,8 @@ cublasStatus_t CUBLASWINAPI cublasZgetrsBatched( cuDoubleComplex *const[], int, int *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZgetrsBatched"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, trans, n, nrhs, Aarray, lda, devIpiv, Barray, ldb, info, batchSize); + return func_ptr(handle, trans, n, nrhs, Aarray, lda, devIpiv, Barray, ldb, + info, batchSize); } cublasStatus_t CUBLASWINAPI cublasStrsmBatched( @@ -3546,7 +3143,8 @@ cublasStatus_t CUBLASWINAPI cublasStrsmBatched( float *const[], int, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasStrsmBatched"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, batchCount); + return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, + batchCount); } cublasStatus_t CUBLASWINAPI cublasDtrsmBatched( @@ -3561,7 +3159,8 @@ cublasStatus_t CUBLASWINAPI cublasDtrsmBatched( double *const[], int, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtrsmBatched"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, batchCount); + return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, + batchCount); } cublasStatus_t CUBLASWINAPI cublasCtrsmBatched( @@ -3576,7 +3175,8 @@ cublasStatus_t CUBLASWINAPI cublasCtrsmBatched( int, cuComplex *const[], int, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtrsmBatched"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, batchCount); + return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, + batchCount); } cublasStatus_t CUBLASWINAPI cublasZtrsmBatched( @@ -3591,7 +3191,8 @@ cublasStatus_t CUBLASWINAPI cublasZtrsmBatched( const cuDoubleComplex *const[], int, cuDoubleComplex *const[], int, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtrsmBatched"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, batchCount); + return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, + batchCount); } cublasStatus_t CUBLASWINAPI cublasSmatinvBatched( @@ -3710,7 +3311,8 @@ cublasSgelsBatched(cublasHandle_t handle, cublasOperation_t trans, int m, int n, float *const[], int, int *, int *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSgelsBatched"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, trans, m, n, nrhs, Aarray, lda, Carray, ldc, info, devInfoArray, batchSize); + return func_ptr(handle, trans, m, n, nrhs, Aarray, lda, Carray, ldc, info, + devInfoArray, batchSize); } cublasStatus_t CUBLASWINAPI @@ -3724,7 +3326,8 @@ cublasDgelsBatched(cublasHandle_t handle, cublasOperation_t trans, int m, int n, double *const[], int, int *, int *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDgelsBatched"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, trans, m, n, nrhs, Aarray, lda, Carray, ldc, info, devInfoArray, batchSize); + return func_ptr(handle, trans, m, n, nrhs, Aarray, lda, Carray, ldc, info, + devInfoArray, batchSize); } cublasStatus_t CUBLASWINAPI @@ -3737,7 +3340,8 @@ cublasCgelsBatched(cublasHandle_t handle, cublasOperation_t trans, int m, int n, cuComplex *const[], int, int *, int *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgelsBatched"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, trans, m, n, nrhs, Aarray, lda, Carray, ldc, info, devInfoArray, batchSize); + return func_ptr(handle, trans, m, n, nrhs, Aarray, lda, Carray, ldc, info, + devInfoArray, batchSize); } cublasStatus_t CUBLASWINAPI @@ -3751,1467 +3355,1666 @@ cublasZgelsBatched(cublasHandle_t handle, cublasOperation_t trans, int m, int n, int *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZgelsBatched"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, trans, m, n, nrhs, Aarray, lda, Carray, ldc, info, devInfoArray, batchSize); + return func_ptr(handle, trans, m, n, nrhs, Aarray, lda, Carray, ldc, info, + devInfoArray, batchSize); } cublasStatus_t CUBLASWINAPI cublasSdgmm(cublasHandle_t handle, - cublasSideMode_t mode, - int m, - int n, - const float *A, - int lda, - const float *x, - int incx, - float *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, int, int, const float *, int, const float *, int, float *, int); + cublasSideMode_t mode, int m, int n, + const float *A, int lda, const float *x, + int incx, float *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, int, int, const float *, int, + const float *, int, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSdgmm"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, mode, m, n, A, lda, x, incx, C, ldc); } cublasStatus_t CUBLASWINAPI cublasDdgmm(cublasHandle_t handle, - cublasSideMode_t mode, - int m, - int n, - const double *A, - int lda, - const double *x, - int incx, - double *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, int, int, const double *, int, const double *, int, double *, int); + cublasSideMode_t mode, int m, int n, + const double *A, int lda, + const double *x, int incx, double *C, + int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, int, int, const double *, int, + const double *, int, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDdgmm"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, mode, m, n, A, lda, x, incx, C, ldc); } cublasStatus_t CUBLASWINAPI cublasCdgmm(cublasHandle_t handle, - cublasSideMode_t mode, - int m, - int n, - const cuComplex *A, - int lda, - const cuComplex *x, - int incx, - cuComplex *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, int, int, const cuComplex *, int, const cuComplex *, int, cuComplex *, int); + cublasSideMode_t mode, int m, int n, + const cuComplex *A, int lda, + const cuComplex *x, int incx, + cuComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, int, int, const cuComplex *, int, + const cuComplex *, int, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCdgmm"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, mode, m, n, A, lda, x, incx, C, ldc); } cublasStatus_t CUBLASWINAPI cublasZdgmm(cublasHandle_t handle, - cublasSideMode_t mode, - int m, - int n, - const cuDoubleComplex *A, - int lda, - const cuDoubleComplex *x, - int incx, - cuDoubleComplex *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, int, int, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); + cublasSideMode_t mode, int m, int n, + const cuDoubleComplex *A, int lda, + const cuDoubleComplex *x, int incx, + cuDoubleComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, int, int, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZdgmm"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, mode, m, n, A, lda, x, incx, C, ldc); } -cublasStatus_t CUBLASWINAPI cublasStpttr ( cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const float *AP, - float *A, - int lda ) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const float *, float *, int); +cublasStatus_t CUBLASWINAPI cublasStpttr(cublasHandle_t handle, + cublasFillMode_t uplo, int n, + const float *AP, float *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const float *, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasStpttr"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, AP, A, lda); } -cublasStatus_t CUBLASWINAPI cublasDtpttr ( cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const double *AP, - double *A, - int lda ) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const double *, double *, int); +cublasStatus_t CUBLASWINAPI cublasDtpttr(cublasHandle_t handle, + cublasFillMode_t uplo, int n, + const double *AP, double *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const double *, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtpttr"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, AP, A, lda); } -cublasStatus_t CUBLASWINAPI cublasCtpttr ( cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const cuComplex *AP, - cuComplex *A, - int lda ) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const cuComplex *, cuComplex *, int); +cublasStatus_t CUBLASWINAPI cublasCtpttr(cublasHandle_t handle, + cublasFillMode_t uplo, int n, + const cuComplex *AP, cuComplex *A, + int lda) { + using FuncPtr = + cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, + const cuComplex *, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtpttr"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, AP, A, lda); } -cublasStatus_t CUBLASWINAPI cublasZtpttr ( cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const cuDoubleComplex *AP, - cuDoubleComplex *A, - int lda ) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI cublasZtpttr(cublasHandle_t handle, + cublasFillMode_t uplo, int n, + const cuDoubleComplex *AP, + cuDoubleComplex *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, + cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtpttr"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, AP, A, lda); } -cublasStatus_t CUBLASWINAPI cublasStrttp ( cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const float *A, - int lda, - float *AP ) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const float *, int, float *); +cublasStatus_t CUBLASWINAPI cublasStrttp(cublasHandle_t handle, + cublasFillMode_t uplo, int n, + const float *A, int lda, float *AP) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const float *, int, float *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasStrttp"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, A, lda, AP); } -cublasStatus_t CUBLASWINAPI cublasDtrttp ( cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const double *A, - int lda, - double *AP ) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const double *, int, double *); +cublasStatus_t CUBLASWINAPI cublasDtrttp(cublasHandle_t handle, + cublasFillMode_t uplo, int n, + const double *A, int lda, double *AP) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const double *, int, double *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtrttp"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, A, lda, AP); } -cublasStatus_t CUBLASWINAPI cublasCtrttp ( cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const cuComplex *A, - int lda, - cuComplex *AP ) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const cuComplex *, int, cuComplex *); +cublasStatus_t CUBLASWINAPI cublasCtrttp(cublasHandle_t handle, + cublasFillMode_t uplo, int n, + const cuComplex *A, int lda, + cuComplex *AP) { + using FuncPtr = + cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, + const cuComplex *, int, cuComplex *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtrttp"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, A, lda, AP); } -cublasStatus_t CUBLASWINAPI cublasZtrttp ( cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const cuDoubleComplex *A, - int lda, - cuDoubleComplex *AP ) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, int, cuDoubleComplex *); +cublasStatus_t CUBLASWINAPI cublasZtrttp(cublasHandle_t handle, + cublasFillMode_t uplo, int n, + const cuDoubleComplex *A, int lda, + cuDoubleComplex *AP) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, int, + cuDoubleComplex *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtrttp"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, A, lda, AP); } -cublasStatus CUBLASWINAPI cublasInit (void) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(); +cublasStatus CUBLASWINAPI cublasInit(void) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(); static auto func_ptr = LoadSymbol<FuncPtr>("cublasInit"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(); } -cublasStatus CUBLASWINAPI cublasShutdown (void) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(); +cublasStatus CUBLASWINAPI cublasShutdown(void) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(); static auto func_ptr = LoadSymbol<FuncPtr>("cublasShutdown"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(); } -cublasStatus CUBLASWINAPI cublasGetError (void) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(); +cublasStatus CUBLASWINAPI cublasGetError(void) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(); static auto func_ptr = LoadSymbol<FuncPtr>("cublasGetError"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(); } cublasStatus CUBLASWINAPI cublasGetVersion(int *version) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(int *); + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(int *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasGetVersion"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(version); } -cublasStatus CUBLASWINAPI cublasAlloc (int n, int elemSize, void **devicePtr) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(int, int, void **); +cublasStatus CUBLASWINAPI cublasAlloc(int n, int elemSize, void **devicePtr) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(int, int, void **); static auto func_ptr = LoadSymbol<FuncPtr>("cublasAlloc"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(n, elemSize, devicePtr); } -cublasStatus CUBLASWINAPI cublasFree (void *devicePtr) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(void *); +cublasStatus CUBLASWINAPI cublasFree(void *devicePtr) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(void *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasFree"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(devicePtr); } -cublasStatus CUBLASWINAPI cublasSetKernelStream (cudaStream_t stream) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cudaStream_t); +cublasStatus CUBLASWINAPI cublasSetKernelStream(cudaStream_t stream) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cudaStream_t); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSetKernelStream"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(stream); } -float CUBLASWINAPI cublasSnrm2 (int n, const float *x, int incx) { - using FuncPtr = float (CUBLASWINAPI *)(int, const float *, int); +float CUBLASWINAPI cublasSnrm2(int n, const float *x, int incx) { + using FuncPtr = float(CUBLASWINAPI *)(int, const float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSnrm2"); if (!func_ptr) LogFatalSymbolNotFound("cublasSnrm2"); return func_ptr(n, x, incx); } -double CUBLASWINAPI cublasDnrm2 (int n, const double *x, int incx) { - using FuncPtr = double (CUBLASWINAPI *)(int, const double *, int); +double CUBLASWINAPI cublasDnrm2(int n, const double *x, int incx) { + using FuncPtr = double(CUBLASWINAPI *)(int, const double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDnrm2"); if (!func_ptr) LogFatalSymbolNotFound("cublasDnrm2"); return func_ptr(n, x, incx); } -float CUBLASWINAPI cublasScnrm2 (int n, const cuComplex *x, int incx) { - using FuncPtr = float (CUBLASWINAPI *)(int, const cuComplex *, int); +float CUBLASWINAPI cublasScnrm2(int n, const cuComplex *x, int incx) { + using FuncPtr = float(CUBLASWINAPI *)(int, const cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasScnrm2"); if (!func_ptr) LogFatalSymbolNotFound("cublasScnrm2"); return func_ptr(n, x, incx); } -double CUBLASWINAPI cublasDznrm2 (int n, const cuDoubleComplex *x, int incx) { - using FuncPtr = double (CUBLASWINAPI *)(int, const cuDoubleComplex *, int); +double CUBLASWINAPI cublasDznrm2(int n, const cuDoubleComplex *x, int incx) { + using FuncPtr = double(CUBLASWINAPI *)(int, const cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDznrm2"); if (!func_ptr) LogFatalSymbolNotFound("cublasDznrm2"); return func_ptr(n, x, incx); } -float CUBLASWINAPI cublasSdot (int n, const float *x, int incx, const float *y, - int incy) { - using FuncPtr = float (CUBLASWINAPI *)(int, const float *, int, const float *, int); +float CUBLASWINAPI cublasSdot(int n, const float *x, int incx, const float *y, + int incy) { + using FuncPtr = + float(CUBLASWINAPI *)(int, const float *, int, const float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSdot"); if (!func_ptr) LogFatalSymbolNotFound("cublasSdot"); return func_ptr(n, x, incx, y, incy); } -double CUBLASWINAPI cublasDdot (int n, const double *x, int incx, const double *y, - int incy) { - using FuncPtr = double (CUBLASWINAPI *)(int, const double *, int, const double *, int); +double CUBLASWINAPI cublasDdot(int n, const double *x, int incx, + const double *y, int incy) { + using FuncPtr = + double(CUBLASWINAPI *)(int, const double *, int, const double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDdot"); if (!func_ptr) LogFatalSymbolNotFound("cublasDdot"); return func_ptr(n, x, incx, y, incy); } -cuComplex CUBLASWINAPI cublasCdotu (int n, const cuComplex *x, int incx, const cuComplex *y, - int incy) { - using FuncPtr = cuComplex (CUBLASWINAPI *)(int, const cuComplex *, int, const cuComplex *, int); +cuComplex CUBLASWINAPI cublasCdotu(int n, const cuComplex *x, int incx, + const cuComplex *y, int incy) { + using FuncPtr = cuComplex(CUBLASWINAPI *)(int, const cuComplex *, int, + const cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCdotu"); if (!func_ptr) LogFatalSymbolNotFound("cublasCdotu"); return func_ptr(n, x, incx, y, incy); } -cuComplex CUBLASWINAPI cublasCdotc (int n, const cuComplex *x, int incx, const cuComplex *y, - int incy) { - using FuncPtr = cuComplex (CUBLASWINAPI *)(int, const cuComplex *, int, const cuComplex *, int); +cuComplex CUBLASWINAPI cublasCdotc(int n, const cuComplex *x, int incx, + const cuComplex *y, int incy) { + using FuncPtr = cuComplex(CUBLASWINAPI *)(int, const cuComplex *, int, + const cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCdotc"); if (!func_ptr) LogFatalSymbolNotFound("cublasCdotc"); return func_ptr(n, x, incx, y, incy); } -cuDoubleComplex CUBLASWINAPI cublasZdotu (int n, const cuDoubleComplex *x, int incx, const cuDoubleComplex *y, - int incy) { - using FuncPtr = cuDoubleComplex (CUBLASWINAPI *)(int, const cuDoubleComplex *, int, const cuDoubleComplex *, int); +cuDoubleComplex CUBLASWINAPI cublasZdotu(int n, const cuDoubleComplex *x, + int incx, const cuDoubleComplex *y, + int incy) { + using FuncPtr = cuDoubleComplex(CUBLASWINAPI *)( + int, const cuDoubleComplex *, int, const cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZdotu"); if (!func_ptr) LogFatalSymbolNotFound("cublasZdotu"); return func_ptr(n, x, incx, y, incy); } -cuDoubleComplex CUBLASWINAPI cublasZdotc (int n, const cuDoubleComplex *x, int incx, const cuDoubleComplex *y, - int incy) { - using FuncPtr = cuDoubleComplex (CUBLASWINAPI *)(int, const cuDoubleComplex *, int, const cuDoubleComplex *, int); +cuDoubleComplex CUBLASWINAPI cublasZdotc(int n, const cuDoubleComplex *x, + int incx, const cuDoubleComplex *y, + int incy) { + using FuncPtr = cuDoubleComplex(CUBLASWINAPI *)( + int, const cuDoubleComplex *, int, const cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZdotc"); if (!func_ptr) LogFatalSymbolNotFound("cublasZdotc"); return func_ptr(n, x, incx, y, incy); } -void CUBLASWINAPI cublasSscal (int n, float alpha, float *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(int, float, float *, int); +void CUBLASWINAPI cublasSscal(int n, float alpha, float *x, int incx) { + using FuncPtr = void(CUBLASWINAPI *)(int, float, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSscal"); if (!func_ptr) LogFatalSymbolNotFound("cublasSscal"); return func_ptr(n, alpha, x, incx); } -void CUBLASWINAPI cublasDscal (int n, double alpha, double *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(int, double, double *, int); +void CUBLASWINAPI cublasDscal(int n, double alpha, double *x, int incx) { + using FuncPtr = void(CUBLASWINAPI *)(int, double, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDscal"); if (!func_ptr) LogFatalSymbolNotFound("cublasDscal"); return func_ptr(n, alpha, x, incx); } -void CUBLASWINAPI cublasCscal (int n, cuComplex alpha, cuComplex *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(int, cuComplex, cuComplex *, int); +void CUBLASWINAPI cublasCscal(int n, cuComplex alpha, cuComplex *x, int incx) { + using FuncPtr = void(CUBLASWINAPI *)(int, cuComplex, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCscal"); if (!func_ptr) LogFatalSymbolNotFound("cublasCscal"); return func_ptr(n, alpha, x, incx); } -void CUBLASWINAPI cublasZscal (int n, cuDoubleComplex alpha, cuDoubleComplex *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(int, cuDoubleComplex, cuDoubleComplex *, int); +void CUBLASWINAPI cublasZscal(int n, cuDoubleComplex alpha, cuDoubleComplex *x, + int incx) { + using FuncPtr = + void(CUBLASWINAPI *)(int, cuDoubleComplex, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZscal"); if (!func_ptr) LogFatalSymbolNotFound("cublasZscal"); return func_ptr(n, alpha, x, incx); } -void CUBLASWINAPI cublasCsscal (int n, float alpha, cuComplex *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(int, float, cuComplex *, int); +void CUBLASWINAPI cublasCsscal(int n, float alpha, cuComplex *x, int incx) { + using FuncPtr = void(CUBLASWINAPI *)(int, float, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCsscal"); if (!func_ptr) LogFatalSymbolNotFound("cublasCsscal"); return func_ptr(n, alpha, x, incx); } -void CUBLASWINAPI cublasZdscal (int n, double alpha, cuDoubleComplex *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(int, double, cuDoubleComplex *, int); +void CUBLASWINAPI cublasZdscal(int n, double alpha, cuDoubleComplex *x, + int incx) { + using FuncPtr = void(CUBLASWINAPI *)(int, double, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZdscal"); if (!func_ptr) LogFatalSymbolNotFound("cublasZdscal"); return func_ptr(n, alpha, x, incx); } -void CUBLASWINAPI cublasSaxpy (int n, float alpha, const float *x, int incx, - float *y, int incy) { - using FuncPtr = void (CUBLASWINAPI *)(int, float, const float *, int, float *, int); +void CUBLASWINAPI cublasSaxpy(int n, float alpha, const float *x, int incx, + float *y, int incy) { + using FuncPtr = + void(CUBLASWINAPI *)(int, float, const float *, int, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSaxpy"); if (!func_ptr) LogFatalSymbolNotFound("cublasSaxpy"); return func_ptr(n, alpha, x, incx, y, incy); } -void CUBLASWINAPI cublasDaxpy (int n, double alpha, const double *x, - int incx, double *y, int incy) { - using FuncPtr = void (CUBLASWINAPI *)(int, double, const double *, int, double *, int); +void CUBLASWINAPI cublasDaxpy(int n, double alpha, const double *x, int incx, + double *y, int incy) { + using FuncPtr = + void(CUBLASWINAPI *)(int, double, const double *, int, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDaxpy"); if (!func_ptr) LogFatalSymbolNotFound("cublasDaxpy"); return func_ptr(n, alpha, x, incx, y, incy); } -void CUBLASWINAPI cublasCaxpy (int n, cuComplex alpha, const cuComplex *x, - int incx, cuComplex *y, int incy) { - using FuncPtr = void (CUBLASWINAPI *)(int, cuComplex, const cuComplex *, int, cuComplex *, int); +void CUBLASWINAPI cublasCaxpy(int n, cuComplex alpha, const cuComplex *x, + int incx, cuComplex *y, int incy) { + using FuncPtr = void(CUBLASWINAPI *)(int, cuComplex, const cuComplex *, int, + cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCaxpy"); if (!func_ptr) LogFatalSymbolNotFound("cublasCaxpy"); return func_ptr(n, alpha, x, incx, y, incy); } -void CUBLASWINAPI cublasZaxpy (int n, cuDoubleComplex alpha, const cuDoubleComplex *x, - int incx, cuDoubleComplex *y, int incy) { - using FuncPtr = void (CUBLASWINAPI *)(int, cuDoubleComplex, const cuDoubleComplex *, int, cuDoubleComplex *, int); +void CUBLASWINAPI cublasZaxpy(int n, cuDoubleComplex alpha, + const cuDoubleComplex *x, int incx, + cuDoubleComplex *y, int incy) { + using FuncPtr = + void(CUBLASWINAPI *)(int, cuDoubleComplex, const cuDoubleComplex *, int, + cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZaxpy"); if (!func_ptr) LogFatalSymbolNotFound("cublasZaxpy"); return func_ptr(n, alpha, x, incx, y, incy); } -void CUBLASWINAPI cublasScopy (int n, const float *x, int incx, float *y, - int incy) { - using FuncPtr = void (CUBLASWINAPI *)(int, const float *, int, float *, int); +void CUBLASWINAPI cublasScopy(int n, const float *x, int incx, float *y, + int incy) { + using FuncPtr = void(CUBLASWINAPI *)(int, const float *, int, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasScopy"); if (!func_ptr) LogFatalSymbolNotFound("cublasScopy"); return func_ptr(n, x, incx, y, incy); } -void CUBLASWINAPI cublasDcopy (int n, const double *x, int incx, double *y, - int incy) { - using FuncPtr = void (CUBLASWINAPI *)(int, const double *, int, double *, int); +void CUBLASWINAPI cublasDcopy(int n, const double *x, int incx, double *y, + int incy) { + using FuncPtr = void(CUBLASWINAPI *)(int, const double *, int, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDcopy"); if (!func_ptr) LogFatalSymbolNotFound("cublasDcopy"); return func_ptr(n, x, incx, y, incy); } -void CUBLASWINAPI cublasCcopy (int n, const cuComplex *x, int incx, cuComplex *y, - int incy) { - using FuncPtr = void (CUBLASWINAPI *)(int, const cuComplex *, int, cuComplex *, int); +void CUBLASWINAPI cublasCcopy(int n, const cuComplex *x, int incx, cuComplex *y, + int incy) { + using FuncPtr = + void(CUBLASWINAPI *)(int, const cuComplex *, int, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCcopy"); if (!func_ptr) LogFatalSymbolNotFound("cublasCcopy"); return func_ptr(n, x, incx, y, incy); } -void CUBLASWINAPI cublasZcopy (int n, const cuDoubleComplex *x, int incx, cuDoubleComplex *y, - int incy) { - using FuncPtr = void (CUBLASWINAPI *)(int, const cuDoubleComplex *, int, cuDoubleComplex *, int); +void CUBLASWINAPI cublasZcopy(int n, const cuDoubleComplex *x, int incx, + cuDoubleComplex *y, int incy) { + using FuncPtr = void(CUBLASWINAPI *)(int, const cuDoubleComplex *, int, + cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZcopy"); if (!func_ptr) LogFatalSymbolNotFound("cublasZcopy"); return func_ptr(n, x, incx, y, incy); } -void CUBLASWINAPI cublasSswap (int n, float *x, int incx, float *y, int incy) { - using FuncPtr = void (CUBLASWINAPI *)(int, float *, int, float *, int); +void CUBLASWINAPI cublasSswap(int n, float *x, int incx, float *y, int incy) { + using FuncPtr = void(CUBLASWINAPI *)(int, float *, int, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSswap"); if (!func_ptr) LogFatalSymbolNotFound("cublasSswap"); return func_ptr(n, x, incx, y, incy); } -void CUBLASWINAPI cublasDswap (int n, double *x, int incx, double *y, int incy) { - using FuncPtr = void (CUBLASWINAPI *)(int, double *, int, double *, int); +void CUBLASWINAPI cublasDswap(int n, double *x, int incx, double *y, int incy) { + using FuncPtr = void(CUBLASWINAPI *)(int, double *, int, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDswap"); if (!func_ptr) LogFatalSymbolNotFound("cublasDswap"); return func_ptr(n, x, incx, y, incy); } -void CUBLASWINAPI cublasCswap (int n, cuComplex *x, int incx, cuComplex *y, int incy) { - using FuncPtr = void (CUBLASWINAPI *)(int, cuComplex *, int, cuComplex *, int); +void CUBLASWINAPI cublasCswap(int n, cuComplex *x, int incx, cuComplex *y, + int incy) { + using FuncPtr = void(CUBLASWINAPI *)(int, cuComplex *, int, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCswap"); if (!func_ptr) LogFatalSymbolNotFound("cublasCswap"); return func_ptr(n, x, incx, y, incy); } -void CUBLASWINAPI cublasZswap (int n, cuDoubleComplex *x, int incx, cuDoubleComplex *y, int incy) { - using FuncPtr = void (CUBLASWINAPI *)(int, cuDoubleComplex *, int, cuDoubleComplex *, int); +void CUBLASWINAPI cublasZswap(int n, cuDoubleComplex *x, int incx, + cuDoubleComplex *y, int incy) { + using FuncPtr = + void(CUBLASWINAPI *)(int, cuDoubleComplex *, int, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZswap"); if (!func_ptr) LogFatalSymbolNotFound("cublasZswap"); return func_ptr(n, x, incx, y, incy); } -int CUBLASWINAPI cublasIsamax (int n, const float *x, int incx) { - using FuncPtr = int (CUBLASWINAPI *)(int, const float *, int); +int CUBLASWINAPI cublasIsamax(int n, const float *x, int incx) { + using FuncPtr = int(CUBLASWINAPI *)(int, const float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasIsamax"); if (!func_ptr) LogFatalSymbolNotFound("cublasIsamax"); return func_ptr(n, x, incx); } -int CUBLASWINAPI cublasIdamax (int n, const double *x, int incx) { - using FuncPtr = int (CUBLASWINAPI *)(int, const double *, int); +int CUBLASWINAPI cublasIdamax(int n, const double *x, int incx) { + using FuncPtr = int(CUBLASWINAPI *)(int, const double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasIdamax"); if (!func_ptr) LogFatalSymbolNotFound("cublasIdamax"); return func_ptr(n, x, incx); } -int CUBLASWINAPI cublasIcamax (int n, const cuComplex *x, int incx) { - using FuncPtr = int (CUBLASWINAPI *)(int, const cuComplex *, int); +int CUBLASWINAPI cublasIcamax(int n, const cuComplex *x, int incx) { + using FuncPtr = int(CUBLASWINAPI *)(int, const cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasIcamax"); if (!func_ptr) LogFatalSymbolNotFound("cublasIcamax"); return func_ptr(n, x, incx); } -int CUBLASWINAPI cublasIzamax (int n, const cuDoubleComplex *x, int incx) { - using FuncPtr = int (CUBLASWINAPI *)(int, const cuDoubleComplex *, int); +int CUBLASWINAPI cublasIzamax(int n, const cuDoubleComplex *x, int incx) { + using FuncPtr = int(CUBLASWINAPI *)(int, const cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasIzamax"); if (!func_ptr) LogFatalSymbolNotFound("cublasIzamax"); return func_ptr(n, x, incx); } -int CUBLASWINAPI cublasIsamin (int n, const float *x, int incx) { - using FuncPtr = int (CUBLASWINAPI *)(int, const float *, int); +int CUBLASWINAPI cublasIsamin(int n, const float *x, int incx) { + using FuncPtr = int(CUBLASWINAPI *)(int, const float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasIsamin"); if (!func_ptr) LogFatalSymbolNotFound("cublasIsamin"); return func_ptr(n, x, incx); } -int CUBLASWINAPI cublasIdamin (int n, const double *x, int incx) { - using FuncPtr = int (CUBLASWINAPI *)(int, const double *, int); +int CUBLASWINAPI cublasIdamin(int n, const double *x, int incx) { + using FuncPtr = int(CUBLASWINAPI *)(int, const double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasIdamin"); if (!func_ptr) LogFatalSymbolNotFound("cublasIdamin"); return func_ptr(n, x, incx); } -int CUBLASWINAPI cublasIcamin (int n, const cuComplex *x, int incx) { - using FuncPtr = int (CUBLASWINAPI *)(int, const cuComplex *, int); +int CUBLASWINAPI cublasIcamin(int n, const cuComplex *x, int incx) { + using FuncPtr = int(CUBLASWINAPI *)(int, const cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasIcamin"); if (!func_ptr) LogFatalSymbolNotFound("cublasIcamin"); return func_ptr(n, x, incx); } -int CUBLASWINAPI cublasIzamin (int n, const cuDoubleComplex *x, int incx) { - using FuncPtr = int (CUBLASWINAPI *)(int, const cuDoubleComplex *, int); +int CUBLASWINAPI cublasIzamin(int n, const cuDoubleComplex *x, int incx) { + using FuncPtr = int(CUBLASWINAPI *)(int, const cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasIzamin"); if (!func_ptr) LogFatalSymbolNotFound("cublasIzamin"); return func_ptr(n, x, incx); } -float CUBLASWINAPI cublasSasum (int n, const float *x, int incx) { - using FuncPtr = float (CUBLASWINAPI *)(int, const float *, int); +float CUBLASWINAPI cublasSasum(int n, const float *x, int incx) { + using FuncPtr = float(CUBLASWINAPI *)(int, const float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSasum"); if (!func_ptr) LogFatalSymbolNotFound("cublasSasum"); return func_ptr(n, x, incx); } -double CUBLASWINAPI cublasDasum (int n, const double *x, int incx) { - using FuncPtr = double (CUBLASWINAPI *)(int, const double *, int); +double CUBLASWINAPI cublasDasum(int n, const double *x, int incx) { + using FuncPtr = double(CUBLASWINAPI *)(int, const double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDasum"); if (!func_ptr) LogFatalSymbolNotFound("cublasDasum"); return func_ptr(n, x, incx); } -float CUBLASWINAPI cublasScasum (int n, const cuComplex *x, int incx) { - using FuncPtr = float (CUBLASWINAPI *)(int, const cuComplex *, int); +float CUBLASWINAPI cublasScasum(int n, const cuComplex *x, int incx) { + using FuncPtr = float(CUBLASWINAPI *)(int, const cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasScasum"); if (!func_ptr) LogFatalSymbolNotFound("cublasScasum"); return func_ptr(n, x, incx); } -double CUBLASWINAPI cublasDzasum (int n, const cuDoubleComplex *x, int incx) { - using FuncPtr = double (CUBLASWINAPI *)(int, const cuDoubleComplex *, int); +double CUBLASWINAPI cublasDzasum(int n, const cuDoubleComplex *x, int incx) { + using FuncPtr = double(CUBLASWINAPI *)(int, const cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDzasum"); if (!func_ptr) LogFatalSymbolNotFound("cublasDzasum"); return func_ptr(n, x, incx); } -void CUBLASWINAPI cublasSrot (int n, float *x, int incx, float *y, int incy, - float sc, float ss) { - using FuncPtr = void (CUBLASWINAPI *)(int, float *, int, float *, int, float, float); +void CUBLASWINAPI cublasSrot(int n, float *x, int incx, float *y, int incy, + float sc, float ss) { + using FuncPtr = + void(CUBLASWINAPI *)(int, float *, int, float *, int, float, float); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSrot"); if (!func_ptr) LogFatalSymbolNotFound("cublasSrot"); return func_ptr(n, x, incx, y, incy, sc, ss); } -void CUBLASWINAPI cublasDrot (int n, double *x, int incx, double *y, int incy, - double sc, double ss) { - using FuncPtr = void (CUBLASWINAPI *)(int, double *, int, double *, int, double, double); +void CUBLASWINAPI cublasDrot(int n, double *x, int incx, double *y, int incy, + double sc, double ss) { + using FuncPtr = + void(CUBLASWINAPI *)(int, double *, int, double *, int, double, double); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDrot"); if (!func_ptr) LogFatalSymbolNotFound("cublasDrot"); return func_ptr(n, x, incx, y, incy, sc, ss); } -void CUBLASWINAPI cublasCrot (int n, cuComplex *x, int incx, cuComplex *y, - int incy, float c, cuComplex s) { - using FuncPtr = void (CUBLASWINAPI *)(int, cuComplex *, int, cuComplex *, int, float, cuComplex); +void CUBLASWINAPI cublasCrot(int n, cuComplex *x, int incx, cuComplex *y, + int incy, float c, cuComplex s) { + using FuncPtr = void(CUBLASWINAPI *)(int, cuComplex *, int, cuComplex *, int, + float, cuComplex); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCrot"); if (!func_ptr) LogFatalSymbolNotFound("cublasCrot"); return func_ptr(n, x, incx, y, incy, c, s); } -void CUBLASWINAPI cublasZrot (int n, cuDoubleComplex *x, int incx, - cuDoubleComplex *y, int incy, double sc, - cuDoubleComplex cs) { - using FuncPtr = void (CUBLASWINAPI *)(int, cuDoubleComplex *, int, cuDoubleComplex *, int, double, cuDoubleComplex); +void CUBLASWINAPI cublasZrot(int n, cuDoubleComplex *x, int incx, + cuDoubleComplex *y, int incy, double sc, + cuDoubleComplex cs) { + using FuncPtr = + void(CUBLASWINAPI *)(int, cuDoubleComplex *, int, cuDoubleComplex *, int, + double, cuDoubleComplex); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZrot"); if (!func_ptr) LogFatalSymbolNotFound("cublasZrot"); return func_ptr(n, x, incx, y, incy, sc, cs); } -void CUBLASWINAPI cublasCsrot (int n, cuComplex *x, int incx, cuComplex *y, - int incy, float c, float s) { - using FuncPtr = void (CUBLASWINAPI *)(int, cuComplex *, int, cuComplex *, int, float, float); +void CUBLASWINAPI cublasCsrot(int n, cuComplex *x, int incx, cuComplex *y, + int incy, float c, float s) { + using FuncPtr = void(CUBLASWINAPI *)(int, cuComplex *, int, cuComplex *, int, + float, float); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCsrot"); if (!func_ptr) LogFatalSymbolNotFound("cublasCsrot"); return func_ptr(n, x, incx, y, incy, c, s); } -void CUBLASWINAPI cublasZdrot (int n, cuDoubleComplex *x, int incx, - cuDoubleComplex *y, int incy, double c, double s) { - using FuncPtr = void (CUBLASWINAPI *)(int, cuDoubleComplex *, int, cuDoubleComplex *, int, double, double); +void CUBLASWINAPI cublasZdrot(int n, cuDoubleComplex *x, int incx, + cuDoubleComplex *y, int incy, double c, + double s) { + using FuncPtr = void(CUBLASWINAPI *)(int, cuDoubleComplex *, int, + cuDoubleComplex *, int, double, double); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZdrot"); if (!func_ptr) LogFatalSymbolNotFound("cublasZdrot"); return func_ptr(n, x, incx, y, incy, c, s); } -void CUBLASWINAPI cublasSrotg (float *sa, float *sb, float *sc, float *ss) { - using FuncPtr = void (CUBLASWINAPI *)(float *, float *, float *, float *); +void CUBLASWINAPI cublasSrotg(float *sa, float *sb, float *sc, float *ss) { + using FuncPtr = void(CUBLASWINAPI *)(float *, float *, float *, float *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSrotg"); if (!func_ptr) LogFatalSymbolNotFound("cublasSrotg"); return func_ptr(sa, sb, sc, ss); } -void CUBLASWINAPI cublasDrotg (double *sa, double *sb, double *sc, double *ss) { - using FuncPtr = void (CUBLASWINAPI *)(double *, double *, double *, double *); +void CUBLASWINAPI cublasDrotg(double *sa, double *sb, double *sc, double *ss) { + using FuncPtr = void(CUBLASWINAPI *)(double *, double *, double *, double *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDrotg"); if (!func_ptr) LogFatalSymbolNotFound("cublasDrotg"); return func_ptr(sa, sb, sc, ss); } -void CUBLASWINAPI cublasCrotg (cuComplex *ca, cuComplex cb, float *sc, - cuComplex *cs) { - using FuncPtr = void (CUBLASWINAPI *)(cuComplex *, cuComplex, float *, cuComplex *); +void CUBLASWINAPI cublasCrotg(cuComplex *ca, cuComplex cb, float *sc, + cuComplex *cs) { + using FuncPtr = + void(CUBLASWINAPI *)(cuComplex *, cuComplex, float *, cuComplex *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCrotg"); if (!func_ptr) LogFatalSymbolNotFound("cublasCrotg"); return func_ptr(ca, cb, sc, cs); } -void CUBLASWINAPI cublasZrotg (cuDoubleComplex *ca, cuDoubleComplex cb, double *sc, - cuDoubleComplex *cs) { - using FuncPtr = void (CUBLASWINAPI *)(cuDoubleComplex *, cuDoubleComplex, double *, cuDoubleComplex *); +void CUBLASWINAPI cublasZrotg(cuDoubleComplex *ca, cuDoubleComplex cb, + double *sc, cuDoubleComplex *cs) { + using FuncPtr = void(CUBLASWINAPI *)(cuDoubleComplex *, cuDoubleComplex, + double *, cuDoubleComplex *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZrotg"); if (!func_ptr) LogFatalSymbolNotFound("cublasZrotg"); return func_ptr(ca, cb, sc, cs); } -void CUBLASWINAPI cublasSrotm(int n, float *x, int incx, float *y, int incy, - const float* sparam) { - using FuncPtr = void (CUBLASWINAPI *)(int, float *, int, float *, int, const float *); +void CUBLASWINAPI cublasSrotm(int n, float *x, int incx, float *y, int incy, + const float *sparam) { + using FuncPtr = + void(CUBLASWINAPI *)(int, float *, int, float *, int, const float *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSrotm"); if (!func_ptr) LogFatalSymbolNotFound("cublasSrotm"); return func_ptr(n, x, incx, y, incy, sparam); } -void CUBLASWINAPI cublasDrotm(int n, double *x, int incx, double *y, int incy, - const double* sparam) { - using FuncPtr = void (CUBLASWINAPI *)(int, double *, int, double *, int, const double *); +void CUBLASWINAPI cublasDrotm(int n, double *x, int incx, double *y, int incy, + const double *sparam) { + using FuncPtr = + void(CUBLASWINAPI *)(int, double *, int, double *, int, const double *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDrotm"); if (!func_ptr) LogFatalSymbolNotFound("cublasDrotm"); return func_ptr(n, x, incx, y, incy, sparam); } -void CUBLASWINAPI cublasSrotmg (float *sd1, float *sd2, float *sx1, - const float *sy1, float* sparam) { - using FuncPtr = void (CUBLASWINAPI *)(float *, float *, float *, const float *, float *); +void CUBLASWINAPI cublasSrotmg(float *sd1, float *sd2, float *sx1, + const float *sy1, float *sparam) { + using FuncPtr = + void(CUBLASWINAPI *)(float *, float *, float *, const float *, float *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSrotmg"); if (!func_ptr) LogFatalSymbolNotFound("cublasSrotmg"); return func_ptr(sd1, sd2, sx1, sy1, sparam); } -void CUBLASWINAPI cublasDrotmg (double *sd1, double *sd2, double *sx1, - const double *sy1, double* sparam) { - using FuncPtr = void (CUBLASWINAPI *)(double *, double *, double *, const double *, double *); +void CUBLASWINAPI cublasDrotmg(double *sd1, double *sd2, double *sx1, + const double *sy1, double *sparam) { + using FuncPtr = void(CUBLASWINAPI *)(double *, double *, double *, + const double *, double *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDrotmg"); if (!func_ptr) LogFatalSymbolNotFound("cublasDrotmg"); return func_ptr(sd1, sd2, sx1, sy1, sparam); } -void CUBLASWINAPI cublasSgemv (char trans, int m, int n, float alpha, - const float *A, int lda, const float *x, int incx, - float beta, float *y, int incy) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, int, float, const float *, int, const float *, int, float, float *, int); +void CUBLASWINAPI cublasSgemv(char trans, int m, int n, float alpha, + const float *A, int lda, const float *x, int incx, + float beta, float *y, int incy) { + using FuncPtr = + void(CUBLASWINAPI *)(char, int, int, float, const float *, int, + const float *, int, float, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSgemv"); if (!func_ptr) LogFatalSymbolNotFound("cublasSgemv"); return func_ptr(trans, m, n, alpha, A, lda, x, incx, beta, y, incy); } -void CUBLASWINAPI cublasDgemv (char trans, int m, int n, double alpha, - const double *A, int lda, const double *x, int incx, - double beta, double *y, int incy) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, int, double, const double *, int, const double *, int, double, double *, int); +void CUBLASWINAPI cublasDgemv(char trans, int m, int n, double alpha, + const double *A, int lda, const double *x, + int incx, double beta, double *y, int incy) { + using FuncPtr = + void(CUBLASWINAPI *)(char, int, int, double, const double *, int, + const double *, int, double, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDgemv"); if (!func_ptr) LogFatalSymbolNotFound("cublasDgemv"); return func_ptr(trans, m, n, alpha, A, lda, x, incx, beta, y, incy); } -void CUBLASWINAPI cublasCgemv (char trans, int m, int n, cuComplex alpha, - const cuComplex *A, int lda, const cuComplex *x, int incx, - cuComplex beta, cuComplex *y, int incy) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, int, cuComplex, const cuComplex *, int, const cuComplex *, int, cuComplex, cuComplex *, int); +void CUBLASWINAPI cublasCgemv(char trans, int m, int n, cuComplex alpha, + const cuComplex *A, int lda, const cuComplex *x, + int incx, cuComplex beta, cuComplex *y, + int incy) { + using FuncPtr = + void(CUBLASWINAPI *)(char, int, int, cuComplex, const cuComplex *, int, + const cuComplex *, int, cuComplex, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgemv"); if (!func_ptr) LogFatalSymbolNotFound("cublasCgemv"); return func_ptr(trans, m, n, alpha, A, lda, x, incx, beta, y, incy); } -void CUBLASWINAPI cublasZgemv (char trans, int m, int n, cuDoubleComplex alpha, - const cuDoubleComplex *A, int lda, const cuDoubleComplex *x, int incx, - cuDoubleComplex beta, cuDoubleComplex *y, int incy) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, int, cuDoubleComplex, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); +void CUBLASWINAPI cublasZgemv(char trans, int m, int n, cuDoubleComplex alpha, + const cuDoubleComplex *A, int lda, + const cuDoubleComplex *x, int incx, + cuDoubleComplex beta, cuDoubleComplex *y, + int incy) { + using FuncPtr = void(CUBLASWINAPI *)( + char, int, int, cuDoubleComplex, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZgemv"); if (!func_ptr) LogFatalSymbolNotFound("cublasZgemv"); return func_ptr(trans, m, n, alpha, A, lda, x, incx, beta, y, incy); } -void CUBLASWINAPI cublasSgbmv (char trans, int m, int n, int kl, int ku, - float alpha, const float *A, int lda, - const float *x, int incx, float beta, float *y, - int incy) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, int, int, int, float, const float *, int, const float *, int, float, float *, int); +void CUBLASWINAPI cublasSgbmv(char trans, int m, int n, int kl, int ku, + float alpha, const float *A, int lda, + const float *x, int incx, float beta, float *y, + int incy) { + using FuncPtr = + void(CUBLASWINAPI *)(char, int, int, int, int, float, const float *, int, + const float *, int, float, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSgbmv"); if (!func_ptr) LogFatalSymbolNotFound("cublasSgbmv"); return func_ptr(trans, m, n, kl, ku, alpha, A, lda, x, incx, beta, y, incy); } -void CUBLASWINAPI cublasDgbmv (char trans, int m, int n, int kl, int ku, - double alpha, const double *A, int lda, - const double *x, int incx, double beta, double *y, - int incy) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, int, int, int, double, const double *, int, const double *, int, double, double *, int); +void CUBLASWINAPI cublasDgbmv(char trans, int m, int n, int kl, int ku, + double alpha, const double *A, int lda, + const double *x, int incx, double beta, double *y, + int incy) { + using FuncPtr = + void(CUBLASWINAPI *)(char, int, int, int, int, double, const double *, + int, const double *, int, double, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDgbmv"); if (!func_ptr) LogFatalSymbolNotFound("cublasDgbmv"); return func_ptr(trans, m, n, kl, ku, alpha, A, lda, x, incx, beta, y, incy); } -void CUBLASWINAPI cublasCgbmv (char trans, int m, int n, int kl, int ku, - cuComplex alpha, const cuComplex *A, int lda, - const cuComplex *x, int incx, cuComplex beta, cuComplex *y, - int incy) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, int, int, int, cuComplex, const cuComplex *, int, const cuComplex *, int, cuComplex, cuComplex *, int); +void CUBLASWINAPI cublasCgbmv(char trans, int m, int n, int kl, int ku, + cuComplex alpha, const cuComplex *A, int lda, + const cuComplex *x, int incx, cuComplex beta, + cuComplex *y, int incy) { + using FuncPtr = void(CUBLASWINAPI *)( + char, int, int, int, int, cuComplex, const cuComplex *, int, + const cuComplex *, int, cuComplex, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgbmv"); if (!func_ptr) LogFatalSymbolNotFound("cublasCgbmv"); return func_ptr(trans, m, n, kl, ku, alpha, A, lda, x, incx, beta, y, incy); } -void CUBLASWINAPI cublasZgbmv (char trans, int m, int n, int kl, int ku, - cuDoubleComplex alpha, const cuDoubleComplex *A, int lda, - const cuDoubleComplex *x, int incx, cuDoubleComplex beta, cuDoubleComplex *y, - int incy) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, int, int, int, cuDoubleComplex, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); +void CUBLASWINAPI cublasZgbmv(char trans, int m, int n, int kl, int ku, + cuDoubleComplex alpha, const cuDoubleComplex *A, + int lda, const cuDoubleComplex *x, int incx, + cuDoubleComplex beta, cuDoubleComplex *y, + int incy) { + using FuncPtr = void(CUBLASWINAPI *)( + char, int, int, int, int, cuDoubleComplex, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZgbmv"); if (!func_ptr) LogFatalSymbolNotFound("cublasZgbmv"); return func_ptr(trans, m, n, kl, ku, alpha, A, lda, x, incx, beta, y, incy); } -void CUBLASWINAPI cublasStrmv (char uplo, char trans, char diag, int n, - const float *A, int lda, float *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, const float *, int, float *, int); +void CUBLASWINAPI cublasStrmv(char uplo, char trans, char diag, int n, + const float *A, int lda, float *x, int incx) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, int, const float *, + int, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasStrmv"); if (!func_ptr) LogFatalSymbolNotFound("cublasStrmv"); return func_ptr(uplo, trans, diag, n, A, lda, x, incx); } -void CUBLASWINAPI cublasDtrmv (char uplo, char trans, char diag, int n, - const double *A, int lda, double *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, const double *, int, double *, int); +void CUBLASWINAPI cublasDtrmv(char uplo, char trans, char diag, int n, + const double *A, int lda, double *x, int incx) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, int, const double *, + int, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtrmv"); if (!func_ptr) LogFatalSymbolNotFound("cublasDtrmv"); return func_ptr(uplo, trans, diag, n, A, lda, x, incx); } -void CUBLASWINAPI cublasCtrmv (char uplo, char trans, char diag, int n, - const cuComplex *A, int lda, cuComplex *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, const cuComplex *, int, cuComplex *, int); +void CUBLASWINAPI cublasCtrmv(char uplo, char trans, char diag, int n, + const cuComplex *A, int lda, cuComplex *x, + int incx) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, int, const cuComplex *, + int, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtrmv"); if (!func_ptr) LogFatalSymbolNotFound("cublasCtrmv"); return func_ptr(uplo, trans, diag, n, A, lda, x, incx); } -void CUBLASWINAPI cublasZtrmv (char uplo, char trans, char diag, int n, - const cuDoubleComplex *A, int lda, cuDoubleComplex *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); +void CUBLASWINAPI cublasZtrmv(char uplo, char trans, char diag, int n, + const cuDoubleComplex *A, int lda, + cuDoubleComplex *x, int incx) { + using FuncPtr = + void(CUBLASWINAPI *)(char, char, char, int, const cuDoubleComplex *, int, + cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtrmv"); if (!func_ptr) LogFatalSymbolNotFound("cublasZtrmv"); return func_ptr(uplo, trans, diag, n, A, lda, x, incx); } -void CUBLASWINAPI cublasStbmv (char uplo, char trans, char diag, int n, int k, - const float *A, int lda, float *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, int, const float *, int, float *, int); +void CUBLASWINAPI cublasStbmv(char uplo, char trans, char diag, int n, int k, + const float *A, int lda, float *x, int incx) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, int, int, + const float *, int, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasStbmv"); if (!func_ptr) LogFatalSymbolNotFound("cublasStbmv"); return func_ptr(uplo, trans, diag, n, k, A, lda, x, incx); } -void CUBLASWINAPI cublasDtbmv (char uplo, char trans, char diag, int n, int k, - const double *A, int lda, double *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, int, const double *, int, double *, int); +void CUBLASWINAPI cublasDtbmv(char uplo, char trans, char diag, int n, int k, + const double *A, int lda, double *x, int incx) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, int, int, + const double *, int, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtbmv"); if (!func_ptr) LogFatalSymbolNotFound("cublasDtbmv"); return func_ptr(uplo, trans, diag, n, k, A, lda, x, incx); } -void CUBLASWINAPI cublasCtbmv (char uplo, char trans, char diag, int n, int k, - const cuComplex *A, int lda, cuComplex *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, int, const cuComplex *, int, cuComplex *, int); +void CUBLASWINAPI cublasCtbmv(char uplo, char trans, char diag, int n, int k, + const cuComplex *A, int lda, cuComplex *x, + int incx) { + using FuncPtr = void(CUBLASWINAPI *)( + char, char, char, int, int, const cuComplex *, int, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtbmv"); if (!func_ptr) LogFatalSymbolNotFound("cublasCtbmv"); return func_ptr(uplo, trans, diag, n, k, A, lda, x, incx); } -void CUBLASWINAPI cublasZtbmv (char uplo, char trans, char diag, int n, int k, - const cuDoubleComplex *A, int lda, cuDoubleComplex *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); +void CUBLASWINAPI cublasZtbmv(char uplo, char trans, char diag, int n, int k, + const cuDoubleComplex *A, int lda, + cuDoubleComplex *x, int incx) { + using FuncPtr = + void(CUBLASWINAPI *)(char, char, char, int, int, const cuDoubleComplex *, + int, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtbmv"); if (!func_ptr) LogFatalSymbolNotFound("cublasZtbmv"); return func_ptr(uplo, trans, diag, n, k, A, lda, x, incx); } -void CUBLASWINAPI cublasStpmv(char uplo, char trans, char diag, int n, const float *AP, float *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, const float *, float *, int); +void CUBLASWINAPI cublasStpmv(char uplo, char trans, char diag, int n, + const float *AP, float *x, int incx) { + using FuncPtr = + void(CUBLASWINAPI *)(char, char, char, int, const float *, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasStpmv"); if (!func_ptr) LogFatalSymbolNotFound("cublasStpmv"); return func_ptr(uplo, trans, diag, n, AP, x, incx); } -void CUBLASWINAPI cublasDtpmv(char uplo, char trans, char diag, int n, const double *AP, double *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, const double *, double *, int); +void CUBLASWINAPI cublasDtpmv(char uplo, char trans, char diag, int n, + const double *AP, double *x, int incx) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, int, const double *, + double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtpmv"); if (!func_ptr) LogFatalSymbolNotFound("cublasDtpmv"); return func_ptr(uplo, trans, diag, n, AP, x, incx); } -void CUBLASWINAPI cublasCtpmv(char uplo, char trans, char diag, int n, const cuComplex *AP, cuComplex *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, const cuComplex *, cuComplex *, int); +void CUBLASWINAPI cublasCtpmv(char uplo, char trans, char diag, int n, + const cuComplex *AP, cuComplex *x, int incx) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, int, const cuComplex *, + cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtpmv"); if (!func_ptr) LogFatalSymbolNotFound("cublasCtpmv"); return func_ptr(uplo, trans, diag, n, AP, x, incx); } -void CUBLASWINAPI cublasZtpmv(char uplo, char trans, char diag, int n, const cuDoubleComplex *AP, cuDoubleComplex *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, const cuDoubleComplex *, cuDoubleComplex *, int); +void CUBLASWINAPI cublasZtpmv(char uplo, char trans, char diag, int n, + const cuDoubleComplex *AP, cuDoubleComplex *x, + int incx) { + using FuncPtr = void(CUBLASWINAPI *)( + char, char, char, int, const cuDoubleComplex *, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtpmv"); if (!func_ptr) LogFatalSymbolNotFound("cublasZtpmv"); return func_ptr(uplo, trans, diag, n, AP, x, incx); } -void CUBLASWINAPI cublasStrsv(char uplo, char trans, char diag, int n, const float *A, int lda, float *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, const float *, int, float *, int); +void CUBLASWINAPI cublasStrsv(char uplo, char trans, char diag, int n, + const float *A, int lda, float *x, int incx) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, int, const float *, + int, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasStrsv"); if (!func_ptr) LogFatalSymbolNotFound("cublasStrsv"); return func_ptr(uplo, trans, diag, n, A, lda, x, incx); } -void CUBLASWINAPI cublasDtrsv(char uplo, char trans, char diag, int n, const double *A, int lda, double *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, const double *, int, double *, int); +void CUBLASWINAPI cublasDtrsv(char uplo, char trans, char diag, int n, + const double *A, int lda, double *x, int incx) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, int, const double *, + int, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtrsv"); if (!func_ptr) LogFatalSymbolNotFound("cublasDtrsv"); return func_ptr(uplo, trans, diag, n, A, lda, x, incx); } -void CUBLASWINAPI cublasCtrsv(char uplo, char trans, char diag, int n, const cuComplex *A, int lda, cuComplex *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, const cuComplex *, int, cuComplex *, int); +void CUBLASWINAPI cublasCtrsv(char uplo, char trans, char diag, int n, + const cuComplex *A, int lda, cuComplex *x, + int incx) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, int, const cuComplex *, + int, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtrsv"); if (!func_ptr) LogFatalSymbolNotFound("cublasCtrsv"); return func_ptr(uplo, trans, diag, n, A, lda, x, incx); } -void CUBLASWINAPI cublasZtrsv(char uplo, char trans, char diag, int n, const cuDoubleComplex *A, int lda, +void CUBLASWINAPI cublasZtrsv(char uplo, char trans, char diag, int n, + const cuDoubleComplex *A, int lda, cuDoubleComplex *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); + using FuncPtr = + void(CUBLASWINAPI *)(char, char, char, int, const cuDoubleComplex *, int, + cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtrsv"); if (!func_ptr) LogFatalSymbolNotFound("cublasZtrsv"); return func_ptr(uplo, trans, diag, n, A, lda, x, incx); } -void CUBLASWINAPI cublasStpsv(char uplo, char trans, char diag, int n, const float *AP, - float *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, const float *, float *, int); +void CUBLASWINAPI cublasStpsv(char uplo, char trans, char diag, int n, + const float *AP, float *x, int incx) { + using FuncPtr = + void(CUBLASWINAPI *)(char, char, char, int, const float *, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasStpsv"); if (!func_ptr) LogFatalSymbolNotFound("cublasStpsv"); return func_ptr(uplo, trans, diag, n, AP, x, incx); } -void CUBLASWINAPI cublasDtpsv(char uplo, char trans, char diag, int n, const double *AP, double *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, const double *, double *, int); +void CUBLASWINAPI cublasDtpsv(char uplo, char trans, char diag, int n, + const double *AP, double *x, int incx) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, int, const double *, + double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtpsv"); if (!func_ptr) LogFatalSymbolNotFound("cublasDtpsv"); return func_ptr(uplo, trans, diag, n, AP, x, incx); } -void CUBLASWINAPI cublasCtpsv(char uplo, char trans, char diag, int n, const cuComplex *AP, cuComplex *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, const cuComplex *, cuComplex *, int); +void CUBLASWINAPI cublasCtpsv(char uplo, char trans, char diag, int n, + const cuComplex *AP, cuComplex *x, int incx) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, int, const cuComplex *, + cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtpsv"); if (!func_ptr) LogFatalSymbolNotFound("cublasCtpsv"); return func_ptr(uplo, trans, diag, n, AP, x, incx); } -void CUBLASWINAPI cublasZtpsv(char uplo, char trans, char diag, int n, const cuDoubleComplex *AP, - cuDoubleComplex *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, const cuDoubleComplex *, cuDoubleComplex *, int); +void CUBLASWINAPI cublasZtpsv(char uplo, char trans, char diag, int n, + const cuDoubleComplex *AP, cuDoubleComplex *x, + int incx) { + using FuncPtr = void(CUBLASWINAPI *)( + char, char, char, int, const cuDoubleComplex *, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtpsv"); if (!func_ptr) LogFatalSymbolNotFound("cublasZtpsv"); return func_ptr(uplo, trans, diag, n, AP, x, incx); } -void CUBLASWINAPI cublasStbsv(char uplo, char trans, - char diag, int n, int k, const float *A, - int lda, float *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, int, const float *, int, float *, int); +void CUBLASWINAPI cublasStbsv(char uplo, char trans, char diag, int n, int k, + const float *A, int lda, float *x, int incx) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, int, int, + const float *, int, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasStbsv"); if (!func_ptr) LogFatalSymbolNotFound("cublasStbsv"); return func_ptr(uplo, trans, diag, n, k, A, lda, x, incx); } -void CUBLASWINAPI cublasDtbsv(char uplo, char trans, - char diag, int n, int k, const double *A, - int lda, double *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, int, const double *, int, double *, int); +void CUBLASWINAPI cublasDtbsv(char uplo, char trans, char diag, int n, int k, + const double *A, int lda, double *x, int incx) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, int, int, + const double *, int, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtbsv"); if (!func_ptr) LogFatalSymbolNotFound("cublasDtbsv"); return func_ptr(uplo, trans, diag, n, k, A, lda, x, incx); } -void CUBLASWINAPI cublasCtbsv(char uplo, char trans, - char diag, int n, int k, const cuComplex *A, - int lda, cuComplex *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, int, const cuComplex *, int, cuComplex *, int); +void CUBLASWINAPI cublasCtbsv(char uplo, char trans, char diag, int n, int k, + const cuComplex *A, int lda, cuComplex *x, + int incx) { + using FuncPtr = void(CUBLASWINAPI *)( + char, char, char, int, int, const cuComplex *, int, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtbsv"); if (!func_ptr) LogFatalSymbolNotFound("cublasCtbsv"); return func_ptr(uplo, trans, diag, n, k, A, lda, x, incx); } -void CUBLASWINAPI cublasZtbsv(char uplo, char trans, - char diag, int n, int k, const cuDoubleComplex *A, - int lda, cuDoubleComplex *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); +void CUBLASWINAPI cublasZtbsv(char uplo, char trans, char diag, int n, int k, + const cuDoubleComplex *A, int lda, + cuDoubleComplex *x, int incx) { + using FuncPtr = + void(CUBLASWINAPI *)(char, char, char, int, int, const cuDoubleComplex *, + int, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtbsv"); if (!func_ptr) LogFatalSymbolNotFound("cublasZtbsv"); return func_ptr(uplo, trans, diag, n, k, A, lda, x, incx); } -void CUBLASWINAPI cublasSsymv (char uplo, int n, float alpha, const float *A, - int lda, const float *x, int incx, float beta, - float *y, int incy) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, float, const float *, int, const float *, int, float, float *, int); +void CUBLASWINAPI cublasSsymv(char uplo, int n, float alpha, const float *A, + int lda, const float *x, int incx, float beta, + float *y, int incy) { + using FuncPtr = void(CUBLASWINAPI *)(char, int, float, const float *, int, + const float *, int, float, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSsymv"); if (!func_ptr) LogFatalSymbolNotFound("cublasSsymv"); return func_ptr(uplo, n, alpha, A, lda, x, incx, beta, y, incy); } -void CUBLASWINAPI cublasDsymv (char uplo, int n, double alpha, const double *A, - int lda, const double *x, int incx, double beta, - double *y, int incy) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, double, const double *, int, const double *, int, double, double *, int); +void CUBLASWINAPI cublasDsymv(char uplo, int n, double alpha, const double *A, + int lda, const double *x, int incx, double beta, + double *y, int incy) { + using FuncPtr = + void(CUBLASWINAPI *)(char, int, double, const double *, int, + const double *, int, double, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDsymv"); if (!func_ptr) LogFatalSymbolNotFound("cublasDsymv"); return func_ptr(uplo, n, alpha, A, lda, x, incx, beta, y, incy); } -void CUBLASWINAPI cublasChemv (char uplo, int n, cuComplex alpha, const cuComplex *A, - int lda, const cuComplex *x, int incx, cuComplex beta, - cuComplex *y, int incy) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, cuComplex, const cuComplex *, int, const cuComplex *, int, cuComplex, cuComplex *, int); +void CUBLASWINAPI cublasChemv(char uplo, int n, cuComplex alpha, + const cuComplex *A, int lda, const cuComplex *x, + int incx, cuComplex beta, cuComplex *y, + int incy) { + using FuncPtr = + void(CUBLASWINAPI *)(char, int, cuComplex, const cuComplex *, int, + const cuComplex *, int, cuComplex, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasChemv"); if (!func_ptr) LogFatalSymbolNotFound("cublasChemv"); return func_ptr(uplo, n, alpha, A, lda, x, incx, beta, y, incy); } -void CUBLASWINAPI cublasZhemv (char uplo, int n, cuDoubleComplex alpha, const cuDoubleComplex *A, - int lda, const cuDoubleComplex *x, int incx, cuDoubleComplex beta, - cuDoubleComplex *y, int incy) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, cuDoubleComplex, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); +void CUBLASWINAPI cublasZhemv(char uplo, int n, cuDoubleComplex alpha, + const cuDoubleComplex *A, int lda, + const cuDoubleComplex *x, int incx, + cuDoubleComplex beta, cuDoubleComplex *y, + int incy) { + using FuncPtr = void(CUBLASWINAPI *)( + char, int, cuDoubleComplex, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZhemv"); if (!func_ptr) LogFatalSymbolNotFound("cublasZhemv"); return func_ptr(uplo, n, alpha, A, lda, x, incx, beta, y, incy); } -void CUBLASWINAPI cublasSsbmv (char uplo, int n, int k, float alpha, - const float *A, int lda, const float *x, int incx, - float beta, float *y, int incy) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, int, float, const float *, int, const float *, int, float, float *, int); +void CUBLASWINAPI cublasSsbmv(char uplo, int n, int k, float alpha, + const float *A, int lda, const float *x, int incx, + float beta, float *y, int incy) { + using FuncPtr = + void(CUBLASWINAPI *)(char, int, int, float, const float *, int, + const float *, int, float, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSsbmv"); if (!func_ptr) LogFatalSymbolNotFound("cublasSsbmv"); return func_ptr(uplo, n, k, alpha, A, lda, x, incx, beta, y, incy); } -void CUBLASWINAPI cublasDsbmv (char uplo, int n, int k, double alpha, - const double *A, int lda, const double *x, int incx, - double beta, double *y, int incy) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, int, double, const double *, int, const double *, int, double, double *, int); +void CUBLASWINAPI cublasDsbmv(char uplo, int n, int k, double alpha, + const double *A, int lda, const double *x, + int incx, double beta, double *y, int incy) { + using FuncPtr = + void(CUBLASWINAPI *)(char, int, int, double, const double *, int, + const double *, int, double, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDsbmv"); if (!func_ptr) LogFatalSymbolNotFound("cublasDsbmv"); return func_ptr(uplo, n, k, alpha, A, lda, x, incx, beta, y, incy); } -void CUBLASWINAPI cublasChbmv (char uplo, int n, int k, cuComplex alpha, - const cuComplex *A, int lda, const cuComplex *x, int incx, - cuComplex beta, cuComplex *y, int incy) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, int, cuComplex, const cuComplex *, int, const cuComplex *, int, cuComplex, cuComplex *, int); +void CUBLASWINAPI cublasChbmv(char uplo, int n, int k, cuComplex alpha, + const cuComplex *A, int lda, const cuComplex *x, + int incx, cuComplex beta, cuComplex *y, + int incy) { + using FuncPtr = + void(CUBLASWINAPI *)(char, int, int, cuComplex, const cuComplex *, int, + const cuComplex *, int, cuComplex, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasChbmv"); if (!func_ptr) LogFatalSymbolNotFound("cublasChbmv"); return func_ptr(uplo, n, k, alpha, A, lda, x, incx, beta, y, incy); } -void CUBLASWINAPI cublasZhbmv (char uplo, int n, int k, cuDoubleComplex alpha, - const cuDoubleComplex *A, int lda, const cuDoubleComplex *x, int incx, - cuDoubleComplex beta, cuDoubleComplex *y, int incy) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, int, cuDoubleComplex, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); +void CUBLASWINAPI cublasZhbmv(char uplo, int n, int k, cuDoubleComplex alpha, + const cuDoubleComplex *A, int lda, + const cuDoubleComplex *x, int incx, + cuDoubleComplex beta, cuDoubleComplex *y, + int incy) { + using FuncPtr = void(CUBLASWINAPI *)( + char, int, int, cuDoubleComplex, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZhbmv"); if (!func_ptr) LogFatalSymbolNotFound("cublasZhbmv"); return func_ptr(uplo, n, k, alpha, A, lda, x, incx, beta, y, incy); } -void CUBLASWINAPI cublasSspmv(char uplo, int n, float alpha, - const float *AP, const float *x, - int incx, float beta, float *y, int incy) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, float, const float *, const float *, int, float, float *, int); +void CUBLASWINAPI cublasSspmv(char uplo, int n, float alpha, const float *AP, + const float *x, int incx, float beta, float *y, + int incy) { + using FuncPtr = void(CUBLASWINAPI *)(char, int, float, const float *, + const float *, int, float, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSspmv"); if (!func_ptr) LogFatalSymbolNotFound("cublasSspmv"); return func_ptr(uplo, n, alpha, AP, x, incx, beta, y, incy); } -void CUBLASWINAPI cublasDspmv(char uplo, int n, double alpha, - const double *AP, const double *x, - int incx, double beta, double *y, int incy) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, double, const double *, const double *, int, double, double *, int); +void CUBLASWINAPI cublasDspmv(char uplo, int n, double alpha, const double *AP, + const double *x, int incx, double beta, double *y, + int incy) { + using FuncPtr = + void(CUBLASWINAPI *)(char, int, double, const double *, const double *, + int, double, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDspmv"); if (!func_ptr) LogFatalSymbolNotFound("cublasDspmv"); return func_ptr(uplo, n, alpha, AP, x, incx, beta, y, incy); } void CUBLASWINAPI cublasChpmv(char uplo, int n, cuComplex alpha, - const cuComplex *AP, const cuComplex *x, - int incx, cuComplex beta, cuComplex *y, int incy) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, cuComplex, const cuComplex *, const cuComplex *, int, cuComplex, cuComplex *, int); + const cuComplex *AP, const cuComplex *x, int incx, + cuComplex beta, cuComplex *y, int incy) { + using FuncPtr = + void(CUBLASWINAPI *)(char, int, cuComplex, const cuComplex *, + const cuComplex *, int, cuComplex, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasChpmv"); if (!func_ptr) LogFatalSymbolNotFound("cublasChpmv"); return func_ptr(uplo, n, alpha, AP, x, incx, beta, y, incy); } void CUBLASWINAPI cublasZhpmv(char uplo, int n, cuDoubleComplex alpha, - const cuDoubleComplex *AP, const cuDoubleComplex *x, - int incx, cuDoubleComplex beta, cuDoubleComplex *y, int incy) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, cuDoubleComplex, const cuDoubleComplex *, const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); + const cuDoubleComplex *AP, + const cuDoubleComplex *x, int incx, + cuDoubleComplex beta, cuDoubleComplex *y, + int incy) { + using FuncPtr = void(CUBLASWINAPI *)( + char, int, cuDoubleComplex, const cuDoubleComplex *, + const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZhpmv"); if (!func_ptr) LogFatalSymbolNotFound("cublasZhpmv"); return func_ptr(uplo, n, alpha, AP, x, incx, beta, y, incy); } -void CUBLASWINAPI cublasSger (int m, int n, float alpha, const float *x, int incx, - const float *y, int incy, float *A, int lda) { - using FuncPtr = void (CUBLASWINAPI *)(int, int, float, const float *, int, const float *, int, float *, int); +void CUBLASWINAPI cublasSger(int m, int n, float alpha, const float *x, + int incx, const float *y, int incy, float *A, + int lda) { + using FuncPtr = void(CUBLASWINAPI *)(int, int, float, const float *, int, + const float *, int, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSger"); if (!func_ptr) LogFatalSymbolNotFound("cublasSger"); return func_ptr(m, n, alpha, x, incx, y, incy, A, lda); } -void CUBLASWINAPI cublasDger (int m, int n, double alpha, const double *x, int incx, - const double *y, int incy, double *A, int lda) { - using FuncPtr = void (CUBLASWINAPI *)(int, int, double, const double *, int, const double *, int, double *, int); +void CUBLASWINAPI cublasDger(int m, int n, double alpha, const double *x, + int incx, const double *y, int incy, double *A, + int lda) { + using FuncPtr = void(CUBLASWINAPI *)(int, int, double, const double *, int, + const double *, int, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDger"); if (!func_ptr) LogFatalSymbolNotFound("cublasDger"); return func_ptr(m, n, alpha, x, incx, y, incy, A, lda); } -void CUBLASWINAPI cublasCgeru (int m, int n, cuComplex alpha, const cuComplex *x, - int incx, const cuComplex *y, int incy, - cuComplex *A, int lda) { - using FuncPtr = void (CUBLASWINAPI *)(int, int, cuComplex, const cuComplex *, int, const cuComplex *, int, cuComplex *, int); +void CUBLASWINAPI cublasCgeru(int m, int n, cuComplex alpha, const cuComplex *x, + int incx, const cuComplex *y, int incy, + cuComplex *A, int lda) { + using FuncPtr = + void(CUBLASWINAPI *)(int, int, cuComplex, const cuComplex *, int, + const cuComplex *, int, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgeru"); if (!func_ptr) LogFatalSymbolNotFound("cublasCgeru"); return func_ptr(m, n, alpha, x, incx, y, incy, A, lda); } -void CUBLASWINAPI cublasCgerc (int m, int n, cuComplex alpha, const cuComplex *x, - int incx, const cuComplex *y, int incy, - cuComplex *A, int lda) { - using FuncPtr = void (CUBLASWINAPI *)(int, int, cuComplex, const cuComplex *, int, const cuComplex *, int, cuComplex *, int); +void CUBLASWINAPI cublasCgerc(int m, int n, cuComplex alpha, const cuComplex *x, + int incx, const cuComplex *y, int incy, + cuComplex *A, int lda) { + using FuncPtr = + void(CUBLASWINAPI *)(int, int, cuComplex, const cuComplex *, int, + const cuComplex *, int, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgerc"); if (!func_ptr) LogFatalSymbolNotFound("cublasCgerc"); return func_ptr(m, n, alpha, x, incx, y, incy, A, lda); } -void CUBLASWINAPI cublasZgeru (int m, int n, cuDoubleComplex alpha, const cuDoubleComplex *x, - int incx, const cuDoubleComplex *y, int incy, - cuDoubleComplex *A, int lda) { - using FuncPtr = void (CUBLASWINAPI *)(int, int, cuDoubleComplex, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); +void CUBLASWINAPI cublasZgeru(int m, int n, cuDoubleComplex alpha, + const cuDoubleComplex *x, int incx, + const cuDoubleComplex *y, int incy, + cuDoubleComplex *A, int lda) { + using FuncPtr = void(CUBLASWINAPI *)( + int, int, cuDoubleComplex, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZgeru"); if (!func_ptr) LogFatalSymbolNotFound("cublasZgeru"); return func_ptr(m, n, alpha, x, incx, y, incy, A, lda); } -void CUBLASWINAPI cublasZgerc (int m, int n, cuDoubleComplex alpha, const cuDoubleComplex *x, - int incx, const cuDoubleComplex *y, int incy, - cuDoubleComplex *A, int lda) { - using FuncPtr = void (CUBLASWINAPI *)(int, int, cuDoubleComplex, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); +void CUBLASWINAPI cublasZgerc(int m, int n, cuDoubleComplex alpha, + const cuDoubleComplex *x, int incx, + const cuDoubleComplex *y, int incy, + cuDoubleComplex *A, int lda) { + using FuncPtr = void(CUBLASWINAPI *)( + int, int, cuDoubleComplex, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZgerc"); if (!func_ptr) LogFatalSymbolNotFound("cublasZgerc"); return func_ptr(m, n, alpha, x, incx, y, incy, A, lda); } -void CUBLASWINAPI cublasSsyr (char uplo, int n, float alpha, const float *x, - int incx, float *A, int lda) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, float, const float *, int, float *, int); +void CUBLASWINAPI cublasSsyr(char uplo, int n, float alpha, const float *x, + int incx, float *A, int lda) { + using FuncPtr = + void(CUBLASWINAPI *)(char, int, float, const float *, int, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSsyr"); if (!func_ptr) LogFatalSymbolNotFound("cublasSsyr"); return func_ptr(uplo, n, alpha, x, incx, A, lda); } -void CUBLASWINAPI cublasDsyr (char uplo, int n, double alpha, const double *x, - int incx, double *A, int lda) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, double, const double *, int, double *, int); +void CUBLASWINAPI cublasDsyr(char uplo, int n, double alpha, const double *x, + int incx, double *A, int lda) { + using FuncPtr = void(CUBLASWINAPI *)(char, int, double, const double *, int, + double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDsyr"); if (!func_ptr) LogFatalSymbolNotFound("cublasDsyr"); return func_ptr(uplo, n, alpha, x, incx, A, lda); } -void CUBLASWINAPI cublasCher (char uplo, int n, float alpha, - const cuComplex *x, int incx, cuComplex *A, int lda) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, float, const cuComplex *, int, cuComplex *, int); +void CUBLASWINAPI cublasCher(char uplo, int n, float alpha, const cuComplex *x, + int incx, cuComplex *A, int lda) { + using FuncPtr = void(CUBLASWINAPI *)(char, int, float, const cuComplex *, int, + cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCher"); if (!func_ptr) LogFatalSymbolNotFound("cublasCher"); return func_ptr(uplo, n, alpha, x, incx, A, lda); } -void CUBLASWINAPI cublasZher (char uplo, int n, double alpha, - const cuDoubleComplex *x, int incx, cuDoubleComplex *A, int lda) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, double, const cuDoubleComplex *, int, cuDoubleComplex *, int); +void CUBLASWINAPI cublasZher(char uplo, int n, double alpha, + const cuDoubleComplex *x, int incx, + cuDoubleComplex *A, int lda) { + using FuncPtr = void(CUBLASWINAPI *)( + char, int, double, const cuDoubleComplex *, int, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZher"); if (!func_ptr) LogFatalSymbolNotFound("cublasZher"); return func_ptr(uplo, n, alpha, x, incx, A, lda); } -void CUBLASWINAPI cublasSspr (char uplo, int n, float alpha, const float *x, - int incx, float *AP) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, float, const float *, int, float *); +void CUBLASWINAPI cublasSspr(char uplo, int n, float alpha, const float *x, + int incx, float *AP) { + using FuncPtr = + void(CUBLASWINAPI *)(char, int, float, const float *, int, float *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSspr"); if (!func_ptr) LogFatalSymbolNotFound("cublasSspr"); return func_ptr(uplo, n, alpha, x, incx, AP); } -void CUBLASWINAPI cublasDspr (char uplo, int n, double alpha, const double *x, - int incx, double *AP) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, double, const double *, int, double *); +void CUBLASWINAPI cublasDspr(char uplo, int n, double alpha, const double *x, + int incx, double *AP) { + using FuncPtr = + void(CUBLASWINAPI *)(char, int, double, const double *, int, double *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDspr"); if (!func_ptr) LogFatalSymbolNotFound("cublasDspr"); return func_ptr(uplo, n, alpha, x, incx, AP); } -void CUBLASWINAPI cublasChpr (char uplo, int n, float alpha, const cuComplex *x, - int incx, cuComplex *AP) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, float, const cuComplex *, int, cuComplex *); +void CUBLASWINAPI cublasChpr(char uplo, int n, float alpha, const cuComplex *x, + int incx, cuComplex *AP) { + using FuncPtr = void(CUBLASWINAPI *)(char, int, float, const cuComplex *, int, + cuComplex *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasChpr"); if (!func_ptr) LogFatalSymbolNotFound("cublasChpr"); return func_ptr(uplo, n, alpha, x, incx, AP); } -void CUBLASWINAPI cublasZhpr (char uplo, int n, double alpha, const cuDoubleComplex *x, - int incx, cuDoubleComplex *AP) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, double, const cuDoubleComplex *, int, cuDoubleComplex *); +void CUBLASWINAPI cublasZhpr(char uplo, int n, double alpha, + const cuDoubleComplex *x, int incx, + cuDoubleComplex *AP) { + using FuncPtr = void(CUBLASWINAPI *)( + char, int, double, const cuDoubleComplex *, int, cuDoubleComplex *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZhpr"); if (!func_ptr) LogFatalSymbolNotFound("cublasZhpr"); return func_ptr(uplo, n, alpha, x, incx, AP); } -void CUBLASWINAPI cublasSsyr2 (char uplo, int n, float alpha, const float *x, - int incx, const float *y, int incy, float *A, - int lda) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, float, const float *, int, const float *, int, float *, int); +void CUBLASWINAPI cublasSsyr2(char uplo, int n, float alpha, const float *x, + int incx, const float *y, int incy, float *A, + int lda) { + using FuncPtr = void(CUBLASWINAPI *)(char, int, float, const float *, int, + const float *, int, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSsyr2"); if (!func_ptr) LogFatalSymbolNotFound("cublasSsyr2"); return func_ptr(uplo, n, alpha, x, incx, y, incy, A, lda); } -void CUBLASWINAPI cublasDsyr2 (char uplo, int n, double alpha, const double *x, - int incx, const double *y, int incy, double *A, - int lda) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, double, const double *, int, const double *, int, double *, int); +void CUBLASWINAPI cublasDsyr2(char uplo, int n, double alpha, const double *x, + int incx, const double *y, int incy, double *A, + int lda) { + using FuncPtr = void(CUBLASWINAPI *)(char, int, double, const double *, int, + const double *, int, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDsyr2"); if (!func_ptr) LogFatalSymbolNotFound("cublasDsyr2"); return func_ptr(uplo, n, alpha, x, incx, y, incy, A, lda); } -void CUBLASWINAPI cublasCher2 (char uplo, int n, cuComplex alpha, const cuComplex *x, - int incx, const cuComplex *y, int incy, cuComplex *A, - int lda) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, cuComplex, const cuComplex *, int, const cuComplex *, int, cuComplex *, int); +void CUBLASWINAPI cublasCher2(char uplo, int n, cuComplex alpha, + const cuComplex *x, int incx, const cuComplex *y, + int incy, cuComplex *A, int lda) { + using FuncPtr = + void(CUBLASWINAPI *)(char, int, cuComplex, const cuComplex *, int, + const cuComplex *, int, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCher2"); if (!func_ptr) LogFatalSymbolNotFound("cublasCher2"); return func_ptr(uplo, n, alpha, x, incx, y, incy, A, lda); } -void CUBLASWINAPI cublasZher2 (char uplo, int n, cuDoubleComplex alpha, const cuDoubleComplex *x, - int incx, const cuDoubleComplex *y, int incy, cuDoubleComplex *A, - int lda) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, cuDoubleComplex, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); +void CUBLASWINAPI cublasZher2(char uplo, int n, cuDoubleComplex alpha, + const cuDoubleComplex *x, int incx, + const cuDoubleComplex *y, int incy, + cuDoubleComplex *A, int lda) { + using FuncPtr = void(CUBLASWINAPI *)( + char, int, cuDoubleComplex, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZher2"); if (!func_ptr) LogFatalSymbolNotFound("cublasZher2"); return func_ptr(uplo, n, alpha, x, incx, y, incy, A, lda); } -void CUBLASWINAPI cublasSspr2 (char uplo, int n, float alpha, const float *x, - int incx, const float *y, int incy, float *AP) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, float, const float *, int, const float *, int, float *); +void CUBLASWINAPI cublasSspr2(char uplo, int n, float alpha, const float *x, + int incx, const float *y, int incy, float *AP) { + using FuncPtr = void(CUBLASWINAPI *)(char, int, float, const float *, int, + const float *, int, float *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSspr2"); if (!func_ptr) LogFatalSymbolNotFound("cublasSspr2"); return func_ptr(uplo, n, alpha, x, incx, y, incy, AP); } -void CUBLASWINAPI cublasDspr2 (char uplo, int n, double alpha, - const double *x, int incx, const double *y, - int incy, double *AP) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, double, const double *, int, const double *, int, double *); +void CUBLASWINAPI cublasDspr2(char uplo, int n, double alpha, const double *x, + int incx, const double *y, int incy, double *AP) { + using FuncPtr = void(CUBLASWINAPI *)(char, int, double, const double *, int, + const double *, int, double *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDspr2"); if (!func_ptr) LogFatalSymbolNotFound("cublasDspr2"); return func_ptr(uplo, n, alpha, x, incx, y, incy, AP); } -void CUBLASWINAPI cublasChpr2 (char uplo, int n, cuComplex alpha, - const cuComplex *x, int incx, const cuComplex *y, - int incy, cuComplex *AP) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, cuComplex, const cuComplex *, int, const cuComplex *, int, cuComplex *); +void CUBLASWINAPI cublasChpr2(char uplo, int n, cuComplex alpha, + const cuComplex *x, int incx, const cuComplex *y, + int incy, cuComplex *AP) { + using FuncPtr = + void(CUBLASWINAPI *)(char, int, cuComplex, const cuComplex *, int, + const cuComplex *, int, cuComplex *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasChpr2"); if (!func_ptr) LogFatalSymbolNotFound("cublasChpr2"); return func_ptr(uplo, n, alpha, x, incx, y, incy, AP); } -void CUBLASWINAPI cublasZhpr2 (char uplo, int n, cuDoubleComplex alpha, - const cuDoubleComplex *x, int incx, const cuDoubleComplex *y, - int incy, cuDoubleComplex *AP) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, cuDoubleComplex, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex *); +void CUBLASWINAPI cublasZhpr2(char uplo, int n, cuDoubleComplex alpha, + const cuDoubleComplex *x, int incx, + const cuDoubleComplex *y, int incy, + cuDoubleComplex *AP) { + using FuncPtr = void(CUBLASWINAPI *)( + char, int, cuDoubleComplex, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, cuDoubleComplex *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZhpr2"); if (!func_ptr) LogFatalSymbolNotFound("cublasZhpr2"); return func_ptr(uplo, n, alpha, x, incx, y, incy, AP); } -void CUBLASWINAPI cublasSgemm (char transa, char transb, int m, int n, int k, - float alpha, const float *A, int lda, - const float *B, int ldb, float beta, float *C, - int ldc) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, int, float, const float *, int, const float *, int, float, float *, int); +void CUBLASWINAPI cublasSgemm(char transa, char transb, int m, int n, int k, + float alpha, const float *A, int lda, + const float *B, int ldb, float beta, float *C, + int ldc) { + using FuncPtr = + void(CUBLASWINAPI *)(char, char, int, int, int, float, const float *, int, + const float *, int, float, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSgemm"); if (!func_ptr) LogFatalSymbolNotFound("cublasSgemm"); return func_ptr(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); } -void CUBLASWINAPI cublasDgemm (char transa, char transb, int m, int n, int k, - double alpha, const double *A, int lda, - const double *B, int ldb, double beta, double *C, - int ldc) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, int, double, const double *, int, const double *, int, double, double *, int); +void CUBLASWINAPI cublasDgemm(char transa, char transb, int m, int n, int k, + double alpha, const double *A, int lda, + const double *B, int ldb, double beta, double *C, + int ldc) { + using FuncPtr = + void(CUBLASWINAPI *)(char, char, int, int, int, double, const double *, + int, const double *, int, double, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDgemm"); if (!func_ptr) LogFatalSymbolNotFound("cublasDgemm"); return func_ptr(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); } -void CUBLASWINAPI cublasCgemm (char transa, char transb, int m, int n, int k, - cuComplex alpha, const cuComplex *A, int lda, - const cuComplex *B, int ldb, cuComplex beta, - cuComplex *C, int ldc) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, int, cuComplex, const cuComplex *, int, const cuComplex *, int, cuComplex, cuComplex *, int); +void CUBLASWINAPI cublasCgemm(char transa, char transb, int m, int n, int k, + cuComplex alpha, const cuComplex *A, int lda, + const cuComplex *B, int ldb, cuComplex beta, + cuComplex *C, int ldc) { + using FuncPtr = void(CUBLASWINAPI *)( + char, char, int, int, int, cuComplex, const cuComplex *, int, + const cuComplex *, int, cuComplex, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgemm"); if (!func_ptr) LogFatalSymbolNotFound("cublasCgemm"); return func_ptr(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); } -void CUBLASWINAPI cublasZgemm (char transa, char transb, int m, int n, - int k, cuDoubleComplex alpha, - const cuDoubleComplex *A, int lda, - const cuDoubleComplex *B, int ldb, - cuDoubleComplex beta, cuDoubleComplex *C, - int ldc) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, int, cuDoubleComplex, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); +void CUBLASWINAPI cublasZgemm(char transa, char transb, int m, int n, int k, + cuDoubleComplex alpha, const cuDoubleComplex *A, + int lda, const cuDoubleComplex *B, int ldb, + cuDoubleComplex beta, cuDoubleComplex *C, + int ldc) { + using FuncPtr = void(CUBLASWINAPI *)( + char, char, int, int, int, cuDoubleComplex, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZgemm"); if (!func_ptr) LogFatalSymbolNotFound("cublasZgemm"); return func_ptr(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); } -void CUBLASWINAPI cublasSsyrk (char uplo, char trans, int n, int k, float alpha, - const float *A, int lda, float beta, float *C, - int ldc) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, float, const float *, int, float, float *, int); +void CUBLASWINAPI cublasSsyrk(char uplo, char trans, int n, int k, float alpha, + const float *A, int lda, float beta, float *C, + int ldc) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, int, int, float, + const float *, int, float, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSsyrk"); if (!func_ptr) LogFatalSymbolNotFound("cublasSsyrk"); return func_ptr(uplo, trans, n, k, alpha, A, lda, beta, C, ldc); } -void CUBLASWINAPI cublasDsyrk (char uplo, char trans, int n, int k, - double alpha, const double *A, int lda, - double beta, double *C, int ldc) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, double, const double *, int, double, double *, int); +void CUBLASWINAPI cublasDsyrk(char uplo, char trans, int n, int k, double alpha, + const double *A, int lda, double beta, double *C, + int ldc) { + using FuncPtr = void(CUBLASWINAPI *)( + char, char, int, int, double, const double *, int, double, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDsyrk"); if (!func_ptr) LogFatalSymbolNotFound("cublasDsyrk"); return func_ptr(uplo, trans, n, k, alpha, A, lda, beta, C, ldc); } -void CUBLASWINAPI cublasCsyrk (char uplo, char trans, int n, int k, - cuComplex alpha, const cuComplex *A, int lda, - cuComplex beta, cuComplex *C, int ldc) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, cuComplex, const cuComplex *, int, cuComplex, cuComplex *, int); +void CUBLASWINAPI cublasCsyrk(char uplo, char trans, int n, int k, + cuComplex alpha, const cuComplex *A, int lda, + cuComplex beta, cuComplex *C, int ldc) { + using FuncPtr = + void(CUBLASWINAPI *)(char, char, int, int, cuComplex, const cuComplex *, + int, cuComplex, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCsyrk"); if (!func_ptr) LogFatalSymbolNotFound("cublasCsyrk"); return func_ptr(uplo, trans, n, k, alpha, A, lda, beta, C, ldc); } -void CUBLASWINAPI cublasZsyrk (char uplo, char trans, int n, int k, - cuDoubleComplex alpha, - const cuDoubleComplex *A, int lda, - cuDoubleComplex beta, - cuDoubleComplex *C, int ldc) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, cuDoubleComplex, const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); +void CUBLASWINAPI cublasZsyrk(char uplo, char trans, int n, int k, + cuDoubleComplex alpha, const cuDoubleComplex *A, + int lda, cuDoubleComplex beta, cuDoubleComplex *C, + int ldc) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, int, int, cuDoubleComplex, + const cuDoubleComplex *, int, + cuDoubleComplex, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZsyrk"); if (!func_ptr) LogFatalSymbolNotFound("cublasZsyrk"); return func_ptr(uplo, trans, n, k, alpha, A, lda, beta, C, ldc); } -void CUBLASWINAPI cublasCherk (char uplo, char trans, int n, int k, - float alpha, const cuComplex *A, int lda, - float beta, cuComplex *C, int ldc) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, float, const cuComplex *, int, float, cuComplex *, int); +void CUBLASWINAPI cublasCherk(char uplo, char trans, int n, int k, float alpha, + const cuComplex *A, int lda, float beta, + cuComplex *C, int ldc) { + using FuncPtr = + void(CUBLASWINAPI *)(char, char, int, int, float, const cuComplex *, int, + float, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCherk"); if (!func_ptr) LogFatalSymbolNotFound("cublasCherk"); return func_ptr(uplo, trans, n, k, alpha, A, lda, beta, C, ldc); } -void CUBLASWINAPI cublasZherk (char uplo, char trans, int n, int k, - double alpha, - const cuDoubleComplex *A, int lda, - double beta, - cuDoubleComplex *C, int ldc) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, double, const cuDoubleComplex *, int, double, cuDoubleComplex *, int); +void CUBLASWINAPI cublasZherk(char uplo, char trans, int n, int k, double alpha, + const cuDoubleComplex *A, int lda, double beta, + cuDoubleComplex *C, int ldc) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, int, int, double, + const cuDoubleComplex *, int, double, + cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZherk"); if (!func_ptr) LogFatalSymbolNotFound("cublasZherk"); return func_ptr(uplo, trans, n, k, alpha, A, lda, beta, C, ldc); } -void CUBLASWINAPI cublasSsyr2k (char uplo, char trans, int n, int k, float alpha, - const float *A, int lda, const float *B, int ldb, - float beta, float *C, int ldc) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, float, const float *, int, const float *, int, float, float *, int); +void CUBLASWINAPI cublasSsyr2k(char uplo, char trans, int n, int k, float alpha, + const float *A, int lda, const float *B, int ldb, + float beta, float *C, int ldc) { + using FuncPtr = + void(CUBLASWINAPI *)(char, char, int, int, float, const float *, int, + const float *, int, float, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSsyr2k"); if (!func_ptr) LogFatalSymbolNotFound("cublasSsyr2k"); return func_ptr(uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); } -void CUBLASWINAPI cublasDsyr2k (char uplo, char trans, int n, int k, - double alpha, const double *A, int lda, - const double *B, int ldb, double beta, - double *C, int ldc) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, double, const double *, int, const double *, int, double, double *, int); +void CUBLASWINAPI cublasDsyr2k(char uplo, char trans, int n, int k, + double alpha, const double *A, int lda, + const double *B, int ldb, double beta, double *C, + int ldc) { + using FuncPtr = + void(CUBLASWINAPI *)(char, char, int, int, double, const double *, int, + const double *, int, double, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDsyr2k"); if (!func_ptr) LogFatalSymbolNotFound("cublasDsyr2k"); return func_ptr(uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); } -void CUBLASWINAPI cublasCsyr2k (char uplo, char trans, int n, int k, - cuComplex alpha, const cuComplex *A, int lda, - const cuComplex *B, int ldb, cuComplex beta, - cuComplex *C, int ldc) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, cuComplex, const cuComplex *, int, const cuComplex *, int, cuComplex, cuComplex *, int); +void CUBLASWINAPI cublasCsyr2k(char uplo, char trans, int n, int k, + cuComplex alpha, const cuComplex *A, int lda, + const cuComplex *B, int ldb, cuComplex beta, + cuComplex *C, int ldc) { + using FuncPtr = void(CUBLASWINAPI *)( + char, char, int, int, cuComplex, const cuComplex *, int, + const cuComplex *, int, cuComplex, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCsyr2k"); if (!func_ptr) LogFatalSymbolNotFound("cublasCsyr2k"); return func_ptr(uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); } -void CUBLASWINAPI cublasZsyr2k (char uplo, char trans, int n, int k, - cuDoubleComplex alpha, const cuDoubleComplex *A, int lda, - const cuDoubleComplex *B, int ldb, cuDoubleComplex beta, - cuDoubleComplex *C, int ldc) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, cuDoubleComplex, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); +void CUBLASWINAPI cublasZsyr2k(char uplo, char trans, int n, int k, + cuDoubleComplex alpha, const cuDoubleComplex *A, + int lda, const cuDoubleComplex *B, int ldb, + cuDoubleComplex beta, cuDoubleComplex *C, + int ldc) { + using FuncPtr = void(CUBLASWINAPI *)( + char, char, int, int, cuDoubleComplex, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZsyr2k"); if (!func_ptr) LogFatalSymbolNotFound("cublasZsyr2k"); return func_ptr(uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); } -void CUBLASWINAPI cublasCher2k (char uplo, char trans, int n, int k, - cuComplex alpha, const cuComplex *A, int lda, - const cuComplex *B, int ldb, float beta, - cuComplex *C, int ldc) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, cuComplex, const cuComplex *, int, const cuComplex *, int, float, cuComplex *, int); +void CUBLASWINAPI cublasCher2k(char uplo, char trans, int n, int k, + cuComplex alpha, const cuComplex *A, int lda, + const cuComplex *B, int ldb, float beta, + cuComplex *C, int ldc) { + using FuncPtr = void(CUBLASWINAPI *)( + char, char, int, int, cuComplex, const cuComplex *, int, + const cuComplex *, int, float, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCher2k"); if (!func_ptr) LogFatalSymbolNotFound("cublasCher2k"); return func_ptr(uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); } -void CUBLASWINAPI cublasZher2k (char uplo, char trans, int n, int k, - cuDoubleComplex alpha, const cuDoubleComplex *A, int lda, - const cuDoubleComplex *B, int ldb, double beta, - cuDoubleComplex *C, int ldc) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, cuDoubleComplex, const cuDoubleComplex *, int, const cuDoubleComplex *, int, double, cuDoubleComplex *, int); +void CUBLASWINAPI cublasZher2k(char uplo, char trans, int n, int k, + cuDoubleComplex alpha, const cuDoubleComplex *A, + int lda, const cuDoubleComplex *B, int ldb, + double beta, cuDoubleComplex *C, int ldc) { + using FuncPtr = void(CUBLASWINAPI *)( + char, char, int, int, cuDoubleComplex, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, double, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZher2k"); if (!func_ptr) LogFatalSymbolNotFound("cublasZher2k"); return func_ptr(uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); } -void CUBLASWINAPI cublasSsymm (char side, char uplo, int m, int n, float alpha, - const float *A, int lda, const float *B, int ldb, - float beta, float *C, int ldc) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, float, const float *, int, const float *, int, float, float *, int); +void CUBLASWINAPI cublasSsymm(char side, char uplo, int m, int n, float alpha, + const float *A, int lda, const float *B, int ldb, + float beta, float *C, int ldc) { + using FuncPtr = + void(CUBLASWINAPI *)(char, char, int, int, float, const float *, int, + const float *, int, float, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSsymm"); if (!func_ptr) LogFatalSymbolNotFound("cublasSsymm"); return func_ptr(side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc); } -void CUBLASWINAPI cublasDsymm (char side, char uplo, int m, int n, double alpha, - const double *A, int lda, const double *B, int ldb, - double beta, double *C, int ldc) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, double, const double *, int, const double *, int, double, double *, int); +void CUBLASWINAPI cublasDsymm(char side, char uplo, int m, int n, double alpha, + const double *A, int lda, const double *B, + int ldb, double beta, double *C, int ldc) { + using FuncPtr = + void(CUBLASWINAPI *)(char, char, int, int, double, const double *, int, + const double *, int, double, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDsymm"); if (!func_ptr) LogFatalSymbolNotFound("cublasDsymm"); return func_ptr(side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc); } -void CUBLASWINAPI cublasCsymm (char side, char uplo, int m, int n, cuComplex alpha, - const cuComplex *A, int lda, const cuComplex *B, int ldb, - cuComplex beta, cuComplex *C, int ldc) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, cuComplex, const cuComplex *, int, const cuComplex *, int, cuComplex, cuComplex *, int); +void CUBLASWINAPI cublasCsymm(char side, char uplo, int m, int n, + cuComplex alpha, const cuComplex *A, int lda, + const cuComplex *B, int ldb, cuComplex beta, + cuComplex *C, int ldc) { + using FuncPtr = void(CUBLASWINAPI *)( + char, char, int, int, cuComplex, const cuComplex *, int, + const cuComplex *, int, cuComplex, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCsymm"); if (!func_ptr) LogFatalSymbolNotFound("cublasCsymm"); return func_ptr(side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc); } -void CUBLASWINAPI cublasZsymm (char side, char uplo, int m, int n, cuDoubleComplex alpha, - const cuDoubleComplex *A, int lda, const cuDoubleComplex *B, int ldb, - cuDoubleComplex beta, cuDoubleComplex *C, int ldc) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, cuDoubleComplex, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); +void CUBLASWINAPI cublasZsymm(char side, char uplo, int m, int n, + cuDoubleComplex alpha, const cuDoubleComplex *A, + int lda, const cuDoubleComplex *B, int ldb, + cuDoubleComplex beta, cuDoubleComplex *C, + int ldc) { + using FuncPtr = void(CUBLASWINAPI *)( + char, char, int, int, cuDoubleComplex, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZsymm"); if (!func_ptr) LogFatalSymbolNotFound("cublasZsymm"); return func_ptr(side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc); } -void CUBLASWINAPI cublasChemm (char side, char uplo, int m, int n, - cuComplex alpha, const cuComplex *A, int lda, - const cuComplex *B, int ldb, cuComplex beta, - cuComplex *C, int ldc) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, cuComplex, const cuComplex *, int, const cuComplex *, int, cuComplex, cuComplex *, int); +void CUBLASWINAPI cublasChemm(char side, char uplo, int m, int n, + cuComplex alpha, const cuComplex *A, int lda, + const cuComplex *B, int ldb, cuComplex beta, + cuComplex *C, int ldc) { + using FuncPtr = void(CUBLASWINAPI *)( + char, char, int, int, cuComplex, const cuComplex *, int, + const cuComplex *, int, cuComplex, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasChemm"); if (!func_ptr) LogFatalSymbolNotFound("cublasChemm"); return func_ptr(side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc); } -void CUBLASWINAPI cublasZhemm (char side, char uplo, int m, int n, - cuDoubleComplex alpha, const cuDoubleComplex *A, int lda, - const cuDoubleComplex *B, int ldb, cuDoubleComplex beta, - cuDoubleComplex *C, int ldc) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, cuDoubleComplex, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); +void CUBLASWINAPI cublasZhemm(char side, char uplo, int m, int n, + cuDoubleComplex alpha, const cuDoubleComplex *A, + int lda, const cuDoubleComplex *B, int ldb, + cuDoubleComplex beta, cuDoubleComplex *C, + int ldc) { + using FuncPtr = void(CUBLASWINAPI *)( + char, char, int, int, cuDoubleComplex, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZhemm"); if (!func_ptr) LogFatalSymbolNotFound("cublasZhemm"); return func_ptr(side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc); } -void CUBLASWINAPI cublasStrsm (char side, char uplo, char transa, char diag, - int m, int n, float alpha, const float *A, int lda, - float *B, int ldb) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, char, int, int, float, const float *, int, float *, int); +void CUBLASWINAPI cublasStrsm(char side, char uplo, char transa, char diag, + int m, int n, float alpha, const float *A, + int lda, float *B, int ldb) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, char, int, int, float, + const float *, int, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasStrsm"); if (!func_ptr) LogFatalSymbolNotFound("cublasStrsm"); return func_ptr(side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb); } -void CUBLASWINAPI cublasDtrsm (char side, char uplo, char transa, - char diag, int m, int n, double alpha, - const double *A, int lda, double *B, - int ldb) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, char, int, int, double, const double *, int, double *, int); +void CUBLASWINAPI cublasDtrsm(char side, char uplo, char transa, char diag, + int m, int n, double alpha, const double *A, + int lda, double *B, int ldb) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, char, int, int, double, + const double *, int, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtrsm"); if (!func_ptr) LogFatalSymbolNotFound("cublasDtrsm"); return func_ptr(side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb); } -void CUBLASWINAPI cublasCtrsm (char side, char uplo, char transa, char diag, - int m, int n, cuComplex alpha, const cuComplex *A, - int lda, cuComplex *B, int ldb) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, char, int, int, cuComplex, const cuComplex *, int, cuComplex *, int); +void CUBLASWINAPI cublasCtrsm(char side, char uplo, char transa, char diag, + int m, int n, cuComplex alpha, const cuComplex *A, + int lda, cuComplex *B, int ldb) { + using FuncPtr = + void(CUBLASWINAPI *)(char, char, char, char, int, int, cuComplex, + const cuComplex *, int, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtrsm"); if (!func_ptr) LogFatalSymbolNotFound("cublasCtrsm"); return func_ptr(side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb); } -void CUBLASWINAPI cublasZtrsm (char side, char uplo, char transa, - char diag, int m, int n, cuDoubleComplex alpha, - const cuDoubleComplex *A, int lda, - cuDoubleComplex *B, int ldb) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, char, int, int, cuDoubleComplex, const cuDoubleComplex *, int, cuDoubleComplex *, int); +void CUBLASWINAPI cublasZtrsm(char side, char uplo, char transa, char diag, + int m, int n, cuDoubleComplex alpha, + const cuDoubleComplex *A, int lda, + cuDoubleComplex *B, int ldb) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, char, int, int, + cuDoubleComplex, const cuDoubleComplex *, + int, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtrsm"); if (!func_ptr) LogFatalSymbolNotFound("cublasZtrsm"); return func_ptr(side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb); } -void CUBLASWINAPI cublasStrmm (char side, char uplo, char transa, char diag, - int m, int n, float alpha, const float *A, int lda, - float *B, int ldb) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, char, int, int, float, const float *, int, float *, int); +void CUBLASWINAPI cublasStrmm(char side, char uplo, char transa, char diag, + int m, int n, float alpha, const float *A, + int lda, float *B, int ldb) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, char, int, int, float, + const float *, int, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasStrmm"); if (!func_ptr) LogFatalSymbolNotFound("cublasStrmm"); return func_ptr(side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb); } -void CUBLASWINAPI cublasDtrmm (char side, char uplo, char transa, - char diag, int m, int n, double alpha, - const double *A, int lda, double *B, - int ldb) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, char, int, int, double, const double *, int, double *, int); +void CUBLASWINAPI cublasDtrmm(char side, char uplo, char transa, char diag, + int m, int n, double alpha, const double *A, + int lda, double *B, int ldb) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, char, int, int, double, + const double *, int, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtrmm"); if (!func_ptr) LogFatalSymbolNotFound("cublasDtrmm"); return func_ptr(side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb); } -void CUBLASWINAPI cublasCtrmm (char side, char uplo, char transa, char diag, - int m, int n, cuComplex alpha, const cuComplex *A, - int lda, cuComplex *B, int ldb) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, char, int, int, cuComplex, const cuComplex *, int, cuComplex *, int); +void CUBLASWINAPI cublasCtrmm(char side, char uplo, char transa, char diag, + int m, int n, cuComplex alpha, const cuComplex *A, + int lda, cuComplex *B, int ldb) { + using FuncPtr = + void(CUBLASWINAPI *)(char, char, char, char, int, int, cuComplex, + const cuComplex *, int, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtrmm"); if (!func_ptr) LogFatalSymbolNotFound("cublasCtrmm"); return func_ptr(side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb); } -void CUBLASWINAPI cublasZtrmm (char side, char uplo, char transa, - char diag, int m, int n, cuDoubleComplex alpha, - const cuDoubleComplex *A, int lda, cuDoubleComplex *B, - int ldb) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, char, int, int, cuDoubleComplex, const cuDoubleComplex *, int, cuDoubleComplex *, int); +void CUBLASWINAPI cublasZtrmm(char side, char uplo, char transa, char diag, + int m, int n, cuDoubleComplex alpha, + const cuDoubleComplex *A, int lda, + cuDoubleComplex *B, int ldb) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, char, int, int, + cuDoubleComplex, const cuDoubleComplex *, + int, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtrmm"); if (!func_ptr) LogFatalSymbolNotFound("cublasZtrmm"); return func_ptr(side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb); diff --git a/tensorflow/stream_executor/cuda/cublas_11_0.inc b/tensorflow/stream_executor/cuda/cublas_11_0.inc new file mode 100644 index 00000000000..c30b2cf8f68 --- /dev/null +++ b/tensorflow/stream_executor/cuda/cublas_11_0.inc @@ -0,0 +1,5023 @@ +// Auto-generated, do not edit. + +extern "C" { + +cublasStatus_t CUBLASWINAPI cublasCreate_v2(cublasHandle_t *handle) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCreate_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle); +} + +cublasStatus_t CUBLASWINAPI cublasDestroy_v2(cublasHandle_t handle) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDestroy_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle); +} + +cublasStatus_t CUBLASWINAPI cublasGetVersion_v2(cublasHandle_t handle, + int *version) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasGetVersion_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, version); +} + +cublasStatus_t CUBLASWINAPI cublasGetProperty(libraryPropertyType type, + int *value) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(libraryPropertyType, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasGetProperty"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(type, value); +} + +size_t CUBLASWINAPI cublasGetCudartVersion(void) { + using FuncPtr = size_t(CUBLASWINAPI *)(); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasGetCudartVersion"); + if (!func_ptr) LogFatalSymbolNotFound("cublasGetCudartVersion"); + return func_ptr(); +} + +cublasStatus_t CUBLASWINAPI cublasSetStream_v2(cublasHandle_t handle, + cudaStream_t streamId) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, cudaStream_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSetStream_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, streamId); +} + +cublasStatus_t CUBLASWINAPI cublasGetStream_v2(cublasHandle_t handle, + cudaStream_t *streamId) { + using FuncPtr = + cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, cudaStream_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasGetStream_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, streamId); +} + +cublasStatus_t CUBLASWINAPI cublasGetPointerMode_v2(cublasHandle_t handle, + cublasPointerMode_t *mode) { + using FuncPtr = + cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, cublasPointerMode_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasGetPointerMode_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, mode); +} + +cublasStatus_t CUBLASWINAPI cublasSetPointerMode_v2(cublasHandle_t handle, + cublasPointerMode_t mode) { + using FuncPtr = + cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, cublasPointerMode_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSetPointerMode_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, mode); +} + +cublasStatus_t CUBLASWINAPI cublasGetAtomicsMode(cublasHandle_t handle, + cublasAtomicsMode_t *mode) { + using FuncPtr = + cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, cublasAtomicsMode_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasGetAtomicsMode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, mode); +} + +cublasStatus_t CUBLASWINAPI cublasSetAtomicsMode(cublasHandle_t handle, + cublasAtomicsMode_t mode) { + using FuncPtr = + cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, cublasAtomicsMode_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSetAtomicsMode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, mode); +} + +cublasStatus_t CUBLASWINAPI cublasGetMathMode(cublasHandle_t handle, + cublasMath_t *mode) { + using FuncPtr = + cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, cublasMath_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasGetMathMode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, mode); +} + +cublasStatus_t CUBLASWINAPI cublasSetMathMode(cublasHandle_t handle, + cublasMath_t mode) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, cublasMath_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSetMathMode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, mode); +} + +cublasStatus_t CUBLASWINAPI cublasLoggerConfigure(int logIsOn, int logToStdOut, + int logToStdErr, + const char *logFileName) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(int, int, int, const char *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasLoggerConfigure"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(logIsOn, logToStdOut, logToStdErr, logFileName); +} + +cublasStatus_t CUBLASWINAPI +cublasSetLoggerCallback(cublasLogCallback userCallback) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasLogCallback); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSetLoggerCallback"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(userCallback); +} + +cublasStatus_t CUBLASWINAPI +cublasGetLoggerCallback(cublasLogCallback *userCallback) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasLogCallback *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasGetLoggerCallback"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(userCallback); +} + +cublasStatus_t CUBLASWINAPI cublasSetVector(int n, int elemSize, const void *x, + int incx, void *devicePtr, + int incy) { + using FuncPtr = + cublasStatus_t(CUBLASWINAPI *)(int, int, const void *, int, void *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSetVector"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(n, elemSize, x, incx, devicePtr, incy); +} + +cublasStatus_t CUBLASWINAPI cublasGetVector(int n, int elemSize, const void *x, + int incx, void *y, int incy) { + using FuncPtr = + cublasStatus_t(CUBLASWINAPI *)(int, int, const void *, int, void *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasGetVector"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(n, elemSize, x, incx, y, incy); +} + +cublasStatus_t CUBLASWINAPI cublasSetMatrix(int rows, int cols, int elemSize, + const void *A, int lda, void *B, + int ldb) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(int, int, int, const void *, + int, void *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSetMatrix"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(rows, cols, elemSize, A, lda, B, ldb); +} + +cublasStatus_t CUBLASWINAPI cublasGetMatrix(int rows, int cols, int elemSize, + const void *A, int lda, void *B, + int ldb) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(int, int, int, const void *, + int, void *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasGetMatrix"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(rows, cols, elemSize, A, lda, B, ldb); +} + +cublasStatus_t CUBLASWINAPI cublasSetVectorAsync(int n, int elemSize, + const void *hostPtr, int incx, + void *devicePtr, int incy, + cudaStream_t stream) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(int, int, const void *, int, + void *, int, cudaStream_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSetVectorAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(n, elemSize, hostPtr, incx, devicePtr, incy, stream); +} + +cublasStatus_t CUBLASWINAPI cublasGetVectorAsync(int n, int elemSize, + const void *devicePtr, + int incx, void *hostPtr, + int incy, + cudaStream_t stream) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(int, int, const void *, int, + void *, int, cudaStream_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasGetVectorAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(n, elemSize, devicePtr, incx, hostPtr, incy, stream); +} + +cublasStatus_t CUBLASWINAPI cublasSetMatrixAsync(int rows, int cols, + int elemSize, const void *A, + int lda, void *B, int ldb, + cudaStream_t stream) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + int, int, int, const void *, int, void *, int, cudaStream_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSetMatrixAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(rows, cols, elemSize, A, lda, B, ldb, stream); +} + +cublasStatus_t CUBLASWINAPI cublasGetMatrixAsync(int rows, int cols, + int elemSize, const void *A, + int lda, void *B, int ldb, + cudaStream_t stream) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + int, int, int, const void *, int, void *, int, cudaStream_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasGetMatrixAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(rows, cols, elemSize, A, lda, B, ldb, stream); +} + +void CUBLASWINAPI cublasXerbla(const char *srName, int info) { + using FuncPtr = void(CUBLASWINAPI *)(const char *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasXerbla"); + if (!func_ptr) LogFatalSymbolNotFound("cublasXerbla"); + return func_ptr(srName, info); +} + +cublasStatus_t CUBLASWINAPI cublasNrm2Ex(cublasHandle_t handle, int n, + const void *x, cudaDataType xType, + int incx, void *result, + cudaDataType resultType, + cudaDataType executionType) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const void *, cudaDataType, int, void *, + cudaDataType, cudaDataType); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasNrm2Ex"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, xType, incx, result, resultType, executionType); +} + +cublasStatus_t CUBLASWINAPI cublasSnrm2_v2(cublasHandle_t handle, int n, + const float *x, int incx, + float *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, + const float *, int, float *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSnrm2_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, result); +} + +cublasStatus_t CUBLASWINAPI cublasDnrm2_v2(cublasHandle_t handle, int n, + const double *x, int incx, + double *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, + const double *, int, double *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDnrm2_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, result); +} + +cublasStatus_t CUBLASWINAPI cublasScnrm2_v2(cublasHandle_t handle, int n, + const cuComplex *x, int incx, + float *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const cuComplex *, int, float *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasScnrm2_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, result); +} + +cublasStatus_t CUBLASWINAPI cublasDznrm2_v2(cublasHandle_t handle, int n, + const cuDoubleComplex *x, int incx, + double *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const cuDoubleComplex *, int, double *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDznrm2_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, result); +} + +cublasStatus_t CUBLASWINAPI cublasDotEx(cublasHandle_t handle, int n, + const void *x, cudaDataType xType, + int incx, const void *y, + cudaDataType yType, int incy, + void *result, cudaDataType resultType, + cudaDataType executionType) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const void *, cudaDataType, int, const void *, + cudaDataType, int, void *, cudaDataType, cudaDataType); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDotEx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, xType, incx, y, yType, incy, result, resultType, + executionType); +} + +cublasStatus_t CUBLASWINAPI cublasDotcEx(cublasHandle_t handle, int n, + const void *x, cudaDataType xType, + int incx, const void *y, + cudaDataType yType, int incy, + void *result, cudaDataType resultType, + cudaDataType executionType) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const void *, cudaDataType, int, const void *, + cudaDataType, int, void *, cudaDataType, cudaDataType); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDotcEx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, xType, incx, y, yType, incy, result, resultType, + executionType); +} + +cublasStatus_t CUBLASWINAPI cublasSdot_v2(cublasHandle_t handle, int n, + const float *x, int incx, + const float *y, int incy, + float *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const float *, int, const float *, int, float *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSdot_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, y, incy, result); +} + +cublasStatus_t CUBLASWINAPI cublasDdot_v2(cublasHandle_t handle, int n, + const double *x, int incx, + const double *y, int incy, + double *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const double *, int, const double *, int, double *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDdot_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, y, incy, result); +} + +cublasStatus_t CUBLASWINAPI cublasCdotu_v2(cublasHandle_t handle, int n, + const cuComplex *x, int incx, + const cuComplex *y, int incy, + cuComplex *result) { + using FuncPtr = + cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, const cuComplex *, + int, const cuComplex *, int, cuComplex *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCdotu_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, y, incy, result); +} + +cublasStatus_t CUBLASWINAPI cublasCdotc_v2(cublasHandle_t handle, int n, + const cuComplex *x, int incx, + const cuComplex *y, int incy, + cuComplex *result) { + using FuncPtr = + cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, const cuComplex *, + int, const cuComplex *, int, cuComplex *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCdotc_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, y, incy, result); +} + +cublasStatus_t CUBLASWINAPI cublasZdotu_v2(cublasHandle_t handle, int n, + const cuDoubleComplex *x, int incx, + const cuDoubleComplex *y, int incy, + cuDoubleComplex *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, cuDoubleComplex *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZdotu_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, y, incy, result); +} + +cublasStatus_t CUBLASWINAPI cublasZdotc_v2(cublasHandle_t handle, int n, + const cuDoubleComplex *x, int incx, + const cuDoubleComplex *y, int incy, + cuDoubleComplex *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, cuDoubleComplex *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZdotc_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, y, incy, result); +} + +cublasStatus_t CUBLASWINAPI +cublasScalEx(cublasHandle_t handle, int n, + const void *alpha, /* host or device pointer */ + cudaDataType alphaType, void *x, cudaDataType xType, int incx, + cudaDataType executionType) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const void *, cudaDataType, void *, cudaDataType, + int, cudaDataType); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasScalEx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, alpha, alphaType, x, xType, incx, executionType); +} + +cublasStatus_t CUBLASWINAPI +cublasSscal_v2(cublasHandle_t handle, int n, + const float *alpha, /* host or device pointer */ + float *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, + const float *, float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSscal_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, alpha, x, incx); +} + +cublasStatus_t CUBLASWINAPI +cublasDscal_v2(cublasHandle_t handle, int n, + const double *alpha, /* host or device pointer */ + double *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, + const double *, double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDscal_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, alpha, x, incx); +} + +cublasStatus_t CUBLASWINAPI +cublasCscal_v2(cublasHandle_t handle, int n, + const cuComplex *alpha, /* host or device pointer */ + cuComplex *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const cuComplex *, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCscal_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, alpha, x, incx); +} + +cublasStatus_t CUBLASWINAPI +cublasCsscal_v2(cublasHandle_t handle, int n, + const float *alpha, /* host or device pointer */ + cuComplex *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const float *, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCsscal_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, alpha, x, incx); +} + +cublasStatus_t CUBLASWINAPI +cublasZscal_v2(cublasHandle_t handle, int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + cuDoubleComplex *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const cuDoubleComplex *, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZscal_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, alpha, x, incx); +} + +cublasStatus_t CUBLASWINAPI +cublasZdscal_v2(cublasHandle_t handle, int n, + const double *alpha, /* host or device pointer */ + cuDoubleComplex *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const double *, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZdscal_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, alpha, x, incx); +} + +cublasStatus_t CUBLASWINAPI cublasAxpyEx( + cublasHandle_t handle, int n, + const void *alpha, /* host or device pointer */ + cudaDataType alphaType, const void *x, cudaDataType xType, int incx, + void *y, cudaDataType yType, int incy, cudaDataType executiontype) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const void *, cudaDataType, const void *, + cudaDataType, int, void *, cudaDataType, int, cudaDataType); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasAxpyEx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, alpha, alphaType, x, xType, incx, y, yType, incy, + executiontype); +} + +cublasStatus_t CUBLASWINAPI +cublasSaxpy_v2(cublasHandle_t handle, int n, + const float *alpha, /* host or device pointer */ + const float *x, int incx, float *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const float *, const float *, int, float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSaxpy_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, alpha, x, incx, y, incy); +} + +cublasStatus_t CUBLASWINAPI +cublasDaxpy_v2(cublasHandle_t handle, int n, + const double *alpha, /* host or device pointer */ + const double *x, int incx, double *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const double *, const double *, int, double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDaxpy_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, alpha, x, incx, y, incy); +} + +cublasStatus_t CUBLASWINAPI +cublasCaxpy_v2(cublasHandle_t handle, int n, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *x, int incx, cuComplex *y, int incy) { + using FuncPtr = + cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, const cuComplex *, + const cuComplex *, int, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCaxpy_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, alpha, x, incx, y, incy); +} + +cublasStatus_t CUBLASWINAPI cublasZaxpy_v2( + cublasHandle_t handle, int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *x, int incx, cuDoubleComplex *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const cuDoubleComplex *, const cuDoubleComplex *, + int, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZaxpy_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, alpha, x, incx, y, incy); +} + +cublasStatus_t CUBLASWINAPI cublasCopyEx(cublasHandle_t handle, int n, + const void *x, cudaDataType xType, + int incx, void *y, cudaDataType yType, + int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const void *, cudaDataType, int, void *, + cudaDataType, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCopyEx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, xType, incx, y, yType, incy); +} + +cublasStatus_t CUBLASWINAPI cublasScopy_v2(cublasHandle_t handle, int n, + const float *x, int incx, float *y, + int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const float *, int, float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasScopy_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, y, incy); +} + +cublasStatus_t CUBLASWINAPI cublasDcopy_v2(cublasHandle_t handle, int n, + const double *x, int incx, double *y, + int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const double *, int, double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDcopy_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, y, incy); +} + +cublasStatus_t CUBLASWINAPI cublasCcopy_v2(cublasHandle_t handle, int n, + const cuComplex *x, int incx, + cuComplex *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const cuComplex *, int, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCcopy_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, y, incy); +} + +cublasStatus_t CUBLASWINAPI cublasZcopy_v2(cublasHandle_t handle, int n, + const cuDoubleComplex *x, int incx, + cuDoubleComplex *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, + const cuDoubleComplex *, int, + cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZcopy_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, y, incy); +} + +cublasStatus_t CUBLASWINAPI cublasSswap_v2(cublasHandle_t handle, int n, + float *x, int incx, float *y, + int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, float *, + int, float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSswap_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, y, incy); +} + +cublasStatus_t CUBLASWINAPI cublasDswap_v2(cublasHandle_t handle, int n, + double *x, int incx, double *y, + int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, double *, + int, double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDswap_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, y, incy); +} + +cublasStatus_t CUBLASWINAPI cublasCswap_v2(cublasHandle_t handle, int n, + cuComplex *x, int incx, cuComplex *y, + int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, cuComplex *, int, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCswap_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, y, incy); +} + +cublasStatus_t CUBLASWINAPI cublasZswap_v2(cublasHandle_t handle, int n, + cuDoubleComplex *x, int incx, + cuDoubleComplex *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, cuDoubleComplex *, int, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZswap_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, y, incy); +} + +cublasStatus_t CUBLASWINAPI cublasSwapEx(cublasHandle_t handle, int n, void *x, + cudaDataType xType, int incx, void *y, + cudaDataType yType, int incy) { + using FuncPtr = + cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, void *, cudaDataType, + int, void *, cudaDataType, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSwapEx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, xType, incx, y, yType, incy); +} + +cublasStatus_t CUBLASWINAPI cublasIsamax_v2(cublasHandle_t handle, int n, + const float *x, int incx, + int *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, + const float *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasIsamax_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, result); +} + +cublasStatus_t CUBLASWINAPI cublasIdamax_v2(cublasHandle_t handle, int n, + const double *x, int incx, + int *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, + const double *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasIdamax_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, result); +} + +cublasStatus_t CUBLASWINAPI cublasIcamax_v2(cublasHandle_t handle, int n, + const cuComplex *x, int incx, + int *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, + const cuComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasIcamax_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, result); +} + +cublasStatus_t CUBLASWINAPI cublasIzamax_v2(cublasHandle_t handle, int n, + const cuDoubleComplex *x, int incx, + int *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasIzamax_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, result); +} + +cublasStatus_t CUBLASWINAPI cublasIamaxEx( + cublasHandle_t handle, int n, const void *x, cudaDataType xType, int incx, + int *result /* host or device pointer */ +) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const void *, cudaDataType, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasIamaxEx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, xType, incx, result); +} + +cublasStatus_t CUBLASWINAPI cublasIsamin_v2(cublasHandle_t handle, int n, + const float *x, int incx, + int *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, + const float *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasIsamin_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, result); +} + +cublasStatus_t CUBLASWINAPI cublasIdamin_v2(cublasHandle_t handle, int n, + const double *x, int incx, + int *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, + const double *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasIdamin_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, result); +} + +cublasStatus_t CUBLASWINAPI cublasIcamin_v2(cublasHandle_t handle, int n, + const cuComplex *x, int incx, + int *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, + const cuComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasIcamin_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, result); +} + +cublasStatus_t CUBLASWINAPI cublasIzamin_v2(cublasHandle_t handle, int n, + const cuDoubleComplex *x, int incx, + int *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasIzamin_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, result); +} + +cublasStatus_t CUBLASWINAPI cublasIaminEx( + cublasHandle_t handle, int n, const void *x, cudaDataType xType, int incx, + int *result /* host or device pointer */ +) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const void *, cudaDataType, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasIaminEx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, xType, incx, result); +} + +cublasStatus_t CUBLASWINAPI cublasAsumEx( + cublasHandle_t handle, int n, const void *x, cudaDataType xType, int incx, + void *result, cudaDataType resultType, /* host or device pointer */ + cudaDataType executiontype) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const void *, cudaDataType, int, void *, + cudaDataType, cudaDataType); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasAsumEx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, xType, incx, result, resultType, executiontype); +} + +cublasStatus_t CUBLASWINAPI cublasSasum_v2(cublasHandle_t handle, int n, + const float *x, int incx, + float *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, + const float *, int, float *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSasum_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, result); +} + +cublasStatus_t CUBLASWINAPI cublasDasum_v2(cublasHandle_t handle, int n, + const double *x, int incx, + double *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, + const double *, int, double *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDasum_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, result); +} + +cublasStatus_t CUBLASWINAPI cublasScasum_v2(cublasHandle_t handle, int n, + const cuComplex *x, int incx, + float *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const cuComplex *, int, float *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasScasum_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, result); +} + +cublasStatus_t CUBLASWINAPI cublasDzasum_v2(cublasHandle_t handle, int n, + const cuDoubleComplex *x, int incx, + double *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const cuDoubleComplex *, int, double *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDzasum_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, result); +} + +cublasStatus_t CUBLASWINAPI +cublasSrot_v2(cublasHandle_t handle, int n, float *x, int incx, float *y, + int incy, const float *c, /* host or device pointer */ + const float *s) { + using FuncPtr = + cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, float *, int, float *, + int, const float *, const float *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSrot_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, y, incy, c, s); +} + +cublasStatus_t CUBLASWINAPI +cublasDrot_v2(cublasHandle_t handle, int n, double *x, int incx, double *y, + int incy, const double *c, /* host or device pointer */ + const double *s) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, double *, int, double *, int, const double *, + const double *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDrot_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, y, incy, c, s); +} + +cublasStatus_t CUBLASWINAPI cublasCrot_v2( + cublasHandle_t handle, int n, cuComplex *x, int incx, cuComplex *y, + int incy, const float *c, /* host or device pointer */ + const cuComplex *s) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, cuComplex *, int, cuComplex *, int, const float *, + const cuComplex *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCrot_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, y, incy, c, s); +} + +cublasStatus_t CUBLASWINAPI cublasCsrot_v2( + cublasHandle_t handle, int n, cuComplex *x, int incx, cuComplex *y, + int incy, const float *c, /* host or device pointer */ + const float *s) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, cuComplex *, int, cuComplex *, int, const float *, + const float *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCsrot_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, y, incy, c, s); +} + +cublasStatus_t CUBLASWINAPI cublasZrot_v2( + cublasHandle_t handle, int n, cuDoubleComplex *x, int incx, + cuDoubleComplex *y, int incy, const double *c, /* host or device pointer */ + const cuDoubleComplex *s) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, cuDoubleComplex *, int, cuDoubleComplex *, int, + const double *, const cuDoubleComplex *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZrot_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, y, incy, c, s); +} + +cublasStatus_t CUBLASWINAPI cublasZdrot_v2( + cublasHandle_t handle, int n, cuDoubleComplex *x, int incx, + cuDoubleComplex *y, int incy, const double *c, /* host or device pointer */ + const double *s) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, cuDoubleComplex *, int, cuDoubleComplex *, int, + const double *, const double *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZdrot_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, y, incy, c, s); +} + +cublasStatus_t CUBLASWINAPI +cublasRotEx(cublasHandle_t handle, int n, void *x, cudaDataType xType, int incx, + void *y, cudaDataType yType, int incy, + const void *c, /* host or device pointer */ + const void *s, cudaDataType csType, cudaDataType executiontype) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, void *, cudaDataType, int, void *, cudaDataType, int, + const void *, const void *, cudaDataType, cudaDataType); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasRotEx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, xType, incx, y, yType, incy, c, s, csType, + executiontype); +} + +cublasStatus_t CUBLASWINAPI +cublasSrotg_v2(cublasHandle_t handle, float *a, /* host or device pointer */ + float *b, /* host or device pointer */ + float *c, /* host or device pointer */ + float *s) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, float *, + float *, float *, float *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSrotg_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, a, b, c, s); +} + +cublasStatus_t CUBLASWINAPI +cublasDrotg_v2(cublasHandle_t handle, double *a, /* host or device pointer */ + double *b, /* host or device pointer */ + double *c, /* host or device pointer */ + double *s) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, double *, + double *, double *, double *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDrotg_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, a, b, c, s); +} + +cublasStatus_t CUBLASWINAPI +cublasCrotg_v2(cublasHandle_t handle, cuComplex *a, /* host or device pointer */ + cuComplex *b, /* host or device pointer */ + float *c, /* host or device pointer */ + cuComplex *s) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cuComplex *, cuComplex *, float *, cuComplex *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCrotg_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, a, b, c, s); +} + +cublasStatus_t CUBLASWINAPI cublasZrotg_v2( + cublasHandle_t handle, cuDoubleComplex *a, /* host or device pointer */ + cuDoubleComplex *b, /* host or device pointer */ + double *c, /* host or device pointer */ + cuDoubleComplex *s) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cuDoubleComplex *, cuDoubleComplex *, double *, + cuDoubleComplex *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZrotg_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, a, b, c, s); +} + +cublasStatus_t CUBLASWINAPI cublasRotgEx(cublasHandle_t handle, + void *a, /* host or device pointer */ + void *b, /* host or device pointer */ + cudaDataType abType, + void *c, /* host or device pointer */ + void *s, /* host or device pointer */ + cudaDataType csType, + cudaDataType executiontype) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, void *, void *, + cudaDataType, void *, void *, + cudaDataType, cudaDataType); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasRotgEx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, a, b, abType, c, s, csType, executiontype); +} + +cublasStatus_t CUBLASWINAPI cublasSrotm_v2(cublasHandle_t handle, int n, + float *x, int incx, float *y, + int incy, const float *param) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, float *, int, float *, int, const float *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSrotm_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, y, incy, param); +} + +cublasStatus_t CUBLASWINAPI cublasDrotm_v2(cublasHandle_t handle, int n, + double *x, int incx, double *y, + int incy, const double *param) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, double *, int, double *, int, const double *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDrotm_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, incx, y, incy, param); +} + +cublasStatus_t CUBLASWINAPI +cublasRotmEx(cublasHandle_t handle, int n, void *x, cudaDataType xType, + int incx, void *y, cudaDataType yType, int incy, + const void *param, /* host or device pointer */ + cudaDataType paramType, cudaDataType executiontype) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, void *, cudaDataType, int, void *, cudaDataType, int, + const void *, cudaDataType, cudaDataType); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasRotmEx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, x, xType, incx, y, yType, incy, param, paramType, + executiontype); +} + +cublasStatus_t CUBLASWINAPI +cublasSrotmg_v2(cublasHandle_t handle, float *d1, /* host or device pointer */ + float *d2, /* host or device pointer */ + float *x1, /* host or device pointer */ + const float *y1, /* host or device pointer */ + float *param) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, float *, float *, float *, const float *, float *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSrotmg_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, d1, d2, x1, y1, param); +} + +cublasStatus_t CUBLASWINAPI +cublasDrotmg_v2(cublasHandle_t handle, double *d1, /* host or device pointer */ + double *d2, /* host or device pointer */ + double *x1, /* host or device pointer */ + const double *y1, /* host or device pointer */ + double *param) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, double *, double *, double *, const double *, double *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDrotmg_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, d1, d2, x1, y1, param); +} + +cublasStatus_t CUBLASWINAPI +cublasRotmgEx(cublasHandle_t handle, void *d1, /* host or device pointer */ + cudaDataType d1Type, void *d2, /* host or device pointer */ + cudaDataType d2Type, void *x1, /* host or device pointer */ + cudaDataType x1Type, const void *y1, /* host or device pointer */ + cudaDataType y1Type, void *param, /* host or device pointer */ + cudaDataType paramType, cudaDataType executiontype) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, void *, cudaDataType, void *, cudaDataType, void *, + cudaDataType, const void *, cudaDataType, void *, cudaDataType, + cudaDataType); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasRotmgEx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, d1, d1Type, d2, d2Type, x1, x1Type, y1, y1Type, param, + paramType, executiontype); +} + +cublasStatus_t CUBLASWINAPI +cublasSgemv_v2(cublasHandle_t handle, cublasOperation_t trans, int m, int n, + const float *alpha, /* host or device pointer */ + const float *A, int lda, const float *x, int incx, + const float *beta, /* host or device pointer */ + float *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, int, int, const float *, const float *, + int, const float *, int, const float *, float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSgemv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, m, n, alpha, A, lda, x, incx, beta, y, incy); +} + +cublasStatus_t CUBLASWINAPI +cublasDgemv_v2(cublasHandle_t handle, cublasOperation_t trans, int m, int n, + const double *alpha, /* host or device pointer */ + const double *A, int lda, const double *x, int incx, + const double *beta, /* host or device pointer */ + double *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, int, int, const double *, + const double *, int, const double *, int, const double *, double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDgemv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, m, n, alpha, A, lda, x, incx, beta, y, incy); +} + +cublasStatus_t CUBLASWINAPI +cublasCgemv_v2(cublasHandle_t handle, cublasOperation_t trans, int m, int n, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, const cuComplex *x, int incx, + const cuComplex *beta, /* host or device pointer */ + cuComplex *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, int, int, const cuComplex *, + const cuComplex *, int, const cuComplex *, int, const cuComplex *, + cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgemv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, m, n, alpha, A, lda, x, incx, beta, y, incy); +} + +cublasStatus_t CUBLASWINAPI cublasZgemv_v2( + cublasHandle_t handle, cublasOperation_t trans, int m, int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, const cuDoubleComplex *x, int incx, + const cuDoubleComplex *beta, /* host or device pointer */ + cuDoubleComplex *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, int, int, const cuDoubleComplex *, + const cuDoubleComplex *, int, const cuDoubleComplex *, int, + const cuDoubleComplex *, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZgemv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, m, n, alpha, A, lda, x, incx, beta, y, incy); +} + +cublasStatus_t CUBLASWINAPI +cublasSgbmv_v2(cublasHandle_t handle, cublasOperation_t trans, int m, int n, + int kl, int ku, const float *alpha, /* host or device pointer */ + const float *A, int lda, const float *x, int incx, + const float *beta, /* host or device pointer */ + float *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, int, int, int, int, const float *, + const float *, int, const float *, int, const float *, float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSgbmv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, m, n, kl, ku, alpha, A, lda, x, incx, beta, y, + incy); +} + +cublasStatus_t CUBLASWINAPI +cublasDgbmv_v2(cublasHandle_t handle, cublasOperation_t trans, int m, int n, + int kl, int ku, const double *alpha, /* host or device pointer */ + const double *A, int lda, const double *x, int incx, + const double *beta, /* host or device pointer */ + double *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, int, int, int, int, const double *, + const double *, int, const double *, int, const double *, double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDgbmv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, m, n, kl, ku, alpha, A, lda, x, incx, beta, y, + incy); +} + +cublasStatus_t CUBLASWINAPI cublasCgbmv_v2( + cublasHandle_t handle, cublasOperation_t trans, int m, int n, int kl, + int ku, const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, const cuComplex *x, int incx, + const cuComplex *beta, /* host or device pointer */ + cuComplex *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, int, int, int, int, const cuComplex *, + const cuComplex *, int, const cuComplex *, int, const cuComplex *, + cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgbmv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, m, n, kl, ku, alpha, A, lda, x, incx, beta, y, + incy); +} + +cublasStatus_t CUBLASWINAPI cublasZgbmv_v2( + cublasHandle_t handle, cublasOperation_t trans, int m, int n, int kl, + int ku, const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, const cuDoubleComplex *x, int incx, + const cuDoubleComplex *beta, /* host or device pointer */ + cuDoubleComplex *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, int, int, int, int, + const cuDoubleComplex *, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, + int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZgbmv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, m, n, kl, ku, alpha, A, lda, x, incx, beta, y, + incy); +} + +cublasStatus_t CUBLASWINAPI cublasStrmv_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + cublasDiagType_t diag, int n, const float *A, int lda, float *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, const float *, int, float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasStrmv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, diag, n, A, lda, x, incx); +} + +cublasStatus_t CUBLASWINAPI cublasDtrmv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, int n, + const double *A, int lda, double *x, + int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, const double *, int, double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtrmv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, diag, n, A, lda, x, incx); +} + +cublasStatus_t CUBLASWINAPI cublasCtrmv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, int n, + const cuComplex *A, int lda, + cuComplex *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, const cuComplex *, int, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtrmv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, diag, n, A, lda, x, incx); +} + +cublasStatus_t CUBLASWINAPI cublasZtrmv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, int n, + const cuDoubleComplex *A, int lda, + cuDoubleComplex *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, const cuDoubleComplex *, int, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtrmv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, diag, n, A, lda, x, incx); +} + +cublasStatus_t CUBLASWINAPI cublasStbmv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, int n, int k, + const float *A, int lda, float *x, + int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, int, const float *, int, float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasStbmv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, diag, n, k, A, lda, x, incx); +} + +cublasStatus_t CUBLASWINAPI cublasDtbmv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, int n, int k, + const double *A, int lda, double *x, + int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, int, const double *, int, double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtbmv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, diag, n, k, A, lda, x, incx); +} + +cublasStatus_t CUBLASWINAPI cublasCtbmv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, int n, int k, + const cuComplex *A, int lda, + cuComplex *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, int, const cuComplex *, int, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtbmv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, diag, n, k, A, lda, x, incx); +} + +cublasStatus_t CUBLASWINAPI cublasZtbmv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, int n, int k, + const cuDoubleComplex *A, int lda, + cuDoubleComplex *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtbmv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, diag, n, k, A, lda, x, incx); +} + +cublasStatus_t CUBLASWINAPI cublasStpmv_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + cublasDiagType_t diag, int n, const float *AP, float *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, const float *, float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasStpmv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, diag, n, AP, x, incx); +} + +cublasStatus_t CUBLASWINAPI cublasDtpmv_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + cublasDiagType_t diag, int n, const double *AP, double *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, const double *, double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtpmv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, diag, n, AP, x, incx); +} + +cublasStatus_t CUBLASWINAPI cublasCtpmv_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + cublasDiagType_t diag, int n, const cuComplex *AP, cuComplex *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, const cuComplex *, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtpmv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, diag, n, AP, x, incx); +} + +cublasStatus_t CUBLASWINAPI cublasZtpmv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, int n, + const cuDoubleComplex *AP, + cuDoubleComplex *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, const cuDoubleComplex *, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtpmv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, diag, n, AP, x, incx); +} + +cublasStatus_t CUBLASWINAPI cublasStrsv_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + cublasDiagType_t diag, int n, const float *A, int lda, float *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, const float *, int, float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasStrsv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, diag, n, A, lda, x, incx); +} + +cublasStatus_t CUBLASWINAPI cublasDtrsv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, int n, + const double *A, int lda, double *x, + int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, const double *, int, double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtrsv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, diag, n, A, lda, x, incx); +} + +cublasStatus_t CUBLASWINAPI cublasCtrsv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, int n, + const cuComplex *A, int lda, + cuComplex *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, const cuComplex *, int, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtrsv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, diag, n, A, lda, x, incx); +} + +cublasStatus_t CUBLASWINAPI cublasZtrsv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, int n, + const cuDoubleComplex *A, int lda, + cuDoubleComplex *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, const cuDoubleComplex *, int, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtrsv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, diag, n, A, lda, x, incx); +} + +cublasStatus_t CUBLASWINAPI cublasStpsv_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + cublasDiagType_t diag, int n, const float *AP, float *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, const float *, float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasStpsv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, diag, n, AP, x, incx); +} + +cublasStatus_t CUBLASWINAPI cublasDtpsv_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + cublasDiagType_t diag, int n, const double *AP, double *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, const double *, double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtpsv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, diag, n, AP, x, incx); +} + +cublasStatus_t CUBLASWINAPI cublasCtpsv_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + cublasDiagType_t diag, int n, const cuComplex *AP, cuComplex *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, const cuComplex *, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtpsv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, diag, n, AP, x, incx); +} + +cublasStatus_t CUBLASWINAPI cublasZtpsv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, int n, + const cuDoubleComplex *AP, + cuDoubleComplex *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, const cuDoubleComplex *, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtpsv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, diag, n, AP, x, incx); +} + +cublasStatus_t CUBLASWINAPI cublasStbsv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, int n, int k, + const float *A, int lda, float *x, + int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, int, const float *, int, float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasStbsv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, diag, n, k, A, lda, x, incx); +} + +cublasStatus_t CUBLASWINAPI cublasDtbsv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, int n, int k, + const double *A, int lda, double *x, + int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, int, const double *, int, double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtbsv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, diag, n, k, A, lda, x, incx); +} + +cublasStatus_t CUBLASWINAPI cublasCtbsv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, int n, int k, + const cuComplex *A, int lda, + cuComplex *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, int, const cuComplex *, int, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtbsv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, diag, n, k, A, lda, x, incx); +} + +cublasStatus_t CUBLASWINAPI cublasZtbsv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, int n, int k, + const cuDoubleComplex *A, int lda, + cuDoubleComplex *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtbsv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, diag, n, k, A, lda, x, incx); +} + +cublasStatus_t CUBLASWINAPI +cublasSsymv_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const float *alpha, /* host or device pointer */ + const float *A, int lda, const float *x, int incx, + const float *beta, /* host or device pointer */ + float *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const float *, const float *, int, + const float *, int, const float *, float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSsymv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, A, lda, x, incx, beta, y, incy); +} + +cublasStatus_t CUBLASWINAPI +cublasDsymv_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const double *alpha, /* host or device pointer */ + const double *A, int lda, const double *x, int incx, + const double *beta, /* host or device pointer */ + double *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const double *, const double *, + int, const double *, int, const double *, double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDsymv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, A, lda, x, incx, beta, y, incy); +} + +cublasStatus_t CUBLASWINAPI +cublasCsymv_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, const cuComplex *x, int incx, + const cuComplex *beta, /* host or device pointer */ + cuComplex *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const cuComplex *, + const cuComplex *, int, const cuComplex *, int, const cuComplex *, + cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCsymv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, A, lda, x, incx, beta, y, incy); +} + +cublasStatus_t CUBLASWINAPI cublasZsymv_v2( + cublasHandle_t handle, cublasFillMode_t uplo, int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, const cuDoubleComplex *x, int incx, + const cuDoubleComplex *beta, /* host or device pointer */ + cuDoubleComplex *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, + const cuDoubleComplex *, int, const cuDoubleComplex *, int, + const cuDoubleComplex *, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZsymv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, A, lda, x, incx, beta, y, incy); +} + +cublasStatus_t CUBLASWINAPI +cublasChemv_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, const cuComplex *x, int incx, + const cuComplex *beta, /* host or device pointer */ + cuComplex *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const cuComplex *, + const cuComplex *, int, const cuComplex *, int, const cuComplex *, + cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasChemv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, A, lda, x, incx, beta, y, incy); +} + +cublasStatus_t CUBLASWINAPI cublasZhemv_v2( + cublasHandle_t handle, cublasFillMode_t uplo, int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, const cuDoubleComplex *x, int incx, + const cuDoubleComplex *beta, /* host or device pointer */ + cuDoubleComplex *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, + const cuDoubleComplex *, int, const cuDoubleComplex *, int, + const cuDoubleComplex *, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZhemv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, A, lda, x, incx, beta, y, incy); +} + +cublasStatus_t CUBLASWINAPI +cublasSsbmv_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, int k, + const float *alpha, /* host or device pointer */ + const float *A, int lda, const float *x, int incx, + const float *beta, /* host or device pointer */ + float *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, int, const float *, const float *, + int, const float *, int, const float *, float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSsbmv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, k, alpha, A, lda, x, incx, beta, y, incy); +} + +cublasStatus_t CUBLASWINAPI +cublasDsbmv_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, int k, + const double *alpha, /* host or device pointer */ + const double *A, int lda, const double *x, int incx, + const double *beta, /* host or device pointer */ + double *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, int, const double *, + const double *, int, const double *, int, const double *, double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDsbmv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, k, alpha, A, lda, x, incx, beta, y, incy); +} + +cublasStatus_t CUBLASWINAPI +cublasChbmv_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, int k, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, const cuComplex *x, int incx, + const cuComplex *beta, /* host or device pointer */ + cuComplex *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, int, const cuComplex *, + const cuComplex *, int, const cuComplex *, int, const cuComplex *, + cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasChbmv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, k, alpha, A, lda, x, incx, beta, y, incy); +} + +cublasStatus_t CUBLASWINAPI cublasZhbmv_v2( + cublasHandle_t handle, cublasFillMode_t uplo, int n, int k, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, const cuDoubleComplex *x, int incx, + const cuDoubleComplex *beta, /* host or device pointer */ + cuDoubleComplex *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, int, const cuDoubleComplex *, + const cuDoubleComplex *, int, const cuDoubleComplex *, int, + const cuDoubleComplex *, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZhbmv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, k, alpha, A, lda, x, incx, beta, y, incy); +} + +cublasStatus_t CUBLASWINAPI +cublasSspmv_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const float *alpha, /* host or device pointer */ + const float *AP, const float *x, int incx, + const float *beta, /* host or device pointer */ + float *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const float *, const float *, + const float *, int, const float *, float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSspmv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, AP, x, incx, beta, y, incy); +} + +cublasStatus_t CUBLASWINAPI +cublasDspmv_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const double *alpha, /* host or device pointer */ + const double *AP, const double *x, int incx, + const double *beta, /* host or device pointer */ + double *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const double *, const double *, + const double *, int, const double *, double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDspmv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, AP, x, incx, beta, y, incy); +} + +cublasStatus_t CUBLASWINAPI +cublasChpmv_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *AP, const cuComplex *x, int incx, + const cuComplex *beta, /* host or device pointer */ + cuComplex *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const cuComplex *, + const cuComplex *, const cuComplex *, int, const cuComplex *, cuComplex *, + int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasChpmv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, AP, x, incx, beta, y, incy); +} + +cublasStatus_t CUBLASWINAPI +cublasZhpmv_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *AP, const cuDoubleComplex *x, int incx, + const cuDoubleComplex *beta, /* host or device pointer */ + cuDoubleComplex *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, + const cuDoubleComplex *, const cuDoubleComplex *, int, + const cuDoubleComplex *, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZhpmv_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, AP, x, incx, beta, y, incy); +} + +cublasStatus_t CUBLASWINAPI cublasSger_v2( + cublasHandle_t handle, int m, int n, + const float *alpha, /* host or device pointer */ + const float *x, int incx, const float *y, int incy, float *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, int, const float *, const float *, int, + const float *, int, float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSger_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, alpha, x, incx, y, incy, A, lda); +} + +cublasStatus_t CUBLASWINAPI cublasDger_v2( + cublasHandle_t handle, int m, int n, + const double *alpha, /* host or device pointer */ + const double *x, int incx, const double *y, int incy, double *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, int, const double *, const double *, int, + const double *, int, double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDger_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, alpha, x, incx, y, incy, A, lda); +} + +cublasStatus_t CUBLASWINAPI +cublasCgeru_v2(cublasHandle_t handle, int m, int n, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *x, int incx, const cuComplex *y, int incy, + cuComplex *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, int, const cuComplex *, const cuComplex *, int, + const cuComplex *, int, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgeru_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, alpha, x, incx, y, incy, A, lda); +} + +cublasStatus_t CUBLASWINAPI +cublasCgerc_v2(cublasHandle_t handle, int m, int n, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *x, int incx, const cuComplex *y, int incy, + cuComplex *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, int, const cuComplex *, const cuComplex *, int, + const cuComplex *, int, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgerc_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, alpha, x, incx, y, incy, A, lda); +} + +cublasStatus_t CUBLASWINAPI +cublasZgeru_v2(cublasHandle_t handle, int m, int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *x, int incx, const cuDoubleComplex *y, + int incy, cuDoubleComplex *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, int, const cuDoubleComplex *, + const cuDoubleComplex *, int, const cuDoubleComplex *, int, + cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZgeru_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, alpha, x, incx, y, incy, A, lda); +} + +cublasStatus_t CUBLASWINAPI +cublasZgerc_v2(cublasHandle_t handle, int m, int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *x, int incx, const cuDoubleComplex *y, + int incy, cuDoubleComplex *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, int, const cuDoubleComplex *, + const cuDoubleComplex *, int, const cuDoubleComplex *, int, + cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZgerc_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, alpha, x, incx, y, incy, A, lda); +} + +cublasStatus_t CUBLASWINAPI +cublasSsyr_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const float *alpha, /* host or device pointer */ + const float *x, int incx, float *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const float *, const float *, int, + float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSsyr_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, x, incx, A, lda); +} + +cublasStatus_t CUBLASWINAPI +cublasDsyr_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const double *alpha, /* host or device pointer */ + const double *x, int incx, double *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const double *, const double *, + int, double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDsyr_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, x, incx, A, lda); +} + +cublasStatus_t CUBLASWINAPI +cublasCsyr_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *x, int incx, cuComplex *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const cuComplex *, + const cuComplex *, int, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCsyr_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, x, incx, A, lda); +} + +cublasStatus_t CUBLASWINAPI +cublasZsyr_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *x, int incx, cuDoubleComplex *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, + const cuDoubleComplex *, int, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZsyr_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, x, incx, A, lda); +} + +cublasStatus_t CUBLASWINAPI +cublasCher_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const float *alpha, /* host or device pointer */ + const cuComplex *x, int incx, cuComplex *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const float *, const cuComplex *, + int, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCher_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, x, incx, A, lda); +} + +cublasStatus_t CUBLASWINAPI +cublasZher_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const double *alpha, /* host or device pointer */ + const cuDoubleComplex *x, int incx, cuDoubleComplex *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const double *, + const cuDoubleComplex *, int, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZher_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, x, incx, A, lda); +} + +cublasStatus_t CUBLASWINAPI +cublasSspr_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const float *alpha, /* host or device pointer */ + const float *x, int incx, float *AP) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const float *, const float *, int, + float *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSspr_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, x, incx, AP); +} + +cublasStatus_t CUBLASWINAPI +cublasDspr_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const double *alpha, /* host or device pointer */ + const double *x, int incx, double *AP) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const double *, const double *, + int, double *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDspr_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, x, incx, AP); +} + +cublasStatus_t CUBLASWINAPI +cublasChpr_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const float *alpha, /* host or device pointer */ + const cuComplex *x, int incx, cuComplex *AP) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const float *, const cuComplex *, + int, cuComplex *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasChpr_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, x, incx, AP); +} + +cublasStatus_t CUBLASWINAPI +cublasZhpr_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const double *alpha, /* host or device pointer */ + const cuDoubleComplex *x, int incx, cuDoubleComplex *AP) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const double *, + const cuDoubleComplex *, int, cuDoubleComplex *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZhpr_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, x, incx, AP); +} + +cublasStatus_t CUBLASWINAPI cublasSsyr2_v2( + cublasHandle_t handle, cublasFillMode_t uplo, int n, + const float *alpha, /* host or device pointer */ + const float *x, int incx, const float *y, int incy, float *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const float *, const float *, int, + const float *, int, float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSsyr2_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, x, incx, y, incy, A, lda); +} + +cublasStatus_t CUBLASWINAPI cublasDsyr2_v2( + cublasHandle_t handle, cublasFillMode_t uplo, int n, + const double *alpha, /* host or device pointer */ + const double *x, int incx, const double *y, int incy, double *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const double *, const double *, + int, const double *, int, double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDsyr2_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, x, incx, y, incy, A, lda); +} + +cublasStatus_t CUBLASWINAPI +cublasCsyr2_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *x, int incx, const cuComplex *y, int incy, + cuComplex *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const cuComplex *, + const cuComplex *, int, const cuComplex *, int, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCsyr2_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, x, incx, y, incy, A, lda); +} + +cublasStatus_t CUBLASWINAPI +cublasZsyr2_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *x, int incx, const cuDoubleComplex *y, + int incy, cuDoubleComplex *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, + const cuDoubleComplex *, int, const cuDoubleComplex *, int, + cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZsyr2_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, x, incx, y, incy, A, lda); +} + +cublasStatus_t CUBLASWINAPI +cublasCher2_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *x, int incx, const cuComplex *y, int incy, + cuComplex *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const cuComplex *, + const cuComplex *, int, const cuComplex *, int, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCher2_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, x, incx, y, incy, A, lda); +} + +cublasStatus_t CUBLASWINAPI +cublasZher2_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *x, int incx, const cuDoubleComplex *y, + int incy, cuDoubleComplex *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, + const cuDoubleComplex *, int, const cuDoubleComplex *, int, + cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZher2_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, x, incx, y, incy, A, lda); +} + +cublasStatus_t CUBLASWINAPI +cublasSspr2_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const float *alpha, /* host or device pointer */ + const float *x, int incx, const float *y, int incy, float *AP) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const float *, const float *, int, + const float *, int, float *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSspr2_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, x, incx, y, incy, AP); +} + +cublasStatus_t CUBLASWINAPI cublasDspr2_v2( + cublasHandle_t handle, cublasFillMode_t uplo, int n, + const double *alpha, /* host or device pointer */ + const double *x, int incx, const double *y, int incy, double *AP) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const double *, const double *, + int, const double *, int, double *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDspr2_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, x, incx, y, incy, AP); +} + +cublasStatus_t CUBLASWINAPI cublasChpr2_v2( + cublasHandle_t handle, cublasFillMode_t uplo, int n, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *x, int incx, const cuComplex *y, int incy, cuComplex *AP) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const cuComplex *, + const cuComplex *, int, const cuComplex *, int, cuComplex *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasChpr2_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, x, incx, y, incy, AP); +} + +cublasStatus_t CUBLASWINAPI +cublasZhpr2_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *x, int incx, const cuDoubleComplex *y, + int incy, cuDoubleComplex *AP) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, + const cuDoubleComplex *, int, const cuDoubleComplex *, int, + cuDoubleComplex *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZhpr2_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, alpha, x, incx, y, incy, AP); +} + +cublasStatus_t CUBLASWINAPI cublasSgemm_v2( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const float *alpha, /* host or device pointer */ + const float *A, int lda, const float *B, int ldb, + const float *beta, /* host or device pointer */ + float *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const float *, const float *, int, const float *, int, const float *, + float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSgemm_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, + C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasDgemm_v2( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const double *alpha, /* host or device pointer */ + const double *A, int lda, const double *B, int ldb, + const double *beta, /* host or device pointer */ + double *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const double *, const double *, int, const double *, int, const double *, + double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDgemm_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, + C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasCgemm_v2( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, const cuComplex *B, int ldb, + const cuComplex *beta, /* host or device pointer */ + cuComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const cuComplex *, const cuComplex *, int, const cuComplex *, int, + const cuComplex *, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgemm_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, + C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasCgemm3m( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, const cuComplex *B, int ldb, + const cuComplex *beta, /* host or device pointer */ + cuComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const cuComplex *, const cuComplex *, int, const cuComplex *, int, + const cuComplex *, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgemm3m"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, + C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasCgemm3mEx( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const cuComplex *alpha, const void *A, + cudaDataType Atype, int lda, const void *B, cudaDataType Btype, int ldb, + const cuComplex *beta, void *C, cudaDataType Ctype, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const cuComplex *, const void *, cudaDataType, int, const void *, + cudaDataType, int, const cuComplex *, void *, cudaDataType, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgemm3mEx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, + Btype, ldb, beta, C, Ctype, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasZgemm_v2( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, const cuDoubleComplex *B, int ldb, + const cuDoubleComplex *beta, /* host or device pointer */ + cuDoubleComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const cuDoubleComplex *, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, + int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZgemm_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, + C, ldc); +} + +cublasStatus_t CUBLASWINAPI +cublasZgemm3m(cublasHandle_t handle, cublasOperation_t transa, + cublasOperation_t transb, int m, int n, int k, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, const cuDoubleComplex *B, + int ldb, const cuDoubleComplex *beta, /* host or device pointer */ + cuDoubleComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const cuDoubleComplex *, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, + int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZgemm3m"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, + C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasSgemmEx( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const float *alpha, /* host or device pointer */ + const void *A, cudaDataType Atype, int lda, const void *B, + cudaDataType Btype, int ldb, const float *beta, /* host or device pointer */ + void *C, cudaDataType Ctype, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const float *, const void *, cudaDataType, int, const void *, + cudaDataType, int, const float *, void *, cudaDataType, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSgemmEx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, + Btype, ldb, beta, C, Ctype, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasGemmEx( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const void *alpha, /* host or device pointer */ + const void *A, cudaDataType Atype, int lda, const void *B, + cudaDataType Btype, int ldb, const void *beta, /* host or device pointer */ + void *C, cudaDataType Ctype, int ldc, cublasComputeType_t computeType, + cublasGemmAlgo_t algo) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const void *, const void *, cudaDataType, int, const void *, cudaDataType, + int, const void *, void *, cudaDataType, int, cublasComputeType_t, + cublasGemmAlgo_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasGemmEx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, + Btype, ldb, beta, C, Ctype, ldc, computeType, algo); +} + +cublasStatus_t CUBLASWINAPI cublasCgemmEx( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const cuComplex *alpha, const void *A, + cudaDataType Atype, int lda, const void *B, cudaDataType Btype, int ldb, + const cuComplex *beta, void *C, cudaDataType Ctype, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const cuComplex *, const void *, cudaDataType, int, const void *, + cudaDataType, int, const cuComplex *, void *, cudaDataType, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgemmEx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, + Btype, ldb, beta, C, Ctype, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasUint8gemmBias( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + cublasOperation_t transc, int m, int n, int k, const unsigned char *A, + int A_bias, int lda, const unsigned char *B, int B_bias, int ldb, + unsigned char *C, int C_bias, int ldc, int C_mult, int C_shift) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, cublasOperation_t, + int, int, int, const unsigned char *, int, int, const unsigned char *, + int, int, unsigned char *, int, int, int, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasUint8gemmBias"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transa, transb, transc, m, n, k, A, A_bias, lda, B, + B_bias, ldb, C, C_bias, ldc, C_mult, C_shift); +} + +cublasStatus_t CUBLASWINAPI cublasSsyrk_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const float *alpha, /* host or device pointer */ + const float *A, int lda, const float *beta, /* host or device pointer */ + float *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const float *, const float *, int, const float *, float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSsyrk_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, beta, C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasDsyrk_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const double *alpha, /* host or device pointer */ + const double *A, int lda, const double *beta, /* host or device pointer */ + double *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const double *, const double *, int, const double *, double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDsyrk_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, beta, C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasCsyrk_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, + const cuComplex *beta, /* host or device pointer */ + cuComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const cuComplex *, const cuComplex *, int, const cuComplex *, cuComplex *, + int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCsyrk_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, beta, C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasZsyrk_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, + const cuDoubleComplex *beta, /* host or device pointer */ + cuDoubleComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const cuDoubleComplex *, const cuDoubleComplex *, int, + const cuDoubleComplex *, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZsyrk_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, beta, C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasCsyrkEx( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const cuComplex *alpha, /* host or device pointer */ + const void *A, cudaDataType Atype, int lda, + const cuComplex *beta, /* host or device pointer */ + void *C, cudaDataType Ctype, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const cuComplex *, const void *, cudaDataType, int, const cuComplex *, + void *, cudaDataType, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCsyrkEx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, n, k, alpha, A, Atype, lda, beta, C, + Ctype, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasCsyrk3mEx( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const cuComplex *alpha, const void *A, cudaDataType Atype, + int lda, const cuComplex *beta, void *C, cudaDataType Ctype, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const cuComplex *, const void *, cudaDataType, int, const cuComplex *, + void *, cudaDataType, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCsyrk3mEx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, n, k, alpha, A, Atype, lda, beta, C, + Ctype, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasCherk_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const float *alpha, /* host or device pointer */ + const cuComplex *A, int lda, const float *beta, /* host or device pointer */ + cuComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const float *, const cuComplex *, int, const float *, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCherk_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, beta, C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasZherk_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const double *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, + const double *beta, /* host or device pointer */ + cuDoubleComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const double *, const cuDoubleComplex *, int, const double *, + cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZherk_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, beta, C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasCherkEx( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const float *alpha, /* host or device pointer */ + const void *A, cudaDataType Atype, int lda, + const float *beta, /* host or device pointer */ + void *C, cudaDataType Ctype, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const float *, const void *, cudaDataType, int, const float *, void *, + cudaDataType, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCherkEx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, n, k, alpha, A, Atype, lda, beta, C, + Ctype, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasCherk3mEx( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const float *alpha, const void *A, cudaDataType Atype, + int lda, const float *beta, void *C, cudaDataType Ctype, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const float *, const void *, cudaDataType, int, const float *, void *, + cudaDataType, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCherk3mEx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, n, k, alpha, A, Atype, lda, beta, C, + Ctype, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasSsyr2k_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const float *alpha, /* host or device pointer */ + const float *A, int lda, const float *B, int ldb, + const float *beta, /* host or device pointer */ + float *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const float *, const float *, int, const float *, int, const float *, + float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSsyr2k_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, + ldc); +} + +cublasStatus_t CUBLASWINAPI cublasDsyr2k_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const double *alpha, /* host or device pointer */ + const double *A, int lda, const double *B, int ldb, + const double *beta, /* host or device pointer */ + double *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const double *, const double *, int, const double *, int, const double *, + double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDsyr2k_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, + ldc); +} + +cublasStatus_t CUBLASWINAPI cublasCsyr2k_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, const cuComplex *B, int ldb, + const cuComplex *beta, /* host or device pointer */ + cuComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const cuComplex *, const cuComplex *, int, const cuComplex *, int, + const cuComplex *, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCsyr2k_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, + ldc); +} + +cublasStatus_t CUBLASWINAPI cublasZsyr2k_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, const cuDoubleComplex *B, int ldb, + const cuDoubleComplex *beta, /* host or device pointer */ + cuDoubleComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const cuDoubleComplex *, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, + int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZsyr2k_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, + ldc); +} + +cublasStatus_t CUBLASWINAPI cublasCher2k_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, const cuComplex *B, int ldb, + const float *beta, /* host or device pointer */ + cuComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const cuComplex *, const cuComplex *, int, const cuComplex *, int, + const float *, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCher2k_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, + ldc); +} + +cublasStatus_t CUBLASWINAPI cublasZher2k_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, const cuDoubleComplex *B, int ldb, + const double *beta, /* host or device pointer */ + cuDoubleComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const cuDoubleComplex *, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, const double *, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZher2k_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, + ldc); +} + +cublasStatus_t CUBLASWINAPI cublasSsyrkx( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const float *alpha, /* host or device pointer */ + const float *A, int lda, const float *B, int ldb, + const float *beta, /* host or device pointer */ + float *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const float *, const float *, int, const float *, int, const float *, + float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSsyrkx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, + ldc); +} + +cublasStatus_t CUBLASWINAPI cublasDsyrkx( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const double *alpha, /* host or device pointer */ + const double *A, int lda, const double *B, int ldb, + const double *beta, /* host or device pointer */ + double *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const double *, const double *, int, const double *, int, const double *, + double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDsyrkx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, + ldc); +} + +cublasStatus_t CUBLASWINAPI cublasCsyrkx( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, const cuComplex *B, int ldb, + const cuComplex *beta, /* host or device pointer */ + cuComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const cuComplex *, const cuComplex *, int, const cuComplex *, int, + const cuComplex *, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCsyrkx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, + ldc); +} + +cublasStatus_t CUBLASWINAPI cublasZsyrkx( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, const cuDoubleComplex *B, int ldb, + const cuDoubleComplex *beta, /* host or device pointer */ + cuDoubleComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const cuDoubleComplex *, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, + int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZsyrkx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, + ldc); +} + +cublasStatus_t CUBLASWINAPI cublasCherkx( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, const cuComplex *B, int ldb, + const float *beta, /* host or device pointer */ + cuComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const cuComplex *, const cuComplex *, int, const cuComplex *, int, + const float *, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCherkx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, + ldc); +} + +cublasStatus_t CUBLASWINAPI cublasZherkx( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, const cuDoubleComplex *B, int ldb, + const double *beta, /* host or device pointer */ + cuDoubleComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const cuDoubleComplex *, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, const double *, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZherkx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, + ldc); +} + +cublasStatus_t CUBLASWINAPI cublasSsymm_v2( + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, int m, + int n, const float *alpha, /* host or device pointer */ + const float *A, int lda, const float *B, int ldb, + const float *beta, /* host or device pointer */ + float *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, cublasFillMode_t, int, int, + const float *, const float *, int, const float *, int, const float *, + float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSsymm_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, + ldc); +} + +cublasStatus_t CUBLASWINAPI cublasDsymm_v2( + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, int m, + int n, const double *alpha, /* host or device pointer */ + const double *A, int lda, const double *B, int ldb, + const double *beta, /* host or device pointer */ + double *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, cublasFillMode_t, int, int, + const double *, const double *, int, const double *, int, const double *, + double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDsymm_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, + ldc); +} + +cublasStatus_t CUBLASWINAPI cublasCsymm_v2( + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, int m, + int n, const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, const cuComplex *B, int ldb, + const cuComplex *beta, /* host or device pointer */ + cuComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, cublasFillMode_t, int, int, + const cuComplex *, const cuComplex *, int, const cuComplex *, int, + const cuComplex *, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCsymm_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, + ldc); +} + +cublasStatus_t CUBLASWINAPI cublasZsymm_v2( + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, int m, + int n, const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, const cuDoubleComplex *B, int ldb, + const cuDoubleComplex *beta, /* host or device pointer */ + cuDoubleComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, cublasFillMode_t, int, int, + const cuDoubleComplex *, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, + int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZsymm_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, + ldc); +} + +cublasStatus_t CUBLASWINAPI cublasChemm_v2( + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, int m, + int n, const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, const cuComplex *B, int ldb, + const cuComplex *beta, /* host or device pointer */ + cuComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, cublasFillMode_t, int, int, + const cuComplex *, const cuComplex *, int, const cuComplex *, int, + const cuComplex *, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasChemm_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, + ldc); +} + +cublasStatus_t CUBLASWINAPI cublasZhemm_v2( + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, int m, + int n, const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, const cuDoubleComplex *B, int ldb, + const cuDoubleComplex *beta, /* host or device pointer */ + cuDoubleComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, cublasFillMode_t, int, int, + const cuDoubleComplex *, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, + int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZhemm_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, + ldc); +} + +cublasStatus_t CUBLASWINAPI cublasStrsm_v2( + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, cublasDiagType_t diag, int m, int n, + const float *alpha, /* host or device pointer */ + const float *A, int lda, float *B, int ldb) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + cublasDiagType_t, int, int, const float *, const float *, int, float *, + int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasStrsm_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb); +} + +cublasStatus_t CUBLASWINAPI cublasDtrsm_v2( + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, cublasDiagType_t diag, int m, int n, + const double *alpha, /* host or device pointer */ + const double *A, int lda, double *B, int ldb) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + cublasDiagType_t, int, int, const double *, const double *, int, double *, + int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtrsm_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb); +} + +cublasStatus_t CUBLASWINAPI cublasCtrsm_v2( + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, cublasDiagType_t diag, int m, int n, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, cuComplex *B, int ldb) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + cublasDiagType_t, int, int, const cuComplex *, const cuComplex *, int, + cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtrsm_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb); +} + +cublasStatus_t CUBLASWINAPI cublasZtrsm_v2( + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, cublasDiagType_t diag, int m, int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, cuDoubleComplex *B, int ldb) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + cublasDiagType_t, int, int, const cuDoubleComplex *, + const cuDoubleComplex *, int, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtrsm_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb); +} + +cublasStatus_t CUBLASWINAPI cublasStrmm_v2( + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, cublasDiagType_t diag, int m, int n, + const float *alpha, /* host or device pointer */ + const float *A, int lda, const float *B, int ldb, float *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + cublasDiagType_t, int, int, const float *, const float *, int, + const float *, int, float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasStrmm_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, + C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasDtrmm_v2( + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, cublasDiagType_t diag, int m, int n, + const double *alpha, /* host or device pointer */ + const double *A, int lda, const double *B, int ldb, double *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + cublasDiagType_t, int, int, const double *, const double *, int, + const double *, int, double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtrmm_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, + C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasCtrmm_v2( + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, cublasDiagType_t diag, int m, int n, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, const cuComplex *B, int ldb, cuComplex *C, + int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + cublasDiagType_t, int, int, const cuComplex *, const cuComplex *, int, + const cuComplex *, int, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtrmm_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, + C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasZtrmm_v2( + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, cublasDiagType_t diag, int m, int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, const cuDoubleComplex *B, int ldb, + cuDoubleComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + cublasDiagType_t, int, int, const cuDoubleComplex *, + const cuDoubleComplex *, int, const cuDoubleComplex *, int, + cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtrmm_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, + C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasSgemmBatched( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const float *alpha, /* host or device pointer */ + const float *const Aarray[], int lda, const float *const Barray[], int ldb, + const float *beta, /* host or device pointer */ + float *const Carray[], int ldc, int batchCount) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const float *, const float *const[], int, const float *const[], int, + const float *, float *const[], int, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSgemmBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, + ldb, beta, Carray, ldc, batchCount); +} + +cublasStatus_t CUBLASWINAPI cublasDgemmBatched( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const double *alpha, /* host or device pointer */ + const double *const Aarray[], int lda, const double *const Barray[], + int ldb, const double *beta, /* host or device pointer */ + double *const Carray[], int ldc, int batchCount) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const double *, const double *const[], int, const double *const[], int, + const double *, double *const[], int, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDgemmBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, + ldb, beta, Carray, ldc, batchCount); +} + +cublasStatus_t CUBLASWINAPI cublasCgemmBatched( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const cuComplex *alpha, /* host or device pointer */ + const cuComplex *const Aarray[], int lda, const cuComplex *const Barray[], + int ldb, const cuComplex *beta, /* host or device pointer */ + cuComplex *const Carray[], int ldc, int batchCount) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const cuComplex *, const cuComplex *const[], int, + const cuComplex *const[], int, const cuComplex *, cuComplex *const[], int, + int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgemmBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, + ldb, beta, Carray, ldc, batchCount); +} + +cublasStatus_t CUBLASWINAPI cublasCgemm3mBatched( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const cuComplex *alpha, /* host or device pointer */ + const cuComplex *const Aarray[], int lda, const cuComplex *const Barray[], + int ldb, const cuComplex *beta, /* host or device pointer */ + cuComplex *const Carray[], int ldc, int batchCount) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const cuComplex *, const cuComplex *const[], int, + const cuComplex *const[], int, const cuComplex *, cuComplex *const[], int, + int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgemm3mBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, + ldb, beta, Carray, ldc, batchCount); +} + +cublasStatus_t CUBLASWINAPI +cublasZgemmBatched(cublasHandle_t handle, cublasOperation_t transa, + cublasOperation_t transb, int m, int n, int k, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *const Aarray[], int lda, + const cuDoubleComplex *const Barray[], int ldb, + const cuDoubleComplex *beta, /* host or device pointer */ + cuDoubleComplex *const Carray[], int ldc, int batchCount) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const cuDoubleComplex *, const cuDoubleComplex *const[], int, + const cuDoubleComplex *const[], int, const cuDoubleComplex *, + cuDoubleComplex *const[], int, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZgemmBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, + ldb, beta, Carray, ldc, batchCount); +} + +cublasStatus_t CUBLASWINAPI cublasGemmBatchedEx( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const void *alpha, /* host or device pointer */ + const void *const Aarray[], cudaDataType Atype, int lda, + const void *const Barray[], cudaDataType Btype, int ldb, + const void *beta, /* host or device pointer */ + void *const Carray[], cudaDataType Ctype, int ldc, int batchCount, + cublasComputeType_t computeType, cublasGemmAlgo_t algo) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const void *, const void *const[], cudaDataType, int, const void *const[], + cudaDataType, int, const void *, void *const[], cudaDataType, int, int, + cublasComputeType_t, cublasGemmAlgo_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasGemmBatchedEx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transa, transb, m, n, k, alpha, Aarray, Atype, lda, + Barray, Btype, ldb, beta, Carray, Ctype, ldc, batchCount, + computeType, algo); +} + +cublasStatus_t CUBLASWINAPI cublasGemmStridedBatchedEx( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const void *alpha, /* host or device pointer */ + const void *A, cudaDataType Atype, int lda, + long long int strideA, /* purposely signed */ + const void *B, cudaDataType Btype, int ldb, long long int strideB, + const void *beta, /* host or device pointer */ + void *C, cudaDataType Ctype, int ldc, long long int strideC, int batchCount, + cublasComputeType_t computeType, cublasGemmAlgo_t algo) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const void *, const void *, cudaDataType, int, long long, const void *, + cudaDataType, int, long long, const void *, void *, cudaDataType, int, + long long, int, cublasComputeType_t, cublasGemmAlgo_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasGemmStridedBatchedEx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, Atype, lda, + strideA, B, Btype, ldb, strideB, beta, C, Ctype, ldc, strideC, + batchCount, computeType, algo); +} + +cublasStatus_t CUBLASWINAPI cublasSgemmStridedBatched( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const float *alpha, /* host or device pointer */ + const float *A, int lda, long long int strideA, /* purposely signed */ + const float *B, int ldb, long long int strideB, + const float *beta, /* host or device pointer */ + float *C, int ldc, long long int strideC, int batchCount) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const float *, const float *, int, long long, const float *, int, + long long, const float *, float *, int, long long, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSgemmStridedBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, + ldb, strideB, beta, C, ldc, strideC, batchCount); +} + +cublasStatus_t CUBLASWINAPI cublasDgemmStridedBatched( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const double *alpha, /* host or device pointer */ + const double *A, int lda, long long int strideA, /* purposely signed */ + const double *B, int ldb, long long int strideB, + const double *beta, /* host or device pointer */ + double *C, int ldc, long long int strideC, int batchCount) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const double *, const double *, int, long long, const double *, int, + long long, const double *, double *, int, long long, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDgemmStridedBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, + ldb, strideB, beta, C, ldc, strideC, batchCount); +} + +cublasStatus_t CUBLASWINAPI cublasCgemmStridedBatched( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, long long int strideA, /* purposely signed */ + const cuComplex *B, int ldb, long long int strideB, + const cuComplex *beta, /* host or device pointer */ + cuComplex *C, int ldc, long long int strideC, int batchCount) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const cuComplex *, const cuComplex *, int, long long, const cuComplex *, + int, long long, const cuComplex *, cuComplex *, int, long long, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgemmStridedBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, + ldb, strideB, beta, C, ldc, strideC, batchCount); +} + +cublasStatus_t CUBLASWINAPI cublasCgemm3mStridedBatched( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, long long int strideA, /* purposely signed */ + const cuComplex *B, int ldb, long long int strideB, + const cuComplex *beta, /* host or device pointer */ + cuComplex *C, int ldc, long long int strideC, int batchCount) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const cuComplex *, const cuComplex *, int, long long, const cuComplex *, + int, long long, const cuComplex *, cuComplex *, int, long long, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgemm3mStridedBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, + ldb, strideB, beta, C, ldc, strideC, batchCount); +} + +cublasStatus_t CUBLASWINAPI cublasZgemmStridedBatched( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, + long long int strideA, /* purposely signed */ + const cuDoubleComplex *B, int ldb, long long int strideB, + const cuDoubleComplex *beta, /* host or device poi */ + cuDoubleComplex *C, int ldc, long long int strideC, int batchCount) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const cuDoubleComplex *, const cuDoubleComplex *, int, long long, + const cuDoubleComplex *, int, long long, const cuDoubleComplex *, + cuDoubleComplex *, int, long long, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZgemmStridedBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, + ldb, strideB, beta, C, ldc, strideC, batchCount); +} + +cublasStatus_t CUBLASWINAPI cublasSgeam( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, const float *alpha, /* host or device pointer */ + const float *A, int lda, const float *beta, /* host or device pointer */ + const float *B, int ldb, float *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, + const float *, const float *, int, const float *, const float *, int, + float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSgeam"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, + ldc); +} + +cublasStatus_t CUBLASWINAPI cublasDgeam( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, const double *alpha, /* host or device pointer */ + const double *A, int lda, const double *beta, /* host or device pointer */ + const double *B, int ldb, double *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, + const double *, const double *, int, const double *, const double *, int, + double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDgeam"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, + ldc); +} + +cublasStatus_t CUBLASWINAPI cublasCgeam( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, + const cuComplex *beta, /* host or device pointer */ + const cuComplex *B, int ldb, cuComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, + const cuComplex *, const cuComplex *, int, const cuComplex *, + const cuComplex *, int, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgeam"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, + ldc); +} + +cublasStatus_t CUBLASWINAPI cublasZgeam( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, + const cuDoubleComplex *beta, /* host or device pointer */ + const cuDoubleComplex *B, int ldb, cuDoubleComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, + const cuDoubleComplex *, const cuDoubleComplex *, int, + const cuDoubleComplex *, const cuDoubleComplex *, int, cuDoubleComplex *, + int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZgeam"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, + ldc); +} + +cublasStatus_t CUBLASWINAPI cublasSgetrfBatched( + cublasHandle_t handle, int n, float *const A[], /*Device pointer*/ + int lda, int *P, /*Device Pointer*/ + int *info, /*Device Pointer*/ + int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, float *const[], int, int *, int *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSgetrfBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, A, lda, P, info, batchSize); +} + +cublasStatus_t CUBLASWINAPI cublasDgetrfBatched( + cublasHandle_t handle, int n, double *const A[], /*Device pointer*/ + int lda, int *P, /*Device Pointer*/ + int *info, /*Device Pointer*/ + int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, double *const[], int, int *, int *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDgetrfBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, A, lda, P, info, batchSize); +} + +cublasStatus_t CUBLASWINAPI cublasCgetrfBatched( + cublasHandle_t handle, int n, cuComplex *const A[], /*Device pointer*/ + int lda, int *P, /*Device Pointer*/ + int *info, /*Device Pointer*/ + int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, cuComplex *const[], int, int *, int *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgetrfBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, A, lda, P, info, batchSize); +} + +cublasStatus_t CUBLASWINAPI cublasZgetrfBatched( + cublasHandle_t handle, int n, cuDoubleComplex *const A[], /*Device pointer*/ + int lda, int *P, /*Device Pointer*/ + int *info, /*Device Pointer*/ + int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, cuDoubleComplex *const[], int, int *, int *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZgetrfBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, A, lda, P, info, batchSize); +} + +cublasStatus_t CUBLASWINAPI cublasSgetriBatched( + cublasHandle_t handle, int n, const float *const A[], /*Device pointer*/ + int lda, const int *P, /*Device pointer*/ + float *const C[], /*Device pointer*/ + int ldc, int *info, int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const float *const[], int, const int *, + float *const[], int, int *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSgetriBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, A, lda, P, C, ldc, info, batchSize); +} + +cublasStatus_t CUBLASWINAPI cublasDgetriBatched( + cublasHandle_t handle, int n, const double *const A[], /*Device pointer*/ + int lda, const int *P, /*Device pointer*/ + double *const C[], /*Device pointer*/ + int ldc, int *info, int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const double *const[], int, const int *, + double *const[], int, int *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDgetriBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, A, lda, P, C, ldc, info, batchSize); +} + +cublasStatus_t CUBLASWINAPI cublasCgetriBatched( + cublasHandle_t handle, int n, const cuComplex *const A[], /*Device pointer*/ + int lda, const int *P, /*Device pointer*/ + cuComplex *const C[], /*Device pointer*/ + int ldc, int *info, int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const cuComplex *const[], int, const int *, + cuComplex *const[], int, int *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgetriBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, A, lda, P, C, ldc, info, batchSize); +} + +cublasStatus_t CUBLASWINAPI +cublasZgetriBatched(cublasHandle_t handle, int n, + const cuDoubleComplex *const A[], /*Device pointer*/ + int lda, const int *P, /*Device pointer*/ + cuDoubleComplex *const C[], /*Device pointer*/ + int ldc, int *info, int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const cuDoubleComplex *const[], int, const int *, + cuDoubleComplex *const[], int, int *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZgetriBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, A, lda, P, C, ldc, info, batchSize); +} + +cublasStatus_t CUBLASWINAPI cublasSgetrsBatched( + cublasHandle_t handle, cublasOperation_t trans, int n, int nrhs, + const float *const Aarray[], int lda, const int *devIpiv, + float *const Barray[], int ldb, int *info, int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, int, int, const float *const[], int, + const int *, float *const[], int, int *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSgetrsBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, n, nrhs, Aarray, lda, devIpiv, Barray, ldb, + info, batchSize); +} + +cublasStatus_t CUBLASWINAPI cublasDgetrsBatched( + cublasHandle_t handle, cublasOperation_t trans, int n, int nrhs, + const double *const Aarray[], int lda, const int *devIpiv, + double *const Barray[], int ldb, int *info, int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, int, int, const double *const[], int, + const int *, double *const[], int, int *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDgetrsBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, n, nrhs, Aarray, lda, devIpiv, Barray, ldb, + info, batchSize); +} + +cublasStatus_t CUBLASWINAPI cublasCgetrsBatched( + cublasHandle_t handle, cublasOperation_t trans, int n, int nrhs, + const cuComplex *const Aarray[], int lda, const int *devIpiv, + cuComplex *const Barray[], int ldb, int *info, int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, int, int, const cuComplex *const[], + int, const int *, cuComplex *const[], int, int *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgetrsBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, n, nrhs, Aarray, lda, devIpiv, Barray, ldb, + info, batchSize); +} + +cublasStatus_t CUBLASWINAPI cublasZgetrsBatched( + cublasHandle_t handle, cublasOperation_t trans, int n, int nrhs, + const cuDoubleComplex *const Aarray[], int lda, const int *devIpiv, + cuDoubleComplex *const Barray[], int ldb, int *info, int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, int, int, + const cuDoubleComplex *const[], int, const int *, + cuDoubleComplex *const[], int, int *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZgetrsBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, n, nrhs, Aarray, lda, devIpiv, Barray, ldb, + info, batchSize); +} + +cublasStatus_t CUBLASWINAPI cublasStrsmBatched( + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, cublasDiagType_t diag, int m, int n, + const float *alpha, /*Host or Device Pointer*/ + const float *const A[], int lda, float *const B[], int ldb, + int batchCount) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + cublasDiagType_t, int, int, const float *, const float *const[], int, + float *const[], int, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasStrsmBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, + batchCount); +} + +cublasStatus_t CUBLASWINAPI cublasDtrsmBatched( + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, cublasDiagType_t diag, int m, int n, + const double *alpha, /*Host or Device Pointer*/ + const double *const A[], int lda, double *const B[], int ldb, + int batchCount) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + cublasDiagType_t, int, int, const double *, const double *const[], int, + double *const[], int, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtrsmBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, + batchCount); +} + +cublasStatus_t CUBLASWINAPI cublasCtrsmBatched( + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, cublasDiagType_t diag, int m, int n, + const cuComplex *alpha, /*Host or Device Pointer*/ + const cuComplex *const A[], int lda, cuComplex *const B[], int ldb, + int batchCount) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + cublasDiagType_t, int, int, const cuComplex *, const cuComplex *const[], + int, cuComplex *const[], int, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtrsmBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, + batchCount); +} + +cublasStatus_t CUBLASWINAPI cublasZtrsmBatched( + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, cublasDiagType_t diag, int m, int n, + const cuDoubleComplex *alpha, /*Host or Device Pointer*/ + const cuDoubleComplex *const A[], int lda, cuDoubleComplex *const B[], + int ldb, int batchCount) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + cublasDiagType_t, int, int, const cuDoubleComplex *, + const cuDoubleComplex *const[], int, cuDoubleComplex *const[], int, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtrsmBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, + batchCount); +} + +cublasStatus_t CUBLASWINAPI cublasSmatinvBatched( + cublasHandle_t handle, int n, const float *const A[], /*Device pointer*/ + int lda, float *const Ainv[], /*Device pointer*/ + int lda_inv, int *info, /*Device Pointer*/ + int batchSize) { + using FuncPtr = + cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, const float *const[], + int, float *const[], int, int *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSmatinvBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, A, lda, Ainv, lda_inv, info, batchSize); +} + +cublasStatus_t CUBLASWINAPI cublasDmatinvBatched( + cublasHandle_t handle, int n, const double *const A[], /*Device pointer*/ + int lda, double *const Ainv[], /*Device pointer*/ + int lda_inv, int *info, /*Device Pointer*/ + int batchSize) { + using FuncPtr = + cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, const double *const[], + int, double *const[], int, int *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDmatinvBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, A, lda, Ainv, lda_inv, info, batchSize); +} + +cublasStatus_t CUBLASWINAPI cublasCmatinvBatched( + cublasHandle_t handle, int n, const cuComplex *const A[], /*Device pointer*/ + int lda, cuComplex *const Ainv[], /*Device pointer*/ + int lda_inv, int *info, /*Device Pointer*/ + int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const cuComplex *const[], int, cuComplex *const[], + int, int *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCmatinvBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, A, lda, Ainv, lda_inv, info, batchSize); +} + +cublasStatus_t CUBLASWINAPI +cublasZmatinvBatched(cublasHandle_t handle, int n, + const cuDoubleComplex *const A[], /*Device pointer*/ + int lda, cuDoubleComplex *const Ainv[], /*Device pointer*/ + int lda_inv, int *info, /*Device Pointer*/ + int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const cuDoubleComplex *const[], int, + cuDoubleComplex *const[], int, int *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZmatinvBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, A, lda, Ainv, lda_inv, info, batchSize); +} + +cublasStatus_t CUBLASWINAPI +cublasSgeqrfBatched(cublasHandle_t handle, int m, int n, + float *const Aarray[], /*Device pointer*/ + int lda, float *const TauArray[], /*Device pointer*/ + int *info, int batchSize) { + using FuncPtr = + cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, int, float *const[], + int, float *const[], int *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSgeqrfBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, Aarray, lda, TauArray, info, batchSize); +} + +cublasStatus_t CUBLASWINAPI +cublasDgeqrfBatched(cublasHandle_t handle, int m, int n, + double *const Aarray[], /*Device pointer*/ + int lda, double *const TauArray[], /*Device pointer*/ + int *info, int batchSize) { + using FuncPtr = + cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, int, double *const[], + int, double *const[], int *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDgeqrfBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, Aarray, lda, TauArray, info, batchSize); +} + +cublasStatus_t CUBLASWINAPI +cublasCgeqrfBatched(cublasHandle_t handle, int m, int n, + cuComplex *const Aarray[], /*Device pointer*/ + int lda, cuComplex *const TauArray[], /*Device pointer*/ + int *info, int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, int, cuComplex *const[], int, cuComplex *const[], + int *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgeqrfBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, Aarray, lda, TauArray, info, batchSize); +} + +cublasStatus_t CUBLASWINAPI cublasZgeqrfBatched( + cublasHandle_t handle, int m, int n, + cuDoubleComplex *const Aarray[], /*Device pointer*/ + int lda, cuDoubleComplex *const TauArray[], /*Device pointer*/ + int *info, int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, int, cuDoubleComplex *const[], int, + cuDoubleComplex *const[], int *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZgeqrfBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, Aarray, lda, TauArray, info, batchSize); +} + +cublasStatus_t CUBLASWINAPI +cublasSgelsBatched(cublasHandle_t handle, cublasOperation_t trans, int m, int n, + int nrhs, float *const Aarray[], /*Device pointer*/ + int lda, float *const Carray[], /*Device pointer*/ + int ldc, int *info, int *devInfoArray, /*Device pointer*/ + int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, int, int, int, float *const[], int, + float *const[], int, int *, int *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSgelsBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, m, n, nrhs, Aarray, lda, Carray, ldc, info, + devInfoArray, batchSize); +} + +cublasStatus_t CUBLASWINAPI +cublasDgelsBatched(cublasHandle_t handle, cublasOperation_t trans, int m, int n, + int nrhs, double *const Aarray[], /*Device pointer*/ + int lda, double *const Carray[], /*Device pointer*/ + int ldc, int *info, int *devInfoArray, /*Device pointer*/ + int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, int, int, int, double *const[], int, + double *const[], int, int *, int *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDgelsBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, m, n, nrhs, Aarray, lda, Carray, ldc, info, + devInfoArray, batchSize); +} + +cublasStatus_t CUBLASWINAPI +cublasCgelsBatched(cublasHandle_t handle, cublasOperation_t trans, int m, int n, + int nrhs, cuComplex *const Aarray[], /*Device pointer*/ + int lda, cuComplex *const Carray[], /*Device pointer*/ + int ldc, int *info, int *devInfoArray, int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, int, int, int, cuComplex *const[], int, + cuComplex *const[], int, int *, int *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgelsBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, m, n, nrhs, Aarray, lda, Carray, ldc, info, + devInfoArray, batchSize); +} + +cublasStatus_t CUBLASWINAPI +cublasZgelsBatched(cublasHandle_t handle, cublasOperation_t trans, int m, int n, + int nrhs, cuDoubleComplex *const Aarray[], /*Device pointer*/ + int lda, cuDoubleComplex *const Carray[], /*Device pointer*/ + int ldc, int *info, int *devInfoArray, int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, int, int, int, + cuDoubleComplex *const[], int, cuDoubleComplex *const[], int, int *, + int *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZgelsBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, m, n, nrhs, Aarray, lda, Carray, ldc, info, + devInfoArray, batchSize); +} + +cublasStatus_t CUBLASWINAPI cublasSdgmm(cublasHandle_t handle, + cublasSideMode_t mode, int m, int n, + const float *A, int lda, const float *x, + int incx, float *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, int, int, const float *, int, + const float *, int, float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSdgmm"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, mode, m, n, A, lda, x, incx, C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasDdgmm(cublasHandle_t handle, + cublasSideMode_t mode, int m, int n, + const double *A, int lda, + const double *x, int incx, double *C, + int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, int, int, const double *, int, + const double *, int, double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDdgmm"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, mode, m, n, A, lda, x, incx, C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasCdgmm(cublasHandle_t handle, + cublasSideMode_t mode, int m, int n, + const cuComplex *A, int lda, + const cuComplex *x, int incx, + cuComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, int, int, const cuComplex *, int, + const cuComplex *, int, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCdgmm"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, mode, m, n, A, lda, x, incx, C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasZdgmm(cublasHandle_t handle, + cublasSideMode_t mode, int m, int n, + const cuDoubleComplex *A, int lda, + const cuDoubleComplex *x, int incx, + cuDoubleComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, int, int, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZdgmm"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, mode, m, n, A, lda, x, incx, C, ldc); +} + +cublasStatus_t CUBLASWINAPI cublasStpttr(cublasHandle_t handle, + cublasFillMode_t uplo, int n, + const float *AP, float *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const float *, float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasStpttr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, AP, A, lda); +} + +cublasStatus_t CUBLASWINAPI cublasDtpttr(cublasHandle_t handle, + cublasFillMode_t uplo, int n, + const double *AP, double *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const double *, double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtpttr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, AP, A, lda); +} + +cublasStatus_t CUBLASWINAPI cublasCtpttr(cublasHandle_t handle, + cublasFillMode_t uplo, int n, + const cuComplex *AP, cuComplex *A, + int lda) { + using FuncPtr = + cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, + const cuComplex *, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtpttr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, AP, A, lda); +} + +cublasStatus_t CUBLASWINAPI cublasZtpttr(cublasHandle_t handle, + cublasFillMode_t uplo, int n, + const cuDoubleComplex *AP, + cuDoubleComplex *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, + cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtpttr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, AP, A, lda); +} + +cublasStatus_t CUBLASWINAPI cublasStrttp(cublasHandle_t handle, + cublasFillMode_t uplo, int n, + const float *A, int lda, float *AP) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const float *, int, float *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasStrttp"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, AP); +} + +cublasStatus_t CUBLASWINAPI cublasDtrttp(cublasHandle_t handle, + cublasFillMode_t uplo, int n, + const double *A, int lda, double *AP) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const double *, int, double *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtrttp"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, AP); +} + +cublasStatus_t CUBLASWINAPI cublasCtrttp(cublasHandle_t handle, + cublasFillMode_t uplo, int n, + const cuComplex *A, int lda, + cuComplex *AP) { + using FuncPtr = + cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, + const cuComplex *, int, cuComplex *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtrttp"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, AP); +} + +cublasStatus_t CUBLASWINAPI cublasZtrttp(cublasHandle_t handle, + cublasFillMode_t uplo, int n, + const cuDoubleComplex *A, int lda, + cuDoubleComplex *AP) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, int, + cuDoubleComplex *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtrttp"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, AP); +} + +cublasStatus CUBLASWINAPI cublasInit(void) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasInit"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(); +} + +cublasStatus CUBLASWINAPI cublasShutdown(void) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasShutdown"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(); +} + +cublasStatus CUBLASWINAPI cublasGetError(void) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasGetError"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(); +} + +cublasStatus CUBLASWINAPI cublasGetVersion(int *version) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasGetVersion"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(version); +} + +cublasStatus CUBLASWINAPI cublasAlloc(int n, int elemSize, void **devicePtr) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(int, int, void **); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasAlloc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(n, elemSize, devicePtr); +} + +cublasStatus CUBLASWINAPI cublasFree(void *devicePtr) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasFree"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(devicePtr); +} + +cublasStatus CUBLASWINAPI cublasSetKernelStream(cudaStream_t stream) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cudaStream_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSetKernelStream"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(stream); +} + +float CUBLASWINAPI cublasSnrm2(int n, const float *x, int incx) { + using FuncPtr = float(CUBLASWINAPI *)(int, const float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSnrm2"); + if (!func_ptr) LogFatalSymbolNotFound("cublasSnrm2"); + return func_ptr(n, x, incx); +} + +double CUBLASWINAPI cublasDnrm2(int n, const double *x, int incx) { + using FuncPtr = double(CUBLASWINAPI *)(int, const double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDnrm2"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDnrm2"); + return func_ptr(n, x, incx); +} + +float CUBLASWINAPI cublasScnrm2(int n, const cuComplex *x, int incx) { + using FuncPtr = float(CUBLASWINAPI *)(int, const cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasScnrm2"); + if (!func_ptr) LogFatalSymbolNotFound("cublasScnrm2"); + return func_ptr(n, x, incx); +} + +double CUBLASWINAPI cublasDznrm2(int n, const cuDoubleComplex *x, int incx) { + using FuncPtr = double(CUBLASWINAPI *)(int, const cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDznrm2"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDznrm2"); + return func_ptr(n, x, incx); +} + +float CUBLASWINAPI cublasSdot(int n, const float *x, int incx, const float *y, + int incy) { + using FuncPtr = + float(CUBLASWINAPI *)(int, const float *, int, const float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSdot"); + if (!func_ptr) LogFatalSymbolNotFound("cublasSdot"); + return func_ptr(n, x, incx, y, incy); +} + +double CUBLASWINAPI cublasDdot(int n, const double *x, int incx, + const double *y, int incy) { + using FuncPtr = + double(CUBLASWINAPI *)(int, const double *, int, const double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDdot"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDdot"); + return func_ptr(n, x, incx, y, incy); +} + +cuComplex CUBLASWINAPI cublasCdotu(int n, const cuComplex *x, int incx, + const cuComplex *y, int incy) { + using FuncPtr = cuComplex(CUBLASWINAPI *)(int, const cuComplex *, int, + const cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCdotu"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCdotu"); + return func_ptr(n, x, incx, y, incy); +} + +cuComplex CUBLASWINAPI cublasCdotc(int n, const cuComplex *x, int incx, + const cuComplex *y, int incy) { + using FuncPtr = cuComplex(CUBLASWINAPI *)(int, const cuComplex *, int, + const cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCdotc"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCdotc"); + return func_ptr(n, x, incx, y, incy); +} + +cuDoubleComplex CUBLASWINAPI cublasZdotu(int n, const cuDoubleComplex *x, + int incx, const cuDoubleComplex *y, + int incy) { + using FuncPtr = cuDoubleComplex(CUBLASWINAPI *)( + int, const cuDoubleComplex *, int, const cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZdotu"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZdotu"); + return func_ptr(n, x, incx, y, incy); +} + +cuDoubleComplex CUBLASWINAPI cublasZdotc(int n, const cuDoubleComplex *x, + int incx, const cuDoubleComplex *y, + int incy) { + using FuncPtr = cuDoubleComplex(CUBLASWINAPI *)( + int, const cuDoubleComplex *, int, const cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZdotc"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZdotc"); + return func_ptr(n, x, incx, y, incy); +} + +void CUBLASWINAPI cublasSscal(int n, float alpha, float *x, int incx) { + using FuncPtr = void(CUBLASWINAPI *)(int, float, float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSscal"); + if (!func_ptr) LogFatalSymbolNotFound("cublasSscal"); + return func_ptr(n, alpha, x, incx); +} + +void CUBLASWINAPI cublasDscal(int n, double alpha, double *x, int incx) { + using FuncPtr = void(CUBLASWINAPI *)(int, double, double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDscal"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDscal"); + return func_ptr(n, alpha, x, incx); +} + +void CUBLASWINAPI cublasCscal(int n, cuComplex alpha, cuComplex *x, int incx) { + using FuncPtr = void(CUBLASWINAPI *)(int, cuComplex, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCscal"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCscal"); + return func_ptr(n, alpha, x, incx); +} + +void CUBLASWINAPI cublasZscal(int n, cuDoubleComplex alpha, cuDoubleComplex *x, + int incx) { + using FuncPtr = + void(CUBLASWINAPI *)(int, cuDoubleComplex, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZscal"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZscal"); + return func_ptr(n, alpha, x, incx); +} + +void CUBLASWINAPI cublasCsscal(int n, float alpha, cuComplex *x, int incx) { + using FuncPtr = void(CUBLASWINAPI *)(int, float, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCsscal"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCsscal"); + return func_ptr(n, alpha, x, incx); +} + +void CUBLASWINAPI cublasZdscal(int n, double alpha, cuDoubleComplex *x, + int incx) { + using FuncPtr = void(CUBLASWINAPI *)(int, double, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZdscal"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZdscal"); + return func_ptr(n, alpha, x, incx); +} + +void CUBLASWINAPI cublasSaxpy(int n, float alpha, const float *x, int incx, + float *y, int incy) { + using FuncPtr = + void(CUBLASWINAPI *)(int, float, const float *, int, float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSaxpy"); + if (!func_ptr) LogFatalSymbolNotFound("cublasSaxpy"); + return func_ptr(n, alpha, x, incx, y, incy); +} + +void CUBLASWINAPI cublasDaxpy(int n, double alpha, const double *x, int incx, + double *y, int incy) { + using FuncPtr = + void(CUBLASWINAPI *)(int, double, const double *, int, double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDaxpy"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDaxpy"); + return func_ptr(n, alpha, x, incx, y, incy); +} + +void CUBLASWINAPI cublasCaxpy(int n, cuComplex alpha, const cuComplex *x, + int incx, cuComplex *y, int incy) { + using FuncPtr = void(CUBLASWINAPI *)(int, cuComplex, const cuComplex *, int, + cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCaxpy"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCaxpy"); + return func_ptr(n, alpha, x, incx, y, incy); +} + +void CUBLASWINAPI cublasZaxpy(int n, cuDoubleComplex alpha, + const cuDoubleComplex *x, int incx, + cuDoubleComplex *y, int incy) { + using FuncPtr = + void(CUBLASWINAPI *)(int, cuDoubleComplex, const cuDoubleComplex *, int, + cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZaxpy"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZaxpy"); + return func_ptr(n, alpha, x, incx, y, incy); +} + +void CUBLASWINAPI cublasScopy(int n, const float *x, int incx, float *y, + int incy) { + using FuncPtr = void(CUBLASWINAPI *)(int, const float *, int, float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasScopy"); + if (!func_ptr) LogFatalSymbolNotFound("cublasScopy"); + return func_ptr(n, x, incx, y, incy); +} + +void CUBLASWINAPI cublasDcopy(int n, const double *x, int incx, double *y, + int incy) { + using FuncPtr = void(CUBLASWINAPI *)(int, const double *, int, double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDcopy"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDcopy"); + return func_ptr(n, x, incx, y, incy); +} + +void CUBLASWINAPI cublasCcopy(int n, const cuComplex *x, int incx, cuComplex *y, + int incy) { + using FuncPtr = + void(CUBLASWINAPI *)(int, const cuComplex *, int, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCcopy"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCcopy"); + return func_ptr(n, x, incx, y, incy); +} + +void CUBLASWINAPI cublasZcopy(int n, const cuDoubleComplex *x, int incx, + cuDoubleComplex *y, int incy) { + using FuncPtr = void(CUBLASWINAPI *)(int, const cuDoubleComplex *, int, + cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZcopy"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZcopy"); + return func_ptr(n, x, incx, y, incy); +} + +void CUBLASWINAPI cublasSswap(int n, float *x, int incx, float *y, int incy) { + using FuncPtr = void(CUBLASWINAPI *)(int, float *, int, float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSswap"); + if (!func_ptr) LogFatalSymbolNotFound("cublasSswap"); + return func_ptr(n, x, incx, y, incy); +} + +void CUBLASWINAPI cublasDswap(int n, double *x, int incx, double *y, int incy) { + using FuncPtr = void(CUBLASWINAPI *)(int, double *, int, double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDswap"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDswap"); + return func_ptr(n, x, incx, y, incy); +} + +void CUBLASWINAPI cublasCswap(int n, cuComplex *x, int incx, cuComplex *y, + int incy) { + using FuncPtr = void(CUBLASWINAPI *)(int, cuComplex *, int, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCswap"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCswap"); + return func_ptr(n, x, incx, y, incy); +} + +void CUBLASWINAPI cublasZswap(int n, cuDoubleComplex *x, int incx, + cuDoubleComplex *y, int incy) { + using FuncPtr = + void(CUBLASWINAPI *)(int, cuDoubleComplex *, int, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZswap"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZswap"); + return func_ptr(n, x, incx, y, incy); +} + +int CUBLASWINAPI cublasIsamax(int n, const float *x, int incx) { + using FuncPtr = int(CUBLASWINAPI *)(int, const float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasIsamax"); + if (!func_ptr) LogFatalSymbolNotFound("cublasIsamax"); + return func_ptr(n, x, incx); +} + +int CUBLASWINAPI cublasIdamax(int n, const double *x, int incx) { + using FuncPtr = int(CUBLASWINAPI *)(int, const double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasIdamax"); + if (!func_ptr) LogFatalSymbolNotFound("cublasIdamax"); + return func_ptr(n, x, incx); +} + +int CUBLASWINAPI cublasIcamax(int n, const cuComplex *x, int incx) { + using FuncPtr = int(CUBLASWINAPI *)(int, const cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasIcamax"); + if (!func_ptr) LogFatalSymbolNotFound("cublasIcamax"); + return func_ptr(n, x, incx); +} + +int CUBLASWINAPI cublasIzamax(int n, const cuDoubleComplex *x, int incx) { + using FuncPtr = int(CUBLASWINAPI *)(int, const cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasIzamax"); + if (!func_ptr) LogFatalSymbolNotFound("cublasIzamax"); + return func_ptr(n, x, incx); +} + +int CUBLASWINAPI cublasIsamin(int n, const float *x, int incx) { + using FuncPtr = int(CUBLASWINAPI *)(int, const float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasIsamin"); + if (!func_ptr) LogFatalSymbolNotFound("cublasIsamin"); + return func_ptr(n, x, incx); +} + +int CUBLASWINAPI cublasIdamin(int n, const double *x, int incx) { + using FuncPtr = int(CUBLASWINAPI *)(int, const double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasIdamin"); + if (!func_ptr) LogFatalSymbolNotFound("cublasIdamin"); + return func_ptr(n, x, incx); +} + +int CUBLASWINAPI cublasIcamin(int n, const cuComplex *x, int incx) { + using FuncPtr = int(CUBLASWINAPI *)(int, const cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasIcamin"); + if (!func_ptr) LogFatalSymbolNotFound("cublasIcamin"); + return func_ptr(n, x, incx); +} + +int CUBLASWINAPI cublasIzamin(int n, const cuDoubleComplex *x, int incx) { + using FuncPtr = int(CUBLASWINAPI *)(int, const cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasIzamin"); + if (!func_ptr) LogFatalSymbolNotFound("cublasIzamin"); + return func_ptr(n, x, incx); +} + +float CUBLASWINAPI cublasSasum(int n, const float *x, int incx) { + using FuncPtr = float(CUBLASWINAPI *)(int, const float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSasum"); + if (!func_ptr) LogFatalSymbolNotFound("cublasSasum"); + return func_ptr(n, x, incx); +} + +double CUBLASWINAPI cublasDasum(int n, const double *x, int incx) { + using FuncPtr = double(CUBLASWINAPI *)(int, const double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDasum"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDasum"); + return func_ptr(n, x, incx); +} + +float CUBLASWINAPI cublasScasum(int n, const cuComplex *x, int incx) { + using FuncPtr = float(CUBLASWINAPI *)(int, const cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasScasum"); + if (!func_ptr) LogFatalSymbolNotFound("cublasScasum"); + return func_ptr(n, x, incx); +} + +double CUBLASWINAPI cublasDzasum(int n, const cuDoubleComplex *x, int incx) { + using FuncPtr = double(CUBLASWINAPI *)(int, const cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDzasum"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDzasum"); + return func_ptr(n, x, incx); +} + +void CUBLASWINAPI cublasSrot(int n, float *x, int incx, float *y, int incy, + float sc, float ss) { + using FuncPtr = + void(CUBLASWINAPI *)(int, float *, int, float *, int, float, float); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSrot"); + if (!func_ptr) LogFatalSymbolNotFound("cublasSrot"); + return func_ptr(n, x, incx, y, incy, sc, ss); +} + +void CUBLASWINAPI cublasDrot(int n, double *x, int incx, double *y, int incy, + double sc, double ss) { + using FuncPtr = + void(CUBLASWINAPI *)(int, double *, int, double *, int, double, double); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDrot"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDrot"); + return func_ptr(n, x, incx, y, incy, sc, ss); +} + +void CUBLASWINAPI cublasCrot(int n, cuComplex *x, int incx, cuComplex *y, + int incy, float c, cuComplex s) { + using FuncPtr = void(CUBLASWINAPI *)(int, cuComplex *, int, cuComplex *, int, + float, cuComplex); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCrot"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCrot"); + return func_ptr(n, x, incx, y, incy, c, s); +} + +void CUBLASWINAPI cublasZrot(int n, cuDoubleComplex *x, int incx, + cuDoubleComplex *y, int incy, double sc, + cuDoubleComplex cs) { + using FuncPtr = + void(CUBLASWINAPI *)(int, cuDoubleComplex *, int, cuDoubleComplex *, int, + double, cuDoubleComplex); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZrot"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZrot"); + return func_ptr(n, x, incx, y, incy, sc, cs); +} + +void CUBLASWINAPI cublasCsrot(int n, cuComplex *x, int incx, cuComplex *y, + int incy, float c, float s) { + using FuncPtr = void(CUBLASWINAPI *)(int, cuComplex *, int, cuComplex *, int, + float, float); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCsrot"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCsrot"); + return func_ptr(n, x, incx, y, incy, c, s); +} + +void CUBLASWINAPI cublasZdrot(int n, cuDoubleComplex *x, int incx, + cuDoubleComplex *y, int incy, double c, + double s) { + using FuncPtr = void(CUBLASWINAPI *)(int, cuDoubleComplex *, int, + cuDoubleComplex *, int, double, double); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZdrot"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZdrot"); + return func_ptr(n, x, incx, y, incy, c, s); +} + +void CUBLASWINAPI cublasSrotg(float *sa, float *sb, float *sc, float *ss) { + using FuncPtr = void(CUBLASWINAPI *)(float *, float *, float *, float *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSrotg"); + if (!func_ptr) LogFatalSymbolNotFound("cublasSrotg"); + return func_ptr(sa, sb, sc, ss); +} + +void CUBLASWINAPI cublasDrotg(double *sa, double *sb, double *sc, double *ss) { + using FuncPtr = void(CUBLASWINAPI *)(double *, double *, double *, double *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDrotg"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDrotg"); + return func_ptr(sa, sb, sc, ss); +} + +void CUBLASWINAPI cublasCrotg(cuComplex *ca, cuComplex cb, float *sc, + cuComplex *cs) { + using FuncPtr = + void(CUBLASWINAPI *)(cuComplex *, cuComplex, float *, cuComplex *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCrotg"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCrotg"); + return func_ptr(ca, cb, sc, cs); +} + +void CUBLASWINAPI cublasZrotg(cuDoubleComplex *ca, cuDoubleComplex cb, + double *sc, cuDoubleComplex *cs) { + using FuncPtr = void(CUBLASWINAPI *)(cuDoubleComplex *, cuDoubleComplex, + double *, cuDoubleComplex *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZrotg"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZrotg"); + return func_ptr(ca, cb, sc, cs); +} + +void CUBLASWINAPI cublasSrotm(int n, float *x, int incx, float *y, int incy, + const float *sparam) { + using FuncPtr = + void(CUBLASWINAPI *)(int, float *, int, float *, int, const float *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSrotm"); + if (!func_ptr) LogFatalSymbolNotFound("cublasSrotm"); + return func_ptr(n, x, incx, y, incy, sparam); +} + +void CUBLASWINAPI cublasDrotm(int n, double *x, int incx, double *y, int incy, + const double *sparam) { + using FuncPtr = + void(CUBLASWINAPI *)(int, double *, int, double *, int, const double *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDrotm"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDrotm"); + return func_ptr(n, x, incx, y, incy, sparam); +} + +void CUBLASWINAPI cublasSrotmg(float *sd1, float *sd2, float *sx1, + const float *sy1, float *sparam) { + using FuncPtr = + void(CUBLASWINAPI *)(float *, float *, float *, const float *, float *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSrotmg"); + if (!func_ptr) LogFatalSymbolNotFound("cublasSrotmg"); + return func_ptr(sd1, sd2, sx1, sy1, sparam); +} + +void CUBLASWINAPI cublasDrotmg(double *sd1, double *sd2, double *sx1, + const double *sy1, double *sparam) { + using FuncPtr = void(CUBLASWINAPI *)(double *, double *, double *, + const double *, double *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDrotmg"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDrotmg"); + return func_ptr(sd1, sd2, sx1, sy1, sparam); +} + +void CUBLASWINAPI cublasSgemv(char trans, int m, int n, float alpha, + const float *A, int lda, const float *x, int incx, + float beta, float *y, int incy) { + using FuncPtr = + void(CUBLASWINAPI *)(char, int, int, float, const float *, int, + const float *, int, float, float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSgemv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasSgemv"); + return func_ptr(trans, m, n, alpha, A, lda, x, incx, beta, y, incy); +} + +void CUBLASWINAPI cublasDgemv(char trans, int m, int n, double alpha, + const double *A, int lda, const double *x, + int incx, double beta, double *y, int incy) { + using FuncPtr = + void(CUBLASWINAPI *)(char, int, int, double, const double *, int, + const double *, int, double, double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDgemv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDgemv"); + return func_ptr(trans, m, n, alpha, A, lda, x, incx, beta, y, incy); +} + +void CUBLASWINAPI cublasCgemv(char trans, int m, int n, cuComplex alpha, + const cuComplex *A, int lda, const cuComplex *x, + int incx, cuComplex beta, cuComplex *y, + int incy) { + using FuncPtr = + void(CUBLASWINAPI *)(char, int, int, cuComplex, const cuComplex *, int, + const cuComplex *, int, cuComplex, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgemv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCgemv"); + return func_ptr(trans, m, n, alpha, A, lda, x, incx, beta, y, incy); +} + +void CUBLASWINAPI cublasZgemv(char trans, int m, int n, cuDoubleComplex alpha, + const cuDoubleComplex *A, int lda, + const cuDoubleComplex *x, int incx, + cuDoubleComplex beta, cuDoubleComplex *y, + int incy) { + using FuncPtr = void(CUBLASWINAPI *)( + char, int, int, cuDoubleComplex, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZgemv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZgemv"); + return func_ptr(trans, m, n, alpha, A, lda, x, incx, beta, y, incy); +} + +void CUBLASWINAPI cublasSgbmv(char trans, int m, int n, int kl, int ku, + float alpha, const float *A, int lda, + const float *x, int incx, float beta, float *y, + int incy) { + using FuncPtr = + void(CUBLASWINAPI *)(char, int, int, int, int, float, const float *, int, + const float *, int, float, float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSgbmv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasSgbmv"); + return func_ptr(trans, m, n, kl, ku, alpha, A, lda, x, incx, beta, y, incy); +} + +void CUBLASWINAPI cublasDgbmv(char trans, int m, int n, int kl, int ku, + double alpha, const double *A, int lda, + const double *x, int incx, double beta, double *y, + int incy) { + using FuncPtr = + void(CUBLASWINAPI *)(char, int, int, int, int, double, const double *, + int, const double *, int, double, double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDgbmv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDgbmv"); + return func_ptr(trans, m, n, kl, ku, alpha, A, lda, x, incx, beta, y, incy); +} + +void CUBLASWINAPI cublasCgbmv(char trans, int m, int n, int kl, int ku, + cuComplex alpha, const cuComplex *A, int lda, + const cuComplex *x, int incx, cuComplex beta, + cuComplex *y, int incy) { + using FuncPtr = void(CUBLASWINAPI *)( + char, int, int, int, int, cuComplex, const cuComplex *, int, + const cuComplex *, int, cuComplex, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgbmv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCgbmv"); + return func_ptr(trans, m, n, kl, ku, alpha, A, lda, x, incx, beta, y, incy); +} + +void CUBLASWINAPI cublasZgbmv(char trans, int m, int n, int kl, int ku, + cuDoubleComplex alpha, const cuDoubleComplex *A, + int lda, const cuDoubleComplex *x, int incx, + cuDoubleComplex beta, cuDoubleComplex *y, + int incy) { + using FuncPtr = void(CUBLASWINAPI *)( + char, int, int, int, int, cuDoubleComplex, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZgbmv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZgbmv"); + return func_ptr(trans, m, n, kl, ku, alpha, A, lda, x, incx, beta, y, incy); +} + +void CUBLASWINAPI cublasStrmv(char uplo, char trans, char diag, int n, + const float *A, int lda, float *x, int incx) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, int, const float *, + int, float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasStrmv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasStrmv"); + return func_ptr(uplo, trans, diag, n, A, lda, x, incx); +} + +void CUBLASWINAPI cublasDtrmv(char uplo, char trans, char diag, int n, + const double *A, int lda, double *x, int incx) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, int, const double *, + int, double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtrmv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDtrmv"); + return func_ptr(uplo, trans, diag, n, A, lda, x, incx); +} + +void CUBLASWINAPI cublasCtrmv(char uplo, char trans, char diag, int n, + const cuComplex *A, int lda, cuComplex *x, + int incx) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, int, const cuComplex *, + int, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtrmv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCtrmv"); + return func_ptr(uplo, trans, diag, n, A, lda, x, incx); +} + +void CUBLASWINAPI cublasZtrmv(char uplo, char trans, char diag, int n, + const cuDoubleComplex *A, int lda, + cuDoubleComplex *x, int incx) { + using FuncPtr = + void(CUBLASWINAPI *)(char, char, char, int, const cuDoubleComplex *, int, + cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtrmv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZtrmv"); + return func_ptr(uplo, trans, diag, n, A, lda, x, incx); +} + +void CUBLASWINAPI cublasStbmv(char uplo, char trans, char diag, int n, int k, + const float *A, int lda, float *x, int incx) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, int, int, + const float *, int, float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasStbmv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasStbmv"); + return func_ptr(uplo, trans, diag, n, k, A, lda, x, incx); +} + +void CUBLASWINAPI cublasDtbmv(char uplo, char trans, char diag, int n, int k, + const double *A, int lda, double *x, int incx) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, int, int, + const double *, int, double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtbmv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDtbmv"); + return func_ptr(uplo, trans, diag, n, k, A, lda, x, incx); +} + +void CUBLASWINAPI cublasCtbmv(char uplo, char trans, char diag, int n, int k, + const cuComplex *A, int lda, cuComplex *x, + int incx) { + using FuncPtr = void(CUBLASWINAPI *)( + char, char, char, int, int, const cuComplex *, int, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtbmv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCtbmv"); + return func_ptr(uplo, trans, diag, n, k, A, lda, x, incx); +} + +void CUBLASWINAPI cublasZtbmv(char uplo, char trans, char diag, int n, int k, + const cuDoubleComplex *A, int lda, + cuDoubleComplex *x, int incx) { + using FuncPtr = + void(CUBLASWINAPI *)(char, char, char, int, int, const cuDoubleComplex *, + int, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtbmv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZtbmv"); + return func_ptr(uplo, trans, diag, n, k, A, lda, x, incx); +} + +void CUBLASWINAPI cublasStpmv(char uplo, char trans, char diag, int n, + const float *AP, float *x, int incx) { + using FuncPtr = + void(CUBLASWINAPI *)(char, char, char, int, const float *, float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasStpmv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasStpmv"); + return func_ptr(uplo, trans, diag, n, AP, x, incx); +} + +void CUBLASWINAPI cublasDtpmv(char uplo, char trans, char diag, int n, + const double *AP, double *x, int incx) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, int, const double *, + double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtpmv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDtpmv"); + return func_ptr(uplo, trans, diag, n, AP, x, incx); +} + +void CUBLASWINAPI cublasCtpmv(char uplo, char trans, char diag, int n, + const cuComplex *AP, cuComplex *x, int incx) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, int, const cuComplex *, + cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtpmv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCtpmv"); + return func_ptr(uplo, trans, diag, n, AP, x, incx); +} + +void CUBLASWINAPI cublasZtpmv(char uplo, char trans, char diag, int n, + const cuDoubleComplex *AP, cuDoubleComplex *x, + int incx) { + using FuncPtr = void(CUBLASWINAPI *)( + char, char, char, int, const cuDoubleComplex *, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtpmv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZtpmv"); + return func_ptr(uplo, trans, diag, n, AP, x, incx); +} + +void CUBLASWINAPI cublasStrsv(char uplo, char trans, char diag, int n, + const float *A, int lda, float *x, int incx) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, int, const float *, + int, float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasStrsv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasStrsv"); + return func_ptr(uplo, trans, diag, n, A, lda, x, incx); +} + +void CUBLASWINAPI cublasDtrsv(char uplo, char trans, char diag, int n, + const double *A, int lda, double *x, int incx) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, int, const double *, + int, double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtrsv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDtrsv"); + return func_ptr(uplo, trans, diag, n, A, lda, x, incx); +} + +void CUBLASWINAPI cublasCtrsv(char uplo, char trans, char diag, int n, + const cuComplex *A, int lda, cuComplex *x, + int incx) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, int, const cuComplex *, + int, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtrsv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCtrsv"); + return func_ptr(uplo, trans, diag, n, A, lda, x, incx); +} + +void CUBLASWINAPI cublasZtrsv(char uplo, char trans, char diag, int n, + const cuDoubleComplex *A, int lda, + cuDoubleComplex *x, int incx) { + using FuncPtr = + void(CUBLASWINAPI *)(char, char, char, int, const cuDoubleComplex *, int, + cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtrsv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZtrsv"); + return func_ptr(uplo, trans, diag, n, A, lda, x, incx); +} + +void CUBLASWINAPI cublasStpsv(char uplo, char trans, char diag, int n, + const float *AP, float *x, int incx) { + using FuncPtr = + void(CUBLASWINAPI *)(char, char, char, int, const float *, float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasStpsv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasStpsv"); + return func_ptr(uplo, trans, diag, n, AP, x, incx); +} + +void CUBLASWINAPI cublasDtpsv(char uplo, char trans, char diag, int n, + const double *AP, double *x, int incx) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, int, const double *, + double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtpsv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDtpsv"); + return func_ptr(uplo, trans, diag, n, AP, x, incx); +} + +void CUBLASWINAPI cublasCtpsv(char uplo, char trans, char diag, int n, + const cuComplex *AP, cuComplex *x, int incx) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, int, const cuComplex *, + cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtpsv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCtpsv"); + return func_ptr(uplo, trans, diag, n, AP, x, incx); +} + +void CUBLASWINAPI cublasZtpsv(char uplo, char trans, char diag, int n, + const cuDoubleComplex *AP, cuDoubleComplex *x, + int incx) { + using FuncPtr = void(CUBLASWINAPI *)( + char, char, char, int, const cuDoubleComplex *, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtpsv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZtpsv"); + return func_ptr(uplo, trans, diag, n, AP, x, incx); +} + +void CUBLASWINAPI cublasStbsv(char uplo, char trans, char diag, int n, int k, + const float *A, int lda, float *x, int incx) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, int, int, + const float *, int, float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasStbsv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasStbsv"); + return func_ptr(uplo, trans, diag, n, k, A, lda, x, incx); +} + +void CUBLASWINAPI cublasDtbsv(char uplo, char trans, char diag, int n, int k, + const double *A, int lda, double *x, int incx) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, int, int, + const double *, int, double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtbsv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDtbsv"); + return func_ptr(uplo, trans, diag, n, k, A, lda, x, incx); +} + +void CUBLASWINAPI cublasCtbsv(char uplo, char trans, char diag, int n, int k, + const cuComplex *A, int lda, cuComplex *x, + int incx) { + using FuncPtr = void(CUBLASWINAPI *)( + char, char, char, int, int, const cuComplex *, int, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtbsv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCtbsv"); + return func_ptr(uplo, trans, diag, n, k, A, lda, x, incx); +} + +void CUBLASWINAPI cublasZtbsv(char uplo, char trans, char diag, int n, int k, + const cuDoubleComplex *A, int lda, + cuDoubleComplex *x, int incx) { + using FuncPtr = + void(CUBLASWINAPI *)(char, char, char, int, int, const cuDoubleComplex *, + int, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtbsv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZtbsv"); + return func_ptr(uplo, trans, diag, n, k, A, lda, x, incx); +} + +void CUBLASWINAPI cublasSsymv(char uplo, int n, float alpha, const float *A, + int lda, const float *x, int incx, float beta, + float *y, int incy) { + using FuncPtr = void(CUBLASWINAPI *)(char, int, float, const float *, int, + const float *, int, float, float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSsymv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasSsymv"); + return func_ptr(uplo, n, alpha, A, lda, x, incx, beta, y, incy); +} + +void CUBLASWINAPI cublasDsymv(char uplo, int n, double alpha, const double *A, + int lda, const double *x, int incx, double beta, + double *y, int incy) { + using FuncPtr = + void(CUBLASWINAPI *)(char, int, double, const double *, int, + const double *, int, double, double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDsymv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDsymv"); + return func_ptr(uplo, n, alpha, A, lda, x, incx, beta, y, incy); +} + +void CUBLASWINAPI cublasChemv(char uplo, int n, cuComplex alpha, + const cuComplex *A, int lda, const cuComplex *x, + int incx, cuComplex beta, cuComplex *y, + int incy) { + using FuncPtr = + void(CUBLASWINAPI *)(char, int, cuComplex, const cuComplex *, int, + const cuComplex *, int, cuComplex, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasChemv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasChemv"); + return func_ptr(uplo, n, alpha, A, lda, x, incx, beta, y, incy); +} + +void CUBLASWINAPI cublasZhemv(char uplo, int n, cuDoubleComplex alpha, + const cuDoubleComplex *A, int lda, + const cuDoubleComplex *x, int incx, + cuDoubleComplex beta, cuDoubleComplex *y, + int incy) { + using FuncPtr = void(CUBLASWINAPI *)( + char, int, cuDoubleComplex, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZhemv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZhemv"); + return func_ptr(uplo, n, alpha, A, lda, x, incx, beta, y, incy); +} + +void CUBLASWINAPI cublasSsbmv(char uplo, int n, int k, float alpha, + const float *A, int lda, const float *x, int incx, + float beta, float *y, int incy) { + using FuncPtr = + void(CUBLASWINAPI *)(char, int, int, float, const float *, int, + const float *, int, float, float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSsbmv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasSsbmv"); + return func_ptr(uplo, n, k, alpha, A, lda, x, incx, beta, y, incy); +} + +void CUBLASWINAPI cublasDsbmv(char uplo, int n, int k, double alpha, + const double *A, int lda, const double *x, + int incx, double beta, double *y, int incy) { + using FuncPtr = + void(CUBLASWINAPI *)(char, int, int, double, const double *, int, + const double *, int, double, double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDsbmv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDsbmv"); + return func_ptr(uplo, n, k, alpha, A, lda, x, incx, beta, y, incy); +} + +void CUBLASWINAPI cublasChbmv(char uplo, int n, int k, cuComplex alpha, + const cuComplex *A, int lda, const cuComplex *x, + int incx, cuComplex beta, cuComplex *y, + int incy) { + using FuncPtr = + void(CUBLASWINAPI *)(char, int, int, cuComplex, const cuComplex *, int, + const cuComplex *, int, cuComplex, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasChbmv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasChbmv"); + return func_ptr(uplo, n, k, alpha, A, lda, x, incx, beta, y, incy); +} + +void CUBLASWINAPI cublasZhbmv(char uplo, int n, int k, cuDoubleComplex alpha, + const cuDoubleComplex *A, int lda, + const cuDoubleComplex *x, int incx, + cuDoubleComplex beta, cuDoubleComplex *y, + int incy) { + using FuncPtr = void(CUBLASWINAPI *)( + char, int, int, cuDoubleComplex, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZhbmv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZhbmv"); + return func_ptr(uplo, n, k, alpha, A, lda, x, incx, beta, y, incy); +} + +void CUBLASWINAPI cublasSspmv(char uplo, int n, float alpha, const float *AP, + const float *x, int incx, float beta, float *y, + int incy) { + using FuncPtr = void(CUBLASWINAPI *)(char, int, float, const float *, + const float *, int, float, float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSspmv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasSspmv"); + return func_ptr(uplo, n, alpha, AP, x, incx, beta, y, incy); +} + +void CUBLASWINAPI cublasDspmv(char uplo, int n, double alpha, const double *AP, + const double *x, int incx, double beta, double *y, + int incy) { + using FuncPtr = + void(CUBLASWINAPI *)(char, int, double, const double *, const double *, + int, double, double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDspmv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDspmv"); + return func_ptr(uplo, n, alpha, AP, x, incx, beta, y, incy); +} + +void CUBLASWINAPI cublasChpmv(char uplo, int n, cuComplex alpha, + const cuComplex *AP, const cuComplex *x, int incx, + cuComplex beta, cuComplex *y, int incy) { + using FuncPtr = + void(CUBLASWINAPI *)(char, int, cuComplex, const cuComplex *, + const cuComplex *, int, cuComplex, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasChpmv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasChpmv"); + return func_ptr(uplo, n, alpha, AP, x, incx, beta, y, incy); +} + +void CUBLASWINAPI cublasZhpmv(char uplo, int n, cuDoubleComplex alpha, + const cuDoubleComplex *AP, + const cuDoubleComplex *x, int incx, + cuDoubleComplex beta, cuDoubleComplex *y, + int incy) { + using FuncPtr = void(CUBLASWINAPI *)( + char, int, cuDoubleComplex, const cuDoubleComplex *, + const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZhpmv"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZhpmv"); + return func_ptr(uplo, n, alpha, AP, x, incx, beta, y, incy); +} + +void CUBLASWINAPI cublasSger(int m, int n, float alpha, const float *x, + int incx, const float *y, int incy, float *A, + int lda) { + using FuncPtr = void(CUBLASWINAPI *)(int, int, float, const float *, int, + const float *, int, float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSger"); + if (!func_ptr) LogFatalSymbolNotFound("cublasSger"); + return func_ptr(m, n, alpha, x, incx, y, incy, A, lda); +} + +void CUBLASWINAPI cublasDger(int m, int n, double alpha, const double *x, + int incx, const double *y, int incy, double *A, + int lda) { + using FuncPtr = void(CUBLASWINAPI *)(int, int, double, const double *, int, + const double *, int, double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDger"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDger"); + return func_ptr(m, n, alpha, x, incx, y, incy, A, lda); +} + +void CUBLASWINAPI cublasCgeru(int m, int n, cuComplex alpha, const cuComplex *x, + int incx, const cuComplex *y, int incy, + cuComplex *A, int lda) { + using FuncPtr = + void(CUBLASWINAPI *)(int, int, cuComplex, const cuComplex *, int, + const cuComplex *, int, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgeru"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCgeru"); + return func_ptr(m, n, alpha, x, incx, y, incy, A, lda); +} + +void CUBLASWINAPI cublasCgerc(int m, int n, cuComplex alpha, const cuComplex *x, + int incx, const cuComplex *y, int incy, + cuComplex *A, int lda) { + using FuncPtr = + void(CUBLASWINAPI *)(int, int, cuComplex, const cuComplex *, int, + const cuComplex *, int, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgerc"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCgerc"); + return func_ptr(m, n, alpha, x, incx, y, incy, A, lda); +} + +void CUBLASWINAPI cublasZgeru(int m, int n, cuDoubleComplex alpha, + const cuDoubleComplex *x, int incx, + const cuDoubleComplex *y, int incy, + cuDoubleComplex *A, int lda) { + using FuncPtr = void(CUBLASWINAPI *)( + int, int, cuDoubleComplex, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZgeru"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZgeru"); + return func_ptr(m, n, alpha, x, incx, y, incy, A, lda); +} + +void CUBLASWINAPI cublasZgerc(int m, int n, cuDoubleComplex alpha, + const cuDoubleComplex *x, int incx, + const cuDoubleComplex *y, int incy, + cuDoubleComplex *A, int lda) { + using FuncPtr = void(CUBLASWINAPI *)( + int, int, cuDoubleComplex, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZgerc"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZgerc"); + return func_ptr(m, n, alpha, x, incx, y, incy, A, lda); +} + +void CUBLASWINAPI cublasSsyr(char uplo, int n, float alpha, const float *x, + int incx, float *A, int lda) { + using FuncPtr = + void(CUBLASWINAPI *)(char, int, float, const float *, int, float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSsyr"); + if (!func_ptr) LogFatalSymbolNotFound("cublasSsyr"); + return func_ptr(uplo, n, alpha, x, incx, A, lda); +} + +void CUBLASWINAPI cublasDsyr(char uplo, int n, double alpha, const double *x, + int incx, double *A, int lda) { + using FuncPtr = void(CUBLASWINAPI *)(char, int, double, const double *, int, + double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDsyr"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDsyr"); + return func_ptr(uplo, n, alpha, x, incx, A, lda); +} + +void CUBLASWINAPI cublasCher(char uplo, int n, float alpha, const cuComplex *x, + int incx, cuComplex *A, int lda) { + using FuncPtr = void(CUBLASWINAPI *)(char, int, float, const cuComplex *, int, + cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCher"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCher"); + return func_ptr(uplo, n, alpha, x, incx, A, lda); +} + +void CUBLASWINAPI cublasZher(char uplo, int n, double alpha, + const cuDoubleComplex *x, int incx, + cuDoubleComplex *A, int lda) { + using FuncPtr = void(CUBLASWINAPI *)( + char, int, double, const cuDoubleComplex *, int, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZher"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZher"); + return func_ptr(uplo, n, alpha, x, incx, A, lda); +} + +void CUBLASWINAPI cublasSspr(char uplo, int n, float alpha, const float *x, + int incx, float *AP) { + using FuncPtr = + void(CUBLASWINAPI *)(char, int, float, const float *, int, float *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSspr"); + if (!func_ptr) LogFatalSymbolNotFound("cublasSspr"); + return func_ptr(uplo, n, alpha, x, incx, AP); +} + +void CUBLASWINAPI cublasDspr(char uplo, int n, double alpha, const double *x, + int incx, double *AP) { + using FuncPtr = + void(CUBLASWINAPI *)(char, int, double, const double *, int, double *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDspr"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDspr"); + return func_ptr(uplo, n, alpha, x, incx, AP); +} + +void CUBLASWINAPI cublasChpr(char uplo, int n, float alpha, const cuComplex *x, + int incx, cuComplex *AP) { + using FuncPtr = void(CUBLASWINAPI *)(char, int, float, const cuComplex *, int, + cuComplex *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasChpr"); + if (!func_ptr) LogFatalSymbolNotFound("cublasChpr"); + return func_ptr(uplo, n, alpha, x, incx, AP); +} + +void CUBLASWINAPI cublasZhpr(char uplo, int n, double alpha, + const cuDoubleComplex *x, int incx, + cuDoubleComplex *AP) { + using FuncPtr = void(CUBLASWINAPI *)( + char, int, double, const cuDoubleComplex *, int, cuDoubleComplex *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZhpr"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZhpr"); + return func_ptr(uplo, n, alpha, x, incx, AP); +} + +void CUBLASWINAPI cublasSsyr2(char uplo, int n, float alpha, const float *x, + int incx, const float *y, int incy, float *A, + int lda) { + using FuncPtr = void(CUBLASWINAPI *)(char, int, float, const float *, int, + const float *, int, float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSsyr2"); + if (!func_ptr) LogFatalSymbolNotFound("cublasSsyr2"); + return func_ptr(uplo, n, alpha, x, incx, y, incy, A, lda); +} + +void CUBLASWINAPI cublasDsyr2(char uplo, int n, double alpha, const double *x, + int incx, const double *y, int incy, double *A, + int lda) { + using FuncPtr = void(CUBLASWINAPI *)(char, int, double, const double *, int, + const double *, int, double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDsyr2"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDsyr2"); + return func_ptr(uplo, n, alpha, x, incx, y, incy, A, lda); +} + +void CUBLASWINAPI cublasCher2(char uplo, int n, cuComplex alpha, + const cuComplex *x, int incx, const cuComplex *y, + int incy, cuComplex *A, int lda) { + using FuncPtr = + void(CUBLASWINAPI *)(char, int, cuComplex, const cuComplex *, int, + const cuComplex *, int, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCher2"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCher2"); + return func_ptr(uplo, n, alpha, x, incx, y, incy, A, lda); +} + +void CUBLASWINAPI cublasZher2(char uplo, int n, cuDoubleComplex alpha, + const cuDoubleComplex *x, int incx, + const cuDoubleComplex *y, int incy, + cuDoubleComplex *A, int lda) { + using FuncPtr = void(CUBLASWINAPI *)( + char, int, cuDoubleComplex, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZher2"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZher2"); + return func_ptr(uplo, n, alpha, x, incx, y, incy, A, lda); +} + +void CUBLASWINAPI cublasSspr2(char uplo, int n, float alpha, const float *x, + int incx, const float *y, int incy, float *AP) { + using FuncPtr = void(CUBLASWINAPI *)(char, int, float, const float *, int, + const float *, int, float *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSspr2"); + if (!func_ptr) LogFatalSymbolNotFound("cublasSspr2"); + return func_ptr(uplo, n, alpha, x, incx, y, incy, AP); +} + +void CUBLASWINAPI cublasDspr2(char uplo, int n, double alpha, const double *x, + int incx, const double *y, int incy, double *AP) { + using FuncPtr = void(CUBLASWINAPI *)(char, int, double, const double *, int, + const double *, int, double *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDspr2"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDspr2"); + return func_ptr(uplo, n, alpha, x, incx, y, incy, AP); +} + +void CUBLASWINAPI cublasChpr2(char uplo, int n, cuComplex alpha, + const cuComplex *x, int incx, const cuComplex *y, + int incy, cuComplex *AP) { + using FuncPtr = + void(CUBLASWINAPI *)(char, int, cuComplex, const cuComplex *, int, + const cuComplex *, int, cuComplex *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasChpr2"); + if (!func_ptr) LogFatalSymbolNotFound("cublasChpr2"); + return func_ptr(uplo, n, alpha, x, incx, y, incy, AP); +} + +void CUBLASWINAPI cublasZhpr2(char uplo, int n, cuDoubleComplex alpha, + const cuDoubleComplex *x, int incx, + const cuDoubleComplex *y, int incy, + cuDoubleComplex *AP) { + using FuncPtr = void(CUBLASWINAPI *)( + char, int, cuDoubleComplex, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, cuDoubleComplex *); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZhpr2"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZhpr2"); + return func_ptr(uplo, n, alpha, x, incx, y, incy, AP); +} + +void CUBLASWINAPI cublasSgemm(char transa, char transb, int m, int n, int k, + float alpha, const float *A, int lda, + const float *B, int ldb, float beta, float *C, + int ldc) { + using FuncPtr = + void(CUBLASWINAPI *)(char, char, int, int, int, float, const float *, int, + const float *, int, float, float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSgemm"); + if (!func_ptr) LogFatalSymbolNotFound("cublasSgemm"); + return func_ptr(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); +} + +void CUBLASWINAPI cublasDgemm(char transa, char transb, int m, int n, int k, + double alpha, const double *A, int lda, + const double *B, int ldb, double beta, double *C, + int ldc) { + using FuncPtr = + void(CUBLASWINAPI *)(char, char, int, int, int, double, const double *, + int, const double *, int, double, double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDgemm"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDgemm"); + return func_ptr(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); +} + +void CUBLASWINAPI cublasCgemm(char transa, char transb, int m, int n, int k, + cuComplex alpha, const cuComplex *A, int lda, + const cuComplex *B, int ldb, cuComplex beta, + cuComplex *C, int ldc) { + using FuncPtr = void(CUBLASWINAPI *)( + char, char, int, int, int, cuComplex, const cuComplex *, int, + const cuComplex *, int, cuComplex, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgemm"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCgemm"); + return func_ptr(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); +} + +void CUBLASWINAPI cublasZgemm(char transa, char transb, int m, int n, int k, + cuDoubleComplex alpha, const cuDoubleComplex *A, + int lda, const cuDoubleComplex *B, int ldb, + cuDoubleComplex beta, cuDoubleComplex *C, + int ldc) { + using FuncPtr = void(CUBLASWINAPI *)( + char, char, int, int, int, cuDoubleComplex, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZgemm"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZgemm"); + return func_ptr(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); +} + +void CUBLASWINAPI cublasSsyrk(char uplo, char trans, int n, int k, float alpha, + const float *A, int lda, float beta, float *C, + int ldc) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, int, int, float, + const float *, int, float, float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSsyrk"); + if (!func_ptr) LogFatalSymbolNotFound("cublasSsyrk"); + return func_ptr(uplo, trans, n, k, alpha, A, lda, beta, C, ldc); +} + +void CUBLASWINAPI cublasDsyrk(char uplo, char trans, int n, int k, double alpha, + const double *A, int lda, double beta, double *C, + int ldc) { + using FuncPtr = void(CUBLASWINAPI *)( + char, char, int, int, double, const double *, int, double, double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDsyrk"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDsyrk"); + return func_ptr(uplo, trans, n, k, alpha, A, lda, beta, C, ldc); +} + +void CUBLASWINAPI cublasCsyrk(char uplo, char trans, int n, int k, + cuComplex alpha, const cuComplex *A, int lda, + cuComplex beta, cuComplex *C, int ldc) { + using FuncPtr = + void(CUBLASWINAPI *)(char, char, int, int, cuComplex, const cuComplex *, + int, cuComplex, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCsyrk"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCsyrk"); + return func_ptr(uplo, trans, n, k, alpha, A, lda, beta, C, ldc); +} + +void CUBLASWINAPI cublasZsyrk(char uplo, char trans, int n, int k, + cuDoubleComplex alpha, const cuDoubleComplex *A, + int lda, cuDoubleComplex beta, cuDoubleComplex *C, + int ldc) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, int, int, cuDoubleComplex, + const cuDoubleComplex *, int, + cuDoubleComplex, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZsyrk"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZsyrk"); + return func_ptr(uplo, trans, n, k, alpha, A, lda, beta, C, ldc); +} + +void CUBLASWINAPI cublasCherk(char uplo, char trans, int n, int k, float alpha, + const cuComplex *A, int lda, float beta, + cuComplex *C, int ldc) { + using FuncPtr = + void(CUBLASWINAPI *)(char, char, int, int, float, const cuComplex *, int, + float, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCherk"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCherk"); + return func_ptr(uplo, trans, n, k, alpha, A, lda, beta, C, ldc); +} + +void CUBLASWINAPI cublasZherk(char uplo, char trans, int n, int k, double alpha, + const cuDoubleComplex *A, int lda, double beta, + cuDoubleComplex *C, int ldc) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, int, int, double, + const cuDoubleComplex *, int, double, + cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZherk"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZherk"); + return func_ptr(uplo, trans, n, k, alpha, A, lda, beta, C, ldc); +} + +void CUBLASWINAPI cublasSsyr2k(char uplo, char trans, int n, int k, float alpha, + const float *A, int lda, const float *B, int ldb, + float beta, float *C, int ldc) { + using FuncPtr = + void(CUBLASWINAPI *)(char, char, int, int, float, const float *, int, + const float *, int, float, float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSsyr2k"); + if (!func_ptr) LogFatalSymbolNotFound("cublasSsyr2k"); + return func_ptr(uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); +} + +void CUBLASWINAPI cublasDsyr2k(char uplo, char trans, int n, int k, + double alpha, const double *A, int lda, + const double *B, int ldb, double beta, double *C, + int ldc) { + using FuncPtr = + void(CUBLASWINAPI *)(char, char, int, int, double, const double *, int, + const double *, int, double, double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDsyr2k"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDsyr2k"); + return func_ptr(uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); +} + +void CUBLASWINAPI cublasCsyr2k(char uplo, char trans, int n, int k, + cuComplex alpha, const cuComplex *A, int lda, + const cuComplex *B, int ldb, cuComplex beta, + cuComplex *C, int ldc) { + using FuncPtr = void(CUBLASWINAPI *)( + char, char, int, int, cuComplex, const cuComplex *, int, + const cuComplex *, int, cuComplex, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCsyr2k"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCsyr2k"); + return func_ptr(uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); +} + +void CUBLASWINAPI cublasZsyr2k(char uplo, char trans, int n, int k, + cuDoubleComplex alpha, const cuDoubleComplex *A, + int lda, const cuDoubleComplex *B, int ldb, + cuDoubleComplex beta, cuDoubleComplex *C, + int ldc) { + using FuncPtr = void(CUBLASWINAPI *)( + char, char, int, int, cuDoubleComplex, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZsyr2k"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZsyr2k"); + return func_ptr(uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); +} + +void CUBLASWINAPI cublasCher2k(char uplo, char trans, int n, int k, + cuComplex alpha, const cuComplex *A, int lda, + const cuComplex *B, int ldb, float beta, + cuComplex *C, int ldc) { + using FuncPtr = void(CUBLASWINAPI *)( + char, char, int, int, cuComplex, const cuComplex *, int, + const cuComplex *, int, float, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCher2k"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCher2k"); + return func_ptr(uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); +} + +void CUBLASWINAPI cublasZher2k(char uplo, char trans, int n, int k, + cuDoubleComplex alpha, const cuDoubleComplex *A, + int lda, const cuDoubleComplex *B, int ldb, + double beta, cuDoubleComplex *C, int ldc) { + using FuncPtr = void(CUBLASWINAPI *)( + char, char, int, int, cuDoubleComplex, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, double, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZher2k"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZher2k"); + return func_ptr(uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); +} + +void CUBLASWINAPI cublasSsymm(char side, char uplo, int m, int n, float alpha, + const float *A, int lda, const float *B, int ldb, + float beta, float *C, int ldc) { + using FuncPtr = + void(CUBLASWINAPI *)(char, char, int, int, float, const float *, int, + const float *, int, float, float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasSsymm"); + if (!func_ptr) LogFatalSymbolNotFound("cublasSsymm"); + return func_ptr(side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc); +} + +void CUBLASWINAPI cublasDsymm(char side, char uplo, int m, int n, double alpha, + const double *A, int lda, const double *B, + int ldb, double beta, double *C, int ldc) { + using FuncPtr = + void(CUBLASWINAPI *)(char, char, int, int, double, const double *, int, + const double *, int, double, double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDsymm"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDsymm"); + return func_ptr(side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc); +} + +void CUBLASWINAPI cublasCsymm(char side, char uplo, int m, int n, + cuComplex alpha, const cuComplex *A, int lda, + const cuComplex *B, int ldb, cuComplex beta, + cuComplex *C, int ldc) { + using FuncPtr = void(CUBLASWINAPI *)( + char, char, int, int, cuComplex, const cuComplex *, int, + const cuComplex *, int, cuComplex, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCsymm"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCsymm"); + return func_ptr(side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc); +} + +void CUBLASWINAPI cublasZsymm(char side, char uplo, int m, int n, + cuDoubleComplex alpha, const cuDoubleComplex *A, + int lda, const cuDoubleComplex *B, int ldb, + cuDoubleComplex beta, cuDoubleComplex *C, + int ldc) { + using FuncPtr = void(CUBLASWINAPI *)( + char, char, int, int, cuDoubleComplex, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZsymm"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZsymm"); + return func_ptr(side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc); +} + +void CUBLASWINAPI cublasChemm(char side, char uplo, int m, int n, + cuComplex alpha, const cuComplex *A, int lda, + const cuComplex *B, int ldb, cuComplex beta, + cuComplex *C, int ldc) { + using FuncPtr = void(CUBLASWINAPI *)( + char, char, int, int, cuComplex, const cuComplex *, int, + const cuComplex *, int, cuComplex, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasChemm"); + if (!func_ptr) LogFatalSymbolNotFound("cublasChemm"); + return func_ptr(side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc); +} + +void CUBLASWINAPI cublasZhemm(char side, char uplo, int m, int n, + cuDoubleComplex alpha, const cuDoubleComplex *A, + int lda, const cuDoubleComplex *B, int ldb, + cuDoubleComplex beta, cuDoubleComplex *C, + int ldc) { + using FuncPtr = void(CUBLASWINAPI *)( + char, char, int, int, cuDoubleComplex, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZhemm"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZhemm"); + return func_ptr(side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc); +} + +void CUBLASWINAPI cublasStrsm(char side, char uplo, char transa, char diag, + int m, int n, float alpha, const float *A, + int lda, float *B, int ldb) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, char, int, int, float, + const float *, int, float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasStrsm"); + if (!func_ptr) LogFatalSymbolNotFound("cublasStrsm"); + return func_ptr(side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb); +} + +void CUBLASWINAPI cublasDtrsm(char side, char uplo, char transa, char diag, + int m, int n, double alpha, const double *A, + int lda, double *B, int ldb) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, char, int, int, double, + const double *, int, double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtrsm"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDtrsm"); + return func_ptr(side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb); +} + +void CUBLASWINAPI cublasCtrsm(char side, char uplo, char transa, char diag, + int m, int n, cuComplex alpha, const cuComplex *A, + int lda, cuComplex *B, int ldb) { + using FuncPtr = + void(CUBLASWINAPI *)(char, char, char, char, int, int, cuComplex, + const cuComplex *, int, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtrsm"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCtrsm"); + return func_ptr(side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb); +} + +void CUBLASWINAPI cublasZtrsm(char side, char uplo, char transa, char diag, + int m, int n, cuDoubleComplex alpha, + const cuDoubleComplex *A, int lda, + cuDoubleComplex *B, int ldb) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, char, int, int, + cuDoubleComplex, const cuDoubleComplex *, + int, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtrsm"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZtrsm"); + return func_ptr(side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb); +} + +void CUBLASWINAPI cublasStrmm(char side, char uplo, char transa, char diag, + int m, int n, float alpha, const float *A, + int lda, float *B, int ldb) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, char, int, int, float, + const float *, int, float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasStrmm"); + if (!func_ptr) LogFatalSymbolNotFound("cublasStrmm"); + return func_ptr(side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb); +} + +void CUBLASWINAPI cublasDtrmm(char side, char uplo, char transa, char diag, + int m, int n, double alpha, const double *A, + int lda, double *B, int ldb) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, char, int, int, double, + const double *, int, double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtrmm"); + if (!func_ptr) LogFatalSymbolNotFound("cublasDtrmm"); + return func_ptr(side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb); +} + +void CUBLASWINAPI cublasCtrmm(char side, char uplo, char transa, char diag, + int m, int n, cuComplex alpha, const cuComplex *A, + int lda, cuComplex *B, int ldb) { + using FuncPtr = + void(CUBLASWINAPI *)(char, char, char, char, int, int, cuComplex, + const cuComplex *, int, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtrmm"); + if (!func_ptr) LogFatalSymbolNotFound("cublasCtrmm"); + return func_ptr(side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb); +} + +void CUBLASWINAPI cublasZtrmm(char side, char uplo, char transa, char diag, + int m, int n, cuDoubleComplex alpha, + const cuDoubleComplex *A, int lda, + cuDoubleComplex *B, int ldb) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, char, int, int, + cuDoubleComplex, const cuDoubleComplex *, + int, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtrmm"); + if (!func_ptr) LogFatalSymbolNotFound("cublasZtrmm"); + return func_ptr(side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb); +} + +} // extern "C" diff --git a/tensorflow/stream_executor/cuda/cublas_9_0.inc b/tensorflow/stream_executor/cuda/cublas_9_0.inc index ba46426878f..5e716114b23 100644 --- a/tensorflow/stream_executor/cuda/cublas_9_0.inc +++ b/tensorflow/stream_executor/cuda/cublas_9_0.inc @@ -2,5120 +2,4814 @@ extern "C" { -cublasStatus_t CUBLASWINAPI cublasCreate_v2 (cublasHandle_t *handle) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t *); +cublasStatus_t CUBLASWINAPI cublasCreate_v2(cublasHandle_t *handle) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCreate_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle); } -cublasStatus_t CUBLASWINAPI cublasDestroy_v2 (cublasHandle_t handle) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t); +cublasStatus_t CUBLASWINAPI cublasDestroy_v2(cublasHandle_t handle) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDestroy_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle); } -cublasStatus_t CUBLASWINAPI cublasGetVersion_v2(cublasHandle_t handle, int *version) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int *); +cublasStatus_t CUBLASWINAPI cublasGetVersion_v2(cublasHandle_t handle, + int *version) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasGetVersion_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, version); } -cublasStatus_t CUBLASWINAPI cublasGetProperty(libraryPropertyType type, int *value) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(libraryPropertyType, int *); +cublasStatus_t CUBLASWINAPI cublasGetProperty(libraryPropertyType type, + int *value) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(libraryPropertyType, int *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasGetProperty"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(type, value); } -cublasStatus_t CUBLASWINAPI cublasSetStream_v2 (cublasHandle_t handle, cudaStream_t streamId) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cudaStream_t); +cublasStatus_t CUBLASWINAPI cublasSetStream_v2(cublasHandle_t handle, + cudaStream_t streamId) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, cudaStream_t); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSetStream_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, streamId); } -cublasStatus_t CUBLASWINAPI cublasGetStream_v2 (cublasHandle_t handle, cudaStream_t *streamId) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cudaStream_t *); +cublasStatus_t CUBLASWINAPI cublasGetStream_v2(cublasHandle_t handle, + cudaStream_t *streamId) { + using FuncPtr = + cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, cudaStream_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasGetStream_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, streamId); } -cublasStatus_t CUBLASWINAPI cublasGetPointerMode_v2 (cublasHandle_t handle, cublasPointerMode_t *mode) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasPointerMode_t *); +cublasStatus_t CUBLASWINAPI cublasGetPointerMode_v2(cublasHandle_t handle, + cublasPointerMode_t *mode) { + using FuncPtr = + cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, cublasPointerMode_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasGetPointerMode_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, mode); } -cublasStatus_t CUBLASWINAPI cublasSetPointerMode_v2 (cublasHandle_t handle, cublasPointerMode_t mode) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasPointerMode_t); +cublasStatus_t CUBLASWINAPI cublasSetPointerMode_v2(cublasHandle_t handle, + cublasPointerMode_t mode) { + using FuncPtr = + cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, cublasPointerMode_t); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSetPointerMode_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, mode); } -cublasStatus_t CUBLASWINAPI cublasGetAtomicsMode(cublasHandle_t handle, cublasAtomicsMode_t *mode) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasAtomicsMode_t *); +cublasStatus_t CUBLASWINAPI cublasGetAtomicsMode(cublasHandle_t handle, + cublasAtomicsMode_t *mode) { + using FuncPtr = + cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, cublasAtomicsMode_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasGetAtomicsMode"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, mode); } -cublasStatus_t CUBLASWINAPI cublasSetAtomicsMode(cublasHandle_t handle, cublasAtomicsMode_t mode) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasAtomicsMode_t); +cublasStatus_t CUBLASWINAPI cublasSetAtomicsMode(cublasHandle_t handle, + cublasAtomicsMode_t mode) { + using FuncPtr = + cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, cublasAtomicsMode_t); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSetAtomicsMode"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, mode); } -cublasStatus_t CUBLASWINAPI cublasGetMathMode(cublasHandle_t handle, cublasMath_t *mode) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasMath_t *); +cublasStatus_t CUBLASWINAPI cublasGetMathMode(cublasHandle_t handle, + cublasMath_t *mode) { + using FuncPtr = + cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, cublasMath_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasGetMathMode"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, mode); } -cublasStatus_t CUBLASWINAPI cublasSetMathMode(cublasHandle_t handle, cublasMath_t mode) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasMath_t); +cublasStatus_t CUBLASWINAPI cublasSetMathMode(cublasHandle_t handle, + cublasMath_t mode) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, cublasMath_t); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSetMathMode"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, mode); } -cublasStatus_t CUBLASWINAPI cublasSetVector (int n, int elemSize, const void *x, - int incx, void *devicePtr, int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(int, int, const void *, int, void *, int); +cublasStatus_t CUBLASWINAPI cublasSetVector(int n, int elemSize, const void *x, + int incx, void *devicePtr, + int incy) { + using FuncPtr = + cublasStatus_t(CUBLASWINAPI *)(int, int, const void *, int, void *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSetVector"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(n, elemSize, x, incx, devicePtr, incy); } -cublasStatus_t CUBLASWINAPI cublasGetVector (int n, int elemSize, const void *x, - int incx, void *y, int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(int, int, const void *, int, void *, int); +cublasStatus_t CUBLASWINAPI cublasGetVector(int n, int elemSize, const void *x, + int incx, void *y, int incy) { + using FuncPtr = + cublasStatus_t(CUBLASWINAPI *)(int, int, const void *, int, void *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasGetVector"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(n, elemSize, x, incx, y, incy); } -cublasStatus_t CUBLASWINAPI cublasSetMatrix (int rows, int cols, int elemSize, - const void *A, int lda, void *B, - int ldb) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(int, int, int, const void *, int, void *, int); +cublasStatus_t CUBLASWINAPI cublasSetMatrix(int rows, int cols, int elemSize, + const void *A, int lda, void *B, + int ldb) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(int, int, int, const void *, + int, void *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSetMatrix"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(rows, cols, elemSize, A, lda, B, ldb); } -cublasStatus_t CUBLASWINAPI cublasGetMatrix (int rows, int cols, int elemSize, - const void *A, int lda, void *B, - int ldb) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(int, int, int, const void *, int, void *, int); +cublasStatus_t CUBLASWINAPI cublasGetMatrix(int rows, int cols, int elemSize, + const void *A, int lda, void *B, + int ldb) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(int, int, int, const void *, + int, void *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasGetMatrix"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(rows, cols, elemSize, A, lda, B, ldb); } -cublasStatus_t CUBLASWINAPI cublasSetVectorAsync (int n, int elemSize, - const void *hostPtr, int incx, - void *devicePtr, int incy, - cudaStream_t stream) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(int, int, const void *, int, void *, int, cudaStream_t); +cublasStatus_t CUBLASWINAPI cublasSetVectorAsync(int n, int elemSize, + const void *hostPtr, int incx, + void *devicePtr, int incy, + cudaStream_t stream) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(int, int, const void *, int, + void *, int, cudaStream_t); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSetVectorAsync"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(n, elemSize, hostPtr, incx, devicePtr, incy, stream); } -cublasStatus_t CUBLASWINAPI cublasGetVectorAsync (int n, int elemSize, - const void *devicePtr, int incx, - void *hostPtr, int incy, - cudaStream_t stream) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(int, int, const void *, int, void *, int, cudaStream_t); +cublasStatus_t CUBLASWINAPI cublasGetVectorAsync(int n, int elemSize, + const void *devicePtr, + int incx, void *hostPtr, + int incy, + cudaStream_t stream) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(int, int, const void *, int, + void *, int, cudaStream_t); static auto func_ptr = LoadSymbol<FuncPtr>("cublasGetVectorAsync"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(n, elemSize, devicePtr, incx, hostPtr, incy, stream); } -cublasStatus_t CUBLASWINAPI cublasSetMatrixAsync (int rows, int cols, int elemSize, - const void *A, int lda, void *B, - int ldb, cudaStream_t stream) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(int, int, int, const void *, int, void *, int, cudaStream_t); +cublasStatus_t CUBLASWINAPI cublasSetMatrixAsync(int rows, int cols, + int elemSize, const void *A, + int lda, void *B, int ldb, + cudaStream_t stream) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + int, int, int, const void *, int, void *, int, cudaStream_t); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSetMatrixAsync"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(rows, cols, elemSize, A, lda, B, ldb, stream); } -cublasStatus_t CUBLASWINAPI cublasGetMatrixAsync (int rows, int cols, int elemSize, - const void *A, int lda, void *B, - int ldb, cudaStream_t stream) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(int, int, int, const void *, int, void *, int, cudaStream_t); +cublasStatus_t CUBLASWINAPI cublasGetMatrixAsync(int rows, int cols, + int elemSize, const void *A, + int lda, void *B, int ldb, + cudaStream_t stream) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + int, int, int, const void *, int, void *, int, cudaStream_t); static auto func_ptr = LoadSymbol<FuncPtr>("cublasGetMatrixAsync"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(rows, cols, elemSize, A, lda, B, ldb, stream); } -void CUBLASWINAPI cublasXerbla (const char *srName, int info) { - using FuncPtr = void (CUBLASWINAPI *)(const char *, int); +void CUBLASWINAPI cublasXerbla(const char *srName, int info) { + using FuncPtr = void(CUBLASWINAPI *)(const char *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasXerbla"); if (!func_ptr) LogFatalSymbolNotFound("cublasXerbla"); return func_ptr(srName, info); } -cublasStatus_t CUBLASWINAPI cublasNrm2Ex(cublasHandle_t handle, - int n, - const void *x, - cudaDataType xType, - int incx, - void *result, - cudaDataType resultType, - cudaDataType executionType) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const void *, cudaDataType, int, void *, cudaDataType, cudaDataType); +cublasStatus_t CUBLASWINAPI cublasNrm2Ex(cublasHandle_t handle, int n, + const void *x, cudaDataType xType, + int incx, void *result, + cudaDataType resultType, + cudaDataType executionType) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const void *, cudaDataType, int, void *, + cudaDataType, cudaDataType); static auto func_ptr = LoadSymbol<FuncPtr>("cublasNrm2Ex"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, xType, incx, result, resultType, executionType); } -cublasStatus_t CUBLASWINAPI cublasSnrm2_v2(cublasHandle_t handle, - int n, - const float *x, - int incx, - float *result) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const float *, int, float *); +cublasStatus_t CUBLASWINAPI cublasSnrm2_v2(cublasHandle_t handle, int n, + const float *x, int incx, + float *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, + const float *, int, float *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSnrm2_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, result); } -cublasStatus_t CUBLASWINAPI cublasDnrm2_v2(cublasHandle_t handle, - int n, - const double *x, - int incx, - double *result) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const double *, int, double *); +cublasStatus_t CUBLASWINAPI cublasDnrm2_v2(cublasHandle_t handle, int n, + const double *x, int incx, + double *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, + const double *, int, double *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDnrm2_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, result); } -cublasStatus_t CUBLASWINAPI cublasScnrm2_v2(cublasHandle_t handle, - int n, - const cuComplex *x, - int incx, - float *result) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuComplex *, int, float *); +cublasStatus_t CUBLASWINAPI cublasScnrm2_v2(cublasHandle_t handle, int n, + const cuComplex *x, int incx, + float *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const cuComplex *, int, float *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasScnrm2_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, result); } -cublasStatus_t CUBLASWINAPI cublasDznrm2_v2(cublasHandle_t handle, - int n, - const cuDoubleComplex *x, - int incx, - double *result) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuDoubleComplex *, int, double *); +cublasStatus_t CUBLASWINAPI cublasDznrm2_v2(cublasHandle_t handle, int n, + const cuDoubleComplex *x, int incx, + double *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const cuDoubleComplex *, int, double *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDznrm2_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, result); } -cublasStatus_t CUBLASWINAPI cublasDotEx (cublasHandle_t handle, - int n, - const void *x, - cudaDataType xType, - int incx, - const void *y, - cudaDataType yType, - int incy, - void *result, - cudaDataType resultType, - cudaDataType executionType) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const void *, cudaDataType, int, const void *, cudaDataType, int, void *, cudaDataType, cudaDataType); +cublasStatus_t CUBLASWINAPI cublasDotEx(cublasHandle_t handle, int n, + const void *x, cudaDataType xType, + int incx, const void *y, + cudaDataType yType, int incy, + void *result, cudaDataType resultType, + cudaDataType executionType) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const void *, cudaDataType, int, const void *, + cudaDataType, int, void *, cudaDataType, cudaDataType); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDotEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, n, x, xType, incx, y, yType, incy, result, resultType, executionType); + return func_ptr(handle, n, x, xType, incx, y, yType, incy, result, resultType, + executionType); } -cublasStatus_t CUBLASWINAPI cublasDotcEx (cublasHandle_t handle, - int n, - const void *x, - cudaDataType xType, - int incx, - const void *y, - cudaDataType yType, - int incy, - void *result, - cudaDataType resultType, - cudaDataType executionType) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const void *, cudaDataType, int, const void *, cudaDataType, int, void *, cudaDataType, cudaDataType); +cublasStatus_t CUBLASWINAPI cublasDotcEx(cublasHandle_t handle, int n, + const void *x, cudaDataType xType, + int incx, const void *y, + cudaDataType yType, int incy, + void *result, cudaDataType resultType, + cudaDataType executionType) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const void *, cudaDataType, int, const void *, + cudaDataType, int, void *, cudaDataType, cudaDataType); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDotcEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, n, x, xType, incx, y, yType, incy, result, resultType, executionType); + return func_ptr(handle, n, x, xType, incx, y, yType, incy, result, resultType, + executionType); } -cublasStatus_t CUBLASWINAPI cublasSdot_v2 (cublasHandle_t handle, - int n, - const float *x, - int incx, - const float *y, - int incy, - float *result) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const float *, int, const float *, int, float *); +cublasStatus_t CUBLASWINAPI cublasSdot_v2(cublasHandle_t handle, int n, + const float *x, int incx, + const float *y, int incy, + float *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const float *, int, const float *, int, float *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSdot_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, y, incy, result); } -cublasStatus_t CUBLASWINAPI cublasDdot_v2 (cublasHandle_t handle, - int n, - const double *x, - int incx, - const double *y, - int incy, - double *result) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const double *, int, const double *, int, double *); +cublasStatus_t CUBLASWINAPI cublasDdot_v2(cublasHandle_t handle, int n, + const double *x, int incx, + const double *y, int incy, + double *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const double *, int, const double *, int, double *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDdot_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, y, incy, result); } -cublasStatus_t CUBLASWINAPI cublasCdotu_v2 (cublasHandle_t handle, - int n, - const cuComplex *x, - int incx, - const cuComplex *y, - int incy, - cuComplex *result) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuComplex *, int, const cuComplex *, int, cuComplex *); +cublasStatus_t CUBLASWINAPI cublasCdotu_v2(cublasHandle_t handle, int n, + const cuComplex *x, int incx, + const cuComplex *y, int incy, + cuComplex *result) { + using FuncPtr = + cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, const cuComplex *, + int, const cuComplex *, int, cuComplex *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCdotu_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, y, incy, result); } -cublasStatus_t CUBLASWINAPI cublasCdotc_v2 (cublasHandle_t handle, - int n, - const cuComplex *x, - int incx, - const cuComplex *y, - int incy, - cuComplex *result) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuComplex *, int, const cuComplex *, int, cuComplex *); +cublasStatus_t CUBLASWINAPI cublasCdotc_v2(cublasHandle_t handle, int n, + const cuComplex *x, int incx, + const cuComplex *y, int incy, + cuComplex *result) { + using FuncPtr = + cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, const cuComplex *, + int, const cuComplex *, int, cuComplex *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCdotc_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, y, incy, result); } -cublasStatus_t CUBLASWINAPI cublasZdotu_v2 (cublasHandle_t handle, - int n, - const cuDoubleComplex *x, - int incx, - const cuDoubleComplex *y, - int incy, - cuDoubleComplex *result) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex *); +cublasStatus_t CUBLASWINAPI cublasZdotu_v2(cublasHandle_t handle, int n, + const cuDoubleComplex *x, int incx, + const cuDoubleComplex *y, int incy, + cuDoubleComplex *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, cuDoubleComplex *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZdotu_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, y, incy, result); } -cublasStatus_t CUBLASWINAPI cublasZdotc_v2 (cublasHandle_t handle, - int n, - const cuDoubleComplex *x, - int incx, - const cuDoubleComplex *y, - int incy, - cuDoubleComplex *result) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex *); +cublasStatus_t CUBLASWINAPI cublasZdotc_v2(cublasHandle_t handle, int n, + const cuDoubleComplex *x, int incx, + const cuDoubleComplex *y, int incy, + cuDoubleComplex *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, cuDoubleComplex *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZdotc_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, y, incy, result); } -cublasStatus_t CUBLASWINAPI cublasScalEx(cublasHandle_t handle, - int n, - const void *alpha, /* host or device pointer */ - cudaDataType alphaType, - void *x, - cudaDataType xType, - int incx, - cudaDataType executionType) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const void *, cudaDataType, void *, cudaDataType, int, cudaDataType); +cublasStatus_t CUBLASWINAPI +cublasScalEx(cublasHandle_t handle, int n, + const void *alpha, /* host or device pointer */ + cudaDataType alphaType, void *x, cudaDataType xType, int incx, + cudaDataType executionType) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const void *, cudaDataType, void *, cudaDataType, + int, cudaDataType); static auto func_ptr = LoadSymbol<FuncPtr>("cublasScalEx"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, alpha, alphaType, x, xType, incx, executionType); } -cublasStatus_t CUBLASWINAPI cublasSscal_v2(cublasHandle_t handle, - int n, - const float *alpha, /* host or device pointer */ - float *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const float *, float *, int); +cublasStatus_t CUBLASWINAPI +cublasSscal_v2(cublasHandle_t handle, int n, + const float *alpha, /* host or device pointer */ + float *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, + const float *, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSscal_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, alpha, x, incx); } -cublasStatus_t CUBLASWINAPI cublasDscal_v2(cublasHandle_t handle, - int n, - const double *alpha, /* host or device pointer */ - double *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const double *, double *, int); +cublasStatus_t CUBLASWINAPI +cublasDscal_v2(cublasHandle_t handle, int n, + const double *alpha, /* host or device pointer */ + double *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, + const double *, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDscal_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, alpha, x, incx); } -cublasStatus_t CUBLASWINAPI cublasCscal_v2(cublasHandle_t handle, - int n, - const cuComplex *alpha, /* host or device pointer */ - cuComplex *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuComplex *, cuComplex *, int); +cublasStatus_t CUBLASWINAPI +cublasCscal_v2(cublasHandle_t handle, int n, + const cuComplex *alpha, /* host or device pointer */ + cuComplex *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const cuComplex *, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCscal_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, alpha, x, incx); } -cublasStatus_t CUBLASWINAPI cublasCsscal_v2(cublasHandle_t handle, - int n, - const float *alpha, /* host or device pointer */ - cuComplex *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const float *, cuComplex *, int); +cublasStatus_t CUBLASWINAPI +cublasCsscal_v2(cublasHandle_t handle, int n, + const float *alpha, /* host or device pointer */ + cuComplex *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const float *, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCsscal_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, alpha, x, incx); } -cublasStatus_t CUBLASWINAPI cublasZscal_v2(cublasHandle_t handle, - int n, - const cuDoubleComplex *alpha, /* host or device pointer */ - cuDoubleComplex *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuDoubleComplex *, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI +cublasZscal_v2(cublasHandle_t handle, int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + cuDoubleComplex *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const cuDoubleComplex *, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZscal_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, alpha, x, incx); } -cublasStatus_t CUBLASWINAPI cublasZdscal_v2(cublasHandle_t handle, - int n, - const double *alpha, /* host or device pointer */ - cuDoubleComplex *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const double *, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI +cublasZdscal_v2(cublasHandle_t handle, int n, + const double *alpha, /* host or device pointer */ + cuDoubleComplex *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const double *, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZdscal_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, alpha, x, incx); } -cublasStatus_t CUBLASWINAPI cublasAxpyEx (cublasHandle_t handle, - int n, - const void *alpha, /* host or device pointer */ - cudaDataType alphaType, - const void *x, - cudaDataType xType, - int incx, - void *y, - cudaDataType yType, - int incy, - cudaDataType executiontype) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const void *, cudaDataType, const void *, cudaDataType, int, void *, cudaDataType, int, cudaDataType); +cublasStatus_t CUBLASWINAPI cublasAxpyEx( + cublasHandle_t handle, int n, + const void *alpha, /* host or device pointer */ + cudaDataType alphaType, const void *x, cudaDataType xType, int incx, + void *y, cudaDataType yType, int incy, cudaDataType executiontype) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const void *, cudaDataType, const void *, + cudaDataType, int, void *, cudaDataType, int, cudaDataType); static auto func_ptr = LoadSymbol<FuncPtr>("cublasAxpyEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, n, alpha, alphaType, x, xType, incx, y, yType, incy, executiontype); + return func_ptr(handle, n, alpha, alphaType, x, xType, incx, y, yType, incy, + executiontype); } -cublasStatus_t CUBLASWINAPI cublasSaxpy_v2 (cublasHandle_t handle, - int n, - const float *alpha, /* host or device pointer */ - const float *x, - int incx, - float *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const float *, const float *, int, float *, int); +cublasStatus_t CUBLASWINAPI +cublasSaxpy_v2(cublasHandle_t handle, int n, + const float *alpha, /* host or device pointer */ + const float *x, int incx, float *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const float *, const float *, int, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSaxpy_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, alpha, x, incx, y, incy); } -cublasStatus_t CUBLASWINAPI cublasDaxpy_v2 (cublasHandle_t handle, - int n, - const double *alpha, /* host or device pointer */ - const double *x, - int incx, - double *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const double *, const double *, int, double *, int); +cublasStatus_t CUBLASWINAPI +cublasDaxpy_v2(cublasHandle_t handle, int n, + const double *alpha, /* host or device pointer */ + const double *x, int incx, double *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const double *, const double *, int, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDaxpy_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, alpha, x, incx, y, incy); } -cublasStatus_t CUBLASWINAPI cublasCaxpy_v2 (cublasHandle_t handle, - int n, - const cuComplex *alpha, /* host or device pointer */ - const cuComplex *x, - int incx, - cuComplex *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuComplex *, const cuComplex *, int, cuComplex *, int); +cublasStatus_t CUBLASWINAPI +cublasCaxpy_v2(cublasHandle_t handle, int n, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *x, int incx, cuComplex *y, int incy) { + using FuncPtr = + cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, const cuComplex *, + const cuComplex *, int, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCaxpy_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, alpha, x, incx, y, incy); } -cublasStatus_t CUBLASWINAPI cublasZaxpy_v2 (cublasHandle_t handle, - int n, - const cuDoubleComplex *alpha, /* host or device pointer */ - const cuDoubleComplex *x, - int incx, - cuDoubleComplex *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuDoubleComplex *, const cuDoubleComplex *, int, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI cublasZaxpy_v2( + cublasHandle_t handle, int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *x, int incx, cuDoubleComplex *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const cuDoubleComplex *, const cuDoubleComplex *, + int, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZaxpy_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, alpha, x, incx, y, incy); } -cublasStatus_t CUBLASWINAPI cublasScopy_v2 (cublasHandle_t handle, - int n, - const float *x, - int incx, - float *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const float *, int, float *, int); +cublasStatus_t CUBLASWINAPI cublasScopy_v2(cublasHandle_t handle, int n, + const float *x, int incx, float *y, + int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const float *, int, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasScopy_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, y, incy); } -cublasStatus_t CUBLASWINAPI cublasDcopy_v2 (cublasHandle_t handle, - int n, - const double *x, - int incx, - double *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const double *, int, double *, int); +cublasStatus_t CUBLASWINAPI cublasDcopy_v2(cublasHandle_t handle, int n, + const double *x, int incx, double *y, + int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const double *, int, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDcopy_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, y, incy); } -cublasStatus_t CUBLASWINAPI cublasCcopy_v2 (cublasHandle_t handle, - int n, - const cuComplex *x, - int incx, - cuComplex *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuComplex *, int, cuComplex *, int); +cublasStatus_t CUBLASWINAPI cublasCcopy_v2(cublasHandle_t handle, int n, + const cuComplex *x, int incx, + cuComplex *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const cuComplex *, int, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCcopy_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, y, incy); } -cublasStatus_t CUBLASWINAPI cublasZcopy_v2 (cublasHandle_t handle, - int n, - const cuDoubleComplex *x, - int incx, - cuDoubleComplex *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI cublasZcopy_v2(cublasHandle_t handle, int n, + const cuDoubleComplex *x, int incx, + cuDoubleComplex *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, + const cuDoubleComplex *, int, + cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZcopy_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, y, incy); } -cublasStatus_t CUBLASWINAPI cublasSswap_v2 (cublasHandle_t handle, - int n, - float *x, - int incx, - float *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, float *, int, float *, int); +cublasStatus_t CUBLASWINAPI cublasSswap_v2(cublasHandle_t handle, int n, + float *x, int incx, float *y, + int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, float *, + int, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSswap_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, y, incy); } -cublasStatus_t CUBLASWINAPI cublasDswap_v2 (cublasHandle_t handle, - int n, - double *x, - int incx, - double *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, double *, int, double *, int); +cublasStatus_t CUBLASWINAPI cublasDswap_v2(cublasHandle_t handle, int n, + double *x, int incx, double *y, + int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, double *, + int, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDswap_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, y, incy); } -cublasStatus_t CUBLASWINAPI cublasCswap_v2 (cublasHandle_t handle, - int n, - cuComplex *x, - int incx, - cuComplex *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, cuComplex *, int, cuComplex *, int); +cublasStatus_t CUBLASWINAPI cublasCswap_v2(cublasHandle_t handle, int n, + cuComplex *x, int incx, cuComplex *y, + int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, cuComplex *, int, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCswap_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, y, incy); } -cublasStatus_t CUBLASWINAPI cublasZswap_v2 (cublasHandle_t handle, - int n, - cuDoubleComplex *x, - int incx, - cuDoubleComplex *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, cuDoubleComplex *, int, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI cublasZswap_v2(cublasHandle_t handle, int n, + cuDoubleComplex *x, int incx, + cuDoubleComplex *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, cuDoubleComplex *, int, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZswap_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, y, incy); } -cublasStatus_t CUBLASWINAPI cublasIsamax_v2(cublasHandle_t handle, - int n, - const float *x, - int incx, - int *result) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const float *, int, int *); +cublasStatus_t CUBLASWINAPI cublasIsamax_v2(cublasHandle_t handle, int n, + const float *x, int incx, + int *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, + const float *, int, int *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasIsamax_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, result); } -cublasStatus_t CUBLASWINAPI cublasIdamax_v2(cublasHandle_t handle, - int n, - const double *x, - int incx, - int *result) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const double *, int, int *); +cublasStatus_t CUBLASWINAPI cublasIdamax_v2(cublasHandle_t handle, int n, + const double *x, int incx, + int *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, + const double *, int, int *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasIdamax_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, result); } -cublasStatus_t CUBLASWINAPI cublasIcamax_v2(cublasHandle_t handle, - int n, - const cuComplex *x, - int incx, - int *result) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuComplex *, int, int *); +cublasStatus_t CUBLASWINAPI cublasIcamax_v2(cublasHandle_t handle, int n, + const cuComplex *x, int incx, + int *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, + const cuComplex *, int, int *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasIcamax_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, result); } -cublasStatus_t CUBLASWINAPI cublasIzamax_v2(cublasHandle_t handle, - int n, - const cuDoubleComplex *x, - int incx, - int *result) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuDoubleComplex *, int, int *); +cublasStatus_t CUBLASWINAPI cublasIzamax_v2(cublasHandle_t handle, int n, + const cuDoubleComplex *x, int incx, + int *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const cuDoubleComplex *, int, int *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasIzamax_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, result); } -cublasStatus_t CUBLASWINAPI cublasIsamin_v2(cublasHandle_t handle, - int n, - const float *x, - int incx, - int *result) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const float *, int, int *); +cublasStatus_t CUBLASWINAPI cublasIsamin_v2(cublasHandle_t handle, int n, + const float *x, int incx, + int *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, + const float *, int, int *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasIsamin_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, result); } -cublasStatus_t CUBLASWINAPI cublasIdamin_v2(cublasHandle_t handle, - int n, - const double *x, - int incx, - int *result) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const double *, int, int *); +cublasStatus_t CUBLASWINAPI cublasIdamin_v2(cublasHandle_t handle, int n, + const double *x, int incx, + int *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, + const double *, int, int *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasIdamin_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, result); } -cublasStatus_t CUBLASWINAPI cublasIcamin_v2(cublasHandle_t handle, - int n, - const cuComplex *x, - int incx, - int *result) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuComplex *, int, int *); +cublasStatus_t CUBLASWINAPI cublasIcamin_v2(cublasHandle_t handle, int n, + const cuComplex *x, int incx, + int *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, + const cuComplex *, int, int *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasIcamin_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, result); } -cublasStatus_t CUBLASWINAPI cublasIzamin_v2(cublasHandle_t handle, - int n, - const cuDoubleComplex *x, - int incx, - int *result) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuDoubleComplex *, int, int *); +cublasStatus_t CUBLASWINAPI cublasIzamin_v2(cublasHandle_t handle, int n, + const cuDoubleComplex *x, int incx, + int *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const cuDoubleComplex *, int, int *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasIzamin_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, result); } -cublasStatus_t CUBLASWINAPI cublasSasum_v2(cublasHandle_t handle, - int n, - const float *x, - int incx, - float *result) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const float *, int, float *); +cublasStatus_t CUBLASWINAPI cublasSasum_v2(cublasHandle_t handle, int n, + const float *x, int incx, + float *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, + const float *, int, float *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSasum_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, result); } -cublasStatus_t CUBLASWINAPI cublasDasum_v2(cublasHandle_t handle, - int n, - const double *x, - int incx, - double *result) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const double *, int, double *); +cublasStatus_t CUBLASWINAPI cublasDasum_v2(cublasHandle_t handle, int n, + const double *x, int incx, + double *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, + const double *, int, double *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDasum_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, result); } -cublasStatus_t CUBLASWINAPI cublasScasum_v2(cublasHandle_t handle, - int n, - const cuComplex *x, - int incx, - float *result) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuComplex *, int, float *); +cublasStatus_t CUBLASWINAPI cublasScasum_v2(cublasHandle_t handle, int n, + const cuComplex *x, int incx, + float *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const cuComplex *, int, float *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasScasum_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, result); } -cublasStatus_t CUBLASWINAPI cublasDzasum_v2(cublasHandle_t handle, - int n, - const cuDoubleComplex *x, - int incx, - double *result) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuDoubleComplex *, int, double *); +cublasStatus_t CUBLASWINAPI cublasDzasum_v2(cublasHandle_t handle, int n, + const cuDoubleComplex *x, int incx, + double *result) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const cuDoubleComplex *, int, double *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDzasum_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, result); } -cublasStatus_t CUBLASWINAPI cublasSrot_v2 (cublasHandle_t handle, - int n, - float *x, - int incx, - float *y, - int incy, - const float *c, /* host or device pointer */ - const float *s) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, float *, int, float *, int, const float *, const float *); +cublasStatus_t CUBLASWINAPI +cublasSrot_v2(cublasHandle_t handle, int n, float *x, int incx, float *y, + int incy, const float *c, /* host or device pointer */ + const float *s) { + using FuncPtr = + cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, float *, int, float *, + int, const float *, const float *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSrot_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, y, incy, c, s); } -cublasStatus_t CUBLASWINAPI cublasDrot_v2 (cublasHandle_t handle, - int n, - double *x, - int incx, - double *y, - int incy, - const double *c, /* host or device pointer */ - const double *s) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, double *, int, double *, int, const double *, const double *); +cublasStatus_t CUBLASWINAPI +cublasDrot_v2(cublasHandle_t handle, int n, double *x, int incx, double *y, + int incy, const double *c, /* host or device pointer */ + const double *s) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, double *, int, double *, int, const double *, + const double *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDrot_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, y, incy, c, s); } -cublasStatus_t CUBLASWINAPI cublasCrot_v2 (cublasHandle_t handle, - int n, - cuComplex *x, - int incx, - cuComplex *y, - int incy, - const float *c, /* host or device pointer */ - const cuComplex *s) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, cuComplex *, int, cuComplex *, int, const float *, const cuComplex *); +cublasStatus_t CUBLASWINAPI cublasCrot_v2( + cublasHandle_t handle, int n, cuComplex *x, int incx, cuComplex *y, + int incy, const float *c, /* host or device pointer */ + const cuComplex *s) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, cuComplex *, int, cuComplex *, int, const float *, + const cuComplex *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCrot_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, y, incy, c, s); } -cublasStatus_t CUBLASWINAPI cublasCsrot_v2(cublasHandle_t handle, - int n, - cuComplex *x, - int incx, - cuComplex *y, - int incy, - const float *c, /* host or device pointer */ - const float *s) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, cuComplex *, int, cuComplex *, int, const float *, const float *); +cublasStatus_t CUBLASWINAPI cublasCsrot_v2( + cublasHandle_t handle, int n, cuComplex *x, int incx, cuComplex *y, + int incy, const float *c, /* host or device pointer */ + const float *s) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, cuComplex *, int, cuComplex *, int, const float *, + const float *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCsrot_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, y, incy, c, s); } -cublasStatus_t CUBLASWINAPI cublasZrot_v2 (cublasHandle_t handle, - int n, - cuDoubleComplex *x, - int incx, - cuDoubleComplex *y, - int incy, - const double *c, /* host or device pointer */ - const cuDoubleComplex *s) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, cuDoubleComplex *, int, cuDoubleComplex *, int, const double *, const cuDoubleComplex *); +cublasStatus_t CUBLASWINAPI cublasZrot_v2( + cublasHandle_t handle, int n, cuDoubleComplex *x, int incx, + cuDoubleComplex *y, int incy, const double *c, /* host or device pointer */ + const cuDoubleComplex *s) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, cuDoubleComplex *, int, cuDoubleComplex *, int, + const double *, const cuDoubleComplex *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZrot_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, y, incy, c, s); } -cublasStatus_t CUBLASWINAPI cublasZdrot_v2(cublasHandle_t handle, - int n, - cuDoubleComplex *x, - int incx, - cuDoubleComplex *y, - int incy, - const double *c, /* host or device pointer */ - const double *s) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, cuDoubleComplex *, int, cuDoubleComplex *, int, const double *, const double *); +cublasStatus_t CUBLASWINAPI cublasZdrot_v2( + cublasHandle_t handle, int n, cuDoubleComplex *x, int incx, + cuDoubleComplex *y, int incy, const double *c, /* host or device pointer */ + const double *s) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, cuDoubleComplex *, int, cuDoubleComplex *, int, + const double *, const double *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZdrot_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, y, incy, c, s); } -cublasStatus_t CUBLASWINAPI cublasSrotg_v2(cublasHandle_t handle, - float *a, /* host or device pointer */ - float *b, /* host or device pointer */ - float *c, /* host or device pointer */ - float *s) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, float *, float *, float *, float *); +cublasStatus_t CUBLASWINAPI +cublasSrotg_v2(cublasHandle_t handle, float *a, /* host or device pointer */ + float *b, /* host or device pointer */ + float *c, /* host or device pointer */ + float *s) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, float *, + float *, float *, float *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSrotg_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, a, b, c, s); } -cublasStatus_t CUBLASWINAPI cublasDrotg_v2(cublasHandle_t handle, - double *a, /* host or device pointer */ - double *b, /* host or device pointer */ - double *c, /* host or device pointer */ - double *s) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, double *, double *, double *, double *); +cublasStatus_t CUBLASWINAPI +cublasDrotg_v2(cublasHandle_t handle, double *a, /* host or device pointer */ + double *b, /* host or device pointer */ + double *c, /* host or device pointer */ + double *s) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, double *, + double *, double *, double *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDrotg_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, a, b, c, s); } -cublasStatus_t CUBLASWINAPI cublasCrotg_v2(cublasHandle_t handle, - cuComplex *a, /* host or device pointer */ - cuComplex *b, /* host or device pointer */ - float *c, /* host or device pointer */ - cuComplex *s) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cuComplex *, cuComplex *, float *, cuComplex *); +cublasStatus_t CUBLASWINAPI +cublasCrotg_v2(cublasHandle_t handle, cuComplex *a, /* host or device pointer */ + cuComplex *b, /* host or device pointer */ + float *c, /* host or device pointer */ + cuComplex *s) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cuComplex *, cuComplex *, float *, cuComplex *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCrotg_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, a, b, c, s); } -cublasStatus_t CUBLASWINAPI cublasZrotg_v2(cublasHandle_t handle, - cuDoubleComplex *a, /* host or device pointer */ - cuDoubleComplex *b, /* host or device pointer */ - double *c, /* host or device pointer */ - cuDoubleComplex *s) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cuDoubleComplex *, cuDoubleComplex *, double *, cuDoubleComplex *); +cublasStatus_t CUBLASWINAPI cublasZrotg_v2( + cublasHandle_t handle, cuDoubleComplex *a, /* host or device pointer */ + cuDoubleComplex *b, /* host or device pointer */ + double *c, /* host or device pointer */ + cuDoubleComplex *s) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cuDoubleComplex *, cuDoubleComplex *, double *, + cuDoubleComplex *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZrotg_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, a, b, c, s); } -cublasStatus_t CUBLASWINAPI cublasSrotm_v2(cublasHandle_t handle, - int n, - float *x, - int incx, - float *y, - int incy, - const float* param) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, float *, int, float *, int, const float *); +cublasStatus_t CUBLASWINAPI cublasSrotm_v2(cublasHandle_t handle, int n, + float *x, int incx, float *y, + int incy, const float *param) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, float *, int, float *, int, const float *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSrotm_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, y, incy, param); } -cublasStatus_t CUBLASWINAPI cublasDrotm_v2(cublasHandle_t handle, - int n, - double *x, - int incx, - double *y, - int incy, - const double* param) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, double *, int, double *, int, const double *); +cublasStatus_t CUBLASWINAPI cublasDrotm_v2(cublasHandle_t handle, int n, + double *x, int incx, double *y, + int incy, const double *param) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, double *, int, double *, int, const double *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDrotm_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, x, incx, y, incy, param); } -cublasStatus_t CUBLASWINAPI cublasSrotmg_v2(cublasHandle_t handle, - float *d1, /* host or device pointer */ - float *d2, /* host or device pointer */ - float *x1, /* host or device pointer */ - const float *y1, /* host or device pointer */ - float *param) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, float *, float *, float *, const float *, float *); +cublasStatus_t CUBLASWINAPI +cublasSrotmg_v2(cublasHandle_t handle, float *d1, /* host or device pointer */ + float *d2, /* host or device pointer */ + float *x1, /* host or device pointer */ + const float *y1, /* host or device pointer */ + float *param) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, float *, float *, float *, const float *, float *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSrotmg_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, d1, d2, x1, y1, param); } -cublasStatus_t CUBLASWINAPI cublasDrotmg_v2(cublasHandle_t handle, - double *d1, /* host or device pointer */ - double *d2, /* host or device pointer */ - double *x1, /* host or device pointer */ - const double *y1, /* host or device pointer */ - double *param) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, double *, double *, double *, const double *, double *); +cublasStatus_t CUBLASWINAPI +cublasDrotmg_v2(cublasHandle_t handle, double *d1, /* host or device pointer */ + double *d2, /* host or device pointer */ + double *x1, /* host or device pointer */ + const double *y1, /* host or device pointer */ + double *param) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, double *, double *, double *, const double *, double *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDrotmg_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, d1, d2, x1, y1, param); } -cublasStatus_t CUBLASWINAPI cublasSgemv_v2 (cublasHandle_t handle, - cublasOperation_t trans, - int m, - int n, - const float *alpha, /* host or device pointer */ - const float *A, - int lda, - const float *x, - int incx, - const float *beta, /* host or device pointer */ - float *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, int, int, const float *, const float *, int, const float *, int, const float *, float *, int); +cublasStatus_t CUBLASWINAPI +cublasSgemv_v2(cublasHandle_t handle, cublasOperation_t trans, int m, int n, + const float *alpha, /* host or device pointer */ + const float *A, int lda, const float *x, int incx, + const float *beta, /* host or device pointer */ + float *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, int, int, const float *, const float *, + int, const float *, int, const float *, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSgemv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, trans, m, n, alpha, A, lda, x, incx, beta, y, incy); } -cublasStatus_t CUBLASWINAPI cublasDgemv_v2 (cublasHandle_t handle, - cublasOperation_t trans, - int m, - int n, - const double *alpha, /* host or device pointer */ - const double *A, - int lda, - const double *x, - int incx, - const double *beta, /* host or device pointer */ - double *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, int, int, const double *, const double *, int, const double *, int, const double *, double *, int); +cublasStatus_t CUBLASWINAPI +cublasDgemv_v2(cublasHandle_t handle, cublasOperation_t trans, int m, int n, + const double *alpha, /* host or device pointer */ + const double *A, int lda, const double *x, int incx, + const double *beta, /* host or device pointer */ + double *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, int, int, const double *, + const double *, int, const double *, int, const double *, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDgemv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, trans, m, n, alpha, A, lda, x, incx, beta, y, incy); } -cublasStatus_t CUBLASWINAPI cublasCgemv_v2 (cublasHandle_t handle, - cublasOperation_t trans, - int m, - int n, - const cuComplex *alpha, /* host or device pointer */ - const cuComplex *A, - int lda, - const cuComplex *x, - int incx, - const cuComplex *beta, /* host or device pointer */ - cuComplex *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, int, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, const cuComplex *, cuComplex *, int); +cublasStatus_t CUBLASWINAPI +cublasCgemv_v2(cublasHandle_t handle, cublasOperation_t trans, int m, int n, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, const cuComplex *x, int incx, + const cuComplex *beta, /* host or device pointer */ + cuComplex *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, int, int, const cuComplex *, + const cuComplex *, int, const cuComplex *, int, const cuComplex *, + cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgemv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, trans, m, n, alpha, A, lda, x, incx, beta, y, incy); } -cublasStatus_t CUBLASWINAPI cublasZgemv_v2 (cublasHandle_t handle, - cublasOperation_t trans, - int m, - int n, - const cuDoubleComplex *alpha, /* host or device pointer */ - const cuDoubleComplex *A, - int lda, - const cuDoubleComplex *x, - int incx, - const cuDoubleComplex *beta, /* host or device pointer */ - cuDoubleComplex *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, int, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI cublasZgemv_v2( + cublasHandle_t handle, cublasOperation_t trans, int m, int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, const cuDoubleComplex *x, int incx, + const cuDoubleComplex *beta, /* host or device pointer */ + cuDoubleComplex *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, int, int, const cuDoubleComplex *, + const cuDoubleComplex *, int, const cuDoubleComplex *, int, + const cuDoubleComplex *, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZgemv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, trans, m, n, alpha, A, lda, x, incx, beta, y, incy); } -cublasStatus_t CUBLASWINAPI cublasSgbmv_v2 (cublasHandle_t handle, - cublasOperation_t trans, - int m, - int n, - int kl, - int ku, - const float *alpha, /* host or device pointer */ - const float *A, - int lda, - const float *x, - int incx, - const float *beta, /* host or device pointer */ - float *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, int, int, int, int, const float *, const float *, int, const float *, int, const float *, float *, int); +cublasStatus_t CUBLASWINAPI +cublasSgbmv_v2(cublasHandle_t handle, cublasOperation_t trans, int m, int n, + int kl, int ku, const float *alpha, /* host or device pointer */ + const float *A, int lda, const float *x, int incx, + const float *beta, /* host or device pointer */ + float *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, int, int, int, int, const float *, + const float *, int, const float *, int, const float *, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSgbmv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, trans, m, n, kl, ku, alpha, A, lda, x, incx, beta, y, incy); + return func_ptr(handle, trans, m, n, kl, ku, alpha, A, lda, x, incx, beta, y, + incy); } -cublasStatus_t CUBLASWINAPI cublasDgbmv_v2 (cublasHandle_t handle, - cublasOperation_t trans, - int m, - int n, - int kl, - int ku, - const double *alpha, /* host or device pointer */ - const double *A, - int lda, - const double *x, - int incx, - const double *beta, /* host or device pointer */ - double *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, int, int, int, int, const double *, const double *, int, const double *, int, const double *, double *, int); +cublasStatus_t CUBLASWINAPI +cublasDgbmv_v2(cublasHandle_t handle, cublasOperation_t trans, int m, int n, + int kl, int ku, const double *alpha, /* host or device pointer */ + const double *A, int lda, const double *x, int incx, + const double *beta, /* host or device pointer */ + double *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, int, int, int, int, const double *, + const double *, int, const double *, int, const double *, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDgbmv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, trans, m, n, kl, ku, alpha, A, lda, x, incx, beta, y, incy); + return func_ptr(handle, trans, m, n, kl, ku, alpha, A, lda, x, incx, beta, y, + incy); } -cublasStatus_t CUBLASWINAPI cublasCgbmv_v2 (cublasHandle_t handle, - cublasOperation_t trans, - int m, - int n, - int kl, - int ku, - const cuComplex *alpha, /* host or device pointer */ - const cuComplex *A, - int lda, - const cuComplex *x, - int incx, - const cuComplex *beta, /* host or device pointer */ - cuComplex *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, int, int, int, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, const cuComplex *, cuComplex *, int); +cublasStatus_t CUBLASWINAPI cublasCgbmv_v2( + cublasHandle_t handle, cublasOperation_t trans, int m, int n, int kl, + int ku, const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, const cuComplex *x, int incx, + const cuComplex *beta, /* host or device pointer */ + cuComplex *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, int, int, int, int, const cuComplex *, + const cuComplex *, int, const cuComplex *, int, const cuComplex *, + cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgbmv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, trans, m, n, kl, ku, alpha, A, lda, x, incx, beta, y, incy); + return func_ptr(handle, trans, m, n, kl, ku, alpha, A, lda, x, incx, beta, y, + incy); } -cublasStatus_t CUBLASWINAPI cublasZgbmv_v2 (cublasHandle_t handle, - cublasOperation_t trans, - int m, - int n, - int kl, - int ku, - const cuDoubleComplex *alpha, /* host or device pointer */ - const cuDoubleComplex *A, - int lda, - const cuDoubleComplex *x, - int incx, - const cuDoubleComplex *beta, /* host or device pointer */ - cuDoubleComplex *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, int, int, int, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI cublasZgbmv_v2( + cublasHandle_t handle, cublasOperation_t trans, int m, int n, int kl, + int ku, const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, const cuDoubleComplex *x, int incx, + const cuDoubleComplex *beta, /* host or device pointer */ + cuDoubleComplex *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, int, int, int, int, + const cuDoubleComplex *, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, + int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZgbmv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, trans, m, n, kl, ku, alpha, A, lda, x, incx, beta, y, incy); + return func_ptr(handle, trans, m, n, kl, ku, alpha, A, lda, x, incx, beta, y, + incy); } -cublasStatus_t CUBLASWINAPI cublasStrmv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int n, - const float *A, - int lda, - float *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, const float *, int, float *, int); +cublasStatus_t CUBLASWINAPI cublasStrmv_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + cublasDiagType_t diag, int n, const float *A, int lda, float *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, const float *, int, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasStrmv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, diag, n, A, lda, x, incx); } -cublasStatus_t CUBLASWINAPI cublasDtrmv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int n, - const double *A, - int lda, - double *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, const double *, int, double *, int); +cublasStatus_t CUBLASWINAPI cublasDtrmv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, int n, + const double *A, int lda, double *x, + int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, const double *, int, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtrmv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, diag, n, A, lda, x, incx); } -cublasStatus_t CUBLASWINAPI cublasCtrmv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int n, - const cuComplex *A, - int lda, - cuComplex *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, const cuComplex *, int, cuComplex *, int); +cublasStatus_t CUBLASWINAPI cublasCtrmv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, int n, + const cuComplex *A, int lda, + cuComplex *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, const cuComplex *, int, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtrmv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, diag, n, A, lda, x, incx); } -cublasStatus_t CUBLASWINAPI cublasZtrmv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int n, - const cuDoubleComplex *A, - int lda, - cuDoubleComplex *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI cublasZtrmv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, int n, + const cuDoubleComplex *A, int lda, + cuDoubleComplex *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, const cuDoubleComplex *, int, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtrmv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, diag, n, A, lda, x, incx); } -cublasStatus_t CUBLASWINAPI cublasStbmv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int n, - int k, - const float *A, - int lda, - float *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, int, const float *, int, float *, int); +cublasStatus_t CUBLASWINAPI cublasStbmv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, int n, int k, + const float *A, int lda, float *x, + int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, int, const float *, int, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasStbmv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, diag, n, k, A, lda, x, incx); } -cublasStatus_t CUBLASWINAPI cublasDtbmv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int n, - int k, - const double *A, - int lda, - double *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, int, const double *, int, double *, int); +cublasStatus_t CUBLASWINAPI cublasDtbmv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, int n, int k, + const double *A, int lda, double *x, + int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, int, const double *, int, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtbmv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, diag, n, k, A, lda, x, incx); } -cublasStatus_t CUBLASWINAPI cublasCtbmv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int n, - int k, - const cuComplex *A, - int lda, - cuComplex *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, int, const cuComplex *, int, cuComplex *, int); +cublasStatus_t CUBLASWINAPI cublasCtbmv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, int n, int k, + const cuComplex *A, int lda, + cuComplex *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, int, const cuComplex *, int, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtbmv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, diag, n, k, A, lda, x, incx); } -cublasStatus_t CUBLASWINAPI cublasZtbmv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int n, - int k, - const cuDoubleComplex *A, - int lda, - cuDoubleComplex *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI cublasZtbmv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, int n, int k, + const cuDoubleComplex *A, int lda, + cuDoubleComplex *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtbmv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, diag, n, k, A, lda, x, incx); } -cublasStatus_t CUBLASWINAPI cublasStpmv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int n, - const float *AP, - float *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, const float *, float *, int); +cublasStatus_t CUBLASWINAPI cublasStpmv_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + cublasDiagType_t diag, int n, const float *AP, float *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, const float *, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasStpmv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, diag, n, AP, x, incx); } -cublasStatus_t CUBLASWINAPI cublasDtpmv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int n, - const double *AP, - double *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, const double *, double *, int); +cublasStatus_t CUBLASWINAPI cublasDtpmv_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + cublasDiagType_t diag, int n, const double *AP, double *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, const double *, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtpmv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, diag, n, AP, x, incx); } -cublasStatus_t CUBLASWINAPI cublasCtpmv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int n, - const cuComplex *AP, - cuComplex *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, const cuComplex *, cuComplex *, int); +cublasStatus_t CUBLASWINAPI cublasCtpmv_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + cublasDiagType_t diag, int n, const cuComplex *AP, cuComplex *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, const cuComplex *, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtpmv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, diag, n, AP, x, incx); } -cublasStatus_t CUBLASWINAPI cublasZtpmv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int n, - const cuDoubleComplex *AP, - cuDoubleComplex *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, const cuDoubleComplex *, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI cublasZtpmv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, int n, + const cuDoubleComplex *AP, + cuDoubleComplex *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, const cuDoubleComplex *, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtpmv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, diag, n, AP, x, incx); } -cublasStatus_t CUBLASWINAPI cublasStrsv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int n, - const float *A, - int lda, - float *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, const float *, int, float *, int); +cublasStatus_t CUBLASWINAPI cublasStrsv_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + cublasDiagType_t diag, int n, const float *A, int lda, float *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, const float *, int, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasStrsv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, diag, n, A, lda, x, incx); } -cublasStatus_t CUBLASWINAPI cublasDtrsv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int n, - const double *A, - int lda, - double *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, const double *, int, double *, int); +cublasStatus_t CUBLASWINAPI cublasDtrsv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, int n, + const double *A, int lda, double *x, + int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, const double *, int, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtrsv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, diag, n, A, lda, x, incx); } -cublasStatus_t CUBLASWINAPI cublasCtrsv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int n, - const cuComplex *A, - int lda, - cuComplex *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, const cuComplex *, int, cuComplex *, int); +cublasStatus_t CUBLASWINAPI cublasCtrsv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, int n, + const cuComplex *A, int lda, + cuComplex *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, const cuComplex *, int, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtrsv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, diag, n, A, lda, x, incx); } -cublasStatus_t CUBLASWINAPI cublasZtrsv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int n, - const cuDoubleComplex *A, - int lda, - cuDoubleComplex *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI cublasZtrsv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, int n, + const cuDoubleComplex *A, int lda, + cuDoubleComplex *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, const cuDoubleComplex *, int, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtrsv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, diag, n, A, lda, x, incx); } -cublasStatus_t CUBLASWINAPI cublasStpsv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int n, - const float *AP, - float *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, const float *, float *, int); +cublasStatus_t CUBLASWINAPI cublasStpsv_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + cublasDiagType_t diag, int n, const float *AP, float *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, const float *, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasStpsv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, diag, n, AP, x, incx); } -cublasStatus_t CUBLASWINAPI cublasDtpsv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int n, - const double *AP, - double *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, const double *, double *, int); +cublasStatus_t CUBLASWINAPI cublasDtpsv_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + cublasDiagType_t diag, int n, const double *AP, double *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, const double *, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtpsv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, diag, n, AP, x, incx); } -cublasStatus_t CUBLASWINAPI cublasCtpsv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int n, - const cuComplex *AP, - cuComplex *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, const cuComplex *, cuComplex *, int); +cublasStatus_t CUBLASWINAPI cublasCtpsv_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + cublasDiagType_t diag, int n, const cuComplex *AP, cuComplex *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, const cuComplex *, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtpsv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, diag, n, AP, x, incx); } -cublasStatus_t CUBLASWINAPI cublasZtpsv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int n, - const cuDoubleComplex *AP, - cuDoubleComplex *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, const cuDoubleComplex *, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI cublasZtpsv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, int n, + const cuDoubleComplex *AP, + cuDoubleComplex *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, const cuDoubleComplex *, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtpsv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, diag, n, AP, x, incx); } -cublasStatus_t CUBLASWINAPI cublasStbsv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int n, - int k, - const float *A, - int lda, - float *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, int, const float *, int, float *, int); +cublasStatus_t CUBLASWINAPI cublasStbsv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, int n, int k, + const float *A, int lda, float *x, + int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, int, const float *, int, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasStbsv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, diag, n, k, A, lda, x, incx); } -cublasStatus_t CUBLASWINAPI cublasDtbsv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int n, - int k, - const double *A, - int lda, - double *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, int, const double *, int, double *, int); +cublasStatus_t CUBLASWINAPI cublasDtbsv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, int n, int k, + const double *A, int lda, double *x, + int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, int, const double *, int, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtbsv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, diag, n, k, A, lda, x, incx); } -cublasStatus_t CUBLASWINAPI cublasCtbsv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int n, - int k, - const cuComplex *A, - int lda, - cuComplex *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, int, const cuComplex *, int, cuComplex *, int); +cublasStatus_t CUBLASWINAPI cublasCtbsv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, int n, int k, + const cuComplex *A, int lda, + cuComplex *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, int, const cuComplex *, int, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtbsv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, diag, n, k, A, lda, x, incx); } -cublasStatus_t CUBLASWINAPI cublasZtbsv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int n, - int k, - const cuDoubleComplex *A, - int lda, - cuDoubleComplex *x, - int incx) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI cublasZtbsv_v2(cublasHandle_t handle, + cublasFillMode_t uplo, + cublasOperation_t trans, + cublasDiagType_t diag, int n, int k, + const cuDoubleComplex *A, int lda, + cuDoubleComplex *x, int incx) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, + int, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtbsv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, diag, n, k, A, lda, x, incx); } -cublasStatus_t CUBLASWINAPI cublasSsymv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const float *alpha, /* host or device pointer */ - const float *A, - int lda, - const float *x, - int incx, - const float *beta, /* host or device pointer */ - float *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const float *, const float *, int, const float *, int, const float *, float *, int); +cublasStatus_t CUBLASWINAPI +cublasSsymv_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const float *alpha, /* host or device pointer */ + const float *A, int lda, const float *x, int incx, + const float *beta, /* host or device pointer */ + float *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const float *, const float *, int, + const float *, int, const float *, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSsymv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, A, lda, x, incx, beta, y, incy); } -cublasStatus_t CUBLASWINAPI cublasDsymv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const double *alpha, /* host or device pointer */ - const double *A, - int lda, - const double *x, - int incx, - const double *beta, /* host or device pointer */ - double *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const double *, const double *, int, const double *, int, const double *, double *, int); +cublasStatus_t CUBLASWINAPI +cublasDsymv_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const double *alpha, /* host or device pointer */ + const double *A, int lda, const double *x, int incx, + const double *beta, /* host or device pointer */ + double *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const double *, const double *, + int, const double *, int, const double *, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDsymv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, A, lda, x, incx, beta, y, incy); } -cublasStatus_t CUBLASWINAPI cublasCsymv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const cuComplex *alpha, /* host or device pointer */ - const cuComplex *A, - int lda, - const cuComplex *x, - int incx, - const cuComplex *beta, /* host or device pointer */ - cuComplex *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, const cuComplex *, cuComplex *, int); +cublasStatus_t CUBLASWINAPI +cublasCsymv_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, const cuComplex *x, int incx, + const cuComplex *beta, /* host or device pointer */ + cuComplex *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const cuComplex *, + const cuComplex *, int, const cuComplex *, int, const cuComplex *, + cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCsymv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, A, lda, x, incx, beta, y, incy); } -cublasStatus_t CUBLASWINAPI cublasZsymv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const cuDoubleComplex *alpha, /* host or device pointer */ - const cuDoubleComplex *A, - int lda, - const cuDoubleComplex *x, - int incx, - const cuDoubleComplex *beta, /* host or device pointer */ - cuDoubleComplex *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI cublasZsymv_v2( + cublasHandle_t handle, cublasFillMode_t uplo, int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, const cuDoubleComplex *x, int incx, + const cuDoubleComplex *beta, /* host or device pointer */ + cuDoubleComplex *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, + const cuDoubleComplex *, int, const cuDoubleComplex *, int, + const cuDoubleComplex *, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZsymv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, A, lda, x, incx, beta, y, incy); } -cublasStatus_t CUBLASWINAPI cublasChemv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const cuComplex *alpha, /* host or device pointer */ - const cuComplex *A, - int lda, - const cuComplex *x, - int incx, - const cuComplex *beta, /* host or device pointer */ - cuComplex *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, const cuComplex *, cuComplex *, int); +cublasStatus_t CUBLASWINAPI +cublasChemv_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, const cuComplex *x, int incx, + const cuComplex *beta, /* host or device pointer */ + cuComplex *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const cuComplex *, + const cuComplex *, int, const cuComplex *, int, const cuComplex *, + cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasChemv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, A, lda, x, incx, beta, y, incy); } -cublasStatus_t CUBLASWINAPI cublasZhemv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const cuDoubleComplex *alpha, /* host or device pointer */ - const cuDoubleComplex *A, - int lda, - const cuDoubleComplex *x, - int incx, - const cuDoubleComplex *beta, /* host or device pointer */ - cuDoubleComplex *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI cublasZhemv_v2( + cublasHandle_t handle, cublasFillMode_t uplo, int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, const cuDoubleComplex *x, int incx, + const cuDoubleComplex *beta, /* host or device pointer */ + cuDoubleComplex *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, + const cuDoubleComplex *, int, const cuDoubleComplex *, int, + const cuDoubleComplex *, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZhemv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, A, lda, x, incx, beta, y, incy); } -cublasStatus_t CUBLASWINAPI cublasSsbmv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - int k, - const float *alpha, /* host or device pointer */ - const float *A, - int lda, - const float *x, - int incx, - const float *beta, /* host or device pointer */ - float *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, int, const float *, const float *, int, const float *, int, const float *, float *, int); +cublasStatus_t CUBLASWINAPI +cublasSsbmv_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, int k, + const float *alpha, /* host or device pointer */ + const float *A, int lda, const float *x, int incx, + const float *beta, /* host or device pointer */ + float *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, int, const float *, const float *, + int, const float *, int, const float *, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSsbmv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, k, alpha, A, lda, x, incx, beta, y, incy); } -cublasStatus_t CUBLASWINAPI cublasDsbmv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - int k, - const double *alpha, /* host or device pointer */ - const double *A, - int lda, - const double *x, - int incx, - const double *beta, /* host or device pointer */ - double *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, int, const double *, const double *, int, const double *, int, const double *, double *, int); +cublasStatus_t CUBLASWINAPI +cublasDsbmv_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, int k, + const double *alpha, /* host or device pointer */ + const double *A, int lda, const double *x, int incx, + const double *beta, /* host or device pointer */ + double *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, int, const double *, + const double *, int, const double *, int, const double *, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDsbmv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, k, alpha, A, lda, x, incx, beta, y, incy); } -cublasStatus_t CUBLASWINAPI cublasChbmv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - int k, - const cuComplex *alpha, /* host or device pointer */ - const cuComplex *A, - int lda, - const cuComplex *x, - int incx, - const cuComplex *beta, /* host or device pointer */ - cuComplex *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, const cuComplex *, cuComplex *, int); +cublasStatus_t CUBLASWINAPI +cublasChbmv_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, int k, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, const cuComplex *x, int incx, + const cuComplex *beta, /* host or device pointer */ + cuComplex *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, int, const cuComplex *, + const cuComplex *, int, const cuComplex *, int, const cuComplex *, + cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasChbmv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, k, alpha, A, lda, x, incx, beta, y, incy); } -cublasStatus_t CUBLASWINAPI cublasZhbmv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - int k, - const cuDoubleComplex *alpha, /* host or device pointer */ - const cuDoubleComplex *A, - int lda, - const cuDoubleComplex *x, - int incx, - const cuDoubleComplex *beta, /* host or device pointer */ - cuDoubleComplex *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI cublasZhbmv_v2( + cublasHandle_t handle, cublasFillMode_t uplo, int n, int k, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, const cuDoubleComplex *x, int incx, + const cuDoubleComplex *beta, /* host or device pointer */ + cuDoubleComplex *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, int, const cuDoubleComplex *, + const cuDoubleComplex *, int, const cuDoubleComplex *, int, + const cuDoubleComplex *, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZhbmv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, k, alpha, A, lda, x, incx, beta, y, incy); } -cublasStatus_t CUBLASWINAPI cublasSspmv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const float *alpha, /* host or device pointer */ - const float *AP, - const float *x, - int incx, - const float *beta, /* host or device pointer */ - float *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const float *, const float *, const float *, int, const float *, float *, int); +cublasStatus_t CUBLASWINAPI +cublasSspmv_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const float *alpha, /* host or device pointer */ + const float *AP, const float *x, int incx, + const float *beta, /* host or device pointer */ + float *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const float *, const float *, + const float *, int, const float *, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSspmv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, AP, x, incx, beta, y, incy); } -cublasStatus_t CUBLASWINAPI cublasDspmv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const double *alpha, /* host or device pointer */ - const double *AP, - const double *x, - int incx, - const double *beta, /* host or device pointer */ - double *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const double *, const double *, const double *, int, const double *, double *, int); +cublasStatus_t CUBLASWINAPI +cublasDspmv_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const double *alpha, /* host or device pointer */ + const double *AP, const double *x, int incx, + const double *beta, /* host or device pointer */ + double *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const double *, const double *, + const double *, int, const double *, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDspmv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, AP, x, incx, beta, y, incy); } -cublasStatus_t CUBLASWINAPI cublasChpmv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const cuComplex *alpha, /* host or device pointer */ - const cuComplex *AP, - const cuComplex *x, - int incx, - const cuComplex *beta, /* host or device pointer */ - cuComplex *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const cuComplex *, const cuComplex *, const cuComplex *, int, const cuComplex *, cuComplex *, int); +cublasStatus_t CUBLASWINAPI +cublasChpmv_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *AP, const cuComplex *x, int incx, + const cuComplex *beta, /* host or device pointer */ + cuComplex *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const cuComplex *, + const cuComplex *, const cuComplex *, int, const cuComplex *, cuComplex *, + int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasChpmv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, AP, x, incx, beta, y, incy); } -cublasStatus_t CUBLASWINAPI cublasZhpmv_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const cuDoubleComplex *alpha, /* host or device pointer */ - const cuDoubleComplex *AP, - const cuDoubleComplex *x, - int incx, - const cuDoubleComplex *beta, /* host or device pointer */ - cuDoubleComplex *y, - int incy) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI +cublasZhpmv_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *AP, const cuDoubleComplex *x, int incx, + const cuDoubleComplex *beta, /* host or device pointer */ + cuDoubleComplex *y, int incy) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, + const cuDoubleComplex *, const cuDoubleComplex *, int, + const cuDoubleComplex *, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZhpmv_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, AP, x, incx, beta, y, incy); } -cublasStatus_t CUBLASWINAPI cublasSger_v2 (cublasHandle_t handle, - int m, - int n, - const float *alpha, /* host or device pointer */ - const float *x, - int incx, - const float *y, - int incy, - float *A, - int lda) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, int, const float *, const float *, int, const float *, int, float *, int); +cublasStatus_t CUBLASWINAPI cublasSger_v2( + cublasHandle_t handle, int m, int n, + const float *alpha, /* host or device pointer */ + const float *x, int incx, const float *y, int incy, float *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, int, const float *, const float *, int, + const float *, int, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSger_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, m, n, alpha, x, incx, y, incy, A, lda); } -cublasStatus_t CUBLASWINAPI cublasDger_v2 (cublasHandle_t handle, - int m, - int n, - const double *alpha, /* host or device pointer */ - const double *x, - int incx, - const double *y, - int incy, - double *A, - int lda) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, int, const double *, const double *, int, const double *, int, double *, int); +cublasStatus_t CUBLASWINAPI cublasDger_v2( + cublasHandle_t handle, int m, int n, + const double *alpha, /* host or device pointer */ + const double *x, int incx, const double *y, int incy, double *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, int, const double *, const double *, int, + const double *, int, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDger_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, m, n, alpha, x, incx, y, incy, A, lda); } -cublasStatus_t CUBLASWINAPI cublasCgeru_v2 (cublasHandle_t handle, - int m, - int n, - const cuComplex *alpha, /* host or device pointer */ - const cuComplex *x, - int incx, - const cuComplex *y, - int incy, - cuComplex *A, - int lda) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, cuComplex *, int); +cublasStatus_t CUBLASWINAPI +cublasCgeru_v2(cublasHandle_t handle, int m, int n, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *x, int incx, const cuComplex *y, int incy, + cuComplex *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, int, const cuComplex *, const cuComplex *, int, + const cuComplex *, int, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgeru_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, m, n, alpha, x, incx, y, incy, A, lda); } -cublasStatus_t CUBLASWINAPI cublasCgerc_v2 (cublasHandle_t handle, - int m, - int n, - const cuComplex *alpha, /* host or device pointer */ - const cuComplex *x, - int incx, - const cuComplex *y, - int incy, - cuComplex *A, - int lda) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, cuComplex *, int); +cublasStatus_t CUBLASWINAPI +cublasCgerc_v2(cublasHandle_t handle, int m, int n, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *x, int incx, const cuComplex *y, int incy, + cuComplex *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, int, const cuComplex *, const cuComplex *, int, + const cuComplex *, int, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgerc_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, m, n, alpha, x, incx, y, incy, A, lda); } -cublasStatus_t CUBLASWINAPI cublasZgeru_v2 (cublasHandle_t handle, - int m, - int n, - const cuDoubleComplex *alpha, /* host or device pointer */ - const cuDoubleComplex *x, - int incx, - const cuDoubleComplex *y, - int incy, - cuDoubleComplex *A, - int lda) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI +cublasZgeru_v2(cublasHandle_t handle, int m, int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *x, int incx, const cuDoubleComplex *y, + int incy, cuDoubleComplex *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, int, const cuDoubleComplex *, + const cuDoubleComplex *, int, const cuDoubleComplex *, int, + cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZgeru_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, m, n, alpha, x, incx, y, incy, A, lda); } -cublasStatus_t CUBLASWINAPI cublasZgerc_v2 (cublasHandle_t handle, - int m, - int n, - const cuDoubleComplex *alpha, /* host or device pointer */ - const cuDoubleComplex *x, - int incx, - const cuDoubleComplex *y, - int incy, - cuDoubleComplex *A, - int lda) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI +cublasZgerc_v2(cublasHandle_t handle, int m, int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *x, int incx, const cuDoubleComplex *y, + int incy, cuDoubleComplex *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, int, const cuDoubleComplex *, + const cuDoubleComplex *, int, const cuDoubleComplex *, int, + cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZgerc_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, m, n, alpha, x, incx, y, incy, A, lda); } -cublasStatus_t CUBLASWINAPI cublasSsyr_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const float *alpha, /* host or device pointer */ - const float *x, - int incx, - float *A, - int lda) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const float *, const float *, int, float *, int); +cublasStatus_t CUBLASWINAPI +cublasSsyr_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const float *alpha, /* host or device pointer */ + const float *x, int incx, float *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const float *, const float *, int, + float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSsyr_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, x, incx, A, lda); } -cublasStatus_t CUBLASWINAPI cublasDsyr_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const double *alpha, /* host or device pointer */ - const double *x, - int incx, - double *A, - int lda) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const double *, const double *, int, double *, int); +cublasStatus_t CUBLASWINAPI +cublasDsyr_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const double *alpha, /* host or device pointer */ + const double *x, int incx, double *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const double *, const double *, + int, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDsyr_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, x, incx, A, lda); } -cublasStatus_t CUBLASWINAPI cublasCsyr_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const cuComplex *alpha, /* host or device pointer */ - const cuComplex *x, - int incx, - cuComplex *A, - int lda) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const cuComplex *, const cuComplex *, int, cuComplex *, int); +cublasStatus_t CUBLASWINAPI +cublasCsyr_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *x, int incx, cuComplex *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const cuComplex *, + const cuComplex *, int, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCsyr_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, x, incx, A, lda); } -cublasStatus_t CUBLASWINAPI cublasZsyr_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const cuDoubleComplex *alpha, /* host or device pointer */ - const cuDoubleComplex *x, - int incx, - cuDoubleComplex *A, - int lda) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, const cuDoubleComplex *, int, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI +cublasZsyr_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *x, int incx, cuDoubleComplex *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, + const cuDoubleComplex *, int, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZsyr_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, x, incx, A, lda); } -cublasStatus_t CUBLASWINAPI cublasCher_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const float *alpha, /* host or device pointer */ - const cuComplex *x, - int incx, - cuComplex *A, - int lda) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const float *, const cuComplex *, int, cuComplex *, int); +cublasStatus_t CUBLASWINAPI +cublasCher_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const float *alpha, /* host or device pointer */ + const cuComplex *x, int incx, cuComplex *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const float *, const cuComplex *, + int, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCher_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, x, incx, A, lda); } -cublasStatus_t CUBLASWINAPI cublasZher_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const double *alpha, /* host or device pointer */ - const cuDoubleComplex *x, - int incx, - cuDoubleComplex *A, - int lda) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const double *, const cuDoubleComplex *, int, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI +cublasZher_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const double *alpha, /* host or device pointer */ + const cuDoubleComplex *x, int incx, cuDoubleComplex *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const double *, + const cuDoubleComplex *, int, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZher_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, x, incx, A, lda); } -cublasStatus_t CUBLASWINAPI cublasSspr_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const float *alpha, /* host or device pointer */ - const float *x, - int incx, - float *AP) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const float *, const float *, int, float *); +cublasStatus_t CUBLASWINAPI +cublasSspr_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const float *alpha, /* host or device pointer */ + const float *x, int incx, float *AP) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const float *, const float *, int, + float *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSspr_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, x, incx, AP); } -cublasStatus_t CUBLASWINAPI cublasDspr_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const double *alpha, /* host or device pointer */ - const double *x, - int incx, - double *AP) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const double *, const double *, int, double *); +cublasStatus_t CUBLASWINAPI +cublasDspr_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const double *alpha, /* host or device pointer */ + const double *x, int incx, double *AP) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const double *, const double *, + int, double *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDspr_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, x, incx, AP); } -cublasStatus_t CUBLASWINAPI cublasChpr_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const float *alpha, /* host or device pointer */ - const cuComplex *x, - int incx, - cuComplex *AP) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const float *, const cuComplex *, int, cuComplex *); +cublasStatus_t CUBLASWINAPI +cublasChpr_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const float *alpha, /* host or device pointer */ + const cuComplex *x, int incx, cuComplex *AP) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const float *, const cuComplex *, + int, cuComplex *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasChpr_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, x, incx, AP); } -cublasStatus_t CUBLASWINAPI cublasZhpr_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const double *alpha, /* host or device pointer */ - const cuDoubleComplex *x, - int incx, - cuDoubleComplex *AP) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const double *, const cuDoubleComplex *, int, cuDoubleComplex *); +cublasStatus_t CUBLASWINAPI +cublasZhpr_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const double *alpha, /* host or device pointer */ + const cuDoubleComplex *x, int incx, cuDoubleComplex *AP) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const double *, + const cuDoubleComplex *, int, cuDoubleComplex *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZhpr_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, x, incx, AP); } -cublasStatus_t CUBLASWINAPI cublasSsyr2_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const float *alpha, /* host or device pointer */ - const float *x, - int incx, - const float *y, - int incy, - float *A, - int lda) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const float *, const float *, int, const float *, int, float *, int); +cublasStatus_t CUBLASWINAPI cublasSsyr2_v2( + cublasHandle_t handle, cublasFillMode_t uplo, int n, + const float *alpha, /* host or device pointer */ + const float *x, int incx, const float *y, int incy, float *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const float *, const float *, int, + const float *, int, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSsyr2_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, x, incx, y, incy, A, lda); } -cublasStatus_t CUBLASWINAPI cublasDsyr2_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const double *alpha, /* host or device pointer */ - const double *x, - int incx, - const double *y, - int incy, - double *A, - int lda) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const double *, const double *, int, const double *, int, double *, int); +cublasStatus_t CUBLASWINAPI cublasDsyr2_v2( + cublasHandle_t handle, cublasFillMode_t uplo, int n, + const double *alpha, /* host or device pointer */ + const double *x, int incx, const double *y, int incy, double *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const double *, const double *, + int, const double *, int, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDsyr2_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, x, incx, y, incy, A, lda); } -cublasStatus_t CUBLASWINAPI cublasCsyr2_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, int n, - const cuComplex *alpha, /* host or device pointer */ - const cuComplex *x, - int incx, - const cuComplex *y, - int incy, - cuComplex *A, - int lda) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, cuComplex *, int); +cublasStatus_t CUBLASWINAPI +cublasCsyr2_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *x, int incx, const cuComplex *y, int incy, + cuComplex *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const cuComplex *, + const cuComplex *, int, const cuComplex *, int, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCsyr2_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, x, incx, y, incy, A, lda); } -cublasStatus_t CUBLASWINAPI cublasZsyr2_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const cuDoubleComplex *alpha, /* host or device pointer */ - const cuDoubleComplex *x, - int incx, - const cuDoubleComplex *y, - int incy, - cuDoubleComplex *A, - int lda) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI +cublasZsyr2_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *x, int incx, const cuDoubleComplex *y, + int incy, cuDoubleComplex *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, + const cuDoubleComplex *, int, const cuDoubleComplex *, int, + cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZsyr2_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, x, incx, y, incy, A, lda); } -cublasStatus_t CUBLASWINAPI cublasCher2_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, int n, - const cuComplex *alpha, /* host or device pointer */ - const cuComplex *x, - int incx, - const cuComplex *y, - int incy, - cuComplex *A, - int lda) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, cuComplex *, int); +cublasStatus_t CUBLASWINAPI +cublasCher2_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *x, int incx, const cuComplex *y, int incy, + cuComplex *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const cuComplex *, + const cuComplex *, int, const cuComplex *, int, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCher2_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, x, incx, y, incy, A, lda); } -cublasStatus_t CUBLASWINAPI cublasZher2_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const cuDoubleComplex *alpha, /* host or device pointer */ - const cuDoubleComplex *x, - int incx, - const cuDoubleComplex *y, - int incy, - cuDoubleComplex *A, - int lda) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI +cublasZher2_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *x, int incx, const cuDoubleComplex *y, + int incy, cuDoubleComplex *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, + const cuDoubleComplex *, int, const cuDoubleComplex *, int, + cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZher2_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, x, incx, y, incy, A, lda); } -cublasStatus_t CUBLASWINAPI cublasSspr2_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const float *alpha, /* host or device pointer */ - const float *x, - int incx, - const float *y, - int incy, - float *AP) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const float *, const float *, int, const float *, int, float *); +cublasStatus_t CUBLASWINAPI +cublasSspr2_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const float *alpha, /* host or device pointer */ + const float *x, int incx, const float *y, int incy, float *AP) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const float *, const float *, int, + const float *, int, float *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSspr2_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, x, incx, y, incy, AP); } -cublasStatus_t CUBLASWINAPI cublasDspr2_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const double *alpha, /* host or device pointer */ - const double *x, - int incx, - const double *y, - int incy, - double *AP) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const double *, const double *, int, const double *, int, double *); +cublasStatus_t CUBLASWINAPI cublasDspr2_v2( + cublasHandle_t handle, cublasFillMode_t uplo, int n, + const double *alpha, /* host or device pointer */ + const double *x, int incx, const double *y, int incy, double *AP) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const double *, const double *, + int, const double *, int, double *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDspr2_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, x, incx, y, incy, AP); } -cublasStatus_t CUBLASWINAPI cublasChpr2_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const cuComplex *alpha, /* host or device pointer */ - const cuComplex *x, - int incx, - const cuComplex *y, - int incy, - cuComplex *AP) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, cuComplex *); +cublasStatus_t CUBLASWINAPI cublasChpr2_v2( + cublasHandle_t handle, cublasFillMode_t uplo, int n, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *x, int incx, const cuComplex *y, int incy, cuComplex *AP) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const cuComplex *, + const cuComplex *, int, const cuComplex *, int, cuComplex *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasChpr2_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, x, incx, y, incy, AP); } -cublasStatus_t CUBLASWINAPI cublasZhpr2_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const cuDoubleComplex *alpha, /* host or device pointer */ - const cuDoubleComplex *x, - int incx, - const cuDoubleComplex *y, - int incy, - cuDoubleComplex *AP) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex *); +cublasStatus_t CUBLASWINAPI +cublasZhpr2_v2(cublasHandle_t handle, cublasFillMode_t uplo, int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *x, int incx, const cuDoubleComplex *y, + int incy, cuDoubleComplex *AP) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, + const cuDoubleComplex *, int, const cuDoubleComplex *, int, + cuDoubleComplex *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZhpr2_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, alpha, x, incx, y, incy, AP); } -cublasStatus_t CUBLASWINAPI cublasSgemm_v2 (cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const float *alpha, /* host or device pointer */ - const float *A, - int lda, - const float *B, - int ldb, - const float *beta, /* host or device pointer */ - float *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const float *, const float *, int, const float *, int, const float *, float *, int); +cublasStatus_t CUBLASWINAPI cublasSgemm_v2( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const float *alpha, /* host or device pointer */ + const float *A, int lda, const float *B, int ldb, + const float *beta, /* host or device pointer */ + float *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const float *, const float *, int, const float *, int, const float *, + float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSgemm_v2"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, + C, ldc); } -cublasStatus_t CUBLASWINAPI cublasDgemm_v2 (cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const double *alpha, /* host or device pointer */ - const double *A, - int lda, - const double *B, - int ldb, - const double *beta, /* host or device pointer */ - double *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const double *, const double *, int, const double *, int, const double *, double *, int); +cublasStatus_t CUBLASWINAPI cublasDgemm_v2( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const double *alpha, /* host or device pointer */ + const double *A, int lda, const double *B, int ldb, + const double *beta, /* host or device pointer */ + double *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const double *, const double *, int, const double *, int, const double *, + double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDgemm_v2"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, + C, ldc); } -cublasStatus_t CUBLASWINAPI cublasCgemm_v2 (cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const cuComplex *alpha, /* host or device pointer */ - const cuComplex *A, - int lda, - const cuComplex *B, - int ldb, - const cuComplex *beta, /* host or device pointer */ - cuComplex *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, const cuComplex *, cuComplex *, int); +cublasStatus_t CUBLASWINAPI cublasCgemm_v2( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, const cuComplex *B, int ldb, + const cuComplex *beta, /* host or device pointer */ + cuComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const cuComplex *, const cuComplex *, int, const cuComplex *, int, + const cuComplex *, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgemm_v2"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, + C, ldc); } -cublasStatus_t CUBLASWINAPI cublasCgemm3m (cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const cuComplex *alpha, /* host or device pointer */ - const cuComplex *A, - int lda, - const cuComplex *B, - int ldb, - const cuComplex *beta, /* host or device pointer */ - cuComplex *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, const cuComplex *, cuComplex *, int); +cublasStatus_t CUBLASWINAPI cublasCgemm3m( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, const cuComplex *B, int ldb, + const cuComplex *beta, /* host or device pointer */ + cuComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const cuComplex *, const cuComplex *, int, const cuComplex *, int, + const cuComplex *, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgemm3m"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, + C, ldc); } -cublasStatus_t CUBLASWINAPI cublasCgemm3mEx (cublasHandle_t handle, - cublasOperation_t transa, cublasOperation_t transb, - int m, int n, int k, - const cuComplex *alpha, - const void *A, - cudaDataType Atype, - int lda, - const void *B, - cudaDataType Btype, - int ldb, - const cuComplex *beta, - void *C, - cudaDataType Ctype, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const cuComplex *, const void *, cudaDataType, int, const void *, cudaDataType, int, const cuComplex *, void *, cudaDataType, int); +cublasStatus_t CUBLASWINAPI cublasCgemm3mEx( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const cuComplex *alpha, const void *A, + cudaDataType Atype, int lda, const void *B, cudaDataType Btype, int ldb, + const cuComplex *beta, void *C, cudaDataType Ctype, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const cuComplex *, const void *, cudaDataType, int, const void *, + cudaDataType, int, const cuComplex *, void *, cudaDataType, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgemm3mEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, + Btype, ldb, beta, C, Ctype, ldc); } -cublasStatus_t CUBLASWINAPI cublasZgemm_v2 (cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const cuDoubleComplex *alpha, /* host or device pointer */ - const cuDoubleComplex *A, - int lda, - const cuDoubleComplex *B, - int ldb, - const cuDoubleComplex *beta, /* host or device pointer */ - cuDoubleComplex *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI cublasZgemm_v2( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, const cuDoubleComplex *B, int ldb, + const cuDoubleComplex *beta, /* host or device pointer */ + cuDoubleComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const cuDoubleComplex *, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, + int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZgemm_v2"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, + C, ldc); } -cublasStatus_t CUBLASWINAPI cublasZgemm3m (cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const cuDoubleComplex *alpha, /* host or device pointer */ - const cuDoubleComplex *A, - int lda, - const cuDoubleComplex *B, - int ldb, - const cuDoubleComplex *beta, /* host or device pointer */ - cuDoubleComplex *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI +cublasZgemm3m(cublasHandle_t handle, cublasOperation_t transa, + cublasOperation_t transb, int m, int n, int k, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, const cuDoubleComplex *B, + int ldb, const cuDoubleComplex *beta, /* host or device pointer */ + cuDoubleComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const cuDoubleComplex *, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, + int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZgemm3m"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, + C, ldc); } -cublasStatus_t CUBLASWINAPI cublasSgemmEx (cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const float *alpha, /* host or device pointer */ - const void *A, - cudaDataType Atype, - int lda, - const void *B, - cudaDataType Btype, - int ldb, - const float *beta, /* host or device pointer */ - void *C, - cudaDataType Ctype, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const float *, const void *, cudaDataType, int, const void *, cudaDataType, int, const float *, void *, cudaDataType, int); +cublasStatus_t CUBLASWINAPI cublasSgemmEx( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const float *alpha, /* host or device pointer */ + const void *A, cudaDataType Atype, int lda, const void *B, + cudaDataType Btype, int ldb, const float *beta, /* host or device pointer */ + void *C, cudaDataType Ctype, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const float *, const void *, cudaDataType, int, const void *, + cudaDataType, int, const float *, void *, cudaDataType, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSgemmEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, + Btype, ldb, beta, C, Ctype, ldc); } -cublasStatus_t CUBLASWINAPI cublasGemmEx (cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const void *alpha, /* host or device pointer */ - const void *A, - cudaDataType Atype, - int lda, - const void *B, - cudaDataType Btype, - int ldb, - const void *beta, /* host or device pointer */ - void *C, - cudaDataType Ctype, - int ldc, - cudaDataType computeType, - cublasGemmAlgo_t algo) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const void *, const void *, cudaDataType, int, const void *, cudaDataType, int, const void *, void *, cudaDataType, int, cudaDataType, cublasGemmAlgo_t); +cublasStatus_t CUBLASWINAPI cublasGemmEx( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const void *alpha, /* host or device pointer */ + const void *A, cudaDataType Atype, int lda, const void *B, + cudaDataType Btype, int ldb, const void *beta, /* host or device pointer */ + void *C, cudaDataType Ctype, int ldc, cudaDataType computeType, + cublasGemmAlgo_t algo) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const void *, const void *, cudaDataType, int, const void *, cudaDataType, + int, const void *, void *, cudaDataType, int, cudaDataType, + cublasGemmAlgo_t); static auto func_ptr = LoadSymbol<FuncPtr>("cublasGemmEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc, computeType, algo); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, + Btype, ldb, beta, C, Ctype, ldc, computeType, algo); } -cublasStatus_t CUBLASWINAPI cublasCgemmEx (cublasHandle_t handle, - cublasOperation_t transa, cublasOperation_t transb, - int m, int n, int k, - const cuComplex *alpha, - const void *A, - cudaDataType Atype, - int lda, - const void *B, - cudaDataType Btype, - int ldb, - const cuComplex *beta, - void *C, - cudaDataType Ctype, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const cuComplex *, const void *, cudaDataType, int, const void *, cudaDataType, int, const cuComplex *, void *, cudaDataType, int); +cublasStatus_t CUBLASWINAPI cublasCgemmEx( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const cuComplex *alpha, const void *A, + cudaDataType Atype, int lda, const void *B, cudaDataType Btype, int ldb, + const cuComplex *beta, void *C, cudaDataType Ctype, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const cuComplex *, const void *, cudaDataType, int, const void *, + cudaDataType, int, const cuComplex *, void *, cudaDataType, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgemmEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, + Btype, ldb, beta, C, Ctype, ldc); } -cublasStatus_t CUBLASWINAPI cublasUint8gemmBias (cublasHandle_t handle, - cublasOperation_t transa, cublasOperation_t transb, cublasOperation_t transc, - int m, int n, int k, - const unsigned char *A, int A_bias, int lda, - const unsigned char *B, int B_bias, int ldb, - unsigned char *C, int C_bias, int ldc, - int C_mult, int C_shift) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, cublasOperation_t, int, int, int, const unsigned char *, int, int, const unsigned char *, int, int, unsigned char *, int, int, int, int); +cublasStatus_t CUBLASWINAPI cublasUint8gemmBias( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + cublasOperation_t transc, int m, int n, int k, const unsigned char *A, + int A_bias, int lda, const unsigned char *B, int B_bias, int ldb, + unsigned char *C, int C_bias, int ldc, int C_mult, int C_shift) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, cublasOperation_t, + int, int, int, const unsigned char *, int, int, const unsigned char *, + int, int, unsigned char *, int, int, int, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasUint8gemmBias"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, transa, transb, transc, m, n, k, A, A_bias, lda, B, B_bias, ldb, C, C_bias, ldc, C_mult, C_shift); + return func_ptr(handle, transa, transb, transc, m, n, k, A, A_bias, lda, B, + B_bias, ldb, C, C_bias, ldc, C_mult, C_shift); } -cublasStatus_t CUBLASWINAPI cublasSsyrk_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - int n, - int k, - const float *alpha, /* host or device pointer */ - const float *A, - int lda, - const float *beta, /* host or device pointer */ - float *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const float *, const float *, int, const float *, float *, int); +cublasStatus_t CUBLASWINAPI cublasSsyrk_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const float *alpha, /* host or device pointer */ + const float *A, int lda, const float *beta, /* host or device pointer */ + float *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const float *, const float *, int, const float *, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSsyrk_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, beta, C, ldc); } -cublasStatus_t CUBLASWINAPI cublasDsyrk_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - int n, - int k, - const double *alpha, /* host or device pointer */ - const double *A, - int lda, - const double *beta, /* host or device pointer */ - double *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const double *, const double *, int, const double *, double *, int); +cublasStatus_t CUBLASWINAPI cublasDsyrk_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const double *alpha, /* host or device pointer */ + const double *A, int lda, const double *beta, /* host or device pointer */ + double *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const double *, const double *, int, const double *, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDsyrk_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, beta, C, ldc); } -cublasStatus_t CUBLASWINAPI cublasCsyrk_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - int n, - int k, - const cuComplex *alpha, /* host or device pointer */ - const cuComplex *A, - int lda, - const cuComplex *beta, /* host or device pointer */ - cuComplex *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const cuComplex *, const cuComplex *, int, const cuComplex *, cuComplex *, int); +cublasStatus_t CUBLASWINAPI cublasCsyrk_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, + const cuComplex *beta, /* host or device pointer */ + cuComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const cuComplex *, const cuComplex *, int, const cuComplex *, cuComplex *, + int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCsyrk_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, beta, C, ldc); } -cublasStatus_t CUBLASWINAPI cublasZsyrk_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - int n, - int k, - const cuDoubleComplex *alpha, /* host or device pointer */ - const cuDoubleComplex *A, - int lda, - const cuDoubleComplex *beta, /* host or device pointer */ - cuDoubleComplex *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI cublasZsyrk_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, + const cuDoubleComplex *beta, /* host or device pointer */ + cuDoubleComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const cuDoubleComplex *, const cuDoubleComplex *, int, + const cuDoubleComplex *, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZsyrk_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, beta, C, ldc); } -cublasStatus_t CUBLASWINAPI cublasCsyrkEx ( cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - int n, - int k, - const cuComplex *alpha, /* host or device pointer */ - const void *A, - cudaDataType Atype, - int lda, - const cuComplex *beta, /* host or device pointer */ - void *C, - cudaDataType Ctype, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const cuComplex *, const void *, cudaDataType, int, const cuComplex *, void *, cudaDataType, int); +cublasStatus_t CUBLASWINAPI cublasCsyrkEx( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const cuComplex *alpha, /* host or device pointer */ + const void *A, cudaDataType Atype, int lda, + const cuComplex *beta, /* host or device pointer */ + void *C, cudaDataType Ctype, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const cuComplex *, const void *, cudaDataType, int, const cuComplex *, + void *, cudaDataType, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCsyrkEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, uplo, trans, n, k, alpha, A, Atype, lda, beta, C, Ctype, ldc); + return func_ptr(handle, uplo, trans, n, k, alpha, A, Atype, lda, beta, C, + Ctype, ldc); } -cublasStatus_t CUBLASWINAPI cublasCsyrk3mEx(cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - int n, - int k, - const cuComplex *alpha, - const void *A, - cudaDataType Atype, - int lda, - const cuComplex *beta, - void *C, - cudaDataType Ctype, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const cuComplex *, const void *, cudaDataType, int, const cuComplex *, void *, cudaDataType, int); +cublasStatus_t CUBLASWINAPI cublasCsyrk3mEx( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const cuComplex *alpha, const void *A, cudaDataType Atype, + int lda, const cuComplex *beta, void *C, cudaDataType Ctype, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const cuComplex *, const void *, cudaDataType, int, const cuComplex *, + void *, cudaDataType, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCsyrk3mEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, uplo, trans, n, k, alpha, A, Atype, lda, beta, C, Ctype, ldc); + return func_ptr(handle, uplo, trans, n, k, alpha, A, Atype, lda, beta, C, + Ctype, ldc); } -cublasStatus_t CUBLASWINAPI cublasCherk_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - int n, - int k, - const float *alpha, /* host or device pointer */ - const cuComplex *A, - int lda, - const float *beta, /* host or device pointer */ - cuComplex *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const float *, const cuComplex *, int, const float *, cuComplex *, int); +cublasStatus_t CUBLASWINAPI cublasCherk_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const float *alpha, /* host or device pointer */ + const cuComplex *A, int lda, const float *beta, /* host or device pointer */ + cuComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const float *, const cuComplex *, int, const float *, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCherk_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, beta, C, ldc); } -cublasStatus_t CUBLASWINAPI cublasZherk_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - int n, - int k, - const double *alpha, /* host or device pointer */ - const cuDoubleComplex *A, - int lda, - const double *beta, /* host or device pointer */ - cuDoubleComplex *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const double *, const cuDoubleComplex *, int, const double *, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI cublasZherk_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const double *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, + const double *beta, /* host or device pointer */ + cuDoubleComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const double *, const cuDoubleComplex *, int, const double *, + cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZherk_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, beta, C, ldc); } -cublasStatus_t CUBLASWINAPI cublasCherkEx (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - int n, - int k, - const float *alpha, /* host or device pointer */ - const void *A, - cudaDataType Atype, - int lda, - const float *beta, /* host or device pointer */ - void *C, - cudaDataType Ctype, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const float *, const void *, cudaDataType, int, const float *, void *, cudaDataType, int); +cublasStatus_t CUBLASWINAPI cublasCherkEx( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const float *alpha, /* host or device pointer */ + const void *A, cudaDataType Atype, int lda, + const float *beta, /* host or device pointer */ + void *C, cudaDataType Ctype, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const float *, const void *, cudaDataType, int, const float *, void *, + cudaDataType, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCherkEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, uplo, trans, n, k, alpha, A, Atype, lda, beta, C, Ctype, ldc); + return func_ptr(handle, uplo, trans, n, k, alpha, A, Atype, lda, beta, C, + Ctype, ldc); } -cublasStatus_t CUBLASWINAPI cublasCherk3mEx (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - int n, - int k, - const float *alpha, - const void *A, cudaDataType Atype, - int lda, - const float *beta, - void *C, - cudaDataType Ctype, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const float *, const void *, cudaDataType, int, const float *, void *, cudaDataType, int); +cublasStatus_t CUBLASWINAPI cublasCherk3mEx( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const float *alpha, const void *A, cudaDataType Atype, + int lda, const float *beta, void *C, cudaDataType Ctype, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const float *, const void *, cudaDataType, int, const float *, void *, + cudaDataType, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCherk3mEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, uplo, trans, n, k, alpha, A, Atype, lda, beta, C, Ctype, ldc); + return func_ptr(handle, uplo, trans, n, k, alpha, A, Atype, lda, beta, C, + Ctype, ldc); } -cublasStatus_t CUBLASWINAPI cublasSsyr2k_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - int n, - int k, - const float *alpha, /* host or device pointer */ - const float *A, - int lda, - const float *B, - int ldb, - const float *beta, /* host or device pointer */ - float *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const float *, const float *, int, const float *, int, const float *, float *, int); +cublasStatus_t CUBLASWINAPI cublasSsyr2k_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const float *alpha, /* host or device pointer */ + const float *A, int lda, const float *B, int ldb, + const float *beta, /* host or device pointer */ + float *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const float *, const float *, int, const float *, int, const float *, + float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSsyr2k_v2"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, + ldc); } -cublasStatus_t CUBLASWINAPI cublasDsyr2k_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - int n, - int k, - const double *alpha, /* host or device pointer */ - const double *A, - int lda, - const double *B, - int ldb, - const double *beta, /* host or device pointer */ - double *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const double *, const double *, int, const double *, int, const double *, double *, int); +cublasStatus_t CUBLASWINAPI cublasDsyr2k_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const double *alpha, /* host or device pointer */ + const double *A, int lda, const double *B, int ldb, + const double *beta, /* host or device pointer */ + double *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const double *, const double *, int, const double *, int, const double *, + double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDsyr2k_v2"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, + ldc); } -cublasStatus_t CUBLASWINAPI cublasCsyr2k_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - int n, - int k, - const cuComplex *alpha, /* host or device pointer */ - const cuComplex *A, - int lda, - const cuComplex *B, - int ldb, - const cuComplex *beta, /* host or device pointer */ - cuComplex *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, const cuComplex *, cuComplex *, int); +cublasStatus_t CUBLASWINAPI cublasCsyr2k_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, const cuComplex *B, int ldb, + const cuComplex *beta, /* host or device pointer */ + cuComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const cuComplex *, const cuComplex *, int, const cuComplex *, int, + const cuComplex *, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCsyr2k_v2"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, + ldc); } -cublasStatus_t CUBLASWINAPI cublasZsyr2k_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - int n, - int k, - const cuDoubleComplex *alpha, /* host or device pointer */ - const cuDoubleComplex *A, - int lda, - const cuDoubleComplex *B, - int ldb, - const cuDoubleComplex *beta, /* host or device pointer */ - cuDoubleComplex *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI cublasZsyr2k_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, const cuDoubleComplex *B, int ldb, + const cuDoubleComplex *beta, /* host or device pointer */ + cuDoubleComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const cuDoubleComplex *, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, + int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZsyr2k_v2"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, + ldc); } -cublasStatus_t CUBLASWINAPI cublasCher2k_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - int n, - int k, - const cuComplex *alpha, /* host or device pointer */ - const cuComplex *A, - int lda, - const cuComplex *B, - int ldb, - const float *beta, /* host or device pointer */ - cuComplex *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, const float *, cuComplex *, int); +cublasStatus_t CUBLASWINAPI cublasCher2k_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, const cuComplex *B, int ldb, + const float *beta, /* host or device pointer */ + cuComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const cuComplex *, const cuComplex *, int, const cuComplex *, int, + const float *, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCher2k_v2"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, + ldc); } -cublasStatus_t CUBLASWINAPI cublasZher2k_v2 (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - int n, - int k, - const cuDoubleComplex *alpha, /* host or device pointer */ - const cuDoubleComplex *A, - int lda, - const cuDoubleComplex *B, - int ldb, - const double *beta, /* host or device pointer */ - cuDoubleComplex *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, const double *, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI cublasZher2k_v2( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, const cuDoubleComplex *B, int ldb, + const double *beta, /* host or device pointer */ + cuDoubleComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const cuDoubleComplex *, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, const double *, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZher2k_v2"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, + ldc); } -cublasStatus_t CUBLASWINAPI cublasSsyrkx (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - int n, - int k, - const float *alpha, /* host or device pointer */ - const float *A, - int lda, - const float *B, - int ldb, - const float *beta, /* host or device pointer */ - float *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const float *, const float *, int, const float *, int, const float *, float *, int); +cublasStatus_t CUBLASWINAPI cublasSsyrkx( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const float *alpha, /* host or device pointer */ + const float *A, int lda, const float *B, int ldb, + const float *beta, /* host or device pointer */ + float *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const float *, const float *, int, const float *, int, const float *, + float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSsyrkx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, + ldc); } -cublasStatus_t CUBLASWINAPI cublasDsyrkx (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - int n, - int k, - const double *alpha, /* host or device pointer */ - const double *A, - int lda, - const double *B, - int ldb, - const double *beta, /* host or device pointer */ - double *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const double *, const double *, int, const double *, int, const double *, double *, int); +cublasStatus_t CUBLASWINAPI cublasDsyrkx( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const double *alpha, /* host or device pointer */ + const double *A, int lda, const double *B, int ldb, + const double *beta, /* host or device pointer */ + double *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const double *, const double *, int, const double *, int, const double *, + double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDsyrkx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, + ldc); } -cublasStatus_t CUBLASWINAPI cublasCsyrkx (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - int n, - int k, - const cuComplex *alpha, /* host or device pointer */ - const cuComplex *A, - int lda, - const cuComplex *B, - int ldb, - const cuComplex *beta, /* host or device pointer */ - cuComplex *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, const cuComplex *, cuComplex *, int); +cublasStatus_t CUBLASWINAPI cublasCsyrkx( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, const cuComplex *B, int ldb, + const cuComplex *beta, /* host or device pointer */ + cuComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const cuComplex *, const cuComplex *, int, const cuComplex *, int, + const cuComplex *, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCsyrkx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, + ldc); } -cublasStatus_t CUBLASWINAPI cublasZsyrkx (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - int n, - int k, - const cuDoubleComplex *alpha, /* host or device pointer */ - const cuDoubleComplex *A, - int lda, - const cuDoubleComplex *B, - int ldb, - const cuDoubleComplex *beta, /* host or device pointer */ - cuDoubleComplex *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI cublasZsyrkx( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, const cuDoubleComplex *B, int ldb, + const cuDoubleComplex *beta, /* host or device pointer */ + cuDoubleComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const cuDoubleComplex *, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, + int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZsyrkx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, + ldc); } -cublasStatus_t CUBLASWINAPI cublasCherkx (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - int n, - int k, - const cuComplex *alpha, /* host or device pointer */ - const cuComplex *A, - int lda, - const cuComplex *B, - int ldb, - const float *beta, /* host or device pointer */ - cuComplex *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, const float *, cuComplex *, int); +cublasStatus_t CUBLASWINAPI cublasCherkx( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, const cuComplex *B, int ldb, + const float *beta, /* host or device pointer */ + cuComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const cuComplex *, const cuComplex *, int, const cuComplex *, int, + const float *, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCherkx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, + ldc); } -cublasStatus_t CUBLASWINAPI cublasZherkx (cublasHandle_t handle, - cublasFillMode_t uplo, - cublasOperation_t trans, - int n, - int k, - const cuDoubleComplex *alpha, /* host or device pointer */ - const cuDoubleComplex *A, - int lda, - const cuDoubleComplex *B, - int ldb, - const double *beta, /* host or device pointer */ - cuDoubleComplex *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, const double *, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI cublasZherkx( + cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, + int n, int k, const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, const cuDoubleComplex *B, int ldb, + const double *beta, /* host or device pointer */ + cuDoubleComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, cublasOperation_t, int, int, + const cuDoubleComplex *, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, const double *, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZherkx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); + return func_ptr(handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, + ldc); } -cublasStatus_t CUBLASWINAPI cublasSsymm_v2 (cublasHandle_t handle, - cublasSideMode_t side, - cublasFillMode_t uplo, - int m, - int n, - const float *alpha, /* host or device pointer */ - const float *A, - int lda, - const float *B, - int ldb, - const float *beta, /* host or device pointer */ - float *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, cublasFillMode_t, int, int, const float *, const float *, int, const float *, int, const float *, float *, int); +cublasStatus_t CUBLASWINAPI cublasSsymm_v2( + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, int m, + int n, const float *alpha, /* host or device pointer */ + const float *A, int lda, const float *B, int ldb, + const float *beta, /* host or device pointer */ + float *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, cublasFillMode_t, int, int, + const float *, const float *, int, const float *, int, const float *, + float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSsymm_v2"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc); + return func_ptr(handle, side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, + ldc); } -cublasStatus_t CUBLASWINAPI cublasDsymm_v2 (cublasHandle_t handle, - cublasSideMode_t side, - cublasFillMode_t uplo, - int m, - int n, - const double *alpha, /* host or device pointer */ - const double *A, - int lda, - const double *B, - int ldb, - const double *beta, /* host or device pointer */ - double *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, cublasFillMode_t, int, int, const double *, const double *, int, const double *, int, const double *, double *, int); +cublasStatus_t CUBLASWINAPI cublasDsymm_v2( + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, int m, + int n, const double *alpha, /* host or device pointer */ + const double *A, int lda, const double *B, int ldb, + const double *beta, /* host or device pointer */ + double *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, cublasFillMode_t, int, int, + const double *, const double *, int, const double *, int, const double *, + double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDsymm_v2"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc); + return func_ptr(handle, side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, + ldc); } -cublasStatus_t CUBLASWINAPI cublasCsymm_v2 (cublasHandle_t handle, - cublasSideMode_t side, - cublasFillMode_t uplo, - int m, - int n, - const cuComplex *alpha, /* host or device pointer */ - const cuComplex *A, - int lda, - const cuComplex *B, - int ldb, - const cuComplex *beta, /* host or device pointer */ - cuComplex *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, cublasFillMode_t, int, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, const cuComplex *, cuComplex *, int); +cublasStatus_t CUBLASWINAPI cublasCsymm_v2( + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, int m, + int n, const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, const cuComplex *B, int ldb, + const cuComplex *beta, /* host or device pointer */ + cuComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, cublasFillMode_t, int, int, + const cuComplex *, const cuComplex *, int, const cuComplex *, int, + const cuComplex *, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCsymm_v2"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc); + return func_ptr(handle, side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, + ldc); } -cublasStatus_t CUBLASWINAPI cublasZsymm_v2 (cublasHandle_t handle, - cublasSideMode_t side, - cublasFillMode_t uplo, - int m, - int n, - const cuDoubleComplex *alpha, /* host or device pointer */ - const cuDoubleComplex *A, - int lda, - const cuDoubleComplex *B, - int ldb, - const cuDoubleComplex *beta, /* host or device pointer */ - cuDoubleComplex *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, cublasFillMode_t, int, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI cublasZsymm_v2( + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, int m, + int n, const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, const cuDoubleComplex *B, int ldb, + const cuDoubleComplex *beta, /* host or device pointer */ + cuDoubleComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, cublasFillMode_t, int, int, + const cuDoubleComplex *, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, + int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZsymm_v2"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc); + return func_ptr(handle, side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, + ldc); } -cublasStatus_t CUBLASWINAPI cublasChemm_v2 (cublasHandle_t handle, - cublasSideMode_t side, - cublasFillMode_t uplo, - int m, - int n, - const cuComplex *alpha, /* host or device pointer */ - const cuComplex *A, - int lda, - const cuComplex *B, - int ldb, - const cuComplex *beta, /* host or device pointer */ - cuComplex *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, cublasFillMode_t, int, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, const cuComplex *, cuComplex *, int); +cublasStatus_t CUBLASWINAPI cublasChemm_v2( + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, int m, + int n, const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, const cuComplex *B, int ldb, + const cuComplex *beta, /* host or device pointer */ + cuComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, cublasFillMode_t, int, int, + const cuComplex *, const cuComplex *, int, const cuComplex *, int, + const cuComplex *, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasChemm_v2"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc); + return func_ptr(handle, side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, + ldc); } -cublasStatus_t CUBLASWINAPI cublasZhemm_v2 (cublasHandle_t handle, - cublasSideMode_t side, - cublasFillMode_t uplo, - int m, - int n, - const cuDoubleComplex *alpha, /* host or device pointer */ - const cuDoubleComplex *A, - int lda, - const cuDoubleComplex *B, - int ldb, - const cuDoubleComplex *beta, /* host or device pointer */ - cuDoubleComplex *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, cublasFillMode_t, int, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI cublasZhemm_v2( + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, int m, + int n, const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, const cuDoubleComplex *B, int ldb, + const cuDoubleComplex *beta, /* host or device pointer */ + cuDoubleComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, cublasFillMode_t, int, int, + const cuDoubleComplex *, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, + int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZhemm_v2"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc); + return func_ptr(handle, side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, + ldc); } -cublasStatus_t CUBLASWINAPI cublasStrsm_v2 (cublasHandle_t handle, - cublasSideMode_t side, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int m, - int n, - const float *alpha, /* host or device pointer */ - const float *A, - int lda, - float *B, - int ldb) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, int, const float *, const float *, int, float *, int); +cublasStatus_t CUBLASWINAPI cublasStrsm_v2( + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, cublasDiagType_t diag, int m, int n, + const float *alpha, /* host or device pointer */ + const float *A, int lda, float *B, int ldb) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + cublasDiagType_t, int, int, const float *, const float *, int, float *, + int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasStrsm_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb); } -cublasStatus_t CUBLASWINAPI cublasDtrsm_v2 (cublasHandle_t handle, - cublasSideMode_t side, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int m, - int n, - const double *alpha, /* host or device pointer */ - const double *A, - int lda, - double *B, - int ldb) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, int, const double *, const double *, int, double *, int); +cublasStatus_t CUBLASWINAPI cublasDtrsm_v2( + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, cublasDiagType_t diag, int m, int n, + const double *alpha, /* host or device pointer */ + const double *A, int lda, double *B, int ldb) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + cublasDiagType_t, int, int, const double *, const double *, int, double *, + int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtrsm_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb); } -cublasStatus_t CUBLASWINAPI cublasCtrsm_v2(cublasHandle_t handle, - cublasSideMode_t side, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int m, - int n, - const cuComplex *alpha, /* host or device pointer */ - const cuComplex *A, - int lda, - cuComplex *B, - int ldb) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, int, const cuComplex *, const cuComplex *, int, cuComplex *, int); +cublasStatus_t CUBLASWINAPI cublasCtrsm_v2( + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, cublasDiagType_t diag, int m, int n, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, cuComplex *B, int ldb) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + cublasDiagType_t, int, int, const cuComplex *, const cuComplex *, int, + cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtrsm_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb); } -cublasStatus_t CUBLASWINAPI cublasZtrsm_v2(cublasHandle_t handle, - cublasSideMode_t side, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int m, - int n, - const cuDoubleComplex *alpha, /* host or device pointer */ - const cuDoubleComplex *A, - int lda, - cuDoubleComplex *B, - int ldb) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, int, const cuDoubleComplex *, const cuDoubleComplex *, int, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI cublasZtrsm_v2( + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, cublasDiagType_t diag, int m, int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, cuDoubleComplex *B, int ldb) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + cublasDiagType_t, int, int, const cuDoubleComplex *, + const cuDoubleComplex *, int, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtrsm_v2"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb); } -cublasStatus_t CUBLASWINAPI cublasStrmm_v2 (cublasHandle_t handle, - cublasSideMode_t side, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int m, - int n, - const float *alpha, /* host or device pointer */ - const float *A, - int lda, - const float *B, - int ldb, - float *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, int, const float *, const float *, int, const float *, int, float *, int); +cublasStatus_t CUBLASWINAPI cublasStrmm_v2( + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, cublasDiagType_t diag, int m, int n, + const float *alpha, /* host or device pointer */ + const float *A, int lda, const float *B, int ldb, float *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + cublasDiagType_t, int, int, const float *, const float *, int, + const float *, int, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasStrmm_v2"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, C, ldc); + return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, + C, ldc); } -cublasStatus_t CUBLASWINAPI cublasDtrmm_v2 (cublasHandle_t handle, - cublasSideMode_t side, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int m, - int n, - const double *alpha, /* host or device pointer */ - const double *A, - int lda, - const double *B, - int ldb, - double *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, int, const double *, const double *, int, const double *, int, double *, int); +cublasStatus_t CUBLASWINAPI cublasDtrmm_v2( + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, cublasDiagType_t diag, int m, int n, + const double *alpha, /* host or device pointer */ + const double *A, int lda, const double *B, int ldb, double *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + cublasDiagType_t, int, int, const double *, const double *, int, + const double *, int, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtrmm_v2"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, C, ldc); + return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, + C, ldc); } -cublasStatus_t CUBLASWINAPI cublasCtrmm_v2(cublasHandle_t handle, - cublasSideMode_t side, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int m, - int n, - const cuComplex *alpha, /* host or device pointer */ - const cuComplex *A, - int lda, - const cuComplex *B, - int ldb, - cuComplex *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, int, const cuComplex *, const cuComplex *, int, const cuComplex *, int, cuComplex *, int); +cublasStatus_t CUBLASWINAPI cublasCtrmm_v2( + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, cublasDiagType_t diag, int m, int n, + const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, const cuComplex *B, int ldb, cuComplex *C, + int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + cublasDiagType_t, int, int, const cuComplex *, const cuComplex *, int, + const cuComplex *, int, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtrmm_v2"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, C, ldc); + return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, + C, ldc); } -cublasStatus_t CUBLASWINAPI cublasZtrmm_v2(cublasHandle_t handle, cublasSideMode_t side, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int m, - int n, - const cuDoubleComplex *alpha, /* host or device pointer */ - const cuDoubleComplex *A, - int lda, - const cuDoubleComplex *B, - int ldb, - cuDoubleComplex *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI cublasZtrmm_v2( + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, cublasDiagType_t diag, int m, int n, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, const cuDoubleComplex *B, int ldb, + cuDoubleComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + cublasDiagType_t, int, int, const cuDoubleComplex *, + const cuDoubleComplex *, int, const cuDoubleComplex *, int, + cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtrmm_v2"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, C, ldc); + return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, + C, ldc); } -cublasStatus_t CUBLASWINAPI cublasSgemmBatched (cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const float *alpha, /* host or device pointer */ - const float *Aarray[], - int lda, - const float *Barray[], - int ldb, - const float *beta, /* host or device pointer */ - float *Carray[], - int ldc, - int batchCount) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const float *, const float *[], int, const float *[], int, const float *, float *[], int, int); +cublasStatus_t CUBLASWINAPI cublasSgemmBatched( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const float *alpha, /* host or device pointer */ + const float *Aarray[], int lda, const float *Barray[], int ldb, + const float *beta, /* host or device pointer */ + float *Carray[], int ldc, int batchCount) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const float *, const float *[], int, const float *[], int, const float *, + float *[], int, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSgemmBatched"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount); + return func_ptr(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, + ldb, beta, Carray, ldc, batchCount); } -cublasStatus_t CUBLASWINAPI cublasDgemmBatched (cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const double *alpha, /* host or device pointer */ - const double *Aarray[], - int lda, - const double *Barray[], - int ldb, - const double *beta, /* host or device pointer */ - double *Carray[], - int ldc, - int batchCount) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const double *, const double *[], int, const double *[], int, const double *, double *[], int, int); +cublasStatus_t CUBLASWINAPI cublasDgemmBatched( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const double *alpha, /* host or device pointer */ + const double *Aarray[], int lda, const double *Barray[], int ldb, + const double *beta, /* host or device pointer */ + double *Carray[], int ldc, int batchCount) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const double *, const double *[], int, const double *[], int, + const double *, double *[], int, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDgemmBatched"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount); + return func_ptr(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, + ldb, beta, Carray, ldc, batchCount); } -cublasStatus_t CUBLASWINAPI cublasCgemmBatched (cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const cuComplex *alpha, /* host or device pointer */ - const cuComplex *Aarray[], - int lda, - const cuComplex *Barray[], - int ldb, - const cuComplex *beta, /* host or device pointer */ - cuComplex *Carray[], - int ldc, - int batchCount) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const cuComplex *, const cuComplex *[], int, const cuComplex *[], int, const cuComplex *, cuComplex *[], int, int); +cublasStatus_t CUBLASWINAPI cublasCgemmBatched( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const cuComplex *alpha, /* host or device pointer */ + const cuComplex *Aarray[], int lda, const cuComplex *Barray[], int ldb, + const cuComplex *beta, /* host or device pointer */ + cuComplex *Carray[], int ldc, int batchCount) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const cuComplex *, const cuComplex *[], int, const cuComplex *[], int, + const cuComplex *, cuComplex *[], int, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgemmBatched"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount); + return func_ptr(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, + ldb, beta, Carray, ldc, batchCount); } -cublasStatus_t CUBLASWINAPI cublasCgemm3mBatched (cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const cuComplex *alpha, /* host or device pointer */ - const cuComplex *Aarray[], - int lda, - const cuComplex *Barray[], - int ldb, - const cuComplex *beta, /* host or device pointer */ - cuComplex *Carray[], - int ldc, - int batchCount) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const cuComplex *, const cuComplex *[], int, const cuComplex *[], int, const cuComplex *, cuComplex *[], int, int); +cublasStatus_t CUBLASWINAPI cublasCgemm3mBatched( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const cuComplex *alpha, /* host or device pointer */ + const cuComplex *Aarray[], int lda, const cuComplex *Barray[], int ldb, + const cuComplex *beta, /* host or device pointer */ + cuComplex *Carray[], int ldc, int batchCount) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const cuComplex *, const cuComplex *[], int, const cuComplex *[], int, + const cuComplex *, cuComplex *[], int, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgemm3mBatched"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount); + return func_ptr(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, + ldb, beta, Carray, ldc, batchCount); } -cublasStatus_t CUBLASWINAPI cublasZgemmBatched (cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const cuDoubleComplex *alpha, /* host or device pointer */ - const cuDoubleComplex *Aarray[], - int lda, - const cuDoubleComplex *Barray[], - int ldb, - const cuDoubleComplex *beta, /* host or device pointer */ - cuDoubleComplex *Carray[], - int ldc, - int batchCount) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const cuDoubleComplex *, const cuDoubleComplex *[], int, const cuDoubleComplex *[], int, const cuDoubleComplex *, cuDoubleComplex *[], int, int); +cublasStatus_t CUBLASWINAPI cublasZgemmBatched( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *Aarray[], int lda, const cuDoubleComplex *Barray[], + int ldb, const cuDoubleComplex *beta, /* host or device pointer */ + cuDoubleComplex *Carray[], int ldc, int batchCount) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const cuDoubleComplex *, const cuDoubleComplex *[], int, + const cuDoubleComplex *[], int, const cuDoubleComplex *, + cuDoubleComplex *[], int, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZgemmBatched"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount); + return func_ptr(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, + ldb, beta, Carray, ldc, batchCount); } -cublasStatus_t CUBLASWINAPI cublasSgemmStridedBatched (cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const float *alpha, /* host or device pointer */ - const float *A, - int lda, - long long int strideA, /* purposely signed */ - const float *B, - int ldb, - long long int strideB, - const float *beta, /* host or device pointer */ - float *C, - int ldc, - long long int strideC, - int batchCount) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const float *, const float *, int, long long, const float *, int, long long, const float *, float *, int, long long, int); +cublasStatus_t CUBLASWINAPI cublasSgemmStridedBatched( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const float *alpha, /* host or device pointer */ + const float *A, int lda, long long int strideA, /* purposely signed */ + const float *B, int ldb, long long int strideB, + const float *beta, /* host or device pointer */ + float *C, int ldc, long long int strideC, int batchCount) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const float *, const float *, int, long long, const float *, int, + long long, const float *, float *, int, long long, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSgemmStridedBatched"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, ldb, strideB, beta, C, ldc, strideC, batchCount); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, + ldb, strideB, beta, C, ldc, strideC, batchCount); } -cublasStatus_t CUBLASWINAPI cublasDgemmStridedBatched (cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const double *alpha, /* host or device pointer */ - const double *A, - int lda, - long long int strideA, /* purposely signed */ - const double *B, - int ldb, - long long int strideB, - const double *beta, /* host or device pointer */ - double *C, - int ldc, - long long int strideC, - int batchCount) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const double *, const double *, int, long long, const double *, int, long long, const double *, double *, int, long long, int); +cublasStatus_t CUBLASWINAPI cublasDgemmStridedBatched( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const double *alpha, /* host or device pointer */ + const double *A, int lda, long long int strideA, /* purposely signed */ + const double *B, int ldb, long long int strideB, + const double *beta, /* host or device pointer */ + double *C, int ldc, long long int strideC, int batchCount) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const double *, const double *, int, long long, const double *, int, + long long, const double *, double *, int, long long, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDgemmStridedBatched"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, ldb, strideB, beta, C, ldc, strideC, batchCount); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, + ldb, strideB, beta, C, ldc, strideC, batchCount); } -cublasStatus_t CUBLASWINAPI cublasCgemmStridedBatched (cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const cuComplex *alpha, /* host or device pointer */ - const cuComplex *A, - int lda, - long long int strideA, /* purposely signed */ - const cuComplex *B, - int ldb, - long long int strideB, - const cuComplex *beta, /* host or device pointer */ - cuComplex *C, - int ldc, - long long int strideC, - int batchCount) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const cuComplex *, const cuComplex *, int, long long, const cuComplex *, int, long long, const cuComplex *, cuComplex *, int, long long, int); +cublasStatus_t CUBLASWINAPI cublasCgemmStridedBatched( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, long long int strideA, /* purposely signed */ + const cuComplex *B, int ldb, long long int strideB, + const cuComplex *beta, /* host or device pointer */ + cuComplex *C, int ldc, long long int strideC, int batchCount) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const cuComplex *, const cuComplex *, int, long long, const cuComplex *, + int, long long, const cuComplex *, cuComplex *, int, long long, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgemmStridedBatched"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, ldb, strideB, beta, C, ldc, strideC, batchCount); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, + ldb, strideB, beta, C, ldc, strideC, batchCount); } -cublasStatus_t CUBLASWINAPI cublasCgemm3mStridedBatched (cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const cuComplex *alpha, /* host or device pointer */ - const cuComplex *A, - int lda, - long long int strideA, /* purposely signed */ - const cuComplex *B, - int ldb, - long long int strideB, - const cuComplex *beta, /* host or device pointer */ - cuComplex *C, - int ldc, - long long int strideC, - int batchCount) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const cuComplex *, const cuComplex *, int, long long, const cuComplex *, int, long long, const cuComplex *, cuComplex *, int, long long, int); +cublasStatus_t CUBLASWINAPI cublasCgemm3mStridedBatched( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, long long int strideA, /* purposely signed */ + const cuComplex *B, int ldb, long long int strideB, + const cuComplex *beta, /* host or device pointer */ + cuComplex *C, int ldc, long long int strideC, int batchCount) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const cuComplex *, const cuComplex *, int, long long, const cuComplex *, + int, long long, const cuComplex *, cuComplex *, int, long long, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgemm3mStridedBatched"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, ldb, strideB, beta, C, ldc, strideC, batchCount); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, + ldb, strideB, beta, C, ldc, strideC, batchCount); } -cublasStatus_t CUBLASWINAPI cublasZgemmStridedBatched (cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const cuDoubleComplex *alpha, /* host or device pointer */ - const cuDoubleComplex *A, - int lda, - long long int strideA, /* purposely signed */ - const cuDoubleComplex *B, - int ldb, - long long int strideB, - const cuDoubleComplex *beta, /* host or device poi */ - cuDoubleComplex *C, - int ldc, - long long int strideC, - int batchCount) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const cuDoubleComplex *, const cuDoubleComplex *, int, long long, const cuDoubleComplex *, int, long long, const cuDoubleComplex *, cuDoubleComplex *, int, long long, int); +cublasStatus_t CUBLASWINAPI cublasZgemmStridedBatched( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, + const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, + long long int strideA, /* purposely signed */ + const cuDoubleComplex *B, int ldb, long long int strideB, + const cuDoubleComplex *beta, /* host or device poi */ + cuDoubleComplex *C, int ldc, long long int strideC, int batchCount) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, + const cuDoubleComplex *, const cuDoubleComplex *, int, long long, + const cuDoubleComplex *, int, long long, const cuDoubleComplex *, + cuDoubleComplex *, int, long long, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZgemmStridedBatched"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, ldb, strideB, beta, C, ldc, strideC, batchCount); + return func_ptr(handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, + ldb, strideB, beta, C, ldc, strideC, batchCount); } -cublasStatus_t CUBLASWINAPI cublasSgeam(cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - const float *alpha, /* host or device pointer */ - const float *A, - int lda, - const float *beta , /* host or device pointer */ - const float *B, - int ldb, - float *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, const float *, const float *, int, const float *, const float *, int, float *, int); +cublasStatus_t CUBLASWINAPI cublasSgeam( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, const float *alpha, /* host or device pointer */ + const float *A, int lda, const float *beta, /* host or device pointer */ + const float *B, int ldb, float *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, + const float *, const float *, int, const float *, const float *, int, + float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSgeam"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc); + return func_ptr(handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, + ldc); } -cublasStatus_t CUBLASWINAPI cublasDgeam(cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - const double *alpha, /* host or device pointer */ - const double *A, - int lda, - const double *beta, /* host or device pointer */ - const double *B, - int ldb, - double *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, const double *, const double *, int, const double *, const double *, int, double *, int); +cublasStatus_t CUBLASWINAPI cublasDgeam( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, const double *alpha, /* host or device pointer */ + const double *A, int lda, const double *beta, /* host or device pointer */ + const double *B, int ldb, double *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, + const double *, const double *, int, const double *, const double *, int, + double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDgeam"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc); + return func_ptr(handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, + ldc); } -cublasStatus_t CUBLASWINAPI cublasCgeam(cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - const cuComplex *alpha, /* host or device pointer */ - const cuComplex *A, - int lda, - const cuComplex *beta, /* host or device pointer */ - const cuComplex *B, - int ldb, - cuComplex *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, const cuComplex *, const cuComplex *, int, const cuComplex *, const cuComplex *, int, cuComplex *, int); +cublasStatus_t CUBLASWINAPI cublasCgeam( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, const cuComplex *alpha, /* host or device pointer */ + const cuComplex *A, int lda, + const cuComplex *beta, /* host or device pointer */ + const cuComplex *B, int ldb, cuComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, + const cuComplex *, const cuComplex *, int, const cuComplex *, + const cuComplex *, int, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgeam"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc); + return func_ptr(handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, + ldc); } -cublasStatus_t CUBLASWINAPI cublasZgeam(cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - const cuDoubleComplex *alpha, /* host or device pointer */ - const cuDoubleComplex *A, - int lda, - const cuDoubleComplex *beta, /* host or device pointer */ - const cuDoubleComplex *B, - int ldb, - cuDoubleComplex *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, const cuDoubleComplex *, const cuDoubleComplex *, int, const cuDoubleComplex *, const cuDoubleComplex *, int, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI cublasZgeam( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, const cuDoubleComplex *alpha, /* host or device pointer */ + const cuDoubleComplex *A, int lda, + const cuDoubleComplex *beta, /* host or device pointer */ + const cuDoubleComplex *B, int ldb, cuDoubleComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, + const cuDoubleComplex *, const cuDoubleComplex *, int, + const cuDoubleComplex *, const cuDoubleComplex *, int, cuDoubleComplex *, + int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZgeam"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc); + return func_ptr(handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, + ldc); } -cublasStatus_t CUBLASWINAPI cublasSgetrfBatched(cublasHandle_t handle, - int n, - float *A[], /*Device pointer*/ - int lda, - int *P, /*Device Pointer*/ - int *info, /*Device Pointer*/ - int batchSize) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, float *[], int, int *, int *, int); +cublasStatus_t CUBLASWINAPI cublasSgetrfBatched(cublasHandle_t handle, int n, + float *A[], /*Device pointer*/ + int lda, + int *P, /*Device Pointer*/ + int *info, /*Device Pointer*/ + int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, float *[], + int, int *, int *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSgetrfBatched"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, A, lda, P, info, batchSize); } -cublasStatus_t CUBLASWINAPI cublasDgetrfBatched(cublasHandle_t handle, - int n, - double *A[], /*Device pointer*/ - int lda, - int *P, /*Device Pointer*/ - int *info, /*Device Pointer*/ - int batchSize) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, double *[], int, int *, int *, int); +cublasStatus_t CUBLASWINAPI cublasDgetrfBatched(cublasHandle_t handle, int n, + double *A[], /*Device pointer*/ + int lda, + int *P, /*Device Pointer*/ + int *info, /*Device Pointer*/ + int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, double *[], int, int *, int *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDgetrfBatched"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, A, lda, P, info, batchSize); } -cublasStatus_t CUBLASWINAPI cublasCgetrfBatched(cublasHandle_t handle, - int n, - cuComplex *A[], /*Device pointer*/ - int lda, - int *P, /*Device Pointer*/ - int *info, /*Device Pointer*/ - int batchSize) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, cuComplex *[], int, int *, int *, int); +cublasStatus_t CUBLASWINAPI cublasCgetrfBatched( + cublasHandle_t handle, int n, cuComplex *A[], /*Device pointer*/ + int lda, int *P, /*Device Pointer*/ + int *info, /*Device Pointer*/ + int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, cuComplex *[], int, int *, int *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgetrfBatched"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, A, lda, P, info, batchSize); } -cublasStatus_t CUBLASWINAPI cublasZgetrfBatched(cublasHandle_t handle, - int n, - cuDoubleComplex *A[], /*Device pointer*/ - int lda, - int *P, /*Device Pointer*/ - int *info, /*Device Pointer*/ - int batchSize) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, cuDoubleComplex *[], int, int *, int *, int); +cublasStatus_t CUBLASWINAPI cublasZgetrfBatched( + cublasHandle_t handle, int n, cuDoubleComplex *A[], /*Device pointer*/ + int lda, int *P, /*Device Pointer*/ + int *info, /*Device Pointer*/ + int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, cuDoubleComplex *[], int, int *, int *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZgetrfBatched"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, A, lda, P, info, batchSize); } -cublasStatus_t CUBLASWINAPI cublasSgetriBatched(cublasHandle_t handle, - int n, - const float *A[], /*Device pointer*/ - int lda, - const int *P, /*Device pointer*/ - float *C[], /*Device pointer*/ - int ldc, - int *info, - int batchSize) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const float *[], int, const int *, float *[], int, int *, int); +cublasStatus_t CUBLASWINAPI cublasSgetriBatched( + cublasHandle_t handle, int n, const float *A[], /*Device pointer*/ + int lda, const int *P, /*Device pointer*/ + float *C[], /*Device pointer*/ + int ldc, int *info, int batchSize) { + using FuncPtr = + cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, const float *[], int, + const int *, float *[], int, int *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSgetriBatched"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, A, lda, P, C, ldc, info, batchSize); } -cublasStatus_t CUBLASWINAPI cublasDgetriBatched(cublasHandle_t handle, - int n, - const double *A[], /*Device pointer*/ - int lda, - const int *P, /*Device pointer*/ - double *C[], /*Device pointer*/ - int ldc, - int *info, - int batchSize) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const double *[], int, const int *, double *[], int, int *, int); +cublasStatus_t CUBLASWINAPI cublasDgetriBatched( + cublasHandle_t handle, int n, const double *A[], /*Device pointer*/ + int lda, const int *P, /*Device pointer*/ + double *C[], /*Device pointer*/ + int ldc, int *info, int batchSize) { + using FuncPtr = + cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, const double *[], int, + const int *, double *[], int, int *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDgetriBatched"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, A, lda, P, C, ldc, info, batchSize); } -cublasStatus_t CUBLASWINAPI cublasCgetriBatched(cublasHandle_t handle, - int n, - const cuComplex *A[], /*Device pointer*/ - int lda, - const int *P, /*Device pointer*/ - cuComplex *C[], /*Device pointer*/ - int ldc, - int *info, - int batchSize) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuComplex *[], int, const int *, cuComplex *[], int, int *, int); +cublasStatus_t CUBLASWINAPI cublasCgetriBatched( + cublasHandle_t handle, int n, const cuComplex *A[], /*Device pointer*/ + int lda, const int *P, /*Device pointer*/ + cuComplex *C[], /*Device pointer*/ + int ldc, int *info, int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const cuComplex *[], int, const int *, cuComplex *[], + int, int *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgetriBatched"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, A, lda, P, C, ldc, info, batchSize); } -cublasStatus_t CUBLASWINAPI cublasZgetriBatched(cublasHandle_t handle, - int n, - const cuDoubleComplex *A[], /*Device pointer*/ - int lda, - const int *P, /*Device pointer*/ - cuDoubleComplex *C[], /*Device pointer*/ - int ldc, - int *info, - int batchSize) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuDoubleComplex *[], int, const int *, cuDoubleComplex *[], int, int *, int); +cublasStatus_t CUBLASWINAPI cublasZgetriBatched( + cublasHandle_t handle, int n, const cuDoubleComplex *A[], /*Device pointer*/ + int lda, const int *P, /*Device pointer*/ + cuDoubleComplex *C[], /*Device pointer*/ + int ldc, int *info, int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const cuDoubleComplex *[], int, const int *, + cuDoubleComplex *[], int, int *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZgetriBatched"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, A, lda, P, C, ldc, info, batchSize); } -cublasStatus_t CUBLASWINAPI cublasSgetrsBatched( cublasHandle_t handle, - cublasOperation_t trans, - int n, - int nrhs, - const float *Aarray[], - int lda, - const int *devIpiv, - float *Barray[], - int ldb, - int *info, - int batchSize) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, int, int, const float *[], int, const int *, float *[], int, int *, int); +cublasStatus_t CUBLASWINAPI cublasSgetrsBatched(cublasHandle_t handle, + cublasOperation_t trans, int n, + int nrhs, const float *Aarray[], + int lda, const int *devIpiv, + float *Barray[], int ldb, + int *info, int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, int, int, const float *[], int, + const int *, float *[], int, int *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSgetrsBatched"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, trans, n, nrhs, Aarray, lda, devIpiv, Barray, ldb, info, batchSize); + return func_ptr(handle, trans, n, nrhs, Aarray, lda, devIpiv, Barray, ldb, + info, batchSize); } -cublasStatus_t CUBLASWINAPI cublasDgetrsBatched( cublasHandle_t handle, - cublasOperation_t trans, - int n, - int nrhs, - const double *Aarray[], - int lda, - const int *devIpiv, - double *Barray[], - int ldb, - int *info, - int batchSize) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, int, int, const double *[], int, const int *, double *[], int, int *, int); +cublasStatus_t CUBLASWINAPI cublasDgetrsBatched( + cublasHandle_t handle, cublasOperation_t trans, int n, int nrhs, + const double *Aarray[], int lda, const int *devIpiv, double *Barray[], + int ldb, int *info, int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, int, int, const double *[], int, + const int *, double *[], int, int *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDgetrsBatched"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, trans, n, nrhs, Aarray, lda, devIpiv, Barray, ldb, info, batchSize); + return func_ptr(handle, trans, n, nrhs, Aarray, lda, devIpiv, Barray, ldb, + info, batchSize); } -cublasStatus_t CUBLASWINAPI cublasCgetrsBatched( cublasHandle_t handle, - cublasOperation_t trans, - int n, - int nrhs, - const cuComplex *Aarray[], - int lda, - const int *devIpiv, - cuComplex *Barray[], - int ldb, - int *info, - int batchSize) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, int, int, const cuComplex *[], int, const int *, cuComplex *[], int, int *, int); +cublasStatus_t CUBLASWINAPI cublasCgetrsBatched( + cublasHandle_t handle, cublasOperation_t trans, int n, int nrhs, + const cuComplex *Aarray[], int lda, const int *devIpiv, cuComplex *Barray[], + int ldb, int *info, int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, int, int, const cuComplex *[], int, + const int *, cuComplex *[], int, int *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgetrsBatched"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, trans, n, nrhs, Aarray, lda, devIpiv, Barray, ldb, info, batchSize); + return func_ptr(handle, trans, n, nrhs, Aarray, lda, devIpiv, Barray, ldb, + info, batchSize); } -cublasStatus_t CUBLASWINAPI cublasZgetrsBatched( cublasHandle_t handle, - cublasOperation_t trans, - int n, - int nrhs, - const cuDoubleComplex *Aarray[], - int lda, - const int *devIpiv, - cuDoubleComplex *Barray[], - int ldb, - int *info, - int batchSize) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, int, int, const cuDoubleComplex *[], int, const int *, cuDoubleComplex *[], int, int *, int); +cublasStatus_t CUBLASWINAPI cublasZgetrsBatched( + cublasHandle_t handle, cublasOperation_t trans, int n, int nrhs, + const cuDoubleComplex *Aarray[], int lda, const int *devIpiv, + cuDoubleComplex *Barray[], int ldb, int *info, int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, int, int, const cuDoubleComplex *[], + int, const int *, cuDoubleComplex *[], int, int *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZgetrsBatched"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, trans, n, nrhs, Aarray, lda, devIpiv, Barray, ldb, info, batchSize); + return func_ptr(handle, trans, n, nrhs, Aarray, lda, devIpiv, Barray, ldb, + info, batchSize); } -cublasStatus_t CUBLASWINAPI cublasStrsmBatched( cublasHandle_t handle, - cublasSideMode_t side, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int m, - int n, - const float *alpha, /*Host or Device Pointer*/ - const float *A[], - int lda, - float *B[], - int ldb, - int batchCount) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, int, const float *, const float *[], int, float *[], int, int); +cublasStatus_t CUBLASWINAPI cublasStrsmBatched( + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, cublasDiagType_t diag, int m, int n, + const float *alpha, /*Host or Device Pointer*/ + const float *A[], int lda, float *B[], int ldb, int batchCount) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + cublasDiagType_t, int, int, const float *, const float *[], int, + float *[], int, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasStrsmBatched"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, batchCount); + return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, + batchCount); } -cublasStatus_t CUBLASWINAPI cublasDtrsmBatched( cublasHandle_t handle, - cublasSideMode_t side, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int m, - int n, - const double *alpha, /*Host or Device Pointer*/ - const double *A[], - int lda, - double *B[], - int ldb, - int batchCount) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, int, const double *, const double *[], int, double *[], int, int); +cublasStatus_t CUBLASWINAPI cublasDtrsmBatched( + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, cublasDiagType_t diag, int m, int n, + const double *alpha, /*Host or Device Pointer*/ + const double *A[], int lda, double *B[], int ldb, int batchCount) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + cublasDiagType_t, int, int, const double *, const double *[], int, + double *[], int, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtrsmBatched"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, batchCount); + return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, + batchCount); } -cublasStatus_t CUBLASWINAPI cublasCtrsmBatched( cublasHandle_t handle, - cublasSideMode_t side, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int m, - int n, - const cuComplex *alpha, /*Host or Device Pointer*/ - const cuComplex *A[], - int lda, - cuComplex *B[], - int ldb, - int batchCount) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, int, const cuComplex *, const cuComplex *[], int, cuComplex *[], int, int); +cublasStatus_t CUBLASWINAPI cublasCtrsmBatched( + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, cublasDiagType_t diag, int m, int n, + const cuComplex *alpha, /*Host or Device Pointer*/ + const cuComplex *A[], int lda, cuComplex *B[], int ldb, int batchCount) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + cublasDiagType_t, int, int, const cuComplex *, const cuComplex *[], int, + cuComplex *[], int, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtrsmBatched"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, batchCount); + return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, + batchCount); } -cublasStatus_t CUBLASWINAPI cublasZtrsmBatched( cublasHandle_t handle, - cublasSideMode_t side, - cublasFillMode_t uplo, - cublasOperation_t trans, - cublasDiagType_t diag, - int m, - int n, - const cuDoubleComplex *alpha, /*Host or Device Pointer*/ - const cuDoubleComplex *A[], - int lda, - cuDoubleComplex *B[], - int ldb, - int batchCount) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, cublasDiagType_t, int, int, const cuDoubleComplex *, const cuDoubleComplex *[], int, cuDoubleComplex *[], int, int); +cublasStatus_t CUBLASWINAPI cublasZtrsmBatched( + cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, cublasDiagType_t diag, int m, int n, + const cuDoubleComplex *alpha, /*Host or Device Pointer*/ + const cuDoubleComplex *A[], int lda, cuDoubleComplex *B[], int ldb, + int batchCount) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + cublasDiagType_t, int, int, const cuDoubleComplex *, + const cuDoubleComplex *[], int, cuDoubleComplex *[], int, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtrsmBatched"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, batchCount); + return func_ptr(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb, + batchCount); } -cublasStatus_t CUBLASWINAPI cublasSmatinvBatched(cublasHandle_t handle, - int n, - const float *A[], /*Device pointer*/ - int lda, - float *Ainv[], /*Device pointer*/ - int lda_inv, - int *info, /*Device Pointer*/ - int batchSize) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const float *[], int, float *[], int, int *, int); +cublasStatus_t CUBLASWINAPI cublasSmatinvBatched( + cublasHandle_t handle, int n, const float *A[], /*Device pointer*/ + int lda, float *Ainv[], /*Device pointer*/ + int lda_inv, int *info, /*Device Pointer*/ + int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const float *[], int, float *[], int, int *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSmatinvBatched"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, A, lda, Ainv, lda_inv, info, batchSize); } -cublasStatus_t CUBLASWINAPI cublasDmatinvBatched(cublasHandle_t handle, - int n, - const double *A[], /*Device pointer*/ - int lda, - double *Ainv[], /*Device pointer*/ - int lda_inv, - int *info, /*Device Pointer*/ - int batchSize) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const double *[], int, double *[], int, int *, int); +cublasStatus_t CUBLASWINAPI cublasDmatinvBatched( + cublasHandle_t handle, int n, const double *A[], /*Device pointer*/ + int lda, double *Ainv[], /*Device pointer*/ + int lda_inv, int *info, /*Device Pointer*/ + int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const double *[], int, double *[], int, int *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDmatinvBatched"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, A, lda, Ainv, lda_inv, info, batchSize); } -cublasStatus_t CUBLASWINAPI cublasCmatinvBatched(cublasHandle_t handle, - int n, - const cuComplex *A[], /*Device pointer*/ - int lda, - cuComplex *Ainv[], /*Device pointer*/ - int lda_inv, - int *info, /*Device Pointer*/ - int batchSize) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuComplex *[], int, cuComplex *[], int, int *, int); +cublasStatus_t CUBLASWINAPI cublasCmatinvBatched( + cublasHandle_t handle, int n, const cuComplex *A[], /*Device pointer*/ + int lda, cuComplex *Ainv[], /*Device pointer*/ + int lda_inv, int *info, /*Device Pointer*/ + int batchSize) { + using FuncPtr = + cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, int, const cuComplex *[], + int, cuComplex *[], int, int *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCmatinvBatched"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, A, lda, Ainv, lda_inv, info, batchSize); } -cublasStatus_t CUBLASWINAPI cublasZmatinvBatched(cublasHandle_t handle, - int n, - const cuDoubleComplex *A[], /*Device pointer*/ - int lda, - cuDoubleComplex *Ainv[], /*Device pointer*/ - int lda_inv, - int *info, /*Device Pointer*/ - int batchSize) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, const cuDoubleComplex *[], int, cuDoubleComplex *[], int, int *, int); +cublasStatus_t CUBLASWINAPI cublasZmatinvBatched( + cublasHandle_t handle, int n, const cuDoubleComplex *A[], /*Device pointer*/ + int lda, cuDoubleComplex *Ainv[], /*Device pointer*/ + int lda_inv, int *info, /*Device Pointer*/ + int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, const cuDoubleComplex *[], int, cuDoubleComplex *[], + int, int *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZmatinvBatched"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, n, A, lda, Ainv, lda_inv, info, batchSize); } -cublasStatus_t CUBLASWINAPI cublasSgeqrfBatched( cublasHandle_t handle, - int m, - int n, - float *Aarray[], /*Device pointer*/ - int lda, - float *TauArray[], /* Device pointer*/ - int *info, - int batchSize) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, int, float *[], int, float *[], int *, int); +cublasStatus_t CUBLASWINAPI cublasSgeqrfBatched( + cublasHandle_t handle, int m, int n, float *Aarray[], /*Device pointer*/ + int lda, float *TauArray[], /* Device pointer*/ + int *info, int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, int, float *[], int, float *[], int *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSgeqrfBatched"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, m, n, Aarray, lda, TauArray, info, batchSize); } -cublasStatus_t CUBLASWINAPI cublasDgeqrfBatched( cublasHandle_t handle, - int m, - int n, - double *Aarray[], /*Device pointer*/ - int lda, - double *TauArray[], /* Device pointer*/ - int *info, - int batchSize) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, int, double *[], int, double *[], int *, int); +cublasStatus_t CUBLASWINAPI cublasDgeqrfBatched( + cublasHandle_t handle, int m, int n, double *Aarray[], /*Device pointer*/ + int lda, double *TauArray[], /* Device pointer*/ + int *info, int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, int, double *[], int, double *[], int *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDgeqrfBatched"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, m, n, Aarray, lda, TauArray, info, batchSize); } -cublasStatus_t CUBLASWINAPI cublasCgeqrfBatched( cublasHandle_t handle, - int m, - int n, - cuComplex *Aarray[], /*Device pointer*/ - int lda, - cuComplex *TauArray[], /* Device pointer*/ - int *info, - int batchSize) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, int, cuComplex *[], int, cuComplex *[], int *, int); +cublasStatus_t CUBLASWINAPI cublasCgeqrfBatched( + cublasHandle_t handle, int m, int n, cuComplex *Aarray[], /*Device pointer*/ + int lda, cuComplex *TauArray[], /* Device pointer*/ + int *info, int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, int, cuComplex *[], int, cuComplex *[], int *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgeqrfBatched"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, m, n, Aarray, lda, TauArray, info, batchSize); } -cublasStatus_t CUBLASWINAPI cublasZgeqrfBatched( cublasHandle_t handle, - int m, - int n, - cuDoubleComplex *Aarray[], /*Device pointer*/ - int lda, - cuDoubleComplex *TauArray[], /* Device pointer*/ - int *info, - int batchSize) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, int, int, cuDoubleComplex *[], int, cuDoubleComplex *[], int *, int); +cublasStatus_t CUBLASWINAPI +cublasZgeqrfBatched(cublasHandle_t handle, int m, int n, + cuDoubleComplex *Aarray[], /*Device pointer*/ + int lda, cuDoubleComplex *TauArray[], /* Device pointer*/ + int *info, int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, int, int, cuDoubleComplex *[], int, cuDoubleComplex *[], + int *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZgeqrfBatched"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, m, n, Aarray, lda, TauArray, info, batchSize); } -cublasStatus_t CUBLASWINAPI cublasSgelsBatched( cublasHandle_t handle, - cublasOperation_t trans, - int m, - int n, - int nrhs, - float *Aarray[], /*Device pointer*/ - int lda, - float *Carray[], /* Device pointer*/ - int ldc, - int *info, - int *devInfoArray, /* Device pointer*/ - int batchSize ) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, int, int, int, float *[], int, float *[], int, int *, int *, int); +cublasStatus_t CUBLASWINAPI +cublasSgelsBatched(cublasHandle_t handle, cublasOperation_t trans, int m, int n, + int nrhs, float *Aarray[], /*Device pointer*/ + int lda, float *Carray[], /* Device pointer*/ + int ldc, int *info, int *devInfoArray, /* Device pointer*/ + int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, int, int, int, float *[], int, + float *[], int, int *, int *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSgelsBatched"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, trans, m, n, nrhs, Aarray, lda, Carray, ldc, info, devInfoArray, batchSize); + return func_ptr(handle, trans, m, n, nrhs, Aarray, lda, Carray, ldc, info, + devInfoArray, batchSize); } -cublasStatus_t CUBLASWINAPI cublasDgelsBatched( cublasHandle_t handle, - cublasOperation_t trans, - int m, - int n, - int nrhs, - double *Aarray[], /*Device pointer*/ - int lda, - double *Carray[], /* Device pointer*/ - int ldc, - int *info, - int *devInfoArray, /* Device pointer*/ - int batchSize) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, int, int, int, double *[], int, double *[], int, int *, int *, int); +cublasStatus_t CUBLASWINAPI +cublasDgelsBatched(cublasHandle_t handle, cublasOperation_t trans, int m, int n, + int nrhs, double *Aarray[], /*Device pointer*/ + int lda, double *Carray[], /* Device pointer*/ + int ldc, int *info, int *devInfoArray, /* Device pointer*/ + int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, int, int, int, double *[], int, + double *[], int, int *, int *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDgelsBatched"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, trans, m, n, nrhs, Aarray, lda, Carray, ldc, info, devInfoArray, batchSize); + return func_ptr(handle, trans, m, n, nrhs, Aarray, lda, Carray, ldc, info, + devInfoArray, batchSize); } -cublasStatus_t CUBLASWINAPI cublasCgelsBatched( cublasHandle_t handle, - cublasOperation_t trans, - int m, - int n, - int nrhs, - cuComplex *Aarray[], /*Device pointer*/ - int lda, - cuComplex *Carray[], /* Device pointer*/ - int ldc, - int *info, - int *devInfoArray, - int batchSize) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, int, int, int, cuComplex *[], int, cuComplex *[], int, int *, int *, int); +cublasStatus_t CUBLASWINAPI +cublasCgelsBatched(cublasHandle_t handle, cublasOperation_t trans, int m, int n, + int nrhs, cuComplex *Aarray[], /*Device pointer*/ + int lda, cuComplex *Carray[], /* Device pointer*/ + int ldc, int *info, int *devInfoArray, int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, int, int, int, cuComplex *[], int, + cuComplex *[], int, int *, int *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgelsBatched"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, trans, m, n, nrhs, Aarray, lda, Carray, ldc, info, devInfoArray, batchSize); + return func_ptr(handle, trans, m, n, nrhs, Aarray, lda, Carray, ldc, info, + devInfoArray, batchSize); } -cublasStatus_t CUBLASWINAPI cublasZgelsBatched( cublasHandle_t handle, - cublasOperation_t trans, - int m, - int n, - int nrhs, - cuDoubleComplex *Aarray[], /*Device pointer*/ - int lda, - cuDoubleComplex *Carray[], /* Device pointer*/ - int ldc, - int *info, - int *devInfoArray, - int batchSize) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasOperation_t, int, int, int, cuDoubleComplex *[], int, cuDoubleComplex *[], int, int *, int *, int); +cublasStatus_t CUBLASWINAPI +cublasZgelsBatched(cublasHandle_t handle, cublasOperation_t trans, int m, int n, + int nrhs, cuDoubleComplex *Aarray[], /*Device pointer*/ + int lda, cuDoubleComplex *Carray[], /* Device pointer*/ + int ldc, int *info, int *devInfoArray, int batchSize) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasOperation_t, int, int, int, cuDoubleComplex *[], + int, cuDoubleComplex *[], int, int *, int *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZgelsBatched"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, trans, m, n, nrhs, Aarray, lda, Carray, ldc, info, devInfoArray, batchSize); + return func_ptr(handle, trans, m, n, nrhs, Aarray, lda, Carray, ldc, info, + devInfoArray, batchSize); } cublasStatus_t CUBLASWINAPI cublasSdgmm(cublasHandle_t handle, - cublasSideMode_t mode, - int m, - int n, - const float *A, - int lda, - const float *x, - int incx, - float *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, int, int, const float *, int, const float *, int, float *, int); + cublasSideMode_t mode, int m, int n, + const float *A, int lda, const float *x, + int incx, float *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, int, int, const float *, int, + const float *, int, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSdgmm"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, mode, m, n, A, lda, x, incx, C, ldc); } cublasStatus_t CUBLASWINAPI cublasDdgmm(cublasHandle_t handle, - cublasSideMode_t mode, - int m, - int n, - const double *A, - int lda, - const double *x, - int incx, - double *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, int, int, const double *, int, const double *, int, double *, int); + cublasSideMode_t mode, int m, int n, + const double *A, int lda, + const double *x, int incx, double *C, + int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, int, int, const double *, int, + const double *, int, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDdgmm"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, mode, m, n, A, lda, x, incx, C, ldc); } cublasStatus_t CUBLASWINAPI cublasCdgmm(cublasHandle_t handle, - cublasSideMode_t mode, - int m, - int n, - const cuComplex *A, - int lda, - const cuComplex *x, - int incx, - cuComplex *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, int, int, const cuComplex *, int, const cuComplex *, int, cuComplex *, int); + cublasSideMode_t mode, int m, int n, + const cuComplex *A, int lda, + const cuComplex *x, int incx, + cuComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, int, int, const cuComplex *, int, + const cuComplex *, int, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCdgmm"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, mode, m, n, A, lda, x, incx, C, ldc); } cublasStatus_t CUBLASWINAPI cublasZdgmm(cublasHandle_t handle, - cublasSideMode_t mode, - int m, - int n, - const cuDoubleComplex *A, - int lda, - const cuDoubleComplex *x, - int incx, - cuDoubleComplex *C, - int ldc) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasSideMode_t, int, int, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); + cublasSideMode_t mode, int m, int n, + const cuDoubleComplex *A, int lda, + const cuDoubleComplex *x, int incx, + cuDoubleComplex *C, int ldc) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasSideMode_t, int, int, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZdgmm"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, mode, m, n, A, lda, x, incx, C, ldc); } -cublasStatus_t CUBLASWINAPI cublasStpttr ( cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const float *AP, - float *A, - int lda ) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const float *, float *, int); +cublasStatus_t CUBLASWINAPI cublasStpttr(cublasHandle_t handle, + cublasFillMode_t uplo, int n, + const float *AP, float *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const float *, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasStpttr"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, AP, A, lda); } -cublasStatus_t CUBLASWINAPI cublasDtpttr ( cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const double *AP, - double *A, - int lda ) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const double *, double *, int); +cublasStatus_t CUBLASWINAPI cublasDtpttr(cublasHandle_t handle, + cublasFillMode_t uplo, int n, + const double *AP, double *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const double *, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtpttr"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, AP, A, lda); } -cublasStatus_t CUBLASWINAPI cublasCtpttr ( cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const cuComplex *AP, - cuComplex *A, - int lda ) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const cuComplex *, cuComplex *, int); +cublasStatus_t CUBLASWINAPI cublasCtpttr(cublasHandle_t handle, + cublasFillMode_t uplo, int n, + const cuComplex *AP, cuComplex *A, + int lda) { + using FuncPtr = + cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, + const cuComplex *, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtpttr"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, AP, A, lda); } -cublasStatus_t CUBLASWINAPI cublasZtpttr ( cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const cuDoubleComplex *AP, - cuDoubleComplex *A, - int lda ) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, cuDoubleComplex *, int); +cublasStatus_t CUBLASWINAPI cublasZtpttr(cublasHandle_t handle, + cublasFillMode_t uplo, int n, + const cuDoubleComplex *AP, + cuDoubleComplex *A, int lda) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, + cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtpttr"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, AP, A, lda); } -cublasStatus_t CUBLASWINAPI cublasStrttp ( cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const float *A, - int lda, - float *AP ) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const float *, int, float *); +cublasStatus_t CUBLASWINAPI cublasStrttp(cublasHandle_t handle, + cublasFillMode_t uplo, int n, + const float *A, int lda, float *AP) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const float *, int, float *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasStrttp"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, A, lda, AP); } -cublasStatus_t CUBLASWINAPI cublasDtrttp ( cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const double *A, - int lda, - double *AP ) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const double *, int, double *); +cublasStatus_t CUBLASWINAPI cublasDtrttp(cublasHandle_t handle, + cublasFillMode_t uplo, int n, + const double *A, int lda, double *AP) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const double *, int, double *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtrttp"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, A, lda, AP); } -cublasStatus_t CUBLASWINAPI cublasCtrttp ( cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const cuComplex *A, - int lda, - cuComplex *AP ) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const cuComplex *, int, cuComplex *); +cublasStatus_t CUBLASWINAPI cublasCtrttp(cublasHandle_t handle, + cublasFillMode_t uplo, int n, + const cuComplex *A, int lda, + cuComplex *AP) { + using FuncPtr = + cublasStatus_t(CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, + const cuComplex *, int, cuComplex *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtrttp"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, A, lda, AP); } -cublasStatus_t CUBLASWINAPI cublasZtrttp ( cublasHandle_t handle, - cublasFillMode_t uplo, - int n, - const cuDoubleComplex *A, - int lda, - cuDoubleComplex *AP ) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cublasHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, int, cuDoubleComplex *); +cublasStatus_t CUBLASWINAPI cublasZtrttp(cublasHandle_t handle, + cublasFillMode_t uplo, int n, + const cuDoubleComplex *A, int lda, + cuDoubleComplex *AP) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)( + cublasHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, int, + cuDoubleComplex *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtrttp"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, uplo, n, A, lda, AP); } -cublasStatus CUBLASWINAPI cublasInit (void) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(); +cublasStatus CUBLASWINAPI cublasInit(void) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(); static auto func_ptr = LoadSymbol<FuncPtr>("cublasInit"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(); } -cublasStatus CUBLASWINAPI cublasShutdown (void) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(); +cublasStatus CUBLASWINAPI cublasShutdown(void) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(); static auto func_ptr = LoadSymbol<FuncPtr>("cublasShutdown"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(); } -cublasStatus CUBLASWINAPI cublasGetError (void) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(); +cublasStatus CUBLASWINAPI cublasGetError(void) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(); static auto func_ptr = LoadSymbol<FuncPtr>("cublasGetError"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(); } cublasStatus CUBLASWINAPI cublasGetVersion(int *version) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(int *); + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(int *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasGetVersion"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(version); } -cublasStatus CUBLASWINAPI cublasAlloc (int n, int elemSize, void **devicePtr) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(int, int, void **); +cublasStatus CUBLASWINAPI cublasAlloc(int n, int elemSize, void **devicePtr) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(int, int, void **); static auto func_ptr = LoadSymbol<FuncPtr>("cublasAlloc"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(n, elemSize, devicePtr); } -cublasStatus CUBLASWINAPI cublasFree (void *devicePtr) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(void *); +cublasStatus CUBLASWINAPI cublasFree(void *devicePtr) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(void *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasFree"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(devicePtr); } -cublasStatus CUBLASWINAPI cublasSetKernelStream (cudaStream_t stream) { - using FuncPtr = cublasStatus_t (CUBLASWINAPI *)(cudaStream_t); +cublasStatus CUBLASWINAPI cublasSetKernelStream(cudaStream_t stream) { + using FuncPtr = cublasStatus_t(CUBLASWINAPI *)(cudaStream_t); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSetKernelStream"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(stream); } -float CUBLASWINAPI cublasSnrm2 (int n, const float *x, int incx) { - using FuncPtr = float (CUBLASWINAPI *)(int, const float *, int); +float CUBLASWINAPI cublasSnrm2(int n, const float *x, int incx) { + using FuncPtr = float(CUBLASWINAPI *)(int, const float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSnrm2"); if (!func_ptr) LogFatalSymbolNotFound("cublasSnrm2"); return func_ptr(n, x, incx); } -double CUBLASWINAPI cublasDnrm2 (int n, const double *x, int incx) { - using FuncPtr = double (CUBLASWINAPI *)(int, const double *, int); +double CUBLASWINAPI cublasDnrm2(int n, const double *x, int incx) { + using FuncPtr = double(CUBLASWINAPI *)(int, const double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDnrm2"); if (!func_ptr) LogFatalSymbolNotFound("cublasDnrm2"); return func_ptr(n, x, incx); } -float CUBLASWINAPI cublasScnrm2 (int n, const cuComplex *x, int incx) { - using FuncPtr = float (CUBLASWINAPI *)(int, const cuComplex *, int); +float CUBLASWINAPI cublasScnrm2(int n, const cuComplex *x, int incx) { + using FuncPtr = float(CUBLASWINAPI *)(int, const cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasScnrm2"); if (!func_ptr) LogFatalSymbolNotFound("cublasScnrm2"); return func_ptr(n, x, incx); } -double CUBLASWINAPI cublasDznrm2 (int n, const cuDoubleComplex *x, int incx) { - using FuncPtr = double (CUBLASWINAPI *)(int, const cuDoubleComplex *, int); +double CUBLASWINAPI cublasDznrm2(int n, const cuDoubleComplex *x, int incx) { + using FuncPtr = double(CUBLASWINAPI *)(int, const cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDznrm2"); if (!func_ptr) LogFatalSymbolNotFound("cublasDznrm2"); return func_ptr(n, x, incx); } -float CUBLASWINAPI cublasSdot (int n, const float *x, int incx, const float *y, - int incy) { - using FuncPtr = float (CUBLASWINAPI *)(int, const float *, int, const float *, int); +float CUBLASWINAPI cublasSdot(int n, const float *x, int incx, const float *y, + int incy) { + using FuncPtr = + float(CUBLASWINAPI *)(int, const float *, int, const float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSdot"); if (!func_ptr) LogFatalSymbolNotFound("cublasSdot"); return func_ptr(n, x, incx, y, incy); } -double CUBLASWINAPI cublasDdot (int n, const double *x, int incx, const double *y, - int incy) { - using FuncPtr = double (CUBLASWINAPI *)(int, const double *, int, const double *, int); +double CUBLASWINAPI cublasDdot(int n, const double *x, int incx, + const double *y, int incy) { + using FuncPtr = + double(CUBLASWINAPI *)(int, const double *, int, const double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDdot"); if (!func_ptr) LogFatalSymbolNotFound("cublasDdot"); return func_ptr(n, x, incx, y, incy); } -cuComplex CUBLASWINAPI cublasCdotu (int n, const cuComplex *x, int incx, const cuComplex *y, - int incy) { - using FuncPtr = cuComplex (CUBLASWINAPI *)(int, const cuComplex *, int, const cuComplex *, int); +cuComplex CUBLASWINAPI cublasCdotu(int n, const cuComplex *x, int incx, + const cuComplex *y, int incy) { + using FuncPtr = cuComplex(CUBLASWINAPI *)(int, const cuComplex *, int, + const cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCdotu"); if (!func_ptr) LogFatalSymbolNotFound("cublasCdotu"); return func_ptr(n, x, incx, y, incy); } -cuComplex CUBLASWINAPI cublasCdotc (int n, const cuComplex *x, int incx, const cuComplex *y, - int incy) { - using FuncPtr = cuComplex (CUBLASWINAPI *)(int, const cuComplex *, int, const cuComplex *, int); +cuComplex CUBLASWINAPI cublasCdotc(int n, const cuComplex *x, int incx, + const cuComplex *y, int incy) { + using FuncPtr = cuComplex(CUBLASWINAPI *)(int, const cuComplex *, int, + const cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCdotc"); if (!func_ptr) LogFatalSymbolNotFound("cublasCdotc"); return func_ptr(n, x, incx, y, incy); } -cuDoubleComplex CUBLASWINAPI cublasZdotu (int n, const cuDoubleComplex *x, int incx, const cuDoubleComplex *y, - int incy) { - using FuncPtr = cuDoubleComplex (CUBLASWINAPI *)(int, const cuDoubleComplex *, int, const cuDoubleComplex *, int); +cuDoubleComplex CUBLASWINAPI cublasZdotu(int n, const cuDoubleComplex *x, + int incx, const cuDoubleComplex *y, + int incy) { + using FuncPtr = cuDoubleComplex(CUBLASWINAPI *)( + int, const cuDoubleComplex *, int, const cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZdotu"); if (!func_ptr) LogFatalSymbolNotFound("cublasZdotu"); return func_ptr(n, x, incx, y, incy); } -cuDoubleComplex CUBLASWINAPI cublasZdotc (int n, const cuDoubleComplex *x, int incx, const cuDoubleComplex *y, - int incy) { - using FuncPtr = cuDoubleComplex (CUBLASWINAPI *)(int, const cuDoubleComplex *, int, const cuDoubleComplex *, int); +cuDoubleComplex CUBLASWINAPI cublasZdotc(int n, const cuDoubleComplex *x, + int incx, const cuDoubleComplex *y, + int incy) { + using FuncPtr = cuDoubleComplex(CUBLASWINAPI *)( + int, const cuDoubleComplex *, int, const cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZdotc"); if (!func_ptr) LogFatalSymbolNotFound("cublasZdotc"); return func_ptr(n, x, incx, y, incy); } -void CUBLASWINAPI cublasSscal (int n, float alpha, float *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(int, float, float *, int); +void CUBLASWINAPI cublasSscal(int n, float alpha, float *x, int incx) { + using FuncPtr = void(CUBLASWINAPI *)(int, float, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSscal"); if (!func_ptr) LogFatalSymbolNotFound("cublasSscal"); return func_ptr(n, alpha, x, incx); } -void CUBLASWINAPI cublasDscal (int n, double alpha, double *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(int, double, double *, int); +void CUBLASWINAPI cublasDscal(int n, double alpha, double *x, int incx) { + using FuncPtr = void(CUBLASWINAPI *)(int, double, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDscal"); if (!func_ptr) LogFatalSymbolNotFound("cublasDscal"); return func_ptr(n, alpha, x, incx); } -void CUBLASWINAPI cublasCscal (int n, cuComplex alpha, cuComplex *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(int, cuComplex, cuComplex *, int); +void CUBLASWINAPI cublasCscal(int n, cuComplex alpha, cuComplex *x, int incx) { + using FuncPtr = void(CUBLASWINAPI *)(int, cuComplex, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCscal"); if (!func_ptr) LogFatalSymbolNotFound("cublasCscal"); return func_ptr(n, alpha, x, incx); } -void CUBLASWINAPI cublasZscal (int n, cuDoubleComplex alpha, cuDoubleComplex *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(int, cuDoubleComplex, cuDoubleComplex *, int); +void CUBLASWINAPI cublasZscal(int n, cuDoubleComplex alpha, cuDoubleComplex *x, + int incx) { + using FuncPtr = + void(CUBLASWINAPI *)(int, cuDoubleComplex, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZscal"); if (!func_ptr) LogFatalSymbolNotFound("cublasZscal"); return func_ptr(n, alpha, x, incx); } -void CUBLASWINAPI cublasCsscal (int n, float alpha, cuComplex *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(int, float, cuComplex *, int); +void CUBLASWINAPI cublasCsscal(int n, float alpha, cuComplex *x, int incx) { + using FuncPtr = void(CUBLASWINAPI *)(int, float, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCsscal"); if (!func_ptr) LogFatalSymbolNotFound("cublasCsscal"); return func_ptr(n, alpha, x, incx); } -void CUBLASWINAPI cublasZdscal (int n, double alpha, cuDoubleComplex *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(int, double, cuDoubleComplex *, int); +void CUBLASWINAPI cublasZdscal(int n, double alpha, cuDoubleComplex *x, + int incx) { + using FuncPtr = void(CUBLASWINAPI *)(int, double, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZdscal"); if (!func_ptr) LogFatalSymbolNotFound("cublasZdscal"); return func_ptr(n, alpha, x, incx); } -void CUBLASWINAPI cublasSaxpy (int n, float alpha, const float *x, int incx, - float *y, int incy) { - using FuncPtr = void (CUBLASWINAPI *)(int, float, const float *, int, float *, int); +void CUBLASWINAPI cublasSaxpy(int n, float alpha, const float *x, int incx, + float *y, int incy) { + using FuncPtr = + void(CUBLASWINAPI *)(int, float, const float *, int, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSaxpy"); if (!func_ptr) LogFatalSymbolNotFound("cublasSaxpy"); return func_ptr(n, alpha, x, incx, y, incy); } -void CUBLASWINAPI cublasDaxpy (int n, double alpha, const double *x, - int incx, double *y, int incy) { - using FuncPtr = void (CUBLASWINAPI *)(int, double, const double *, int, double *, int); +void CUBLASWINAPI cublasDaxpy(int n, double alpha, const double *x, int incx, + double *y, int incy) { + using FuncPtr = + void(CUBLASWINAPI *)(int, double, const double *, int, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDaxpy"); if (!func_ptr) LogFatalSymbolNotFound("cublasDaxpy"); return func_ptr(n, alpha, x, incx, y, incy); } -void CUBLASWINAPI cublasCaxpy (int n, cuComplex alpha, const cuComplex *x, - int incx, cuComplex *y, int incy) { - using FuncPtr = void (CUBLASWINAPI *)(int, cuComplex, const cuComplex *, int, cuComplex *, int); +void CUBLASWINAPI cublasCaxpy(int n, cuComplex alpha, const cuComplex *x, + int incx, cuComplex *y, int incy) { + using FuncPtr = void(CUBLASWINAPI *)(int, cuComplex, const cuComplex *, int, + cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCaxpy"); if (!func_ptr) LogFatalSymbolNotFound("cublasCaxpy"); return func_ptr(n, alpha, x, incx, y, incy); } -void CUBLASWINAPI cublasZaxpy (int n, cuDoubleComplex alpha, const cuDoubleComplex *x, - int incx, cuDoubleComplex *y, int incy) { - using FuncPtr = void (CUBLASWINAPI *)(int, cuDoubleComplex, const cuDoubleComplex *, int, cuDoubleComplex *, int); +void CUBLASWINAPI cublasZaxpy(int n, cuDoubleComplex alpha, + const cuDoubleComplex *x, int incx, + cuDoubleComplex *y, int incy) { + using FuncPtr = + void(CUBLASWINAPI *)(int, cuDoubleComplex, const cuDoubleComplex *, int, + cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZaxpy"); if (!func_ptr) LogFatalSymbolNotFound("cublasZaxpy"); return func_ptr(n, alpha, x, incx, y, incy); } -void CUBLASWINAPI cublasScopy (int n, const float *x, int incx, float *y, - int incy) { - using FuncPtr = void (CUBLASWINAPI *)(int, const float *, int, float *, int); +void CUBLASWINAPI cublasScopy(int n, const float *x, int incx, float *y, + int incy) { + using FuncPtr = void(CUBLASWINAPI *)(int, const float *, int, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasScopy"); if (!func_ptr) LogFatalSymbolNotFound("cublasScopy"); return func_ptr(n, x, incx, y, incy); } -void CUBLASWINAPI cublasDcopy (int n, const double *x, int incx, double *y, - int incy) { - using FuncPtr = void (CUBLASWINAPI *)(int, const double *, int, double *, int); +void CUBLASWINAPI cublasDcopy(int n, const double *x, int incx, double *y, + int incy) { + using FuncPtr = void(CUBLASWINAPI *)(int, const double *, int, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDcopy"); if (!func_ptr) LogFatalSymbolNotFound("cublasDcopy"); return func_ptr(n, x, incx, y, incy); } -void CUBLASWINAPI cublasCcopy (int n, const cuComplex *x, int incx, cuComplex *y, - int incy) { - using FuncPtr = void (CUBLASWINAPI *)(int, const cuComplex *, int, cuComplex *, int); +void CUBLASWINAPI cublasCcopy(int n, const cuComplex *x, int incx, cuComplex *y, + int incy) { + using FuncPtr = + void(CUBLASWINAPI *)(int, const cuComplex *, int, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCcopy"); if (!func_ptr) LogFatalSymbolNotFound("cublasCcopy"); return func_ptr(n, x, incx, y, incy); } -void CUBLASWINAPI cublasZcopy (int n, const cuDoubleComplex *x, int incx, cuDoubleComplex *y, - int incy) { - using FuncPtr = void (CUBLASWINAPI *)(int, const cuDoubleComplex *, int, cuDoubleComplex *, int); +void CUBLASWINAPI cublasZcopy(int n, const cuDoubleComplex *x, int incx, + cuDoubleComplex *y, int incy) { + using FuncPtr = void(CUBLASWINAPI *)(int, const cuDoubleComplex *, int, + cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZcopy"); if (!func_ptr) LogFatalSymbolNotFound("cublasZcopy"); return func_ptr(n, x, incx, y, incy); } -void CUBLASWINAPI cublasSswap (int n, float *x, int incx, float *y, int incy) { - using FuncPtr = void (CUBLASWINAPI *)(int, float *, int, float *, int); +void CUBLASWINAPI cublasSswap(int n, float *x, int incx, float *y, int incy) { + using FuncPtr = void(CUBLASWINAPI *)(int, float *, int, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSswap"); if (!func_ptr) LogFatalSymbolNotFound("cublasSswap"); return func_ptr(n, x, incx, y, incy); } -void CUBLASWINAPI cublasDswap (int n, double *x, int incx, double *y, int incy) { - using FuncPtr = void (CUBLASWINAPI *)(int, double *, int, double *, int); +void CUBLASWINAPI cublasDswap(int n, double *x, int incx, double *y, int incy) { + using FuncPtr = void(CUBLASWINAPI *)(int, double *, int, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDswap"); if (!func_ptr) LogFatalSymbolNotFound("cublasDswap"); return func_ptr(n, x, incx, y, incy); } -void CUBLASWINAPI cublasCswap (int n, cuComplex *x, int incx, cuComplex *y, int incy) { - using FuncPtr = void (CUBLASWINAPI *)(int, cuComplex *, int, cuComplex *, int); +void CUBLASWINAPI cublasCswap(int n, cuComplex *x, int incx, cuComplex *y, + int incy) { + using FuncPtr = void(CUBLASWINAPI *)(int, cuComplex *, int, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCswap"); if (!func_ptr) LogFatalSymbolNotFound("cublasCswap"); return func_ptr(n, x, incx, y, incy); } -void CUBLASWINAPI cublasZswap (int n, cuDoubleComplex *x, int incx, cuDoubleComplex *y, int incy) { - using FuncPtr = void (CUBLASWINAPI *)(int, cuDoubleComplex *, int, cuDoubleComplex *, int); +void CUBLASWINAPI cublasZswap(int n, cuDoubleComplex *x, int incx, + cuDoubleComplex *y, int incy) { + using FuncPtr = + void(CUBLASWINAPI *)(int, cuDoubleComplex *, int, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZswap"); if (!func_ptr) LogFatalSymbolNotFound("cublasZswap"); return func_ptr(n, x, incx, y, incy); } -int CUBLASWINAPI cublasIsamax (int n, const float *x, int incx) { - using FuncPtr = int (CUBLASWINAPI *)(int, const float *, int); +int CUBLASWINAPI cublasIsamax(int n, const float *x, int incx) { + using FuncPtr = int(CUBLASWINAPI *)(int, const float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasIsamax"); if (!func_ptr) LogFatalSymbolNotFound("cublasIsamax"); return func_ptr(n, x, incx); } -int CUBLASWINAPI cublasIdamax (int n, const double *x, int incx) { - using FuncPtr = int (CUBLASWINAPI *)(int, const double *, int); +int CUBLASWINAPI cublasIdamax(int n, const double *x, int incx) { + using FuncPtr = int(CUBLASWINAPI *)(int, const double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasIdamax"); if (!func_ptr) LogFatalSymbolNotFound("cublasIdamax"); return func_ptr(n, x, incx); } -int CUBLASWINAPI cublasIcamax (int n, const cuComplex *x, int incx) { - using FuncPtr = int (CUBLASWINAPI *)(int, const cuComplex *, int); +int CUBLASWINAPI cublasIcamax(int n, const cuComplex *x, int incx) { + using FuncPtr = int(CUBLASWINAPI *)(int, const cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasIcamax"); if (!func_ptr) LogFatalSymbolNotFound("cublasIcamax"); return func_ptr(n, x, incx); } -int CUBLASWINAPI cublasIzamax (int n, const cuDoubleComplex *x, int incx) { - using FuncPtr = int (CUBLASWINAPI *)(int, const cuDoubleComplex *, int); +int CUBLASWINAPI cublasIzamax(int n, const cuDoubleComplex *x, int incx) { + using FuncPtr = int(CUBLASWINAPI *)(int, const cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasIzamax"); if (!func_ptr) LogFatalSymbolNotFound("cublasIzamax"); return func_ptr(n, x, incx); } -int CUBLASWINAPI cublasIsamin (int n, const float *x, int incx) { - using FuncPtr = int (CUBLASWINAPI *)(int, const float *, int); +int CUBLASWINAPI cublasIsamin(int n, const float *x, int incx) { + using FuncPtr = int(CUBLASWINAPI *)(int, const float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasIsamin"); if (!func_ptr) LogFatalSymbolNotFound("cublasIsamin"); return func_ptr(n, x, incx); } -int CUBLASWINAPI cublasIdamin (int n, const double *x, int incx) { - using FuncPtr = int (CUBLASWINAPI *)(int, const double *, int); +int CUBLASWINAPI cublasIdamin(int n, const double *x, int incx) { + using FuncPtr = int(CUBLASWINAPI *)(int, const double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasIdamin"); if (!func_ptr) LogFatalSymbolNotFound("cublasIdamin"); return func_ptr(n, x, incx); } -int CUBLASWINAPI cublasIcamin (int n, const cuComplex *x, int incx) { - using FuncPtr = int (CUBLASWINAPI *)(int, const cuComplex *, int); +int CUBLASWINAPI cublasIcamin(int n, const cuComplex *x, int incx) { + using FuncPtr = int(CUBLASWINAPI *)(int, const cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasIcamin"); if (!func_ptr) LogFatalSymbolNotFound("cublasIcamin"); return func_ptr(n, x, incx); } -int CUBLASWINAPI cublasIzamin (int n, const cuDoubleComplex *x, int incx) { - using FuncPtr = int (CUBLASWINAPI *)(int, const cuDoubleComplex *, int); +int CUBLASWINAPI cublasIzamin(int n, const cuDoubleComplex *x, int incx) { + using FuncPtr = int(CUBLASWINAPI *)(int, const cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasIzamin"); if (!func_ptr) LogFatalSymbolNotFound("cublasIzamin"); return func_ptr(n, x, incx); } -float CUBLASWINAPI cublasSasum (int n, const float *x, int incx) { - using FuncPtr = float (CUBLASWINAPI *)(int, const float *, int); +float CUBLASWINAPI cublasSasum(int n, const float *x, int incx) { + using FuncPtr = float(CUBLASWINAPI *)(int, const float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSasum"); if (!func_ptr) LogFatalSymbolNotFound("cublasSasum"); return func_ptr(n, x, incx); } -double CUBLASWINAPI cublasDasum (int n, const double *x, int incx) { - using FuncPtr = double (CUBLASWINAPI *)(int, const double *, int); +double CUBLASWINAPI cublasDasum(int n, const double *x, int incx) { + using FuncPtr = double(CUBLASWINAPI *)(int, const double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDasum"); if (!func_ptr) LogFatalSymbolNotFound("cublasDasum"); return func_ptr(n, x, incx); } -float CUBLASWINAPI cublasScasum (int n, const cuComplex *x, int incx) { - using FuncPtr = float (CUBLASWINAPI *)(int, const cuComplex *, int); +float CUBLASWINAPI cublasScasum(int n, const cuComplex *x, int incx) { + using FuncPtr = float(CUBLASWINAPI *)(int, const cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasScasum"); if (!func_ptr) LogFatalSymbolNotFound("cublasScasum"); return func_ptr(n, x, incx); } -double CUBLASWINAPI cublasDzasum (int n, const cuDoubleComplex *x, int incx) { - using FuncPtr = double (CUBLASWINAPI *)(int, const cuDoubleComplex *, int); +double CUBLASWINAPI cublasDzasum(int n, const cuDoubleComplex *x, int incx) { + using FuncPtr = double(CUBLASWINAPI *)(int, const cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDzasum"); if (!func_ptr) LogFatalSymbolNotFound("cublasDzasum"); return func_ptr(n, x, incx); } -void CUBLASWINAPI cublasSrot (int n, float *x, int incx, float *y, int incy, - float sc, float ss) { - using FuncPtr = void (CUBLASWINAPI *)(int, float *, int, float *, int, float, float); +void CUBLASWINAPI cublasSrot(int n, float *x, int incx, float *y, int incy, + float sc, float ss) { + using FuncPtr = + void(CUBLASWINAPI *)(int, float *, int, float *, int, float, float); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSrot"); if (!func_ptr) LogFatalSymbolNotFound("cublasSrot"); return func_ptr(n, x, incx, y, incy, sc, ss); } -void CUBLASWINAPI cublasDrot (int n, double *x, int incx, double *y, int incy, - double sc, double ss) { - using FuncPtr = void (CUBLASWINAPI *)(int, double *, int, double *, int, double, double); +void CUBLASWINAPI cublasDrot(int n, double *x, int incx, double *y, int incy, + double sc, double ss) { + using FuncPtr = + void(CUBLASWINAPI *)(int, double *, int, double *, int, double, double); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDrot"); if (!func_ptr) LogFatalSymbolNotFound("cublasDrot"); return func_ptr(n, x, incx, y, incy, sc, ss); } -void CUBLASWINAPI cublasCrot (int n, cuComplex *x, int incx, cuComplex *y, - int incy, float c, cuComplex s) { - using FuncPtr = void (CUBLASWINAPI *)(int, cuComplex *, int, cuComplex *, int, float, cuComplex); +void CUBLASWINAPI cublasCrot(int n, cuComplex *x, int incx, cuComplex *y, + int incy, float c, cuComplex s) { + using FuncPtr = void(CUBLASWINAPI *)(int, cuComplex *, int, cuComplex *, int, + float, cuComplex); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCrot"); if (!func_ptr) LogFatalSymbolNotFound("cublasCrot"); return func_ptr(n, x, incx, y, incy, c, s); } -void CUBLASWINAPI cublasZrot (int n, cuDoubleComplex *x, int incx, - cuDoubleComplex *y, int incy, double sc, - cuDoubleComplex cs) { - using FuncPtr = void (CUBLASWINAPI *)(int, cuDoubleComplex *, int, cuDoubleComplex *, int, double, cuDoubleComplex); +void CUBLASWINAPI cublasZrot(int n, cuDoubleComplex *x, int incx, + cuDoubleComplex *y, int incy, double sc, + cuDoubleComplex cs) { + using FuncPtr = + void(CUBLASWINAPI *)(int, cuDoubleComplex *, int, cuDoubleComplex *, int, + double, cuDoubleComplex); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZrot"); if (!func_ptr) LogFatalSymbolNotFound("cublasZrot"); return func_ptr(n, x, incx, y, incy, sc, cs); } -void CUBLASWINAPI cublasCsrot (int n, cuComplex *x, int incx, cuComplex *y, - int incy, float c, float s) { - using FuncPtr = void (CUBLASWINAPI *)(int, cuComplex *, int, cuComplex *, int, float, float); +void CUBLASWINAPI cublasCsrot(int n, cuComplex *x, int incx, cuComplex *y, + int incy, float c, float s) { + using FuncPtr = void(CUBLASWINAPI *)(int, cuComplex *, int, cuComplex *, int, + float, float); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCsrot"); if (!func_ptr) LogFatalSymbolNotFound("cublasCsrot"); return func_ptr(n, x, incx, y, incy, c, s); } -void CUBLASWINAPI cublasZdrot (int n, cuDoubleComplex *x, int incx, - cuDoubleComplex *y, int incy, double c, double s) { - using FuncPtr = void (CUBLASWINAPI *)(int, cuDoubleComplex *, int, cuDoubleComplex *, int, double, double); +void CUBLASWINAPI cublasZdrot(int n, cuDoubleComplex *x, int incx, + cuDoubleComplex *y, int incy, double c, + double s) { + using FuncPtr = void(CUBLASWINAPI *)(int, cuDoubleComplex *, int, + cuDoubleComplex *, int, double, double); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZdrot"); if (!func_ptr) LogFatalSymbolNotFound("cublasZdrot"); return func_ptr(n, x, incx, y, incy, c, s); } -void CUBLASWINAPI cublasSrotg (float *sa, float *sb, float *sc, float *ss) { - using FuncPtr = void (CUBLASWINAPI *)(float *, float *, float *, float *); +void CUBLASWINAPI cublasSrotg(float *sa, float *sb, float *sc, float *ss) { + using FuncPtr = void(CUBLASWINAPI *)(float *, float *, float *, float *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSrotg"); if (!func_ptr) LogFatalSymbolNotFound("cublasSrotg"); return func_ptr(sa, sb, sc, ss); } -void CUBLASWINAPI cublasDrotg (double *sa, double *sb, double *sc, double *ss) { - using FuncPtr = void (CUBLASWINAPI *)(double *, double *, double *, double *); +void CUBLASWINAPI cublasDrotg(double *sa, double *sb, double *sc, double *ss) { + using FuncPtr = void(CUBLASWINAPI *)(double *, double *, double *, double *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDrotg"); if (!func_ptr) LogFatalSymbolNotFound("cublasDrotg"); return func_ptr(sa, sb, sc, ss); } -void CUBLASWINAPI cublasCrotg (cuComplex *ca, cuComplex cb, float *sc, - cuComplex *cs) { - using FuncPtr = void (CUBLASWINAPI *)(cuComplex *, cuComplex, float *, cuComplex *); +void CUBLASWINAPI cublasCrotg(cuComplex *ca, cuComplex cb, float *sc, + cuComplex *cs) { + using FuncPtr = + void(CUBLASWINAPI *)(cuComplex *, cuComplex, float *, cuComplex *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCrotg"); if (!func_ptr) LogFatalSymbolNotFound("cublasCrotg"); return func_ptr(ca, cb, sc, cs); } -void CUBLASWINAPI cublasZrotg (cuDoubleComplex *ca, cuDoubleComplex cb, double *sc, - cuDoubleComplex *cs) { - using FuncPtr = void (CUBLASWINAPI *)(cuDoubleComplex *, cuDoubleComplex, double *, cuDoubleComplex *); +void CUBLASWINAPI cublasZrotg(cuDoubleComplex *ca, cuDoubleComplex cb, + double *sc, cuDoubleComplex *cs) { + using FuncPtr = void(CUBLASWINAPI *)(cuDoubleComplex *, cuDoubleComplex, + double *, cuDoubleComplex *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZrotg"); if (!func_ptr) LogFatalSymbolNotFound("cublasZrotg"); return func_ptr(ca, cb, sc, cs); } -void CUBLASWINAPI cublasSrotm(int n, float *x, int incx, float *y, int incy, - const float* sparam) { - using FuncPtr = void (CUBLASWINAPI *)(int, float *, int, float *, int, const float *); +void CUBLASWINAPI cublasSrotm(int n, float *x, int incx, float *y, int incy, + const float *sparam) { + using FuncPtr = + void(CUBLASWINAPI *)(int, float *, int, float *, int, const float *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSrotm"); if (!func_ptr) LogFatalSymbolNotFound("cublasSrotm"); return func_ptr(n, x, incx, y, incy, sparam); } -void CUBLASWINAPI cublasDrotm(int n, double *x, int incx, double *y, int incy, - const double* sparam) { - using FuncPtr = void (CUBLASWINAPI *)(int, double *, int, double *, int, const double *); +void CUBLASWINAPI cublasDrotm(int n, double *x, int incx, double *y, int incy, + const double *sparam) { + using FuncPtr = + void(CUBLASWINAPI *)(int, double *, int, double *, int, const double *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDrotm"); if (!func_ptr) LogFatalSymbolNotFound("cublasDrotm"); return func_ptr(n, x, incx, y, incy, sparam); } -void CUBLASWINAPI cublasSrotmg (float *sd1, float *sd2, float *sx1, - const float *sy1, float* sparam) { - using FuncPtr = void (CUBLASWINAPI *)(float *, float *, float *, const float *, float *); +void CUBLASWINAPI cublasSrotmg(float *sd1, float *sd2, float *sx1, + const float *sy1, float *sparam) { + using FuncPtr = + void(CUBLASWINAPI *)(float *, float *, float *, const float *, float *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSrotmg"); if (!func_ptr) LogFatalSymbolNotFound("cublasSrotmg"); return func_ptr(sd1, sd2, sx1, sy1, sparam); } -void CUBLASWINAPI cublasDrotmg (double *sd1, double *sd2, double *sx1, - const double *sy1, double* sparam) { - using FuncPtr = void (CUBLASWINAPI *)(double *, double *, double *, const double *, double *); +void CUBLASWINAPI cublasDrotmg(double *sd1, double *sd2, double *sx1, + const double *sy1, double *sparam) { + using FuncPtr = void(CUBLASWINAPI *)(double *, double *, double *, + const double *, double *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDrotmg"); if (!func_ptr) LogFatalSymbolNotFound("cublasDrotmg"); return func_ptr(sd1, sd2, sx1, sy1, sparam); } -void CUBLASWINAPI cublasSgemv (char trans, int m, int n, float alpha, - const float *A, int lda, const float *x, int incx, - float beta, float *y, int incy) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, int, float, const float *, int, const float *, int, float, float *, int); +void CUBLASWINAPI cublasSgemv(char trans, int m, int n, float alpha, + const float *A, int lda, const float *x, int incx, + float beta, float *y, int incy) { + using FuncPtr = + void(CUBLASWINAPI *)(char, int, int, float, const float *, int, + const float *, int, float, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSgemv"); if (!func_ptr) LogFatalSymbolNotFound("cublasSgemv"); return func_ptr(trans, m, n, alpha, A, lda, x, incx, beta, y, incy); } -void CUBLASWINAPI cublasDgemv (char trans, int m, int n, double alpha, - const double *A, int lda, const double *x, int incx, - double beta, double *y, int incy) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, int, double, const double *, int, const double *, int, double, double *, int); +void CUBLASWINAPI cublasDgemv(char trans, int m, int n, double alpha, + const double *A, int lda, const double *x, + int incx, double beta, double *y, int incy) { + using FuncPtr = + void(CUBLASWINAPI *)(char, int, int, double, const double *, int, + const double *, int, double, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDgemv"); if (!func_ptr) LogFatalSymbolNotFound("cublasDgemv"); return func_ptr(trans, m, n, alpha, A, lda, x, incx, beta, y, incy); } -void CUBLASWINAPI cublasCgemv (char trans, int m, int n, cuComplex alpha, - const cuComplex *A, int lda, const cuComplex *x, int incx, - cuComplex beta, cuComplex *y, int incy) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, int, cuComplex, const cuComplex *, int, const cuComplex *, int, cuComplex, cuComplex *, int); +void CUBLASWINAPI cublasCgemv(char trans, int m, int n, cuComplex alpha, + const cuComplex *A, int lda, const cuComplex *x, + int incx, cuComplex beta, cuComplex *y, + int incy) { + using FuncPtr = + void(CUBLASWINAPI *)(char, int, int, cuComplex, const cuComplex *, int, + const cuComplex *, int, cuComplex, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgemv"); if (!func_ptr) LogFatalSymbolNotFound("cublasCgemv"); return func_ptr(trans, m, n, alpha, A, lda, x, incx, beta, y, incy); } -void CUBLASWINAPI cublasZgemv (char trans, int m, int n, cuDoubleComplex alpha, - const cuDoubleComplex *A, int lda, const cuDoubleComplex *x, int incx, - cuDoubleComplex beta, cuDoubleComplex *y, int incy) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, int, cuDoubleComplex, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); +void CUBLASWINAPI cublasZgemv(char trans, int m, int n, cuDoubleComplex alpha, + const cuDoubleComplex *A, int lda, + const cuDoubleComplex *x, int incx, + cuDoubleComplex beta, cuDoubleComplex *y, + int incy) { + using FuncPtr = void(CUBLASWINAPI *)( + char, int, int, cuDoubleComplex, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZgemv"); if (!func_ptr) LogFatalSymbolNotFound("cublasZgemv"); return func_ptr(trans, m, n, alpha, A, lda, x, incx, beta, y, incy); } -void CUBLASWINAPI cublasSgbmv (char trans, int m, int n, int kl, int ku, - float alpha, const float *A, int lda, - const float *x, int incx, float beta, float *y, - int incy) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, int, int, int, float, const float *, int, const float *, int, float, float *, int); +void CUBLASWINAPI cublasSgbmv(char trans, int m, int n, int kl, int ku, + float alpha, const float *A, int lda, + const float *x, int incx, float beta, float *y, + int incy) { + using FuncPtr = + void(CUBLASWINAPI *)(char, int, int, int, int, float, const float *, int, + const float *, int, float, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSgbmv"); if (!func_ptr) LogFatalSymbolNotFound("cublasSgbmv"); return func_ptr(trans, m, n, kl, ku, alpha, A, lda, x, incx, beta, y, incy); } -void CUBLASWINAPI cublasDgbmv (char trans, int m, int n, int kl, int ku, - double alpha, const double *A, int lda, - const double *x, int incx, double beta, double *y, - int incy) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, int, int, int, double, const double *, int, const double *, int, double, double *, int); +void CUBLASWINAPI cublasDgbmv(char trans, int m, int n, int kl, int ku, + double alpha, const double *A, int lda, + const double *x, int incx, double beta, double *y, + int incy) { + using FuncPtr = + void(CUBLASWINAPI *)(char, int, int, int, int, double, const double *, + int, const double *, int, double, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDgbmv"); if (!func_ptr) LogFatalSymbolNotFound("cublasDgbmv"); return func_ptr(trans, m, n, kl, ku, alpha, A, lda, x, incx, beta, y, incy); } -void CUBLASWINAPI cublasCgbmv (char trans, int m, int n, int kl, int ku, - cuComplex alpha, const cuComplex *A, int lda, - const cuComplex *x, int incx, cuComplex beta, cuComplex *y, - int incy) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, int, int, int, cuComplex, const cuComplex *, int, const cuComplex *, int, cuComplex, cuComplex *, int); +void CUBLASWINAPI cublasCgbmv(char trans, int m, int n, int kl, int ku, + cuComplex alpha, const cuComplex *A, int lda, + const cuComplex *x, int incx, cuComplex beta, + cuComplex *y, int incy) { + using FuncPtr = void(CUBLASWINAPI *)( + char, int, int, int, int, cuComplex, const cuComplex *, int, + const cuComplex *, int, cuComplex, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgbmv"); if (!func_ptr) LogFatalSymbolNotFound("cublasCgbmv"); return func_ptr(trans, m, n, kl, ku, alpha, A, lda, x, incx, beta, y, incy); } -void CUBLASWINAPI cublasZgbmv (char trans, int m, int n, int kl, int ku, - cuDoubleComplex alpha, const cuDoubleComplex *A, int lda, - const cuDoubleComplex *x, int incx, cuDoubleComplex beta, cuDoubleComplex *y, - int incy) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, int, int, int, cuDoubleComplex, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); +void CUBLASWINAPI cublasZgbmv(char trans, int m, int n, int kl, int ku, + cuDoubleComplex alpha, const cuDoubleComplex *A, + int lda, const cuDoubleComplex *x, int incx, + cuDoubleComplex beta, cuDoubleComplex *y, + int incy) { + using FuncPtr = void(CUBLASWINAPI *)( + char, int, int, int, int, cuDoubleComplex, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZgbmv"); if (!func_ptr) LogFatalSymbolNotFound("cublasZgbmv"); return func_ptr(trans, m, n, kl, ku, alpha, A, lda, x, incx, beta, y, incy); } -void CUBLASWINAPI cublasStrmv (char uplo, char trans, char diag, int n, - const float *A, int lda, float *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, const float *, int, float *, int); +void CUBLASWINAPI cublasStrmv(char uplo, char trans, char diag, int n, + const float *A, int lda, float *x, int incx) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, int, const float *, + int, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasStrmv"); if (!func_ptr) LogFatalSymbolNotFound("cublasStrmv"); return func_ptr(uplo, trans, diag, n, A, lda, x, incx); } -void CUBLASWINAPI cublasDtrmv (char uplo, char trans, char diag, int n, - const double *A, int lda, double *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, const double *, int, double *, int); +void CUBLASWINAPI cublasDtrmv(char uplo, char trans, char diag, int n, + const double *A, int lda, double *x, int incx) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, int, const double *, + int, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtrmv"); if (!func_ptr) LogFatalSymbolNotFound("cublasDtrmv"); return func_ptr(uplo, trans, diag, n, A, lda, x, incx); } -void CUBLASWINAPI cublasCtrmv (char uplo, char trans, char diag, int n, - const cuComplex *A, int lda, cuComplex *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, const cuComplex *, int, cuComplex *, int); +void CUBLASWINAPI cublasCtrmv(char uplo, char trans, char diag, int n, + const cuComplex *A, int lda, cuComplex *x, + int incx) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, int, const cuComplex *, + int, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtrmv"); if (!func_ptr) LogFatalSymbolNotFound("cublasCtrmv"); return func_ptr(uplo, trans, diag, n, A, lda, x, incx); } -void CUBLASWINAPI cublasZtrmv (char uplo, char trans, char diag, int n, - const cuDoubleComplex *A, int lda, cuDoubleComplex *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); +void CUBLASWINAPI cublasZtrmv(char uplo, char trans, char diag, int n, + const cuDoubleComplex *A, int lda, + cuDoubleComplex *x, int incx) { + using FuncPtr = + void(CUBLASWINAPI *)(char, char, char, int, const cuDoubleComplex *, int, + cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtrmv"); if (!func_ptr) LogFatalSymbolNotFound("cublasZtrmv"); return func_ptr(uplo, trans, diag, n, A, lda, x, incx); } -void CUBLASWINAPI cublasStbmv (char uplo, char trans, char diag, int n, int k, - const float *A, int lda, float *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, int, const float *, int, float *, int); +void CUBLASWINAPI cublasStbmv(char uplo, char trans, char diag, int n, int k, + const float *A, int lda, float *x, int incx) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, int, int, + const float *, int, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasStbmv"); if (!func_ptr) LogFatalSymbolNotFound("cublasStbmv"); return func_ptr(uplo, trans, diag, n, k, A, lda, x, incx); } -void CUBLASWINAPI cublasDtbmv (char uplo, char trans, char diag, int n, int k, - const double *A, int lda, double *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, int, const double *, int, double *, int); +void CUBLASWINAPI cublasDtbmv(char uplo, char trans, char diag, int n, int k, + const double *A, int lda, double *x, int incx) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, int, int, + const double *, int, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtbmv"); if (!func_ptr) LogFatalSymbolNotFound("cublasDtbmv"); return func_ptr(uplo, trans, diag, n, k, A, lda, x, incx); } -void CUBLASWINAPI cublasCtbmv (char uplo, char trans, char diag, int n, int k, - const cuComplex *A, int lda, cuComplex *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, int, const cuComplex *, int, cuComplex *, int); +void CUBLASWINAPI cublasCtbmv(char uplo, char trans, char diag, int n, int k, + const cuComplex *A, int lda, cuComplex *x, + int incx) { + using FuncPtr = void(CUBLASWINAPI *)( + char, char, char, int, int, const cuComplex *, int, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtbmv"); if (!func_ptr) LogFatalSymbolNotFound("cublasCtbmv"); return func_ptr(uplo, trans, diag, n, k, A, lda, x, incx); } -void CUBLASWINAPI cublasZtbmv (char uplo, char trans, char diag, int n, int k, - const cuDoubleComplex *A, int lda, cuDoubleComplex *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); +void CUBLASWINAPI cublasZtbmv(char uplo, char trans, char diag, int n, int k, + const cuDoubleComplex *A, int lda, + cuDoubleComplex *x, int incx) { + using FuncPtr = + void(CUBLASWINAPI *)(char, char, char, int, int, const cuDoubleComplex *, + int, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtbmv"); if (!func_ptr) LogFatalSymbolNotFound("cublasZtbmv"); return func_ptr(uplo, trans, diag, n, k, A, lda, x, incx); } -void CUBLASWINAPI cublasStpmv(char uplo, char trans, char diag, int n, const float *AP, float *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, const float *, float *, int); +void CUBLASWINAPI cublasStpmv(char uplo, char trans, char diag, int n, + const float *AP, float *x, int incx) { + using FuncPtr = + void(CUBLASWINAPI *)(char, char, char, int, const float *, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasStpmv"); if (!func_ptr) LogFatalSymbolNotFound("cublasStpmv"); return func_ptr(uplo, trans, diag, n, AP, x, incx); } -void CUBLASWINAPI cublasDtpmv(char uplo, char trans, char diag, int n, const double *AP, double *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, const double *, double *, int); +void CUBLASWINAPI cublasDtpmv(char uplo, char trans, char diag, int n, + const double *AP, double *x, int incx) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, int, const double *, + double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtpmv"); if (!func_ptr) LogFatalSymbolNotFound("cublasDtpmv"); return func_ptr(uplo, trans, diag, n, AP, x, incx); } -void CUBLASWINAPI cublasCtpmv(char uplo, char trans, char diag, int n, const cuComplex *AP, cuComplex *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, const cuComplex *, cuComplex *, int); +void CUBLASWINAPI cublasCtpmv(char uplo, char trans, char diag, int n, + const cuComplex *AP, cuComplex *x, int incx) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, int, const cuComplex *, + cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtpmv"); if (!func_ptr) LogFatalSymbolNotFound("cublasCtpmv"); return func_ptr(uplo, trans, diag, n, AP, x, incx); } -void CUBLASWINAPI cublasZtpmv(char uplo, char trans, char diag, int n, const cuDoubleComplex *AP, cuDoubleComplex *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, const cuDoubleComplex *, cuDoubleComplex *, int); +void CUBLASWINAPI cublasZtpmv(char uplo, char trans, char diag, int n, + const cuDoubleComplex *AP, cuDoubleComplex *x, + int incx) { + using FuncPtr = void(CUBLASWINAPI *)( + char, char, char, int, const cuDoubleComplex *, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtpmv"); if (!func_ptr) LogFatalSymbolNotFound("cublasZtpmv"); return func_ptr(uplo, trans, diag, n, AP, x, incx); } -void CUBLASWINAPI cublasStrsv(char uplo, char trans, char diag, int n, const float *A, int lda, float *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, const float *, int, float *, int); +void CUBLASWINAPI cublasStrsv(char uplo, char trans, char diag, int n, + const float *A, int lda, float *x, int incx) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, int, const float *, + int, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasStrsv"); if (!func_ptr) LogFatalSymbolNotFound("cublasStrsv"); return func_ptr(uplo, trans, diag, n, A, lda, x, incx); } -void CUBLASWINAPI cublasDtrsv(char uplo, char trans, char diag, int n, const double *A, int lda, double *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, const double *, int, double *, int); +void CUBLASWINAPI cublasDtrsv(char uplo, char trans, char diag, int n, + const double *A, int lda, double *x, int incx) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, int, const double *, + int, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtrsv"); if (!func_ptr) LogFatalSymbolNotFound("cublasDtrsv"); return func_ptr(uplo, trans, diag, n, A, lda, x, incx); } -void CUBLASWINAPI cublasCtrsv(char uplo, char trans, char diag, int n, const cuComplex *A, int lda, cuComplex *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, const cuComplex *, int, cuComplex *, int); +void CUBLASWINAPI cublasCtrsv(char uplo, char trans, char diag, int n, + const cuComplex *A, int lda, cuComplex *x, + int incx) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, int, const cuComplex *, + int, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtrsv"); if (!func_ptr) LogFatalSymbolNotFound("cublasCtrsv"); return func_ptr(uplo, trans, diag, n, A, lda, x, incx); } -void CUBLASWINAPI cublasZtrsv(char uplo, char trans, char diag, int n, const cuDoubleComplex *A, int lda, +void CUBLASWINAPI cublasZtrsv(char uplo, char trans, char diag, int n, + const cuDoubleComplex *A, int lda, cuDoubleComplex *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); + using FuncPtr = + void(CUBLASWINAPI *)(char, char, char, int, const cuDoubleComplex *, int, + cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtrsv"); if (!func_ptr) LogFatalSymbolNotFound("cublasZtrsv"); return func_ptr(uplo, trans, diag, n, A, lda, x, incx); } -void CUBLASWINAPI cublasStpsv(char uplo, char trans, char diag, int n, const float *AP, - float *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, const float *, float *, int); +void CUBLASWINAPI cublasStpsv(char uplo, char trans, char diag, int n, + const float *AP, float *x, int incx) { + using FuncPtr = + void(CUBLASWINAPI *)(char, char, char, int, const float *, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasStpsv"); if (!func_ptr) LogFatalSymbolNotFound("cublasStpsv"); return func_ptr(uplo, trans, diag, n, AP, x, incx); } -void CUBLASWINAPI cublasDtpsv(char uplo, char trans, char diag, int n, const double *AP, double *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, const double *, double *, int); +void CUBLASWINAPI cublasDtpsv(char uplo, char trans, char diag, int n, + const double *AP, double *x, int incx) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, int, const double *, + double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtpsv"); if (!func_ptr) LogFatalSymbolNotFound("cublasDtpsv"); return func_ptr(uplo, trans, diag, n, AP, x, incx); } -void CUBLASWINAPI cublasCtpsv(char uplo, char trans, char diag, int n, const cuComplex *AP, cuComplex *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, const cuComplex *, cuComplex *, int); +void CUBLASWINAPI cublasCtpsv(char uplo, char trans, char diag, int n, + const cuComplex *AP, cuComplex *x, int incx) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, int, const cuComplex *, + cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtpsv"); if (!func_ptr) LogFatalSymbolNotFound("cublasCtpsv"); return func_ptr(uplo, trans, diag, n, AP, x, incx); } -void CUBLASWINAPI cublasZtpsv(char uplo, char trans, char diag, int n, const cuDoubleComplex *AP, - cuDoubleComplex *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, const cuDoubleComplex *, cuDoubleComplex *, int); +void CUBLASWINAPI cublasZtpsv(char uplo, char trans, char diag, int n, + const cuDoubleComplex *AP, cuDoubleComplex *x, + int incx) { + using FuncPtr = void(CUBLASWINAPI *)( + char, char, char, int, const cuDoubleComplex *, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtpsv"); if (!func_ptr) LogFatalSymbolNotFound("cublasZtpsv"); return func_ptr(uplo, trans, diag, n, AP, x, incx); } -void CUBLASWINAPI cublasStbsv(char uplo, char trans, - char diag, int n, int k, const float *A, - int lda, float *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, int, const float *, int, float *, int); +void CUBLASWINAPI cublasStbsv(char uplo, char trans, char diag, int n, int k, + const float *A, int lda, float *x, int incx) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, int, int, + const float *, int, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasStbsv"); if (!func_ptr) LogFatalSymbolNotFound("cublasStbsv"); return func_ptr(uplo, trans, diag, n, k, A, lda, x, incx); } -void CUBLASWINAPI cublasDtbsv(char uplo, char trans, - char diag, int n, int k, const double *A, - int lda, double *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, int, const double *, int, double *, int); +void CUBLASWINAPI cublasDtbsv(char uplo, char trans, char diag, int n, int k, + const double *A, int lda, double *x, int incx) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, int, int, + const double *, int, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtbsv"); if (!func_ptr) LogFatalSymbolNotFound("cublasDtbsv"); return func_ptr(uplo, trans, diag, n, k, A, lda, x, incx); } -void CUBLASWINAPI cublasCtbsv(char uplo, char trans, - char diag, int n, int k, const cuComplex *A, - int lda, cuComplex *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, int, const cuComplex *, int, cuComplex *, int); +void CUBLASWINAPI cublasCtbsv(char uplo, char trans, char diag, int n, int k, + const cuComplex *A, int lda, cuComplex *x, + int incx) { + using FuncPtr = void(CUBLASWINAPI *)( + char, char, char, int, int, const cuComplex *, int, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtbsv"); if (!func_ptr) LogFatalSymbolNotFound("cublasCtbsv"); return func_ptr(uplo, trans, diag, n, k, A, lda, x, incx); } -void CUBLASWINAPI cublasZtbsv(char uplo, char trans, - char diag, int n, int k, const cuDoubleComplex *A, - int lda, cuDoubleComplex *x, int incx) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, int, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); +void CUBLASWINAPI cublasZtbsv(char uplo, char trans, char diag, int n, int k, + const cuDoubleComplex *A, int lda, + cuDoubleComplex *x, int incx) { + using FuncPtr = + void(CUBLASWINAPI *)(char, char, char, int, int, const cuDoubleComplex *, + int, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtbsv"); if (!func_ptr) LogFatalSymbolNotFound("cublasZtbsv"); return func_ptr(uplo, trans, diag, n, k, A, lda, x, incx); } -void CUBLASWINAPI cublasSsymv (char uplo, int n, float alpha, const float *A, - int lda, const float *x, int incx, float beta, - float *y, int incy) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, float, const float *, int, const float *, int, float, float *, int); +void CUBLASWINAPI cublasSsymv(char uplo, int n, float alpha, const float *A, + int lda, const float *x, int incx, float beta, + float *y, int incy) { + using FuncPtr = void(CUBLASWINAPI *)(char, int, float, const float *, int, + const float *, int, float, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSsymv"); if (!func_ptr) LogFatalSymbolNotFound("cublasSsymv"); return func_ptr(uplo, n, alpha, A, lda, x, incx, beta, y, incy); } -void CUBLASWINAPI cublasDsymv (char uplo, int n, double alpha, const double *A, - int lda, const double *x, int incx, double beta, - double *y, int incy) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, double, const double *, int, const double *, int, double, double *, int); +void CUBLASWINAPI cublasDsymv(char uplo, int n, double alpha, const double *A, + int lda, const double *x, int incx, double beta, + double *y, int incy) { + using FuncPtr = + void(CUBLASWINAPI *)(char, int, double, const double *, int, + const double *, int, double, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDsymv"); if (!func_ptr) LogFatalSymbolNotFound("cublasDsymv"); return func_ptr(uplo, n, alpha, A, lda, x, incx, beta, y, incy); } -void CUBLASWINAPI cublasChemv (char uplo, int n, cuComplex alpha, const cuComplex *A, - int lda, const cuComplex *x, int incx, cuComplex beta, - cuComplex *y, int incy) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, cuComplex, const cuComplex *, int, const cuComplex *, int, cuComplex, cuComplex *, int); +void CUBLASWINAPI cublasChemv(char uplo, int n, cuComplex alpha, + const cuComplex *A, int lda, const cuComplex *x, + int incx, cuComplex beta, cuComplex *y, + int incy) { + using FuncPtr = + void(CUBLASWINAPI *)(char, int, cuComplex, const cuComplex *, int, + const cuComplex *, int, cuComplex, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasChemv"); if (!func_ptr) LogFatalSymbolNotFound("cublasChemv"); return func_ptr(uplo, n, alpha, A, lda, x, incx, beta, y, incy); } -void CUBLASWINAPI cublasZhemv (char uplo, int n, cuDoubleComplex alpha, const cuDoubleComplex *A, - int lda, const cuDoubleComplex *x, int incx, cuDoubleComplex beta, - cuDoubleComplex *y, int incy) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, cuDoubleComplex, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); +void CUBLASWINAPI cublasZhemv(char uplo, int n, cuDoubleComplex alpha, + const cuDoubleComplex *A, int lda, + const cuDoubleComplex *x, int incx, + cuDoubleComplex beta, cuDoubleComplex *y, + int incy) { + using FuncPtr = void(CUBLASWINAPI *)( + char, int, cuDoubleComplex, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZhemv"); if (!func_ptr) LogFatalSymbolNotFound("cublasZhemv"); return func_ptr(uplo, n, alpha, A, lda, x, incx, beta, y, incy); } -void CUBLASWINAPI cublasSsbmv (char uplo, int n, int k, float alpha, - const float *A, int lda, const float *x, int incx, - float beta, float *y, int incy) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, int, float, const float *, int, const float *, int, float, float *, int); +void CUBLASWINAPI cublasSsbmv(char uplo, int n, int k, float alpha, + const float *A, int lda, const float *x, int incx, + float beta, float *y, int incy) { + using FuncPtr = + void(CUBLASWINAPI *)(char, int, int, float, const float *, int, + const float *, int, float, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSsbmv"); if (!func_ptr) LogFatalSymbolNotFound("cublasSsbmv"); return func_ptr(uplo, n, k, alpha, A, lda, x, incx, beta, y, incy); } -void CUBLASWINAPI cublasDsbmv (char uplo, int n, int k, double alpha, - const double *A, int lda, const double *x, int incx, - double beta, double *y, int incy) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, int, double, const double *, int, const double *, int, double, double *, int); +void CUBLASWINAPI cublasDsbmv(char uplo, int n, int k, double alpha, + const double *A, int lda, const double *x, + int incx, double beta, double *y, int incy) { + using FuncPtr = + void(CUBLASWINAPI *)(char, int, int, double, const double *, int, + const double *, int, double, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDsbmv"); if (!func_ptr) LogFatalSymbolNotFound("cublasDsbmv"); return func_ptr(uplo, n, k, alpha, A, lda, x, incx, beta, y, incy); } -void CUBLASWINAPI cublasChbmv (char uplo, int n, int k, cuComplex alpha, - const cuComplex *A, int lda, const cuComplex *x, int incx, - cuComplex beta, cuComplex *y, int incy) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, int, cuComplex, const cuComplex *, int, const cuComplex *, int, cuComplex, cuComplex *, int); +void CUBLASWINAPI cublasChbmv(char uplo, int n, int k, cuComplex alpha, + const cuComplex *A, int lda, const cuComplex *x, + int incx, cuComplex beta, cuComplex *y, + int incy) { + using FuncPtr = + void(CUBLASWINAPI *)(char, int, int, cuComplex, const cuComplex *, int, + const cuComplex *, int, cuComplex, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasChbmv"); if (!func_ptr) LogFatalSymbolNotFound("cublasChbmv"); return func_ptr(uplo, n, k, alpha, A, lda, x, incx, beta, y, incy); } -void CUBLASWINAPI cublasZhbmv (char uplo, int n, int k, cuDoubleComplex alpha, - const cuDoubleComplex *A, int lda, const cuDoubleComplex *x, int incx, - cuDoubleComplex beta, cuDoubleComplex *y, int incy) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, int, cuDoubleComplex, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); +void CUBLASWINAPI cublasZhbmv(char uplo, int n, int k, cuDoubleComplex alpha, + const cuDoubleComplex *A, int lda, + const cuDoubleComplex *x, int incx, + cuDoubleComplex beta, cuDoubleComplex *y, + int incy) { + using FuncPtr = void(CUBLASWINAPI *)( + char, int, int, cuDoubleComplex, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZhbmv"); if (!func_ptr) LogFatalSymbolNotFound("cublasZhbmv"); return func_ptr(uplo, n, k, alpha, A, lda, x, incx, beta, y, incy); } -void CUBLASWINAPI cublasSspmv(char uplo, int n, float alpha, - const float *AP, const float *x, - int incx, float beta, float *y, int incy) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, float, const float *, const float *, int, float, float *, int); +void CUBLASWINAPI cublasSspmv(char uplo, int n, float alpha, const float *AP, + const float *x, int incx, float beta, float *y, + int incy) { + using FuncPtr = void(CUBLASWINAPI *)(char, int, float, const float *, + const float *, int, float, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSspmv"); if (!func_ptr) LogFatalSymbolNotFound("cublasSspmv"); return func_ptr(uplo, n, alpha, AP, x, incx, beta, y, incy); } -void CUBLASWINAPI cublasDspmv(char uplo, int n, double alpha, - const double *AP, const double *x, - int incx, double beta, double *y, int incy) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, double, const double *, const double *, int, double, double *, int); +void CUBLASWINAPI cublasDspmv(char uplo, int n, double alpha, const double *AP, + const double *x, int incx, double beta, double *y, + int incy) { + using FuncPtr = + void(CUBLASWINAPI *)(char, int, double, const double *, const double *, + int, double, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDspmv"); if (!func_ptr) LogFatalSymbolNotFound("cublasDspmv"); return func_ptr(uplo, n, alpha, AP, x, incx, beta, y, incy); } void CUBLASWINAPI cublasChpmv(char uplo, int n, cuComplex alpha, - const cuComplex *AP, const cuComplex *x, - int incx, cuComplex beta, cuComplex *y, int incy) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, cuComplex, const cuComplex *, const cuComplex *, int, cuComplex, cuComplex *, int); + const cuComplex *AP, const cuComplex *x, int incx, + cuComplex beta, cuComplex *y, int incy) { + using FuncPtr = + void(CUBLASWINAPI *)(char, int, cuComplex, const cuComplex *, + const cuComplex *, int, cuComplex, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasChpmv"); if (!func_ptr) LogFatalSymbolNotFound("cublasChpmv"); return func_ptr(uplo, n, alpha, AP, x, incx, beta, y, incy); } void CUBLASWINAPI cublasZhpmv(char uplo, int n, cuDoubleComplex alpha, - const cuDoubleComplex *AP, const cuDoubleComplex *x, - int incx, cuDoubleComplex beta, cuDoubleComplex *y, int incy) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, cuDoubleComplex, const cuDoubleComplex *, const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); + const cuDoubleComplex *AP, + const cuDoubleComplex *x, int incx, + cuDoubleComplex beta, cuDoubleComplex *y, + int incy) { + using FuncPtr = void(CUBLASWINAPI *)( + char, int, cuDoubleComplex, const cuDoubleComplex *, + const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZhpmv"); if (!func_ptr) LogFatalSymbolNotFound("cublasZhpmv"); return func_ptr(uplo, n, alpha, AP, x, incx, beta, y, incy); } -void CUBLASWINAPI cublasSger (int m, int n, float alpha, const float *x, int incx, - const float *y, int incy, float *A, int lda) { - using FuncPtr = void (CUBLASWINAPI *)(int, int, float, const float *, int, const float *, int, float *, int); +void CUBLASWINAPI cublasSger(int m, int n, float alpha, const float *x, + int incx, const float *y, int incy, float *A, + int lda) { + using FuncPtr = void(CUBLASWINAPI *)(int, int, float, const float *, int, + const float *, int, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSger"); if (!func_ptr) LogFatalSymbolNotFound("cublasSger"); return func_ptr(m, n, alpha, x, incx, y, incy, A, lda); } -void CUBLASWINAPI cublasDger (int m, int n, double alpha, const double *x, int incx, - const double *y, int incy, double *A, int lda) { - using FuncPtr = void (CUBLASWINAPI *)(int, int, double, const double *, int, const double *, int, double *, int); +void CUBLASWINAPI cublasDger(int m, int n, double alpha, const double *x, + int incx, const double *y, int incy, double *A, + int lda) { + using FuncPtr = void(CUBLASWINAPI *)(int, int, double, const double *, int, + const double *, int, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDger"); if (!func_ptr) LogFatalSymbolNotFound("cublasDger"); return func_ptr(m, n, alpha, x, incx, y, incy, A, lda); } -void CUBLASWINAPI cublasCgeru (int m, int n, cuComplex alpha, const cuComplex *x, - int incx, const cuComplex *y, int incy, - cuComplex *A, int lda) { - using FuncPtr = void (CUBLASWINAPI *)(int, int, cuComplex, const cuComplex *, int, const cuComplex *, int, cuComplex *, int); +void CUBLASWINAPI cublasCgeru(int m, int n, cuComplex alpha, const cuComplex *x, + int incx, const cuComplex *y, int incy, + cuComplex *A, int lda) { + using FuncPtr = + void(CUBLASWINAPI *)(int, int, cuComplex, const cuComplex *, int, + const cuComplex *, int, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgeru"); if (!func_ptr) LogFatalSymbolNotFound("cublasCgeru"); return func_ptr(m, n, alpha, x, incx, y, incy, A, lda); } -void CUBLASWINAPI cublasCgerc (int m, int n, cuComplex alpha, const cuComplex *x, - int incx, const cuComplex *y, int incy, - cuComplex *A, int lda) { - using FuncPtr = void (CUBLASWINAPI *)(int, int, cuComplex, const cuComplex *, int, const cuComplex *, int, cuComplex *, int); +void CUBLASWINAPI cublasCgerc(int m, int n, cuComplex alpha, const cuComplex *x, + int incx, const cuComplex *y, int incy, + cuComplex *A, int lda) { + using FuncPtr = + void(CUBLASWINAPI *)(int, int, cuComplex, const cuComplex *, int, + const cuComplex *, int, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgerc"); if (!func_ptr) LogFatalSymbolNotFound("cublasCgerc"); return func_ptr(m, n, alpha, x, incx, y, incy, A, lda); } -void CUBLASWINAPI cublasZgeru (int m, int n, cuDoubleComplex alpha, const cuDoubleComplex *x, - int incx, const cuDoubleComplex *y, int incy, - cuDoubleComplex *A, int lda) { - using FuncPtr = void (CUBLASWINAPI *)(int, int, cuDoubleComplex, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); +void CUBLASWINAPI cublasZgeru(int m, int n, cuDoubleComplex alpha, + const cuDoubleComplex *x, int incx, + const cuDoubleComplex *y, int incy, + cuDoubleComplex *A, int lda) { + using FuncPtr = void(CUBLASWINAPI *)( + int, int, cuDoubleComplex, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZgeru"); if (!func_ptr) LogFatalSymbolNotFound("cublasZgeru"); return func_ptr(m, n, alpha, x, incx, y, incy, A, lda); } -void CUBLASWINAPI cublasZgerc (int m, int n, cuDoubleComplex alpha, const cuDoubleComplex *x, - int incx, const cuDoubleComplex *y, int incy, - cuDoubleComplex *A, int lda) { - using FuncPtr = void (CUBLASWINAPI *)(int, int, cuDoubleComplex, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); +void CUBLASWINAPI cublasZgerc(int m, int n, cuDoubleComplex alpha, + const cuDoubleComplex *x, int incx, + const cuDoubleComplex *y, int incy, + cuDoubleComplex *A, int lda) { + using FuncPtr = void(CUBLASWINAPI *)( + int, int, cuDoubleComplex, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZgerc"); if (!func_ptr) LogFatalSymbolNotFound("cublasZgerc"); return func_ptr(m, n, alpha, x, incx, y, incy, A, lda); } -void CUBLASWINAPI cublasSsyr (char uplo, int n, float alpha, const float *x, - int incx, float *A, int lda) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, float, const float *, int, float *, int); +void CUBLASWINAPI cublasSsyr(char uplo, int n, float alpha, const float *x, + int incx, float *A, int lda) { + using FuncPtr = + void(CUBLASWINAPI *)(char, int, float, const float *, int, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSsyr"); if (!func_ptr) LogFatalSymbolNotFound("cublasSsyr"); return func_ptr(uplo, n, alpha, x, incx, A, lda); } -void CUBLASWINAPI cublasDsyr (char uplo, int n, double alpha, const double *x, - int incx, double *A, int lda) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, double, const double *, int, double *, int); +void CUBLASWINAPI cublasDsyr(char uplo, int n, double alpha, const double *x, + int incx, double *A, int lda) { + using FuncPtr = void(CUBLASWINAPI *)(char, int, double, const double *, int, + double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDsyr"); if (!func_ptr) LogFatalSymbolNotFound("cublasDsyr"); return func_ptr(uplo, n, alpha, x, incx, A, lda); } -void CUBLASWINAPI cublasCher (char uplo, int n, float alpha, - const cuComplex *x, int incx, cuComplex *A, int lda) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, float, const cuComplex *, int, cuComplex *, int); +void CUBLASWINAPI cublasCher(char uplo, int n, float alpha, const cuComplex *x, + int incx, cuComplex *A, int lda) { + using FuncPtr = void(CUBLASWINAPI *)(char, int, float, const cuComplex *, int, + cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCher"); if (!func_ptr) LogFatalSymbolNotFound("cublasCher"); return func_ptr(uplo, n, alpha, x, incx, A, lda); } -void CUBLASWINAPI cublasZher (char uplo, int n, double alpha, - const cuDoubleComplex *x, int incx, cuDoubleComplex *A, int lda) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, double, const cuDoubleComplex *, int, cuDoubleComplex *, int); +void CUBLASWINAPI cublasZher(char uplo, int n, double alpha, + const cuDoubleComplex *x, int incx, + cuDoubleComplex *A, int lda) { + using FuncPtr = void(CUBLASWINAPI *)( + char, int, double, const cuDoubleComplex *, int, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZher"); if (!func_ptr) LogFatalSymbolNotFound("cublasZher"); return func_ptr(uplo, n, alpha, x, incx, A, lda); } -void CUBLASWINAPI cublasSspr (char uplo, int n, float alpha, const float *x, - int incx, float *AP) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, float, const float *, int, float *); +void CUBLASWINAPI cublasSspr(char uplo, int n, float alpha, const float *x, + int incx, float *AP) { + using FuncPtr = + void(CUBLASWINAPI *)(char, int, float, const float *, int, float *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSspr"); if (!func_ptr) LogFatalSymbolNotFound("cublasSspr"); return func_ptr(uplo, n, alpha, x, incx, AP); } -void CUBLASWINAPI cublasDspr (char uplo, int n, double alpha, const double *x, - int incx, double *AP) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, double, const double *, int, double *); +void CUBLASWINAPI cublasDspr(char uplo, int n, double alpha, const double *x, + int incx, double *AP) { + using FuncPtr = + void(CUBLASWINAPI *)(char, int, double, const double *, int, double *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDspr"); if (!func_ptr) LogFatalSymbolNotFound("cublasDspr"); return func_ptr(uplo, n, alpha, x, incx, AP); } -void CUBLASWINAPI cublasChpr (char uplo, int n, float alpha, const cuComplex *x, - int incx, cuComplex *AP) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, float, const cuComplex *, int, cuComplex *); +void CUBLASWINAPI cublasChpr(char uplo, int n, float alpha, const cuComplex *x, + int incx, cuComplex *AP) { + using FuncPtr = void(CUBLASWINAPI *)(char, int, float, const cuComplex *, int, + cuComplex *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasChpr"); if (!func_ptr) LogFatalSymbolNotFound("cublasChpr"); return func_ptr(uplo, n, alpha, x, incx, AP); } -void CUBLASWINAPI cublasZhpr (char uplo, int n, double alpha, const cuDoubleComplex *x, - int incx, cuDoubleComplex *AP) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, double, const cuDoubleComplex *, int, cuDoubleComplex *); +void CUBLASWINAPI cublasZhpr(char uplo, int n, double alpha, + const cuDoubleComplex *x, int incx, + cuDoubleComplex *AP) { + using FuncPtr = void(CUBLASWINAPI *)( + char, int, double, const cuDoubleComplex *, int, cuDoubleComplex *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZhpr"); if (!func_ptr) LogFatalSymbolNotFound("cublasZhpr"); return func_ptr(uplo, n, alpha, x, incx, AP); } -void CUBLASWINAPI cublasSsyr2 (char uplo, int n, float alpha, const float *x, - int incx, const float *y, int incy, float *A, - int lda) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, float, const float *, int, const float *, int, float *, int); +void CUBLASWINAPI cublasSsyr2(char uplo, int n, float alpha, const float *x, + int incx, const float *y, int incy, float *A, + int lda) { + using FuncPtr = void(CUBLASWINAPI *)(char, int, float, const float *, int, + const float *, int, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSsyr2"); if (!func_ptr) LogFatalSymbolNotFound("cublasSsyr2"); return func_ptr(uplo, n, alpha, x, incx, y, incy, A, lda); } -void CUBLASWINAPI cublasDsyr2 (char uplo, int n, double alpha, const double *x, - int incx, const double *y, int incy, double *A, - int lda) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, double, const double *, int, const double *, int, double *, int); +void CUBLASWINAPI cublasDsyr2(char uplo, int n, double alpha, const double *x, + int incx, const double *y, int incy, double *A, + int lda) { + using FuncPtr = void(CUBLASWINAPI *)(char, int, double, const double *, int, + const double *, int, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDsyr2"); if (!func_ptr) LogFatalSymbolNotFound("cublasDsyr2"); return func_ptr(uplo, n, alpha, x, incx, y, incy, A, lda); } -void CUBLASWINAPI cublasCher2 (char uplo, int n, cuComplex alpha, const cuComplex *x, - int incx, const cuComplex *y, int incy, cuComplex *A, - int lda) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, cuComplex, const cuComplex *, int, const cuComplex *, int, cuComplex *, int); +void CUBLASWINAPI cublasCher2(char uplo, int n, cuComplex alpha, + const cuComplex *x, int incx, const cuComplex *y, + int incy, cuComplex *A, int lda) { + using FuncPtr = + void(CUBLASWINAPI *)(char, int, cuComplex, const cuComplex *, int, + const cuComplex *, int, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCher2"); if (!func_ptr) LogFatalSymbolNotFound("cublasCher2"); return func_ptr(uplo, n, alpha, x, incx, y, incy, A, lda); } -void CUBLASWINAPI cublasZher2 (char uplo, int n, cuDoubleComplex alpha, const cuDoubleComplex *x, - int incx, const cuDoubleComplex *y, int incy, cuDoubleComplex *A, - int lda) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, cuDoubleComplex, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex *, int); +void CUBLASWINAPI cublasZher2(char uplo, int n, cuDoubleComplex alpha, + const cuDoubleComplex *x, int incx, + const cuDoubleComplex *y, int incy, + cuDoubleComplex *A, int lda) { + using FuncPtr = void(CUBLASWINAPI *)( + char, int, cuDoubleComplex, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZher2"); if (!func_ptr) LogFatalSymbolNotFound("cublasZher2"); return func_ptr(uplo, n, alpha, x, incx, y, incy, A, lda); } -void CUBLASWINAPI cublasSspr2 (char uplo, int n, float alpha, const float *x, - int incx, const float *y, int incy, float *AP) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, float, const float *, int, const float *, int, float *); +void CUBLASWINAPI cublasSspr2(char uplo, int n, float alpha, const float *x, + int incx, const float *y, int incy, float *AP) { + using FuncPtr = void(CUBLASWINAPI *)(char, int, float, const float *, int, + const float *, int, float *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSspr2"); if (!func_ptr) LogFatalSymbolNotFound("cublasSspr2"); return func_ptr(uplo, n, alpha, x, incx, y, incy, AP); } -void CUBLASWINAPI cublasDspr2 (char uplo, int n, double alpha, - const double *x, int incx, const double *y, - int incy, double *AP) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, double, const double *, int, const double *, int, double *); +void CUBLASWINAPI cublasDspr2(char uplo, int n, double alpha, const double *x, + int incx, const double *y, int incy, double *AP) { + using FuncPtr = void(CUBLASWINAPI *)(char, int, double, const double *, int, + const double *, int, double *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDspr2"); if (!func_ptr) LogFatalSymbolNotFound("cublasDspr2"); return func_ptr(uplo, n, alpha, x, incx, y, incy, AP); } -void CUBLASWINAPI cublasChpr2 (char uplo, int n, cuComplex alpha, - const cuComplex *x, int incx, const cuComplex *y, - int incy, cuComplex *AP) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, cuComplex, const cuComplex *, int, const cuComplex *, int, cuComplex *); +void CUBLASWINAPI cublasChpr2(char uplo, int n, cuComplex alpha, + const cuComplex *x, int incx, const cuComplex *y, + int incy, cuComplex *AP) { + using FuncPtr = + void(CUBLASWINAPI *)(char, int, cuComplex, const cuComplex *, int, + const cuComplex *, int, cuComplex *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasChpr2"); if (!func_ptr) LogFatalSymbolNotFound("cublasChpr2"); return func_ptr(uplo, n, alpha, x, incx, y, incy, AP); } -void CUBLASWINAPI cublasZhpr2 (char uplo, int n, cuDoubleComplex alpha, - const cuDoubleComplex *x, int incx, const cuDoubleComplex *y, - int incy, cuDoubleComplex *AP) { - using FuncPtr = void (CUBLASWINAPI *)(char, int, cuDoubleComplex, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex *); +void CUBLASWINAPI cublasZhpr2(char uplo, int n, cuDoubleComplex alpha, + const cuDoubleComplex *x, int incx, + const cuDoubleComplex *y, int incy, + cuDoubleComplex *AP) { + using FuncPtr = void(CUBLASWINAPI *)( + char, int, cuDoubleComplex, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, cuDoubleComplex *); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZhpr2"); if (!func_ptr) LogFatalSymbolNotFound("cublasZhpr2"); return func_ptr(uplo, n, alpha, x, incx, y, incy, AP); } -void CUBLASWINAPI cublasSgemm (char transa, char transb, int m, int n, int k, - float alpha, const float *A, int lda, - const float *B, int ldb, float beta, float *C, - int ldc) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, int, float, const float *, int, const float *, int, float, float *, int); +void CUBLASWINAPI cublasSgemm(char transa, char transb, int m, int n, int k, + float alpha, const float *A, int lda, + const float *B, int ldb, float beta, float *C, + int ldc) { + using FuncPtr = + void(CUBLASWINAPI *)(char, char, int, int, int, float, const float *, int, + const float *, int, float, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSgemm"); if (!func_ptr) LogFatalSymbolNotFound("cublasSgemm"); return func_ptr(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); } -void CUBLASWINAPI cublasDgemm (char transa, char transb, int m, int n, int k, - double alpha, const double *A, int lda, - const double *B, int ldb, double beta, double *C, - int ldc) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, int, double, const double *, int, const double *, int, double, double *, int); +void CUBLASWINAPI cublasDgemm(char transa, char transb, int m, int n, int k, + double alpha, const double *A, int lda, + const double *B, int ldb, double beta, double *C, + int ldc) { + using FuncPtr = + void(CUBLASWINAPI *)(char, char, int, int, int, double, const double *, + int, const double *, int, double, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDgemm"); if (!func_ptr) LogFatalSymbolNotFound("cublasDgemm"); return func_ptr(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); } -void CUBLASWINAPI cublasCgemm (char transa, char transb, int m, int n, int k, - cuComplex alpha, const cuComplex *A, int lda, - const cuComplex *B, int ldb, cuComplex beta, - cuComplex *C, int ldc) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, int, cuComplex, const cuComplex *, int, const cuComplex *, int, cuComplex, cuComplex *, int); +void CUBLASWINAPI cublasCgemm(char transa, char transb, int m, int n, int k, + cuComplex alpha, const cuComplex *A, int lda, + const cuComplex *B, int ldb, cuComplex beta, + cuComplex *C, int ldc) { + using FuncPtr = void(CUBLASWINAPI *)( + char, char, int, int, int, cuComplex, const cuComplex *, int, + const cuComplex *, int, cuComplex, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCgemm"); if (!func_ptr) LogFatalSymbolNotFound("cublasCgemm"); return func_ptr(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); } -void CUBLASWINAPI cublasZgemm (char transa, char transb, int m, int n, - int k, cuDoubleComplex alpha, - const cuDoubleComplex *A, int lda, - const cuDoubleComplex *B, int ldb, - cuDoubleComplex beta, cuDoubleComplex *C, - int ldc) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, int, cuDoubleComplex, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); +void CUBLASWINAPI cublasZgemm(char transa, char transb, int m, int n, int k, + cuDoubleComplex alpha, const cuDoubleComplex *A, + int lda, const cuDoubleComplex *B, int ldb, + cuDoubleComplex beta, cuDoubleComplex *C, + int ldc) { + using FuncPtr = void(CUBLASWINAPI *)( + char, char, int, int, int, cuDoubleComplex, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZgemm"); if (!func_ptr) LogFatalSymbolNotFound("cublasZgemm"); return func_ptr(transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); } -void CUBLASWINAPI cublasSsyrk (char uplo, char trans, int n, int k, float alpha, - const float *A, int lda, float beta, float *C, - int ldc) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, float, const float *, int, float, float *, int); +void CUBLASWINAPI cublasSsyrk(char uplo, char trans, int n, int k, float alpha, + const float *A, int lda, float beta, float *C, + int ldc) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, int, int, float, + const float *, int, float, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSsyrk"); if (!func_ptr) LogFatalSymbolNotFound("cublasSsyrk"); return func_ptr(uplo, trans, n, k, alpha, A, lda, beta, C, ldc); } -void CUBLASWINAPI cublasDsyrk (char uplo, char trans, int n, int k, - double alpha, const double *A, int lda, - double beta, double *C, int ldc) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, double, const double *, int, double, double *, int); +void CUBLASWINAPI cublasDsyrk(char uplo, char trans, int n, int k, double alpha, + const double *A, int lda, double beta, double *C, + int ldc) { + using FuncPtr = void(CUBLASWINAPI *)( + char, char, int, int, double, const double *, int, double, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDsyrk"); if (!func_ptr) LogFatalSymbolNotFound("cublasDsyrk"); return func_ptr(uplo, trans, n, k, alpha, A, lda, beta, C, ldc); } -void CUBLASWINAPI cublasCsyrk (char uplo, char trans, int n, int k, - cuComplex alpha, const cuComplex *A, int lda, - cuComplex beta, cuComplex *C, int ldc) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, cuComplex, const cuComplex *, int, cuComplex, cuComplex *, int); +void CUBLASWINAPI cublasCsyrk(char uplo, char trans, int n, int k, + cuComplex alpha, const cuComplex *A, int lda, + cuComplex beta, cuComplex *C, int ldc) { + using FuncPtr = + void(CUBLASWINAPI *)(char, char, int, int, cuComplex, const cuComplex *, + int, cuComplex, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCsyrk"); if (!func_ptr) LogFatalSymbolNotFound("cublasCsyrk"); return func_ptr(uplo, trans, n, k, alpha, A, lda, beta, C, ldc); } -void CUBLASWINAPI cublasZsyrk (char uplo, char trans, int n, int k, - cuDoubleComplex alpha, - const cuDoubleComplex *A, int lda, - cuDoubleComplex beta, - cuDoubleComplex *C, int ldc) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, cuDoubleComplex, const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); +void CUBLASWINAPI cublasZsyrk(char uplo, char trans, int n, int k, + cuDoubleComplex alpha, const cuDoubleComplex *A, + int lda, cuDoubleComplex beta, cuDoubleComplex *C, + int ldc) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, int, int, cuDoubleComplex, + const cuDoubleComplex *, int, + cuDoubleComplex, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZsyrk"); if (!func_ptr) LogFatalSymbolNotFound("cublasZsyrk"); return func_ptr(uplo, trans, n, k, alpha, A, lda, beta, C, ldc); } -void CUBLASWINAPI cublasCherk (char uplo, char trans, int n, int k, - float alpha, const cuComplex *A, int lda, - float beta, cuComplex *C, int ldc) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, float, const cuComplex *, int, float, cuComplex *, int); +void CUBLASWINAPI cublasCherk(char uplo, char trans, int n, int k, float alpha, + const cuComplex *A, int lda, float beta, + cuComplex *C, int ldc) { + using FuncPtr = + void(CUBLASWINAPI *)(char, char, int, int, float, const cuComplex *, int, + float, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCherk"); if (!func_ptr) LogFatalSymbolNotFound("cublasCherk"); return func_ptr(uplo, trans, n, k, alpha, A, lda, beta, C, ldc); } -void CUBLASWINAPI cublasZherk (char uplo, char trans, int n, int k, - double alpha, - const cuDoubleComplex *A, int lda, - double beta, - cuDoubleComplex *C, int ldc) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, double, const cuDoubleComplex *, int, double, cuDoubleComplex *, int); +void CUBLASWINAPI cublasZherk(char uplo, char trans, int n, int k, double alpha, + const cuDoubleComplex *A, int lda, double beta, + cuDoubleComplex *C, int ldc) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, int, int, double, + const cuDoubleComplex *, int, double, + cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZherk"); if (!func_ptr) LogFatalSymbolNotFound("cublasZherk"); return func_ptr(uplo, trans, n, k, alpha, A, lda, beta, C, ldc); } -void CUBLASWINAPI cublasSsyr2k (char uplo, char trans, int n, int k, float alpha, - const float *A, int lda, const float *B, int ldb, - float beta, float *C, int ldc) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, float, const float *, int, const float *, int, float, float *, int); +void CUBLASWINAPI cublasSsyr2k(char uplo, char trans, int n, int k, float alpha, + const float *A, int lda, const float *B, int ldb, + float beta, float *C, int ldc) { + using FuncPtr = + void(CUBLASWINAPI *)(char, char, int, int, float, const float *, int, + const float *, int, float, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSsyr2k"); if (!func_ptr) LogFatalSymbolNotFound("cublasSsyr2k"); return func_ptr(uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); } -void CUBLASWINAPI cublasDsyr2k (char uplo, char trans, int n, int k, - double alpha, const double *A, int lda, - const double *B, int ldb, double beta, - double *C, int ldc) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, double, const double *, int, const double *, int, double, double *, int); +void CUBLASWINAPI cublasDsyr2k(char uplo, char trans, int n, int k, + double alpha, const double *A, int lda, + const double *B, int ldb, double beta, double *C, + int ldc) { + using FuncPtr = + void(CUBLASWINAPI *)(char, char, int, int, double, const double *, int, + const double *, int, double, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDsyr2k"); if (!func_ptr) LogFatalSymbolNotFound("cublasDsyr2k"); return func_ptr(uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); } -void CUBLASWINAPI cublasCsyr2k (char uplo, char trans, int n, int k, - cuComplex alpha, const cuComplex *A, int lda, - const cuComplex *B, int ldb, cuComplex beta, - cuComplex *C, int ldc) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, cuComplex, const cuComplex *, int, const cuComplex *, int, cuComplex, cuComplex *, int); +void CUBLASWINAPI cublasCsyr2k(char uplo, char trans, int n, int k, + cuComplex alpha, const cuComplex *A, int lda, + const cuComplex *B, int ldb, cuComplex beta, + cuComplex *C, int ldc) { + using FuncPtr = void(CUBLASWINAPI *)( + char, char, int, int, cuComplex, const cuComplex *, int, + const cuComplex *, int, cuComplex, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCsyr2k"); if (!func_ptr) LogFatalSymbolNotFound("cublasCsyr2k"); return func_ptr(uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); } -void CUBLASWINAPI cublasZsyr2k (char uplo, char trans, int n, int k, - cuDoubleComplex alpha, const cuDoubleComplex *A, int lda, - const cuDoubleComplex *B, int ldb, cuDoubleComplex beta, - cuDoubleComplex *C, int ldc) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, cuDoubleComplex, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); +void CUBLASWINAPI cublasZsyr2k(char uplo, char trans, int n, int k, + cuDoubleComplex alpha, const cuDoubleComplex *A, + int lda, const cuDoubleComplex *B, int ldb, + cuDoubleComplex beta, cuDoubleComplex *C, + int ldc) { + using FuncPtr = void(CUBLASWINAPI *)( + char, char, int, int, cuDoubleComplex, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZsyr2k"); if (!func_ptr) LogFatalSymbolNotFound("cublasZsyr2k"); return func_ptr(uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); } -void CUBLASWINAPI cublasCher2k (char uplo, char trans, int n, int k, - cuComplex alpha, const cuComplex *A, int lda, - const cuComplex *B, int ldb, float beta, - cuComplex *C, int ldc) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, cuComplex, const cuComplex *, int, const cuComplex *, int, float, cuComplex *, int); +void CUBLASWINAPI cublasCher2k(char uplo, char trans, int n, int k, + cuComplex alpha, const cuComplex *A, int lda, + const cuComplex *B, int ldb, float beta, + cuComplex *C, int ldc) { + using FuncPtr = void(CUBLASWINAPI *)( + char, char, int, int, cuComplex, const cuComplex *, int, + const cuComplex *, int, float, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCher2k"); if (!func_ptr) LogFatalSymbolNotFound("cublasCher2k"); return func_ptr(uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); } -void CUBLASWINAPI cublasZher2k (char uplo, char trans, int n, int k, - cuDoubleComplex alpha, const cuDoubleComplex *A, int lda, - const cuDoubleComplex *B, int ldb, double beta, - cuDoubleComplex *C, int ldc) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, cuDoubleComplex, const cuDoubleComplex *, int, const cuDoubleComplex *, int, double, cuDoubleComplex *, int); +void CUBLASWINAPI cublasZher2k(char uplo, char trans, int n, int k, + cuDoubleComplex alpha, const cuDoubleComplex *A, + int lda, const cuDoubleComplex *B, int ldb, + double beta, cuDoubleComplex *C, int ldc) { + using FuncPtr = void(CUBLASWINAPI *)( + char, char, int, int, cuDoubleComplex, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, double, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZher2k"); if (!func_ptr) LogFatalSymbolNotFound("cublasZher2k"); return func_ptr(uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc); } -void CUBLASWINAPI cublasSsymm (char side, char uplo, int m, int n, float alpha, - const float *A, int lda, const float *B, int ldb, - float beta, float *C, int ldc) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, float, const float *, int, const float *, int, float, float *, int); +void CUBLASWINAPI cublasSsymm(char side, char uplo, int m, int n, float alpha, + const float *A, int lda, const float *B, int ldb, + float beta, float *C, int ldc) { + using FuncPtr = + void(CUBLASWINAPI *)(char, char, int, int, float, const float *, int, + const float *, int, float, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasSsymm"); if (!func_ptr) LogFatalSymbolNotFound("cublasSsymm"); return func_ptr(side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc); } -void CUBLASWINAPI cublasDsymm (char side, char uplo, int m, int n, double alpha, - const double *A, int lda, const double *B, int ldb, - double beta, double *C, int ldc) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, double, const double *, int, const double *, int, double, double *, int); +void CUBLASWINAPI cublasDsymm(char side, char uplo, int m, int n, double alpha, + const double *A, int lda, const double *B, + int ldb, double beta, double *C, int ldc) { + using FuncPtr = + void(CUBLASWINAPI *)(char, char, int, int, double, const double *, int, + const double *, int, double, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDsymm"); if (!func_ptr) LogFatalSymbolNotFound("cublasDsymm"); return func_ptr(side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc); } -void CUBLASWINAPI cublasCsymm (char side, char uplo, int m, int n, cuComplex alpha, - const cuComplex *A, int lda, const cuComplex *B, int ldb, - cuComplex beta, cuComplex *C, int ldc) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, cuComplex, const cuComplex *, int, const cuComplex *, int, cuComplex, cuComplex *, int); +void CUBLASWINAPI cublasCsymm(char side, char uplo, int m, int n, + cuComplex alpha, const cuComplex *A, int lda, + const cuComplex *B, int ldb, cuComplex beta, + cuComplex *C, int ldc) { + using FuncPtr = void(CUBLASWINAPI *)( + char, char, int, int, cuComplex, const cuComplex *, int, + const cuComplex *, int, cuComplex, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCsymm"); if (!func_ptr) LogFatalSymbolNotFound("cublasCsymm"); return func_ptr(side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc); } -void CUBLASWINAPI cublasZsymm (char side, char uplo, int m, int n, cuDoubleComplex alpha, - const cuDoubleComplex *A, int lda, const cuDoubleComplex *B, int ldb, - cuDoubleComplex beta, cuDoubleComplex *C, int ldc) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, cuDoubleComplex, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); +void CUBLASWINAPI cublasZsymm(char side, char uplo, int m, int n, + cuDoubleComplex alpha, const cuDoubleComplex *A, + int lda, const cuDoubleComplex *B, int ldb, + cuDoubleComplex beta, cuDoubleComplex *C, + int ldc) { + using FuncPtr = void(CUBLASWINAPI *)( + char, char, int, int, cuDoubleComplex, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZsymm"); if (!func_ptr) LogFatalSymbolNotFound("cublasZsymm"); return func_ptr(side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc); } -void CUBLASWINAPI cublasChemm (char side, char uplo, int m, int n, - cuComplex alpha, const cuComplex *A, int lda, - const cuComplex *B, int ldb, cuComplex beta, - cuComplex *C, int ldc) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, cuComplex, const cuComplex *, int, const cuComplex *, int, cuComplex, cuComplex *, int); +void CUBLASWINAPI cublasChemm(char side, char uplo, int m, int n, + cuComplex alpha, const cuComplex *A, int lda, + const cuComplex *B, int ldb, cuComplex beta, + cuComplex *C, int ldc) { + using FuncPtr = void(CUBLASWINAPI *)( + char, char, int, int, cuComplex, const cuComplex *, int, + const cuComplex *, int, cuComplex, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasChemm"); if (!func_ptr) LogFatalSymbolNotFound("cublasChemm"); return func_ptr(side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc); } -void CUBLASWINAPI cublasZhemm (char side, char uplo, int m, int n, - cuDoubleComplex alpha, const cuDoubleComplex *A, int lda, - const cuDoubleComplex *B, int ldb, cuDoubleComplex beta, - cuDoubleComplex *C, int ldc) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, int, int, cuDoubleComplex, const cuDoubleComplex *, int, const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); +void CUBLASWINAPI cublasZhemm(char side, char uplo, int m, int n, + cuDoubleComplex alpha, const cuDoubleComplex *A, + int lda, const cuDoubleComplex *B, int ldb, + cuDoubleComplex beta, cuDoubleComplex *C, + int ldc) { + using FuncPtr = void(CUBLASWINAPI *)( + char, char, int, int, cuDoubleComplex, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, cuDoubleComplex, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZhemm"); if (!func_ptr) LogFatalSymbolNotFound("cublasZhemm"); return func_ptr(side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc); } -void CUBLASWINAPI cublasStrsm (char side, char uplo, char transa, char diag, - int m, int n, float alpha, const float *A, int lda, - float *B, int ldb) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, char, int, int, float, const float *, int, float *, int); +void CUBLASWINAPI cublasStrsm(char side, char uplo, char transa, char diag, + int m, int n, float alpha, const float *A, + int lda, float *B, int ldb) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, char, int, int, float, + const float *, int, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasStrsm"); if (!func_ptr) LogFatalSymbolNotFound("cublasStrsm"); return func_ptr(side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb); } -void CUBLASWINAPI cublasDtrsm (char side, char uplo, char transa, - char diag, int m, int n, double alpha, - const double *A, int lda, double *B, - int ldb) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, char, int, int, double, const double *, int, double *, int); +void CUBLASWINAPI cublasDtrsm(char side, char uplo, char transa, char diag, + int m, int n, double alpha, const double *A, + int lda, double *B, int ldb) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, char, int, int, double, + const double *, int, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtrsm"); if (!func_ptr) LogFatalSymbolNotFound("cublasDtrsm"); return func_ptr(side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb); } -void CUBLASWINAPI cublasCtrsm (char side, char uplo, char transa, char diag, - int m, int n, cuComplex alpha, const cuComplex *A, - int lda, cuComplex *B, int ldb) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, char, int, int, cuComplex, const cuComplex *, int, cuComplex *, int); +void CUBLASWINAPI cublasCtrsm(char side, char uplo, char transa, char diag, + int m, int n, cuComplex alpha, const cuComplex *A, + int lda, cuComplex *B, int ldb) { + using FuncPtr = + void(CUBLASWINAPI *)(char, char, char, char, int, int, cuComplex, + const cuComplex *, int, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtrsm"); if (!func_ptr) LogFatalSymbolNotFound("cublasCtrsm"); return func_ptr(side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb); } -void CUBLASWINAPI cublasZtrsm (char side, char uplo, char transa, - char diag, int m, int n, cuDoubleComplex alpha, - const cuDoubleComplex *A, int lda, - cuDoubleComplex *B, int ldb) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, char, int, int, cuDoubleComplex, const cuDoubleComplex *, int, cuDoubleComplex *, int); +void CUBLASWINAPI cublasZtrsm(char side, char uplo, char transa, char diag, + int m, int n, cuDoubleComplex alpha, + const cuDoubleComplex *A, int lda, + cuDoubleComplex *B, int ldb) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, char, int, int, + cuDoubleComplex, const cuDoubleComplex *, + int, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtrsm"); if (!func_ptr) LogFatalSymbolNotFound("cublasZtrsm"); return func_ptr(side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb); } -void CUBLASWINAPI cublasStrmm (char side, char uplo, char transa, char diag, - int m, int n, float alpha, const float *A, int lda, - float *B, int ldb) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, char, int, int, float, const float *, int, float *, int); +void CUBLASWINAPI cublasStrmm(char side, char uplo, char transa, char diag, + int m, int n, float alpha, const float *A, + int lda, float *B, int ldb) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, char, int, int, float, + const float *, int, float *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasStrmm"); if (!func_ptr) LogFatalSymbolNotFound("cublasStrmm"); return func_ptr(side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb); } -void CUBLASWINAPI cublasDtrmm (char side, char uplo, char transa, - char diag, int m, int n, double alpha, - const double *A, int lda, double *B, - int ldb) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, char, int, int, double, const double *, int, double *, int); +void CUBLASWINAPI cublasDtrmm(char side, char uplo, char transa, char diag, + int m, int n, double alpha, const double *A, + int lda, double *B, int ldb) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, char, int, int, double, + const double *, int, double *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasDtrmm"); if (!func_ptr) LogFatalSymbolNotFound("cublasDtrmm"); return func_ptr(side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb); } -void CUBLASWINAPI cublasCtrmm (char side, char uplo, char transa, char diag, - int m, int n, cuComplex alpha, const cuComplex *A, - int lda, cuComplex *B, int ldb) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, char, int, int, cuComplex, const cuComplex *, int, cuComplex *, int); +void CUBLASWINAPI cublasCtrmm(char side, char uplo, char transa, char diag, + int m, int n, cuComplex alpha, const cuComplex *A, + int lda, cuComplex *B, int ldb) { + using FuncPtr = + void(CUBLASWINAPI *)(char, char, char, char, int, int, cuComplex, + const cuComplex *, int, cuComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasCtrmm"); if (!func_ptr) LogFatalSymbolNotFound("cublasCtrmm"); return func_ptr(side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb); } -void CUBLASWINAPI cublasZtrmm (char side, char uplo, char transa, - char diag, int m, int n, cuDoubleComplex alpha, - const cuDoubleComplex *A, int lda, cuDoubleComplex *B, - int ldb) { - using FuncPtr = void (CUBLASWINAPI *)(char, char, char, char, int, int, cuDoubleComplex, const cuDoubleComplex *, int, cuDoubleComplex *, int); +void CUBLASWINAPI cublasZtrmm(char side, char uplo, char transa, char diag, + int m, int n, cuDoubleComplex alpha, + const cuDoubleComplex *A, int lda, + cuDoubleComplex *B, int ldb) { + using FuncPtr = void(CUBLASWINAPI *)(char, char, char, char, int, int, + cuDoubleComplex, const cuDoubleComplex *, + int, cuDoubleComplex *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cublasZtrmm"); if (!func_ptr) LogFatalSymbolNotFound("cublasZtrmm"); return func_ptr(side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb); diff --git a/tensorflow/stream_executor/cuda/cuda_11_0.inc b/tensorflow/stream_executor/cuda/cuda_11_0.inc new file mode 100644 index 00000000000..18f3ff4cd57 --- /dev/null +++ b/tensorflow/stream_executor/cuda/cuda_11_0.inc @@ -0,0 +1,2430 @@ +// Auto-generated, do not edit. + +extern "C" { + +CUresult CUDAAPI cuGetErrorString(CUresult error, const char **pStr) { + using FuncPtr = CUresult(CUDAAPI *)(CUresult, const char **); + static auto func_ptr = LoadSymbol<FuncPtr>("cuGetErrorString"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(error, pStr); +} + +CUresult CUDAAPI cuGetErrorName(CUresult error, const char **pStr) { + using FuncPtr = CUresult(CUDAAPI *)(CUresult, const char **); + static auto func_ptr = LoadSymbol<FuncPtr>("cuGetErrorName"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(error, pStr); +} + +CUresult CUDAAPI cuInit(unsigned int Flags) { + using FuncPtr = CUresult(CUDAAPI *)(unsigned int); + static auto func_ptr = LoadSymbol<FuncPtr>("cuInit"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(Flags); +} + +CUresult CUDAAPI cuDriverGetVersion(int *driverVersion) { + using FuncPtr = CUresult(CUDAAPI *)(int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuDriverGetVersion"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(driverVersion); +} + +CUresult CUDAAPI cuDeviceGet(CUdevice *device, int ordinal) { + using FuncPtr = CUresult(CUDAAPI *)(CUdevice *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cuDeviceGet"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(device, ordinal); +} + +CUresult CUDAAPI cuDeviceGetCount(int *count) { + using FuncPtr = CUresult(CUDAAPI *)(int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuDeviceGetCount"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(count); +} + +CUresult CUDAAPI cuDeviceGetName(char *name, int len, CUdevice dev) { + using FuncPtr = CUresult(CUDAAPI *)(char *, int, CUdevice); + static auto func_ptr = LoadSymbol<FuncPtr>("cuDeviceGetName"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(name, len, dev); +} + +CUresult CUDAAPI cuDeviceGetUuid(CUuuid *uuid, CUdevice dev) { + using FuncPtr = CUresult(CUDAAPI *)(CUuuid *, CUdevice); + static auto func_ptr = LoadSymbol<FuncPtr>("cuDeviceGetUuid"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(uuid, dev); +} + +CUresult CUDAAPI cuDeviceTotalMem(size_t *bytes, CUdevice dev) { + using FuncPtr = CUresult(CUDAAPI *)(size_t *, CUdevice); + static auto func_ptr = LoadSymbol<FuncPtr>("cuDeviceTotalMem_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(bytes, dev); +} + +CUresult CUDAAPI cuDeviceGetAttribute(int *pi, CUdevice_attribute attrib, + CUdevice dev) { + using FuncPtr = CUresult(CUDAAPI *)(int *, CUdevice_attribute, CUdevice); + static auto func_ptr = LoadSymbol<FuncPtr>("cuDeviceGetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pi, attrib, dev); +} + +CUresult CUDAAPI cuDeviceGetNvSciSyncAttributes(void *nvSciSyncAttrList, + CUdevice dev, int flags) { + using FuncPtr = CUresult(CUDAAPI *)(void *, CUdevice, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cuDeviceGetNvSciSyncAttributes"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(nvSciSyncAttrList, dev, flags); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuDeviceGetProperties(CUdevprop *prop, + CUdevice dev) { + using FuncPtr = CUresult(CUDAAPI *)(CUdevprop *, CUdevice); + static auto func_ptr = LoadSymbol<FuncPtr>("cuDeviceGetProperties"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(prop, dev); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuDeviceComputeCapability(int *major, + int *minor, + CUdevice dev) { + using FuncPtr = CUresult(CUDAAPI *)(int *, int *, CUdevice); + static auto func_ptr = LoadSymbol<FuncPtr>("cuDeviceComputeCapability"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(major, minor, dev); +} + +CUresult CUDAAPI cuDevicePrimaryCtxRetain(CUcontext *pctx, CUdevice dev) { + using FuncPtr = CUresult(CUDAAPI *)(CUcontext *, CUdevice); + static auto func_ptr = LoadSymbol<FuncPtr>("cuDevicePrimaryCtxRetain"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pctx, dev); +} + +CUresult CUDAAPI cuDevicePrimaryCtxRelease(CUdevice dev) { + using FuncPtr = CUresult(CUDAAPI *)(CUdevice); + static auto func_ptr = LoadSymbol<FuncPtr>("cuDevicePrimaryCtxRelease_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dev); +} + +CUresult CUDAAPI cuDevicePrimaryCtxSetFlags(CUdevice dev, unsigned int flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUdevice, unsigned int); + static auto func_ptr = LoadSymbol<FuncPtr>("cuDevicePrimaryCtxSetFlags_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dev, flags); +} + +CUresult CUDAAPI cuDevicePrimaryCtxGetState(CUdevice dev, unsigned int *flags, + int *active) { + using FuncPtr = CUresult(CUDAAPI *)(CUdevice, unsigned int *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuDevicePrimaryCtxGetState"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dev, flags, active); +} + +CUresult CUDAAPI cuDevicePrimaryCtxReset(CUdevice dev) { + using FuncPtr = CUresult(CUDAAPI *)(CUdevice); + static auto func_ptr = LoadSymbol<FuncPtr>("cuDevicePrimaryCtxReset_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dev); +} + +CUresult CUDAAPI cuCtxCreate(CUcontext *pctx, unsigned int flags, + CUdevice dev) { + using FuncPtr = CUresult(CUDAAPI *)(CUcontext *, unsigned int, CUdevice); + static auto func_ptr = LoadSymbol<FuncPtr>("cuCtxCreate_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pctx, flags, dev); +} + +CUresult CUDAAPI cuCtxDestroy(CUcontext ctx) { + using FuncPtr = CUresult(CUDAAPI *)(CUcontext); + static auto func_ptr = LoadSymbol<FuncPtr>("cuCtxDestroy_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(ctx); +} + +CUresult CUDAAPI cuCtxPushCurrent(CUcontext ctx) { + using FuncPtr = CUresult(CUDAAPI *)(CUcontext); + static auto func_ptr = LoadSymbol<FuncPtr>("cuCtxPushCurrent_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(ctx); +} + +CUresult CUDAAPI cuCtxPopCurrent(CUcontext *pctx) { + using FuncPtr = CUresult(CUDAAPI *)(CUcontext *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuCtxPopCurrent_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pctx); +} + +CUresult CUDAAPI cuCtxSetCurrent(CUcontext ctx) { + using FuncPtr = CUresult(CUDAAPI *)(CUcontext); + static auto func_ptr = LoadSymbol<FuncPtr>("cuCtxSetCurrent"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(ctx); +} + +CUresult CUDAAPI cuCtxGetCurrent(CUcontext *pctx) { + using FuncPtr = CUresult(CUDAAPI *)(CUcontext *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuCtxGetCurrent"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pctx); +} + +CUresult CUDAAPI cuCtxGetDevice(CUdevice *device) { + using FuncPtr = CUresult(CUDAAPI *)(CUdevice *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuCtxGetDevice"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(device); +} + +CUresult CUDAAPI cuCtxGetFlags(unsigned int *flags) { + using FuncPtr = CUresult(CUDAAPI *)(unsigned int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuCtxGetFlags"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(flags); +} + +CUresult CUDAAPI cuCtxSynchronize(void) { + using FuncPtr = CUresult(CUDAAPI *)(); + static auto func_ptr = LoadSymbol<FuncPtr>("cuCtxSynchronize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(); +} + +CUresult CUDAAPI cuCtxSetLimit(CUlimit limit, size_t value) { + using FuncPtr = CUresult(CUDAAPI *)(CUlimit, size_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cuCtxSetLimit"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(limit, value); +} + +CUresult CUDAAPI cuCtxGetLimit(size_t *pvalue, CUlimit limit) { + using FuncPtr = CUresult(CUDAAPI *)(size_t *, CUlimit); + static auto func_ptr = LoadSymbol<FuncPtr>("cuCtxGetLimit"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pvalue, limit); +} + +CUresult CUDAAPI cuCtxGetCacheConfig(CUfunc_cache *pconfig) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunc_cache *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuCtxGetCacheConfig"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pconfig); +} + +CUresult CUDAAPI cuCtxSetCacheConfig(CUfunc_cache config) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunc_cache); + static auto func_ptr = LoadSymbol<FuncPtr>("cuCtxSetCacheConfig"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(config); +} + +CUresult CUDAAPI cuCtxGetSharedMemConfig(CUsharedconfig *pConfig) { + using FuncPtr = CUresult(CUDAAPI *)(CUsharedconfig *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuCtxGetSharedMemConfig"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pConfig); +} + +CUresult CUDAAPI cuCtxSetSharedMemConfig(CUsharedconfig config) { + using FuncPtr = CUresult(CUDAAPI *)(CUsharedconfig); + static auto func_ptr = LoadSymbol<FuncPtr>("cuCtxSetSharedMemConfig"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(config); +} + +CUresult CUDAAPI cuCtxGetApiVersion(CUcontext ctx, unsigned int *version) { + using FuncPtr = CUresult(CUDAAPI *)(CUcontext, unsigned int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuCtxGetApiVersion"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(ctx, version); +} + +CUresult CUDAAPI cuCtxGetStreamPriorityRange(int *leastPriority, + int *greatestPriority) { + using FuncPtr = CUresult(CUDAAPI *)(int *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuCtxGetStreamPriorityRange"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(leastPriority, greatestPriority); +} + +CUresult CUDAAPI cuCtxResetPersistingL2Cache(void) { + using FuncPtr = CUresult(CUDAAPI *)(); + static auto func_ptr = LoadSymbol<FuncPtr>("cuCtxResetPersistingL2Cache"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuCtxAttach(CUcontext *pctx, + unsigned int flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUcontext *, unsigned int); + static auto func_ptr = LoadSymbol<FuncPtr>("cuCtxAttach"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pctx, flags); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuCtxDetach(CUcontext ctx) { + using FuncPtr = CUresult(CUDAAPI *)(CUcontext); + static auto func_ptr = LoadSymbol<FuncPtr>("cuCtxDetach"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(ctx); +} + +CUresult CUDAAPI cuModuleLoad(CUmodule *module, const char *fname) { + using FuncPtr = CUresult(CUDAAPI *)(CUmodule *, const char *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuModuleLoad"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(module, fname); +} + +CUresult CUDAAPI cuModuleLoadData(CUmodule *module, const void *image) { + using FuncPtr = CUresult(CUDAAPI *)(CUmodule *, const void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuModuleLoadData"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(module, image); +} + +CUresult CUDAAPI cuModuleLoadDataEx(CUmodule *module, const void *image, + unsigned int numOptions, + CUjit_option *options, + void **optionValues) { + using FuncPtr = CUresult(CUDAAPI *)(CUmodule *, const void *, unsigned int, + CUjit_option *, void **); + static auto func_ptr = LoadSymbol<FuncPtr>("cuModuleLoadDataEx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(module, image, numOptions, options, optionValues); +} + +CUresult CUDAAPI cuModuleLoadFatBinary(CUmodule *module, const void *fatCubin) { + using FuncPtr = CUresult(CUDAAPI *)(CUmodule *, const void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuModuleLoadFatBinary"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(module, fatCubin); +} + +CUresult CUDAAPI cuModuleUnload(CUmodule hmod) { + using FuncPtr = CUresult(CUDAAPI *)(CUmodule); + static auto func_ptr = LoadSymbol<FuncPtr>("cuModuleUnload"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hmod); +} + +CUresult CUDAAPI cuModuleGetFunction(CUfunction *hfunc, CUmodule hmod, + const char *name) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunction *, CUmodule, const char *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuModuleGetFunction"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hfunc, hmod, name); +} + +CUresult CUDAAPI cuModuleGetGlobal(CUdeviceptr *dptr, size_t *bytes, + CUmodule hmod, const char *name) { + using FuncPtr = + CUresult(CUDAAPI *)(CUdeviceptr *, size_t *, CUmodule, const char *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuModuleGetGlobal_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dptr, bytes, hmod, name); +} + +CUresult CUDAAPI cuModuleGetTexRef(CUtexref *pTexRef, CUmodule hmod, + const char *name) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref *, CUmodule, const char *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuModuleGetTexRef"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pTexRef, hmod, name); +} + +CUresult CUDAAPI cuModuleGetSurfRef(CUsurfref *pSurfRef, CUmodule hmod, + const char *name) { + using FuncPtr = CUresult(CUDAAPI *)(CUsurfref *, CUmodule, const char *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuModuleGetSurfRef"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pSurfRef, hmod, name); +} + +CUresult CUDAAPI cuLinkCreate(unsigned int numOptions, CUjit_option *options, + void **optionValues, CUlinkState *stateOut) { + using FuncPtr = + CUresult(CUDAAPI *)(unsigned int, CUjit_option *, void **, CUlinkState *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuLinkCreate_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(numOptions, options, optionValues, stateOut); +} + +CUresult CUDAAPI cuLinkAddData(CUlinkState state, CUjitInputType type, + void *data, size_t size, const char *name, + unsigned int numOptions, CUjit_option *options, + void **optionValues) { + using FuncPtr = + CUresult(CUDAAPI *)(CUlinkState, CUjitInputType, void *, size_t, + const char *, unsigned int, CUjit_option *, void **); + static auto func_ptr = LoadSymbol<FuncPtr>("cuLinkAddData_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(state, type, data, size, name, numOptions, options, + optionValues); +} + +CUresult CUDAAPI cuLinkAddFile(CUlinkState state, CUjitInputType type, + const char *path, unsigned int numOptions, + CUjit_option *options, void **optionValues) { + using FuncPtr = CUresult(CUDAAPI *)(CUlinkState, CUjitInputType, const char *, + unsigned int, CUjit_option *, void **); + static auto func_ptr = LoadSymbol<FuncPtr>("cuLinkAddFile_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(state, type, path, numOptions, options, optionValues); +} + +CUresult CUDAAPI cuLinkComplete(CUlinkState state, void **cubinOut, + size_t *sizeOut) { + using FuncPtr = CUresult(CUDAAPI *)(CUlinkState, void **, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuLinkComplete"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(state, cubinOut, sizeOut); +} + +CUresult CUDAAPI cuLinkDestroy(CUlinkState state) { + using FuncPtr = CUresult(CUDAAPI *)(CUlinkState); + static auto func_ptr = LoadSymbol<FuncPtr>("cuLinkDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(state); +} + +CUresult CUDAAPI cuMemGetInfo(size_t *free, size_t *total) { + using FuncPtr = CUresult(CUDAAPI *)(size_t *, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemGetInfo_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(free, total); +} + +CUresult CUDAAPI cuMemAlloc(CUdeviceptr *dptr, size_t bytesize) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr *, size_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemAlloc_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dptr, bytesize); +} + +CUresult CUDAAPI cuMemAllocPitch(CUdeviceptr *dptr, size_t *pPitch, + size_t WidthInBytes, size_t Height, + unsigned int ElementSizeBytes) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr *, size_t *, size_t, size_t, + unsigned int); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemAllocPitch_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dptr, pPitch, WidthInBytes, Height, ElementSizeBytes); +} + +CUresult CUDAAPI cuMemFree(CUdeviceptr dptr) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemFree_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dptr); +} + +CUresult CUDAAPI cuMemGetAddressRange(CUdeviceptr *pbase, size_t *psize, + CUdeviceptr dptr) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr *, size_t *, CUdeviceptr); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemGetAddressRange_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pbase, psize, dptr); +} + +CUresult CUDAAPI cuMemAllocHost(void **pp, size_t bytesize) { + using FuncPtr = CUresult(CUDAAPI *)(void **, size_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemAllocHost_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pp, bytesize); +} + +CUresult CUDAAPI cuMemFreeHost(void *p) { + using FuncPtr = CUresult(CUDAAPI *)(void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemFreeHost"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(p); +} + +CUresult CUDAAPI cuMemHostAlloc(void **pp, size_t bytesize, + unsigned int Flags) { + using FuncPtr = CUresult(CUDAAPI *)(void **, size_t, unsigned int); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemHostAlloc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pp, bytesize, Flags); +} + +CUresult CUDAAPI cuMemHostGetDevicePointer(CUdeviceptr *pdptr, void *p, + unsigned int Flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr *, void *, unsigned int); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemHostGetDevicePointer_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pdptr, p, Flags); +} + +CUresult CUDAAPI cuMemHostGetFlags(unsigned int *pFlags, void *p) { + using FuncPtr = CUresult(CUDAAPI *)(unsigned int *, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemHostGetFlags"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pFlags, p); +} + +CUresult CUDAAPI cuMemAllocManaged(CUdeviceptr *dptr, size_t bytesize, + unsigned int flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr *, size_t, unsigned int); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemAllocManaged"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dptr, bytesize, flags); +} + +CUresult CUDAAPI cuDeviceGetByPCIBusId(CUdevice *dev, const char *pciBusId) { + using FuncPtr = CUresult(CUDAAPI *)(CUdevice *, const char *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuDeviceGetByPCIBusId"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dev, pciBusId); +} + +CUresult CUDAAPI cuDeviceGetPCIBusId(char *pciBusId, int len, CUdevice dev) { + using FuncPtr = CUresult(CUDAAPI *)(char *, int, CUdevice); + static auto func_ptr = LoadSymbol<FuncPtr>("cuDeviceGetPCIBusId"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pciBusId, len, dev); +} + +CUresult CUDAAPI cuIpcGetEventHandle(CUipcEventHandle *pHandle, CUevent event) { + using FuncPtr = CUresult(CUDAAPI *)(CUipcEventHandle *, CUevent); + static auto func_ptr = LoadSymbol<FuncPtr>("cuIpcGetEventHandle"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pHandle, event); +} + +CUresult CUDAAPI cuIpcOpenEventHandle(CUevent *phEvent, + CUipcEventHandle handle) { + using FuncPtr = CUresult(CUDAAPI *)(CUevent *, CUipcEventHandle); + static auto func_ptr = LoadSymbol<FuncPtr>("cuIpcOpenEventHandle"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phEvent, handle); +} + +CUresult CUDAAPI cuIpcGetMemHandle(CUipcMemHandle *pHandle, CUdeviceptr dptr) { + using FuncPtr = CUresult(CUDAAPI *)(CUipcMemHandle *, CUdeviceptr); + static auto func_ptr = LoadSymbol<FuncPtr>("cuIpcGetMemHandle"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pHandle, dptr); +} + +CUresult CUDAAPI cuIpcOpenMemHandle(CUdeviceptr *pdptr, CUipcMemHandle handle, + unsigned int Flags) { + using FuncPtr = + CUresult(CUDAAPI *)(CUdeviceptr *, CUipcMemHandle, unsigned int); + static auto func_ptr = LoadSymbol<FuncPtr>("cuIpcOpenMemHandle"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pdptr, handle, Flags); +} + +CUresult CUDAAPI cuIpcCloseMemHandle(CUdeviceptr dptr) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr); + static auto func_ptr = LoadSymbol<FuncPtr>("cuIpcCloseMemHandle"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dptr); +} + +CUresult CUDAAPI cuMemHostRegister(void *p, size_t bytesize, + unsigned int Flags) { + using FuncPtr = CUresult(CUDAAPI *)(void *, size_t, unsigned int); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemHostRegister_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(p, bytesize, Flags); +} + +CUresult CUDAAPI cuMemHostUnregister(void *p) { + using FuncPtr = CUresult(CUDAAPI *)(void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemHostUnregister"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(p); +} + +CUresult CUDAAPI cuMemcpy(CUdeviceptr dst, CUdeviceptr src, size_t ByteCount) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr, CUdeviceptr, size_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemcpy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dst, src, ByteCount); +} + +CUresult CUDAAPI cuMemcpyPeer(CUdeviceptr dstDevice, CUcontext dstContext, + CUdeviceptr srcDevice, CUcontext srcContext, + size_t ByteCount) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr, CUcontext, CUdeviceptr, + CUcontext, size_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemcpyPeer"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, dstContext, srcDevice, srcContext, ByteCount); +} + +CUresult CUDAAPI cuMemcpyHtoD(CUdeviceptr dstDevice, const void *srcHost, + size_t ByteCount) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr, const void *, size_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemcpyHtoD_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, srcHost, ByteCount); +} + +CUresult CUDAAPI cuMemcpyDtoH(void *dstHost, CUdeviceptr srcDevice, + size_t ByteCount) { + using FuncPtr = CUresult(CUDAAPI *)(void *, CUdeviceptr, size_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemcpyDtoH_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstHost, srcDevice, ByteCount); +} + +CUresult CUDAAPI cuMemcpyDtoD(CUdeviceptr dstDevice, CUdeviceptr srcDevice, + size_t ByteCount) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr, CUdeviceptr, size_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemcpyDtoD_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, srcDevice, ByteCount); +} + +CUresult CUDAAPI cuMemcpyDtoA(CUarray dstArray, size_t dstOffset, + CUdeviceptr srcDevice, size_t ByteCount) { + using FuncPtr = CUresult(CUDAAPI *)(CUarray, size_t, CUdeviceptr, size_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemcpyDtoA_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstArray, dstOffset, srcDevice, ByteCount); +} + +CUresult CUDAAPI cuMemcpyAtoD(CUdeviceptr dstDevice, CUarray srcArray, + size_t srcOffset, size_t ByteCount) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr, CUarray, size_t, size_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemcpyAtoD_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, srcArray, srcOffset, ByteCount); +} + +CUresult CUDAAPI cuMemcpyHtoA(CUarray dstArray, size_t dstOffset, + const void *srcHost, size_t ByteCount) { + using FuncPtr = CUresult(CUDAAPI *)(CUarray, size_t, const void *, size_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemcpyHtoA_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstArray, dstOffset, srcHost, ByteCount); +} + +CUresult CUDAAPI cuMemcpyAtoH(void *dstHost, CUarray srcArray, size_t srcOffset, + size_t ByteCount) { + using FuncPtr = CUresult(CUDAAPI *)(void *, CUarray, size_t, size_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemcpyAtoH_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstHost, srcArray, srcOffset, ByteCount); +} + +CUresult CUDAAPI cuMemcpyAtoA(CUarray dstArray, size_t dstOffset, + CUarray srcArray, size_t srcOffset, + size_t ByteCount) { + using FuncPtr = CUresult(CUDAAPI *)(CUarray, size_t, CUarray, size_t, size_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemcpyAtoA_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstArray, dstOffset, srcArray, srcOffset, ByteCount); +} + +CUresult CUDAAPI cuMemcpy2D(const CUDA_MEMCPY2D *pCopy) { + using FuncPtr = CUresult(CUDAAPI *)(const CUDA_MEMCPY2D *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemcpy2D_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pCopy); +} + +CUresult CUDAAPI cuMemcpy2DUnaligned(const CUDA_MEMCPY2D *pCopy) { + using FuncPtr = CUresult(CUDAAPI *)(const CUDA_MEMCPY2D *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemcpy2DUnaligned_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pCopy); +} + +CUresult CUDAAPI cuMemcpy3D(const CUDA_MEMCPY3D *pCopy) { + using FuncPtr = CUresult(CUDAAPI *)(const CUDA_MEMCPY3D *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemcpy3D_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pCopy); +} + +CUresult CUDAAPI cuMemcpy3DPeer(const CUDA_MEMCPY3D_PEER *pCopy) { + using FuncPtr = CUresult(CUDAAPI *)(const CUDA_MEMCPY3D_PEER *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemcpy3DPeer"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pCopy); +} + +CUresult CUDAAPI cuMemcpyAsync(CUdeviceptr dst, CUdeviceptr src, + size_t ByteCount, CUstream hStream) { + using FuncPtr = + CUresult(CUDAAPI *)(CUdeviceptr, CUdeviceptr, size_t, CUstream); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemcpyAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dst, src, ByteCount, hStream); +} + +CUresult CUDAAPI cuMemcpyPeerAsync(CUdeviceptr dstDevice, CUcontext dstContext, + CUdeviceptr srcDevice, CUcontext srcContext, + size_t ByteCount, CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr, CUcontext, CUdeviceptr, + CUcontext, size_t, CUstream); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemcpyPeerAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, dstContext, srcDevice, srcContext, ByteCount, + hStream); +} + +CUresult CUDAAPI cuMemcpyHtoDAsync(CUdeviceptr dstDevice, const void *srcHost, + size_t ByteCount, CUstream hStream) { + using FuncPtr = + CUresult(CUDAAPI *)(CUdeviceptr, const void *, size_t, CUstream); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemcpyHtoDAsync_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, srcHost, ByteCount, hStream); +} + +CUresult CUDAAPI cuMemcpyDtoHAsync(void *dstHost, CUdeviceptr srcDevice, + size_t ByteCount, CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(void *, CUdeviceptr, size_t, CUstream); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemcpyDtoHAsync_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstHost, srcDevice, ByteCount, hStream); +} + +CUresult CUDAAPI cuMemcpyDtoDAsync(CUdeviceptr dstDevice, CUdeviceptr srcDevice, + size_t ByteCount, CUstream hStream) { + using FuncPtr = + CUresult(CUDAAPI *)(CUdeviceptr, CUdeviceptr, size_t, CUstream); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemcpyDtoDAsync_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, srcDevice, ByteCount, hStream); +} + +CUresult CUDAAPI cuMemcpyHtoAAsync(CUarray dstArray, size_t dstOffset, + const void *srcHost, size_t ByteCount, + CUstream hStream) { + using FuncPtr = + CUresult(CUDAAPI *)(CUarray, size_t, const void *, size_t, CUstream); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemcpyHtoAAsync_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstArray, dstOffset, srcHost, ByteCount, hStream); +} + +CUresult CUDAAPI cuMemcpyAtoHAsync(void *dstHost, CUarray srcArray, + size_t srcOffset, size_t ByteCount, + CUstream hStream) { + using FuncPtr = + CUresult(CUDAAPI *)(void *, CUarray, size_t, size_t, CUstream); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemcpyAtoHAsync_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstHost, srcArray, srcOffset, ByteCount, hStream); +} + +CUresult CUDAAPI cuMemcpy2DAsync(const CUDA_MEMCPY2D *pCopy, CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(const CUDA_MEMCPY2D *, CUstream); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemcpy2DAsync_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pCopy, hStream); +} + +CUresult CUDAAPI cuMemcpy3DAsync(const CUDA_MEMCPY3D *pCopy, CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(const CUDA_MEMCPY3D *, CUstream); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemcpy3DAsync_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pCopy, hStream); +} + +CUresult CUDAAPI cuMemcpy3DPeerAsync(const CUDA_MEMCPY3D_PEER *pCopy, + CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(const CUDA_MEMCPY3D_PEER *, CUstream); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemcpy3DPeerAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pCopy, hStream); +} + +CUresult CUDAAPI cuMemsetD8(CUdeviceptr dstDevice, unsigned char uc, size_t N) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr, unsigned char, size_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemsetD8_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, uc, N); +} + +CUresult CUDAAPI cuMemsetD16(CUdeviceptr dstDevice, unsigned short us, + size_t N) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr, unsigned short, size_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemsetD16_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, us, N); +} + +CUresult CUDAAPI cuMemsetD32(CUdeviceptr dstDevice, unsigned int ui, size_t N) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr, unsigned int, size_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemsetD32_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, ui, N); +} + +CUresult CUDAAPI cuMemsetD2D8(CUdeviceptr dstDevice, size_t dstPitch, + unsigned char uc, size_t Width, size_t Height) { + using FuncPtr = + CUresult(CUDAAPI *)(CUdeviceptr, size_t, unsigned char, size_t, size_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemsetD2D8_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, dstPitch, uc, Width, Height); +} + +CUresult CUDAAPI cuMemsetD2D16(CUdeviceptr dstDevice, size_t dstPitch, + unsigned short us, size_t Width, size_t Height) { + using FuncPtr = + CUresult(CUDAAPI *)(CUdeviceptr, size_t, unsigned short, size_t, size_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemsetD2D16_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, dstPitch, us, Width, Height); +} + +CUresult CUDAAPI cuMemsetD2D32(CUdeviceptr dstDevice, size_t dstPitch, + unsigned int ui, size_t Width, size_t Height) { + using FuncPtr = + CUresult(CUDAAPI *)(CUdeviceptr, size_t, unsigned int, size_t, size_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemsetD2D32_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, dstPitch, ui, Width, Height); +} + +CUresult CUDAAPI cuMemsetD8Async(CUdeviceptr dstDevice, unsigned char uc, + size_t N, CUstream hStream) { + using FuncPtr = + CUresult(CUDAAPI *)(CUdeviceptr, unsigned char, size_t, CUstream); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemsetD8Async"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, uc, N, hStream); +} + +CUresult CUDAAPI cuMemsetD16Async(CUdeviceptr dstDevice, unsigned short us, + size_t N, CUstream hStream) { + using FuncPtr = + CUresult(CUDAAPI *)(CUdeviceptr, unsigned short, size_t, CUstream); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemsetD16Async"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, us, N, hStream); +} + +CUresult CUDAAPI cuMemsetD32Async(CUdeviceptr dstDevice, unsigned int ui, + size_t N, CUstream hStream) { + using FuncPtr = + CUresult(CUDAAPI *)(CUdeviceptr, unsigned int, size_t, CUstream); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemsetD32Async"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, ui, N, hStream); +} + +CUresult CUDAAPI cuMemsetD2D8Async(CUdeviceptr dstDevice, size_t dstPitch, + unsigned char uc, size_t Width, + size_t Height, CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr, size_t, unsigned char, + size_t, size_t, CUstream); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemsetD2D8Async"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, dstPitch, uc, Width, Height, hStream); +} + +CUresult CUDAAPI cuMemsetD2D16Async(CUdeviceptr dstDevice, size_t dstPitch, + unsigned short us, size_t Width, + size_t Height, CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr, size_t, unsigned short, + size_t, size_t, CUstream); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemsetD2D16Async"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, dstPitch, us, Width, Height, hStream); +} + +CUresult CUDAAPI cuMemsetD2D32Async(CUdeviceptr dstDevice, size_t dstPitch, + unsigned int ui, size_t Width, + size_t Height, CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr, size_t, unsigned int, size_t, + size_t, CUstream); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemsetD2D32Async"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dstDevice, dstPitch, ui, Width, Height, hStream); +} + +CUresult CUDAAPI cuArrayCreate(CUarray *pHandle, + const CUDA_ARRAY_DESCRIPTOR *pAllocateArray) { + using FuncPtr = CUresult(CUDAAPI *)(CUarray *, const CUDA_ARRAY_DESCRIPTOR *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuArrayCreate_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pHandle, pAllocateArray); +} + +CUresult CUDAAPI cuArrayGetDescriptor(CUDA_ARRAY_DESCRIPTOR *pArrayDescriptor, + CUarray hArray) { + using FuncPtr = CUresult(CUDAAPI *)(CUDA_ARRAY_DESCRIPTOR *, CUarray); + static auto func_ptr = LoadSymbol<FuncPtr>("cuArrayGetDescriptor_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pArrayDescriptor, hArray); +} + +CUresult CUDAAPI cuArrayDestroy(CUarray hArray) { + using FuncPtr = CUresult(CUDAAPI *)(CUarray); + static auto func_ptr = LoadSymbol<FuncPtr>("cuArrayDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hArray); +} + +CUresult CUDAAPI cuArray3DCreate( + CUarray *pHandle, const CUDA_ARRAY3D_DESCRIPTOR *pAllocateArray) { + using FuncPtr = + CUresult(CUDAAPI *)(CUarray *, const CUDA_ARRAY3D_DESCRIPTOR *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuArray3DCreate_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pHandle, pAllocateArray); +} + +CUresult CUDAAPI cuArray3DGetDescriptor( + CUDA_ARRAY3D_DESCRIPTOR *pArrayDescriptor, CUarray hArray) { + using FuncPtr = CUresult(CUDAAPI *)(CUDA_ARRAY3D_DESCRIPTOR *, CUarray); + static auto func_ptr = LoadSymbol<FuncPtr>("cuArray3DGetDescriptor_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pArrayDescriptor, hArray); +} + +CUresult CUDAAPI +cuMipmappedArrayCreate(CUmipmappedArray *pHandle, + const CUDA_ARRAY3D_DESCRIPTOR *pMipmappedArrayDesc, + unsigned int numMipmapLevels) { + using FuncPtr = CUresult(CUDAAPI *)( + CUmipmappedArray *, const CUDA_ARRAY3D_DESCRIPTOR *, unsigned int); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMipmappedArrayCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pHandle, pMipmappedArrayDesc, numMipmapLevels); +} + +CUresult CUDAAPI cuMipmappedArrayGetLevel(CUarray *pLevelArray, + CUmipmappedArray hMipmappedArray, + unsigned int level) { + using FuncPtr = + CUresult(CUDAAPI *)(CUarray *, CUmipmappedArray, unsigned int); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMipmappedArrayGetLevel"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pLevelArray, hMipmappedArray, level); +} + +CUresult CUDAAPI cuMipmappedArrayDestroy(CUmipmappedArray hMipmappedArray) { + using FuncPtr = CUresult(CUDAAPI *)(CUmipmappedArray); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMipmappedArrayDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hMipmappedArray); +} + +CUresult CUDAAPI cuMemAddressReserve(CUdeviceptr *ptr, size_t size, + size_t alignment, CUdeviceptr addr, + unsigned long long flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr *, size_t, size_t, + CUdeviceptr, unsigned long long); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemAddressReserve"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(ptr, size, alignment, addr, flags); +} + +CUresult CUDAAPI cuMemAddressFree(CUdeviceptr ptr, size_t size) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr, size_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemAddressFree"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(ptr, size); +} + +CUresult CUDAAPI cuMemCreate(CUmemGenericAllocationHandle *handle, size_t size, + const CUmemAllocationProp *prop, + unsigned long long flags) { + using FuncPtr = + CUresult(CUDAAPI *)(CUmemGenericAllocationHandle *, size_t, + const CUmemAllocationProp *, unsigned long long); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, size, prop, flags); +} + +CUresult CUDAAPI cuMemRelease(CUmemGenericAllocationHandle handle) { + using FuncPtr = CUresult(CUDAAPI *)(CUmemGenericAllocationHandle); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemRelease"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle); +} + +CUresult CUDAAPI cuMemMap(CUdeviceptr ptr, size_t size, size_t offset, + CUmemGenericAllocationHandle handle, + unsigned long long flags) { + using FuncPtr = + CUresult(CUDAAPI *)(CUdeviceptr, size_t, size_t, + CUmemGenericAllocationHandle, unsigned long long); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemMap"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(ptr, size, offset, handle, flags); +} + +CUresult CUDAAPI cuMemUnmap(CUdeviceptr ptr, size_t size) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr, size_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemUnmap"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(ptr, size); +} + +CUresult CUDAAPI cuMemSetAccess(CUdeviceptr ptr, size_t size, + const CUmemAccessDesc *desc, size_t count) { + using FuncPtr = + CUresult(CUDAAPI *)(CUdeviceptr, size_t, const CUmemAccessDesc *, size_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemSetAccess"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(ptr, size, desc, count); +} + +CUresult CUDAAPI cuMemGetAccess(unsigned long long *flags, + const CUmemLocation *location, + CUdeviceptr ptr) { + using FuncPtr = CUresult(CUDAAPI *)(unsigned long long *, + const CUmemLocation *, CUdeviceptr); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemGetAccess"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(flags, location, ptr); +} + +CUresult CUDAAPI cuMemExportToShareableHandle( + void *shareableHandle, CUmemGenericAllocationHandle handle, + CUmemAllocationHandleType handleType, unsigned long long flags) { + using FuncPtr = + CUresult(CUDAAPI *)(void *, CUmemGenericAllocationHandle, + CUmemAllocationHandleType, unsigned long long); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemExportToShareableHandle"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(shareableHandle, handle, handleType, flags); +} + +CUresult CUDAAPI cuMemImportFromShareableHandle( + CUmemGenericAllocationHandle *handle, void *osHandle, + CUmemAllocationHandleType shHandleType) { + using FuncPtr = CUresult(CUDAAPI *)(CUmemGenericAllocationHandle *, void *, + CUmemAllocationHandleType); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemImportFromShareableHandle"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, osHandle, shHandleType); +} + +CUresult CUDAAPI cuMemGetAllocationGranularity( + size_t *granularity, const CUmemAllocationProp *prop, + CUmemAllocationGranularity_flags option) { + using FuncPtr = CUresult(CUDAAPI *)(size_t *, const CUmemAllocationProp *, + CUmemAllocationGranularity_flags); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemGetAllocationGranularity"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(granularity, prop, option); +} + +CUresult CUDAAPI cuMemGetAllocationPropertiesFromHandle( + CUmemAllocationProp *prop, CUmemGenericAllocationHandle handle) { + using FuncPtr = + CUresult(CUDAAPI *)(CUmemAllocationProp *, CUmemGenericAllocationHandle); + static auto func_ptr = + LoadSymbol<FuncPtr>("cuMemGetAllocationPropertiesFromHandle"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(prop, handle); +} + +CUresult CUDAAPI +cuMemRetainAllocationHandle(CUmemGenericAllocationHandle *handle, void *addr) { + using FuncPtr = CUresult(CUDAAPI *)(CUmemGenericAllocationHandle *, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemRetainAllocationHandle"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, addr); +} + +CUresult CUDAAPI cuPointerGetAttribute(void *data, + CUpointer_attribute attribute, + CUdeviceptr ptr) { + using FuncPtr = CUresult(CUDAAPI *)(void *, CUpointer_attribute, CUdeviceptr); + static auto func_ptr = LoadSymbol<FuncPtr>("cuPointerGetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(data, attribute, ptr); +} + +CUresult CUDAAPI cuMemPrefetchAsync(CUdeviceptr devPtr, size_t count, + CUdevice dstDevice, CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr, size_t, CUdevice, CUstream); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemPrefetchAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(devPtr, count, dstDevice, hStream); +} + +CUresult CUDAAPI cuMemAdvise(CUdeviceptr devPtr, size_t count, + CUmem_advise advice, CUdevice device) { + using FuncPtr = + CUresult(CUDAAPI *)(CUdeviceptr, size_t, CUmem_advise, CUdevice); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemAdvise"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(devPtr, count, advice, device); +} + +CUresult CUDAAPI cuMemRangeGetAttribute(void *data, size_t dataSize, + CUmem_range_attribute attribute, + CUdeviceptr devPtr, size_t count) { + using FuncPtr = CUresult(CUDAAPI *)(void *, size_t, CUmem_range_attribute, + CUdeviceptr, size_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemRangeGetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(data, dataSize, attribute, devPtr, count); +} + +CUresult CUDAAPI cuMemRangeGetAttributes(void **data, size_t *dataSizes, + CUmem_range_attribute *attributes, + size_t numAttributes, + CUdeviceptr devPtr, size_t count) { + using FuncPtr = CUresult(CUDAAPI *)( + void **, size_t *, CUmem_range_attribute *, size_t, CUdeviceptr, size_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cuMemRangeGetAttributes"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(data, dataSizes, attributes, numAttributes, devPtr, count); +} + +CUresult CUDAAPI cuPointerSetAttribute(const void *value, + CUpointer_attribute attribute, + CUdeviceptr ptr) { + using FuncPtr = + CUresult(CUDAAPI *)(const void *, CUpointer_attribute, CUdeviceptr); + static auto func_ptr = LoadSymbol<FuncPtr>("cuPointerSetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(value, attribute, ptr); +} + +CUresult CUDAAPI cuPointerGetAttributes(unsigned int numAttributes, + CUpointer_attribute *attributes, + void **data, CUdeviceptr ptr) { + using FuncPtr = CUresult(CUDAAPI *)(unsigned int, CUpointer_attribute *, + void **, CUdeviceptr); + static auto func_ptr = LoadSymbol<FuncPtr>("cuPointerGetAttributes"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(numAttributes, attributes, data, ptr); +} + +CUresult CUDAAPI cuStreamCreate(CUstream *phStream, unsigned int Flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUstream *, unsigned int); + static auto func_ptr = LoadSymbol<FuncPtr>("cuStreamCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phStream, Flags); +} + +CUresult CUDAAPI cuStreamCreateWithPriority(CUstream *phStream, + unsigned int flags, int priority) { + using FuncPtr = CUresult(CUDAAPI *)(CUstream *, unsigned int, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cuStreamCreateWithPriority"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phStream, flags, priority); +} + +CUresult CUDAAPI cuStreamGetPriority(CUstream hStream, int *priority) { + using FuncPtr = CUresult(CUDAAPI *)(CUstream, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuStreamGetPriority"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream, priority); +} + +CUresult CUDAAPI cuStreamGetFlags(CUstream hStream, unsigned int *flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUstream, unsigned int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuStreamGetFlags"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream, flags); +} + +CUresult CUDAAPI cuStreamGetCtx(CUstream hStream, CUcontext *pctx) { + using FuncPtr = CUresult(CUDAAPI *)(CUstream, CUcontext *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuStreamGetCtx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream, pctx); +} + +CUresult CUDAAPI cuStreamWaitEvent(CUstream hStream, CUevent hEvent, + unsigned int Flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUstream, CUevent, unsigned int); + static auto func_ptr = LoadSymbol<FuncPtr>("cuStreamWaitEvent"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream, hEvent, Flags); +} + +CUresult CUDAAPI cuStreamAddCallback(CUstream hStream, + CUstreamCallback callback, void *userData, + unsigned int flags) { + using FuncPtr = + CUresult(CUDAAPI *)(CUstream, CUstreamCallback, void *, unsigned int); + static auto func_ptr = LoadSymbol<FuncPtr>("cuStreamAddCallback"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream, callback, userData, flags); +} + +CUresult CUDAAPI cuStreamBeginCapture(CUstream hStream, + CUstreamCaptureMode mode) { + using FuncPtr = CUresult(CUDAAPI *)(CUstream, CUstreamCaptureMode); + static auto func_ptr = LoadSymbol<FuncPtr>("cuStreamBeginCapture_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream, mode); +} + +CUresult CUDAAPI cuThreadExchangeStreamCaptureMode(CUstreamCaptureMode *mode) { + using FuncPtr = CUresult(CUDAAPI *)(CUstreamCaptureMode *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cuThreadExchangeStreamCaptureMode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(mode); +} + +CUresult CUDAAPI cuStreamEndCapture(CUstream hStream, CUgraph *phGraph) { + using FuncPtr = CUresult(CUDAAPI *)(CUstream, CUgraph *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuStreamEndCapture"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream, phGraph); +} + +CUresult CUDAAPI cuStreamIsCapturing(CUstream hStream, + CUstreamCaptureStatus *captureStatus) { + using FuncPtr = CUresult(CUDAAPI *)(CUstream, CUstreamCaptureStatus *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuStreamIsCapturing"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream, captureStatus); +} + +CUresult CUDAAPI cuStreamGetCaptureInfo(CUstream hStream, + CUstreamCaptureStatus *captureStatus, + cuuint64_t *id) { + using FuncPtr = + CUresult(CUDAAPI *)(CUstream, CUstreamCaptureStatus *, cuuint64_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuStreamGetCaptureInfo"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream, captureStatus, id); +} + +CUresult CUDAAPI cuStreamAttachMemAsync(CUstream hStream, CUdeviceptr dptr, + size_t length, unsigned int flags) { + using FuncPtr = + CUresult(CUDAAPI *)(CUstream, CUdeviceptr, size_t, unsigned int); + static auto func_ptr = LoadSymbol<FuncPtr>("cuStreamAttachMemAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream, dptr, length, flags); +} + +CUresult CUDAAPI cuStreamQuery(CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(CUstream); + static auto func_ptr = LoadSymbol<FuncPtr>("cuStreamQuery"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream); +} + +CUresult CUDAAPI cuStreamSynchronize(CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(CUstream); + static auto func_ptr = LoadSymbol<FuncPtr>("cuStreamSynchronize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream); +} + +CUresult CUDAAPI cuStreamDestroy(CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(CUstream); + static auto func_ptr = LoadSymbol<FuncPtr>("cuStreamDestroy_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream); +} + +CUresult CUDAAPI cuStreamCopyAttributes(CUstream dst, CUstream src) { + using FuncPtr = CUresult(CUDAAPI *)(CUstream, CUstream); + static auto func_ptr = LoadSymbol<FuncPtr>("cuStreamCopyAttributes"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dst, src); +} + +CUresult CUDAAPI cuStreamGetAttribute(CUstream hStream, CUstreamAttrID attr, + CUstreamAttrValue *value_out) { + using FuncPtr = + CUresult(CUDAAPI *)(CUstream, CUstreamAttrID, CUstreamAttrValue *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuStreamGetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream, attr, value_out); +} + +CUresult CUDAAPI cuStreamSetAttribute(CUstream hStream, CUstreamAttrID attr, + const CUstreamAttrValue *value) { + using FuncPtr = + CUresult(CUDAAPI *)(CUstream, CUstreamAttrID, const CUstreamAttrValue *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuStreamSetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream, attr, value); +} + +CUresult CUDAAPI cuEventCreate(CUevent *phEvent, unsigned int Flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUevent *, unsigned int); + static auto func_ptr = LoadSymbol<FuncPtr>("cuEventCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phEvent, Flags); +} + +CUresult CUDAAPI cuEventRecord(CUevent hEvent, CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(CUevent, CUstream); + static auto func_ptr = LoadSymbol<FuncPtr>("cuEventRecord"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hEvent, hStream); +} + +CUresult CUDAAPI cuEventQuery(CUevent hEvent) { + using FuncPtr = CUresult(CUDAAPI *)(CUevent); + static auto func_ptr = LoadSymbol<FuncPtr>("cuEventQuery"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hEvent); +} + +CUresult CUDAAPI cuEventSynchronize(CUevent hEvent) { + using FuncPtr = CUresult(CUDAAPI *)(CUevent); + static auto func_ptr = LoadSymbol<FuncPtr>("cuEventSynchronize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hEvent); +} + +CUresult CUDAAPI cuEventDestroy(CUevent hEvent) { + using FuncPtr = CUresult(CUDAAPI *)(CUevent); + static auto func_ptr = LoadSymbol<FuncPtr>("cuEventDestroy_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hEvent); +} + +CUresult CUDAAPI cuEventElapsedTime(float *pMilliseconds, CUevent hStart, + CUevent hEnd) { + using FuncPtr = CUresult(CUDAAPI *)(float *, CUevent, CUevent); + static auto func_ptr = LoadSymbol<FuncPtr>("cuEventElapsedTime"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pMilliseconds, hStart, hEnd); +} + +CUresult CUDAAPI +cuImportExternalMemory(CUexternalMemory *extMem_out, + const CUDA_EXTERNAL_MEMORY_HANDLE_DESC *memHandleDesc) { + using FuncPtr = CUresult(CUDAAPI *)(CUexternalMemory *, + const CUDA_EXTERNAL_MEMORY_HANDLE_DESC *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuImportExternalMemory"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(extMem_out, memHandleDesc); +} + +CUresult CUDAAPI cuExternalMemoryGetMappedBuffer( + CUdeviceptr *devPtr, CUexternalMemory extMem, + const CUDA_EXTERNAL_MEMORY_BUFFER_DESC *bufferDesc) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr *, CUexternalMemory, + const CUDA_EXTERNAL_MEMORY_BUFFER_DESC *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuExternalMemoryGetMappedBuffer"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(devPtr, extMem, bufferDesc); +} + +CUresult CUDAAPI cuExternalMemoryGetMappedMipmappedArray( + CUmipmappedArray *mipmap, CUexternalMemory extMem, + const CUDA_EXTERNAL_MEMORY_MIPMAPPED_ARRAY_DESC *mipmapDesc) { + using FuncPtr = + CUresult(CUDAAPI *)(CUmipmappedArray *, CUexternalMemory, + const CUDA_EXTERNAL_MEMORY_MIPMAPPED_ARRAY_DESC *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cuExternalMemoryGetMappedMipmappedArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(mipmap, extMem, mipmapDesc); +} + +CUresult CUDAAPI cuDestroyExternalMemory(CUexternalMemory extMem) { + using FuncPtr = CUresult(CUDAAPI *)(CUexternalMemory); + static auto func_ptr = LoadSymbol<FuncPtr>("cuDestroyExternalMemory"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(extMem); +} + +CUresult CUDAAPI cuImportExternalSemaphore( + CUexternalSemaphore *extSem_out, + const CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC *semHandleDesc) { + using FuncPtr = CUresult(CUDAAPI *)( + CUexternalSemaphore *, const CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuImportExternalSemaphore"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(extSem_out, semHandleDesc); +} + +CUresult CUDAAPI cuSignalExternalSemaphoresAsync( + const CUexternalSemaphore *extSemArray, + const CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS *paramsArray, + unsigned int numExtSems, CUstream stream) { + using FuncPtr = CUresult(CUDAAPI *)( + const CUexternalSemaphore *, + const CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS *, unsigned int, CUstream); + static auto func_ptr = LoadSymbol<FuncPtr>("cuSignalExternalSemaphoresAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(extSemArray, paramsArray, numExtSems, stream); +} + +CUresult CUDAAPI cuWaitExternalSemaphoresAsync( + const CUexternalSemaphore *extSemArray, + const CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS *paramsArray, + unsigned int numExtSems, CUstream stream) { + using FuncPtr = CUresult(CUDAAPI *)( + const CUexternalSemaphore *, const CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS *, + unsigned int, CUstream); + static auto func_ptr = LoadSymbol<FuncPtr>("cuWaitExternalSemaphoresAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(extSemArray, paramsArray, numExtSems, stream); +} + +CUresult CUDAAPI cuDestroyExternalSemaphore(CUexternalSemaphore extSem) { + using FuncPtr = CUresult(CUDAAPI *)(CUexternalSemaphore); + static auto func_ptr = LoadSymbol<FuncPtr>("cuDestroyExternalSemaphore"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(extSem); +} + +CUresult CUDAAPI cuStreamWaitValue32(CUstream stream, CUdeviceptr addr, + cuuint32_t value, unsigned int flags) { + using FuncPtr = + CUresult(CUDAAPI *)(CUstream, CUdeviceptr, cuuint32_t, unsigned int); + static auto func_ptr = LoadSymbol<FuncPtr>("cuStreamWaitValue32"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(stream, addr, value, flags); +} + +CUresult CUDAAPI cuStreamWaitValue64(CUstream stream, CUdeviceptr addr, + cuuint64_t value, unsigned int flags) { + using FuncPtr = + CUresult(CUDAAPI *)(CUstream, CUdeviceptr, cuuint64_t, unsigned int); + static auto func_ptr = LoadSymbol<FuncPtr>("cuStreamWaitValue64"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(stream, addr, value, flags); +} + +CUresult CUDAAPI cuStreamWriteValue32(CUstream stream, CUdeviceptr addr, + cuuint32_t value, unsigned int flags) { + using FuncPtr = + CUresult(CUDAAPI *)(CUstream, CUdeviceptr, cuuint32_t, unsigned int); + static auto func_ptr = LoadSymbol<FuncPtr>("cuStreamWriteValue32"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(stream, addr, value, flags); +} + +CUresult CUDAAPI cuStreamWriteValue64(CUstream stream, CUdeviceptr addr, + cuuint64_t value, unsigned int flags) { + using FuncPtr = + CUresult(CUDAAPI *)(CUstream, CUdeviceptr, cuuint64_t, unsigned int); + static auto func_ptr = LoadSymbol<FuncPtr>("cuStreamWriteValue64"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(stream, addr, value, flags); +} + +CUresult CUDAAPI cuStreamBatchMemOp(CUstream stream, unsigned int count, + CUstreamBatchMemOpParams *paramArray, + unsigned int flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUstream, unsigned int, + CUstreamBatchMemOpParams *, unsigned int); + static auto func_ptr = LoadSymbol<FuncPtr>("cuStreamBatchMemOp"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(stream, count, paramArray, flags); +} + +CUresult CUDAAPI cuFuncGetAttribute(int *pi, CUfunction_attribute attrib, + CUfunction hfunc) { + using FuncPtr = CUresult(CUDAAPI *)(int *, CUfunction_attribute, CUfunction); + static auto func_ptr = LoadSymbol<FuncPtr>("cuFuncGetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pi, attrib, hfunc); +} + +CUresult CUDAAPI cuFuncSetAttribute(CUfunction hfunc, + CUfunction_attribute attrib, int value) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunction, CUfunction_attribute, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cuFuncSetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hfunc, attrib, value); +} + +CUresult CUDAAPI cuFuncSetCacheConfig(CUfunction hfunc, CUfunc_cache config) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunction, CUfunc_cache); + static auto func_ptr = LoadSymbol<FuncPtr>("cuFuncSetCacheConfig"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hfunc, config); +} + +CUresult CUDAAPI cuFuncSetSharedMemConfig(CUfunction hfunc, + CUsharedconfig config) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunction, CUsharedconfig); + static auto func_ptr = LoadSymbol<FuncPtr>("cuFuncSetSharedMemConfig"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hfunc, config); +} + +CUresult CUDAAPI cuLaunchKernel(CUfunction f, unsigned int gridDimX, + unsigned int gridDimY, unsigned int gridDimZ, + unsigned int blockDimX, unsigned int blockDimY, + unsigned int blockDimZ, + unsigned int sharedMemBytes, CUstream hStream, + void **kernelParams, void **extra) { + using FuncPtr = CUresult(CUDAAPI *)( + CUfunction, unsigned int, unsigned int, unsigned int, unsigned int, + unsigned int, unsigned int, unsigned int, CUstream, void **, void **); + static auto func_ptr = LoadSymbol<FuncPtr>("cuLaunchKernel"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, + blockDimZ, sharedMemBytes, hStream, kernelParams, extra); +} + +CUresult CUDAAPI cuLaunchCooperativeKernel( + CUfunction f, unsigned int gridDimX, unsigned int gridDimY, + unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, + unsigned int blockDimZ, unsigned int sharedMemBytes, CUstream hStream, + void **kernelParams) { + using FuncPtr = CUresult(CUDAAPI *)( + CUfunction, unsigned int, unsigned int, unsigned int, unsigned int, + unsigned int, unsigned int, unsigned int, CUstream, void **); + static auto func_ptr = LoadSymbol<FuncPtr>("cuLaunchCooperativeKernel"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, + blockDimZ, sharedMemBytes, hStream, kernelParams); +} + +CUresult CUDAAPI cuLaunchCooperativeKernelMultiDevice( + CUDA_LAUNCH_PARAMS *launchParamsList, unsigned int numDevices, + unsigned int flags) { + using FuncPtr = + CUresult(CUDAAPI *)(CUDA_LAUNCH_PARAMS *, unsigned int, unsigned int); + static auto func_ptr = + LoadSymbol<FuncPtr>("cuLaunchCooperativeKernelMultiDevice"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(launchParamsList, numDevices, flags); +} + +CUresult CUDAAPI cuLaunchHostFunc(CUstream hStream, CUhostFn fn, + void *userData) { + using FuncPtr = CUresult(CUDAAPI *)(CUstream, CUhostFn, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuLaunchHostFunc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream, fn, userData); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuFuncSetBlockShape(CUfunction hfunc, int x, + int y, int z) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunction, int, int, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cuFuncSetBlockShape"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hfunc, x, y, z); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuFuncSetSharedSize(CUfunction hfunc, + unsigned int bytes) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunction, unsigned int); + static auto func_ptr = LoadSymbol<FuncPtr>("cuFuncSetSharedSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hfunc, bytes); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuParamSetSize(CUfunction hfunc, + unsigned int numbytes) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunction, unsigned int); + static auto func_ptr = LoadSymbol<FuncPtr>("cuParamSetSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hfunc, numbytes); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuParamSeti(CUfunction hfunc, int offset, + unsigned int value) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunction, int, unsigned int); + static auto func_ptr = LoadSymbol<FuncPtr>("cuParamSeti"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hfunc, offset, value); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuParamSetf(CUfunction hfunc, int offset, + float value) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunction, int, float); + static auto func_ptr = LoadSymbol<FuncPtr>("cuParamSetf"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hfunc, offset, value); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuParamSetv(CUfunction hfunc, int offset, + void *ptr, + unsigned int numbytes) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunction, int, void *, unsigned int); + static auto func_ptr = LoadSymbol<FuncPtr>("cuParamSetv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hfunc, offset, ptr, numbytes); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuLaunch(CUfunction f) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunction); + static auto func_ptr = LoadSymbol<FuncPtr>("cuLaunch"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(f); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuLaunchGrid(CUfunction f, int grid_width, + int grid_height) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunction, int, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cuLaunchGrid"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(f, grid_width, grid_height); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuLaunchGridAsync(CUfunction f, + int grid_width, + int grid_height, + CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunction, int, int, CUstream); + static auto func_ptr = LoadSymbol<FuncPtr>("cuLaunchGridAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(f, grid_width, grid_height, hStream); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuParamSetTexRef(CUfunction hfunc, + int texunit, + CUtexref hTexRef) { + using FuncPtr = CUresult(CUDAAPI *)(CUfunction, int, CUtexref); + static auto func_ptr = LoadSymbol<FuncPtr>("cuParamSetTexRef"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hfunc, texunit, hTexRef); +} + +CUresult CUDAAPI cuGraphCreate(CUgraph *phGraph, unsigned int flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraph *, unsigned int); + static auto func_ptr = LoadSymbol<FuncPtr>("cuGraphCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phGraph, flags); +} + +CUresult CUDAAPI cuGraphAddKernelNode( + CUgraphNode *phGraphNode, CUgraph hGraph, const CUgraphNode *dependencies, + size_t numDependencies, const CUDA_KERNEL_NODE_PARAMS *nodeParams) { + using FuncPtr = + CUresult(CUDAAPI *)(CUgraphNode *, CUgraph, const CUgraphNode *, size_t, + const CUDA_KERNEL_NODE_PARAMS *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuGraphAddKernelNode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phGraphNode, hGraph, dependencies, numDependencies, + nodeParams); +} + +CUresult CUDAAPI cuGraphKernelNodeGetParams( + CUgraphNode hNode, CUDA_KERNEL_NODE_PARAMS *nodeParams) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphNode, CUDA_KERNEL_NODE_PARAMS *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuGraphKernelNodeGetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hNode, nodeParams); +} + +CUresult CUDAAPI cuGraphKernelNodeSetParams( + CUgraphNode hNode, const CUDA_KERNEL_NODE_PARAMS *nodeParams) { + using FuncPtr = + CUresult(CUDAAPI *)(CUgraphNode, const CUDA_KERNEL_NODE_PARAMS *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuGraphKernelNodeSetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hNode, nodeParams); +} + +CUresult CUDAAPI cuGraphAddMemcpyNode(CUgraphNode *phGraphNode, CUgraph hGraph, + const CUgraphNode *dependencies, + size_t numDependencies, + const CUDA_MEMCPY3D *copyParams, + CUcontext ctx) { + using FuncPtr = + CUresult(CUDAAPI *)(CUgraphNode *, CUgraph, const CUgraphNode *, size_t, + const CUDA_MEMCPY3D *, CUcontext); + static auto func_ptr = LoadSymbol<FuncPtr>("cuGraphAddMemcpyNode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phGraphNode, hGraph, dependencies, numDependencies, + copyParams, ctx); +} + +CUresult CUDAAPI cuGraphMemcpyNodeGetParams(CUgraphNode hNode, + CUDA_MEMCPY3D *nodeParams) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphNode, CUDA_MEMCPY3D *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuGraphMemcpyNodeGetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hNode, nodeParams); +} + +CUresult CUDAAPI cuGraphMemcpyNodeSetParams(CUgraphNode hNode, + const CUDA_MEMCPY3D *nodeParams) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphNode, const CUDA_MEMCPY3D *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuGraphMemcpyNodeSetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hNode, nodeParams); +} + +CUresult CUDAAPI cuGraphAddMemsetNode( + CUgraphNode *phGraphNode, CUgraph hGraph, const CUgraphNode *dependencies, + size_t numDependencies, const CUDA_MEMSET_NODE_PARAMS *memsetParams, + CUcontext ctx) { + using FuncPtr = + CUresult(CUDAAPI *)(CUgraphNode *, CUgraph, const CUgraphNode *, size_t, + const CUDA_MEMSET_NODE_PARAMS *, CUcontext); + static auto func_ptr = LoadSymbol<FuncPtr>("cuGraphAddMemsetNode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phGraphNode, hGraph, dependencies, numDependencies, + memsetParams, ctx); +} + +CUresult CUDAAPI cuGraphMemsetNodeGetParams( + CUgraphNode hNode, CUDA_MEMSET_NODE_PARAMS *nodeParams) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphNode, CUDA_MEMSET_NODE_PARAMS *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuGraphMemsetNodeGetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hNode, nodeParams); +} + +CUresult CUDAAPI cuGraphMemsetNodeSetParams( + CUgraphNode hNode, const CUDA_MEMSET_NODE_PARAMS *nodeParams) { + using FuncPtr = + CUresult(CUDAAPI *)(CUgraphNode, const CUDA_MEMSET_NODE_PARAMS *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuGraphMemsetNodeSetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hNode, nodeParams); +} + +CUresult CUDAAPI cuGraphAddHostNode(CUgraphNode *phGraphNode, CUgraph hGraph, + const CUgraphNode *dependencies, + size_t numDependencies, + const CUDA_HOST_NODE_PARAMS *nodeParams) { + using FuncPtr = + CUresult(CUDAAPI *)(CUgraphNode *, CUgraph, const CUgraphNode *, size_t, + const CUDA_HOST_NODE_PARAMS *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuGraphAddHostNode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phGraphNode, hGraph, dependencies, numDependencies, + nodeParams); +} + +CUresult CUDAAPI cuGraphHostNodeGetParams(CUgraphNode hNode, + CUDA_HOST_NODE_PARAMS *nodeParams) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphNode, CUDA_HOST_NODE_PARAMS *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuGraphHostNodeGetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hNode, nodeParams); +} + +CUresult CUDAAPI cuGraphHostNodeSetParams( + CUgraphNode hNode, const CUDA_HOST_NODE_PARAMS *nodeParams) { + using FuncPtr = + CUresult(CUDAAPI *)(CUgraphNode, const CUDA_HOST_NODE_PARAMS *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuGraphHostNodeSetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hNode, nodeParams); +} + +CUresult CUDAAPI cuGraphAddChildGraphNode(CUgraphNode *phGraphNode, + CUgraph hGraph, + const CUgraphNode *dependencies, + size_t numDependencies, + CUgraph childGraph) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphNode *, CUgraph, + const CUgraphNode *, size_t, CUgraph); + static auto func_ptr = LoadSymbol<FuncPtr>("cuGraphAddChildGraphNode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phGraphNode, hGraph, dependencies, numDependencies, + childGraph); +} + +CUresult CUDAAPI cuGraphChildGraphNodeGetGraph(CUgraphNode hNode, + CUgraph *phGraph) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphNode, CUgraph *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuGraphChildGraphNodeGetGraph"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hNode, phGraph); +} + +CUresult CUDAAPI cuGraphAddEmptyNode(CUgraphNode *phGraphNode, CUgraph hGraph, + const CUgraphNode *dependencies, + size_t numDependencies) { + using FuncPtr = + CUresult(CUDAAPI *)(CUgraphNode *, CUgraph, const CUgraphNode *, size_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cuGraphAddEmptyNode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phGraphNode, hGraph, dependencies, numDependencies); +} + +CUresult CUDAAPI cuGraphClone(CUgraph *phGraphClone, CUgraph originalGraph) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraph *, CUgraph); + static auto func_ptr = LoadSymbol<FuncPtr>("cuGraphClone"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phGraphClone, originalGraph); +} + +CUresult CUDAAPI cuGraphNodeFindInClone(CUgraphNode *phNode, + CUgraphNode hOriginalNode, + CUgraph hClonedGraph) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphNode *, CUgraphNode, CUgraph); + static auto func_ptr = LoadSymbol<FuncPtr>("cuGraphNodeFindInClone"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phNode, hOriginalNode, hClonedGraph); +} + +CUresult CUDAAPI cuGraphNodeGetType(CUgraphNode hNode, CUgraphNodeType *type) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphNode, CUgraphNodeType *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuGraphNodeGetType"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hNode, type); +} + +CUresult CUDAAPI cuGraphGetNodes(CUgraph hGraph, CUgraphNode *nodes, + size_t *numNodes) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraph, CUgraphNode *, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuGraphGetNodes"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hGraph, nodes, numNodes); +} + +CUresult CUDAAPI cuGraphGetRootNodes(CUgraph hGraph, CUgraphNode *rootNodes, + size_t *numRootNodes) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraph, CUgraphNode *, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuGraphGetRootNodes"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hGraph, rootNodes, numRootNodes); +} + +CUresult CUDAAPI cuGraphGetEdges(CUgraph hGraph, CUgraphNode *from, + CUgraphNode *to, size_t *numEdges) { + using FuncPtr = + CUresult(CUDAAPI *)(CUgraph, CUgraphNode *, CUgraphNode *, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuGraphGetEdges"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hGraph, from, to, numEdges); +} + +CUresult CUDAAPI cuGraphNodeGetDependencies(CUgraphNode hNode, + CUgraphNode *dependencies, + size_t *numDependencies) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphNode, CUgraphNode *, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuGraphNodeGetDependencies"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hNode, dependencies, numDependencies); +} + +CUresult CUDAAPI cuGraphNodeGetDependentNodes(CUgraphNode hNode, + CUgraphNode *dependentNodes, + size_t *numDependentNodes) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphNode, CUgraphNode *, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuGraphNodeGetDependentNodes"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hNode, dependentNodes, numDependentNodes); +} + +CUresult CUDAAPI cuGraphAddDependencies(CUgraph hGraph, const CUgraphNode *from, + const CUgraphNode *to, + size_t numDependencies) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraph, const CUgraphNode *, + const CUgraphNode *, size_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cuGraphAddDependencies"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hGraph, from, to, numDependencies); +} + +CUresult CUDAAPI cuGraphRemoveDependencies(CUgraph hGraph, + const CUgraphNode *from, + const CUgraphNode *to, + size_t numDependencies) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraph, const CUgraphNode *, + const CUgraphNode *, size_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cuGraphRemoveDependencies"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hGraph, from, to, numDependencies); +} + +CUresult CUDAAPI cuGraphDestroyNode(CUgraphNode hNode) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphNode); + static auto func_ptr = LoadSymbol<FuncPtr>("cuGraphDestroyNode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hNode); +} + +CUresult CUDAAPI cuGraphInstantiate(CUgraphExec *phGraphExec, CUgraph hGraph, + CUgraphNode *phErrorNode, char *logBuffer, + size_t bufferSize) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphExec *, CUgraph, CUgraphNode *, + char *, size_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cuGraphInstantiate_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phGraphExec, hGraph, phErrorNode, logBuffer, bufferSize); +} + +CUresult CUDAAPI +cuGraphExecKernelNodeSetParams(CUgraphExec hGraphExec, CUgraphNode hNode, + const CUDA_KERNEL_NODE_PARAMS *nodeParams) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphExec, CUgraphNode, + const CUDA_KERNEL_NODE_PARAMS *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuGraphExecKernelNodeSetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hGraphExec, hNode, nodeParams); +} + +CUresult CUDAAPI cuGraphExecMemcpyNodeSetParams(CUgraphExec hGraphExec, + CUgraphNode hNode, + const CUDA_MEMCPY3D *copyParams, + CUcontext ctx) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphExec, CUgraphNode, + const CUDA_MEMCPY3D *, CUcontext); + static auto func_ptr = LoadSymbol<FuncPtr>("cuGraphExecMemcpyNodeSetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hGraphExec, hNode, copyParams, ctx); +} + +CUresult CUDAAPI cuGraphExecMemsetNodeSetParams( + CUgraphExec hGraphExec, CUgraphNode hNode, + const CUDA_MEMSET_NODE_PARAMS *memsetParams, CUcontext ctx) { + using FuncPtr = CUresult(CUDAAPI *)( + CUgraphExec, CUgraphNode, const CUDA_MEMSET_NODE_PARAMS *, CUcontext); + static auto func_ptr = LoadSymbol<FuncPtr>("cuGraphExecMemsetNodeSetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hGraphExec, hNode, memsetParams, ctx); +} + +CUresult CUDAAPI +cuGraphExecHostNodeSetParams(CUgraphExec hGraphExec, CUgraphNode hNode, + const CUDA_HOST_NODE_PARAMS *nodeParams) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphExec, CUgraphNode, + const CUDA_HOST_NODE_PARAMS *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuGraphExecHostNodeSetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hGraphExec, hNode, nodeParams); +} + +CUresult CUDAAPI cuGraphLaunch(CUgraphExec hGraphExec, CUstream hStream) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphExec, CUstream); + static auto func_ptr = LoadSymbol<FuncPtr>("cuGraphLaunch"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hGraphExec, hStream); +} + +CUresult CUDAAPI cuGraphExecDestroy(CUgraphExec hGraphExec) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphExec); + static auto func_ptr = LoadSymbol<FuncPtr>("cuGraphExecDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hGraphExec); +} + +CUresult CUDAAPI cuGraphDestroy(CUgraph hGraph) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraph); + static auto func_ptr = LoadSymbol<FuncPtr>("cuGraphDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hGraph); +} + +CUresult CUDAAPI cuGraphExecUpdate(CUgraphExec hGraphExec, CUgraph hGraph, + CUgraphNode *hErrorNode_out, + CUgraphExecUpdateResult *updateResult_out) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphExec, CUgraph, CUgraphNode *, + CUgraphExecUpdateResult *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuGraphExecUpdate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hGraphExec, hGraph, hErrorNode_out, updateResult_out); +} + +CUresult CUDAAPI cuGraphKernelNodeCopyAttributes(CUgraphNode dst, + CUgraphNode src) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphNode, CUgraphNode); + static auto func_ptr = LoadSymbol<FuncPtr>("cuGraphKernelNodeCopyAttributes"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dst, src); +} + +CUresult CUDAAPI +cuGraphKernelNodeGetAttribute(CUgraphNode hNode, CUkernelNodeAttrID attr, + CUkernelNodeAttrValue *value_out) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphNode, CUkernelNodeAttrID, + CUkernelNodeAttrValue *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuGraphKernelNodeGetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hNode, attr, value_out); +} + +CUresult CUDAAPI +cuGraphKernelNodeSetAttribute(CUgraphNode hNode, CUkernelNodeAttrID attr, + const CUkernelNodeAttrValue *value) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphNode, CUkernelNodeAttrID, + const CUkernelNodeAttrValue *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuGraphKernelNodeSetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hNode, attr, value); +} + +CUresult CUDAAPI cuOccupancyMaxActiveBlocksPerMultiprocessor( + int *numBlocks, CUfunction func, int blockSize, size_t dynamicSMemSize) { + using FuncPtr = CUresult(CUDAAPI *)(int *, CUfunction, int, size_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cuOccupancyMaxActiveBlocksPerMultiprocessor"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(numBlocks, func, blockSize, dynamicSMemSize); +} + +CUresult CUDAAPI cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags( + int *numBlocks, CUfunction func, int blockSize, size_t dynamicSMemSize, + unsigned int flags) { + using FuncPtr = + CUresult(CUDAAPI *)(int *, CUfunction, int, size_t, unsigned int); + static auto func_ptr = LoadSymbol<FuncPtr>( + "cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(numBlocks, func, blockSize, dynamicSMemSize, flags); +} + +CUresult CUDAAPI cuOccupancyMaxPotentialBlockSize( + int *minGridSize, int *blockSize, CUfunction func, + CUoccupancyB2DSize blockSizeToDynamicSMemSize, size_t dynamicSMemSize, + int blockSizeLimit) { + using FuncPtr = CUresult(CUDAAPI *)(int *, int *, CUfunction, + CUoccupancyB2DSize, size_t, int); + static auto func_ptr = + LoadSymbol<FuncPtr>("cuOccupancyMaxPotentialBlockSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(minGridSize, blockSize, func, blockSizeToDynamicSMemSize, + dynamicSMemSize, blockSizeLimit); +} + +CUresult CUDAAPI cuOccupancyMaxPotentialBlockSizeWithFlags( + int *minGridSize, int *blockSize, CUfunction func, + CUoccupancyB2DSize blockSizeToDynamicSMemSize, size_t dynamicSMemSize, + int blockSizeLimit, unsigned int flags) { + using FuncPtr = CUresult(CUDAAPI *)( + int *, int *, CUfunction, CUoccupancyB2DSize, size_t, int, unsigned int); + static auto func_ptr = + LoadSymbol<FuncPtr>("cuOccupancyMaxPotentialBlockSizeWithFlags"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(minGridSize, blockSize, func, blockSizeToDynamicSMemSize, + dynamicSMemSize, blockSizeLimit, flags); +} + +CUresult CUDAAPI cuOccupancyAvailableDynamicSMemPerBlock( + size_t *dynamicSmemSize, CUfunction func, int numBlocks, int blockSize) { + using FuncPtr = CUresult(CUDAAPI *)(size_t *, CUfunction, int, int); + static auto func_ptr = + LoadSymbol<FuncPtr>("cuOccupancyAvailableDynamicSMemPerBlock"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dynamicSmemSize, func, numBlocks, blockSize); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetArray(CUtexref hTexRef, + CUarray hArray, + unsigned int Flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref, CUarray, unsigned int); + static auto func_ptr = LoadSymbol<FuncPtr>("cuTexRefSetArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hTexRef, hArray, Flags); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetMipmappedArray( + CUtexref hTexRef, CUmipmappedArray hMipmappedArray, unsigned int Flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref, CUmipmappedArray, unsigned int); + static auto func_ptr = LoadSymbol<FuncPtr>("cuTexRefSetMipmappedArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hTexRef, hMipmappedArray, Flags); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetAddress(size_t *ByteOffset, + CUtexref hTexRef, + CUdeviceptr dptr, + size_t bytes) { + using FuncPtr = CUresult(CUDAAPI *)(size_t *, CUtexref, CUdeviceptr, size_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cuTexRefSetAddress_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(ByteOffset, hTexRef, dptr, bytes); +} + +__CUDA_DEPRECATED CUresult CUDAAPI +cuTexRefSetAddress2D(CUtexref hTexRef, const CUDA_ARRAY_DESCRIPTOR *desc, + CUdeviceptr dptr, size_t Pitch) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref, const CUDA_ARRAY_DESCRIPTOR *, + CUdeviceptr, size_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cuTexRefSetAddress2D_v3"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hTexRef, desc, dptr, Pitch); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetFormat(CUtexref hTexRef, + CUarray_format fmt, + int NumPackedComponents) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref, CUarray_format, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cuTexRefSetFormat"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hTexRef, fmt, NumPackedComponents); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetAddressMode(CUtexref hTexRef, + int dim, + CUaddress_mode am) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref, int, CUaddress_mode); + static auto func_ptr = LoadSymbol<FuncPtr>("cuTexRefSetAddressMode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hTexRef, dim, am); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetFilterMode(CUtexref hTexRef, + CUfilter_mode fm) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref, CUfilter_mode); + static auto func_ptr = LoadSymbol<FuncPtr>("cuTexRefSetFilterMode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hTexRef, fm); +} + +__CUDA_DEPRECATED CUresult CUDAAPI +cuTexRefSetMipmapFilterMode(CUtexref hTexRef, CUfilter_mode fm) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref, CUfilter_mode); + static auto func_ptr = LoadSymbol<FuncPtr>("cuTexRefSetMipmapFilterMode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hTexRef, fm); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetMipmapLevelBias(CUtexref hTexRef, + float bias) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref, float); + static auto func_ptr = LoadSymbol<FuncPtr>("cuTexRefSetMipmapLevelBias"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hTexRef, bias); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetMipmapLevelClamp( + CUtexref hTexRef, float minMipmapLevelClamp, float maxMipmapLevelClamp) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref, float, float); + static auto func_ptr = LoadSymbol<FuncPtr>("cuTexRefSetMipmapLevelClamp"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hTexRef, minMipmapLevelClamp, maxMipmapLevelClamp); +} + +__CUDA_DEPRECATED CUresult CUDAAPI +cuTexRefSetMaxAnisotropy(CUtexref hTexRef, unsigned int maxAniso) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref, unsigned int); + static auto func_ptr = LoadSymbol<FuncPtr>("cuTexRefSetMaxAnisotropy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hTexRef, maxAniso); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetBorderColor(CUtexref hTexRef, + float *pBorderColor) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref, float *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuTexRefSetBorderColor"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hTexRef, pBorderColor); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetFlags(CUtexref hTexRef, + unsigned int Flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref, unsigned int); + static auto func_ptr = LoadSymbol<FuncPtr>("cuTexRefSetFlags"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hTexRef, Flags); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefGetAddress(CUdeviceptr *pdptr, + CUtexref hTexRef) { + using FuncPtr = CUresult(CUDAAPI *)(CUdeviceptr *, CUtexref); + static auto func_ptr = LoadSymbol<FuncPtr>("cuTexRefGetAddress_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pdptr, hTexRef); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefGetArray(CUarray *phArray, + CUtexref hTexRef) { + using FuncPtr = CUresult(CUDAAPI *)(CUarray *, CUtexref); + static auto func_ptr = LoadSymbol<FuncPtr>("cuTexRefGetArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phArray, hTexRef); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefGetMipmappedArray( + CUmipmappedArray *phMipmappedArray, CUtexref hTexRef) { + using FuncPtr = CUresult(CUDAAPI *)(CUmipmappedArray *, CUtexref); + static auto func_ptr = LoadSymbol<FuncPtr>("cuTexRefGetMipmappedArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phMipmappedArray, hTexRef); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefGetAddressMode(CUaddress_mode *pam, + CUtexref hTexRef, + int dim) { + using FuncPtr = CUresult(CUDAAPI *)(CUaddress_mode *, CUtexref, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cuTexRefGetAddressMode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pam, hTexRef, dim); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefGetFilterMode(CUfilter_mode *pfm, + CUtexref hTexRef) { + using FuncPtr = CUresult(CUDAAPI *)(CUfilter_mode *, CUtexref); + static auto func_ptr = LoadSymbol<FuncPtr>("cuTexRefGetFilterMode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pfm, hTexRef); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefGetFormat(CUarray_format *pFormat, + int *pNumChannels, + CUtexref hTexRef) { + using FuncPtr = CUresult(CUDAAPI *)(CUarray_format *, int *, CUtexref); + static auto func_ptr = LoadSymbol<FuncPtr>("cuTexRefGetFormat"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pFormat, pNumChannels, hTexRef); +} + +__CUDA_DEPRECATED CUresult CUDAAPI +cuTexRefGetMipmapFilterMode(CUfilter_mode *pfm, CUtexref hTexRef) { + using FuncPtr = CUresult(CUDAAPI *)(CUfilter_mode *, CUtexref); + static auto func_ptr = LoadSymbol<FuncPtr>("cuTexRefGetMipmapFilterMode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pfm, hTexRef); +} + +__CUDA_DEPRECATED CUresult CUDAAPI +cuTexRefGetMipmapLevelBias(float *pbias, CUtexref hTexRef) { + using FuncPtr = CUresult(CUDAAPI *)(float *, CUtexref); + static auto func_ptr = LoadSymbol<FuncPtr>("cuTexRefGetMipmapLevelBias"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pbias, hTexRef); +} + +__CUDA_DEPRECATED CUresult CUDAAPI +cuTexRefGetMipmapLevelClamp(float *pminMipmapLevelClamp, + float *pmaxMipmapLevelClamp, CUtexref hTexRef) { + using FuncPtr = CUresult(CUDAAPI *)(float *, float *, CUtexref); + static auto func_ptr = LoadSymbol<FuncPtr>("cuTexRefGetMipmapLevelClamp"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pminMipmapLevelClamp, pmaxMipmapLevelClamp, hTexRef); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefGetMaxAnisotropy(int *pmaxAniso, + CUtexref hTexRef) { + using FuncPtr = CUresult(CUDAAPI *)(int *, CUtexref); + static auto func_ptr = LoadSymbol<FuncPtr>("cuTexRefGetMaxAnisotropy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pmaxAniso, hTexRef); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefGetBorderColor(float *pBorderColor, + CUtexref hTexRef) { + using FuncPtr = CUresult(CUDAAPI *)(float *, CUtexref); + static auto func_ptr = LoadSymbol<FuncPtr>("cuTexRefGetBorderColor"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pBorderColor, hTexRef); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefGetFlags(unsigned int *pFlags, + CUtexref hTexRef) { + using FuncPtr = CUresult(CUDAAPI *)(unsigned int *, CUtexref); + static auto func_ptr = LoadSymbol<FuncPtr>("cuTexRefGetFlags"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pFlags, hTexRef); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefCreate(CUtexref *pTexRef) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuTexRefCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pTexRef); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefDestroy(CUtexref hTexRef) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexref); + static auto func_ptr = LoadSymbol<FuncPtr>("cuTexRefDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hTexRef); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuSurfRefSetArray(CUsurfref hSurfRef, + CUarray hArray, + unsigned int Flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUsurfref, CUarray, unsigned int); + static auto func_ptr = LoadSymbol<FuncPtr>("cuSurfRefSetArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hSurfRef, hArray, Flags); +} + +__CUDA_DEPRECATED CUresult CUDAAPI cuSurfRefGetArray(CUarray *phArray, + CUsurfref hSurfRef) { + using FuncPtr = CUresult(CUDAAPI *)(CUarray *, CUsurfref); + static auto func_ptr = LoadSymbol<FuncPtr>("cuSurfRefGetArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(phArray, hSurfRef); +} + +CUresult CUDAAPI +cuTexObjectCreate(CUtexObject *pTexObject, const CUDA_RESOURCE_DESC *pResDesc, + const CUDA_TEXTURE_DESC *pTexDesc, + const CUDA_RESOURCE_VIEW_DESC *pResViewDesc) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexObject *, const CUDA_RESOURCE_DESC *, + const CUDA_TEXTURE_DESC *, + const CUDA_RESOURCE_VIEW_DESC *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuTexObjectCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pTexObject, pResDesc, pTexDesc, pResViewDesc); +} + +CUresult CUDAAPI cuTexObjectDestroy(CUtexObject texObject) { + using FuncPtr = CUresult(CUDAAPI *)(CUtexObject); + static auto func_ptr = LoadSymbol<FuncPtr>("cuTexObjectDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(texObject); +} + +CUresult CUDAAPI cuTexObjectGetResourceDesc(CUDA_RESOURCE_DESC *pResDesc, + CUtexObject texObject) { + using FuncPtr = CUresult(CUDAAPI *)(CUDA_RESOURCE_DESC *, CUtexObject); + static auto func_ptr = LoadSymbol<FuncPtr>("cuTexObjectGetResourceDesc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pResDesc, texObject); +} + +CUresult CUDAAPI cuTexObjectGetTextureDesc(CUDA_TEXTURE_DESC *pTexDesc, + CUtexObject texObject) { + using FuncPtr = CUresult(CUDAAPI *)(CUDA_TEXTURE_DESC *, CUtexObject); + static auto func_ptr = LoadSymbol<FuncPtr>("cuTexObjectGetTextureDesc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pTexDesc, texObject); +} + +CUresult CUDAAPI cuTexObjectGetResourceViewDesc( + CUDA_RESOURCE_VIEW_DESC *pResViewDesc, CUtexObject texObject) { + using FuncPtr = CUresult(CUDAAPI *)(CUDA_RESOURCE_VIEW_DESC *, CUtexObject); + static auto func_ptr = LoadSymbol<FuncPtr>("cuTexObjectGetResourceViewDesc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pResViewDesc, texObject); +} + +CUresult CUDAAPI cuSurfObjectCreate(CUsurfObject *pSurfObject, + const CUDA_RESOURCE_DESC *pResDesc) { + using FuncPtr = + CUresult(CUDAAPI *)(CUsurfObject *, const CUDA_RESOURCE_DESC *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuSurfObjectCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pSurfObject, pResDesc); +} + +CUresult CUDAAPI cuSurfObjectDestroy(CUsurfObject surfObject) { + using FuncPtr = CUresult(CUDAAPI *)(CUsurfObject); + static auto func_ptr = LoadSymbol<FuncPtr>("cuSurfObjectDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(surfObject); +} + +CUresult CUDAAPI cuSurfObjectGetResourceDesc(CUDA_RESOURCE_DESC *pResDesc, + CUsurfObject surfObject) { + using FuncPtr = CUresult(CUDAAPI *)(CUDA_RESOURCE_DESC *, CUsurfObject); + static auto func_ptr = LoadSymbol<FuncPtr>("cuSurfObjectGetResourceDesc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pResDesc, surfObject); +} + +CUresult CUDAAPI cuDeviceCanAccessPeer(int *canAccessPeer, CUdevice dev, + CUdevice peerDev) { + using FuncPtr = CUresult(CUDAAPI *)(int *, CUdevice, CUdevice); + static auto func_ptr = LoadSymbol<FuncPtr>("cuDeviceCanAccessPeer"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(canAccessPeer, dev, peerDev); +} + +CUresult CUDAAPI cuCtxEnablePeerAccess(CUcontext peerContext, + unsigned int Flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUcontext, unsigned int); + static auto func_ptr = LoadSymbol<FuncPtr>("cuCtxEnablePeerAccess"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(peerContext, Flags); +} + +CUresult CUDAAPI cuCtxDisablePeerAccess(CUcontext peerContext) { + using FuncPtr = CUresult(CUDAAPI *)(CUcontext); + static auto func_ptr = LoadSymbol<FuncPtr>("cuCtxDisablePeerAccess"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(peerContext); +} + +CUresult CUDAAPI cuDeviceGetP2PAttribute(int *value, + CUdevice_P2PAttribute attrib, + CUdevice srcDevice, + CUdevice dstDevice) { + using FuncPtr = + CUresult(CUDAAPI *)(int *, CUdevice_P2PAttribute, CUdevice, CUdevice); + static auto func_ptr = LoadSymbol<FuncPtr>("cuDeviceGetP2PAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(value, attrib, srcDevice, dstDevice); +} + +CUresult CUDAAPI cuGraphicsUnregisterResource(CUgraphicsResource resource) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphicsResource); + static auto func_ptr = LoadSymbol<FuncPtr>("cuGraphicsUnregisterResource"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(resource); +} + +CUresult CUDAAPI cuGraphicsSubResourceGetMappedArray( + CUarray *pArray, CUgraphicsResource resource, unsigned int arrayIndex, + unsigned int mipLevel) { + using FuncPtr = CUresult(CUDAAPI *)(CUarray *, CUgraphicsResource, + unsigned int, unsigned int); + static auto func_ptr = + LoadSymbol<FuncPtr>("cuGraphicsSubResourceGetMappedArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pArray, resource, arrayIndex, mipLevel); +} + +CUresult CUDAAPI cuGraphicsResourceGetMappedMipmappedArray( + CUmipmappedArray *pMipmappedArray, CUgraphicsResource resource) { + using FuncPtr = CUresult(CUDAAPI *)(CUmipmappedArray *, CUgraphicsResource); + static auto func_ptr = + LoadSymbol<FuncPtr>("cuGraphicsResourceGetMappedMipmappedArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pMipmappedArray, resource); +} + +CUresult CUDAAPI cuGraphicsResourceGetMappedPointer( + CUdeviceptr *pDevPtr, size_t *pSize, CUgraphicsResource resource) { + using FuncPtr = + CUresult(CUDAAPI *)(CUdeviceptr *, size_t *, CUgraphicsResource); + static auto func_ptr = + LoadSymbol<FuncPtr>("cuGraphicsResourceGetMappedPointer_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pDevPtr, pSize, resource); +} + +CUresult CUDAAPI cuGraphicsResourceSetMapFlags(CUgraphicsResource resource, + unsigned int flags) { + using FuncPtr = CUresult(CUDAAPI *)(CUgraphicsResource, unsigned int); + static auto func_ptr = + LoadSymbol<FuncPtr>("cuGraphicsResourceSetMapFlags_v2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(resource, flags); +} + +CUresult CUDAAPI cuGraphicsMapResources(unsigned int count, + CUgraphicsResource *resources, + CUstream hStream) { + using FuncPtr = + CUresult(CUDAAPI *)(unsigned int, CUgraphicsResource *, CUstream); + static auto func_ptr = LoadSymbol<FuncPtr>("cuGraphicsMapResources"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(count, resources, hStream); +} + +CUresult CUDAAPI cuGraphicsUnmapResources(unsigned int count, + CUgraphicsResource *resources, + CUstream hStream) { + using FuncPtr = + CUresult(CUDAAPI *)(unsigned int, CUgraphicsResource *, CUstream); + static auto func_ptr = LoadSymbol<FuncPtr>("cuGraphicsUnmapResources"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(count, resources, hStream); +} + +CUresult CUDAAPI cuGetExportTable(const void **ppExportTable, + const CUuuid *pExportTableId) { + using FuncPtr = CUresult(CUDAAPI *)(const void **, const CUuuid *); + static auto func_ptr = LoadSymbol<FuncPtr>("cuGetExportTable"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(ppExportTable, pExportTableId); +} + +CUresult CUDAAPI cuFuncGetModule(CUmodule *hmod, CUfunction hfunc) { + using FuncPtr = CUresult(CUDAAPI *)(CUmodule *, CUfunction); + static auto func_ptr = LoadSymbol<FuncPtr>("cuFuncGetModule"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hmod, hfunc); +} + +} // extern "C" diff --git a/tensorflow/stream_executor/cuda/cuda_runtime_11_0.inc b/tensorflow/stream_executor/cuda/cuda_runtime_11_0.inc new file mode 100644 index 00000000000..df3ada219e2 --- /dev/null +++ b/tensorflow/stream_executor/cuda/cuda_runtime_11_0.inc @@ -0,0 +1,1974 @@ +// Auto-generated, do not edit. + +extern "C" { + +extern __host__ cudaError_t CUDARTAPI cudaDeviceReset(void) { + using FuncPtr = cudaError_t(CUDARTAPI *)(); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaDeviceReset"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaDeviceSynchronize(void) { + using FuncPtr = cudaError_t(CUDARTAPI *)(); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaDeviceSynchronize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(); +} + +extern __host__ cudaError_t CUDARTAPI cudaDeviceSetLimit(enum cudaLimit limit, + size_t value) { + using FuncPtr = cudaError_t(CUDARTAPI *)(enum cudaLimit, size_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaDeviceSetLimit"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(limit, value); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaDeviceGetLimit(size_t *pValue, enum cudaLimit limit) { + using FuncPtr = cudaError_t(CUDARTAPI *)(size_t *, enum cudaLimit); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaDeviceGetLimit"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pValue, limit); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaDeviceGetCacheConfig(enum cudaFuncCache *pCacheConfig) { + using FuncPtr = cudaError_t(CUDARTAPI *)(enum cudaFuncCache *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaDeviceGetCacheConfig"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pCacheConfig); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaDeviceGetStreamPriorityRange(int *leastPriority, int *greatestPriority) { + using FuncPtr = cudaError_t(CUDARTAPI *)(int *, int *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudaDeviceGetStreamPriorityRange"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(leastPriority, greatestPriority); +} + +extern __host__ cudaError_t CUDARTAPI +cudaDeviceSetCacheConfig(enum cudaFuncCache cacheConfig) { + using FuncPtr = cudaError_t(CUDARTAPI *)(enum cudaFuncCache); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaDeviceSetCacheConfig"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(cacheConfig); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaDeviceGetSharedMemConfig(enum cudaSharedMemConfig *pConfig) { + using FuncPtr = cudaError_t(CUDARTAPI *)(enum cudaSharedMemConfig *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaDeviceGetSharedMemConfig"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pConfig); +} + +extern __host__ cudaError_t CUDARTAPI +cudaDeviceSetSharedMemConfig(enum cudaSharedMemConfig config) { + using FuncPtr = cudaError_t(CUDARTAPI *)(enum cudaSharedMemConfig); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaDeviceSetSharedMemConfig"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(config); +} + +extern __host__ cudaError_t CUDARTAPI +cudaDeviceGetByPCIBusId(int *device, const char *pciBusId) { + using FuncPtr = cudaError_t(CUDARTAPI *)(int *, const char *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaDeviceGetByPCIBusId"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(device, pciBusId); +} + +extern __host__ cudaError_t CUDARTAPI cudaDeviceGetPCIBusId(char *pciBusId, + int len, + int device) { + using FuncPtr = cudaError_t(CUDARTAPI *)(char *, int, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaDeviceGetPCIBusId"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pciBusId, len, device); +} + +extern __host__ cudaError_t CUDARTAPI +cudaIpcGetEventHandle(cudaIpcEventHandle_t *handle, cudaEvent_t event) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaIpcEventHandle_t *, cudaEvent_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaIpcGetEventHandle"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, event); +} + +extern __host__ cudaError_t CUDARTAPI +cudaIpcOpenEventHandle(cudaEvent_t *event, cudaIpcEventHandle_t handle) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaEvent_t *, cudaIpcEventHandle_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaIpcOpenEventHandle"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(event, handle); +} + +extern __host__ cudaError_t CUDARTAPI +cudaIpcGetMemHandle(cudaIpcMemHandle_t *handle, void *devPtr) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaIpcMemHandle_t *, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaIpcGetMemHandle"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, devPtr); +} + +extern __host__ cudaError_t CUDARTAPI cudaIpcOpenMemHandle( + void **devPtr, cudaIpcMemHandle_t handle, unsigned int flags) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(void **, cudaIpcMemHandle_t, unsigned int); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaIpcOpenMemHandle"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(devPtr, handle, flags); +} + +extern __host__ cudaError_t CUDARTAPI cudaIpcCloseMemHandle(void *devPtr) { + using FuncPtr = cudaError_t(CUDARTAPI *)(void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaIpcCloseMemHandle"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(devPtr); +} + +extern __CUDA_DEPRECATED __host__ cudaError_t CUDARTAPI cudaThreadExit(void) { + using FuncPtr = cudaError_t(CUDARTAPI *)(); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaThreadExit"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(); +} + +extern __CUDA_DEPRECATED __host__ cudaError_t CUDARTAPI +cudaThreadSynchronize(void) { + using FuncPtr = cudaError_t(CUDARTAPI *)(); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaThreadSynchronize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(); +} + +extern __CUDA_DEPRECATED __host__ cudaError_t CUDARTAPI +cudaThreadSetLimit(enum cudaLimit limit, size_t value) { + using FuncPtr = cudaError_t(CUDARTAPI *)(enum cudaLimit, size_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaThreadSetLimit"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(limit, value); +} + +extern __CUDA_DEPRECATED __host__ cudaError_t CUDARTAPI +cudaThreadGetLimit(size_t *pValue, enum cudaLimit limit) { + using FuncPtr = cudaError_t(CUDARTAPI *)(size_t *, enum cudaLimit); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaThreadGetLimit"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pValue, limit); +} + +extern __CUDA_DEPRECATED __host__ cudaError_t CUDARTAPI +cudaThreadGetCacheConfig(enum cudaFuncCache *pCacheConfig) { + using FuncPtr = cudaError_t(CUDARTAPI *)(enum cudaFuncCache *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaThreadGetCacheConfig"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pCacheConfig); +} + +extern __CUDA_DEPRECATED __host__ cudaError_t CUDARTAPI +cudaThreadSetCacheConfig(enum cudaFuncCache cacheConfig) { + using FuncPtr = cudaError_t(CUDARTAPI *)(enum cudaFuncCache); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaThreadSetCacheConfig"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(cacheConfig); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaGetLastError(void) { + using FuncPtr = cudaError_t(CUDARTAPI *)(); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaGetLastError"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaPeekAtLastError(void) { + using FuncPtr = cudaError_t(CUDARTAPI *)(); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaPeekAtLastError"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(); +} + +extern __host__ __cudart_builtin__ const char *CUDARTAPI +cudaGetErrorName(cudaError_t error) { + using FuncPtr = const char *(CUDARTAPI *)(cudaError_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaGetErrorName"); + if (!func_ptr) return "cudaGetErrorName symbol not found."; + return func_ptr(error); +} + +extern __host__ __cudart_builtin__ const char *CUDARTAPI +cudaGetErrorString(cudaError_t error) { + using FuncPtr = const char *(CUDARTAPI *)(cudaError_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaGetErrorString"); + if (!func_ptr) return "cudaGetErrorString symbol not found."; + return func_ptr(error); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaGetDeviceCount(int *count) { + using FuncPtr = cudaError_t(CUDARTAPI *)(int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaGetDeviceCount"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(count); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaGetDeviceProperties(struct cudaDeviceProp *prop, int device) { + using FuncPtr = cudaError_t(CUDARTAPI *)(struct cudaDeviceProp *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaGetDeviceProperties"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(prop, device); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaDeviceGetAttribute(int *value, enum cudaDeviceAttr attr, int device) { + using FuncPtr = cudaError_t(CUDARTAPI *)(int *, enum cudaDeviceAttr, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaDeviceGetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(value, attr, device); +} + +extern __host__ cudaError_t CUDARTAPI cudaDeviceGetNvSciSyncAttributes( + void *nvSciSyncAttrList, int device, int flags) { + using FuncPtr = cudaError_t(CUDARTAPI *)(void *, int, int); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudaDeviceGetNvSciSyncAttributes"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(nvSciSyncAttrList, device, flags); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaDeviceGetP2PAttribute(int *value, enum cudaDeviceP2PAttr attr, + int srcDevice, int dstDevice) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(int *, enum cudaDeviceP2PAttr, int, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaDeviceGetP2PAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(value, attr, srcDevice, dstDevice); +} + +extern __host__ cudaError_t CUDARTAPI +cudaChooseDevice(int *device, const struct cudaDeviceProp *prop) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(int *, const struct cudaDeviceProp *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaChooseDevice"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(device, prop); +} + +extern __host__ cudaError_t CUDARTAPI cudaSetDevice(int device) { + using FuncPtr = cudaError_t(CUDARTAPI *)(int); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaSetDevice"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(device); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaGetDevice(int *device) { + using FuncPtr = cudaError_t(CUDARTAPI *)(int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaGetDevice"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(device); +} + +extern __host__ cudaError_t CUDARTAPI cudaSetValidDevices(int *device_arr, + int len) { + using FuncPtr = cudaError_t(CUDARTAPI *)(int *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaSetValidDevices"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(device_arr, len); +} + +extern __host__ cudaError_t CUDARTAPI cudaSetDeviceFlags(unsigned int flags) { + using FuncPtr = cudaError_t(CUDARTAPI *)(unsigned int); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaSetDeviceFlags"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(flags); +} + +extern __host__ cudaError_t CUDARTAPI cudaGetDeviceFlags(unsigned int *flags) { + using FuncPtr = cudaError_t(CUDARTAPI *)(unsigned int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaGetDeviceFlags"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(flags); +} + +extern __host__ cudaError_t CUDARTAPI cudaStreamCreate(cudaStream_t *pStream) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaStream_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaStreamCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pStream); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaStreamCreateWithFlags(cudaStream_t *pStream, unsigned int flags) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaStream_t *, unsigned int); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaStreamCreateWithFlags"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pStream, flags); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaStreamCreateWithPriority(cudaStream_t *pStream, unsigned int flags, + int priority) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaStream_t *, unsigned int, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaStreamCreateWithPriority"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pStream, flags, priority); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaStreamGetPriority(cudaStream_t hStream, int *priority) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaStream_t, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaStreamGetPriority"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream, priority); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaStreamGetFlags(cudaStream_t hStream, unsigned int *flags) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaStream_t, unsigned int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaStreamGetFlags"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream, flags); +} + +extern __host__ cudaError_t CUDARTAPI cudaCtxResetPersistingL2Cache(void) { + using FuncPtr = cudaError_t(CUDARTAPI *)(); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaCtxResetPersistingL2Cache"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaStreamCopyAttributes(cudaStream_t dst, cudaStream_t src) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaStream_t, cudaStream_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaStreamCopyAttributes"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dst, src); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaStreamGetAttribute(cudaStream_t hStream, enum cudaStreamAttrID attr, + union cudaStreamAttrValue *value_out) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaStream_t, enum cudaStreamAttrID, + union cudaStreamAttrValue *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaStreamGetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream, attr, value_out); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaStreamSetAttribute(cudaStream_t hStream, enum cudaStreamAttrID attr, + const union cudaStreamAttrValue *value) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaStream_t, enum cudaStreamAttrID, + const union cudaStreamAttrValue *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaStreamSetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hStream, attr, value); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaStreamDestroy(cudaStream_t stream) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaStream_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaStreamDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(stream); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI cudaStreamWaitEvent( + cudaStream_t stream, cudaEvent_t event, unsigned int flags) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(cudaStream_t, cudaEvent_t, unsigned int); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaStreamWaitEvent"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(stream, event, flags); +} + +extern __host__ cudaError_t CUDARTAPI +cudaStreamAddCallback(cudaStream_t stream, cudaStreamCallback_t callback, + void *userData, unsigned int flags) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaStream_t, cudaStreamCallback_t, + void *, unsigned int); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaStreamAddCallback"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(stream, callback, userData, flags); +} + +extern __host__ cudaError_t CUDARTAPI +cudaStreamSynchronize(cudaStream_t stream) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaStream_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaStreamSynchronize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(stream); +} + +extern __host__ cudaError_t CUDARTAPI cudaStreamQuery(cudaStream_t stream) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaStream_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaStreamQuery"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(stream); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaStreamAttachMemAsync(cudaStream_t stream, void *devPtr, + size_t length __dv(0), + unsigned int flags __dv(cudaMemAttachSingle)) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(cudaStream_t, void *, size_t, unsigned int); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaStreamAttachMemAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(stream, devPtr, length, flags); +} + +extern __host__ cudaError_t CUDARTAPI +cudaStreamBeginCapture(cudaStream_t stream, enum cudaStreamCaptureMode mode) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(cudaStream_t, enum cudaStreamCaptureMode); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaStreamBeginCapture"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(stream, mode); +} + +extern __host__ cudaError_t CUDARTAPI +cudaThreadExchangeStreamCaptureMode(enum cudaStreamCaptureMode *mode) { + using FuncPtr = cudaError_t(CUDARTAPI *)(enum cudaStreamCaptureMode *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudaThreadExchangeStreamCaptureMode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(mode); +} + +extern __host__ cudaError_t CUDARTAPI +cudaStreamEndCapture(cudaStream_t stream, cudaGraph_t *pGraph) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaStream_t, cudaGraph_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaStreamEndCapture"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(stream, pGraph); +} + +extern __host__ cudaError_t CUDARTAPI cudaStreamIsCapturing( + cudaStream_t stream, enum cudaStreamCaptureStatus *pCaptureStatus) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(cudaStream_t, enum cudaStreamCaptureStatus *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaStreamIsCapturing"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(stream, pCaptureStatus); +} + +extern __host__ cudaError_t CUDARTAPI cudaStreamGetCaptureInfo( + cudaStream_t stream, enum cudaStreamCaptureStatus *pCaptureStatus, + unsigned long long *pId) { + using FuncPtr = cudaError_t(CUDARTAPI *)( + cudaStream_t, enum cudaStreamCaptureStatus *, unsigned long long *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaStreamGetCaptureInfo"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(stream, pCaptureStatus, pId); +} + +extern __host__ cudaError_t CUDARTAPI cudaEventCreate(cudaEvent_t *event) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaEvent_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaEventCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(event); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaEventCreateWithFlags(cudaEvent_t *event, unsigned int flags) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaEvent_t *, unsigned int); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaEventCreateWithFlags"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(event, flags); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaEventRecord(cudaEvent_t event, cudaStream_t stream __dv(0)) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaEvent_t, cudaStream_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaEventRecord"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(event, stream); +} + +extern __host__ cudaError_t CUDARTAPI cudaEventQuery(cudaEvent_t event) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaEvent_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaEventQuery"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(event); +} + +extern __host__ cudaError_t CUDARTAPI cudaEventSynchronize(cudaEvent_t event) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaEvent_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaEventSynchronize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(event); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaEventDestroy(cudaEvent_t event) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaEvent_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaEventDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(event); +} + +extern __host__ cudaError_t CUDARTAPI cudaEventElapsedTime(float *ms, + cudaEvent_t start, + cudaEvent_t end) { + using FuncPtr = cudaError_t(CUDARTAPI *)(float *, cudaEvent_t, cudaEvent_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaEventElapsedTime"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(ms, start, end); +} + +extern __host__ cudaError_t CUDARTAPI cudaImportExternalMemory( + cudaExternalMemory_t *extMem_out, + const struct cudaExternalMemoryHandleDesc *memHandleDesc) { + using FuncPtr = cudaError_t(CUDARTAPI *)( + cudaExternalMemory_t *, const struct cudaExternalMemoryHandleDesc *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaImportExternalMemory"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(extMem_out, memHandleDesc); +} + +extern __host__ cudaError_t CUDARTAPI cudaExternalMemoryGetMappedBuffer( + void **devPtr, cudaExternalMemory_t extMem, + const struct cudaExternalMemoryBufferDesc *bufferDesc) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(void **, cudaExternalMemory_t, + const struct cudaExternalMemoryBufferDesc *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudaExternalMemoryGetMappedBuffer"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(devPtr, extMem, bufferDesc); +} + +extern __host__ cudaError_t CUDARTAPI cudaExternalMemoryGetMappedMipmappedArray( + cudaMipmappedArray_t *mipmap, cudaExternalMemory_t extMem, + const struct cudaExternalMemoryMipmappedArrayDesc *mipmapDesc) { + using FuncPtr = cudaError_t(CUDARTAPI *)( + cudaMipmappedArray_t *, cudaExternalMemory_t, + const struct cudaExternalMemoryMipmappedArrayDesc *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudaExternalMemoryGetMappedMipmappedArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(mipmap, extMem, mipmapDesc); +} + +extern __host__ cudaError_t CUDARTAPI +cudaDestroyExternalMemory(cudaExternalMemory_t extMem) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaExternalMemory_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaDestroyExternalMemory"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(extMem); +} + +extern __host__ cudaError_t CUDARTAPI cudaImportExternalSemaphore( + cudaExternalSemaphore_t *extSem_out, + const struct cudaExternalSemaphoreHandleDesc *semHandleDesc) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(cudaExternalSemaphore_t *, + const struct cudaExternalSemaphoreHandleDesc *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaImportExternalSemaphore"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(extSem_out, semHandleDesc); +} + +extern __host__ cudaError_t CUDARTAPI cudaSignalExternalSemaphoresAsync( + const cudaExternalSemaphore_t *extSemArray, + const struct cudaExternalSemaphoreSignalParams *paramsArray, + unsigned int numExtSems, cudaStream_t stream __dv(0)) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(const cudaExternalSemaphore_t *, + const struct cudaExternalSemaphoreSignalParams *, + unsigned int, cudaStream_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudaSignalExternalSemaphoresAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(extSemArray, paramsArray, numExtSems, stream); +} + +extern __host__ cudaError_t CUDARTAPI cudaWaitExternalSemaphoresAsync( + const cudaExternalSemaphore_t *extSemArray, + const struct cudaExternalSemaphoreWaitParams *paramsArray, + unsigned int numExtSems, cudaStream_t stream __dv(0)) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(const cudaExternalSemaphore_t *, + const struct cudaExternalSemaphoreWaitParams *, + unsigned int, cudaStream_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaWaitExternalSemaphoresAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(extSemArray, paramsArray, numExtSems, stream); +} + +extern __host__ cudaError_t CUDARTAPI +cudaDestroyExternalSemaphore(cudaExternalSemaphore_t extSem) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaExternalSemaphore_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaDestroyExternalSemaphore"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(extSem); +} + +extern __host__ cudaError_t CUDARTAPI +cudaLaunchKernel(const void *func, dim3 gridDim, dim3 blockDim, void **args, + size_t sharedMem, cudaStream_t stream) { + using FuncPtr = cudaError_t(CUDARTAPI *)(const void *, dim3, dim3, void **, + size_t, cudaStream_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaLaunchKernel"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(func, gridDim, blockDim, args, sharedMem, stream); +} + +extern __host__ cudaError_t CUDARTAPI cudaLaunchCooperativeKernel( + const void *func, dim3 gridDim, dim3 blockDim, void **args, + size_t sharedMem, cudaStream_t stream) { + using FuncPtr = cudaError_t(CUDARTAPI *)(const void *, dim3, dim3, void **, + size_t, cudaStream_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaLaunchCooperativeKernel"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(func, gridDim, blockDim, args, sharedMem, stream); +} + +extern __host__ cudaError_t CUDARTAPI cudaLaunchCooperativeKernelMultiDevice( + struct cudaLaunchParams *launchParamsList, unsigned int numDevices, + unsigned int flags __dv(0)) { + using FuncPtr = cudaError_t(CUDARTAPI *)(struct cudaLaunchParams *, + unsigned int, unsigned int); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudaLaunchCooperativeKernelMultiDevice"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(launchParamsList, numDevices, flags); +} + +extern __host__ cudaError_t CUDARTAPI +cudaFuncSetCacheConfig(const void *func, enum cudaFuncCache cacheConfig) { + using FuncPtr = cudaError_t(CUDARTAPI *)(const void *, enum cudaFuncCache); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaFuncSetCacheConfig"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(func, cacheConfig); +} + +extern __host__ cudaError_t CUDARTAPI +cudaFuncSetSharedMemConfig(const void *func, enum cudaSharedMemConfig config) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(const void *, enum cudaSharedMemConfig); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaFuncSetSharedMemConfig"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(func, config); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaFuncGetAttributes(struct cudaFuncAttributes *attr, const void *func) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(struct cudaFuncAttributes *, const void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaFuncGetAttributes"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(attr, func); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaFuncSetAttribute(const void *func, enum cudaFuncAttribute attr, int value) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(const void *, enum cudaFuncAttribute, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaFuncSetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(func, attr, value); +} + +extern __CUDA_DEPRECATED __host__ cudaError_t CUDARTAPI +cudaSetDoubleForDevice(double *d) { + using FuncPtr = cudaError_t(CUDARTAPI *)(double *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaSetDoubleForDevice"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(d); +} + +extern __CUDA_DEPRECATED __host__ cudaError_t CUDARTAPI +cudaSetDoubleForHost(double *d) { + using FuncPtr = cudaError_t(CUDARTAPI *)(double *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaSetDoubleForHost"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(d); +} + +extern __host__ cudaError_t CUDARTAPI cudaLaunchHostFunc(cudaStream_t stream, + cudaHostFn_t fn, + void *userData) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaStream_t, cudaHostFn_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaLaunchHostFunc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(stream, fn, userData); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaOccupancyMaxActiveBlocksPerMultiprocessor(int *numBlocks, const void *func, + int blockSize, + size_t dynamicSMemSize) { + using FuncPtr = cudaError_t(CUDARTAPI *)(int *, const void *, int, size_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudaOccupancyMaxActiveBlocksPerMultiprocessor"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(numBlocks, func, blockSize, dynamicSMemSize); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaOccupancyAvailableDynamicSMemPerBlock(size_t *dynamicSmemSize, + const void *func, int numBlocks, + int blockSize) { + using FuncPtr = cudaError_t(CUDARTAPI *)(size_t *, const void *, int, int); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudaOccupancyAvailableDynamicSMemPerBlock"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dynamicSmemSize, func, numBlocks, blockSize); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags(int *numBlocks, + const void *func, + int blockSize, + size_t dynamicSMemSize, + unsigned int flags) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(int *, const void *, int, size_t, unsigned int); + static auto func_ptr = LoadSymbol<FuncPtr>( + "cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(numBlocks, func, blockSize, dynamicSMemSize, flags); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI cudaMallocManaged( + void **devPtr, size_t size, unsigned int flags __dv(cudaMemAttachGlobal)) { + using FuncPtr = cudaError_t(CUDARTAPI *)(void **, size_t, unsigned int); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaMallocManaged"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(devPtr, size, flags); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaMalloc(void **devPtr, size_t size) { + using FuncPtr = cudaError_t(CUDARTAPI *)(void **, size_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaMalloc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(devPtr, size); +} + +extern __host__ cudaError_t CUDARTAPI cudaMallocHost(void **ptr, size_t size) { + using FuncPtr = cudaError_t(CUDARTAPI *)(void **, size_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaMallocHost"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(ptr, size); +} + +extern __host__ cudaError_t CUDARTAPI cudaMallocPitch(void **devPtr, + size_t *pitch, + size_t width, + size_t height) { + using FuncPtr = cudaError_t(CUDARTAPI *)(void **, size_t *, size_t, size_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaMallocPitch"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(devPtr, pitch, width, height); +} + +extern __host__ cudaError_t CUDARTAPI cudaMallocArray( + cudaArray_t *array, const struct cudaChannelFormatDesc *desc, size_t width, + size_t height __dv(0), unsigned int flags __dv(0)) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaArray_t *, + const struct cudaChannelFormatDesc *, + size_t, size_t, unsigned int); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaMallocArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(array, desc, width, height, flags); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaFree(void *devPtr) { + using FuncPtr = cudaError_t(CUDARTAPI *)(void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaFree"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(devPtr); +} + +extern __host__ cudaError_t CUDARTAPI cudaFreeHost(void *ptr) { + using FuncPtr = cudaError_t(CUDARTAPI *)(void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaFreeHost"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(ptr); +} + +extern __host__ cudaError_t CUDARTAPI cudaFreeArray(cudaArray_t array) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaArray_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaFreeArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(array); +} + +extern __host__ cudaError_t CUDARTAPI +cudaFreeMipmappedArray(cudaMipmappedArray_t mipmappedArray) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaMipmappedArray_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaFreeMipmappedArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(mipmappedArray); +} + +extern __host__ cudaError_t CUDARTAPI cudaHostAlloc(void **pHost, size_t size, + unsigned int flags) { + using FuncPtr = cudaError_t(CUDARTAPI *)(void **, size_t, unsigned int); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaHostAlloc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pHost, size, flags); +} + +extern __host__ cudaError_t CUDARTAPI cudaHostRegister(void *ptr, size_t size, + unsigned int flags) { + using FuncPtr = cudaError_t(CUDARTAPI *)(void *, size_t, unsigned int); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaHostRegister"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(ptr, size, flags); +} + +extern __host__ cudaError_t CUDARTAPI cudaHostUnregister(void *ptr) { + using FuncPtr = cudaError_t(CUDARTAPI *)(void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaHostUnregister"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(ptr); +} + +extern __host__ cudaError_t CUDARTAPI +cudaHostGetDevicePointer(void **pDevice, void *pHost, unsigned int flags) { + using FuncPtr = cudaError_t(CUDARTAPI *)(void **, void *, unsigned int); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaHostGetDevicePointer"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pDevice, pHost, flags); +} + +extern __host__ cudaError_t CUDARTAPI cudaHostGetFlags(unsigned int *pFlags, + void *pHost) { + using FuncPtr = cudaError_t(CUDARTAPI *)(unsigned int *, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaHostGetFlags"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pFlags, pHost); +} + +extern __host__ cudaError_t CUDARTAPI +cudaMalloc3D(struct cudaPitchedPtr *pitchedDevPtr, struct cudaExtent extent) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(struct cudaPitchedPtr *, struct cudaExtent); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaMalloc3D"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pitchedDevPtr, extent); +} + +extern __host__ cudaError_t CUDARTAPI +cudaMalloc3DArray(cudaArray_t *array, const struct cudaChannelFormatDesc *desc, + struct cudaExtent extent, unsigned int flags __dv(0)) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaArray_t *, + const struct cudaChannelFormatDesc *, + struct cudaExtent, unsigned int); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaMalloc3DArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(array, desc, extent, flags); +} + +extern __host__ cudaError_t CUDARTAPI cudaMallocMipmappedArray( + cudaMipmappedArray_t *mipmappedArray, + const struct cudaChannelFormatDesc *desc, struct cudaExtent extent, + unsigned int numLevels, unsigned int flags __dv(0)) { + using FuncPtr = cudaError_t(CUDARTAPI *)( + cudaMipmappedArray_t *, const struct cudaChannelFormatDesc *, + struct cudaExtent, unsigned int, unsigned int); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaMallocMipmappedArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(mipmappedArray, desc, extent, numLevels, flags); +} + +extern __host__ cudaError_t CUDARTAPI cudaGetMipmappedArrayLevel( + cudaArray_t *levelArray, cudaMipmappedArray_const_t mipmappedArray, + unsigned int level) { + using FuncPtr = cudaError_t(CUDARTAPI *)( + cudaArray_t *, cudaMipmappedArray_const_t, unsigned int); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaGetMipmappedArrayLevel"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(levelArray, mipmappedArray, level); +} + +extern __host__ cudaError_t CUDARTAPI +cudaMemcpy3D(const struct cudaMemcpy3DParms *p) { + using FuncPtr = cudaError_t(CUDARTAPI *)(const struct cudaMemcpy3DParms *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaMemcpy3D"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(p); +} + +extern __host__ cudaError_t CUDARTAPI +cudaMemcpy3DPeer(const struct cudaMemcpy3DPeerParms *p) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(const struct cudaMemcpy3DPeerParms *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaMemcpy3DPeer"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(p); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI cudaMemcpy3DAsync( + const struct cudaMemcpy3DParms *p, cudaStream_t stream __dv(0)) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(const struct cudaMemcpy3DParms *, cudaStream_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaMemcpy3DAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(p, stream); +} + +extern __host__ cudaError_t CUDARTAPI cudaMemcpy3DPeerAsync( + const struct cudaMemcpy3DPeerParms *p, cudaStream_t stream __dv(0)) { + using FuncPtr = cudaError_t(CUDARTAPI *)(const struct cudaMemcpy3DPeerParms *, + cudaStream_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaMemcpy3DPeerAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(p, stream); +} + +extern __host__ cudaError_t CUDARTAPI cudaMemGetInfo(size_t *free, + size_t *total) { + using FuncPtr = cudaError_t(CUDARTAPI *)(size_t *, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaMemGetInfo"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(free, total); +} + +extern __host__ cudaError_t CUDARTAPI +cudaArrayGetInfo(struct cudaChannelFormatDesc *desc, struct cudaExtent *extent, + unsigned int *flags, cudaArray_t array) { + using FuncPtr = cudaError_t(CUDARTAPI *)(struct cudaChannelFormatDesc *, + struct cudaExtent *, unsigned int *, + cudaArray_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaArrayGetInfo"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(desc, extent, flags, array); +} + +extern __host__ cudaError_t CUDARTAPI cudaMemcpy(void *dst, const void *src, + size_t count, + enum cudaMemcpyKind kind) { + using FuncPtr = cudaError_t(CUDARTAPI *)(void *, const void *, size_t, + enum cudaMemcpyKind); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaMemcpy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dst, src, count, kind); +} + +extern __host__ cudaError_t CUDARTAPI cudaMemcpyPeer(void *dst, int dstDevice, + const void *src, + int srcDevice, + size_t count) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(void *, int, const void *, int, size_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaMemcpyPeer"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dst, dstDevice, src, srcDevice, count); +} + +extern __host__ cudaError_t CUDARTAPI cudaMemcpy2D(void *dst, size_t dpitch, + const void *src, + size_t spitch, size_t width, + size_t height, + enum cudaMemcpyKind kind) { + using FuncPtr = cudaError_t(CUDARTAPI *)(void *, size_t, const void *, size_t, + size_t, size_t, enum cudaMemcpyKind); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaMemcpy2D"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dst, dpitch, src, spitch, width, height, kind); +} + +extern __host__ cudaError_t CUDARTAPI cudaMemcpy2DToArray( + cudaArray_t dst, size_t wOffset, size_t hOffset, const void *src, + size_t spitch, size_t width, size_t height, enum cudaMemcpyKind kind) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(cudaArray_t, size_t, size_t, const void *, + size_t, size_t, size_t, enum cudaMemcpyKind); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaMemcpy2DToArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dst, wOffset, hOffset, src, spitch, width, height, kind); +} + +extern __host__ cudaError_t CUDARTAPI cudaMemcpy2DFromArray( + void *dst, size_t dpitch, cudaArray_const_t src, size_t wOffset, + size_t hOffset, size_t width, size_t height, enum cudaMemcpyKind kind) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(void *, size_t, cudaArray_const_t, size_t, + size_t, size_t, size_t, enum cudaMemcpyKind); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaMemcpy2DFromArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dst, dpitch, src, wOffset, hOffset, width, height, kind); +} + +extern __host__ cudaError_t CUDARTAPI cudaMemcpy2DArrayToArray( + cudaArray_t dst, size_t wOffsetDst, size_t hOffsetDst, + cudaArray_const_t src, size_t wOffsetSrc, size_t hOffsetSrc, size_t width, + size_t height, enum cudaMemcpyKind kind __dv(cudaMemcpyDeviceToDevice)) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaArray_t, size_t, size_t, + cudaArray_const_t, size_t, size_t, + size_t, size_t, enum cudaMemcpyKind); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaMemcpy2DArrayToArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dst, wOffsetDst, hOffsetDst, src, wOffsetSrc, hOffsetSrc, + width, height, kind); +} + +extern __host__ cudaError_t CUDARTAPI cudaMemcpyToSymbol( + const void *symbol, const void *src, size_t count, size_t offset __dv(0), + enum cudaMemcpyKind kind __dv(cudaMemcpyHostToDevice)) { + using FuncPtr = cudaError_t(CUDARTAPI *)(const void *, const void *, size_t, + size_t, enum cudaMemcpyKind); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaMemcpyToSymbol"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(symbol, src, count, offset, kind); +} + +extern __host__ cudaError_t CUDARTAPI cudaMemcpyFromSymbol( + void *dst, const void *symbol, size_t count, size_t offset __dv(0), + enum cudaMemcpyKind kind __dv(cudaMemcpyDeviceToHost)) { + using FuncPtr = cudaError_t(CUDARTAPI *)(void *, const void *, size_t, size_t, + enum cudaMemcpyKind); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaMemcpyFromSymbol"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dst, symbol, count, offset, kind); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaMemcpyAsync(void *dst, const void *src, size_t count, + enum cudaMemcpyKind kind, cudaStream_t stream __dv(0)) { + using FuncPtr = cudaError_t(CUDARTAPI *)(void *, const void *, size_t, + enum cudaMemcpyKind, cudaStream_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaMemcpyAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dst, src, count, kind, stream); +} + +extern __host__ cudaError_t CUDARTAPI +cudaMemcpyPeerAsync(void *dst, int dstDevice, const void *src, int srcDevice, + size_t count, cudaStream_t stream __dv(0)) { + using FuncPtr = cudaError_t(CUDARTAPI *)(void *, int, const void *, int, + size_t, cudaStream_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaMemcpyPeerAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dst, dstDevice, src, srcDevice, count, stream); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI cudaMemcpy2DAsync( + void *dst, size_t dpitch, const void *src, size_t spitch, size_t width, + size_t height, enum cudaMemcpyKind kind, cudaStream_t stream __dv(0)) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(void *, size_t, const void *, size_t, size_t, + size_t, enum cudaMemcpyKind, cudaStream_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaMemcpy2DAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dst, dpitch, src, spitch, width, height, kind, stream); +} + +extern __host__ cudaError_t CUDARTAPI cudaMemcpy2DToArrayAsync( + cudaArray_t dst, size_t wOffset, size_t hOffset, const void *src, + size_t spitch, size_t width, size_t height, enum cudaMemcpyKind kind, + cudaStream_t stream __dv(0)) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaArray_t, size_t, size_t, + const void *, size_t, size_t, size_t, + enum cudaMemcpyKind, cudaStream_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaMemcpy2DToArrayAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dst, wOffset, hOffset, src, spitch, width, height, kind, + stream); +} + +extern __host__ cudaError_t CUDARTAPI cudaMemcpy2DFromArrayAsync( + void *dst, size_t dpitch, cudaArray_const_t src, size_t wOffset, + size_t hOffset, size_t width, size_t height, enum cudaMemcpyKind kind, + cudaStream_t stream __dv(0)) { + using FuncPtr = cudaError_t(CUDARTAPI *)(void *, size_t, cudaArray_const_t, + size_t, size_t, size_t, size_t, + enum cudaMemcpyKind, cudaStream_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaMemcpy2DFromArrayAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dst, dpitch, src, wOffset, hOffset, width, height, kind, + stream); +} + +extern __host__ cudaError_t CUDARTAPI cudaMemcpyToSymbolAsync( + const void *symbol, const void *src, size_t count, size_t offset, + enum cudaMemcpyKind kind, cudaStream_t stream __dv(0)) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(const void *, const void *, size_t, size_t, + enum cudaMemcpyKind, cudaStream_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaMemcpyToSymbolAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(symbol, src, count, offset, kind, stream); +} + +extern __host__ cudaError_t CUDARTAPI cudaMemcpyFromSymbolAsync( + void *dst, const void *symbol, size_t count, size_t offset, + enum cudaMemcpyKind kind, cudaStream_t stream __dv(0)) { + using FuncPtr = cudaError_t(CUDARTAPI *)(void *, const void *, size_t, size_t, + enum cudaMemcpyKind, cudaStream_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaMemcpyFromSymbolAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dst, symbol, count, offset, kind, stream); +} + +extern __host__ cudaError_t CUDARTAPI cudaMemset(void *devPtr, int value, + size_t count) { + using FuncPtr = cudaError_t(CUDARTAPI *)(void *, int, size_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaMemset"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(devPtr, value, count); +} + +extern __host__ cudaError_t CUDARTAPI cudaMemset2D(void *devPtr, size_t pitch, + int value, size_t width, + size_t height) { + using FuncPtr = cudaError_t(CUDARTAPI *)(void *, size_t, int, size_t, size_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaMemset2D"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(devPtr, pitch, value, width, height); +} + +extern __host__ cudaError_t CUDARTAPI cudaMemset3D( + struct cudaPitchedPtr pitchedDevPtr, int value, struct cudaExtent extent) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(struct cudaPitchedPtr, int, struct cudaExtent); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaMemset3D"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pitchedDevPtr, value, extent); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI cudaMemsetAsync( + void *devPtr, int value, size_t count, cudaStream_t stream __dv(0)) { + using FuncPtr = cudaError_t(CUDARTAPI *)(void *, int, size_t, cudaStream_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaMemsetAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(devPtr, value, count, stream); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaMemset2DAsync(void *devPtr, size_t pitch, int value, size_t width, + size_t height, cudaStream_t stream __dv(0)) { + using FuncPtr = cudaError_t(CUDARTAPI *)(void *, size_t, int, size_t, size_t, + cudaStream_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaMemset2DAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(devPtr, pitch, value, width, height, stream); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaMemset3DAsync(struct cudaPitchedPtr pitchedDevPtr, int value, + struct cudaExtent extent, cudaStream_t stream __dv(0)) { + using FuncPtr = cudaError_t(CUDARTAPI *)(struct cudaPitchedPtr, int, + struct cudaExtent, cudaStream_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaMemset3DAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pitchedDevPtr, value, extent, stream); +} + +extern __host__ cudaError_t CUDARTAPI cudaGetSymbolAddress(void **devPtr, + const void *symbol) { + using FuncPtr = cudaError_t(CUDARTAPI *)(void **, const void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaGetSymbolAddress"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(devPtr, symbol); +} + +extern __host__ cudaError_t CUDARTAPI cudaGetSymbolSize(size_t *size, + const void *symbol) { + using FuncPtr = cudaError_t(CUDARTAPI *)(size_t *, const void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaGetSymbolSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(size, symbol); +} + +extern __host__ cudaError_t CUDARTAPI +cudaMemPrefetchAsync(const void *devPtr, size_t count, int dstDevice, + cudaStream_t stream __dv(0)) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(const void *, size_t, int, cudaStream_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaMemPrefetchAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(devPtr, count, dstDevice, stream); +} + +extern __host__ cudaError_t CUDARTAPI +cudaMemAdvise(const void *devPtr, size_t count, enum cudaMemoryAdvise advice, + int device) { + using FuncPtr = cudaError_t(CUDARTAPI *)(const void *, size_t, + enum cudaMemoryAdvise, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaMemAdvise"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(devPtr, count, advice, device); +} + +extern __host__ cudaError_t CUDARTAPI cudaMemRangeGetAttribute( + void *data, size_t dataSize, enum cudaMemRangeAttribute attribute, + const void *devPtr, size_t count) { + using FuncPtr = cudaError_t(CUDARTAPI *)( + void *, size_t, enum cudaMemRangeAttribute, const void *, size_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaMemRangeGetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(data, dataSize, attribute, devPtr, count); +} + +extern __host__ cudaError_t CUDARTAPI cudaMemRangeGetAttributes( + void **data, size_t *dataSizes, enum cudaMemRangeAttribute *attributes, + size_t numAttributes, const void *devPtr, size_t count) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(void **, size_t *, enum cudaMemRangeAttribute *, + size_t, const void *, size_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaMemRangeGetAttributes"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(data, dataSizes, attributes, numAttributes, devPtr, count); +} + +extern __CUDA_DEPRECATED __host__ cudaError_t CUDARTAPI +cudaMemcpyToArray(cudaArray_t dst, size_t wOffset, size_t hOffset, + const void *src, size_t count, enum cudaMemcpyKind kind) { + using FuncPtr = cudaError_t(CUDARTAPI *)( + cudaArray_t, size_t, size_t, const void *, size_t, enum cudaMemcpyKind); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaMemcpyToArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dst, wOffset, hOffset, src, count, kind); +} + +extern __CUDA_DEPRECATED __host__ cudaError_t CUDARTAPI +cudaMemcpyFromArray(void *dst, cudaArray_const_t src, size_t wOffset, + size_t hOffset, size_t count, enum cudaMemcpyKind kind) { + using FuncPtr = cudaError_t(CUDARTAPI *)(void *, cudaArray_const_t, size_t, + size_t, size_t, enum cudaMemcpyKind); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaMemcpyFromArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dst, src, wOffset, hOffset, count, kind); +} + +extern __CUDA_DEPRECATED __host__ cudaError_t CUDARTAPI cudaMemcpyArrayToArray( + cudaArray_t dst, size_t wOffsetDst, size_t hOffsetDst, + cudaArray_const_t src, size_t wOffsetSrc, size_t hOffsetSrc, size_t count, + enum cudaMemcpyKind kind __dv(cudaMemcpyDeviceToDevice)) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(cudaArray_t, size_t, size_t, cudaArray_const_t, + size_t, size_t, size_t, enum cudaMemcpyKind); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaMemcpyArrayToArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dst, wOffsetDst, hOffsetDst, src, wOffsetSrc, hOffsetSrc, + count, kind); +} + +extern __CUDA_DEPRECATED __host__ cudaError_t CUDARTAPI cudaMemcpyToArrayAsync( + cudaArray_t dst, size_t wOffset, size_t hOffset, const void *src, + size_t count, enum cudaMemcpyKind kind, cudaStream_t stream __dv(0)) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(cudaArray_t, size_t, size_t, const void *, + size_t, enum cudaMemcpyKind, cudaStream_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaMemcpyToArrayAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dst, wOffset, hOffset, src, count, kind, stream); +} + +extern __CUDA_DEPRECATED __host__ cudaError_t CUDARTAPI +cudaMemcpyFromArrayAsync(void *dst, cudaArray_const_t src, size_t wOffset, + size_t hOffset, size_t count, enum cudaMemcpyKind kind, + cudaStream_t stream __dv(0)) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(void *, cudaArray_const_t, size_t, size_t, + size_t, enum cudaMemcpyKind, cudaStream_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaMemcpyFromArrayAsync"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dst, src, wOffset, hOffset, count, kind, stream); +} + +extern __host__ cudaError_t CUDARTAPI cudaPointerGetAttributes( + struct cudaPointerAttributes *attributes, const void *ptr) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(struct cudaPointerAttributes *, const void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaPointerGetAttributes"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(attributes, ptr); +} + +extern __host__ cudaError_t CUDARTAPI +cudaDeviceCanAccessPeer(int *canAccessPeer, int device, int peerDevice) { + using FuncPtr = cudaError_t(CUDARTAPI *)(int *, int, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaDeviceCanAccessPeer"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(canAccessPeer, device, peerDevice); +} + +extern __host__ cudaError_t CUDARTAPI +cudaDeviceEnablePeerAccess(int peerDevice, unsigned int flags) { + using FuncPtr = cudaError_t(CUDARTAPI *)(int, unsigned int); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaDeviceEnablePeerAccess"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(peerDevice, flags); +} + +extern __host__ cudaError_t CUDARTAPI +cudaDeviceDisablePeerAccess(int peerDevice) { + using FuncPtr = cudaError_t(CUDARTAPI *)(int); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaDeviceDisablePeerAccess"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(peerDevice); +} + +extern __host__ cudaError_t CUDARTAPI +cudaGraphicsUnregisterResource(cudaGraphicsResource_t resource) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraphicsResource_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaGraphicsUnregisterResource"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(resource); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphicsResourceSetMapFlags( + cudaGraphicsResource_t resource, unsigned int flags) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(cudaGraphicsResource_t, unsigned int); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaGraphicsResourceSetMapFlags"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(resource, flags); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphicsMapResources( + int count, cudaGraphicsResource_t *resources, cudaStream_t stream __dv(0)) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(int, cudaGraphicsResource_t *, cudaStream_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaGraphicsMapResources"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(count, resources, stream); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphicsUnmapResources( + int count, cudaGraphicsResource_t *resources, cudaStream_t stream __dv(0)) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(int, cudaGraphicsResource_t *, cudaStream_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaGraphicsUnmapResources"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(count, resources, stream); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphicsResourceGetMappedPointer( + void **devPtr, size_t *size, cudaGraphicsResource_t resource) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(void **, size_t *, cudaGraphicsResource_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudaGraphicsResourceGetMappedPointer"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(devPtr, size, resource); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphicsSubResourceGetMappedArray( + cudaArray_t *array, cudaGraphicsResource_t resource, + unsigned int arrayIndex, unsigned int mipLevel) { + using FuncPtr = cudaError_t(CUDARTAPI *)( + cudaArray_t *, cudaGraphicsResource_t, unsigned int, unsigned int); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudaGraphicsSubResourceGetMappedArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(array, resource, arrayIndex, mipLevel); +} + +extern __host__ cudaError_t CUDARTAPI +cudaGraphicsResourceGetMappedMipmappedArray( + cudaMipmappedArray_t *mipmappedArray, cudaGraphicsResource_t resource) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(cudaMipmappedArray_t *, cudaGraphicsResource_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudaGraphicsResourceGetMappedMipmappedArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(mipmappedArray, resource); +} + +extern __CUDA_DEPRECATED __host__ cudaError_t CUDARTAPI cudaBindTexture( + size_t *offset, const struct textureReference *texref, const void *devPtr, + const struct cudaChannelFormatDesc *desc, size_t size __dv(UINT_MAX)) { + using FuncPtr = cudaError_t(CUDARTAPI *)( + size_t *, const struct textureReference *, const void *, + const struct cudaChannelFormatDesc *, size_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaBindTexture"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(offset, texref, devPtr, desc, size); +} + +extern __CUDA_DEPRECATED __host__ cudaError_t CUDARTAPI +cudaBindTexture2D(size_t *offset, const struct textureReference *texref, + const void *devPtr, const struct cudaChannelFormatDesc *desc, + size_t width, size_t height, size_t pitch) { + using FuncPtr = cudaError_t(CUDARTAPI *)( + size_t *, const struct textureReference *, const void *, + const struct cudaChannelFormatDesc *, size_t, size_t, size_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaBindTexture2D"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(offset, texref, devPtr, desc, width, height, pitch); +} + +extern __CUDA_DEPRECATED __host__ cudaError_t CUDARTAPI cudaBindTextureToArray( + const struct textureReference *texref, cudaArray_const_t array, + const struct cudaChannelFormatDesc *desc) { + using FuncPtr = cudaError_t(CUDARTAPI *)( + const struct textureReference *, cudaArray_const_t, + const struct cudaChannelFormatDesc *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaBindTextureToArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(texref, array, desc); +} + +extern __CUDA_DEPRECATED __host__ cudaError_t CUDARTAPI +cudaBindTextureToMipmappedArray(const struct textureReference *texref, + cudaMipmappedArray_const_t mipmappedArray, + const struct cudaChannelFormatDesc *desc) { + using FuncPtr = cudaError_t(CUDARTAPI *)( + const struct textureReference *, cudaMipmappedArray_const_t, + const struct cudaChannelFormatDesc *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaBindTextureToMipmappedArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(texref, mipmappedArray, desc); +} + +extern __CUDA_DEPRECATED __host__ cudaError_t CUDARTAPI +cudaUnbindTexture(const struct textureReference *texref) { + using FuncPtr = cudaError_t(CUDARTAPI *)(const struct textureReference *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaUnbindTexture"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(texref); +} + +extern __CUDA_DEPRECATED __host__ cudaError_t CUDARTAPI +cudaGetTextureAlignmentOffset(size_t *offset, + const struct textureReference *texref) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(size_t *, const struct textureReference *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaGetTextureAlignmentOffset"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(offset, texref); +} + +extern __CUDA_DEPRECATED __host__ cudaError_t CUDARTAPI cudaGetTextureReference( + const struct textureReference **texref, const void *symbol) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(const struct textureReference **, const void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaGetTextureReference"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(texref, symbol); +} + +extern __CUDA_DEPRECATED __host__ cudaError_t CUDARTAPI cudaBindSurfaceToArray( + const struct surfaceReference *surfref, cudaArray_const_t array, + const struct cudaChannelFormatDesc *desc) { + using FuncPtr = cudaError_t(CUDARTAPI *)( + const struct surfaceReference *, cudaArray_const_t, + const struct cudaChannelFormatDesc *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaBindSurfaceToArray"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(surfref, array, desc); +} + +extern __CUDA_DEPRECATED __host__ cudaError_t CUDARTAPI cudaGetSurfaceReference( + const struct surfaceReference **surfref, const void *symbol) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(const struct surfaceReference **, const void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaGetSurfaceReference"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(surfref, symbol); +} + +extern __host__ cudaError_t CUDARTAPI cudaGetChannelDesc( + struct cudaChannelFormatDesc *desc, cudaArray_const_t array) { + using FuncPtr = cudaError_t(CUDARTAPI *)(struct cudaChannelFormatDesc *, + cudaArray_const_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaGetChannelDesc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(desc, array); +} + +extern __host__ cudaError_t CUDARTAPI cudaCreateTextureObject( + cudaTextureObject_t *pTexObject, const struct cudaResourceDesc *pResDesc, + const struct cudaTextureDesc *pTexDesc, + const struct cudaResourceViewDesc *pResViewDesc) { + using FuncPtr = cudaError_t(CUDARTAPI *)( + cudaTextureObject_t *, const struct cudaResourceDesc *, + const struct cudaTextureDesc *, const struct cudaResourceViewDesc *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaCreateTextureObject"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pTexObject, pResDesc, pTexDesc, pResViewDesc); +} + +extern __host__ cudaError_t CUDARTAPI +cudaDestroyTextureObject(cudaTextureObject_t texObject) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaTextureObject_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaDestroyTextureObject"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(texObject); +} + +extern __host__ cudaError_t CUDARTAPI cudaGetTextureObjectResourceDesc( + struct cudaResourceDesc *pResDesc, cudaTextureObject_t texObject) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(struct cudaResourceDesc *, cudaTextureObject_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudaGetTextureObjectResourceDesc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pResDesc, texObject); +} + +extern __host__ cudaError_t CUDARTAPI cudaGetTextureObjectTextureDesc( + struct cudaTextureDesc *pTexDesc, cudaTextureObject_t texObject) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(struct cudaTextureDesc *, cudaTextureObject_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaGetTextureObjectTextureDesc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pTexDesc, texObject); +} + +extern __host__ cudaError_t CUDARTAPI cudaGetTextureObjectResourceViewDesc( + struct cudaResourceViewDesc *pResViewDesc, cudaTextureObject_t texObject) { + using FuncPtr = cudaError_t(CUDARTAPI *)(struct cudaResourceViewDesc *, + cudaTextureObject_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudaGetTextureObjectResourceViewDesc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pResViewDesc, texObject); +} + +extern __host__ cudaError_t CUDARTAPI cudaCreateSurfaceObject( + cudaSurfaceObject_t *pSurfObject, const struct cudaResourceDesc *pResDesc) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaSurfaceObject_t *, + const struct cudaResourceDesc *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaCreateSurfaceObject"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pSurfObject, pResDesc); +} + +extern __host__ cudaError_t CUDARTAPI +cudaDestroySurfaceObject(cudaSurfaceObject_t surfObject) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaSurfaceObject_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaDestroySurfaceObject"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(surfObject); +} + +extern __host__ cudaError_t CUDARTAPI cudaGetSurfaceObjectResourceDesc( + struct cudaResourceDesc *pResDesc, cudaSurfaceObject_t surfObject) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(struct cudaResourceDesc *, cudaSurfaceObject_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudaGetSurfaceObjectResourceDesc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pResDesc, surfObject); +} + +extern __host__ cudaError_t CUDARTAPI cudaDriverGetVersion(int *driverVersion) { + using FuncPtr = cudaError_t(CUDARTAPI *)(int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaDriverGetVersion"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(driverVersion); +} + +extern __host__ __cudart_builtin__ cudaError_t CUDARTAPI +cudaRuntimeGetVersion(int *runtimeVersion) { + using FuncPtr = cudaError_t(CUDARTAPI *)(int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaRuntimeGetVersion"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(runtimeVersion); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphCreate(cudaGraph_t *pGraph, + unsigned int flags) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraph_t *, unsigned int); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaGraphCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pGraph, flags); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphAddKernelNode( + cudaGraphNode_t *pGraphNode, cudaGraph_t graph, + const cudaGraphNode_t *pDependencies, size_t numDependencies, + const struct cudaKernelNodeParams *pNodeParams) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraphNode_t *, cudaGraph_t, + const cudaGraphNode_t *, size_t, + const struct cudaKernelNodeParams *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaGraphAddKernelNode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pGraphNode, graph, pDependencies, numDependencies, + pNodeParams); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphKernelNodeGetParams( + cudaGraphNode_t node, struct cudaKernelNodeParams *pNodeParams) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(cudaGraphNode_t, struct cudaKernelNodeParams *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaGraphKernelNodeGetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(node, pNodeParams); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphKernelNodeSetParams( + cudaGraphNode_t node, const struct cudaKernelNodeParams *pNodeParams) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraphNode_t, + const struct cudaKernelNodeParams *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaGraphKernelNodeSetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(node, pNodeParams); +} + +extern __host__ cudaError_t CUDARTAPI +cudaGraphKernelNodeCopyAttributes(cudaGraphNode_t hSrc, cudaGraphNode_t hDst) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraphNode_t, cudaGraphNode_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudaGraphKernelNodeCopyAttributes"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hSrc, hDst); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphKernelNodeGetAttribute( + cudaGraphNode_t hNode, enum cudaKernelNodeAttrID attr, + union cudaKernelNodeAttrValue *value_out) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(cudaGraphNode_t, enum cudaKernelNodeAttrID, + union cudaKernelNodeAttrValue *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaGraphKernelNodeGetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hNode, attr, value_out); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphKernelNodeSetAttribute( + cudaGraphNode_t hNode, enum cudaKernelNodeAttrID attr, + const union cudaKernelNodeAttrValue *value) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(cudaGraphNode_t, enum cudaKernelNodeAttrID, + const union cudaKernelNodeAttrValue *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaGraphKernelNodeSetAttribute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hNode, attr, value); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphAddMemcpyNode( + cudaGraphNode_t *pGraphNode, cudaGraph_t graph, + const cudaGraphNode_t *pDependencies, size_t numDependencies, + const struct cudaMemcpy3DParms *pCopyParams) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraphNode_t *, cudaGraph_t, + const cudaGraphNode_t *, size_t, + const struct cudaMemcpy3DParms *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaGraphAddMemcpyNode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pGraphNode, graph, pDependencies, numDependencies, + pCopyParams); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphMemcpyNodeGetParams( + cudaGraphNode_t node, struct cudaMemcpy3DParms *pNodeParams) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(cudaGraphNode_t, struct cudaMemcpy3DParms *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaGraphMemcpyNodeGetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(node, pNodeParams); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphMemcpyNodeSetParams( + cudaGraphNode_t node, const struct cudaMemcpy3DParms *pNodeParams) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraphNode_t, + const struct cudaMemcpy3DParms *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaGraphMemcpyNodeSetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(node, pNodeParams); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphAddMemsetNode( + cudaGraphNode_t *pGraphNode, cudaGraph_t graph, + const cudaGraphNode_t *pDependencies, size_t numDependencies, + const struct cudaMemsetParams *pMemsetParams) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraphNode_t *, cudaGraph_t, + const cudaGraphNode_t *, size_t, + const struct cudaMemsetParams *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaGraphAddMemsetNode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pGraphNode, graph, pDependencies, numDependencies, + pMemsetParams); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphMemsetNodeGetParams( + cudaGraphNode_t node, struct cudaMemsetParams *pNodeParams) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(cudaGraphNode_t, struct cudaMemsetParams *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaGraphMemsetNodeGetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(node, pNodeParams); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphMemsetNodeSetParams( + cudaGraphNode_t node, const struct cudaMemsetParams *pNodeParams) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraphNode_t, + const struct cudaMemsetParams *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaGraphMemsetNodeSetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(node, pNodeParams); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphAddHostNode( + cudaGraphNode_t *pGraphNode, cudaGraph_t graph, + const cudaGraphNode_t *pDependencies, size_t numDependencies, + const struct cudaHostNodeParams *pNodeParams) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraphNode_t *, cudaGraph_t, + const cudaGraphNode_t *, size_t, + const struct cudaHostNodeParams *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaGraphAddHostNode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pGraphNode, graph, pDependencies, numDependencies, + pNodeParams); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphHostNodeGetParams( + cudaGraphNode_t node, struct cudaHostNodeParams *pNodeParams) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(cudaGraphNode_t, struct cudaHostNodeParams *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaGraphHostNodeGetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(node, pNodeParams); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphHostNodeSetParams( + cudaGraphNode_t node, const struct cudaHostNodeParams *pNodeParams) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraphNode_t, + const struct cudaHostNodeParams *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaGraphHostNodeSetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(node, pNodeParams); +} + +extern __host__ cudaError_t CUDARTAPI +cudaGraphAddChildGraphNode(cudaGraphNode_t *pGraphNode, cudaGraph_t graph, + const cudaGraphNode_t *pDependencies, + size_t numDependencies, cudaGraph_t childGraph) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(cudaGraphNode_t *, cudaGraph_t, + const cudaGraphNode_t *, size_t, cudaGraph_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaGraphAddChildGraphNode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pGraphNode, graph, pDependencies, numDependencies, + childGraph); +} + +extern __host__ cudaError_t CUDARTAPI +cudaGraphChildGraphNodeGetGraph(cudaGraphNode_t node, cudaGraph_t *pGraph) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraphNode_t, cudaGraph_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaGraphChildGraphNodeGetGraph"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(node, pGraph); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphAddEmptyNode( + cudaGraphNode_t *pGraphNode, cudaGraph_t graph, + const cudaGraphNode_t *pDependencies, size_t numDependencies) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraphNode_t *, cudaGraph_t, + const cudaGraphNode_t *, size_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaGraphAddEmptyNode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pGraphNode, graph, pDependencies, numDependencies); +} + +extern __host__ cudaError_t CUDARTAPI +cudaGraphClone(cudaGraph_t *pGraphClone, cudaGraph_t originalGraph) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraph_t *, cudaGraph_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaGraphClone"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pGraphClone, originalGraph); +} + +extern __host__ cudaError_t CUDARTAPI +cudaGraphNodeFindInClone(cudaGraphNode_t *pNode, cudaGraphNode_t originalNode, + cudaGraph_t clonedGraph) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(cudaGraphNode_t *, cudaGraphNode_t, cudaGraph_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaGraphNodeFindInClone"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pNode, originalNode, clonedGraph); +} + +extern __host__ cudaError_t CUDARTAPI +cudaGraphNodeGetType(cudaGraphNode_t node, enum cudaGraphNodeType *pType) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(cudaGraphNode_t, enum cudaGraphNodeType *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaGraphNodeGetType"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(node, pType); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphGetNodes(cudaGraph_t graph, + cudaGraphNode_t *nodes, + size_t *numNodes) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(cudaGraph_t, cudaGraphNode_t *, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaGraphGetNodes"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(graph, nodes, numNodes); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphGetRootNodes( + cudaGraph_t graph, cudaGraphNode_t *pRootNodes, size_t *pNumRootNodes) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(cudaGraph_t, cudaGraphNode_t *, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaGraphGetRootNodes"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(graph, pRootNodes, pNumRootNodes); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphGetEdges(cudaGraph_t graph, + cudaGraphNode_t *from, + cudaGraphNode_t *to, + size_t *numEdges) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraph_t, cudaGraphNode_t *, + cudaGraphNode_t *, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaGraphGetEdges"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(graph, from, to, numEdges); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphNodeGetDependencies( + cudaGraphNode_t node, cudaGraphNode_t *pDependencies, + size_t *pNumDependencies) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(cudaGraphNode_t, cudaGraphNode_t *, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaGraphNodeGetDependencies"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(node, pDependencies, pNumDependencies); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphNodeGetDependentNodes( + cudaGraphNode_t node, cudaGraphNode_t *pDependentNodes, + size_t *pNumDependentNodes) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(cudaGraphNode_t, cudaGraphNode_t *, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaGraphNodeGetDependentNodes"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(node, pDependentNodes, pNumDependentNodes); +} + +extern __host__ cudaError_t CUDARTAPI +cudaGraphAddDependencies(cudaGraph_t graph, const cudaGraphNode_t *from, + const cudaGraphNode_t *to, size_t numDependencies) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraph_t, const cudaGraphNode_t *, + const cudaGraphNode_t *, size_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaGraphAddDependencies"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(graph, from, to, numDependencies); +} + +extern __host__ cudaError_t CUDARTAPI +cudaGraphRemoveDependencies(cudaGraph_t graph, const cudaGraphNode_t *from, + const cudaGraphNode_t *to, size_t numDependencies) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraph_t, const cudaGraphNode_t *, + const cudaGraphNode_t *, size_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaGraphRemoveDependencies"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(graph, from, to, numDependencies); +} + +extern __host__ cudaError_t CUDARTAPI +cudaGraphDestroyNode(cudaGraphNode_t node) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraphNode_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaGraphDestroyNode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(node); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphInstantiate( + cudaGraphExec_t *pGraphExec, cudaGraph_t graph, cudaGraphNode_t *pErrorNode, + char *pLogBuffer, size_t bufferSize) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraphExec_t *, cudaGraph_t, + cudaGraphNode_t *, char *, size_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaGraphInstantiate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(pGraphExec, graph, pErrorNode, pLogBuffer, bufferSize); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphExecKernelNodeSetParams( + cudaGraphExec_t hGraphExec, cudaGraphNode_t node, + const struct cudaKernelNodeParams *pNodeParams) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraphExec_t, cudaGraphNode_t, + const struct cudaKernelNodeParams *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudaGraphExecKernelNodeSetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hGraphExec, node, pNodeParams); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphExecMemcpyNodeSetParams( + cudaGraphExec_t hGraphExec, cudaGraphNode_t node, + const struct cudaMemcpy3DParms *pNodeParams) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraphExec_t, cudaGraphNode_t, + const struct cudaMemcpy3DParms *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudaGraphExecMemcpyNodeSetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hGraphExec, node, pNodeParams); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphExecMemsetNodeSetParams( + cudaGraphExec_t hGraphExec, cudaGraphNode_t node, + const struct cudaMemsetParams *pNodeParams) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraphExec_t, cudaGraphNode_t, + const struct cudaMemsetParams *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudaGraphExecMemsetNodeSetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hGraphExec, node, pNodeParams); +} + +extern __host__ cudaError_t CUDARTAPI +cudaGraphExecHostNodeSetParams(cudaGraphExec_t hGraphExec, cudaGraphNode_t node, + const struct cudaHostNodeParams *pNodeParams) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraphExec_t, cudaGraphNode_t, + const struct cudaHostNodeParams *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaGraphExecHostNodeSetParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hGraphExec, node, pNodeParams); +} + +extern __host__ cudaError_t CUDARTAPI +cudaGraphExecUpdate(cudaGraphExec_t hGraphExec, cudaGraph_t hGraph, + cudaGraphNode_t *hErrorNode_out, + enum cudaGraphExecUpdateResult *updateResult_out) { + using FuncPtr = + cudaError_t(CUDARTAPI *)(cudaGraphExec_t, cudaGraph_t, cudaGraphNode_t *, + enum cudaGraphExecUpdateResult *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaGraphExecUpdate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(hGraphExec, hGraph, hErrorNode_out, updateResult_out); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphLaunch(cudaGraphExec_t graphExec, + cudaStream_t stream) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraphExec_t, cudaStream_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaGraphLaunch"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(graphExec, stream); +} + +extern __host__ cudaError_t CUDARTAPI +cudaGraphExecDestroy(cudaGraphExec_t graphExec) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraphExec_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaGraphExecDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(graphExec); +} + +extern __host__ cudaError_t CUDARTAPI cudaGraphDestroy(cudaGraph_t graph) { + using FuncPtr = cudaError_t(CUDARTAPI *)(cudaGraph_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaGraphDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(graph); +} + +extern __host__ cudaError_t CUDARTAPI cudaGetExportTable( + const void **ppExportTable, const cudaUUID_t *pExportTableId) { + using FuncPtr = cudaError_t(CUDARTAPI *)(const void **, const cudaUUID_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cudaGetExportTable"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(ppExportTable, pExportTableId); +} + +} // extern "C" diff --git a/tensorflow/stream_executor/cuda/cudnn_6_0.inc b/tensorflow/stream_executor/cuda/cudnn_6_0.inc index 6ac7a695d9f..11288983a4a 100644 --- a/tensorflow/stream_executor/cuda/cudnn_6_0.inc +++ b/tensorflow/stream_executor/cuda/cudnn_6_0.inc @@ -3,1771 +3,1823 @@ extern "C" { size_t CUDNNWINAPI cudnnGetVersion(void) { - using FuncPtr = size_t (CUDNNWINAPI *)(); + using FuncPtr = size_t(CUDNNWINAPI *)(); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetVersion"); if (!func_ptr) return 0; return func_ptr(); } size_t CUDNNWINAPI cudnnGetCudartVersion(void) { - using FuncPtr = size_t (CUDNNWINAPI *)(); + using FuncPtr = size_t(CUDNNWINAPI *)(); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetCudartVersion"); if (!func_ptr) return 0; return func_ptr(); } -const char * CUDNNWINAPI cudnnGetErrorString(cudnnStatus_t status) { - using FuncPtr = const char * (CUDNNWINAPI *)(cudnnStatus_t); +const char *CUDNNWINAPI cudnnGetErrorString(cudnnStatus_t status) { + using FuncPtr = const char *(CUDNNWINAPI *)(cudnnStatus_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetErrorString"); if (!func_ptr) return "cudnnGetErrorString symbol not found."; return func_ptr(status); } -cudnnStatus_t CUDNNWINAPI cudnnGetProperty(libraryPropertyType type, int *value) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(libraryPropertyType, int *); +cudnnStatus_t CUDNNWINAPI cudnnGetProperty(libraryPropertyType type, + int *value) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(libraryPropertyType, int *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetProperty"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(type, value); } -cudnnStatus_t CUDNNWINAPI cudnnCreate (cudnnHandle_t *handle) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t *); +cudnnStatus_t CUDNNWINAPI cudnnCreate(cudnnHandle_t *handle) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreate"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle); } -cudnnStatus_t CUDNNWINAPI cudnnDestroy (cudnnHandle_t handle) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t); +cudnnStatus_t CUDNNWINAPI cudnnDestroy(cudnnHandle_t handle) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroy"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle); } -cudnnStatus_t CUDNNWINAPI cudnnSetStream (cudnnHandle_t handle, cudaStream_t streamId) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudaStream_t); +cudnnStatus_t CUDNNWINAPI cudnnSetStream(cudnnHandle_t handle, + cudaStream_t streamId) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t, cudaStream_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetStream"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, streamId); } -cudnnStatus_t CUDNNWINAPI cudnnGetStream (cudnnHandle_t handle, cudaStream_t *streamId) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudaStream_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetStream(cudnnHandle_t handle, + cudaStream_t *streamId) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t, cudaStream_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetStream"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, streamId); } -cudnnStatus_t CUDNNWINAPI cudnnCreateTensorDescriptor( - cudnnTensorDescriptor_t *tensorDesc ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorDescriptor_t *); +cudnnStatus_t CUDNNWINAPI +cudnnCreateTensorDescriptor(cudnnTensorDescriptor_t *tensorDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnTensorDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(tensorDesc); } cudnnStatus_t CUDNNWINAPI cudnnSetTensor4dDescriptor( - cudnnTensorDescriptor_t tensorDesc, - cudnnTensorFormat_t format, - cudnnDataType_t dataType, // image data type - int n, // number of inputs (batch size) - int c, // number of input feature maps - int h, // height of input section - int w ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorDescriptor_t, cudnnTensorFormat_t, cudnnDataType_t, int, int, int, int); + cudnnTensorDescriptor_t tensorDesc, cudnnTensorFormat_t format, + cudnnDataType_t dataType, // image data type + int n, // number of inputs (batch size) + int c, // number of input feature maps + int h, // height of input section + int w) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnTensorDescriptor_t, cudnnTensorFormat_t, + cudnnDataType_t, int, int, int, int); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetTensor4dDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(tensorDesc, format, dataType, n, c, h, w); } cudnnStatus_t CUDNNWINAPI cudnnSetTensor4dDescriptorEx( - cudnnTensorDescriptor_t tensorDesc, - cudnnDataType_t dataType, // image data type - int n, // number of inputs (batch size) - int c, // number of input feature maps - int h, // height of input section - int w, // width of input section - int nStride, - int cStride, - int hStride, - int wStride ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorDescriptor_t, cudnnDataType_t, int, int, int, int, int, int, int, int); + cudnnTensorDescriptor_t tensorDesc, + cudnnDataType_t dataType, // image data type + int n, // number of inputs (batch size) + int c, // number of input feature maps + int h, // height of input section + int w, // width of input section + int nStride, int cStride, int hStride, int wStride) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnTensorDescriptor_t, cudnnDataType_t, + int, int, int, int, int, int, int, int); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetTensor4dDescriptorEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(tensorDesc, dataType, n, c, h, w, nStride, cStride, hStride, wStride); + return func_ptr(tensorDesc, dataType, n, c, h, w, nStride, cStride, hStride, + wStride); } cudnnStatus_t CUDNNWINAPI cudnnGetTensor4dDescriptor( - const cudnnTensorDescriptor_t tensorDesc, - cudnnDataType_t *dataType, // image data type - int *n, // number of inputs (batch size) - int *c, // number of input feature maps - int *h, // height of input section - int *w, // width of input section - int *nStride, - int *cStride, - int *hStride, - int *wStride ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnTensorDescriptor_t, cudnnDataType_t *, int *, int *, int *, int *, int *, int *, int *, int *); + const cudnnTensorDescriptor_t tensorDesc, + cudnnDataType_t *dataType, // image data type + int *n, // number of inputs (batch size) + int *c, // number of input feature maps + int *h, // height of input section + int *w, // width of input section + int *nStride, int *cStride, int *hStride, int *wStride) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnTensorDescriptor_t, cudnnDataType_t *, int *, int *, int *, + int *, int *, int *, int *, int *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetTensor4dDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(tensorDesc, dataType, n, c, h, w, nStride, cStride, hStride, wStride); + return func_ptr(tensorDesc, dataType, n, c, h, w, nStride, cStride, hStride, + wStride); } cudnnStatus_t CUDNNWINAPI cudnnSetTensorNdDescriptor( - cudnnTensorDescriptor_t tensorDesc, - cudnnDataType_t dataType, - int nbDims, - const int dimA[], - const int strideA[] ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorDescriptor_t, cudnnDataType_t, int, const int [], const int []); + cudnnTensorDescriptor_t tensorDesc, cudnnDataType_t dataType, int nbDims, + const int dimA[], const int strideA[]) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnTensorDescriptor_t, cudnnDataType_t, int, const int[], const int[]); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetTensorNdDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(tensorDesc, dataType, nbDims, dimA, strideA); } cudnnStatus_t CUDNNWINAPI cudnnSetTensorNdDescriptorEx( - cudnnTensorDescriptor_t tensorDesc, - cudnnTensorFormat_t format, - cudnnDataType_t dataType, - int nbDims, - const int dimA[] ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorDescriptor_t, cudnnTensorFormat_t, cudnnDataType_t, int, const int []); + cudnnTensorDescriptor_t tensorDesc, cudnnTensorFormat_t format, + cudnnDataType_t dataType, int nbDims, const int dimA[]) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnTensorDescriptor_t, cudnnTensorFormat_t, + cudnnDataType_t, int, const int[]); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetTensorNdDescriptorEx"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(tensorDesc, format, dataType, nbDims, dimA); } cudnnStatus_t CUDNNWINAPI cudnnGetTensorNdDescriptor( - const cudnnTensorDescriptor_t tensorDesc, - int nbDimsRequested, - cudnnDataType_t *dataType, - int *nbDims, - int dimA[], - int strideA[] ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnTensorDescriptor_t, int, cudnnDataType_t *, int *, int [], int []); + const cudnnTensorDescriptor_t tensorDesc, int nbDimsRequested, + cudnnDataType_t *dataType, int *nbDims, int dimA[], int strideA[]) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(const cudnnTensorDescriptor_t, int, + cudnnDataType_t *, int *, int[], int[]); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetTensorNdDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(tensorDesc, nbDimsRequested, dataType, nbDims, dimA, strideA); } cudnnStatus_t CUDNNWINAPI cudnnGetTensorSizeInBytes( - const cudnnTensorDescriptor_t tensorDesc, - size_t *size) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnTensorDescriptor_t, size_t *); + const cudnnTensorDescriptor_t tensorDesc, size_t *size) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(const cudnnTensorDescriptor_t, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetTensorSizeInBytes"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(tensorDesc, size); } -cudnnStatus_t CUDNNWINAPI cudnnDestroyTensorDescriptor( - cudnnTensorDescriptor_t tensorDesc ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorDescriptor_t); +cudnnStatus_t CUDNNWINAPI +cudnnDestroyTensorDescriptor(cudnnTensorDescriptor_t tensorDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnTensorDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(tensorDesc); } cudnnStatus_t CUDNNWINAPI cudnnTransformTensor( - cudnnHandle_t handle, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t yDesc, - void *y ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); + cudnnHandle_t handle, const void *alpha, + const cudnnTensorDescriptor_t xDesc, const void *x, const void *beta, + const cudnnTensorDescriptor_t yDesc, void *y) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, + const void *, const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnTransformTensor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, alpha, xDesc, x, beta, yDesc, y); } -cudnnStatus_t CUDNNWINAPI cudnnAddTensor( - cudnnHandle_t handle, - const void *alpha, - const cudnnTensorDescriptor_t aDesc, - const void *A, - const void *beta, - const cudnnTensorDescriptor_t cDesc, - void *C ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnAddTensor(cudnnHandle_t handle, + const void *alpha, + const cudnnTensorDescriptor_t aDesc, + const void *A, const void *beta, + const cudnnTensorDescriptor_t cDesc, + void *C) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, + const void *, const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnAddTensor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, alpha, aDesc, A, beta, cDesc, C); } -cudnnStatus_t CUDNNWINAPI cudnnCreateOpTensorDescriptor( - cudnnOpTensorDescriptor_t *opTensorDesc ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnOpTensorDescriptor_t *); +cudnnStatus_t CUDNNWINAPI +cudnnCreateOpTensorDescriptor(cudnnOpTensorDescriptor_t *opTensorDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnOpTensorDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateOpTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(opTensorDesc); } cudnnStatus_t CUDNNWINAPI cudnnSetOpTensorDescriptor( - cudnnOpTensorDescriptor_t opTensorDesc, - cudnnOpTensorOp_t opTensorOp, - cudnnDataType_t opTensorCompType, - cudnnNanPropagation_t opTensorNanOpt ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnOpTensorDescriptor_t, cudnnOpTensorOp_t, cudnnDataType_t, cudnnNanPropagation_t); + cudnnOpTensorDescriptor_t opTensorDesc, cudnnOpTensorOp_t opTensorOp, + cudnnDataType_t opTensorCompType, cudnnNanPropagation_t opTensorNanOpt) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnOpTensorDescriptor_t, cudnnOpTensorOp_t, + cudnnDataType_t, cudnnNanPropagation_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetOpTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(opTensorDesc, opTensorOp, opTensorCompType, opTensorNanOpt); } cudnnStatus_t CUDNNWINAPI cudnnGetOpTensorDescriptor( - const cudnnOpTensorDescriptor_t opTensorDesc, - cudnnOpTensorOp_t *opTensorOp, - cudnnDataType_t *opTensorCompType, - cudnnNanPropagation_t *opTensorNanOpt ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnOpTensorDescriptor_t, cudnnOpTensorOp_t *, cudnnDataType_t *, cudnnNanPropagation_t *); + const cudnnOpTensorDescriptor_t opTensorDesc, cudnnOpTensorOp_t *opTensorOp, + cudnnDataType_t *opTensorCompType, cudnnNanPropagation_t *opTensorNanOpt) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnOpTensorDescriptor_t, cudnnOpTensorOp_t *, cudnnDataType_t *, + cudnnNanPropagation_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetOpTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(opTensorDesc, opTensorOp, opTensorCompType, opTensorNanOpt); } -cudnnStatus_t CUDNNWINAPI cudnnDestroyOpTensorDescriptor( - cudnnOpTensorDescriptor_t opTensorDesc ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnOpTensorDescriptor_t); +cudnnStatus_t CUDNNWINAPI +cudnnDestroyOpTensorDescriptor(cudnnOpTensorDescriptor_t opTensorDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnOpTensorDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyOpTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(opTensorDesc); } cudnnStatus_t CUDNNWINAPI cudnnOpTensor( - cudnnHandle_t handle, - const cudnnOpTensorDescriptor_t opTensorDesc, - const void *alpha1, - const cudnnTensorDescriptor_t aDesc, - const void *A, - const void *alpha2, - const cudnnTensorDescriptor_t bDesc, - const void *B, - const void *beta, - const cudnnTensorDescriptor_t cDesc, - void *C ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnOpTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); + cudnnHandle_t handle, const cudnnOpTensorDescriptor_t opTensorDesc, + const void *alpha1, const cudnnTensorDescriptor_t aDesc, const void *A, + const void *alpha2, const cudnnTensorDescriptor_t bDesc, const void *B, + const void *beta, const cudnnTensorDescriptor_t cDesc, void *C) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnOpTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnOpTensor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, opTensorDesc, alpha1, aDesc, A, alpha2, bDesc, B, beta, cDesc, C); + return func_ptr(handle, opTensorDesc, alpha1, aDesc, A, alpha2, bDesc, B, + beta, cDesc, C); } cudnnStatus_t CUDNNWINAPI cudnnCreateReduceTensorDescriptor( - cudnnReduceTensorDescriptor_t *reduceTensorDesc ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnReduceTensorDescriptor_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateReduceTensorDescriptor"); + cudnnReduceTensorDescriptor_t *reduceTensorDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnReduceTensorDescriptor_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnCreateReduceTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(reduceTensorDesc); } cudnnStatus_t CUDNNWINAPI cudnnSetReduceTensorDescriptor( - cudnnReduceTensorDescriptor_t reduceTensorDesc, - cudnnReduceTensorOp_t reduceTensorOp, - cudnnDataType_t reduceTensorCompType, - cudnnNanPropagation_t reduceTensorNanOpt, - cudnnReduceTensorIndices_t reduceTensorIndices, - cudnnIndicesType_t reduceTensorIndicesType ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnReduceTensorDescriptor_t, cudnnReduceTensorOp_t, cudnnDataType_t, cudnnNanPropagation_t, cudnnReduceTensorIndices_t, cudnnIndicesType_t); + cudnnReduceTensorDescriptor_t reduceTensorDesc, + cudnnReduceTensorOp_t reduceTensorOp, cudnnDataType_t reduceTensorCompType, + cudnnNanPropagation_t reduceTensorNanOpt, + cudnnReduceTensorIndices_t reduceTensorIndices, + cudnnIndicesType_t reduceTensorIndicesType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnReduceTensorDescriptor_t, cudnnReduceTensorOp_t, cudnnDataType_t, + cudnnNanPropagation_t, cudnnReduceTensorIndices_t, cudnnIndicesType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetReduceTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(reduceTensorDesc, reduceTensorOp, reduceTensorCompType, reduceTensorNanOpt, reduceTensorIndices, reduceTensorIndicesType); + return func_ptr(reduceTensorDesc, reduceTensorOp, reduceTensorCompType, + reduceTensorNanOpt, reduceTensorIndices, + reduceTensorIndicesType); } cudnnStatus_t CUDNNWINAPI cudnnGetReduceTensorDescriptor( - const cudnnReduceTensorDescriptor_t reduceTensorDesc, - cudnnReduceTensorOp_t *reduceTensorOp, - cudnnDataType_t *reduceTensorCompType, - cudnnNanPropagation_t *reduceTensorNanOpt, - cudnnReduceTensorIndices_t *reduceTensorIndices, - cudnnIndicesType_t *reduceTensorIndicesType ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnReduceTensorDescriptor_t, cudnnReduceTensorOp_t *, cudnnDataType_t *, cudnnNanPropagation_t *, cudnnReduceTensorIndices_t *, cudnnIndicesType_t *); + const cudnnReduceTensorDescriptor_t reduceTensorDesc, + cudnnReduceTensorOp_t *reduceTensorOp, + cudnnDataType_t *reduceTensorCompType, + cudnnNanPropagation_t *reduceTensorNanOpt, + cudnnReduceTensorIndices_t *reduceTensorIndices, + cudnnIndicesType_t *reduceTensorIndicesType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnReduceTensorDescriptor_t, cudnnReduceTensorOp_t *, + cudnnDataType_t *, cudnnNanPropagation_t *, cudnnReduceTensorIndices_t *, + cudnnIndicesType_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetReduceTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(reduceTensorDesc, reduceTensorOp, reduceTensorCompType, reduceTensorNanOpt, reduceTensorIndices, reduceTensorIndicesType); + return func_ptr(reduceTensorDesc, reduceTensorOp, reduceTensorCompType, + reduceTensorNanOpt, reduceTensorIndices, + reduceTensorIndicesType); } cudnnStatus_t CUDNNWINAPI cudnnDestroyReduceTensorDescriptor( - cudnnReduceTensorDescriptor_t reduceTensorDesc ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnReduceTensorDescriptor_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyReduceTensorDescriptor"); + cudnnReduceTensorDescriptor_t reduceTensorDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnReduceTensorDescriptor_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnDestroyReduceTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(reduceTensorDesc); } cudnnStatus_t CUDNNWINAPI cudnnGetReductionIndicesSize( - cudnnHandle_t handle, - const cudnnReduceTensorDescriptor_t reduceTensorDesc, - const cudnnTensorDescriptor_t aDesc, - const cudnnTensorDescriptor_t cDesc, - size_t *sizeInBytes ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnReduceTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, size_t *); + cudnnHandle_t handle, const cudnnReduceTensorDescriptor_t reduceTensorDesc, + const cudnnTensorDescriptor_t aDesc, const cudnnTensorDescriptor_t cDesc, + size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnReduceTensorDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetReductionIndicesSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, reduceTensorDesc, aDesc, cDesc, sizeInBytes); } cudnnStatus_t CUDNNWINAPI cudnnGetReductionWorkspaceSize( - cudnnHandle_t handle, - const cudnnReduceTensorDescriptor_t reduceTensorDesc, - const cudnnTensorDescriptor_t aDesc, - const cudnnTensorDescriptor_t cDesc, - size_t *sizeInBytes ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnReduceTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, size_t *); + cudnnHandle_t handle, const cudnnReduceTensorDescriptor_t reduceTensorDesc, + const cudnnTensorDescriptor_t aDesc, const cudnnTensorDescriptor_t cDesc, + size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnReduceTensorDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetReductionWorkspaceSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, reduceTensorDesc, aDesc, cDesc, sizeInBytes); } cudnnStatus_t CUDNNWINAPI cudnnReduceTensor( - cudnnHandle_t handle, - const cudnnReduceTensorDescriptor_t reduceTensorDesc, - void *indices, - size_t indicesSizeInBytes, - void *workspace, - size_t workspaceSizeInBytes, - const void *alpha, - const cudnnTensorDescriptor_t aDesc, - const void *A, - const void *beta, - const cudnnTensorDescriptor_t cDesc, - void *C ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnReduceTensorDescriptor_t, void *, size_t, void *, size_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); + cudnnHandle_t handle, const cudnnReduceTensorDescriptor_t reduceTensorDesc, + void *indices, size_t indicesSizeInBytes, void *workspace, + size_t workspaceSizeInBytes, const void *alpha, + const cudnnTensorDescriptor_t aDesc, const void *A, const void *beta, + const cudnnTensorDescriptor_t cDesc, void *C) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnReduceTensorDescriptor_t, void *, size_t, + void *, size_t, const void *, const cudnnTensorDescriptor_t, const void *, + const void *, const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnReduceTensor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, reduceTensorDesc, indices, indicesSizeInBytes, workspace, workspaceSizeInBytes, alpha, aDesc, A, beta, cDesc, C); + return func_ptr(handle, reduceTensorDesc, indices, indicesSizeInBytes, + workspace, workspaceSizeInBytes, alpha, aDesc, A, beta, cDesc, + C); } -cudnnStatus_t CUDNNWINAPI cudnnSetTensor( - cudnnHandle_t handle, - const cudnnTensorDescriptor_t yDesc, - void *y, - const void *valuePtr ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, void *, const void *); +cudnnStatus_t CUDNNWINAPI cudnnSetTensor(cudnnHandle_t handle, + const cudnnTensorDescriptor_t yDesc, + void *y, const void *valuePtr) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, void *, const void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetTensor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, yDesc, y, valuePtr); } -cudnnStatus_t CUDNNWINAPI cudnnScaleTensor( - cudnnHandle_t handle, - const cudnnTensorDescriptor_t yDesc, - void *y, - const void *alpha ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, void *, const void *); +cudnnStatus_t CUDNNWINAPI cudnnScaleTensor(cudnnHandle_t handle, + const cudnnTensorDescriptor_t yDesc, + void *y, const void *alpha) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, void *, const void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnScaleTensor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, yDesc, y, alpha); } -cudnnStatus_t CUDNNWINAPI cudnnCreateFilterDescriptor( - cudnnFilterDescriptor_t *filterDesc ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnFilterDescriptor_t *); +cudnnStatus_t CUDNNWINAPI +cudnnCreateFilterDescriptor(cudnnFilterDescriptor_t *filterDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnFilterDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateFilterDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(filterDesc); } -cudnnStatus_t CUDNNWINAPI cudnnSetFilter4dDescriptor( - cudnnFilterDescriptor_t filterDesc, - cudnnDataType_t dataType, // image data type - cudnnTensorFormat_t format, - int k, // number of output feature maps - int c, // number of input feature maps - int h, // height of each input filter - int w ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnFilterDescriptor_t, cudnnDataType_t, cudnnTensorFormat_t, int, int, int, int); +cudnnStatus_t CUDNNWINAPI +cudnnSetFilter4dDescriptor(cudnnFilterDescriptor_t filterDesc, + cudnnDataType_t dataType, // image data type + cudnnTensorFormat_t format, + int k, // number of output feature maps + int c, // number of input feature maps + int h, // height of each input filter + int w) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnFilterDescriptor_t, cudnnDataType_t, + cudnnTensorFormat_t, int, int, int, int); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetFilter4dDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(filterDesc, dataType, format, k, c, h, w); } -cudnnStatus_t CUDNNWINAPI cudnnGetFilter4dDescriptor( - const cudnnFilterDescriptor_t filterDesc, - cudnnDataType_t *dataType, // image data type - cudnnTensorFormat_t *format, - int *k, // number of output feature maps - int *c, // number of input feature maps - int *h, // height of each input filter - int *w ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnFilterDescriptor_t, cudnnDataType_t *, cudnnTensorFormat_t *, int *, int *, int *, int *); +cudnnStatus_t CUDNNWINAPI +cudnnGetFilter4dDescriptor(const cudnnFilterDescriptor_t filterDesc, + cudnnDataType_t *dataType, // image data type + cudnnTensorFormat_t *format, + int *k, // number of output feature maps + int *c, // number of input feature maps + int *h, // height of each input filter + int *w) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnFilterDescriptor_t, cudnnDataType_t *, cudnnTensorFormat_t *, + int *, int *, int *, int *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetFilter4dDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(filterDesc, dataType, format, k, c, h, w); } cudnnStatus_t CUDNNWINAPI cudnnSetFilterNdDescriptor( - cudnnFilterDescriptor_t filterDesc, - cudnnDataType_t dataType, // image data type - cudnnTensorFormat_t format, - int nbDims, - const int filterDimA[] ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnFilterDescriptor_t, cudnnDataType_t, cudnnTensorFormat_t, int, const int []); + cudnnFilterDescriptor_t filterDesc, + cudnnDataType_t dataType, // image data type + cudnnTensorFormat_t format, int nbDims, const int filterDimA[]) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnFilterDescriptor_t, cudnnDataType_t, + cudnnTensorFormat_t, int, const int[]); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetFilterNdDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(filterDesc, dataType, format, nbDims, filterDimA); } cudnnStatus_t CUDNNWINAPI cudnnGetFilterNdDescriptor( - const cudnnFilterDescriptor_t filterDesc, - int nbDimsRequested, - cudnnDataType_t *dataType, // image data type - cudnnTensorFormat_t *format, - int *nbDims, - int filterDimA[] ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnFilterDescriptor_t, int, cudnnDataType_t *, cudnnTensorFormat_t *, int *, int []); + const cudnnFilterDescriptor_t filterDesc, int nbDimsRequested, + cudnnDataType_t *dataType, // image data type + cudnnTensorFormat_t *format, int *nbDims, int filterDimA[]) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnFilterDescriptor_t, int, cudnnDataType_t *, + cudnnTensorFormat_t *, int *, int[]); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetFilterNdDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(filterDesc, nbDimsRequested, dataType, format, nbDims, filterDimA); + return func_ptr(filterDesc, nbDimsRequested, dataType, format, nbDims, + filterDimA); } -cudnnStatus_t CUDNNWINAPI cudnnDestroyFilterDescriptor( - cudnnFilterDescriptor_t filterDesc ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnFilterDescriptor_t); +cudnnStatus_t CUDNNWINAPI +cudnnDestroyFilterDescriptor(cudnnFilterDescriptor_t filterDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnFilterDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyFilterDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(filterDesc); } -cudnnStatus_t CUDNNWINAPI cudnnCreateConvolutionDescriptor( - cudnnConvolutionDescriptor_t *convDesc ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateConvolutionDescriptor"); +cudnnStatus_t CUDNNWINAPI +cudnnCreateConvolutionDescriptor(cudnnConvolutionDescriptor_t *convDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnConvolutionDescriptor_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnCreateConvolutionDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(convDesc); } -cudnnStatus_t CUDNNWINAPI cudnnSetConvolution2dDescriptor( cudnnConvolutionDescriptor_t convDesc, - int pad_h, // zero-padding height - int pad_w, // zero-padding width - int u, // vertical filter stride - int v, // horizontal filter stride - int dilation_h, // filter dilation in the vertical dimension - int dilation_w, // filter dilation in the horizontal dimension - cudnnConvolutionMode_t mode, - cudnnDataType_t computeType - ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, int, int, int, int, int, int, cudnnConvolutionMode_t, cudnnDataType_t); +cudnnStatus_t CUDNNWINAPI cudnnSetConvolution2dDescriptor( + cudnnConvolutionDescriptor_t convDesc, + int pad_h, // zero-padding height + int pad_w, // zero-padding width + int u, // vertical filter stride + int v, // horizontal filter stride + int dilation_h, // filter dilation in the vertical dimension + int dilation_w, // filter dilation in the horizontal dimension + cudnnConvolutionMode_t mode, cudnnDataType_t computeType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnConvolutionDescriptor_t, int, int, int, int, int, int, + cudnnConvolutionMode_t, cudnnDataType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetConvolution2dDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(convDesc, pad_h, pad_w, u, v, dilation_h, dilation_w, mode, computeType); + return func_ptr(convDesc, pad_h, pad_w, u, v, dilation_h, dilation_w, mode, + computeType); } -cudnnStatus_t CUDNNWINAPI cudnnGetConvolution2dDescriptor( const cudnnConvolutionDescriptor_t convDesc, - int* pad_h, // zero-padding height - int* pad_w, // zero-padding width - int* u, // vertical filter stride - int* v, // horizontal filter stride - int* dilation_h, // filter dilation in the vertical dimension - int* dilation_w, // filter dilation in the horizontal dimension - cudnnConvolutionMode_t* mode, - cudnnDataType_t *computeType - ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnConvolutionDescriptor_t, int *, int *, int *, int *, int *, int *, cudnnConvolutionMode_t *, cudnnDataType_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolution2dDescriptor( + const cudnnConvolutionDescriptor_t convDesc, + int *pad_h, // zero-padding height + int *pad_w, // zero-padding width + int *u, // vertical filter stride + int *v, // horizontal filter stride + int *dilation_h, // filter dilation in the vertical dimension + int *dilation_w, // filter dilation in the horizontal dimension + cudnnConvolutionMode_t *mode, cudnnDataType_t *computeType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnConvolutionDescriptor_t, int *, int *, int *, int *, int *, + int *, cudnnConvolutionMode_t *, cudnnDataType_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolution2dDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(convDesc, pad_h, pad_w, u, v, dilation_h, dilation_w, mode, computeType); + return func_ptr(convDesc, pad_h, pad_w, u, v, dilation_h, dilation_w, mode, + computeType); } cudnnStatus_t CUDNNWINAPI cudnnGetConvolution2dForwardOutputDim( - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t inputTensorDesc, - const cudnnFilterDescriptor_t filterDesc, - int *n, - int *c, - int *h, - int *w ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, int *, int *, int *, int *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolution2dForwardOutputDim"); + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t inputTensorDesc, + const cudnnFilterDescriptor_t filterDesc, int *n, int *c, int *h, int *w) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, + const cudnnFilterDescriptor_t, int *, int *, int *, int *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolution2dForwardOutputDim"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(convDesc, inputTensorDesc, filterDesc, n, c, h, w); } cudnnStatus_t CUDNNWINAPI cudnnSetConvolutionNdDescriptor( - cudnnConvolutionDescriptor_t convDesc, - int arrayLength, /* nbDims-2 size */ - const int padA[], - const int filterStrideA[], - const int dilationA[], - cudnnConvolutionMode_t mode, - cudnnDataType_t computeType ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, int, const int [], const int [], const int [], cudnnConvolutionMode_t, cudnnDataType_t); + cudnnConvolutionDescriptor_t convDesc, int arrayLength, /* nbDims-2 size */ + const int padA[], const int filterStrideA[], const int dilationA[], + cudnnConvolutionMode_t mode, cudnnDataType_t computeType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnConvolutionDescriptor_t, int, const int[], const int[], const int[], + cudnnConvolutionMode_t, cudnnDataType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetConvolutionNdDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(convDesc, arrayLength, padA, filterStrideA, dilationA, mode, computeType); + return func_ptr(convDesc, arrayLength, padA, filterStrideA, dilationA, mode, + computeType); } cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionNdDescriptor( - const cudnnConvolutionDescriptor_t convDesc, - int arrayLengthRequested, - int *arrayLength, - int padA[], - int strideA[], - int dilationA[], - cudnnConvolutionMode_t *mode, - cudnnDataType_t *computeType ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnConvolutionDescriptor_t, int, int *, int [], int [], int [], cudnnConvolutionMode_t *, cudnnDataType_t *); + const cudnnConvolutionDescriptor_t convDesc, int arrayLengthRequested, + int *arrayLength, int padA[], int strideA[], int dilationA[], + cudnnConvolutionMode_t *mode, cudnnDataType_t *computeType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnConvolutionDescriptor_t, int, int *, int[], int[], int[], + cudnnConvolutionMode_t *, cudnnDataType_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionNdDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(convDesc, arrayLengthRequested, arrayLength, padA, strideA, dilationA, mode, computeType); + return func_ptr(convDesc, arrayLengthRequested, arrayLength, padA, strideA, + dilationA, mode, computeType); } cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionNdForwardOutputDim( - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t inputTensorDesc, - const cudnnFilterDescriptor_t filterDesc, - int nbDims, - int tensorOutputDimA[] ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, int, int []); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionNdForwardOutputDim"); + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t inputTensorDesc, + const cudnnFilterDescriptor_t filterDesc, int nbDims, + int tensorOutputDimA[]) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, + const cudnnFilterDescriptor_t, int, int[]); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionNdForwardOutputDim"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(convDesc, inputTensorDesc, filterDesc, nbDims, tensorOutputDimA); + return func_ptr(convDesc, inputTensorDesc, filterDesc, nbDims, + tensorOutputDimA); } -cudnnStatus_t CUDNNWINAPI cudnnDestroyConvolutionDescriptor( - cudnnConvolutionDescriptor_t convDesc ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyConvolutionDescriptor"); +cudnnStatus_t CUDNNWINAPI +cudnnDestroyConvolutionDescriptor(cudnnConvolutionDescriptor_t convDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnConvolutionDescriptor_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnDestroyConvolutionDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(convDesc); } cudnnStatus_t CUDNNWINAPI cudnnFindConvolutionForwardAlgorithm( - cudnnHandle_t handle, - const cudnnTensorDescriptor_t xDesc, - const cudnnFilterDescriptor_t wDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t yDesc, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnConvolutionFwdAlgoPerf_t *perfResults ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, const int, int *, cudnnConvolutionFwdAlgoPerf_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindConvolutionForwardAlgorithm"); + cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, + const cudnnFilterDescriptor_t wDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t yDesc, const int requestedAlgoCount, + int *returnedAlgoCount, cudnnConvolutionFwdAlgoPerf_t *perfResults) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, + const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnTensorDescriptor_t, const int, int *, + cudnnConvolutionFwdAlgoPerf_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnFindConvolutionForwardAlgorithm"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, xDesc, wDesc, convDesc, yDesc, requestedAlgoCount, returnedAlgoCount, perfResults); + return func_ptr(handle, xDesc, wDesc, convDesc, yDesc, requestedAlgoCount, + returnedAlgoCount, perfResults); } cudnnStatus_t CUDNNWINAPI cudnnFindConvolutionForwardAlgorithmEx( - cudnnHandle_t handle, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t yDesc, - void *y, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnConvolutionFwdAlgoPerf_t *perfResults, - void *workSpace, - size_t workSpaceSizeInBytes ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, void *, const int, int *, cudnnConvolutionFwdAlgoPerf_t *, void *, size_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindConvolutionForwardAlgorithmEx"); + cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, const void *x, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t yDesc, void *y, const int requestedAlgoCount, + int *returnedAlgoCount, cudnnConvolutionFwdAlgoPerf_t *perfResults, + void *workSpace, size_t workSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, void *, + const int, int *, cudnnConvolutionFwdAlgoPerf_t *, void *, size_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnFindConvolutionForwardAlgorithmEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, xDesc, x, wDesc, w, convDesc, yDesc, y, requestedAlgoCount, returnedAlgoCount, perfResults, workSpace, workSpaceSizeInBytes); + return func_ptr(handle, xDesc, x, wDesc, w, convDesc, yDesc, y, + requestedAlgoCount, returnedAlgoCount, perfResults, workSpace, + workSpaceSizeInBytes); } cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionForwardAlgorithm( - cudnnHandle_t handle, - const cudnnTensorDescriptor_t xDesc, - const cudnnFilterDescriptor_t wDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t yDesc, - cudnnConvolutionFwdPreference_t preference, - size_t memoryLimitInBytes, - cudnnConvolutionFwdAlgo_t *algo ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, cudnnConvolutionFwdPreference_t, size_t, cudnnConvolutionFwdAlgo_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionForwardAlgorithm"); + cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, + const cudnnFilterDescriptor_t wDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t yDesc, + cudnnConvolutionFwdPreference_t preference, size_t memoryLimitInBytes, + cudnnConvolutionFwdAlgo_t *algo) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, + const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnTensorDescriptor_t, cudnnConvolutionFwdPreference_t, size_t, + cudnnConvolutionFwdAlgo_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionForwardAlgorithm"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, xDesc, wDesc, convDesc, yDesc, preference, memoryLimitInBytes, algo); + return func_ptr(handle, xDesc, wDesc, convDesc, yDesc, preference, + memoryLimitInBytes, algo); } cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionForwardWorkspaceSize( - cudnnHandle_t handle, - const cudnnTensorDescriptor_t xDesc, - const cudnnFilterDescriptor_t wDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t yDesc, - cudnnConvolutionFwdAlgo_t algo, - size_t *sizeInBytes ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, cudnnConvolutionFwdAlgo_t, size_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionForwardWorkspaceSize"); + cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, + const cudnnFilterDescriptor_t wDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t yDesc, cudnnConvolutionFwdAlgo_t algo, + size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, + const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnTensorDescriptor_t, cudnnConvolutionFwdAlgo_t, size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionForwardWorkspaceSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, xDesc, wDesc, convDesc, yDesc, algo, sizeInBytes); } cudnnStatus_t CUDNNWINAPI cudnnConvolutionForward( - cudnnHandle_t handle, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnConvolutionDescriptor_t convDesc, - cudnnConvolutionFwdAlgo_t algo, - void *workSpace, - size_t workSpaceSizeInBytes, - const void *beta, - const cudnnTensorDescriptor_t yDesc, - void *y ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnConvolutionDescriptor_t, cudnnConvolutionFwdAlgo_t, void *, size_t, const void *, const cudnnTensorDescriptor_t, void *); + cudnnHandle_t handle, const void *alpha, + const cudnnTensorDescriptor_t xDesc, const void *x, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnConvolutionDescriptor_t convDesc, cudnnConvolutionFwdAlgo_t algo, + void *workSpace, size_t workSpaceSizeInBytes, const void *beta, + const cudnnTensorDescriptor_t yDesc, void *y) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnConvolutionDescriptor_t, cudnnConvolutionFwdAlgo_t, void *, + size_t, const void *, const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnConvolutionForward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, alpha, xDesc, x, wDesc, w, convDesc, algo, workSpace, workSpaceSizeInBytes, beta, yDesc, y); + return func_ptr(handle, alpha, xDesc, x, wDesc, w, convDesc, algo, workSpace, + workSpaceSizeInBytes, beta, yDesc, y); } cudnnStatus_t CUDNNWINAPI cudnnConvolutionBiasActivationForward( - cudnnHandle_t handle, - const void *alpha1, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnConvolutionDescriptor_t convDesc, - cudnnConvolutionFwdAlgo_t algo, - void *workSpace, - size_t workSpaceSizeInBytes, - const void *alpha2, - const cudnnTensorDescriptor_t zDesc, - const void *z, - const cudnnTensorDescriptor_t biasDesc, - const void *bias, - const cudnnActivationDescriptor_t activationDesc, - const cudnnTensorDescriptor_t yDesc, - void *y ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnConvolutionDescriptor_t, cudnnConvolutionFwdAlgo_t, void *, size_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnActivationDescriptor_t, const cudnnTensorDescriptor_t, void *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnConvolutionBiasActivationForward"); + cudnnHandle_t handle, const void *alpha1, + const cudnnTensorDescriptor_t xDesc, const void *x, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnConvolutionDescriptor_t convDesc, cudnnConvolutionFwdAlgo_t algo, + void *workSpace, size_t workSpaceSizeInBytes, const void *alpha2, + const cudnnTensorDescriptor_t zDesc, const void *z, + const cudnnTensorDescriptor_t biasDesc, const void *bias, + const cudnnActivationDescriptor_t activationDesc, + const cudnnTensorDescriptor_t yDesc, void *y) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnConvolutionDescriptor_t, cudnnConvolutionFwdAlgo_t, void *, + size_t, const void *, const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnActivationDescriptor_t, const cudnnTensorDescriptor_t, void *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnConvolutionBiasActivationForward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, alpha1, xDesc, x, wDesc, w, convDesc, algo, workSpace, workSpaceSizeInBytes, alpha2, zDesc, z, biasDesc, bias, activationDesc, yDesc, y); + return func_ptr(handle, alpha1, xDesc, x, wDesc, w, convDesc, algo, workSpace, + workSpaceSizeInBytes, alpha2, zDesc, z, biasDesc, bias, + activationDesc, yDesc, y); } cudnnStatus_t CUDNNWINAPI cudnnConvolutionBackwardBias( - cudnnHandle_t handle, - const void *alpha, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const void *beta, - const cudnnTensorDescriptor_t dbDesc, - void *db ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); + cudnnHandle_t handle, const void *alpha, + const cudnnTensorDescriptor_t dyDesc, const void *dy, const void *beta, + const cudnnTensorDescriptor_t dbDesc, void *db) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, + const void *, const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnConvolutionBackwardBias"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, alpha, dyDesc, dy, beta, dbDesc, db); } cudnnStatus_t CUDNNWINAPI cudnnFindConvolutionBackwardFilterAlgorithm( - cudnnHandle_t handle, - const cudnnTensorDescriptor_t xDesc, - const cudnnTensorDescriptor_t dyDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnFilterDescriptor_t dwDesc, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnConvolutionBwdFilterAlgoPerf_t *perfResults ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnFilterDescriptor_t, const int, int *, cudnnConvolutionBwdFilterAlgoPerf_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindConvolutionBackwardFilterAlgorithm"); + cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, + const cudnnTensorDescriptor_t dyDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnFilterDescriptor_t dwDesc, const int requestedAlgoCount, + int *returnedAlgoCount, cudnnConvolutionBwdFilterAlgoPerf_t *perfResults) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnFilterDescriptor_t, const int, int *, + cudnnConvolutionBwdFilterAlgoPerf_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnFindConvolutionBackwardFilterAlgorithm"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, xDesc, dyDesc, convDesc, dwDesc, requestedAlgoCount, returnedAlgoCount, perfResults); + return func_ptr(handle, xDesc, dyDesc, convDesc, dwDesc, requestedAlgoCount, + returnedAlgoCount, perfResults); } cudnnStatus_t CUDNNWINAPI cudnnFindConvolutionBackwardFilterAlgorithmEx( - cudnnHandle_t handle, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const cudnnTensorDescriptor_t dyDesc, - const void *y, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnFilterDescriptor_t dwDesc, - void *dw, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnConvolutionBwdFilterAlgoPerf_t *perfResults, - void *workSpace, - size_t workSpaceSizeInBytes ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnConvolutionDescriptor_t, const cudnnFilterDescriptor_t, void *, const int, int *, cudnnConvolutionBwdFilterAlgoPerf_t *, void *, size_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindConvolutionBackwardFilterAlgorithmEx"); + cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, const void *x, + const cudnnTensorDescriptor_t dyDesc, const void *y, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnFilterDescriptor_t dwDesc, void *dw, + const int requestedAlgoCount, int *returnedAlgoCount, + cudnnConvolutionBwdFilterAlgoPerf_t *perfResults, void *workSpace, + size_t workSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnConvolutionDescriptor_t, const cudnnFilterDescriptor_t, void *, + const int, int *, cudnnConvolutionBwdFilterAlgoPerf_t *, void *, size_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnFindConvolutionBackwardFilterAlgorithmEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, xDesc, x, dyDesc, y, convDesc, dwDesc, dw, requestedAlgoCount, returnedAlgoCount, perfResults, workSpace, workSpaceSizeInBytes); + return func_ptr(handle, xDesc, x, dyDesc, y, convDesc, dwDesc, dw, + requestedAlgoCount, returnedAlgoCount, perfResults, workSpace, + workSpaceSizeInBytes); } cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionBackwardFilterAlgorithm( - cudnnHandle_t handle, - const cudnnTensorDescriptor_t xDesc, - const cudnnTensorDescriptor_t dyDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnFilterDescriptor_t dwDesc, - cudnnConvolutionBwdFilterPreference_t preference, - size_t memoryLimitInBytes, - cudnnConvolutionBwdFilterAlgo_t *algo ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnFilterDescriptor_t, cudnnConvolutionBwdFilterPreference_t, size_t, cudnnConvolutionBwdFilterAlgo_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardFilterAlgorithm"); + cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, + const cudnnTensorDescriptor_t dyDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnFilterDescriptor_t dwDesc, + cudnnConvolutionBwdFilterPreference_t preference, size_t memoryLimitInBytes, + cudnnConvolutionBwdFilterAlgo_t *algo) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnFilterDescriptor_t, cudnnConvolutionBwdFilterPreference_t, + size_t, cudnnConvolutionBwdFilterAlgo_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardFilterAlgorithm"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, xDesc, dyDesc, convDesc, dwDesc, preference, memoryLimitInBytes, algo); + return func_ptr(handle, xDesc, dyDesc, convDesc, dwDesc, preference, + memoryLimitInBytes, algo); } cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionBackwardFilterWorkspaceSize( - cudnnHandle_t handle, - const cudnnTensorDescriptor_t xDesc, - const cudnnTensorDescriptor_t dyDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnFilterDescriptor_t gradDesc, - cudnnConvolutionBwdFilterAlgo_t algo, - size_t *sizeInBytes ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnFilterDescriptor_t, cudnnConvolutionBwdFilterAlgo_t, size_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardFilterWorkspaceSize"); + cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, + const cudnnTensorDescriptor_t dyDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnFilterDescriptor_t gradDesc, + cudnnConvolutionBwdFilterAlgo_t algo, size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnFilterDescriptor_t, cudnnConvolutionBwdFilterAlgo_t, size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardFilterWorkspaceSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, xDesc, dyDesc, convDesc, gradDesc, algo, sizeInBytes); } cudnnStatus_t CUDNNWINAPI cudnnConvolutionBackwardFilter( - cudnnHandle_t handle, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const cudnnConvolutionDescriptor_t convDesc, - cudnnConvolutionBwdFilterAlgo_t algo, - void *workSpace, - size_t workSpaceSizeInBytes, - const void *beta, - const cudnnFilterDescriptor_t dwDesc, - void *dw ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnConvolutionDescriptor_t, cudnnConvolutionBwdFilterAlgo_t, void *, size_t, const void *, const cudnnFilterDescriptor_t, void *); + cudnnHandle_t handle, const void *alpha, + const cudnnTensorDescriptor_t xDesc, const void *x, + const cudnnTensorDescriptor_t dyDesc, const void *dy, + const cudnnConvolutionDescriptor_t convDesc, + cudnnConvolutionBwdFilterAlgo_t algo, void *workSpace, + size_t workSpaceSizeInBytes, const void *beta, + const cudnnFilterDescriptor_t dwDesc, void *dw) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnConvolutionDescriptor_t, cudnnConvolutionBwdFilterAlgo_t, + void *, size_t, const void *, const cudnnFilterDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnConvolutionBackwardFilter"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, alpha, xDesc, x, dyDesc, dy, convDesc, algo, workSpace, workSpaceSizeInBytes, beta, dwDesc, dw); + return func_ptr(handle, alpha, xDesc, x, dyDesc, dy, convDesc, algo, + workSpace, workSpaceSizeInBytes, beta, dwDesc, dw); } cudnnStatus_t CUDNNWINAPI cudnnFindConvolutionBackwardDataAlgorithm( - cudnnHandle_t handle, - const cudnnFilterDescriptor_t wDesc, - const cudnnTensorDescriptor_t dyDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t dxDesc, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnConvolutionBwdDataAlgoPerf_t *perfResults ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnFilterDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, const int, int *, cudnnConvolutionBwdDataAlgoPerf_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindConvolutionBackwardDataAlgorithm"); + cudnnHandle_t handle, const cudnnFilterDescriptor_t wDesc, + const cudnnTensorDescriptor_t dyDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t dxDesc, const int requestedAlgoCount, + int *returnedAlgoCount, cudnnConvolutionBwdDataAlgoPerf_t *perfResults) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnFilterDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnTensorDescriptor_t, const int, int *, + cudnnConvolutionBwdDataAlgoPerf_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnFindConvolutionBackwardDataAlgorithm"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, wDesc, dyDesc, convDesc, dxDesc, requestedAlgoCount, returnedAlgoCount, perfResults); + return func_ptr(handle, wDesc, dyDesc, convDesc, dxDesc, requestedAlgoCount, + returnedAlgoCount, perfResults); } cudnnStatus_t CUDNNWINAPI cudnnFindConvolutionBackwardDataAlgorithmEx( - cudnnHandle_t handle, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t dxDesc, - void *dx, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnConvolutionBwdDataAlgoPerf_t *perfResults, - void *workSpace, - size_t workSpaceSizeInBytes ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnFilterDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, void *, const int, int *, cudnnConvolutionBwdDataAlgoPerf_t *, void *, size_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindConvolutionBackwardDataAlgorithmEx"); + cudnnHandle_t handle, const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnTensorDescriptor_t dyDesc, const void *dy, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t dxDesc, void *dx, + const int requestedAlgoCount, int *returnedAlgoCount, + cudnnConvolutionBwdDataAlgoPerf_t *perfResults, void *workSpace, + size_t workSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnFilterDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, void *, + const int, int *, cudnnConvolutionBwdDataAlgoPerf_t *, void *, size_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnFindConvolutionBackwardDataAlgorithmEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, wDesc, w, dyDesc, dy, convDesc, dxDesc, dx, requestedAlgoCount, returnedAlgoCount, perfResults, workSpace, workSpaceSizeInBytes); + return func_ptr(handle, wDesc, w, dyDesc, dy, convDesc, dxDesc, dx, + requestedAlgoCount, returnedAlgoCount, perfResults, workSpace, + workSpaceSizeInBytes); } cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionBackwardDataAlgorithm( - cudnnHandle_t handle, - const cudnnFilterDescriptor_t wDesc, - const cudnnTensorDescriptor_t dyDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t dxDesc, - cudnnConvolutionBwdDataPreference_t preference, - size_t memoryLimitInBytes, - cudnnConvolutionBwdDataAlgo_t *algo ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnFilterDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, cudnnConvolutionBwdDataPreference_t, size_t, cudnnConvolutionBwdDataAlgo_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardDataAlgorithm"); + cudnnHandle_t handle, const cudnnFilterDescriptor_t wDesc, + const cudnnTensorDescriptor_t dyDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t dxDesc, + cudnnConvolutionBwdDataPreference_t preference, size_t memoryLimitInBytes, + cudnnConvolutionBwdDataAlgo_t *algo) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnFilterDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnTensorDescriptor_t, cudnnConvolutionBwdDataPreference_t, + size_t, cudnnConvolutionBwdDataAlgo_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardDataAlgorithm"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, wDesc, dyDesc, convDesc, dxDesc, preference, memoryLimitInBytes, algo); + return func_ptr(handle, wDesc, dyDesc, convDesc, dxDesc, preference, + memoryLimitInBytes, algo); } cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionBackwardDataWorkspaceSize( - cudnnHandle_t handle, - const cudnnFilterDescriptor_t wDesc, - const cudnnTensorDescriptor_t dyDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t dxDesc, - cudnnConvolutionBwdDataAlgo_t algo, - size_t *sizeInBytes ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnFilterDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, cudnnConvolutionBwdDataAlgo_t, size_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardDataWorkspaceSize"); + cudnnHandle_t handle, const cudnnFilterDescriptor_t wDesc, + const cudnnTensorDescriptor_t dyDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t dxDesc, cudnnConvolutionBwdDataAlgo_t algo, + size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnFilterDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnTensorDescriptor_t, cudnnConvolutionBwdDataAlgo_t, size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardDataWorkspaceSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, wDesc, dyDesc, convDesc, dxDesc, algo, sizeInBytes); } cudnnStatus_t CUDNNWINAPI cudnnConvolutionBackwardData( - cudnnHandle_t handle, - const void *alpha, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const cudnnConvolutionDescriptor_t convDesc, - cudnnConvolutionBwdDataAlgo_t algo, - void *workSpace, - size_t workSpaceSizeInBytes, - const void *beta, - const cudnnTensorDescriptor_t dxDesc, - void *dx ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnConvolutionDescriptor_t, cudnnConvolutionBwdDataAlgo_t, void *, size_t, const void *, const cudnnTensorDescriptor_t, void *); + cudnnHandle_t handle, const void *alpha, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnTensorDescriptor_t dyDesc, const void *dy, + const cudnnConvolutionDescriptor_t convDesc, + cudnnConvolutionBwdDataAlgo_t algo, void *workSpace, + size_t workSpaceSizeInBytes, const void *beta, + const cudnnTensorDescriptor_t dxDesc, void *dx) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const void *, const cudnnFilterDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnConvolutionDescriptor_t, cudnnConvolutionBwdDataAlgo_t, void *, + size_t, const void *, const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnConvolutionBackwardData"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, alpha, wDesc, w, dyDesc, dy, convDesc, algo, workSpace, workSpaceSizeInBytes, beta, dxDesc, dx); + return func_ptr(handle, alpha, wDesc, w, dyDesc, dy, convDesc, algo, + workSpace, workSpaceSizeInBytes, beta, dxDesc, dx); } -cudnnStatus_t CUDNNWINAPI cudnnIm2Col( - cudnnHandle_t handle, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const cudnnFilterDescriptor_t wDesc, - const cudnnConvolutionDescriptor_t convDesc, - void *colBuffer ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI +cudnnIm2Col(cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, + const void *x, const cudnnFilterDescriptor_t wDesc, + const cudnnConvolutionDescriptor_t convDesc, void *colBuffer) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, + const void *, const cudnnFilterDescriptor_t, + const cudnnConvolutionDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnIm2Col"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, xDesc, x, wDesc, convDesc, colBuffer); } cudnnStatus_t CUDNNWINAPI cudnnSoftmaxForward( - cudnnHandle_t handle, - cudnnSoftmaxAlgorithm_t algo, - cudnnSoftmaxMode_t mode, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t yDesc, - void *y ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnSoftmaxAlgorithm_t, cudnnSoftmaxMode_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); + cudnnHandle_t handle, cudnnSoftmaxAlgorithm_t algo, cudnnSoftmaxMode_t mode, + const void *alpha, const cudnnTensorDescriptor_t xDesc, const void *x, + const void *beta, const cudnnTensorDescriptor_t yDesc, void *y) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnSoftmaxAlgorithm_t, cudnnSoftmaxMode_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSoftmaxForward"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, algo, mode, alpha, xDesc, x, beta, yDesc, y); } cudnnStatus_t CUDNNWINAPI cudnnSoftmaxBackward( - cudnnHandle_t handle, - cudnnSoftmaxAlgorithm_t algo, - cudnnSoftmaxMode_t mode, - const void *alpha, - const cudnnTensorDescriptor_t yDesc, - const void *y, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const void *beta, - const cudnnTensorDescriptor_t dxDesc, - void *dx ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnSoftmaxAlgorithm_t, cudnnSoftmaxMode_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); + cudnnHandle_t handle, cudnnSoftmaxAlgorithm_t algo, cudnnSoftmaxMode_t mode, + const void *alpha, const cudnnTensorDescriptor_t yDesc, const void *y, + const cudnnTensorDescriptor_t dyDesc, const void *dy, const void *beta, + const cudnnTensorDescriptor_t dxDesc, void *dx) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnSoftmaxAlgorithm_t, cudnnSoftmaxMode_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSoftmaxBackward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, algo, mode, alpha, yDesc, y, dyDesc, dy, beta, dxDesc, dx); + return func_ptr(handle, algo, mode, alpha, yDesc, y, dyDesc, dy, beta, dxDesc, + dx); } -cudnnStatus_t CUDNNWINAPI cudnnCreatePoolingDescriptor( - cudnnPoolingDescriptor_t *poolingDesc ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnPoolingDescriptor_t *); +cudnnStatus_t CUDNNWINAPI +cudnnCreatePoolingDescriptor(cudnnPoolingDescriptor_t *poolingDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnPoolingDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreatePoolingDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(poolingDesc); } cudnnStatus_t CUDNNWINAPI cudnnSetPooling2dDescriptor( - cudnnPoolingDescriptor_t poolingDesc, - cudnnPoolingMode_t mode, - cudnnNanPropagation_t maxpoolingNanOpt, - int windowHeight, - int windowWidth, - int verticalPadding, - int horizontalPadding, - int verticalStride, - int horizontalStride ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnPoolingDescriptor_t, cudnnPoolingMode_t, cudnnNanPropagation_t, int, int, int, int, int, int); + cudnnPoolingDescriptor_t poolingDesc, cudnnPoolingMode_t mode, + cudnnNanPropagation_t maxpoolingNanOpt, int windowHeight, int windowWidth, + int verticalPadding, int horizontalPadding, int verticalStride, + int horizontalStride) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnPoolingDescriptor_t, cudnnPoolingMode_t, cudnnNanPropagation_t, int, + int, int, int, int, int); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetPooling2dDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(poolingDesc, mode, maxpoolingNanOpt, windowHeight, windowWidth, verticalPadding, horizontalPadding, verticalStride, horizontalStride); + return func_ptr(poolingDesc, mode, maxpoolingNanOpt, windowHeight, + windowWidth, verticalPadding, horizontalPadding, + verticalStride, horizontalStride); } cudnnStatus_t CUDNNWINAPI cudnnGetPooling2dDescriptor( - const cudnnPoolingDescriptor_t poolingDesc, - cudnnPoolingMode_t *mode, - cudnnNanPropagation_t *maxpoolingNanOpt, - int *windowHeight, - int *windowWidth, - int *verticalPadding, - int *horizontalPadding, - int *verticalStride, - int *horizontalStride ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnPoolingDescriptor_t, cudnnPoolingMode_t *, cudnnNanPropagation_t *, int *, int *, int *, int *, int *, int *); + const cudnnPoolingDescriptor_t poolingDesc, cudnnPoolingMode_t *mode, + cudnnNanPropagation_t *maxpoolingNanOpt, int *windowHeight, + int *windowWidth, int *verticalPadding, int *horizontalPadding, + int *verticalStride, int *horizontalStride) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnPoolingDescriptor_t, cudnnPoolingMode_t *, + cudnnNanPropagation_t *, int *, int *, int *, int *, int *, int *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetPooling2dDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(poolingDesc, mode, maxpoolingNanOpt, windowHeight, windowWidth, verticalPadding, horizontalPadding, verticalStride, horizontalStride); + return func_ptr(poolingDesc, mode, maxpoolingNanOpt, windowHeight, + windowWidth, verticalPadding, horizontalPadding, + verticalStride, horizontalStride); } cudnnStatus_t CUDNNWINAPI cudnnSetPoolingNdDescriptor( - cudnnPoolingDescriptor_t poolingDesc, - const cudnnPoolingMode_t mode, - const cudnnNanPropagation_t maxpoolingNanOpt, - int nbDims, - const int windowDimA[], - const int paddingA[], - const int strideA[] ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnPoolingDescriptor_t, const cudnnPoolingMode_t, const cudnnNanPropagation_t, int, const int [], const int [], const int []); + cudnnPoolingDescriptor_t poolingDesc, const cudnnPoolingMode_t mode, + const cudnnNanPropagation_t maxpoolingNanOpt, int nbDims, + const int windowDimA[], const int paddingA[], const int strideA[]) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnPoolingDescriptor_t, const cudnnPoolingMode_t, + const cudnnNanPropagation_t, int, const int[], const int[], const int[]); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetPoolingNdDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(poolingDesc, mode, maxpoolingNanOpt, nbDims, windowDimA, paddingA, strideA); + return func_ptr(poolingDesc, mode, maxpoolingNanOpt, nbDims, windowDimA, + paddingA, strideA); } cudnnStatus_t CUDNNWINAPI cudnnGetPoolingNdDescriptor( - const cudnnPoolingDescriptor_t poolingDesc, - int nbDimsRequested, - cudnnPoolingMode_t *mode, - cudnnNanPropagation_t *maxpoolingNanOpt, - int *nbDims, - int windowDimA[], - int paddingA[], - int strideA[] ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnPoolingDescriptor_t, int, cudnnPoolingMode_t *, cudnnNanPropagation_t *, int *, int [], int [], int []); + const cudnnPoolingDescriptor_t poolingDesc, int nbDimsRequested, + cudnnPoolingMode_t *mode, cudnnNanPropagation_t *maxpoolingNanOpt, + int *nbDims, int windowDimA[], int paddingA[], int strideA[]) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnPoolingDescriptor_t, int, cudnnPoolingMode_t *, + cudnnNanPropagation_t *, int *, int[], int[], int[]); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetPoolingNdDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(poolingDesc, nbDimsRequested, mode, maxpoolingNanOpt, nbDims, windowDimA, paddingA, strideA); + return func_ptr(poolingDesc, nbDimsRequested, mode, maxpoolingNanOpt, nbDims, + windowDimA, paddingA, strideA); } -cudnnStatus_t CUDNNWINAPI cudnnGetPoolingNdForwardOutputDim( - const cudnnPoolingDescriptor_t poolingDesc, - const cudnnTensorDescriptor_t inputTensorDesc, - int nbDims, - int outputTensorDimA[] ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnPoolingDescriptor_t, const cudnnTensorDescriptor_t, int, int []); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetPoolingNdForwardOutputDim"); +cudnnStatus_t CUDNNWINAPI +cudnnGetPoolingNdForwardOutputDim(const cudnnPoolingDescriptor_t poolingDesc, + const cudnnTensorDescriptor_t inputTensorDesc, + int nbDims, int outputTensorDimA[]) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(const cudnnPoolingDescriptor_t, + const cudnnTensorDescriptor_t, int, int[]); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetPoolingNdForwardOutputDim"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(poolingDesc, inputTensorDesc, nbDims, outputTensorDimA); } -cudnnStatus_t CUDNNWINAPI cudnnGetPooling2dForwardOutputDim( - const cudnnPoolingDescriptor_t poolingDesc, - const cudnnTensorDescriptor_t inputTensorDesc, - int *n, - int *c, - int *h, - int *w ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnPoolingDescriptor_t, const cudnnTensorDescriptor_t, int *, int *, int *, int *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetPooling2dForwardOutputDim"); +cudnnStatus_t CUDNNWINAPI +cudnnGetPooling2dForwardOutputDim(const cudnnPoolingDescriptor_t poolingDesc, + const cudnnTensorDescriptor_t inputTensorDesc, + int *n, int *c, int *h, int *w) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(const cudnnPoolingDescriptor_t, + const cudnnTensorDescriptor_t, + int *, int *, int *, int *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetPooling2dForwardOutputDim"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(poolingDesc, inputTensorDesc, n, c, h, w); } -cudnnStatus_t CUDNNWINAPI cudnnDestroyPoolingDescriptor( - cudnnPoolingDescriptor_t poolingDesc ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnPoolingDescriptor_t); +cudnnStatus_t CUDNNWINAPI +cudnnDestroyPoolingDescriptor(cudnnPoolingDescriptor_t poolingDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnPoolingDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyPoolingDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(poolingDesc); } cudnnStatus_t CUDNNWINAPI cudnnPoolingForward( - cudnnHandle_t handle, - const cudnnPoolingDescriptor_t poolingDesc, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t yDesc, - void *y ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnPoolingDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); + cudnnHandle_t handle, const cudnnPoolingDescriptor_t poolingDesc, + const void *alpha, const cudnnTensorDescriptor_t xDesc, const void *x, + const void *beta, const cudnnTensorDescriptor_t yDesc, void *y) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnPoolingDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnPoolingForward"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, poolingDesc, alpha, xDesc, x, beta, yDesc, y); } cudnnStatus_t CUDNNWINAPI cudnnPoolingBackward( - cudnnHandle_t handle, - const cudnnPoolingDescriptor_t poolingDesc, - const void *alpha, - const cudnnTensorDescriptor_t yDesc, - const void *y, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t dxDesc, - void *dx ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnPoolingDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); + cudnnHandle_t handle, const cudnnPoolingDescriptor_t poolingDesc, + const void *alpha, const cudnnTensorDescriptor_t yDesc, const void *y, + const cudnnTensorDescriptor_t dyDesc, const void *dy, + const cudnnTensorDescriptor_t xDesc, const void *x, const void *beta, + const cudnnTensorDescriptor_t dxDesc, void *dx) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnPoolingDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnPoolingBackward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, poolingDesc, alpha, yDesc, y, dyDesc, dy, xDesc, x, beta, dxDesc, dx); + return func_ptr(handle, poolingDesc, alpha, yDesc, y, dyDesc, dy, xDesc, x, + beta, dxDesc, dx); } -cudnnStatus_t CUDNNWINAPI cudnnCreateActivationDescriptor( - cudnnActivationDescriptor_t *activationDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnActivationDescriptor_t *); +cudnnStatus_t CUDNNWINAPI +cudnnCreateActivationDescriptor(cudnnActivationDescriptor_t *activationDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnActivationDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateActivationDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(activationDesc); } cudnnStatus_t CUDNNWINAPI cudnnSetActivationDescriptor( - cudnnActivationDescriptor_t activationDesc, - cudnnActivationMode_t mode, - cudnnNanPropagation_t reluNanOpt, - double coef ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnActivationDescriptor_t, cudnnActivationMode_t, cudnnNanPropagation_t, double); + cudnnActivationDescriptor_t activationDesc, cudnnActivationMode_t mode, + cudnnNanPropagation_t reluNanOpt, double coef) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnActivationDescriptor_t, + cudnnActivationMode_t, + cudnnNanPropagation_t, double); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetActivationDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(activationDesc, mode, reluNanOpt, coef); } -cudnnStatus_t CUDNNWINAPI cudnnGetActivationDescriptor( - const cudnnActivationDescriptor_t activationDesc, - cudnnActivationMode_t *mode, - cudnnNanPropagation_t *reluNanOpt, - double* coef ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnActivationDescriptor_t, cudnnActivationMode_t *, cudnnNanPropagation_t *, double *); +cudnnStatus_t CUDNNWINAPI +cudnnGetActivationDescriptor(const cudnnActivationDescriptor_t activationDesc, + cudnnActivationMode_t *mode, + cudnnNanPropagation_t *reluNanOpt, double *coef) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnActivationDescriptor_t, cudnnActivationMode_t *, + cudnnNanPropagation_t *, double *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetActivationDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(activationDesc, mode, reluNanOpt, coef); } -cudnnStatus_t CUDNNWINAPI cudnnDestroyActivationDescriptor( - cudnnActivationDescriptor_t activationDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnActivationDescriptor_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyActivationDescriptor"); +cudnnStatus_t CUDNNWINAPI +cudnnDestroyActivationDescriptor(cudnnActivationDescriptor_t activationDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnActivationDescriptor_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnDestroyActivationDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(activationDesc); } cudnnStatus_t CUDNNWINAPI cudnnActivationForward( - cudnnHandle_t handle, - cudnnActivationDescriptor_t activationDesc, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t yDesc, - void *y ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnActivationDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); + cudnnHandle_t handle, cudnnActivationDescriptor_t activationDesc, + const void *alpha, const cudnnTensorDescriptor_t xDesc, const void *x, + const void *beta, const cudnnTensorDescriptor_t yDesc, void *y) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnActivationDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnActivationForward"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, activationDesc, alpha, xDesc, x, beta, yDesc, y); } cudnnStatus_t CUDNNWINAPI cudnnActivationBackward( - cudnnHandle_t handle, - cudnnActivationDescriptor_t activationDesc, - const void *alpha, - const cudnnTensorDescriptor_t yDesc, - const void *y, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t dxDesc, - void *dx ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnActivationDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); + cudnnHandle_t handle, cudnnActivationDescriptor_t activationDesc, + const void *alpha, const cudnnTensorDescriptor_t yDesc, const void *y, + const cudnnTensorDescriptor_t dyDesc, const void *dy, + const cudnnTensorDescriptor_t xDesc, const void *x, const void *beta, + const cudnnTensorDescriptor_t dxDesc, void *dx) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnActivationDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnActivationBackward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, activationDesc, alpha, yDesc, y, dyDesc, dy, xDesc, x, beta, dxDesc, dx); + return func_ptr(handle, activationDesc, alpha, yDesc, y, dyDesc, dy, xDesc, x, + beta, dxDesc, dx); } -cudnnStatus_t CUDNNWINAPI cudnnCreateLRNDescriptor( - cudnnLRNDescriptor_t *normDesc ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnLRNDescriptor_t *); +cudnnStatus_t CUDNNWINAPI +cudnnCreateLRNDescriptor(cudnnLRNDescriptor_t *normDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnLRNDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateLRNDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(normDesc); } -cudnnStatus_t CUDNNWINAPI cudnnSetLRNDescriptor( - cudnnLRNDescriptor_t normDesc, - unsigned lrnN, - double lrnAlpha, - double lrnBeta, - double lrnK ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnLRNDescriptor_t, unsigned int, double, double, double); +cudnnStatus_t CUDNNWINAPI cudnnSetLRNDescriptor(cudnnLRNDescriptor_t normDesc, + unsigned lrnN, double lrnAlpha, + double lrnBeta, double lrnK) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnLRNDescriptor_t, unsigned int, double, double, double); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetLRNDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(normDesc, lrnN, lrnAlpha, lrnBeta, lrnK); } -cudnnStatus_t CUDNNWINAPI cudnnGetLRNDescriptor( - cudnnLRNDescriptor_t normDesc, - unsigned* lrnN, - double* lrnAlpha, - double* lrnBeta, - double* lrnK ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnLRNDescriptor_t, unsigned int *, double *, double *, double *); +cudnnStatus_t CUDNNWINAPI cudnnGetLRNDescriptor(cudnnLRNDescriptor_t normDesc, + unsigned *lrnN, + double *lrnAlpha, + double *lrnBeta, double *lrnK) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnLRNDescriptor_t, unsigned int *, double *, double *, double *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetLRNDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(normDesc, lrnN, lrnAlpha, lrnBeta, lrnK); } -cudnnStatus_t CUDNNWINAPI cudnnDestroyLRNDescriptor( cudnnLRNDescriptor_t lrnDesc ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnLRNDescriptor_t); +cudnnStatus_t CUDNNWINAPI +cudnnDestroyLRNDescriptor(cudnnLRNDescriptor_t lrnDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnLRNDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyLRNDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(lrnDesc); } cudnnStatus_t CUDNNWINAPI cudnnLRNCrossChannelForward( - cudnnHandle_t handle, - cudnnLRNDescriptor_t normDesc, - cudnnLRNMode_t lrnMode, - const void* alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t yDesc, - void *y ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnLRNDescriptor_t, cudnnLRNMode_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); + cudnnHandle_t handle, cudnnLRNDescriptor_t normDesc, cudnnLRNMode_t lrnMode, + const void *alpha, const cudnnTensorDescriptor_t xDesc, const void *x, + const void *beta, const cudnnTensorDescriptor_t yDesc, void *y) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnLRNDescriptor_t, cudnnLRNMode_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnLRNCrossChannelForward"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, normDesc, lrnMode, alpha, xDesc, x, beta, yDesc, y); } cudnnStatus_t CUDNNWINAPI cudnnLRNCrossChannelBackward( - cudnnHandle_t handle, - cudnnLRNDescriptor_t normDesc, - cudnnLRNMode_t lrnMode, - const void* alpha, - const cudnnTensorDescriptor_t yDesc, - const void *y, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t dxDesc, - void *dx) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnLRNDescriptor_t, cudnnLRNMode_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); + cudnnHandle_t handle, cudnnLRNDescriptor_t normDesc, cudnnLRNMode_t lrnMode, + const void *alpha, const cudnnTensorDescriptor_t yDesc, const void *y, + const cudnnTensorDescriptor_t dyDesc, const void *dy, + const cudnnTensorDescriptor_t xDesc, const void *x, const void *beta, + const cudnnTensorDescriptor_t dxDesc, void *dx) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnLRNDescriptor_t, cudnnLRNMode_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnLRNCrossChannelBackward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, normDesc, lrnMode, alpha, yDesc, y, dyDesc, dy, xDesc, x, beta, dxDesc, dx); + return func_ptr(handle, normDesc, lrnMode, alpha, yDesc, y, dyDesc, dy, xDesc, + x, beta, dxDesc, dx); } cudnnStatus_t CUDNNWINAPI cudnnDivisiveNormalizationForward( - cudnnHandle_t handle, - cudnnLRNDescriptor_t normDesc, - cudnnDivNormMode_t mode, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, // same desc for means, temp, temp2 - const void *x, - const void *means, // if NULL, means are assumed to be zero - void *temp, - void *temp2, - const void *beta, - const cudnnTensorDescriptor_t yDesc, - void *y ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnLRNDescriptor_t, cudnnDivNormMode_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, void *, void *, const void *, const cudnnTensorDescriptor_t, void *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDivisiveNormalizationForward"); + cudnnHandle_t handle, cudnnLRNDescriptor_t normDesc, + cudnnDivNormMode_t mode, const void *alpha, + const cudnnTensorDescriptor_t xDesc, // same desc for means, temp, temp2 + const void *x, + const void *means, // if NULL, means are assumed to be zero + void *temp, void *temp2, const void *beta, + const cudnnTensorDescriptor_t yDesc, void *y) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnLRNDescriptor_t, cudnnDivNormMode_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, void *, void *, + const void *, const cudnnTensorDescriptor_t, void *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnDivisiveNormalizationForward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, normDesc, mode, alpha, xDesc, x, means, temp, temp2, beta, yDesc, y); + return func_ptr(handle, normDesc, mode, alpha, xDesc, x, means, temp, temp2, + beta, yDesc, y); } cudnnStatus_t CUDNNWINAPI cudnnDivisiveNormalizationBackward( - cudnnHandle_t handle, - cudnnLRNDescriptor_t normDesc, - cudnnDivNormMode_t mode, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, // same desc for x, means, dy, temp, temp2 - const void *x, - const void *means, // if NULL, means are assumed to be zero - const void *dy, - void *temp, - void *temp2, - const void *beta, - const cudnnTensorDescriptor_t dXdMeansDesc, // same desc for dx, dMeans - void *dx, // output x differential - void *dMeans ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnLRNDescriptor_t, cudnnDivNormMode_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const void *, void *, void *, const void *, const cudnnTensorDescriptor_t, void *, void *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDivisiveNormalizationBackward"); + cudnnHandle_t handle, cudnnLRNDescriptor_t normDesc, + cudnnDivNormMode_t mode, const void *alpha, + const cudnnTensorDescriptor_t + xDesc, // same desc for x, means, dy, temp, temp2 + const void *x, + const void *means, // if NULL, means are assumed to be zero + const void *dy, void *temp, void *temp2, const void *beta, + const cudnnTensorDescriptor_t dXdMeansDesc, // same desc for dx, dMeans + void *dx, // output x differential + void *dMeans) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnLRNDescriptor_t, cudnnDivNormMode_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, const void *, + void *, void *, const void *, const cudnnTensorDescriptor_t, void *, + void *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnDivisiveNormalizationBackward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, normDesc, mode, alpha, xDesc, x, means, dy, temp, temp2, beta, dXdMeansDesc, dx, dMeans); + return func_ptr(handle, normDesc, mode, alpha, xDesc, x, means, dy, temp, + temp2, beta, dXdMeansDesc, dx, dMeans); } cudnnStatus_t CUDNNWINAPI cudnnDeriveBNTensorDescriptor( - cudnnTensorDescriptor_t derivedBnDesc, - const cudnnTensorDescriptor_t xDesc, - cudnnBatchNormMode_t mode ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, cudnnBatchNormMode_t); + cudnnTensorDescriptor_t derivedBnDesc, const cudnnTensorDescriptor_t xDesc, + cudnnBatchNormMode_t mode) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnTensorDescriptor_t, + const cudnnTensorDescriptor_t, + cudnnBatchNormMode_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDeriveBNTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(derivedBnDesc, xDesc, mode); } cudnnStatus_t CUDNNWINAPI cudnnBatchNormalizationForwardTraining( - cudnnHandle_t handle, - cudnnBatchNormMode_t mode, + cudnnHandle_t handle, cudnnBatchNormMode_t mode, - const void *alpha, // alpha[0] = result blend factor - const void *beta, // beta[0] = dest layer blend factor + const void *alpha, // alpha[0] = result blend factor + const void *beta, // beta[0] = dest layer blend factor - const cudnnTensorDescriptor_t xDesc, - const void *x, // NxCxHxW - const cudnnTensorDescriptor_t yDesc, - void *y, // NxCxHxW + const cudnnTensorDescriptor_t xDesc, + const void *x, // NxCxHxW + const cudnnTensorDescriptor_t yDesc, + void *y, // NxCxHxW - /* Shared desc for the next 6 tensors in the argument list. - Data type to be set as follows: - type = (typeOf(x) == double) ? double : float - Dimensions for this descriptor depend on normalization mode - - Spatial Normalization : tensors are expected to have dims 1xCx1x1 - (normalization is performed across NxHxW) - - Per-Activation Normalization : tensors are expected to have dims of 1xCxHxW - (normalization is performed across N) */ - const cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc, + /* Shared desc for the next 6 tensors in the argument list. + Data type to be set as follows: + type = (typeOf(x) == double) ? double : float + Dimensions for this descriptor depend on normalization mode + - Spatial Normalization : tensors are expected to have dims 1xCx1x1 + (normalization is performed across NxHxW) + - Per-Activation Normalization : tensors are expected to have dims of + 1xCxHxW (normalization is performed across N) */ + const cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc, - // 'Gamma' and 'Beta' respectively in Ioffe and Szegedy's paper's notation - const void *bnScale, - const void *bnBias, + // 'Gamma' and 'Beta' respectively in Ioffe and Szegedy's paper's notation + const void *bnScale, const void *bnBias, - /* MUST use factor=1 in the very first call of a complete training cycle. - Use a factor=1/(1+n) at N-th call to the function to get - Cumulative Moving Average (CMA) behavior - CMA[n] = (x[1]+...+x[n])/n - Since CMA[n+1] = (n*CMA[n]+x[n+1])/(n+1) = - ((n+1)*CMA[n]-CMA[n])/(n+1) + x[n+1]/(n+1) = - CMA[n]*(1-1/(n+1)) + x[n+1]*1/(n+1) */ - double exponentialAverageFactor, + /* MUST use factor=1 in the very first call of a complete training cycle. + Use a factor=1/(1+n) at N-th call to the function to get + Cumulative Moving Average (CMA) behavior + CMA[n] = (x[1]+...+x[n])/n + Since CMA[n+1] = (n*CMA[n]+x[n+1])/(n+1) = + ((n+1)*CMA[n]-CMA[n])/(n+1) + x[n+1]/(n+1) = + CMA[n]*(1-1/(n+1)) + x[n+1]*1/(n+1) */ + double exponentialAverageFactor, - /* Used in Training phase only. - runningMean = newMean*factor + runningMean*(1-factor) */ - void *resultRunningMean, - /* Output in training mode, input in inference. Is the moving average - of variance[x] (factor is applied in the same way as for runningMean) */ - void *resultRunningVariance, + /* Used in Training phase only. + runningMean = newMean*factor + runningMean*(1-factor) */ + void *resultRunningMean, + /* Output in training mode, input in inference. Is the moving average + of variance[x] (factor is applied in the same way as for runningMean) */ + void *resultRunningVariance, - /* Has to be >= CUDNN_BN_MIN_EPSILON. Should be the same in forward and backward functions. */ - double epsilon, + /* Has to be >= CUDNN_BN_MIN_EPSILON. Should be the same in forward and + backward functions. */ + double epsilon, - /* Optionally save intermediate results from the forward pass here - - can be reused to speed up backward pass. NULL if unused */ - void *resultSaveMean, - void *resultSaveInvVariance ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnBatchNormMode_t, const void *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, const void *, const void *, double, void *, void *, double, void *, void *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnBatchNormalizationForwardTraining"); + /* Optionally save intermediate results from the forward pass here + - can be reused to speed up backward pass. NULL if unused */ + void *resultSaveMean, void *resultSaveInvVariance) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnBatchNormMode_t, const void *, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, + const void *, const void *, double, void *, void *, double, void *, + void *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnBatchNormalizationForwardTraining"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, mode, alpha, beta, xDesc, x, yDesc, y, bnScaleBiasMeanVarDesc, bnScale, bnBias, exponentialAverageFactor, resultRunningMean, resultRunningVariance, epsilon, resultSaveMean, resultSaveInvVariance); + return func_ptr( + handle, mode, alpha, beta, xDesc, x, yDesc, y, bnScaleBiasMeanVarDesc, + bnScale, bnBias, exponentialAverageFactor, resultRunningMean, + resultRunningVariance, epsilon, resultSaveMean, resultSaveInvVariance); } cudnnStatus_t CUDNNWINAPI cudnnBatchNormalizationForwardInference( - cudnnHandle_t handle, - cudnnBatchNormMode_t mode, - const void *alpha, // alpha[0] = result blend factor - const void *beta, // beta[0] = dest layer blend factor - const cudnnTensorDescriptor_t xDesc, - const void *x, // NxCxHxW - const cudnnTensorDescriptor_t yDesc, - void *y, // NxCxHxW - const cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc, - const void *bnScale, - const void *bnBias, - const void *estimatedMean, - const void *estimatedVariance, - double epsilon ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnBatchNormMode_t, const void *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, const void *, const void *, const void *, const void *, double); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnBatchNormalizationForwardInference"); + cudnnHandle_t handle, cudnnBatchNormMode_t mode, + const void *alpha, // alpha[0] = result blend factor + const void *beta, // beta[0] = dest layer blend factor + const cudnnTensorDescriptor_t xDesc, + const void *x, // NxCxHxW + const cudnnTensorDescriptor_t yDesc, + void *y, // NxCxHxW + const cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc, const void *bnScale, + const void *bnBias, const void *estimatedMean, + const void *estimatedVariance, double epsilon) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnBatchNormMode_t, const void *, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, + const void *, const void *, const void *, const void *, double); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnBatchNormalizationForwardInference"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, mode, alpha, beta, xDesc, x, yDesc, y, bnScaleBiasMeanVarDesc, bnScale, bnBias, estimatedMean, estimatedVariance, epsilon); + return func_ptr(handle, mode, alpha, beta, xDesc, x, yDesc, y, + bnScaleBiasMeanVarDesc, bnScale, bnBias, estimatedMean, + estimatedVariance, epsilon); } cudnnStatus_t CUDNNWINAPI cudnnBatchNormalizationBackward( - cudnnHandle_t handle, - cudnnBatchNormMode_t mode, - const void *alphaDataDiff, - const void *betaDataDiff, - const void *alphaParamDiff, - const void *betaParamDiff, - const cudnnTensorDescriptor_t xDesc, // same desc for x, dx, dy - const void *x, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const cudnnTensorDescriptor_t dxDesc, - void *dx, - /* Shared tensor desc for the 4 tensors below */ - const cudnnTensorDescriptor_t dBnScaleBiasDesc, - const void *bnScale, // bnBias doesn't affect backpropagation - /* scale and bias diff are not backpropagated below this layer */ - void *dBnScaleResult, - void *dBnBiasResult, - /* Same epsilon as forward pass */ - double epsilon, + cudnnHandle_t handle, cudnnBatchNormMode_t mode, const void *alphaDataDiff, + const void *betaDataDiff, const void *alphaParamDiff, + const void *betaParamDiff, + const cudnnTensorDescriptor_t xDesc, // same desc for x, dx, dy + const void *x, const cudnnTensorDescriptor_t dyDesc, const void *dy, + const cudnnTensorDescriptor_t dxDesc, void *dx, + /* Shared tensor desc for the 4 tensors below */ + const cudnnTensorDescriptor_t dBnScaleBiasDesc, + const void *bnScale, // bnBias doesn't affect backpropagation + /* scale and bias diff are not backpropagated below this layer */ + void *dBnScaleResult, void *dBnBiasResult, + /* Same epsilon as forward pass */ + double epsilon, - /* Optionally cached intermediate results from - forward pass */ - const void *savedMean, - const void *savedInvVariance ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnBatchNormMode_t, const void *, const void *, const void *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, const void *, void *, void *, double, const void *, const void *); + /* Optionally cached intermediate results from + forward pass */ + const void *savedMean, const void *savedInvVariance) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnBatchNormMode_t, const void *, const void *, + const void *, const void *, const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, + const void *, void *, void *, double, const void *, const void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnBatchNormalizationBackward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, mode, alphaDataDiff, betaDataDiff, alphaParamDiff, betaParamDiff, xDesc, x, dyDesc, dy, dxDesc, dx, dBnScaleBiasDesc, bnScale, dBnScaleResult, dBnBiasResult, epsilon, savedMean, savedInvVariance); + return func_ptr(handle, mode, alphaDataDiff, betaDataDiff, alphaParamDiff, + betaParamDiff, xDesc, x, dyDesc, dy, dxDesc, dx, + dBnScaleBiasDesc, bnScale, dBnScaleResult, dBnBiasResult, + epsilon, savedMean, savedInvVariance); } -cudnnStatus_t CUDNNWINAPI cudnnCreateSpatialTransformerDescriptor( +cudnnStatus_t CUDNNWINAPI cudnnCreateSpatialTransformerDescriptor( - cudnnSpatialTransformerDescriptor_t *stDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnSpatialTransformerDescriptor_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateSpatialTransformerDescriptor"); + cudnnSpatialTransformerDescriptor_t *stDesc) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnSpatialTransformerDescriptor_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnCreateSpatialTransformerDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(stDesc); } cudnnStatus_t CUDNNWINAPI cudnnSetSpatialTransformerNdDescriptor( - cudnnSpatialTransformerDescriptor_t stDesc, - cudnnSamplerType_t samplerType, - cudnnDataType_t dataType, - const int nbDims, - const int dimA[]) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnSpatialTransformerDescriptor_t, cudnnSamplerType_t, cudnnDataType_t, const int, const int []); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetSpatialTransformerNdDescriptor"); + cudnnSpatialTransformerDescriptor_t stDesc, cudnnSamplerType_t samplerType, + cudnnDataType_t dataType, const int nbDims, const int dimA[]) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnSpatialTransformerDescriptor_t, cudnnSamplerType_t, cudnnDataType_t, + const int, const int[]); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnSetSpatialTransformerNdDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(stDesc, samplerType, dataType, nbDims, dimA); } cudnnStatus_t CUDNNWINAPI cudnnDestroySpatialTransformerDescriptor( - cudnnSpatialTransformerDescriptor_t stDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnSpatialTransformerDescriptor_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroySpatialTransformerDescriptor"); + cudnnSpatialTransformerDescriptor_t stDesc) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnSpatialTransformerDescriptor_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnDestroySpatialTransformerDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(stDesc); } cudnnStatus_t CUDNNWINAPI cudnnSpatialTfGridGeneratorForward( - cudnnHandle_t handle, - const cudnnSpatialTransformerDescriptor_t stDesc, - const void *theta, - void *grid) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnSpatialTransformerDescriptor_t, const void *, void *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSpatialTfGridGeneratorForward"); + cudnnHandle_t handle, const cudnnSpatialTransformerDescriptor_t stDesc, + const void *theta, void *grid) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnSpatialTransformerDescriptor_t, const void *, + void *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnSpatialTfGridGeneratorForward"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, stDesc, theta, grid); } cudnnStatus_t CUDNNWINAPI cudnnSpatialTfGridGeneratorBackward( - cudnnHandle_t handle, - const cudnnSpatialTransformerDescriptor_t stDesc, - const void *dgrid, - void *dtheta) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnSpatialTransformerDescriptor_t, const void *, void *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSpatialTfGridGeneratorBackward"); + cudnnHandle_t handle, const cudnnSpatialTransformerDescriptor_t stDesc, + const void *dgrid, void *dtheta) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnSpatialTransformerDescriptor_t, const void *, + void *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnSpatialTfGridGeneratorBackward"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, stDesc, dgrid, dtheta); } cudnnStatus_t CUDNNWINAPI cudnnSpatialTfSamplerForward( - cudnnHandle_t handle, - cudnnSpatialTransformerDescriptor_t stDesc, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *grid, - const void *beta, - cudnnTensorDescriptor_t yDesc, - void *y) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnSpatialTransformerDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const void *, cudnnTensorDescriptor_t, void *); + cudnnHandle_t handle, cudnnSpatialTransformerDescriptor_t stDesc, + const void *alpha, const cudnnTensorDescriptor_t xDesc, const void *x, + const void *grid, const void *beta, cudnnTensorDescriptor_t yDesc, + void *y) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnSpatialTransformerDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, const void *, + cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSpatialTfSamplerForward"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, stDesc, alpha, xDesc, x, grid, beta, yDesc, y); } cudnnStatus_t CUDNNWINAPI cudnnSpatialTfSamplerBackward( - cudnnHandle_t handle, - cudnnSpatialTransformerDescriptor_t stDesc, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t dxDesc, - void *dx, - const void *alphaDgrid, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const void *grid, - const void *betaDgrid, - void *dgrid) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnSpatialTransformerDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const void *, void *); + cudnnHandle_t handle, cudnnSpatialTransformerDescriptor_t stDesc, + const void *alpha, const cudnnTensorDescriptor_t xDesc, const void *x, + const void *beta, const cudnnTensorDescriptor_t dxDesc, void *dx, + const void *alphaDgrid, const cudnnTensorDescriptor_t dyDesc, + const void *dy, const void *grid, const void *betaDgrid, void *dgrid) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnSpatialTransformerDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, const void *, + void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSpatialTfSamplerBackward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, stDesc, alpha, xDesc, x, beta, dxDesc, dx, alphaDgrid, dyDesc, dy, grid, betaDgrid, dgrid); + return func_ptr(handle, stDesc, alpha, xDesc, x, beta, dxDesc, dx, alphaDgrid, + dyDesc, dy, grid, betaDgrid, dgrid); } -cudnnStatus_t CUDNNWINAPI cudnnCreateDropoutDescriptor(cudnnDropoutDescriptor_t * dropoutDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnDropoutDescriptor_t *); +cudnnStatus_t CUDNNWINAPI +cudnnCreateDropoutDescriptor(cudnnDropoutDescriptor_t *dropoutDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnDropoutDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateDropoutDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(dropoutDesc); } -cudnnStatus_t CUDNNWINAPI cudnnDestroyDropoutDescriptor(cudnnDropoutDescriptor_t dropoutDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnDropoutDescriptor_t); +cudnnStatus_t CUDNNWINAPI +cudnnDestroyDropoutDescriptor(cudnnDropoutDescriptor_t dropoutDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnDropoutDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyDropoutDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(dropoutDesc); } -cudnnStatus_t CUDNNWINAPI cudnnDropoutGetStatesSize(cudnnHandle_t handle, size_t * sizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, size_t *); +cudnnStatus_t CUDNNWINAPI cudnnDropoutGetStatesSize(cudnnHandle_t handle, + size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDropoutGetStatesSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, sizeInBytes); } -cudnnStatus_t CUDNNWINAPI cudnnDropoutGetReserveSpaceSize(cudnnTensorDescriptor_t xdesc, size_t * sizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorDescriptor_t, size_t *); +cudnnStatus_t CUDNNWINAPI cudnnDropoutGetReserveSpaceSize( + cudnnTensorDescriptor_t xdesc, size_t *sizeInBytes) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnTensorDescriptor_t, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDropoutGetReserveSpaceSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(xdesc, sizeInBytes); } -cudnnStatus_t CUDNNWINAPI cudnnSetDropoutDescriptor(cudnnDropoutDescriptor_t dropoutDesc, - cudnnHandle_t handle, - float dropout, - void * states, - size_t stateSizeInBytes, - unsigned long long seed) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnDropoutDescriptor_t, cudnnHandle_t, float, void *, size_t, unsigned long long); +cudnnStatus_t CUDNNWINAPI cudnnSetDropoutDescriptor( + cudnnDropoutDescriptor_t dropoutDesc, cudnnHandle_t handle, float dropout, + void *states, size_t stateSizeInBytes, unsigned long long seed) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnDropoutDescriptor_t, cudnnHandle_t, + float, void *, size_t, unsigned long long); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetDropoutDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(dropoutDesc, handle, dropout, states, stateSizeInBytes, seed); } -cudnnStatus_t CUDNNWINAPI cudnnDropoutForward(cudnnHandle_t handle, - const cudnnDropoutDescriptor_t dropoutDesc, - const cudnnTensorDescriptor_t xdesc, - const void * x, - const cudnnTensorDescriptor_t ydesc, - void * y, - void * reserveSpace, - size_t reserveSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnDropoutDescriptor_t, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, void *, void *, size_t); +cudnnStatus_t CUDNNWINAPI cudnnDropoutForward( + cudnnHandle_t handle, const cudnnDropoutDescriptor_t dropoutDesc, + const cudnnTensorDescriptor_t xdesc, const void *x, + const cudnnTensorDescriptor_t ydesc, void *y, void *reserveSpace, + size_t reserveSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnDropoutDescriptor_t, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, void *, void *, size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDropoutForward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, dropoutDesc, xdesc, x, ydesc, y, reserveSpace, reserveSpaceSizeInBytes); + return func_ptr(handle, dropoutDesc, xdesc, x, ydesc, y, reserveSpace, + reserveSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI cudnnDropoutBackward(cudnnHandle_t handle, - const cudnnDropoutDescriptor_t dropoutDesc, - const cudnnTensorDescriptor_t dydesc, - const void * dy, - const cudnnTensorDescriptor_t dxdesc, - void * dx, - void * reserveSpace, - size_t reserveSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnDropoutDescriptor_t, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, void *, void *, size_t); +cudnnStatus_t CUDNNWINAPI cudnnDropoutBackward( + cudnnHandle_t handle, const cudnnDropoutDescriptor_t dropoutDesc, + const cudnnTensorDescriptor_t dydesc, const void *dy, + const cudnnTensorDescriptor_t dxdesc, void *dx, void *reserveSpace, + size_t reserveSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnDropoutDescriptor_t, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, void *, void *, size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDropoutBackward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, dropoutDesc, dydesc, dy, dxdesc, dx, reserveSpace, reserveSpaceSizeInBytes); + return func_ptr(handle, dropoutDesc, dydesc, dy, dxdesc, dx, reserveSpace, + reserveSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI cudnnCreateRNNDescriptor(cudnnRNNDescriptor_t * rnnDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t *); +cudnnStatus_t CUDNNWINAPI +cudnnCreateRNNDescriptor(cudnnRNNDescriptor_t *rnnDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnRNNDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateRNNDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(rnnDesc); } -cudnnStatus_t CUDNNWINAPI cudnnDestroyRNNDescriptor(cudnnRNNDescriptor_t rnnDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t); +cudnnStatus_t CUDNNWINAPI +cudnnDestroyRNNDescriptor(cudnnRNNDescriptor_t rnnDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnRNNDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyRNNDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(rnnDesc); } -cudnnStatus_t CUDNNWINAPI cudnnCreatePersistentRNNPlan(cudnnRNNDescriptor_t rnnDesc, - const int minibatch, - const cudnnDataType_t dataType, - cudnnPersistentRNNPlan_t * plan) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t, const int, const cudnnDataType_t, cudnnPersistentRNNPlan_t *); +cudnnStatus_t CUDNNWINAPI cudnnCreatePersistentRNNPlan( + cudnnRNNDescriptor_t rnnDesc, const int minibatch, + const cudnnDataType_t dataType, cudnnPersistentRNNPlan_t *plan) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnRNNDescriptor_t, const int, + const cudnnDataType_t, + cudnnPersistentRNNPlan_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreatePersistentRNNPlan"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(rnnDesc, minibatch, dataType, plan); } -cudnnStatus_t CUDNNWINAPI cudnnSetPersistentRNNPlan(cudnnRNNDescriptor_t rnnDesc, - cudnnPersistentRNNPlan_t plan) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t, cudnnPersistentRNNPlan_t); +cudnnStatus_t CUDNNWINAPI cudnnSetPersistentRNNPlan( + cudnnRNNDescriptor_t rnnDesc, cudnnPersistentRNNPlan_t plan) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnRNNDescriptor_t, + cudnnPersistentRNNPlan_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetPersistentRNNPlan"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(rnnDesc, plan); } -cudnnStatus_t CUDNNWINAPI cudnnDestroyPersistentRNNPlan(cudnnPersistentRNNPlan_t plan) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnPersistentRNNPlan_t); +cudnnStatus_t CUDNNWINAPI +cudnnDestroyPersistentRNNPlan(cudnnPersistentRNNPlan_t plan) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnPersistentRNNPlan_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyPersistentRNNPlan"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(plan); } -cudnnStatus_t CUDNNWINAPI cudnnSetRNNDescriptor_v6(cudnnHandle_t handle, - cudnnRNNDescriptor_t rnnDesc, - const int hiddenSize, - const int numLayers, - cudnnDropoutDescriptor_t dropoutDesc, // Between layers, not between recurrent steps. - cudnnRNNInputMode_t inputMode, - cudnnDirectionMode_t direction, - cudnnRNNMode_t mode, - cudnnRNNAlgo_t algo, - cudnnDataType_t dataType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnRNNDescriptor_t, const int, const int, cudnnDropoutDescriptor_t, cudnnRNNInputMode_t, cudnnDirectionMode_t, cudnnRNNMode_t, cudnnRNNAlgo_t, cudnnDataType_t); +cudnnStatus_t CUDNNWINAPI cudnnSetRNNDescriptor_v6( + cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc, const int hiddenSize, + const int numLayers, + cudnnDropoutDescriptor_t + dropoutDesc, // Between layers, not between recurrent steps. + cudnnRNNInputMode_t inputMode, cudnnDirectionMode_t direction, + cudnnRNNMode_t mode, cudnnRNNAlgo_t algo, cudnnDataType_t dataType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnRNNDescriptor_t, const int, const int, + cudnnDropoutDescriptor_t, cudnnRNNInputMode_t, cudnnDirectionMode_t, + cudnnRNNMode_t, cudnnRNNAlgo_t, cudnnDataType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetRNNDescriptor_v6"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, hiddenSize, numLayers, dropoutDesc, inputMode, direction, mode, algo, dataType); + return func_ptr(handle, rnnDesc, hiddenSize, numLayers, dropoutDesc, + inputMode, direction, mode, algo, dataType); } -cudnnStatus_t CUDNNWINAPI cudnnSetRNNDescriptor(cudnnRNNDescriptor_t rnnDesc, - int hiddenSize, - int numLayers, - cudnnDropoutDescriptor_t dropoutDesc, // Between layers, not between recurrent steps. - cudnnRNNInputMode_t inputMode, - cudnnDirectionMode_t direction, - cudnnRNNMode_t mode, - cudnnDataType_t dataType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t, int, int, cudnnDropoutDescriptor_t, cudnnRNNInputMode_t, cudnnDirectionMode_t, cudnnRNNMode_t, cudnnDataType_t); +cudnnStatus_t CUDNNWINAPI cudnnSetRNNDescriptor( + cudnnRNNDescriptor_t rnnDesc, int hiddenSize, int numLayers, + cudnnDropoutDescriptor_t + dropoutDesc, // Between layers, not between recurrent steps. + cudnnRNNInputMode_t inputMode, cudnnDirectionMode_t direction, + cudnnRNNMode_t mode, cudnnDataType_t dataType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnRNNDescriptor_t, int, int, cudnnDropoutDescriptor_t, + cudnnRNNInputMode_t, cudnnDirectionMode_t, cudnnRNNMode_t, + cudnnDataType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetRNNDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(rnnDesc, hiddenSize, numLayers, dropoutDesc, inputMode, direction, mode, dataType); + return func_ptr(rnnDesc, hiddenSize, numLayers, dropoutDesc, inputMode, + direction, mode, dataType); } -cudnnStatus_t CUDNNWINAPI cudnnGetRNNWorkspaceSize( cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int seqLength, - const cudnnTensorDescriptor_t *xDesc, - size_t *sizeInBytes - ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, size_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetRNNWorkspaceSize( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, const cudnnTensorDescriptor_t *xDesc, + size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t *, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNWorkspaceSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rnnDesc, seqLength, xDesc, sizeInBytes); } -cudnnStatus_t CUDNNWINAPI cudnnGetRNNTrainingReserveSize( cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int seqLength, - const cudnnTensorDescriptor_t *xDesc, - size_t *sizeInBytes - ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, size_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetRNNTrainingReserveSize( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, const cudnnTensorDescriptor_t *xDesc, + size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t *, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNTrainingReserveSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rnnDesc, seqLength, xDesc, sizeInBytes); } -cudnnStatus_t CUDNNWINAPI cudnnGetRNNParamsSize( cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const cudnnTensorDescriptor_t xDesc, - size_t *sizeInBytes, - cudnnDataType_t dataType - ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const cudnnTensorDescriptor_t, size_t *, cudnnDataType_t); +cudnnStatus_t CUDNNWINAPI +cudnnGetRNNParamsSize(cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const cudnnTensorDescriptor_t xDesc, size_t *sizeInBytes, + cudnnDataType_t dataType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const cudnnTensorDescriptor_t, + size_t *, cudnnDataType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNParamsSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rnnDesc, xDesc, sizeInBytes, dataType); } -cudnnStatus_t CUDNNWINAPI cudnnGetRNNLinLayerMatrixParams( cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int layer, - const cudnnTensorDescriptor_t xDesc, - const cudnnFilterDescriptor_t wDesc, - const void * w, - const int linLayerID, - cudnnFilterDescriptor_t linLayerMatDesc, - void ** linLayerMat - ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, const void *, const int, cudnnFilterDescriptor_t, void **); +cudnnStatus_t CUDNNWINAPI cudnnGetRNNLinLayerMatrixParams( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, const int layer, + const cudnnTensorDescriptor_t xDesc, const cudnnFilterDescriptor_t wDesc, + const void *w, const int linLayerID, + cudnnFilterDescriptor_t linLayerMatDesc, void **linLayerMat) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, + const void *, const int, cudnnFilterDescriptor_t, void **); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNLinLayerMatrixParams"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, layer, xDesc, wDesc, w, linLayerID, linLayerMatDesc, linLayerMat); + return func_ptr(handle, rnnDesc, layer, xDesc, wDesc, w, linLayerID, + linLayerMatDesc, linLayerMat); } -cudnnStatus_t CUDNNWINAPI cudnnGetRNNLinLayerBiasParams( cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int layer, - const cudnnTensorDescriptor_t xDesc, - const cudnnFilterDescriptor_t wDesc, - const void * w, - const int linLayerID, - cudnnFilterDescriptor_t linLayerBiasDesc, - void ** linLayerBias - ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, const void *, const int, cudnnFilterDescriptor_t, void **); +cudnnStatus_t CUDNNWINAPI cudnnGetRNNLinLayerBiasParams( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, const int layer, + const cudnnTensorDescriptor_t xDesc, const cudnnFilterDescriptor_t wDesc, + const void *w, const int linLayerID, + cudnnFilterDescriptor_t linLayerBiasDesc, void **linLayerBias) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, + const void *, const int, cudnnFilterDescriptor_t, void **); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNLinLayerBiasParams"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, layer, xDesc, wDesc, w, linLayerID, linLayerBiasDesc, linLayerBias); + return func_ptr(handle, rnnDesc, layer, xDesc, wDesc, w, linLayerID, + linLayerBiasDesc, linLayerBias); } -cudnnStatus_t CUDNNWINAPI cudnnRNNForwardInference( cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int seqLength, - const cudnnTensorDescriptor_t * xDesc, - const void * x, - const cudnnTensorDescriptor_t hxDesc, - const void * hx, - const cudnnTensorDescriptor_t cxDesc, - const void * cx, - const cudnnFilterDescriptor_t wDesc, - const void * w, - const cudnnTensorDescriptor_t *yDesc, - void * y, - const cudnnTensorDescriptor_t hyDesc, - void * hy, - const cudnnTensorDescriptor_t cyDesc, - void * cy, - void * workspace, - size_t workSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, void *, size_t); +cudnnStatus_t CUDNNWINAPI cudnnRNNForwardInference( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, const cudnnTensorDescriptor_t *xDesc, const void *x, + const cudnnTensorDescriptor_t hxDesc, const void *hx, + const cudnnTensorDescriptor_t cxDesc, const void *cx, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnTensorDescriptor_t *yDesc, void *y, + const cudnnTensorDescriptor_t hyDesc, void *hy, + const cudnnTensorDescriptor_t cyDesc, void *cy, void *workspace, + size_t workSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t *, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, + void *, const cudnnTensorDescriptor_t, void *, void *, size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRNNForwardInference"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, cxDesc, cx, wDesc, w, yDesc, y, hyDesc, hy, cyDesc, cy, workspace, workSpaceSizeInBytes); + return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, cxDesc, cx, + wDesc, w, yDesc, y, hyDesc, hy, cyDesc, cy, workspace, + workSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI cudnnRNNForwardTraining( cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int seqLength, - const cudnnTensorDescriptor_t *xDesc, - const void * x, - const cudnnTensorDescriptor_t hxDesc, - const void * hx, - const cudnnTensorDescriptor_t cxDesc, - const void * cx, - const cudnnFilterDescriptor_t wDesc, - const void * w, - const cudnnTensorDescriptor_t *yDesc, - void * y, - const cudnnTensorDescriptor_t hyDesc, - void * hy, - const cudnnTensorDescriptor_t cyDesc, - void * cy, - void * workspace, - size_t workSpaceSizeInBytes, - void * reserveSpace, - size_t reserveSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, void *, size_t, void *, size_t); +cudnnStatus_t CUDNNWINAPI cudnnRNNForwardTraining( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, const cudnnTensorDescriptor_t *xDesc, const void *x, + const cudnnTensorDescriptor_t hxDesc, const void *hx, + const cudnnTensorDescriptor_t cxDesc, const void *cx, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnTensorDescriptor_t *yDesc, void *y, + const cudnnTensorDescriptor_t hyDesc, void *hy, + const cudnnTensorDescriptor_t cyDesc, void *cy, void *workspace, + size_t workSpaceSizeInBytes, void *reserveSpace, + size_t reserveSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t *, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, + void *, const cudnnTensorDescriptor_t, void *, void *, size_t, void *, + size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRNNForwardTraining"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, cxDesc, cx, wDesc, w, yDesc, y, hyDesc, hy, cyDesc, cy, workspace, workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes); + return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, cxDesc, cx, + wDesc, w, yDesc, y, hyDesc, hy, cyDesc, cy, workspace, + workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI cudnnRNNBackwardData( cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int seqLength, - const cudnnTensorDescriptor_t * yDesc, - const void * y, - const cudnnTensorDescriptor_t * dyDesc, - const void * dy, - const cudnnTensorDescriptor_t dhyDesc, - const void * dhy, - const cudnnTensorDescriptor_t dcyDesc, - const void * dcy, - const cudnnFilterDescriptor_t wDesc, - const void * w, - const cudnnTensorDescriptor_t hxDesc, - const void * hx, - const cudnnTensorDescriptor_t cxDesc, - const void * cx, - const cudnnTensorDescriptor_t * dxDesc, - void * dx, - const cudnnTensorDescriptor_t dhxDesc, - void * dhx, - const cudnnTensorDescriptor_t dcxDesc, - void * dcx, - void * workspace, - size_t workSpaceSizeInBytes, - void * reserveSpace, - size_t reserveSpaceSizeInBytes ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, void *, size_t, void *, size_t); +cudnnStatus_t CUDNNWINAPI +cudnnRNNBackwardData(cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, const cudnnTensorDescriptor_t *yDesc, + const void *y, const cudnnTensorDescriptor_t *dyDesc, + const void *dy, const cudnnTensorDescriptor_t dhyDesc, + const void *dhy, const cudnnTensorDescriptor_t dcyDesc, + const void *dcy, const cudnnFilterDescriptor_t wDesc, + const void *w, const cudnnTensorDescriptor_t hxDesc, + const void *hx, const cudnnTensorDescriptor_t cxDesc, + const void *cx, const cudnnTensorDescriptor_t *dxDesc, + void *dx, const cudnnTensorDescriptor_t dhxDesc, void *dhx, + const cudnnTensorDescriptor_t dcxDesc, void *dcx, + void *workspace, size_t workSpaceSizeInBytes, + void *reserveSpace, size_t reserveSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t *, const void *, + const cudnnTensorDescriptor_t *, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, + void *, const cudnnTensorDescriptor_t, void *, void *, size_t, void *, + size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRNNBackwardData"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, seqLength, yDesc, y, dyDesc, dy, dhyDesc, dhy, dcyDesc, dcy, wDesc, w, hxDesc, hx, cxDesc, cx, dxDesc, dx, dhxDesc, dhx, dcxDesc, dcx, workspace, workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes); + return func_ptr(handle, rnnDesc, seqLength, yDesc, y, dyDesc, dy, dhyDesc, + dhy, dcyDesc, dcy, wDesc, w, hxDesc, hx, cxDesc, cx, dxDesc, + dx, dhxDesc, dhx, dcxDesc, dcx, workspace, + workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI cudnnRNNBackwardWeights( cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int seqLength, - const cudnnTensorDescriptor_t * xDesc, - const void * x, - const cudnnTensorDescriptor_t hxDesc, - const void * hx, - const cudnnTensorDescriptor_t * yDesc, - const void * y, - const void * workspace, - size_t workSpaceSizeInBytes, - const cudnnFilterDescriptor_t dwDesc, - void * dw, - const void * reserveSpace, - size_t reserveSpaceSizeInBytes ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t *, const void *, const void *, size_t, const cudnnFilterDescriptor_t, void *, const void *, size_t); +cudnnStatus_t CUDNNWINAPI cudnnRNNBackwardWeights( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, const cudnnTensorDescriptor_t *xDesc, const void *x, + const cudnnTensorDescriptor_t hxDesc, const void *hx, + const cudnnTensorDescriptor_t *yDesc, const void *y, const void *workspace, + size_t workSpaceSizeInBytes, const cudnnFilterDescriptor_t dwDesc, void *dw, + const void *reserveSpace, size_t reserveSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t *, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t *, const void *, const void *, size_t, + const cudnnFilterDescriptor_t, void *, const void *, size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRNNBackwardWeights"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, yDesc, y, workspace, workSpaceSizeInBytes, dwDesc, dw, reserveSpace, reserveSpaceSizeInBytes); + return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, yDesc, y, + workspace, workSpaceSizeInBytes, dwDesc, dw, reserveSpace, + reserveSpaceSizeInBytes); } cudnnStatus_t CUDNNWINAPI cudnnSetConvolution2dDescriptor_v4( - cudnnConvolutionDescriptor_t convDesc, - int pad_h, // zero-padding height - int pad_w, // zero-padding width - int u, // vertical filter stride - int v, // horizontal filter stride - int dilation_h, // filter dilation in the vertical dimension - int dilation_w, // filter dilation in the horizontal dimension - cudnnConvolutionMode_t mode ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, int, int, int, int, int, int, cudnnConvolutionMode_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetConvolution2dDescriptor_v4"); + cudnnConvolutionDescriptor_t convDesc, + int pad_h, // zero-padding height + int pad_w, // zero-padding width + int u, // vertical filter stride + int v, // horizontal filter stride + int dilation_h, // filter dilation in the vertical dimension + int dilation_w, // filter dilation in the horizontal dimension + cudnnConvolutionMode_t mode) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, int, int, int, + int, int, int, cudnnConvolutionMode_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnSetConvolution2dDescriptor_v4"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(convDesc, pad_h, pad_w, u, v, dilation_h, dilation_w, mode); } -cudnnStatus_t CUDNNWINAPI cudnnSetConvolution2dDescriptor_v5( cudnnConvolutionDescriptor_t convDesc, - int pad_h, // zero-padding height - int pad_w, // zero-padding width - int u, // vertical filter stride - int v, // horizontal filter stride - int dilation_h, // filter dilation in the vertical dimension - int dilation_w, // filter dilation in the horizontal dimension - cudnnConvolutionMode_t mode, - cudnnDataType_t computeType - ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, int, int, int, int, int, int, cudnnConvolutionMode_t, cudnnDataType_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetConvolution2dDescriptor_v5"); +cudnnStatus_t CUDNNWINAPI cudnnSetConvolution2dDescriptor_v5( + cudnnConvolutionDescriptor_t convDesc, + int pad_h, // zero-padding height + int pad_w, // zero-padding width + int u, // vertical filter stride + int v, // horizontal filter stride + int dilation_h, // filter dilation in the vertical dimension + int dilation_w, // filter dilation in the horizontal dimension + cudnnConvolutionMode_t mode, cudnnDataType_t computeType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnConvolutionDescriptor_t, int, int, int, int, int, int, + cudnnConvolutionMode_t, cudnnDataType_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnSetConvolution2dDescriptor_v5"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(convDesc, pad_h, pad_w, u, v, dilation_h, dilation_w, mode, computeType); + return func_ptr(convDesc, pad_h, pad_w, u, v, dilation_h, dilation_w, mode, + computeType); } cudnnStatus_t CUDNNWINAPI cudnnGetConvolution2dDescriptor_v4( - const cudnnConvolutionDescriptor_t convDesc, - int *pad_h, // zero-padding height - int *pad_w, // zero-padding width - int *u, // vertical filter stride - int *v, // horizontal filter stride - int *dilation_h, // filter dilation in the vertical dimension - int *dilation_w, // filter dilation in the horizontal dimension - cudnnConvolutionMode_t *mode ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnConvolutionDescriptor_t, int *, int *, int *, int *, int *, int *, cudnnConvolutionMode_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolution2dDescriptor_v4"); + const cudnnConvolutionDescriptor_t convDesc, + int *pad_h, // zero-padding height + int *pad_w, // zero-padding width + int *u, // vertical filter stride + int *v, // horizontal filter stride + int *dilation_h, // filter dilation in the vertical dimension + int *dilation_w, // filter dilation in the horizontal dimension + cudnnConvolutionMode_t *mode) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnConvolutionDescriptor_t, int *, int *, int *, int *, int *, + int *, cudnnConvolutionMode_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolution2dDescriptor_v4"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(convDesc, pad_h, pad_w, u, v, dilation_h, dilation_w, mode); } -cudnnStatus_t CUDNNWINAPI cudnnGetConvolution2dDescriptor_v5( const cudnnConvolutionDescriptor_t convDesc, - int* pad_h, // zero-padding height - int* pad_w, // zero-padding width - int* u, // vertical filter stride - int* v, // horizontal filter stride - int* dilation_h, // filter dilation in the vertical dimension - int* dilation_w, // filter dilation in the horizontal dimension - cudnnConvolutionMode_t* mode, - cudnnDataType_t *computeType - ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnConvolutionDescriptor_t, int *, int *, int *, int *, int *, int *, cudnnConvolutionMode_t *, cudnnDataType_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolution2dDescriptor_v5"); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolution2dDescriptor_v5( + const cudnnConvolutionDescriptor_t convDesc, + int *pad_h, // zero-padding height + int *pad_w, // zero-padding width + int *u, // vertical filter stride + int *v, // horizontal filter stride + int *dilation_h, // filter dilation in the vertical dimension + int *dilation_w, // filter dilation in the horizontal dimension + cudnnConvolutionMode_t *mode, cudnnDataType_t *computeType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnConvolutionDescriptor_t, int *, int *, int *, int *, int *, + int *, cudnnConvolutionMode_t *, cudnnDataType_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolution2dDescriptor_v5"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(convDesc, pad_h, pad_w, u, v, dilation_h, dilation_w, mode, computeType); + return func_ptr(convDesc, pad_h, pad_w, u, v, dilation_h, dilation_w, mode, + computeType); } } // extern "C" diff --git a/tensorflow/stream_executor/cuda/cudnn_7_0.inc b/tensorflow/stream_executor/cuda/cudnn_7_0.inc index d2ea31e366b..008ae9099c0 100644 --- a/tensorflow/stream_executor/cuda/cudnn_7_0.inc +++ b/tensorflow/stream_executor/cuda/cudnn_7_0.inc @@ -3,1944 +3,2025 @@ extern "C" { size_t CUDNNWINAPI cudnnGetVersion(void) { - using FuncPtr = size_t (CUDNNWINAPI *)(); + using FuncPtr = size_t(CUDNNWINAPI *)(); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetVersion"); if (!func_ptr) return 0; return func_ptr(); } size_t CUDNNWINAPI cudnnGetCudartVersion(void) { - using FuncPtr = size_t (CUDNNWINAPI *)(); + using FuncPtr = size_t(CUDNNWINAPI *)(); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetCudartVersion"); if (!func_ptr) return 0; return func_ptr(); } -const char * CUDNNWINAPI cudnnGetErrorString(cudnnStatus_t status) { - using FuncPtr = const char * (CUDNNWINAPI *)(cudnnStatus_t); +const char *CUDNNWINAPI cudnnGetErrorString(cudnnStatus_t status) { + using FuncPtr = const char *(CUDNNWINAPI *)(cudnnStatus_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetErrorString"); if (!func_ptr) return "cudnnGetErrorString symbol not found."; return func_ptr(status); } -cudnnStatus_t CUDNNWINAPI cudnnQueryRuntimeError( - cudnnHandle_t handle, - cudnnStatus_t *rstatus, - cudnnErrQueryMode_t mode, - cudnnRuntimeTag_t *tag ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnStatus_t *, cudnnErrQueryMode_t, cudnnRuntimeTag_t *); +cudnnStatus_t CUDNNWINAPI cudnnQueryRuntimeError(cudnnHandle_t handle, + cudnnStatus_t *rstatus, + cudnnErrQueryMode_t mode, + cudnnRuntimeTag_t *tag) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnStatus_t *, cudnnErrQueryMode_t, cudnnRuntimeTag_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnQueryRuntimeError"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rstatus, mode, tag); } -cudnnStatus_t CUDNNWINAPI cudnnGetProperty(libraryPropertyType type, int *value) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(libraryPropertyType, int *); +cudnnStatus_t CUDNNWINAPI cudnnGetProperty(libraryPropertyType type, + int *value) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(libraryPropertyType, int *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetProperty"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(type, value); } -cudnnStatus_t CUDNNWINAPI cudnnCreate (cudnnHandle_t *handle) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t *); +cudnnStatus_t CUDNNWINAPI cudnnCreate(cudnnHandle_t *handle) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreate"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle); } -cudnnStatus_t CUDNNWINAPI cudnnDestroy (cudnnHandle_t handle) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t); +cudnnStatus_t CUDNNWINAPI cudnnDestroy(cudnnHandle_t handle) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroy"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle); } -cudnnStatus_t CUDNNWINAPI cudnnSetStream (cudnnHandle_t handle, cudaStream_t streamId) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudaStream_t); +cudnnStatus_t CUDNNWINAPI cudnnSetStream(cudnnHandle_t handle, + cudaStream_t streamId) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t, cudaStream_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetStream"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, streamId); } -cudnnStatus_t CUDNNWINAPI cudnnGetStream (cudnnHandle_t handle, cudaStream_t *streamId) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudaStream_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetStream(cudnnHandle_t handle, + cudaStream_t *streamId) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t, cudaStream_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetStream"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, streamId); } -cudnnStatus_t CUDNNWINAPI cudnnCreateTensorDescriptor( - cudnnTensorDescriptor_t *tensorDesc ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorDescriptor_t *); +cudnnStatus_t CUDNNWINAPI +cudnnCreateTensorDescriptor(cudnnTensorDescriptor_t *tensorDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnTensorDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(tensorDesc); } cudnnStatus_t CUDNNWINAPI cudnnSetTensor4dDescriptor( - cudnnTensorDescriptor_t tensorDesc, - cudnnTensorFormat_t format, - cudnnDataType_t dataType, /* image data type */ - int n, /* number of inputs (batch size) */ - int c, /* number of input feature maps */ - int h, /* height of input section */ - int w ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorDescriptor_t, cudnnTensorFormat_t, cudnnDataType_t, int, int, int, int); + cudnnTensorDescriptor_t tensorDesc, cudnnTensorFormat_t format, + cudnnDataType_t dataType, /* image data type */ + int n, /* number of inputs (batch size) */ + int c, /* number of input feature maps */ + int h, /* height of input section */ + int w) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnTensorDescriptor_t, cudnnTensorFormat_t, + cudnnDataType_t, int, int, int, int); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetTensor4dDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(tensorDesc, format, dataType, n, c, h, w); } cudnnStatus_t CUDNNWINAPI cudnnSetTensor4dDescriptorEx( - cudnnTensorDescriptor_t tensorDesc, - cudnnDataType_t dataType, /* image data type */ - int n, /* number of inputs (batch size) */ - int c, /* number of input feature maps */ - int h, /* height of input section */ - int w, /* width of input section */ - int nStride, - int cStride, - int hStride, - int wStride ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorDescriptor_t, cudnnDataType_t, int, int, int, int, int, int, int, int); + cudnnTensorDescriptor_t tensorDesc, + cudnnDataType_t dataType, /* image data type */ + int n, /* number of inputs (batch size) */ + int c, /* number of input feature maps */ + int h, /* height of input section */ + int w, /* width of input section */ + int nStride, int cStride, int hStride, int wStride) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnTensorDescriptor_t, cudnnDataType_t, + int, int, int, int, int, int, int, int); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetTensor4dDescriptorEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(tensorDesc, dataType, n, c, h, w, nStride, cStride, hStride, wStride); + return func_ptr(tensorDesc, dataType, n, c, h, w, nStride, cStride, hStride, + wStride); } cudnnStatus_t CUDNNWINAPI cudnnGetTensor4dDescriptor( - const cudnnTensorDescriptor_t tensorDesc, - cudnnDataType_t *dataType, /* image data type */ - int *n, /* number of inputs (batch size) */ - int *c, /* number of input feature maps */ - int *h, /* height of input section */ - int *w, /* width of input section */ - int *nStride, - int *cStride, - int *hStride, - int *wStride ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnTensorDescriptor_t, cudnnDataType_t *, int *, int *, int *, int *, int *, int *, int *, int *); + const cudnnTensorDescriptor_t tensorDesc, + cudnnDataType_t *dataType, /* image data type */ + int *n, /* number of inputs (batch size) */ + int *c, /* number of input feature maps */ + int *h, /* height of input section */ + int *w, /* width of input section */ + int *nStride, int *cStride, int *hStride, int *wStride) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnTensorDescriptor_t, cudnnDataType_t *, int *, int *, int *, + int *, int *, int *, int *, int *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetTensor4dDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(tensorDesc, dataType, n, c, h, w, nStride, cStride, hStride, wStride); + return func_ptr(tensorDesc, dataType, n, c, h, w, nStride, cStride, hStride, + wStride); } cudnnStatus_t CUDNNWINAPI cudnnSetTensorNdDescriptor( - cudnnTensorDescriptor_t tensorDesc, - cudnnDataType_t dataType, - int nbDims, - const int dimA[], - const int strideA[] ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorDescriptor_t, cudnnDataType_t, int, const int [], const int []); + cudnnTensorDescriptor_t tensorDesc, cudnnDataType_t dataType, int nbDims, + const int dimA[], const int strideA[]) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnTensorDescriptor_t, cudnnDataType_t, int, const int[], const int[]); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetTensorNdDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(tensorDesc, dataType, nbDims, dimA, strideA); } cudnnStatus_t CUDNNWINAPI cudnnSetTensorNdDescriptorEx( - cudnnTensorDescriptor_t tensorDesc, - cudnnTensorFormat_t format, - cudnnDataType_t dataType, - int nbDims, - const int dimA[] ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorDescriptor_t, cudnnTensorFormat_t, cudnnDataType_t, int, const int []); + cudnnTensorDescriptor_t tensorDesc, cudnnTensorFormat_t format, + cudnnDataType_t dataType, int nbDims, const int dimA[]) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnTensorDescriptor_t, cudnnTensorFormat_t, + cudnnDataType_t, int, const int[]); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetTensorNdDescriptorEx"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(tensorDesc, format, dataType, nbDims, dimA); } cudnnStatus_t CUDNNWINAPI cudnnGetTensorNdDescriptor( - const cudnnTensorDescriptor_t tensorDesc, - int nbDimsRequested, - cudnnDataType_t *dataType, - int *nbDims, - int dimA[], - int strideA[] ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnTensorDescriptor_t, int, cudnnDataType_t *, int *, int [], int []); + const cudnnTensorDescriptor_t tensorDesc, int nbDimsRequested, + cudnnDataType_t *dataType, int *nbDims, int dimA[], int strideA[]) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(const cudnnTensorDescriptor_t, int, + cudnnDataType_t *, int *, int[], int[]); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetTensorNdDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(tensorDesc, nbDimsRequested, dataType, nbDims, dimA, strideA); } cudnnStatus_t CUDNNWINAPI cudnnGetTensorSizeInBytes( - const cudnnTensorDescriptor_t tensorDesc, - size_t *size) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnTensorDescriptor_t, size_t *); + const cudnnTensorDescriptor_t tensorDesc, size_t *size) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(const cudnnTensorDescriptor_t, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetTensorSizeInBytes"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(tensorDesc, size); } -cudnnStatus_t CUDNNWINAPI cudnnDestroyTensorDescriptor( - cudnnTensorDescriptor_t tensorDesc ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorDescriptor_t); +cudnnStatus_t CUDNNWINAPI +cudnnDestroyTensorDescriptor(cudnnTensorDescriptor_t tensorDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnTensorDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(tensorDesc); } cudnnStatus_t CUDNNWINAPI cudnnTransformTensor( - cudnnHandle_t handle, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t yDesc, - void *y ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); + cudnnHandle_t handle, const void *alpha, + const cudnnTensorDescriptor_t xDesc, const void *x, const void *beta, + const cudnnTensorDescriptor_t yDesc, void *y) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, + const void *, const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnTransformTensor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, alpha, xDesc, x, beta, yDesc, y); } -cudnnStatus_t CUDNNWINAPI cudnnAddTensor( - cudnnHandle_t handle, - const void *alpha, - const cudnnTensorDescriptor_t aDesc, - const void *A, - const void *beta, - const cudnnTensorDescriptor_t cDesc, - void *C ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnAddTensor(cudnnHandle_t handle, + const void *alpha, + const cudnnTensorDescriptor_t aDesc, + const void *A, const void *beta, + const cudnnTensorDescriptor_t cDesc, + void *C) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, + const void *, const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnAddTensor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, alpha, aDesc, A, beta, cDesc, C); } -cudnnStatus_t CUDNNWINAPI cudnnCreateOpTensorDescriptor( - cudnnOpTensorDescriptor_t *opTensorDesc ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnOpTensorDescriptor_t *); +cudnnStatus_t CUDNNWINAPI +cudnnCreateOpTensorDescriptor(cudnnOpTensorDescriptor_t *opTensorDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnOpTensorDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateOpTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(opTensorDesc); } cudnnStatus_t CUDNNWINAPI cudnnSetOpTensorDescriptor( - cudnnOpTensorDescriptor_t opTensorDesc, - cudnnOpTensorOp_t opTensorOp, - cudnnDataType_t opTensorCompType, - cudnnNanPropagation_t opTensorNanOpt ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnOpTensorDescriptor_t, cudnnOpTensorOp_t, cudnnDataType_t, cudnnNanPropagation_t); + cudnnOpTensorDescriptor_t opTensorDesc, cudnnOpTensorOp_t opTensorOp, + cudnnDataType_t opTensorCompType, cudnnNanPropagation_t opTensorNanOpt) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnOpTensorDescriptor_t, cudnnOpTensorOp_t, + cudnnDataType_t, cudnnNanPropagation_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetOpTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(opTensorDesc, opTensorOp, opTensorCompType, opTensorNanOpt); } cudnnStatus_t CUDNNWINAPI cudnnGetOpTensorDescriptor( - const cudnnOpTensorDescriptor_t opTensorDesc, - cudnnOpTensorOp_t *opTensorOp, - cudnnDataType_t *opTensorCompType, - cudnnNanPropagation_t *opTensorNanOpt ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnOpTensorDescriptor_t, cudnnOpTensorOp_t *, cudnnDataType_t *, cudnnNanPropagation_t *); + const cudnnOpTensorDescriptor_t opTensorDesc, cudnnOpTensorOp_t *opTensorOp, + cudnnDataType_t *opTensorCompType, cudnnNanPropagation_t *opTensorNanOpt) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnOpTensorDescriptor_t, cudnnOpTensorOp_t *, cudnnDataType_t *, + cudnnNanPropagation_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetOpTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(opTensorDesc, opTensorOp, opTensorCompType, opTensorNanOpt); } -cudnnStatus_t CUDNNWINAPI cudnnDestroyOpTensorDescriptor( - cudnnOpTensorDescriptor_t opTensorDesc ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnOpTensorDescriptor_t); +cudnnStatus_t CUDNNWINAPI +cudnnDestroyOpTensorDescriptor(cudnnOpTensorDescriptor_t opTensorDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnOpTensorDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyOpTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(opTensorDesc); } cudnnStatus_t CUDNNWINAPI cudnnOpTensor( - cudnnHandle_t handle, - const cudnnOpTensorDescriptor_t opTensorDesc, - const void *alpha1, - const cudnnTensorDescriptor_t aDesc, - const void *A, - const void *alpha2, - const cudnnTensorDescriptor_t bDesc, - const void *B, - const void *beta, - const cudnnTensorDescriptor_t cDesc, - void *C ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnOpTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); + cudnnHandle_t handle, const cudnnOpTensorDescriptor_t opTensorDesc, + const void *alpha1, const cudnnTensorDescriptor_t aDesc, const void *A, + const void *alpha2, const cudnnTensorDescriptor_t bDesc, const void *B, + const void *beta, const cudnnTensorDescriptor_t cDesc, void *C) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnOpTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnOpTensor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, opTensorDesc, alpha1, aDesc, A, alpha2, bDesc, B, beta, cDesc, C); + return func_ptr(handle, opTensorDesc, alpha1, aDesc, A, alpha2, bDesc, B, + beta, cDesc, C); } cudnnStatus_t CUDNNWINAPI cudnnCreateReduceTensorDescriptor( - cudnnReduceTensorDescriptor_t *reduceTensorDesc ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnReduceTensorDescriptor_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateReduceTensorDescriptor"); + cudnnReduceTensorDescriptor_t *reduceTensorDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnReduceTensorDescriptor_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnCreateReduceTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(reduceTensorDesc); } cudnnStatus_t CUDNNWINAPI cudnnSetReduceTensorDescriptor( - cudnnReduceTensorDescriptor_t reduceTensorDesc, - cudnnReduceTensorOp_t reduceTensorOp, - cudnnDataType_t reduceTensorCompType, - cudnnNanPropagation_t reduceTensorNanOpt, - cudnnReduceTensorIndices_t reduceTensorIndices, - cudnnIndicesType_t reduceTensorIndicesType ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnReduceTensorDescriptor_t, cudnnReduceTensorOp_t, cudnnDataType_t, cudnnNanPropagation_t, cudnnReduceTensorIndices_t, cudnnIndicesType_t); + cudnnReduceTensorDescriptor_t reduceTensorDesc, + cudnnReduceTensorOp_t reduceTensorOp, cudnnDataType_t reduceTensorCompType, + cudnnNanPropagation_t reduceTensorNanOpt, + cudnnReduceTensorIndices_t reduceTensorIndices, + cudnnIndicesType_t reduceTensorIndicesType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnReduceTensorDescriptor_t, cudnnReduceTensorOp_t, cudnnDataType_t, + cudnnNanPropagation_t, cudnnReduceTensorIndices_t, cudnnIndicesType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetReduceTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(reduceTensorDesc, reduceTensorOp, reduceTensorCompType, reduceTensorNanOpt, reduceTensorIndices, reduceTensorIndicesType); + return func_ptr(reduceTensorDesc, reduceTensorOp, reduceTensorCompType, + reduceTensorNanOpt, reduceTensorIndices, + reduceTensorIndicesType); } cudnnStatus_t CUDNNWINAPI cudnnGetReduceTensorDescriptor( - const cudnnReduceTensorDescriptor_t reduceTensorDesc, - cudnnReduceTensorOp_t *reduceTensorOp, - cudnnDataType_t *reduceTensorCompType, - cudnnNanPropagation_t *reduceTensorNanOpt, - cudnnReduceTensorIndices_t *reduceTensorIndices, - cudnnIndicesType_t *reduceTensorIndicesType ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnReduceTensorDescriptor_t, cudnnReduceTensorOp_t *, cudnnDataType_t *, cudnnNanPropagation_t *, cudnnReduceTensorIndices_t *, cudnnIndicesType_t *); + const cudnnReduceTensorDescriptor_t reduceTensorDesc, + cudnnReduceTensorOp_t *reduceTensorOp, + cudnnDataType_t *reduceTensorCompType, + cudnnNanPropagation_t *reduceTensorNanOpt, + cudnnReduceTensorIndices_t *reduceTensorIndices, + cudnnIndicesType_t *reduceTensorIndicesType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnReduceTensorDescriptor_t, cudnnReduceTensorOp_t *, + cudnnDataType_t *, cudnnNanPropagation_t *, cudnnReduceTensorIndices_t *, + cudnnIndicesType_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetReduceTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(reduceTensorDesc, reduceTensorOp, reduceTensorCompType, reduceTensorNanOpt, reduceTensorIndices, reduceTensorIndicesType); + return func_ptr(reduceTensorDesc, reduceTensorOp, reduceTensorCompType, + reduceTensorNanOpt, reduceTensorIndices, + reduceTensorIndicesType); } cudnnStatus_t CUDNNWINAPI cudnnDestroyReduceTensorDescriptor( - cudnnReduceTensorDescriptor_t reduceTensorDesc ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnReduceTensorDescriptor_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyReduceTensorDescriptor"); + cudnnReduceTensorDescriptor_t reduceTensorDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnReduceTensorDescriptor_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnDestroyReduceTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(reduceTensorDesc); } cudnnStatus_t CUDNNWINAPI cudnnGetReductionIndicesSize( - cudnnHandle_t handle, - const cudnnReduceTensorDescriptor_t reduceTensorDesc, - const cudnnTensorDescriptor_t aDesc, - const cudnnTensorDescriptor_t cDesc, - size_t *sizeInBytes ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnReduceTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, size_t *); + cudnnHandle_t handle, const cudnnReduceTensorDescriptor_t reduceTensorDesc, + const cudnnTensorDescriptor_t aDesc, const cudnnTensorDescriptor_t cDesc, + size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnReduceTensorDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetReductionIndicesSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, reduceTensorDesc, aDesc, cDesc, sizeInBytes); } cudnnStatus_t CUDNNWINAPI cudnnGetReductionWorkspaceSize( - cudnnHandle_t handle, - const cudnnReduceTensorDescriptor_t reduceTensorDesc, - const cudnnTensorDescriptor_t aDesc, - const cudnnTensorDescriptor_t cDesc, - size_t *sizeInBytes ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnReduceTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, size_t *); + cudnnHandle_t handle, const cudnnReduceTensorDescriptor_t reduceTensorDesc, + const cudnnTensorDescriptor_t aDesc, const cudnnTensorDescriptor_t cDesc, + size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnReduceTensorDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetReductionWorkspaceSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, reduceTensorDesc, aDesc, cDesc, sizeInBytes); } cudnnStatus_t CUDNNWINAPI cudnnReduceTensor( - cudnnHandle_t handle, - const cudnnReduceTensorDescriptor_t reduceTensorDesc, - void *indices, - size_t indicesSizeInBytes, - void *workspace, - size_t workspaceSizeInBytes, - const void *alpha, - const cudnnTensorDescriptor_t aDesc, - const void *A, - const void *beta, - const cudnnTensorDescriptor_t cDesc, - void *C ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnReduceTensorDescriptor_t, void *, size_t, void *, size_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); + cudnnHandle_t handle, const cudnnReduceTensorDescriptor_t reduceTensorDesc, + void *indices, size_t indicesSizeInBytes, void *workspace, + size_t workspaceSizeInBytes, const void *alpha, + const cudnnTensorDescriptor_t aDesc, const void *A, const void *beta, + const cudnnTensorDescriptor_t cDesc, void *C) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnReduceTensorDescriptor_t, void *, size_t, + void *, size_t, const void *, const cudnnTensorDescriptor_t, const void *, + const void *, const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnReduceTensor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, reduceTensorDesc, indices, indicesSizeInBytes, workspace, workspaceSizeInBytes, alpha, aDesc, A, beta, cDesc, C); + return func_ptr(handle, reduceTensorDesc, indices, indicesSizeInBytes, + workspace, workspaceSizeInBytes, alpha, aDesc, A, beta, cDesc, + C); } -cudnnStatus_t CUDNNWINAPI cudnnSetTensor( - cudnnHandle_t handle, - const cudnnTensorDescriptor_t yDesc, - void *y, - const void *valuePtr ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, void *, const void *); +cudnnStatus_t CUDNNWINAPI cudnnSetTensor(cudnnHandle_t handle, + const cudnnTensorDescriptor_t yDesc, + void *y, const void *valuePtr) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, void *, const void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetTensor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, yDesc, y, valuePtr); } -cudnnStatus_t CUDNNWINAPI cudnnScaleTensor( - cudnnHandle_t handle, - const cudnnTensorDescriptor_t yDesc, - void *y, - const void *alpha ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, void *, const void *); +cudnnStatus_t CUDNNWINAPI cudnnScaleTensor(cudnnHandle_t handle, + const cudnnTensorDescriptor_t yDesc, + void *y, const void *alpha) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, void *, const void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnScaleTensor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, yDesc, y, alpha); } -cudnnStatus_t CUDNNWINAPI cudnnCreateFilterDescriptor( - cudnnFilterDescriptor_t *filterDesc ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnFilterDescriptor_t *); +cudnnStatus_t CUDNNWINAPI +cudnnCreateFilterDescriptor(cudnnFilterDescriptor_t *filterDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnFilterDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateFilterDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(filterDesc); } cudnnStatus_t CUDNNWINAPI cudnnSetFilter4dDescriptor( - cudnnFilterDescriptor_t filterDesc, - cudnnDataType_t dataType, /* image data type */ - cudnnTensorFormat_t format, - int k, /* number of output feature maps */ - int c, /* number of input feature maps */ - int h, /* height of each input filter */ - int w ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnFilterDescriptor_t, cudnnDataType_t, cudnnTensorFormat_t, int, int, int, int); + cudnnFilterDescriptor_t filterDesc, + cudnnDataType_t dataType, /* image data type */ + cudnnTensorFormat_t format, int k, /* number of output feature maps */ + int c, /* number of input feature maps */ + int h, /* height of each input filter */ + int w) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnFilterDescriptor_t, cudnnDataType_t, + cudnnTensorFormat_t, int, int, int, int); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetFilter4dDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(filterDesc, dataType, format, k, c, h, w); } cudnnStatus_t CUDNNWINAPI cudnnGetFilter4dDescriptor( - const cudnnFilterDescriptor_t filterDesc, - cudnnDataType_t *dataType, /* image data type */ - cudnnTensorFormat_t *format, - int *k, /* number of output feature maps */ - int *c, /* number of input feature maps */ - int *h, /* height of each input filter */ - int *w ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnFilterDescriptor_t, cudnnDataType_t *, cudnnTensorFormat_t *, int *, int *, int *, int *); + const cudnnFilterDescriptor_t filterDesc, + cudnnDataType_t *dataType, /* image data type */ + cudnnTensorFormat_t *format, int *k, /* number of output feature maps */ + int *c, /* number of input feature maps */ + int *h, /* height of each input filter */ + int *w) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnFilterDescriptor_t, cudnnDataType_t *, cudnnTensorFormat_t *, + int *, int *, int *, int *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetFilter4dDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(filterDesc, dataType, format, k, c, h, w); } cudnnStatus_t CUDNNWINAPI cudnnSetFilterNdDescriptor( - cudnnFilterDescriptor_t filterDesc, - cudnnDataType_t dataType, /* image data type */ - cudnnTensorFormat_t format, - int nbDims, - const int filterDimA[] ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnFilterDescriptor_t, cudnnDataType_t, cudnnTensorFormat_t, int, const int []); + cudnnFilterDescriptor_t filterDesc, + cudnnDataType_t dataType, /* image data type */ + cudnnTensorFormat_t format, int nbDims, const int filterDimA[]) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnFilterDescriptor_t, cudnnDataType_t, + cudnnTensorFormat_t, int, const int[]); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetFilterNdDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(filterDesc, dataType, format, nbDims, filterDimA); } cudnnStatus_t CUDNNWINAPI cudnnGetFilterNdDescriptor( - const cudnnFilterDescriptor_t filterDesc, - int nbDimsRequested, - cudnnDataType_t *dataType, /* image data type */ - cudnnTensorFormat_t *format, - int *nbDims, - int filterDimA[] ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnFilterDescriptor_t, int, cudnnDataType_t *, cudnnTensorFormat_t *, int *, int []); + const cudnnFilterDescriptor_t filterDesc, int nbDimsRequested, + cudnnDataType_t *dataType, /* image data type */ + cudnnTensorFormat_t *format, int *nbDims, int filterDimA[]) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnFilterDescriptor_t, int, cudnnDataType_t *, + cudnnTensorFormat_t *, int *, int[]); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetFilterNdDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(filterDesc, nbDimsRequested, dataType, format, nbDims, filterDimA); + return func_ptr(filterDesc, nbDimsRequested, dataType, format, nbDims, + filterDimA); } -cudnnStatus_t CUDNNWINAPI cudnnDestroyFilterDescriptor( - cudnnFilterDescriptor_t filterDesc ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnFilterDescriptor_t); +cudnnStatus_t CUDNNWINAPI +cudnnDestroyFilterDescriptor(cudnnFilterDescriptor_t filterDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnFilterDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyFilterDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(filterDesc); } -cudnnStatus_t CUDNNWINAPI cudnnCreateConvolutionDescriptor( - cudnnConvolutionDescriptor_t *convDesc ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateConvolutionDescriptor"); +cudnnStatus_t CUDNNWINAPI +cudnnCreateConvolutionDescriptor(cudnnConvolutionDescriptor_t *convDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnConvolutionDescriptor_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnCreateConvolutionDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(convDesc); } -cudnnStatus_t CUDNNWINAPI cudnnSetConvolutionMathType( cudnnConvolutionDescriptor_t convDesc, - cudnnMathType_t mathType ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, cudnnMathType_t); +cudnnStatus_t CUDNNWINAPI cudnnSetConvolutionMathType( + cudnnConvolutionDescriptor_t convDesc, cudnnMathType_t mathType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, + cudnnMathType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetConvolutionMathType"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(convDesc, mathType); } -cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionMathType( cudnnConvolutionDescriptor_t convDesc, - cudnnMathType_t *mathType ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, cudnnMathType_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionMathType( + cudnnConvolutionDescriptor_t convDesc, cudnnMathType_t *mathType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, + cudnnMathType_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionMathType"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(convDesc, mathType); } -cudnnStatus_t CUDNNWINAPI cudnnSetConvolutionGroupCount( cudnnConvolutionDescriptor_t convDesc, - int groupCount ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, int); +cudnnStatus_t CUDNNWINAPI cudnnSetConvolutionGroupCount( + cudnnConvolutionDescriptor_t convDesc, int groupCount) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, int); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetConvolutionGroupCount"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(convDesc, groupCount); } -cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionGroupCount( cudnnConvolutionDescriptor_t convDesc, - int *groupCount ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, int *); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionGroupCount( + cudnnConvolutionDescriptor_t convDesc, int *groupCount) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, int *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionGroupCount"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(convDesc, groupCount); } -cudnnStatus_t CUDNNWINAPI cudnnSetConvolution2dDescriptor( cudnnConvolutionDescriptor_t convDesc, - int pad_h, /* zero-padding height */ - int pad_w, /* zero-padding width */ - int u, /* vertical filter stride */ - int v, /* horizontal filter stride */ - int dilation_h, /* filter dilation in the vertical dimension */ - int dilation_w, /* filter dilation in the horizontal dimension */ - cudnnConvolutionMode_t mode, - cudnnDataType_t computeType - ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, int, int, int, int, int, int, cudnnConvolutionMode_t, cudnnDataType_t); +cudnnStatus_t CUDNNWINAPI cudnnSetConvolution2dDescriptor( + cudnnConvolutionDescriptor_t convDesc, int pad_h, /* zero-padding height */ + int pad_w, /* zero-padding width */ + int u, /* vertical filter stride */ + int v, /* horizontal filter stride */ + int dilation_h, /* filter dilation in the vertical dimension */ + int dilation_w, /* filter dilation in the horizontal dimension */ + cudnnConvolutionMode_t mode, cudnnDataType_t computeType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnConvolutionDescriptor_t, int, int, int, int, int, int, + cudnnConvolutionMode_t, cudnnDataType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetConvolution2dDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(convDesc, pad_h, pad_w, u, v, dilation_h, dilation_w, mode, computeType); + return func_ptr(convDesc, pad_h, pad_w, u, v, dilation_h, dilation_w, mode, + computeType); } -cudnnStatus_t CUDNNWINAPI cudnnGetConvolution2dDescriptor( const cudnnConvolutionDescriptor_t convDesc, - int* pad_h, /* zero-padding height */ - int* pad_w, /* zero-padding width */ - int* u, /* vertical filter stride */ - int* v, /* horizontal filter stride */ - int* dilation_h, /* filter dilation in the vertical dimension */ - int* dilation_w, /* filter dilation in the horizontal dimension */ - cudnnConvolutionMode_t* mode, - cudnnDataType_t *computeType - ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnConvolutionDescriptor_t, int *, int *, int *, int *, int *, int *, cudnnConvolutionMode_t *, cudnnDataType_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolution2dDescriptor( + const cudnnConvolutionDescriptor_t convDesc, + int *pad_h, /* zero-padding height */ + int *pad_w, /* zero-padding width */ + int *u, /* vertical filter stride */ + int *v, /* horizontal filter stride */ + int *dilation_h, /* filter dilation in the vertical dimension */ + int *dilation_w, /* filter dilation in the horizontal dimension */ + cudnnConvolutionMode_t *mode, cudnnDataType_t *computeType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnConvolutionDescriptor_t, int *, int *, int *, int *, int *, + int *, cudnnConvolutionMode_t *, cudnnDataType_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolution2dDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(convDesc, pad_h, pad_w, u, v, dilation_h, dilation_w, mode, computeType); + return func_ptr(convDesc, pad_h, pad_w, u, v, dilation_h, dilation_w, mode, + computeType); } cudnnStatus_t CUDNNWINAPI cudnnGetConvolution2dForwardOutputDim( - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t inputTensorDesc, - const cudnnFilterDescriptor_t filterDesc, - int *n, - int *c, - int *h, - int *w ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, int *, int *, int *, int *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolution2dForwardOutputDim"); + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t inputTensorDesc, + const cudnnFilterDescriptor_t filterDesc, int *n, int *c, int *h, int *w) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, + const cudnnFilterDescriptor_t, int *, int *, int *, int *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolution2dForwardOutputDim"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(convDesc, inputTensorDesc, filterDesc, n, c, h, w); } cudnnStatus_t CUDNNWINAPI cudnnSetConvolutionNdDescriptor( - cudnnConvolutionDescriptor_t convDesc, - int arrayLength, /* nbDims-2 size */ - const int padA[], - const int filterStrideA[], - const int dilationA[], - cudnnConvolutionMode_t mode, - cudnnDataType_t computeType ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, int, const int [], const int [], const int [], cudnnConvolutionMode_t, cudnnDataType_t); + cudnnConvolutionDescriptor_t convDesc, int arrayLength, /* nbDims-2 size */ + const int padA[], const int filterStrideA[], const int dilationA[], + cudnnConvolutionMode_t mode, cudnnDataType_t computeType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnConvolutionDescriptor_t, int, const int[], const int[], const int[], + cudnnConvolutionMode_t, cudnnDataType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetConvolutionNdDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(convDesc, arrayLength, padA, filterStrideA, dilationA, mode, computeType); + return func_ptr(convDesc, arrayLength, padA, filterStrideA, dilationA, mode, + computeType); } cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionNdDescriptor( - const cudnnConvolutionDescriptor_t convDesc, - int arrayLengthRequested, - int *arrayLength, - int padA[], - int strideA[], - int dilationA[], - cudnnConvolutionMode_t *mode, - cudnnDataType_t *computeType ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnConvolutionDescriptor_t, int, int *, int [], int [], int [], cudnnConvolutionMode_t *, cudnnDataType_t *); + const cudnnConvolutionDescriptor_t convDesc, int arrayLengthRequested, + int *arrayLength, int padA[], int strideA[], int dilationA[], + cudnnConvolutionMode_t *mode, cudnnDataType_t *computeType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnConvolutionDescriptor_t, int, int *, int[], int[], int[], + cudnnConvolutionMode_t *, cudnnDataType_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionNdDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(convDesc, arrayLengthRequested, arrayLength, padA, strideA, dilationA, mode, computeType); + return func_ptr(convDesc, arrayLengthRequested, arrayLength, padA, strideA, + dilationA, mode, computeType); } cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionNdForwardOutputDim( - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t inputTensorDesc, - const cudnnFilterDescriptor_t filterDesc, - int nbDims, - int tensorOutputDimA[] ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, int, int []); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionNdForwardOutputDim"); + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t inputTensorDesc, + const cudnnFilterDescriptor_t filterDesc, int nbDims, + int tensorOutputDimA[]) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, + const cudnnFilterDescriptor_t, int, int[]); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionNdForwardOutputDim"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(convDesc, inputTensorDesc, filterDesc, nbDims, tensorOutputDimA); + return func_ptr(convDesc, inputTensorDesc, filterDesc, nbDims, + tensorOutputDimA); } -cudnnStatus_t CUDNNWINAPI cudnnDestroyConvolutionDescriptor( - cudnnConvolutionDescriptor_t convDesc ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyConvolutionDescriptor"); +cudnnStatus_t CUDNNWINAPI +cudnnDestroyConvolutionDescriptor(cudnnConvolutionDescriptor_t convDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnConvolutionDescriptor_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnDestroyConvolutionDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(convDesc); } -cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionForwardAlgorithmMaxCount( cudnnHandle_t handle, - int *count) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, int *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionForwardAlgorithmMaxCount"); +cudnnStatus_t CUDNNWINAPI +cudnnGetConvolutionForwardAlgorithmMaxCount(cudnnHandle_t handle, int *count) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t, int *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionForwardAlgorithmMaxCount"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, count); } cudnnStatus_t CUDNNWINAPI cudnnFindConvolutionForwardAlgorithm( - cudnnHandle_t handle, - const cudnnTensorDescriptor_t xDesc, - const cudnnFilterDescriptor_t wDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t yDesc, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnConvolutionFwdAlgoPerf_t *perfResults ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, const int, int *, cudnnConvolutionFwdAlgoPerf_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindConvolutionForwardAlgorithm"); + cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, + const cudnnFilterDescriptor_t wDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t yDesc, const int requestedAlgoCount, + int *returnedAlgoCount, cudnnConvolutionFwdAlgoPerf_t *perfResults) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, + const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnTensorDescriptor_t, const int, int *, + cudnnConvolutionFwdAlgoPerf_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnFindConvolutionForwardAlgorithm"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, xDesc, wDesc, convDesc, yDesc, requestedAlgoCount, returnedAlgoCount, perfResults); + return func_ptr(handle, xDesc, wDesc, convDesc, yDesc, requestedAlgoCount, + returnedAlgoCount, perfResults); } cudnnStatus_t CUDNNWINAPI cudnnFindConvolutionForwardAlgorithmEx( - cudnnHandle_t handle, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t yDesc, - void *y, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnConvolutionFwdAlgoPerf_t *perfResults, - void *workSpace, - size_t workSpaceSizeInBytes ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, void *, const int, int *, cudnnConvolutionFwdAlgoPerf_t *, void *, size_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindConvolutionForwardAlgorithmEx"); + cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, const void *x, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t yDesc, void *y, const int requestedAlgoCount, + int *returnedAlgoCount, cudnnConvolutionFwdAlgoPerf_t *perfResults, + void *workSpace, size_t workSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, void *, + const int, int *, cudnnConvolutionFwdAlgoPerf_t *, void *, size_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnFindConvolutionForwardAlgorithmEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, xDesc, x, wDesc, w, convDesc, yDesc, y, requestedAlgoCount, returnedAlgoCount, perfResults, workSpace, workSpaceSizeInBytes); + return func_ptr(handle, xDesc, x, wDesc, w, convDesc, yDesc, y, + requestedAlgoCount, returnedAlgoCount, perfResults, workSpace, + workSpaceSizeInBytes); } cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionForwardAlgorithm( - cudnnHandle_t handle, - const cudnnTensorDescriptor_t xDesc, - const cudnnFilterDescriptor_t wDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t yDesc, - cudnnConvolutionFwdPreference_t preference, - size_t memoryLimitInBytes, - cudnnConvolutionFwdAlgo_t *algo ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, cudnnConvolutionFwdPreference_t, size_t, cudnnConvolutionFwdAlgo_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionForwardAlgorithm"); + cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, + const cudnnFilterDescriptor_t wDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t yDesc, + cudnnConvolutionFwdPreference_t preference, size_t memoryLimitInBytes, + cudnnConvolutionFwdAlgo_t *algo) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, + const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnTensorDescriptor_t, cudnnConvolutionFwdPreference_t, size_t, + cudnnConvolutionFwdAlgo_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionForwardAlgorithm"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, xDesc, wDesc, convDesc, yDesc, preference, memoryLimitInBytes, algo); + return func_ptr(handle, xDesc, wDesc, convDesc, yDesc, preference, + memoryLimitInBytes, algo); } cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionForwardAlgorithm_v7( - cudnnHandle_t handle, - const cudnnTensorDescriptor_t srcDesc, - const cudnnFilterDescriptor_t filterDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t destDesc, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnConvolutionFwdAlgoPerf_t *perfResults) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, const int, int *, cudnnConvolutionFwdAlgoPerf_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionForwardAlgorithm_v7"); + cudnnHandle_t handle, const cudnnTensorDescriptor_t srcDesc, + const cudnnFilterDescriptor_t filterDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t destDesc, const int requestedAlgoCount, + int *returnedAlgoCount, cudnnConvolutionFwdAlgoPerf_t *perfResults) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, + const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnTensorDescriptor_t, const int, int *, + cudnnConvolutionFwdAlgoPerf_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionForwardAlgorithm_v7"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, srcDesc, filterDesc, convDesc, destDesc, requestedAlgoCount, returnedAlgoCount, perfResults); + return func_ptr(handle, srcDesc, filterDesc, convDesc, destDesc, + requestedAlgoCount, returnedAlgoCount, perfResults); } cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionForwardWorkspaceSize( - cudnnHandle_t handle, - const cudnnTensorDescriptor_t xDesc, - const cudnnFilterDescriptor_t wDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t yDesc, - cudnnConvolutionFwdAlgo_t algo, - size_t *sizeInBytes ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, cudnnConvolutionFwdAlgo_t, size_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionForwardWorkspaceSize"); + cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, + const cudnnFilterDescriptor_t wDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t yDesc, cudnnConvolutionFwdAlgo_t algo, + size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, + const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnTensorDescriptor_t, cudnnConvolutionFwdAlgo_t, size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionForwardWorkspaceSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, xDesc, wDesc, convDesc, yDesc, algo, sizeInBytes); } cudnnStatus_t CUDNNWINAPI cudnnConvolutionForward( - cudnnHandle_t handle, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnConvolutionDescriptor_t convDesc, - cudnnConvolutionFwdAlgo_t algo, - void *workSpace, - size_t workSpaceSizeInBytes, - const void *beta, - const cudnnTensorDescriptor_t yDesc, - void *y ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnConvolutionDescriptor_t, cudnnConvolutionFwdAlgo_t, void *, size_t, const void *, const cudnnTensorDescriptor_t, void *); + cudnnHandle_t handle, const void *alpha, + const cudnnTensorDescriptor_t xDesc, const void *x, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnConvolutionDescriptor_t convDesc, cudnnConvolutionFwdAlgo_t algo, + void *workSpace, size_t workSpaceSizeInBytes, const void *beta, + const cudnnTensorDescriptor_t yDesc, void *y) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnConvolutionDescriptor_t, cudnnConvolutionFwdAlgo_t, void *, + size_t, const void *, const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnConvolutionForward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, alpha, xDesc, x, wDesc, w, convDesc, algo, workSpace, workSpaceSizeInBytes, beta, yDesc, y); + return func_ptr(handle, alpha, xDesc, x, wDesc, w, convDesc, algo, workSpace, + workSpaceSizeInBytes, beta, yDesc, y); } cudnnStatus_t CUDNNWINAPI cudnnConvolutionBiasActivationForward( - cudnnHandle_t handle, - const void *alpha1, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnConvolutionDescriptor_t convDesc, - cudnnConvolutionFwdAlgo_t algo, - void *workSpace, - size_t workSpaceSizeInBytes, - const void *alpha2, - const cudnnTensorDescriptor_t zDesc, - const void *z, - const cudnnTensorDescriptor_t biasDesc, - const void *bias, - const cudnnActivationDescriptor_t activationDesc, - const cudnnTensorDescriptor_t yDesc, - void *y ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnConvolutionDescriptor_t, cudnnConvolutionFwdAlgo_t, void *, size_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnActivationDescriptor_t, const cudnnTensorDescriptor_t, void *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnConvolutionBiasActivationForward"); + cudnnHandle_t handle, const void *alpha1, + const cudnnTensorDescriptor_t xDesc, const void *x, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnConvolutionDescriptor_t convDesc, cudnnConvolutionFwdAlgo_t algo, + void *workSpace, size_t workSpaceSizeInBytes, const void *alpha2, + const cudnnTensorDescriptor_t zDesc, const void *z, + const cudnnTensorDescriptor_t biasDesc, const void *bias, + const cudnnActivationDescriptor_t activationDesc, + const cudnnTensorDescriptor_t yDesc, void *y) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnConvolutionDescriptor_t, cudnnConvolutionFwdAlgo_t, void *, + size_t, const void *, const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnActivationDescriptor_t, const cudnnTensorDescriptor_t, void *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnConvolutionBiasActivationForward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, alpha1, xDesc, x, wDesc, w, convDesc, algo, workSpace, workSpaceSizeInBytes, alpha2, zDesc, z, biasDesc, bias, activationDesc, yDesc, y); + return func_ptr(handle, alpha1, xDesc, x, wDesc, w, convDesc, algo, workSpace, + workSpaceSizeInBytes, alpha2, zDesc, z, biasDesc, bias, + activationDesc, yDesc, y); } cudnnStatus_t CUDNNWINAPI cudnnConvolutionBackwardBias( - cudnnHandle_t handle, - const void *alpha, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const void *beta, - const cudnnTensorDescriptor_t dbDesc, - void *db ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); + cudnnHandle_t handle, const void *alpha, + const cudnnTensorDescriptor_t dyDesc, const void *dy, const void *beta, + const cudnnTensorDescriptor_t dbDesc, void *db) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, + const void *, const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnConvolutionBackwardBias"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, alpha, dyDesc, dy, beta, dbDesc, db); } -cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionBackwardFilterAlgorithmMaxCount( cudnnHandle_t handle, - int *count) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, int *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardFilterAlgorithmMaxCount"); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionBackwardFilterAlgorithmMaxCount( + cudnnHandle_t handle, int *count) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t, int *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardFilterAlgorithmMaxCount"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, count); } cudnnStatus_t CUDNNWINAPI cudnnFindConvolutionBackwardFilterAlgorithm( - cudnnHandle_t handle, - const cudnnTensorDescriptor_t xDesc, - const cudnnTensorDescriptor_t dyDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnFilterDescriptor_t dwDesc, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnConvolutionBwdFilterAlgoPerf_t *perfResults ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnFilterDescriptor_t, const int, int *, cudnnConvolutionBwdFilterAlgoPerf_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindConvolutionBackwardFilterAlgorithm"); + cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, + const cudnnTensorDescriptor_t dyDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnFilterDescriptor_t dwDesc, const int requestedAlgoCount, + int *returnedAlgoCount, cudnnConvolutionBwdFilterAlgoPerf_t *perfResults) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnFilterDescriptor_t, const int, int *, + cudnnConvolutionBwdFilterAlgoPerf_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnFindConvolutionBackwardFilterAlgorithm"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, xDesc, dyDesc, convDesc, dwDesc, requestedAlgoCount, returnedAlgoCount, perfResults); + return func_ptr(handle, xDesc, dyDesc, convDesc, dwDesc, requestedAlgoCount, + returnedAlgoCount, perfResults); } cudnnStatus_t CUDNNWINAPI cudnnFindConvolutionBackwardFilterAlgorithmEx( - cudnnHandle_t handle, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const cudnnTensorDescriptor_t dyDesc, - const void *y, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnFilterDescriptor_t dwDesc, - void *dw, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnConvolutionBwdFilterAlgoPerf_t *perfResults, - void *workSpace, - size_t workSpaceSizeInBytes ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnConvolutionDescriptor_t, const cudnnFilterDescriptor_t, void *, const int, int *, cudnnConvolutionBwdFilterAlgoPerf_t *, void *, size_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindConvolutionBackwardFilterAlgorithmEx"); + cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, const void *x, + const cudnnTensorDescriptor_t dyDesc, const void *y, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnFilterDescriptor_t dwDesc, void *dw, + const int requestedAlgoCount, int *returnedAlgoCount, + cudnnConvolutionBwdFilterAlgoPerf_t *perfResults, void *workSpace, + size_t workSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnConvolutionDescriptor_t, const cudnnFilterDescriptor_t, void *, + const int, int *, cudnnConvolutionBwdFilterAlgoPerf_t *, void *, size_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnFindConvolutionBackwardFilterAlgorithmEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, xDesc, x, dyDesc, y, convDesc, dwDesc, dw, requestedAlgoCount, returnedAlgoCount, perfResults, workSpace, workSpaceSizeInBytes); + return func_ptr(handle, xDesc, x, dyDesc, y, convDesc, dwDesc, dw, + requestedAlgoCount, returnedAlgoCount, perfResults, workSpace, + workSpaceSizeInBytes); } cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionBackwardFilterAlgorithm( - cudnnHandle_t handle, - const cudnnTensorDescriptor_t xDesc, - const cudnnTensorDescriptor_t dyDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnFilterDescriptor_t dwDesc, - cudnnConvolutionBwdFilterPreference_t preference, - size_t memoryLimitInBytes, - cudnnConvolutionBwdFilterAlgo_t *algo ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnFilterDescriptor_t, cudnnConvolutionBwdFilterPreference_t, size_t, cudnnConvolutionBwdFilterAlgo_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardFilterAlgorithm"); + cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, + const cudnnTensorDescriptor_t dyDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnFilterDescriptor_t dwDesc, + cudnnConvolutionBwdFilterPreference_t preference, size_t memoryLimitInBytes, + cudnnConvolutionBwdFilterAlgo_t *algo) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnFilterDescriptor_t, cudnnConvolutionBwdFilterPreference_t, + size_t, cudnnConvolutionBwdFilterAlgo_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardFilterAlgorithm"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, xDesc, dyDesc, convDesc, dwDesc, preference, memoryLimitInBytes, algo); + return func_ptr(handle, xDesc, dyDesc, convDesc, dwDesc, preference, + memoryLimitInBytes, algo); } cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionBackwardFilterAlgorithm_v7( - cudnnHandle_t handle, - const cudnnTensorDescriptor_t srcDesc, - const cudnnTensorDescriptor_t diffDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnFilterDescriptor_t gradDesc, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnConvolutionBwdFilterAlgoPerf_t *perfResults) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnFilterDescriptor_t, const int, int *, cudnnConvolutionBwdFilterAlgoPerf_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardFilterAlgorithm_v7"); + cudnnHandle_t handle, const cudnnTensorDescriptor_t srcDesc, + const cudnnTensorDescriptor_t diffDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnFilterDescriptor_t gradDesc, const int requestedAlgoCount, + int *returnedAlgoCount, cudnnConvolutionBwdFilterAlgoPerf_t *perfResults) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnFilterDescriptor_t, const int, int *, + cudnnConvolutionBwdFilterAlgoPerf_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardFilterAlgorithm_v7"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, srcDesc, diffDesc, convDesc, gradDesc, requestedAlgoCount, returnedAlgoCount, perfResults); + return func_ptr(handle, srcDesc, diffDesc, convDesc, gradDesc, + requestedAlgoCount, returnedAlgoCount, perfResults); } cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionBackwardFilterWorkspaceSize( - cudnnHandle_t handle, - const cudnnTensorDescriptor_t xDesc, - const cudnnTensorDescriptor_t dyDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnFilterDescriptor_t gradDesc, - cudnnConvolutionBwdFilterAlgo_t algo, - size_t *sizeInBytes ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnFilterDescriptor_t, cudnnConvolutionBwdFilterAlgo_t, size_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardFilterWorkspaceSize"); + cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, + const cudnnTensorDescriptor_t dyDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnFilterDescriptor_t gradDesc, + cudnnConvolutionBwdFilterAlgo_t algo, size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnFilterDescriptor_t, cudnnConvolutionBwdFilterAlgo_t, size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardFilterWorkspaceSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, xDesc, dyDesc, convDesc, gradDesc, algo, sizeInBytes); } cudnnStatus_t CUDNNWINAPI cudnnConvolutionBackwardFilter( - cudnnHandle_t handle, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const cudnnConvolutionDescriptor_t convDesc, - cudnnConvolutionBwdFilterAlgo_t algo, - void *workSpace, - size_t workSpaceSizeInBytes, - const void *beta, - const cudnnFilterDescriptor_t dwDesc, - void *dw ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnConvolutionDescriptor_t, cudnnConvolutionBwdFilterAlgo_t, void *, size_t, const void *, const cudnnFilterDescriptor_t, void *); + cudnnHandle_t handle, const void *alpha, + const cudnnTensorDescriptor_t xDesc, const void *x, + const cudnnTensorDescriptor_t dyDesc, const void *dy, + const cudnnConvolutionDescriptor_t convDesc, + cudnnConvolutionBwdFilterAlgo_t algo, void *workSpace, + size_t workSpaceSizeInBytes, const void *beta, + const cudnnFilterDescriptor_t dwDesc, void *dw) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnConvolutionDescriptor_t, cudnnConvolutionBwdFilterAlgo_t, + void *, size_t, const void *, const cudnnFilterDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnConvolutionBackwardFilter"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, alpha, xDesc, x, dyDesc, dy, convDesc, algo, workSpace, workSpaceSizeInBytes, beta, dwDesc, dw); + return func_ptr(handle, alpha, xDesc, x, dyDesc, dy, convDesc, algo, + workSpace, workSpaceSizeInBytes, beta, dwDesc, dw); } -cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionBackwardDataAlgorithmMaxCount( cudnnHandle_t handle, - int *count) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, int *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardDataAlgorithmMaxCount"); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionBackwardDataAlgorithmMaxCount( + cudnnHandle_t handle, int *count) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t, int *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardDataAlgorithmMaxCount"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, count); } cudnnStatus_t CUDNNWINAPI cudnnFindConvolutionBackwardDataAlgorithm( - cudnnHandle_t handle, - const cudnnFilterDescriptor_t wDesc, - const cudnnTensorDescriptor_t dyDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t dxDesc, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnConvolutionBwdDataAlgoPerf_t *perfResults ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnFilterDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, const int, int *, cudnnConvolutionBwdDataAlgoPerf_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindConvolutionBackwardDataAlgorithm"); + cudnnHandle_t handle, const cudnnFilterDescriptor_t wDesc, + const cudnnTensorDescriptor_t dyDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t dxDesc, const int requestedAlgoCount, + int *returnedAlgoCount, cudnnConvolutionBwdDataAlgoPerf_t *perfResults) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnFilterDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnTensorDescriptor_t, const int, int *, + cudnnConvolutionBwdDataAlgoPerf_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnFindConvolutionBackwardDataAlgorithm"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, wDesc, dyDesc, convDesc, dxDesc, requestedAlgoCount, returnedAlgoCount, perfResults); + return func_ptr(handle, wDesc, dyDesc, convDesc, dxDesc, requestedAlgoCount, + returnedAlgoCount, perfResults); } cudnnStatus_t CUDNNWINAPI cudnnFindConvolutionBackwardDataAlgorithmEx( - cudnnHandle_t handle, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t dxDesc, - void *dx, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnConvolutionBwdDataAlgoPerf_t *perfResults, - void *workSpace, - size_t workSpaceSizeInBytes ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnFilterDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, void *, const int, int *, cudnnConvolutionBwdDataAlgoPerf_t *, void *, size_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindConvolutionBackwardDataAlgorithmEx"); + cudnnHandle_t handle, const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnTensorDescriptor_t dyDesc, const void *dy, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t dxDesc, void *dx, + const int requestedAlgoCount, int *returnedAlgoCount, + cudnnConvolutionBwdDataAlgoPerf_t *perfResults, void *workSpace, + size_t workSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnFilterDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, void *, + const int, int *, cudnnConvolutionBwdDataAlgoPerf_t *, void *, size_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnFindConvolutionBackwardDataAlgorithmEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, wDesc, w, dyDesc, dy, convDesc, dxDesc, dx, requestedAlgoCount, returnedAlgoCount, perfResults, workSpace, workSpaceSizeInBytes); + return func_ptr(handle, wDesc, w, dyDesc, dy, convDesc, dxDesc, dx, + requestedAlgoCount, returnedAlgoCount, perfResults, workSpace, + workSpaceSizeInBytes); } cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionBackwardDataAlgorithm( - cudnnHandle_t handle, - const cudnnFilterDescriptor_t wDesc, - const cudnnTensorDescriptor_t dyDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t dxDesc, - cudnnConvolutionBwdDataPreference_t preference, - size_t memoryLimitInBytes, - cudnnConvolutionBwdDataAlgo_t *algo ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnFilterDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, cudnnConvolutionBwdDataPreference_t, size_t, cudnnConvolutionBwdDataAlgo_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardDataAlgorithm"); + cudnnHandle_t handle, const cudnnFilterDescriptor_t wDesc, + const cudnnTensorDescriptor_t dyDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t dxDesc, + cudnnConvolutionBwdDataPreference_t preference, size_t memoryLimitInBytes, + cudnnConvolutionBwdDataAlgo_t *algo) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnFilterDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnTensorDescriptor_t, cudnnConvolutionBwdDataPreference_t, + size_t, cudnnConvolutionBwdDataAlgo_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardDataAlgorithm"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, wDesc, dyDesc, convDesc, dxDesc, preference, memoryLimitInBytes, algo); + return func_ptr(handle, wDesc, dyDesc, convDesc, dxDesc, preference, + memoryLimitInBytes, algo); } cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionBackwardDataAlgorithm_v7( - cudnnHandle_t handle, - const cudnnFilterDescriptor_t filterDesc, - const cudnnTensorDescriptor_t diffDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t gradDesc, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnConvolutionBwdDataAlgoPerf_t *perfResults) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnFilterDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, const int, int *, cudnnConvolutionBwdDataAlgoPerf_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardDataAlgorithm_v7"); + cudnnHandle_t handle, const cudnnFilterDescriptor_t filterDesc, + const cudnnTensorDescriptor_t diffDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t gradDesc, const int requestedAlgoCount, + int *returnedAlgoCount, cudnnConvolutionBwdDataAlgoPerf_t *perfResults) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnFilterDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnTensorDescriptor_t, const int, int *, + cudnnConvolutionBwdDataAlgoPerf_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardDataAlgorithm_v7"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, filterDesc, diffDesc, convDesc, gradDesc, requestedAlgoCount, returnedAlgoCount, perfResults); + return func_ptr(handle, filterDesc, diffDesc, convDesc, gradDesc, + requestedAlgoCount, returnedAlgoCount, perfResults); } cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionBackwardDataWorkspaceSize( - cudnnHandle_t handle, - const cudnnFilterDescriptor_t wDesc, - const cudnnTensorDescriptor_t dyDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t dxDesc, - cudnnConvolutionBwdDataAlgo_t algo, - size_t *sizeInBytes ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnFilterDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, cudnnConvolutionBwdDataAlgo_t, size_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardDataWorkspaceSize"); + cudnnHandle_t handle, const cudnnFilterDescriptor_t wDesc, + const cudnnTensorDescriptor_t dyDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t dxDesc, cudnnConvolutionBwdDataAlgo_t algo, + size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnFilterDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnTensorDescriptor_t, cudnnConvolutionBwdDataAlgo_t, size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardDataWorkspaceSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, wDesc, dyDesc, convDesc, dxDesc, algo, sizeInBytes); } cudnnStatus_t CUDNNWINAPI cudnnConvolutionBackwardData( - cudnnHandle_t handle, - const void *alpha, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const cudnnConvolutionDescriptor_t convDesc, - cudnnConvolutionBwdDataAlgo_t algo, - void *workSpace, - size_t workSpaceSizeInBytes, - const void *beta, - const cudnnTensorDescriptor_t dxDesc, - void *dx ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnConvolutionDescriptor_t, cudnnConvolutionBwdDataAlgo_t, void *, size_t, const void *, const cudnnTensorDescriptor_t, void *); + cudnnHandle_t handle, const void *alpha, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnTensorDescriptor_t dyDesc, const void *dy, + const cudnnConvolutionDescriptor_t convDesc, + cudnnConvolutionBwdDataAlgo_t algo, void *workSpace, + size_t workSpaceSizeInBytes, const void *beta, + const cudnnTensorDescriptor_t dxDesc, void *dx) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const void *, const cudnnFilterDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnConvolutionDescriptor_t, cudnnConvolutionBwdDataAlgo_t, void *, + size_t, const void *, const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnConvolutionBackwardData"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, alpha, wDesc, w, dyDesc, dy, convDesc, algo, workSpace, workSpaceSizeInBytes, beta, dxDesc, dx); + return func_ptr(handle, alpha, wDesc, w, dyDesc, dy, convDesc, algo, + workSpace, workSpaceSizeInBytes, beta, dxDesc, dx); } -cudnnStatus_t CUDNNWINAPI cudnnIm2Col( - cudnnHandle_t handle, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const cudnnFilterDescriptor_t wDesc, - const cudnnConvolutionDescriptor_t convDesc, - void *colBuffer ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI +cudnnIm2Col(cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, + const void *x, const cudnnFilterDescriptor_t wDesc, + const cudnnConvolutionDescriptor_t convDesc, void *colBuffer) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, + const void *, const cudnnFilterDescriptor_t, + const cudnnConvolutionDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnIm2Col"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, xDesc, x, wDesc, convDesc, colBuffer); } cudnnStatus_t CUDNNWINAPI cudnnSoftmaxForward( - cudnnHandle_t handle, - cudnnSoftmaxAlgorithm_t algo, - cudnnSoftmaxMode_t mode, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t yDesc, - void *y ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnSoftmaxAlgorithm_t, cudnnSoftmaxMode_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); + cudnnHandle_t handle, cudnnSoftmaxAlgorithm_t algo, cudnnSoftmaxMode_t mode, + const void *alpha, const cudnnTensorDescriptor_t xDesc, const void *x, + const void *beta, const cudnnTensorDescriptor_t yDesc, void *y) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnSoftmaxAlgorithm_t, cudnnSoftmaxMode_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSoftmaxForward"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, algo, mode, alpha, xDesc, x, beta, yDesc, y); } cudnnStatus_t CUDNNWINAPI cudnnSoftmaxBackward( - cudnnHandle_t handle, - cudnnSoftmaxAlgorithm_t algo, - cudnnSoftmaxMode_t mode, - const void *alpha, - const cudnnTensorDescriptor_t yDesc, - const void *y, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const void *beta, - const cudnnTensorDescriptor_t dxDesc, - void *dx ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnSoftmaxAlgorithm_t, cudnnSoftmaxMode_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); + cudnnHandle_t handle, cudnnSoftmaxAlgorithm_t algo, cudnnSoftmaxMode_t mode, + const void *alpha, const cudnnTensorDescriptor_t yDesc, const void *y, + const cudnnTensorDescriptor_t dyDesc, const void *dy, const void *beta, + const cudnnTensorDescriptor_t dxDesc, void *dx) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnSoftmaxAlgorithm_t, cudnnSoftmaxMode_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSoftmaxBackward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, algo, mode, alpha, yDesc, y, dyDesc, dy, beta, dxDesc, dx); + return func_ptr(handle, algo, mode, alpha, yDesc, y, dyDesc, dy, beta, dxDesc, + dx); } -cudnnStatus_t CUDNNWINAPI cudnnCreatePoolingDescriptor( - cudnnPoolingDescriptor_t *poolingDesc ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnPoolingDescriptor_t *); +cudnnStatus_t CUDNNWINAPI +cudnnCreatePoolingDescriptor(cudnnPoolingDescriptor_t *poolingDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnPoolingDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreatePoolingDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(poolingDesc); } cudnnStatus_t CUDNNWINAPI cudnnSetPooling2dDescriptor( - cudnnPoolingDescriptor_t poolingDesc, - cudnnPoolingMode_t mode, - cudnnNanPropagation_t maxpoolingNanOpt, - int windowHeight, - int windowWidth, - int verticalPadding, - int horizontalPadding, - int verticalStride, - int horizontalStride ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnPoolingDescriptor_t, cudnnPoolingMode_t, cudnnNanPropagation_t, int, int, int, int, int, int); + cudnnPoolingDescriptor_t poolingDesc, cudnnPoolingMode_t mode, + cudnnNanPropagation_t maxpoolingNanOpt, int windowHeight, int windowWidth, + int verticalPadding, int horizontalPadding, int verticalStride, + int horizontalStride) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnPoolingDescriptor_t, cudnnPoolingMode_t, cudnnNanPropagation_t, int, + int, int, int, int, int); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetPooling2dDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(poolingDesc, mode, maxpoolingNanOpt, windowHeight, windowWidth, verticalPadding, horizontalPadding, verticalStride, horizontalStride); + return func_ptr(poolingDesc, mode, maxpoolingNanOpt, windowHeight, + windowWidth, verticalPadding, horizontalPadding, + verticalStride, horizontalStride); } cudnnStatus_t CUDNNWINAPI cudnnGetPooling2dDescriptor( - const cudnnPoolingDescriptor_t poolingDesc, - cudnnPoolingMode_t *mode, - cudnnNanPropagation_t *maxpoolingNanOpt, - int *windowHeight, - int *windowWidth, - int *verticalPadding, - int *horizontalPadding, - int *verticalStride, - int *horizontalStride ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnPoolingDescriptor_t, cudnnPoolingMode_t *, cudnnNanPropagation_t *, int *, int *, int *, int *, int *, int *); + const cudnnPoolingDescriptor_t poolingDesc, cudnnPoolingMode_t *mode, + cudnnNanPropagation_t *maxpoolingNanOpt, int *windowHeight, + int *windowWidth, int *verticalPadding, int *horizontalPadding, + int *verticalStride, int *horizontalStride) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnPoolingDescriptor_t, cudnnPoolingMode_t *, + cudnnNanPropagation_t *, int *, int *, int *, int *, int *, int *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetPooling2dDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(poolingDesc, mode, maxpoolingNanOpt, windowHeight, windowWidth, verticalPadding, horizontalPadding, verticalStride, horizontalStride); + return func_ptr(poolingDesc, mode, maxpoolingNanOpt, windowHeight, + windowWidth, verticalPadding, horizontalPadding, + verticalStride, horizontalStride); } cudnnStatus_t CUDNNWINAPI cudnnSetPoolingNdDescriptor( - cudnnPoolingDescriptor_t poolingDesc, - const cudnnPoolingMode_t mode, - const cudnnNanPropagation_t maxpoolingNanOpt, - int nbDims, - const int windowDimA[], - const int paddingA[], - const int strideA[] ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnPoolingDescriptor_t, const cudnnPoolingMode_t, const cudnnNanPropagation_t, int, const int [], const int [], const int []); + cudnnPoolingDescriptor_t poolingDesc, const cudnnPoolingMode_t mode, + const cudnnNanPropagation_t maxpoolingNanOpt, int nbDims, + const int windowDimA[], const int paddingA[], const int strideA[]) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnPoolingDescriptor_t, const cudnnPoolingMode_t, + const cudnnNanPropagation_t, int, const int[], const int[], const int[]); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetPoolingNdDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(poolingDesc, mode, maxpoolingNanOpt, nbDims, windowDimA, paddingA, strideA); + return func_ptr(poolingDesc, mode, maxpoolingNanOpt, nbDims, windowDimA, + paddingA, strideA); } cudnnStatus_t CUDNNWINAPI cudnnGetPoolingNdDescriptor( - const cudnnPoolingDescriptor_t poolingDesc, - int nbDimsRequested, - cudnnPoolingMode_t *mode, - cudnnNanPropagation_t *maxpoolingNanOpt, - int *nbDims, - int windowDimA[], - int paddingA[], - int strideA[] ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnPoolingDescriptor_t, int, cudnnPoolingMode_t *, cudnnNanPropagation_t *, int *, int [], int [], int []); + const cudnnPoolingDescriptor_t poolingDesc, int nbDimsRequested, + cudnnPoolingMode_t *mode, cudnnNanPropagation_t *maxpoolingNanOpt, + int *nbDims, int windowDimA[], int paddingA[], int strideA[]) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnPoolingDescriptor_t, int, cudnnPoolingMode_t *, + cudnnNanPropagation_t *, int *, int[], int[], int[]); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetPoolingNdDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(poolingDesc, nbDimsRequested, mode, maxpoolingNanOpt, nbDims, windowDimA, paddingA, strideA); + return func_ptr(poolingDesc, nbDimsRequested, mode, maxpoolingNanOpt, nbDims, + windowDimA, paddingA, strideA); } -cudnnStatus_t CUDNNWINAPI cudnnGetPoolingNdForwardOutputDim( - const cudnnPoolingDescriptor_t poolingDesc, - const cudnnTensorDescriptor_t inputTensorDesc, - int nbDims, - int outputTensorDimA[] ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnPoolingDescriptor_t, const cudnnTensorDescriptor_t, int, int []); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetPoolingNdForwardOutputDim"); +cudnnStatus_t CUDNNWINAPI +cudnnGetPoolingNdForwardOutputDim(const cudnnPoolingDescriptor_t poolingDesc, + const cudnnTensorDescriptor_t inputTensorDesc, + int nbDims, int outputTensorDimA[]) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(const cudnnPoolingDescriptor_t, + const cudnnTensorDescriptor_t, int, int[]); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetPoolingNdForwardOutputDim"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(poolingDesc, inputTensorDesc, nbDims, outputTensorDimA); } -cudnnStatus_t CUDNNWINAPI cudnnGetPooling2dForwardOutputDim( - const cudnnPoolingDescriptor_t poolingDesc, - const cudnnTensorDescriptor_t inputTensorDesc, - int *n, - int *c, - int *h, - int *w ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnPoolingDescriptor_t, const cudnnTensorDescriptor_t, int *, int *, int *, int *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetPooling2dForwardOutputDim"); +cudnnStatus_t CUDNNWINAPI +cudnnGetPooling2dForwardOutputDim(const cudnnPoolingDescriptor_t poolingDesc, + const cudnnTensorDescriptor_t inputTensorDesc, + int *n, int *c, int *h, int *w) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(const cudnnPoolingDescriptor_t, + const cudnnTensorDescriptor_t, + int *, int *, int *, int *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetPooling2dForwardOutputDim"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(poolingDesc, inputTensorDesc, n, c, h, w); } -cudnnStatus_t CUDNNWINAPI cudnnDestroyPoolingDescriptor( - cudnnPoolingDescriptor_t poolingDesc ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnPoolingDescriptor_t); +cudnnStatus_t CUDNNWINAPI +cudnnDestroyPoolingDescriptor(cudnnPoolingDescriptor_t poolingDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnPoolingDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyPoolingDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(poolingDesc); } cudnnStatus_t CUDNNWINAPI cudnnPoolingForward( - cudnnHandle_t handle, - const cudnnPoolingDescriptor_t poolingDesc, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t yDesc, - void *y ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnPoolingDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); + cudnnHandle_t handle, const cudnnPoolingDescriptor_t poolingDesc, + const void *alpha, const cudnnTensorDescriptor_t xDesc, const void *x, + const void *beta, const cudnnTensorDescriptor_t yDesc, void *y) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnPoolingDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnPoolingForward"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, poolingDesc, alpha, xDesc, x, beta, yDesc, y); } cudnnStatus_t CUDNNWINAPI cudnnPoolingBackward( - cudnnHandle_t handle, - const cudnnPoolingDescriptor_t poolingDesc, - const void *alpha, - const cudnnTensorDescriptor_t yDesc, - const void *y, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t dxDesc, - void *dx ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnPoolingDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); + cudnnHandle_t handle, const cudnnPoolingDescriptor_t poolingDesc, + const void *alpha, const cudnnTensorDescriptor_t yDesc, const void *y, + const cudnnTensorDescriptor_t dyDesc, const void *dy, + const cudnnTensorDescriptor_t xDesc, const void *x, const void *beta, + const cudnnTensorDescriptor_t dxDesc, void *dx) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnPoolingDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnPoolingBackward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, poolingDesc, alpha, yDesc, y, dyDesc, dy, xDesc, x, beta, dxDesc, dx); + return func_ptr(handle, poolingDesc, alpha, yDesc, y, dyDesc, dy, xDesc, x, + beta, dxDesc, dx); } -cudnnStatus_t CUDNNWINAPI cudnnCreateActivationDescriptor( - cudnnActivationDescriptor_t *activationDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnActivationDescriptor_t *); +cudnnStatus_t CUDNNWINAPI +cudnnCreateActivationDescriptor(cudnnActivationDescriptor_t *activationDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnActivationDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateActivationDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(activationDesc); } cudnnStatus_t CUDNNWINAPI cudnnSetActivationDescriptor( - cudnnActivationDescriptor_t activationDesc, - cudnnActivationMode_t mode, - cudnnNanPropagation_t reluNanOpt, - double coef ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnActivationDescriptor_t, cudnnActivationMode_t, cudnnNanPropagation_t, double); + cudnnActivationDescriptor_t activationDesc, cudnnActivationMode_t mode, + cudnnNanPropagation_t reluNanOpt, double coef) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnActivationDescriptor_t, + cudnnActivationMode_t, + cudnnNanPropagation_t, double); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetActivationDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(activationDesc, mode, reluNanOpt, coef); } -cudnnStatus_t CUDNNWINAPI cudnnGetActivationDescriptor( - const cudnnActivationDescriptor_t activationDesc, - cudnnActivationMode_t *mode, - cudnnNanPropagation_t *reluNanOpt, - double* coef ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnActivationDescriptor_t, cudnnActivationMode_t *, cudnnNanPropagation_t *, double *); +cudnnStatus_t CUDNNWINAPI +cudnnGetActivationDescriptor(const cudnnActivationDescriptor_t activationDesc, + cudnnActivationMode_t *mode, + cudnnNanPropagation_t *reluNanOpt, double *coef) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnActivationDescriptor_t, cudnnActivationMode_t *, + cudnnNanPropagation_t *, double *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetActivationDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(activationDesc, mode, reluNanOpt, coef); } -cudnnStatus_t CUDNNWINAPI cudnnDestroyActivationDescriptor( - cudnnActivationDescriptor_t activationDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnActivationDescriptor_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyActivationDescriptor"); +cudnnStatus_t CUDNNWINAPI +cudnnDestroyActivationDescriptor(cudnnActivationDescriptor_t activationDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnActivationDescriptor_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnDestroyActivationDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(activationDesc); } cudnnStatus_t CUDNNWINAPI cudnnActivationForward( - cudnnHandle_t handle, - cudnnActivationDescriptor_t activationDesc, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t yDesc, - void *y ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnActivationDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); + cudnnHandle_t handle, cudnnActivationDescriptor_t activationDesc, + const void *alpha, const cudnnTensorDescriptor_t xDesc, const void *x, + const void *beta, const cudnnTensorDescriptor_t yDesc, void *y) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnActivationDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnActivationForward"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, activationDesc, alpha, xDesc, x, beta, yDesc, y); } cudnnStatus_t CUDNNWINAPI cudnnActivationBackward( - cudnnHandle_t handle, - cudnnActivationDescriptor_t activationDesc, - const void *alpha, - const cudnnTensorDescriptor_t yDesc, - const void *y, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t dxDesc, - void *dx ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnActivationDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); + cudnnHandle_t handle, cudnnActivationDescriptor_t activationDesc, + const void *alpha, const cudnnTensorDescriptor_t yDesc, const void *y, + const cudnnTensorDescriptor_t dyDesc, const void *dy, + const cudnnTensorDescriptor_t xDesc, const void *x, const void *beta, + const cudnnTensorDescriptor_t dxDesc, void *dx) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnActivationDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnActivationBackward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, activationDesc, alpha, yDesc, y, dyDesc, dy, xDesc, x, beta, dxDesc, dx); + return func_ptr(handle, activationDesc, alpha, yDesc, y, dyDesc, dy, xDesc, x, + beta, dxDesc, dx); } -cudnnStatus_t CUDNNWINAPI cudnnCreateLRNDescriptor( - cudnnLRNDescriptor_t *normDesc ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnLRNDescriptor_t *); +cudnnStatus_t CUDNNWINAPI +cudnnCreateLRNDescriptor(cudnnLRNDescriptor_t *normDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnLRNDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateLRNDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(normDesc); } -cudnnStatus_t CUDNNWINAPI cudnnSetLRNDescriptor( - cudnnLRNDescriptor_t normDesc, - unsigned lrnN, - double lrnAlpha, - double lrnBeta, - double lrnK ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnLRNDescriptor_t, unsigned int, double, double, double); +cudnnStatus_t CUDNNWINAPI cudnnSetLRNDescriptor(cudnnLRNDescriptor_t normDesc, + unsigned lrnN, double lrnAlpha, + double lrnBeta, double lrnK) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnLRNDescriptor_t, unsigned int, double, double, double); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetLRNDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(normDesc, lrnN, lrnAlpha, lrnBeta, lrnK); } -cudnnStatus_t CUDNNWINAPI cudnnGetLRNDescriptor( - cudnnLRNDescriptor_t normDesc, - unsigned* lrnN, - double* lrnAlpha, - double* lrnBeta, - double* lrnK ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnLRNDescriptor_t, unsigned int *, double *, double *, double *); +cudnnStatus_t CUDNNWINAPI cudnnGetLRNDescriptor(cudnnLRNDescriptor_t normDesc, + unsigned *lrnN, + double *lrnAlpha, + double *lrnBeta, double *lrnK) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnLRNDescriptor_t, unsigned int *, double *, double *, double *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetLRNDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(normDesc, lrnN, lrnAlpha, lrnBeta, lrnK); } -cudnnStatus_t CUDNNWINAPI cudnnDestroyLRNDescriptor( cudnnLRNDescriptor_t lrnDesc ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnLRNDescriptor_t); +cudnnStatus_t CUDNNWINAPI +cudnnDestroyLRNDescriptor(cudnnLRNDescriptor_t lrnDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnLRNDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyLRNDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(lrnDesc); } cudnnStatus_t CUDNNWINAPI cudnnLRNCrossChannelForward( - cudnnHandle_t handle, - cudnnLRNDescriptor_t normDesc, - cudnnLRNMode_t lrnMode, - const void* alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t yDesc, - void *y ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnLRNDescriptor_t, cudnnLRNMode_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); + cudnnHandle_t handle, cudnnLRNDescriptor_t normDesc, cudnnLRNMode_t lrnMode, + const void *alpha, const cudnnTensorDescriptor_t xDesc, const void *x, + const void *beta, const cudnnTensorDescriptor_t yDesc, void *y) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnLRNDescriptor_t, cudnnLRNMode_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnLRNCrossChannelForward"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, normDesc, lrnMode, alpha, xDesc, x, beta, yDesc, y); } cudnnStatus_t CUDNNWINAPI cudnnLRNCrossChannelBackward( - cudnnHandle_t handle, - cudnnLRNDescriptor_t normDesc, - cudnnLRNMode_t lrnMode, - const void* alpha, - const cudnnTensorDescriptor_t yDesc, - const void *y, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t dxDesc, - void *dx) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnLRNDescriptor_t, cudnnLRNMode_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); + cudnnHandle_t handle, cudnnLRNDescriptor_t normDesc, cudnnLRNMode_t lrnMode, + const void *alpha, const cudnnTensorDescriptor_t yDesc, const void *y, + const cudnnTensorDescriptor_t dyDesc, const void *dy, + const cudnnTensorDescriptor_t xDesc, const void *x, const void *beta, + const cudnnTensorDescriptor_t dxDesc, void *dx) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnLRNDescriptor_t, cudnnLRNMode_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnLRNCrossChannelBackward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, normDesc, lrnMode, alpha, yDesc, y, dyDesc, dy, xDesc, x, beta, dxDesc, dx); + return func_ptr(handle, normDesc, lrnMode, alpha, yDesc, y, dyDesc, dy, xDesc, + x, beta, dxDesc, dx); } cudnnStatus_t CUDNNWINAPI cudnnDivisiveNormalizationForward( - cudnnHandle_t handle, - cudnnLRNDescriptor_t normDesc, - cudnnDivNormMode_t mode, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, /* same desc for means, temp, temp2 */ - const void *x, - const void *means, /* if NULL, means are assumed to be zero */ - void *temp, - void *temp2, - const void *beta, - const cudnnTensorDescriptor_t yDesc, - void *y ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnLRNDescriptor_t, cudnnDivNormMode_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, void *, void *, const void *, const cudnnTensorDescriptor_t, void *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDivisiveNormalizationForward"); + cudnnHandle_t handle, cudnnLRNDescriptor_t normDesc, + cudnnDivNormMode_t mode, const void *alpha, + const cudnnTensorDescriptor_t xDesc, /* same desc for means, temp, temp2 */ + const void *x, + const void *means, /* if NULL, means are assumed to be zero */ + void *temp, void *temp2, const void *beta, + const cudnnTensorDescriptor_t yDesc, void *y) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnLRNDescriptor_t, cudnnDivNormMode_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, void *, void *, + const void *, const cudnnTensorDescriptor_t, void *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnDivisiveNormalizationForward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, normDesc, mode, alpha, xDesc, x, means, temp, temp2, beta, yDesc, y); + return func_ptr(handle, normDesc, mode, alpha, xDesc, x, means, temp, temp2, + beta, yDesc, y); } cudnnStatus_t CUDNNWINAPI cudnnDivisiveNormalizationBackward( - cudnnHandle_t handle, - cudnnLRNDescriptor_t normDesc, - cudnnDivNormMode_t mode, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, /* same desc for x, means, dy, temp, temp2 */ - const void *x, - const void *means, /* if NULL, means are assumed to be zero */ - const void *dy, - void *temp, - void *temp2, - const void *beta, - const cudnnTensorDescriptor_t dXdMeansDesc, /* same desc for dx, dMeans */ - void *dx, /* output x differential */ - void *dMeans ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnLRNDescriptor_t, cudnnDivNormMode_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const void *, void *, void *, const void *, const cudnnTensorDescriptor_t, void *, void *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDivisiveNormalizationBackward"); + cudnnHandle_t handle, cudnnLRNDescriptor_t normDesc, + cudnnDivNormMode_t mode, const void *alpha, + const cudnnTensorDescriptor_t + xDesc, /* same desc for x, means, dy, temp, temp2 */ + const void *x, + const void *means, /* if NULL, means are assumed to be zero */ + const void *dy, void *temp, void *temp2, const void *beta, + const cudnnTensorDescriptor_t dXdMeansDesc, /* same desc for dx, dMeans */ + void *dx, /* output x differential */ + void *dMeans) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnLRNDescriptor_t, cudnnDivNormMode_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, const void *, + void *, void *, const void *, const cudnnTensorDescriptor_t, void *, + void *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnDivisiveNormalizationBackward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, normDesc, mode, alpha, xDesc, x, means, dy, temp, temp2, beta, dXdMeansDesc, dx, dMeans); + return func_ptr(handle, normDesc, mode, alpha, xDesc, x, means, dy, temp, + temp2, beta, dXdMeansDesc, dx, dMeans); } cudnnStatus_t CUDNNWINAPI cudnnDeriveBNTensorDescriptor( - cudnnTensorDescriptor_t derivedBnDesc, - const cudnnTensorDescriptor_t xDesc, - cudnnBatchNormMode_t mode ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, cudnnBatchNormMode_t); + cudnnTensorDescriptor_t derivedBnDesc, const cudnnTensorDescriptor_t xDesc, + cudnnBatchNormMode_t mode) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnTensorDescriptor_t, + const cudnnTensorDescriptor_t, + cudnnBatchNormMode_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDeriveBNTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(derivedBnDesc, xDesc, mode); } cudnnStatus_t CUDNNWINAPI cudnnBatchNormalizationForwardTraining( - cudnnHandle_t handle, - cudnnBatchNormMode_t mode, + cudnnHandle_t handle, cudnnBatchNormMode_t mode, - const void *alpha, /* alpha[0] = result blend factor */ - const void *beta, /* beta[0] = dest layer blend factor */ + const void *alpha, /* alpha[0] = result blend factor */ + const void *beta, /* beta[0] = dest layer blend factor */ - const cudnnTensorDescriptor_t xDesc, - const void *x, /* NxCxHxW */ - const cudnnTensorDescriptor_t yDesc, - void *y, /* NxCxHxW */ + const cudnnTensorDescriptor_t xDesc, const void *x, /* NxCxHxW */ + const cudnnTensorDescriptor_t yDesc, void *y, /* NxCxHxW */ - /* Shared desc for the next 6 tensors in the argument list. - Data type to be set as follows: - type = (typeOf(x) == double) ? double : float - Dimensions for this descriptor depend on normalization mode - - Spatial Normalization : tensors are expected to have dims 1xCx1x1 - (normalization is performed across NxHxW) - - Per-Activation Normalization : tensors are expected to have dims of 1xCxHxW - (normalization is performed across N) */ - const cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc, + /* Shared desc for the next 6 tensors in the argument list. + Data type to be set as follows: + type = (typeOf(x) == double) ? double : float + Dimensions for this descriptor depend on normalization mode + - Spatial Normalization : tensors are expected to have dims 1xCx1x1 + (normalization is performed across NxHxW) + - Per-Activation Normalization : tensors are expected to have dims of + 1xCxHxW (normalization is performed across N) */ + const cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc, - /* 'Gamma' and 'Beta' respectively in Ioffe and Szegedy's paper's notation */ - const void *bnScale, - const void *bnBias, + /* 'Gamma' and 'Beta' respectively in Ioffe and Szegedy's paper's notation + */ + const void *bnScale, const void *bnBias, - /* MUST use factor=1 in the very first call of a complete training cycle. - Use a factor=1/(1+n) at N-th call to the function to get - Cumulative Moving Average (CMA) behavior - CMA[n] = (x[1]+...+x[n])/n - Since CMA[n+1] = (n*CMA[n]+x[n+1])/(n+1) = - ((n+1)*CMA[n]-CMA[n])/(n+1) + x[n+1]/(n+1) = - CMA[n]*(1-1/(n+1)) + x[n+1]*1/(n+1) */ - double exponentialAverageFactor, + /* MUST use factor=1 in the very first call of a complete training cycle. + Use a factor=1/(1+n) at N-th call to the function to get + Cumulative Moving Average (CMA) behavior + CMA[n] = (x[1]+...+x[n])/n + Since CMA[n+1] = (n*CMA[n]+x[n+1])/(n+1) = + ((n+1)*CMA[n]-CMA[n])/(n+1) + x[n+1]/(n+1) = + CMA[n]*(1-1/(n+1)) + x[n+1]*1/(n+1) */ + double exponentialAverageFactor, - /* Used in Training phase only. - runningMean = newMean*factor + runningMean*(1-factor) */ - void *resultRunningMean, - /* Output in training mode, input in inference. Is the moving average - of variance[x] (factor is applied in the same way as for runningMean) */ - void *resultRunningVariance, + /* Used in Training phase only. + runningMean = newMean*factor + runningMean*(1-factor) */ + void *resultRunningMean, + /* Output in training mode, input in inference. Is the moving average + of variance[x] (factor is applied in the same way as for runningMean) */ + void *resultRunningVariance, - /* Has to be >= CUDNN_BN_MIN_EPSILON. Should be the same in forward and backward functions. */ - double epsilon, + /* Has to be >= CUDNN_BN_MIN_EPSILON. Should be the same in forward and + backward functions. */ + double epsilon, - /* Optionally save intermediate results from the forward pass here - - can be reused to speed up backward pass. NULL if unused */ - void *resultSaveMean, - void *resultSaveInvVariance ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnBatchNormMode_t, const void *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, const void *, const void *, double, void *, void *, double, void *, void *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnBatchNormalizationForwardTraining"); + /* Optionally save intermediate results from the forward pass here + - can be reused to speed up backward pass. NULL if unused */ + void *resultSaveMean, void *resultSaveInvVariance) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnBatchNormMode_t, const void *, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, + const void *, const void *, double, void *, void *, double, void *, + void *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnBatchNormalizationForwardTraining"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, mode, alpha, beta, xDesc, x, yDesc, y, bnScaleBiasMeanVarDesc, bnScale, bnBias, exponentialAverageFactor, resultRunningMean, resultRunningVariance, epsilon, resultSaveMean, resultSaveInvVariance); + return func_ptr( + handle, mode, alpha, beta, xDesc, x, yDesc, y, bnScaleBiasMeanVarDesc, + bnScale, bnBias, exponentialAverageFactor, resultRunningMean, + resultRunningVariance, epsilon, resultSaveMean, resultSaveInvVariance); } cudnnStatus_t CUDNNWINAPI cudnnBatchNormalizationForwardInference( - cudnnHandle_t handle, - cudnnBatchNormMode_t mode, - const void *alpha, /* alpha[0] = result blend factor */ - const void *beta, /* beta[0] = dest layer blend factor */ - const cudnnTensorDescriptor_t xDesc, - const void *x, /* NxCxHxW */ - const cudnnTensorDescriptor_t yDesc, - void *y, /* NxCxHxW */ - const cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc, - const void *bnScale, - const void *bnBias, - const void *estimatedMean, - const void *estimatedVariance, - double epsilon ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnBatchNormMode_t, const void *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, const void *, const void *, const void *, const void *, double); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnBatchNormalizationForwardInference"); + cudnnHandle_t handle, cudnnBatchNormMode_t mode, + const void *alpha, /* alpha[0] = result blend factor */ + const void *beta, /* beta[0] = dest layer blend factor */ + const cudnnTensorDescriptor_t xDesc, const void *x, /* NxCxHxW */ + const cudnnTensorDescriptor_t yDesc, void *y, /* NxCxHxW */ + const cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc, const void *bnScale, + const void *bnBias, const void *estimatedMean, + const void *estimatedVariance, double epsilon) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnBatchNormMode_t, const void *, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, + const void *, const void *, const void *, const void *, double); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnBatchNormalizationForwardInference"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, mode, alpha, beta, xDesc, x, yDesc, y, bnScaleBiasMeanVarDesc, bnScale, bnBias, estimatedMean, estimatedVariance, epsilon); + return func_ptr(handle, mode, alpha, beta, xDesc, x, yDesc, y, + bnScaleBiasMeanVarDesc, bnScale, bnBias, estimatedMean, + estimatedVariance, epsilon); } cudnnStatus_t CUDNNWINAPI cudnnBatchNormalizationBackward( - cudnnHandle_t handle, - cudnnBatchNormMode_t mode, - const void *alphaDataDiff, - const void *betaDataDiff, - const void *alphaParamDiff, - const void *betaParamDiff, - const cudnnTensorDescriptor_t xDesc, /* same desc for x, dx, dy */ - const void *x, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const cudnnTensorDescriptor_t dxDesc, - void *dx, - /* Shared tensor desc for the 4 tensors below */ - const cudnnTensorDescriptor_t dBnScaleBiasDesc, - const void *bnScale, /* bnBias doesn't affect backpropagation */ - /* scale and bias diff are not backpropagated below this layer */ - void *dBnScaleResult, - void *dBnBiasResult, - /* Same epsilon as forward pass */ - double epsilon, + cudnnHandle_t handle, cudnnBatchNormMode_t mode, const void *alphaDataDiff, + const void *betaDataDiff, const void *alphaParamDiff, + const void *betaParamDiff, + const cudnnTensorDescriptor_t xDesc, /* same desc for x, dx, dy */ + const void *x, const cudnnTensorDescriptor_t dyDesc, const void *dy, + const cudnnTensorDescriptor_t dxDesc, void *dx, + /* Shared tensor desc for the 4 tensors below */ + const cudnnTensorDescriptor_t dBnScaleBiasDesc, + const void *bnScale, /* bnBias doesn't affect backpropagation */ + /* scale and bias diff are not backpropagated below this layer */ + void *dBnScaleResult, void *dBnBiasResult, + /* Same epsilon as forward pass */ + double epsilon, - /* Optionally cached intermediate results from - forward pass */ - const void *savedMean, - const void *savedInvVariance ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnBatchNormMode_t, const void *, const void *, const void *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, const void *, void *, void *, double, const void *, const void *); + /* Optionally cached intermediate results from + forward pass */ + const void *savedMean, const void *savedInvVariance) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnBatchNormMode_t, const void *, const void *, + const void *, const void *, const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, + const void *, void *, void *, double, const void *, const void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnBatchNormalizationBackward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, mode, alphaDataDiff, betaDataDiff, alphaParamDiff, betaParamDiff, xDesc, x, dyDesc, dy, dxDesc, dx, dBnScaleBiasDesc, bnScale, dBnScaleResult, dBnBiasResult, epsilon, savedMean, savedInvVariance); + return func_ptr(handle, mode, alphaDataDiff, betaDataDiff, alphaParamDiff, + betaParamDiff, xDesc, x, dyDesc, dy, dxDesc, dx, + dBnScaleBiasDesc, bnScale, dBnScaleResult, dBnBiasResult, + epsilon, savedMean, savedInvVariance); } cudnnStatus_t CUDNNWINAPI cudnnCreateSpatialTransformerDescriptor( - cudnnSpatialTransformerDescriptor_t *stDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnSpatialTransformerDescriptor_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateSpatialTransformerDescriptor"); + cudnnSpatialTransformerDescriptor_t *stDesc) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnSpatialTransformerDescriptor_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnCreateSpatialTransformerDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(stDesc); } cudnnStatus_t CUDNNWINAPI cudnnSetSpatialTransformerNdDescriptor( - cudnnSpatialTransformerDescriptor_t stDesc, - cudnnSamplerType_t samplerType, - cudnnDataType_t dataType, - const int nbDims, - const int dimA[]) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnSpatialTransformerDescriptor_t, cudnnSamplerType_t, cudnnDataType_t, const int, const int []); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetSpatialTransformerNdDescriptor"); + cudnnSpatialTransformerDescriptor_t stDesc, cudnnSamplerType_t samplerType, + cudnnDataType_t dataType, const int nbDims, const int dimA[]) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnSpatialTransformerDescriptor_t, cudnnSamplerType_t, cudnnDataType_t, + const int, const int[]); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnSetSpatialTransformerNdDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(stDesc, samplerType, dataType, nbDims, dimA); } cudnnStatus_t CUDNNWINAPI cudnnDestroySpatialTransformerDescriptor( - cudnnSpatialTransformerDescriptor_t stDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnSpatialTransformerDescriptor_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroySpatialTransformerDescriptor"); + cudnnSpatialTransformerDescriptor_t stDesc) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnSpatialTransformerDescriptor_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnDestroySpatialTransformerDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(stDesc); } cudnnStatus_t CUDNNWINAPI cudnnSpatialTfGridGeneratorForward( - cudnnHandle_t handle, - const cudnnSpatialTransformerDescriptor_t stDesc, - const void *theta, - void *grid) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnSpatialTransformerDescriptor_t, const void *, void *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSpatialTfGridGeneratorForward"); + cudnnHandle_t handle, const cudnnSpatialTransformerDescriptor_t stDesc, + const void *theta, void *grid) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnSpatialTransformerDescriptor_t, const void *, + void *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnSpatialTfGridGeneratorForward"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, stDesc, theta, grid); } cudnnStatus_t CUDNNWINAPI cudnnSpatialTfGridGeneratorBackward( - cudnnHandle_t handle, - const cudnnSpatialTransformerDescriptor_t stDesc, - const void *dgrid, - void *dtheta) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnSpatialTransformerDescriptor_t, const void *, void *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSpatialTfGridGeneratorBackward"); + cudnnHandle_t handle, const cudnnSpatialTransformerDescriptor_t stDesc, + const void *dgrid, void *dtheta) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnSpatialTransformerDescriptor_t, const void *, + void *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnSpatialTfGridGeneratorBackward"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, stDesc, dgrid, dtheta); } cudnnStatus_t CUDNNWINAPI cudnnSpatialTfSamplerForward( - cudnnHandle_t handle, - cudnnSpatialTransformerDescriptor_t stDesc, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *grid, - const void *beta, - cudnnTensorDescriptor_t yDesc, - void *y) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnSpatialTransformerDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const void *, cudnnTensorDescriptor_t, void *); + cudnnHandle_t handle, cudnnSpatialTransformerDescriptor_t stDesc, + const void *alpha, const cudnnTensorDescriptor_t xDesc, const void *x, + const void *grid, const void *beta, cudnnTensorDescriptor_t yDesc, + void *y) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnSpatialTransformerDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, const void *, + cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSpatialTfSamplerForward"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, stDesc, alpha, xDesc, x, grid, beta, yDesc, y); } cudnnStatus_t CUDNNWINAPI cudnnSpatialTfSamplerBackward( - cudnnHandle_t handle, - cudnnSpatialTransformerDescriptor_t stDesc, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t dxDesc, - void *dx, - const void *alphaDgrid, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const void *grid, - const void *betaDgrid, - void *dgrid) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnSpatialTransformerDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const void *, void *); + cudnnHandle_t handle, cudnnSpatialTransformerDescriptor_t stDesc, + const void *alpha, const cudnnTensorDescriptor_t xDesc, const void *x, + const void *beta, const cudnnTensorDescriptor_t dxDesc, void *dx, + const void *alphaDgrid, const cudnnTensorDescriptor_t dyDesc, + const void *dy, const void *grid, const void *betaDgrid, void *dgrid) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnSpatialTransformerDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, const void *, + void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSpatialTfSamplerBackward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, stDesc, alpha, xDesc, x, beta, dxDesc, dx, alphaDgrid, dyDesc, dy, grid, betaDgrid, dgrid); + return func_ptr(handle, stDesc, alpha, xDesc, x, beta, dxDesc, dx, alphaDgrid, + dyDesc, dy, grid, betaDgrid, dgrid); } -cudnnStatus_t CUDNNWINAPI cudnnCreateDropoutDescriptor(cudnnDropoutDescriptor_t * dropoutDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnDropoutDescriptor_t *); +cudnnStatus_t CUDNNWINAPI +cudnnCreateDropoutDescriptor(cudnnDropoutDescriptor_t *dropoutDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnDropoutDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateDropoutDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(dropoutDesc); } -cudnnStatus_t CUDNNWINAPI cudnnDestroyDropoutDescriptor(cudnnDropoutDescriptor_t dropoutDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnDropoutDescriptor_t); +cudnnStatus_t CUDNNWINAPI +cudnnDestroyDropoutDescriptor(cudnnDropoutDescriptor_t dropoutDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnDropoutDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyDropoutDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(dropoutDesc); } -cudnnStatus_t CUDNNWINAPI cudnnDropoutGetStatesSize(cudnnHandle_t handle, size_t * sizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, size_t *); +cudnnStatus_t CUDNNWINAPI cudnnDropoutGetStatesSize(cudnnHandle_t handle, + size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDropoutGetStatesSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, sizeInBytes); } -cudnnStatus_t CUDNNWINAPI cudnnDropoutGetReserveSpaceSize(cudnnTensorDescriptor_t xdesc, size_t * sizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorDescriptor_t, size_t *); +cudnnStatus_t CUDNNWINAPI cudnnDropoutGetReserveSpaceSize( + cudnnTensorDescriptor_t xdesc, size_t *sizeInBytes) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnTensorDescriptor_t, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDropoutGetReserveSpaceSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(xdesc, sizeInBytes); } -cudnnStatus_t CUDNNWINAPI cudnnSetDropoutDescriptor(cudnnDropoutDescriptor_t dropoutDesc, - cudnnHandle_t handle, - float dropout, - void * states, - size_t stateSizeInBytes, - unsigned long long seed) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnDropoutDescriptor_t, cudnnHandle_t, float, void *, size_t, unsigned long long); +cudnnStatus_t CUDNNWINAPI cudnnSetDropoutDescriptor( + cudnnDropoutDescriptor_t dropoutDesc, cudnnHandle_t handle, float dropout, + void *states, size_t stateSizeInBytes, unsigned long long seed) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnDropoutDescriptor_t, cudnnHandle_t, + float, void *, size_t, unsigned long long); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetDropoutDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(dropoutDesc, handle, dropout, states, stateSizeInBytes, seed); } -cudnnStatus_t CUDNNWINAPI cudnnRestoreDropoutDescriptor(cudnnDropoutDescriptor_t dropoutDesc, - cudnnHandle_t handle, - float dropout, - void * states, - size_t stateSizeInBytes, - unsigned long long seed) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnDropoutDescriptor_t, cudnnHandle_t, float, void *, size_t, unsigned long long); +cudnnStatus_t CUDNNWINAPI cudnnRestoreDropoutDescriptor( + cudnnDropoutDescriptor_t dropoutDesc, cudnnHandle_t handle, float dropout, + void *states, size_t stateSizeInBytes, unsigned long long seed) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnDropoutDescriptor_t, cudnnHandle_t, + float, void *, size_t, unsigned long long); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRestoreDropoutDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(dropoutDesc, handle, dropout, states, stateSizeInBytes, seed); } -cudnnStatus_t CUDNNWINAPI cudnnGetDropoutDescriptor(cudnnDropoutDescriptor_t dropoutDesc, - cudnnHandle_t handle, - float * dropout, - void ** states, - unsigned long long * seed) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnDropoutDescriptor_t, cudnnHandle_t, float *, void **, unsigned long long *); +cudnnStatus_t CUDNNWINAPI cudnnGetDropoutDescriptor( + cudnnDropoutDescriptor_t dropoutDesc, cudnnHandle_t handle, float *dropout, + void **states, unsigned long long *seed) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnDropoutDescriptor_t, cudnnHandle_t, + float *, void **, unsigned long long *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetDropoutDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(dropoutDesc, handle, dropout, states, seed); } -cudnnStatus_t CUDNNWINAPI cudnnDropoutForward(cudnnHandle_t handle, - const cudnnDropoutDescriptor_t dropoutDesc, - const cudnnTensorDescriptor_t xdesc, - const void * x, - const cudnnTensorDescriptor_t ydesc, - void * y, - void * reserveSpace, - size_t reserveSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnDropoutDescriptor_t, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, void *, void *, size_t); +cudnnStatus_t CUDNNWINAPI cudnnDropoutForward( + cudnnHandle_t handle, const cudnnDropoutDescriptor_t dropoutDesc, + const cudnnTensorDescriptor_t xdesc, const void *x, + const cudnnTensorDescriptor_t ydesc, void *y, void *reserveSpace, + size_t reserveSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnDropoutDescriptor_t, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, void *, void *, size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDropoutForward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, dropoutDesc, xdesc, x, ydesc, y, reserveSpace, reserveSpaceSizeInBytes); + return func_ptr(handle, dropoutDesc, xdesc, x, ydesc, y, reserveSpace, + reserveSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI cudnnDropoutBackward(cudnnHandle_t handle, - const cudnnDropoutDescriptor_t dropoutDesc, - const cudnnTensorDescriptor_t dydesc, - const void * dy, - const cudnnTensorDescriptor_t dxdesc, - void * dx, - void * reserveSpace, - size_t reserveSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnDropoutDescriptor_t, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, void *, void *, size_t); +cudnnStatus_t CUDNNWINAPI cudnnDropoutBackward( + cudnnHandle_t handle, const cudnnDropoutDescriptor_t dropoutDesc, + const cudnnTensorDescriptor_t dydesc, const void *dy, + const cudnnTensorDescriptor_t dxdesc, void *dx, void *reserveSpace, + size_t reserveSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnDropoutDescriptor_t, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, void *, void *, size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDropoutBackward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, dropoutDesc, dydesc, dy, dxdesc, dx, reserveSpace, reserveSpaceSizeInBytes); + return func_ptr(handle, dropoutDesc, dydesc, dy, dxdesc, dx, reserveSpace, + reserveSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI cudnnCreateRNNDescriptor(cudnnRNNDescriptor_t * rnnDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t *); +cudnnStatus_t CUDNNWINAPI +cudnnCreateRNNDescriptor(cudnnRNNDescriptor_t *rnnDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnRNNDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateRNNDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(rnnDesc); } -cudnnStatus_t CUDNNWINAPI cudnnDestroyRNNDescriptor(cudnnRNNDescriptor_t rnnDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t); +cudnnStatus_t CUDNNWINAPI +cudnnDestroyRNNDescriptor(cudnnRNNDescriptor_t rnnDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnRNNDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyRNNDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(rnnDesc); } -cudnnStatus_t CUDNNWINAPI cudnnCreatePersistentRNNPlan(cudnnRNNDescriptor_t rnnDesc, - const int minibatch, - const cudnnDataType_t dataType, - cudnnPersistentRNNPlan_t * plan) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t, const int, const cudnnDataType_t, cudnnPersistentRNNPlan_t *); +cudnnStatus_t CUDNNWINAPI cudnnCreatePersistentRNNPlan( + cudnnRNNDescriptor_t rnnDesc, const int minibatch, + const cudnnDataType_t dataType, cudnnPersistentRNNPlan_t *plan) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnRNNDescriptor_t, const int, + const cudnnDataType_t, + cudnnPersistentRNNPlan_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreatePersistentRNNPlan"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(rnnDesc, minibatch, dataType, plan); } -cudnnStatus_t CUDNNWINAPI cudnnSetPersistentRNNPlan(cudnnRNNDescriptor_t rnnDesc, - cudnnPersistentRNNPlan_t plan) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t, cudnnPersistentRNNPlan_t); +cudnnStatus_t CUDNNWINAPI cudnnSetPersistentRNNPlan( + cudnnRNNDescriptor_t rnnDesc, cudnnPersistentRNNPlan_t plan) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnRNNDescriptor_t, + cudnnPersistentRNNPlan_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetPersistentRNNPlan"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(rnnDesc, plan); } -cudnnStatus_t CUDNNWINAPI cudnnDestroyPersistentRNNPlan(cudnnPersistentRNNPlan_t plan) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnPersistentRNNPlan_t); +cudnnStatus_t CUDNNWINAPI +cudnnDestroyPersistentRNNPlan(cudnnPersistentRNNPlan_t plan) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnPersistentRNNPlan_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyPersistentRNNPlan"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(plan); } -cudnnStatus_t CUDNNWINAPI cudnnSetRNNDescriptor(cudnnHandle_t handle, - cudnnRNNDescriptor_t rnnDesc, - const int hiddenSize, - const int numLayers, - cudnnDropoutDescriptor_t dropoutDesc, /* Between layers, not between recurrent steps. */ - cudnnRNNInputMode_t inputMode, - cudnnDirectionMode_t direction, - cudnnRNNMode_t mode, - cudnnRNNAlgo_t algo, - cudnnDataType_t dataType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnRNNDescriptor_t, const int, const int, cudnnDropoutDescriptor_t, cudnnRNNInputMode_t, cudnnDirectionMode_t, cudnnRNNMode_t, cudnnRNNAlgo_t, cudnnDataType_t); +cudnnStatus_t CUDNNWINAPI cudnnSetRNNDescriptor( + cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc, const int hiddenSize, + const int numLayers, + cudnnDropoutDescriptor_t + dropoutDesc, /* Between layers, not between recurrent steps. */ + cudnnRNNInputMode_t inputMode, cudnnDirectionMode_t direction, + cudnnRNNMode_t mode, cudnnRNNAlgo_t algo, cudnnDataType_t dataType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnRNNDescriptor_t, const int, const int, + cudnnDropoutDescriptor_t, cudnnRNNInputMode_t, cudnnDirectionMode_t, + cudnnRNNMode_t, cudnnRNNAlgo_t, cudnnDataType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetRNNDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, hiddenSize, numLayers, dropoutDesc, inputMode, direction, mode, algo, dataType); + return func_ptr(handle, rnnDesc, hiddenSize, numLayers, dropoutDesc, + inputMode, direction, mode, algo, dataType); } -cudnnStatus_t CUDNNWINAPI cudnnGetRNNDescriptor(cudnnHandle_t cudnnHandle, - cudnnRNNDescriptor_t rnnDesc, - int * hiddenSize, - int * numLayers, - cudnnDropoutDescriptor_t * dropoutDesc, - cudnnRNNInputMode_t * inputMode, - cudnnDirectionMode_t * direction, - cudnnRNNMode_t * mode, - cudnnRNNAlgo_t * algo, - cudnnDataType_t * dataType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnRNNDescriptor_t, int *, int *, cudnnDropoutDescriptor_t *, cudnnRNNInputMode_t *, cudnnDirectionMode_t *, cudnnRNNMode_t *, cudnnRNNAlgo_t *, cudnnDataType_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetRNNDescriptor( + cudnnHandle_t cudnnHandle, cudnnRNNDescriptor_t rnnDesc, int *hiddenSize, + int *numLayers, cudnnDropoutDescriptor_t *dropoutDesc, + cudnnRNNInputMode_t *inputMode, cudnnDirectionMode_t *direction, + cudnnRNNMode_t *mode, cudnnRNNAlgo_t *algo, cudnnDataType_t *dataType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnRNNDescriptor_t, int *, int *, + cudnnDropoutDescriptor_t *, cudnnRNNInputMode_t *, cudnnDirectionMode_t *, + cudnnRNNMode_t *, cudnnRNNAlgo_t *, cudnnDataType_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(cudnnHandle, rnnDesc, hiddenSize, numLayers, dropoutDesc, inputMode, direction, mode, algo, dataType); + return func_ptr(cudnnHandle, rnnDesc, hiddenSize, numLayers, dropoutDesc, + inputMode, direction, mode, algo, dataType); } -cudnnStatus_t CUDNNWINAPI cudnnSetRNNMatrixMathType (cudnnRNNDescriptor_t desc, cudnnMathType_t math) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t, cudnnMathType_t); +cudnnStatus_t CUDNNWINAPI cudnnSetRNNMatrixMathType(cudnnRNNDescriptor_t desc, + cudnnMathType_t math) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnRNNDescriptor_t, cudnnMathType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetRNNMatrixMathType"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(desc, math); } -cudnnStatus_t CUDNNWINAPI cudnnGetRNNWorkspaceSize( cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int seqLength, - const cudnnTensorDescriptor_t *xDesc, - size_t *sizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, size_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetRNNWorkspaceSize( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, const cudnnTensorDescriptor_t *xDesc, + size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t *, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNWorkspaceSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rnnDesc, seqLength, xDesc, sizeInBytes); } -cudnnStatus_t CUDNNWINAPI cudnnGetRNNTrainingReserveSize( cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int seqLength, - const cudnnTensorDescriptor_t *xDesc, - size_t *sizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, size_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetRNNTrainingReserveSize( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, const cudnnTensorDescriptor_t *xDesc, + size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t *, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNTrainingReserveSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rnnDesc, seqLength, xDesc, sizeInBytes); } -cudnnStatus_t CUDNNWINAPI cudnnGetRNNParamsSize( cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const cudnnTensorDescriptor_t xDesc, - size_t *sizeInBytes, - cudnnDataType_t dataType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const cudnnTensorDescriptor_t, size_t *, cudnnDataType_t); +cudnnStatus_t CUDNNWINAPI +cudnnGetRNNParamsSize(cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const cudnnTensorDescriptor_t xDesc, size_t *sizeInBytes, + cudnnDataType_t dataType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const cudnnTensorDescriptor_t, + size_t *, cudnnDataType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNParamsSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rnnDesc, xDesc, sizeInBytes, dataType); } -cudnnStatus_t CUDNNWINAPI cudnnGetRNNLinLayerMatrixParams( cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int layer, - const cudnnTensorDescriptor_t xDesc, - const cudnnFilterDescriptor_t wDesc, - const void * w, - const int linLayerID, - cudnnFilterDescriptor_t linLayerMatDesc, - void ** linLayerMat) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, const void *, const int, cudnnFilterDescriptor_t, void **); +cudnnStatus_t CUDNNWINAPI cudnnGetRNNLinLayerMatrixParams( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, const int layer, + const cudnnTensorDescriptor_t xDesc, const cudnnFilterDescriptor_t wDesc, + const void *w, const int linLayerID, + cudnnFilterDescriptor_t linLayerMatDesc, void **linLayerMat) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, + const void *, const int, cudnnFilterDescriptor_t, void **); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNLinLayerMatrixParams"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, layer, xDesc, wDesc, w, linLayerID, linLayerMatDesc, linLayerMat); + return func_ptr(handle, rnnDesc, layer, xDesc, wDesc, w, linLayerID, + linLayerMatDesc, linLayerMat); } -cudnnStatus_t CUDNNWINAPI cudnnGetRNNLinLayerBiasParams( cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int layer, - const cudnnTensorDescriptor_t xDesc, - const cudnnFilterDescriptor_t wDesc, - const void * w, - const int linLayerID, - cudnnFilterDescriptor_t linLayerBiasDesc, - void ** linLayerBias) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, const void *, const int, cudnnFilterDescriptor_t, void **); +cudnnStatus_t CUDNNWINAPI cudnnGetRNNLinLayerBiasParams( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, const int layer, + const cudnnTensorDescriptor_t xDesc, const cudnnFilterDescriptor_t wDesc, + const void *w, const int linLayerID, + cudnnFilterDescriptor_t linLayerBiasDesc, void **linLayerBias) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, + const void *, const int, cudnnFilterDescriptor_t, void **); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNLinLayerBiasParams"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, layer, xDesc, wDesc, w, linLayerID, linLayerBiasDesc, linLayerBias); + return func_ptr(handle, rnnDesc, layer, xDesc, wDesc, w, linLayerID, + linLayerBiasDesc, linLayerBias); } -cudnnStatus_t CUDNNWINAPI cudnnRNNForwardInference( cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int seqLength, - const cudnnTensorDescriptor_t * xDesc, - const void * x, - const cudnnTensorDescriptor_t hxDesc, - const void * hx, - const cudnnTensorDescriptor_t cxDesc, - const void * cx, - const cudnnFilterDescriptor_t wDesc, - const void * w, - const cudnnTensorDescriptor_t *yDesc, - void * y, - const cudnnTensorDescriptor_t hyDesc, - void * hy, - const cudnnTensorDescriptor_t cyDesc, - void * cy, - void * workspace, - size_t workSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, void *, size_t); +cudnnStatus_t CUDNNWINAPI cudnnRNNForwardInference( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, const cudnnTensorDescriptor_t *xDesc, const void *x, + const cudnnTensorDescriptor_t hxDesc, const void *hx, + const cudnnTensorDescriptor_t cxDesc, const void *cx, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnTensorDescriptor_t *yDesc, void *y, + const cudnnTensorDescriptor_t hyDesc, void *hy, + const cudnnTensorDescriptor_t cyDesc, void *cy, void *workspace, + size_t workSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t *, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, + void *, const cudnnTensorDescriptor_t, void *, void *, size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRNNForwardInference"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, cxDesc, cx, wDesc, w, yDesc, y, hyDesc, hy, cyDesc, cy, workspace, workSpaceSizeInBytes); + return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, cxDesc, cx, + wDesc, w, yDesc, y, hyDesc, hy, cyDesc, cy, workspace, + workSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI cudnnRNNForwardTraining( cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int seqLength, - const cudnnTensorDescriptor_t *xDesc, - const void * x, - const cudnnTensorDescriptor_t hxDesc, - const void * hx, - const cudnnTensorDescriptor_t cxDesc, - const void * cx, - const cudnnFilterDescriptor_t wDesc, - const void * w, - const cudnnTensorDescriptor_t *yDesc, - void * y, - const cudnnTensorDescriptor_t hyDesc, - void * hy, - const cudnnTensorDescriptor_t cyDesc, - void * cy, - void * workspace, - size_t workSpaceSizeInBytes, - void * reserveSpace, - size_t reserveSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, void *, size_t, void *, size_t); +cudnnStatus_t CUDNNWINAPI cudnnRNNForwardTraining( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, const cudnnTensorDescriptor_t *xDesc, const void *x, + const cudnnTensorDescriptor_t hxDesc, const void *hx, + const cudnnTensorDescriptor_t cxDesc, const void *cx, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnTensorDescriptor_t *yDesc, void *y, + const cudnnTensorDescriptor_t hyDesc, void *hy, + const cudnnTensorDescriptor_t cyDesc, void *cy, void *workspace, + size_t workSpaceSizeInBytes, void *reserveSpace, + size_t reserveSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t *, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, + void *, const cudnnTensorDescriptor_t, void *, void *, size_t, void *, + size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRNNForwardTraining"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, cxDesc, cx, wDesc, w, yDesc, y, hyDesc, hy, cyDesc, cy, workspace, workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes); + return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, cxDesc, cx, + wDesc, w, yDesc, y, hyDesc, hy, cyDesc, cy, workspace, + workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI cudnnRNNBackwardData( cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int seqLength, - const cudnnTensorDescriptor_t * yDesc, - const void * y, - const cudnnTensorDescriptor_t * dyDesc, - const void * dy, - const cudnnTensorDescriptor_t dhyDesc, - const void * dhy, - const cudnnTensorDescriptor_t dcyDesc, - const void * dcy, - const cudnnFilterDescriptor_t wDesc, - const void * w, - const cudnnTensorDescriptor_t hxDesc, - const void * hx, - const cudnnTensorDescriptor_t cxDesc, - const void * cx, - const cudnnTensorDescriptor_t * dxDesc, - void * dx, - const cudnnTensorDescriptor_t dhxDesc, - void * dhx, - const cudnnTensorDescriptor_t dcxDesc, - void * dcx, - void * workspace, - size_t workSpaceSizeInBytes, - void * reserveSpace, - size_t reserveSpaceSizeInBytes ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, void *, size_t, void *, size_t); +cudnnStatus_t CUDNNWINAPI +cudnnRNNBackwardData(cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, const cudnnTensorDescriptor_t *yDesc, + const void *y, const cudnnTensorDescriptor_t *dyDesc, + const void *dy, const cudnnTensorDescriptor_t dhyDesc, + const void *dhy, const cudnnTensorDescriptor_t dcyDesc, + const void *dcy, const cudnnFilterDescriptor_t wDesc, + const void *w, const cudnnTensorDescriptor_t hxDesc, + const void *hx, const cudnnTensorDescriptor_t cxDesc, + const void *cx, const cudnnTensorDescriptor_t *dxDesc, + void *dx, const cudnnTensorDescriptor_t dhxDesc, void *dhx, + const cudnnTensorDescriptor_t dcxDesc, void *dcx, + void *workspace, size_t workSpaceSizeInBytes, + void *reserveSpace, size_t reserveSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t *, const void *, + const cudnnTensorDescriptor_t *, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, + void *, const cudnnTensorDescriptor_t, void *, void *, size_t, void *, + size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRNNBackwardData"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, seqLength, yDesc, y, dyDesc, dy, dhyDesc, dhy, dcyDesc, dcy, wDesc, w, hxDesc, hx, cxDesc, cx, dxDesc, dx, dhxDesc, dhx, dcxDesc, dcx, workspace, workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes); + return func_ptr(handle, rnnDesc, seqLength, yDesc, y, dyDesc, dy, dhyDesc, + dhy, dcyDesc, dcy, wDesc, w, hxDesc, hx, cxDesc, cx, dxDesc, + dx, dhxDesc, dhx, dcxDesc, dcx, workspace, + workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI cudnnRNNBackwardWeights( cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int seqLength, - const cudnnTensorDescriptor_t * xDesc, - const void * x, - const cudnnTensorDescriptor_t hxDesc, - const void * hx, - const cudnnTensorDescriptor_t * yDesc, - const void * y, - const void * workspace, - size_t workSpaceSizeInBytes, - const cudnnFilterDescriptor_t dwDesc, - void * dw, - const void * reserveSpace, - size_t reserveSpaceSizeInBytes ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t *, const void *, const void *, size_t, const cudnnFilterDescriptor_t, void *, const void *, size_t); +cudnnStatus_t CUDNNWINAPI cudnnRNNBackwardWeights( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, const cudnnTensorDescriptor_t *xDesc, const void *x, + const cudnnTensorDescriptor_t hxDesc, const void *hx, + const cudnnTensorDescriptor_t *yDesc, const void *y, const void *workspace, + size_t workSpaceSizeInBytes, const cudnnFilterDescriptor_t dwDesc, void *dw, + const void *reserveSpace, size_t reserveSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t *, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t *, const void *, const void *, size_t, + const cudnnFilterDescriptor_t, void *, const void *, size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRNNBackwardWeights"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, yDesc, y, workspace, workSpaceSizeInBytes, dwDesc, dw, reserveSpace, reserveSpaceSizeInBytes); + return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, yDesc, y, + workspace, workSpaceSizeInBytes, dwDesc, dw, reserveSpace, + reserveSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI cudnnCreateCTCLossDescriptor( cudnnCTCLossDescriptor_t* ctcLossDesc ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnCTCLossDescriptor_t *); +cudnnStatus_t CUDNNWINAPI +cudnnCreateCTCLossDescriptor(cudnnCTCLossDescriptor_t *ctcLossDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnCTCLossDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateCTCLossDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(ctcLossDesc); } cudnnStatus_t CUDNNWINAPI cudnnSetCTCLossDescriptor( - cudnnCTCLossDescriptor_t ctcLossDesc, - cudnnDataType_t compType ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnCTCLossDescriptor_t, cudnnDataType_t); + cudnnCTCLossDescriptor_t ctcLossDesc, cudnnDataType_t compType) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnCTCLossDescriptor_t, cudnnDataType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetCTCLossDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(ctcLossDesc, compType); } cudnnStatus_t CUDNNWINAPI cudnnGetCTCLossDescriptor( - cudnnCTCLossDescriptor_t ctcLossDesc, - cudnnDataType_t* compType ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnCTCLossDescriptor_t, cudnnDataType_t *); + cudnnCTCLossDescriptor_t ctcLossDesc, cudnnDataType_t *compType) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnCTCLossDescriptor_t, cudnnDataType_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetCTCLossDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(ctcLossDesc, compType); } -cudnnStatus_t CUDNNWINAPI cudnnDestroyCTCLossDescriptor( cudnnCTCLossDescriptor_t ctcLossDesc ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnCTCLossDescriptor_t); +cudnnStatus_t CUDNNWINAPI +cudnnDestroyCTCLossDescriptor(cudnnCTCLossDescriptor_t ctcLossDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnCTCLossDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyCTCLossDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(ctcLossDesc); } -cudnnStatus_t CUDNNWINAPI cudnnCTCLoss( cudnnHandle_t handle, - const cudnnTensorDescriptor_t probsDesc, /* Tensor descriptor for probabilities, the dimensions are T,N,A (T is the timing steps, N is the mini batch size, A is the alphabet size) */ - const void * probs, /* probabilities after softmax, in GPU memory */ - const int * labels, /* labels, in CPU memory */ - const int * labelLengths, /* the length of each label, in CPU memory */ - const int * inputLengths, /* the lengths of timing steps in each batch, in CPU memory */ - void * costs, /* the returned costs of CTC, in GPU memory */ - const cudnnTensorDescriptor_t gradientsDesc, /* Tensor descriptor for gradients, the dimensions are T,N,A */ - const void * gradients, /* the returned CTC gradients, in GPU memory, to compute costs only, set it to NULL */ - cudnnCTCLossAlgo_t algo, /* algorithm selected, supported now 0 and 1 */ - cudnnCTCLossDescriptor_t ctcLossDesc, - void * workspace, /* pointer to the workspace, in GPU memory */ - size_t workSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const void *, const int *, const int *, const int *, void *, const cudnnTensorDescriptor_t, const void *, cudnnCTCLossAlgo_t, cudnnCTCLossDescriptor_t, void *, size_t); +cudnnStatus_t CUDNNWINAPI cudnnCTCLoss( + cudnnHandle_t handle, + const cudnnTensorDescriptor_t + probsDesc, /* Tensor descriptor for probabilities, the dimensions are + T,N,A (T is the timing steps, N is the mini batch size, A + is the alphabet size) */ + const void *probs, /* probabilities after softmax, in GPU memory */ + const int *labels, /* labels, in CPU memory */ + const int *labelLengths, /* the length of each label, in CPU memory */ + const int *inputLengths, /* the lengths of timing steps in each batch, in + CPU memory */ + void *costs, /* the returned costs of CTC, in GPU memory */ + const cudnnTensorDescriptor_t + gradientsDesc, /* Tensor descriptor for gradients, the dimensions are + T,N,A */ + const void *gradients, /* the returned CTC gradients, in GPU memory, to + compute costs only, set it to NULL */ + cudnnCTCLossAlgo_t algo, /* algorithm selected, supported now 0 and 1 */ + cudnnCTCLossDescriptor_t ctcLossDesc, + void *workspace, /* pointer to the workspace, in GPU memory */ + size_t workSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, const void *, const int *, + const int *, const int *, void *, const cudnnTensorDescriptor_t, + const void *, cudnnCTCLossAlgo_t, cudnnCTCLossDescriptor_t, void *, + size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCTCLoss"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, probsDesc, probs, labels, labelLengths, inputLengths, costs, gradientsDesc, gradients, algo, ctcLossDesc, workspace, workSpaceSizeInBytes); + return func_ptr(handle, probsDesc, probs, labels, labelLengths, inputLengths, + costs, gradientsDesc, gradients, algo, ctcLossDesc, workspace, + workSpaceSizeInBytes); } cudnnStatus_t CUDNNWINAPI cudnnGetCTCLossWorkspaceSize( - cudnnHandle_t handle, - const cudnnTensorDescriptor_t probsDesc, /* Tensor descriptor for probabilities, the dimensions are T,N,A (T is the timing steps, N is the mini batch size, A is the alphabet size) */ - const cudnnTensorDescriptor_t gradientsDesc, /* Tensor descriptor for gradients, the dimensions are T,N,A. To compute costs only, set it to NULL */ - const int * labels, /* labels, in CPU memory */ - const int * labelLengths, /* the length of each label, in CPU memory */ - const int * inputLengths, /* the lengths of timing steps in each batch, in CPU memory */ - cudnnCTCLossAlgo_t algo, /* algorithm selected, supported now 0 and 1 */ - cudnnCTCLossDescriptor_t ctcLossDesc, - size_t *sizeInBytes ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const int *, const int *, const int *, cudnnCTCLossAlgo_t, cudnnCTCLossDescriptor_t, size_t *); + cudnnHandle_t handle, + const cudnnTensorDescriptor_t + probsDesc, /* Tensor descriptor for probabilities, the dimensions are + T,N,A (T is the timing steps, N is the mini batch size, A + is the alphabet size) */ + const cudnnTensorDescriptor_t + gradientsDesc, /* Tensor descriptor for gradients, the dimensions are + T,N,A. To compute costs only, set it to NULL */ + const int *labels, /* labels, in CPU memory */ + const int *labelLengths, /* the length of each label, in CPU memory */ + const int *inputLengths, /* the lengths of timing steps in each batch, in + CPU memory */ + cudnnCTCLossAlgo_t algo, /* algorithm selected, supported now 0 and 1 */ + cudnnCTCLossDescriptor_t ctcLossDesc, size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, + const cudnnTensorDescriptor_t, const int *, const int *, const int *, + cudnnCTCLossAlgo_t, cudnnCTCLossDescriptor_t, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetCTCLossWorkspaceSize"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, probsDesc, gradientsDesc, labels, labelLengths, inputLengths, algo, ctcLossDesc, sizeInBytes); + return func_ptr(handle, probsDesc, gradientsDesc, labels, labelLengths, + inputLengths, algo, ctcLossDesc, sizeInBytes); } -cudnnStatus_t CUDNNWINAPI cudnnSetRNNDescriptor_v6(cudnnHandle_t handle, - cudnnRNNDescriptor_t rnnDesc, - const int hiddenSize, - const int numLayers, - cudnnDropoutDescriptor_t dropoutDesc, /* Between layers, not between recurrent steps. */ - cudnnRNNInputMode_t inputMode, - cudnnDirectionMode_t direction, - cudnnRNNMode_t mode, - cudnnRNNAlgo_t algo, - cudnnDataType_t dataType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnRNNDescriptor_t, const int, const int, cudnnDropoutDescriptor_t, cudnnRNNInputMode_t, cudnnDirectionMode_t, cudnnRNNMode_t, cudnnRNNAlgo_t, cudnnDataType_t); +cudnnStatus_t CUDNNWINAPI cudnnSetRNNDescriptor_v6( + cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc, const int hiddenSize, + const int numLayers, + cudnnDropoutDescriptor_t + dropoutDesc, /* Between layers, not between recurrent steps. */ + cudnnRNNInputMode_t inputMode, cudnnDirectionMode_t direction, + cudnnRNNMode_t mode, cudnnRNNAlgo_t algo, cudnnDataType_t dataType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnRNNDescriptor_t, const int, const int, + cudnnDropoutDescriptor_t, cudnnRNNInputMode_t, cudnnDirectionMode_t, + cudnnRNNMode_t, cudnnRNNAlgo_t, cudnnDataType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetRNNDescriptor_v6"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, hiddenSize, numLayers, dropoutDesc, inputMode, direction, mode, algo, dataType); + return func_ptr(handle, rnnDesc, hiddenSize, numLayers, dropoutDesc, + inputMode, direction, mode, algo, dataType); } -cudnnStatus_t CUDNNWINAPI cudnnSetRNNDescriptor_v5(cudnnRNNDescriptor_t rnnDesc, - int hiddenSize, - int numLayers, - cudnnDropoutDescriptor_t dropoutDesc, /* Between layers, not between recurrent steps. */ - cudnnRNNInputMode_t inputMode, - cudnnDirectionMode_t direction, - cudnnRNNMode_t mode, - cudnnDataType_t dataType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t, int, int, cudnnDropoutDescriptor_t, cudnnRNNInputMode_t, cudnnDirectionMode_t, cudnnRNNMode_t, cudnnDataType_t); +cudnnStatus_t CUDNNWINAPI cudnnSetRNNDescriptor_v5( + cudnnRNNDescriptor_t rnnDesc, int hiddenSize, int numLayers, + cudnnDropoutDescriptor_t + dropoutDesc, /* Between layers, not between recurrent steps. */ + cudnnRNNInputMode_t inputMode, cudnnDirectionMode_t direction, + cudnnRNNMode_t mode, cudnnDataType_t dataType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnRNNDescriptor_t, int, int, cudnnDropoutDescriptor_t, + cudnnRNNInputMode_t, cudnnDirectionMode_t, cudnnRNNMode_t, + cudnnDataType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetRNNDescriptor_v5"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(rnnDesc, hiddenSize, numLayers, dropoutDesc, inputMode, direction, mode, dataType); + return func_ptr(rnnDesc, hiddenSize, numLayers, dropoutDesc, inputMode, + direction, mode, dataType); } } // extern "C" diff --git a/tensorflow/stream_executor/cuda/cudnn_7_1.inc b/tensorflow/stream_executor/cuda/cudnn_7_1.inc index 9f4b28f3fe3..5330e6d0584 100644 --- a/tensorflow/stream_executor/cuda/cudnn_7_1.inc +++ b/tensorflow/stream_executor/cuda/cudnn_7_1.inc @@ -3,2279 +3,2359 @@ extern "C" { size_t CUDNNWINAPI cudnnGetVersion(void) { - using FuncPtr = size_t (CUDNNWINAPI *)(); + using FuncPtr = size_t(CUDNNWINAPI *)(); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetVersion"); if (!func_ptr) return 0; return func_ptr(); } size_t CUDNNWINAPI cudnnGetCudartVersion(void) { - using FuncPtr = size_t (CUDNNWINAPI *)(); + using FuncPtr = size_t(CUDNNWINAPI *)(); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetCudartVersion"); if (!func_ptr) return 0; return func_ptr(); } -const char * CUDNNWINAPI cudnnGetErrorString(cudnnStatus_t status) { - using FuncPtr = const char * (CUDNNWINAPI *)(cudnnStatus_t); +const char *CUDNNWINAPI cudnnGetErrorString(cudnnStatus_t status) { + using FuncPtr = const char *(CUDNNWINAPI *)(cudnnStatus_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetErrorString"); if (!func_ptr) return "cudnnGetErrorString symbol not found."; return func_ptr(status); } -cudnnStatus_t CUDNNWINAPI cudnnQueryRuntimeError( - cudnnHandle_t handle, - cudnnStatus_t *rstatus, - cudnnErrQueryMode_t mode, - cudnnRuntimeTag_t *tag ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnStatus_t *, cudnnErrQueryMode_t, cudnnRuntimeTag_t *); +cudnnStatus_t CUDNNWINAPI cudnnQueryRuntimeError(cudnnHandle_t handle, + cudnnStatus_t *rstatus, + cudnnErrQueryMode_t mode, + cudnnRuntimeTag_t *tag) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnStatus_t *, cudnnErrQueryMode_t, cudnnRuntimeTag_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnQueryRuntimeError"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rstatus, mode, tag); } -cudnnStatus_t CUDNNWINAPI cudnnGetProperty(libraryPropertyType type, int *value) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(libraryPropertyType, int *); +cudnnStatus_t CUDNNWINAPI cudnnGetProperty(libraryPropertyType type, + int *value) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(libraryPropertyType, int *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetProperty"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(type, value); } -cudnnStatus_t CUDNNWINAPI cudnnCreate (cudnnHandle_t *handle) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t *); +cudnnStatus_t CUDNNWINAPI cudnnCreate(cudnnHandle_t *handle) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreate"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle); } -cudnnStatus_t CUDNNWINAPI cudnnDestroy (cudnnHandle_t handle) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t); +cudnnStatus_t CUDNNWINAPI cudnnDestroy(cudnnHandle_t handle) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroy"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle); } -cudnnStatus_t CUDNNWINAPI cudnnSetStream (cudnnHandle_t handle, cudaStream_t streamId) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudaStream_t); +cudnnStatus_t CUDNNWINAPI cudnnSetStream(cudnnHandle_t handle, + cudaStream_t streamId) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t, cudaStream_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetStream"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, streamId); } -cudnnStatus_t CUDNNWINAPI cudnnGetStream (cudnnHandle_t handle, cudaStream_t *streamId) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudaStream_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetStream(cudnnHandle_t handle, + cudaStream_t *streamId) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t, cudaStream_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetStream"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, streamId); } -cudnnStatus_t CUDNNWINAPI cudnnCreateTensorDescriptor( - cudnnTensorDescriptor_t *tensorDesc ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorDescriptor_t *); +cudnnStatus_t CUDNNWINAPI +cudnnCreateTensorDescriptor(cudnnTensorDescriptor_t *tensorDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnTensorDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(tensorDesc); } cudnnStatus_t CUDNNWINAPI cudnnSetTensor4dDescriptor( - cudnnTensorDescriptor_t tensorDesc, - cudnnTensorFormat_t format, - cudnnDataType_t dataType, /* image data type */ - int n, /* number of inputs (batch size) */ - int c, /* number of input feature maps */ - int h, /* height of input section */ - int w ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorDescriptor_t, cudnnTensorFormat_t, cudnnDataType_t, int, int, int, int); + cudnnTensorDescriptor_t tensorDesc, cudnnTensorFormat_t format, + cudnnDataType_t dataType, /* image data type */ + int n, /* number of inputs (batch size) */ + int c, /* number of input feature maps */ + int h, /* height of input section */ + int w) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnTensorDescriptor_t, cudnnTensorFormat_t, + cudnnDataType_t, int, int, int, int); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetTensor4dDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(tensorDesc, format, dataType, n, c, h, w); } cudnnStatus_t CUDNNWINAPI cudnnSetTensor4dDescriptorEx( - cudnnTensorDescriptor_t tensorDesc, - cudnnDataType_t dataType, /* image data type */ - int n, /* number of inputs (batch size) */ - int c, /* number of input feature maps */ - int h, /* height of input section */ - int w, /* width of input section */ - int nStride, - int cStride, - int hStride, - int wStride ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorDescriptor_t, cudnnDataType_t, int, int, int, int, int, int, int, int); + cudnnTensorDescriptor_t tensorDesc, + cudnnDataType_t dataType, /* image data type */ + int n, /* number of inputs (batch size) */ + int c, /* number of input feature maps */ + int h, /* height of input section */ + int w, /* width of input section */ + int nStride, int cStride, int hStride, int wStride) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnTensorDescriptor_t, cudnnDataType_t, + int, int, int, int, int, int, int, int); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetTensor4dDescriptorEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(tensorDesc, dataType, n, c, h, w, nStride, cStride, hStride, wStride); + return func_ptr(tensorDesc, dataType, n, c, h, w, nStride, cStride, hStride, + wStride); } cudnnStatus_t CUDNNWINAPI cudnnGetTensor4dDescriptor( - const cudnnTensorDescriptor_t tensorDesc, - cudnnDataType_t *dataType, /* image data type */ - int *n, /* number of inputs (batch size) */ - int *c, /* number of input feature maps */ - int *h, /* height of input section */ - int *w, /* width of input section */ - int *nStride, - int *cStride, - int *hStride, - int *wStride ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnTensorDescriptor_t, cudnnDataType_t *, int *, int *, int *, int *, int *, int *, int *, int *); + const cudnnTensorDescriptor_t tensorDesc, + cudnnDataType_t *dataType, /* image data type */ + int *n, /* number of inputs (batch size) */ + int *c, /* number of input feature maps */ + int *h, /* height of input section */ + int *w, /* width of input section */ + int *nStride, int *cStride, int *hStride, int *wStride) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnTensorDescriptor_t, cudnnDataType_t *, int *, int *, int *, + int *, int *, int *, int *, int *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetTensor4dDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(tensorDesc, dataType, n, c, h, w, nStride, cStride, hStride, wStride); + return func_ptr(tensorDesc, dataType, n, c, h, w, nStride, cStride, hStride, + wStride); } cudnnStatus_t CUDNNWINAPI cudnnSetTensorNdDescriptor( - cudnnTensorDescriptor_t tensorDesc, - cudnnDataType_t dataType, - int nbDims, - const int dimA[], - const int strideA[] ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorDescriptor_t, cudnnDataType_t, int, const int [], const int []); + cudnnTensorDescriptor_t tensorDesc, cudnnDataType_t dataType, int nbDims, + const int dimA[], const int strideA[]) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnTensorDescriptor_t, cudnnDataType_t, int, const int[], const int[]); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetTensorNdDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(tensorDesc, dataType, nbDims, dimA, strideA); } cudnnStatus_t CUDNNWINAPI cudnnSetTensorNdDescriptorEx( - cudnnTensorDescriptor_t tensorDesc, - cudnnTensorFormat_t format, - cudnnDataType_t dataType, - int nbDims, - const int dimA[] ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorDescriptor_t, cudnnTensorFormat_t, cudnnDataType_t, int, const int []); + cudnnTensorDescriptor_t tensorDesc, cudnnTensorFormat_t format, + cudnnDataType_t dataType, int nbDims, const int dimA[]) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnTensorDescriptor_t, cudnnTensorFormat_t, + cudnnDataType_t, int, const int[]); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetTensorNdDescriptorEx"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(tensorDesc, format, dataType, nbDims, dimA); } cudnnStatus_t CUDNNWINAPI cudnnGetTensorNdDescriptor( - const cudnnTensorDescriptor_t tensorDesc, - int nbDimsRequested, - cudnnDataType_t *dataType, - int *nbDims, - int dimA[], - int strideA[] ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnTensorDescriptor_t, int, cudnnDataType_t *, int *, int [], int []); + const cudnnTensorDescriptor_t tensorDesc, int nbDimsRequested, + cudnnDataType_t *dataType, int *nbDims, int dimA[], int strideA[]) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(const cudnnTensorDescriptor_t, int, + cudnnDataType_t *, int *, int[], int[]); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetTensorNdDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(tensorDesc, nbDimsRequested, dataType, nbDims, dimA, strideA); } cudnnStatus_t CUDNNWINAPI cudnnGetTensorSizeInBytes( - const cudnnTensorDescriptor_t tensorDesc, - size_t *size) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnTensorDescriptor_t, size_t *); + const cudnnTensorDescriptor_t tensorDesc, size_t *size) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(const cudnnTensorDescriptor_t, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetTensorSizeInBytes"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(tensorDesc, size); } -cudnnStatus_t CUDNNWINAPI cudnnDestroyTensorDescriptor( - cudnnTensorDescriptor_t tensorDesc ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorDescriptor_t); +cudnnStatus_t CUDNNWINAPI +cudnnDestroyTensorDescriptor(cudnnTensorDescriptor_t tensorDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnTensorDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(tensorDesc); } cudnnStatus_t CUDNNWINAPI cudnnTransformTensor( - cudnnHandle_t handle, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t yDesc, - void *y ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); + cudnnHandle_t handle, const void *alpha, + const cudnnTensorDescriptor_t xDesc, const void *x, const void *beta, + const cudnnTensorDescriptor_t yDesc, void *y) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, + const void *, const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnTransformTensor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, alpha, xDesc, x, beta, yDesc, y); } -cudnnStatus_t CUDNNWINAPI cudnnAddTensor( - cudnnHandle_t handle, - const void *alpha, - const cudnnTensorDescriptor_t aDesc, - const void *A, - const void *beta, - const cudnnTensorDescriptor_t cDesc, - void *C ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnAddTensor(cudnnHandle_t handle, + const void *alpha, + const cudnnTensorDescriptor_t aDesc, + const void *A, const void *beta, + const cudnnTensorDescriptor_t cDesc, + void *C) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, + const void *, const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnAddTensor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, alpha, aDesc, A, beta, cDesc, C); } -cudnnStatus_t CUDNNWINAPI cudnnCreateOpTensorDescriptor( - cudnnOpTensorDescriptor_t *opTensorDesc ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnOpTensorDescriptor_t *); +cudnnStatus_t CUDNNWINAPI +cudnnCreateOpTensorDescriptor(cudnnOpTensorDescriptor_t *opTensorDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnOpTensorDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateOpTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(opTensorDesc); } cudnnStatus_t CUDNNWINAPI cudnnSetOpTensorDescriptor( - cudnnOpTensorDescriptor_t opTensorDesc, - cudnnOpTensorOp_t opTensorOp, - cudnnDataType_t opTensorCompType, - cudnnNanPropagation_t opTensorNanOpt ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnOpTensorDescriptor_t, cudnnOpTensorOp_t, cudnnDataType_t, cudnnNanPropagation_t); + cudnnOpTensorDescriptor_t opTensorDesc, cudnnOpTensorOp_t opTensorOp, + cudnnDataType_t opTensorCompType, cudnnNanPropagation_t opTensorNanOpt) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnOpTensorDescriptor_t, cudnnOpTensorOp_t, + cudnnDataType_t, cudnnNanPropagation_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetOpTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(opTensorDesc, opTensorOp, opTensorCompType, opTensorNanOpt); } cudnnStatus_t CUDNNWINAPI cudnnGetOpTensorDescriptor( - const cudnnOpTensorDescriptor_t opTensorDesc, - cudnnOpTensorOp_t *opTensorOp, - cudnnDataType_t *opTensorCompType, - cudnnNanPropagation_t *opTensorNanOpt ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnOpTensorDescriptor_t, cudnnOpTensorOp_t *, cudnnDataType_t *, cudnnNanPropagation_t *); + const cudnnOpTensorDescriptor_t opTensorDesc, cudnnOpTensorOp_t *opTensorOp, + cudnnDataType_t *opTensorCompType, cudnnNanPropagation_t *opTensorNanOpt) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnOpTensorDescriptor_t, cudnnOpTensorOp_t *, cudnnDataType_t *, + cudnnNanPropagation_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetOpTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(opTensorDesc, opTensorOp, opTensorCompType, opTensorNanOpt); } -cudnnStatus_t CUDNNWINAPI cudnnDestroyOpTensorDescriptor( - cudnnOpTensorDescriptor_t opTensorDesc ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnOpTensorDescriptor_t); +cudnnStatus_t CUDNNWINAPI +cudnnDestroyOpTensorDescriptor(cudnnOpTensorDescriptor_t opTensorDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnOpTensorDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyOpTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(opTensorDesc); } cudnnStatus_t CUDNNWINAPI cudnnOpTensor( - cudnnHandle_t handle, - const cudnnOpTensorDescriptor_t opTensorDesc, - const void *alpha1, - const cudnnTensorDescriptor_t aDesc, - const void *A, - const void *alpha2, - const cudnnTensorDescriptor_t bDesc, - const void *B, - const void *beta, - const cudnnTensorDescriptor_t cDesc, - void *C ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnOpTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); + cudnnHandle_t handle, const cudnnOpTensorDescriptor_t opTensorDesc, + const void *alpha1, const cudnnTensorDescriptor_t aDesc, const void *A, + const void *alpha2, const cudnnTensorDescriptor_t bDesc, const void *B, + const void *beta, const cudnnTensorDescriptor_t cDesc, void *C) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnOpTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnOpTensor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, opTensorDesc, alpha1, aDesc, A, alpha2, bDesc, B, beta, cDesc, C); + return func_ptr(handle, opTensorDesc, alpha1, aDesc, A, alpha2, bDesc, B, + beta, cDesc, C); } cudnnStatus_t CUDNNWINAPI cudnnCreateReduceTensorDescriptor( - cudnnReduceTensorDescriptor_t *reduceTensorDesc ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnReduceTensorDescriptor_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateReduceTensorDescriptor"); + cudnnReduceTensorDescriptor_t *reduceTensorDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnReduceTensorDescriptor_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnCreateReduceTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(reduceTensorDesc); } cudnnStatus_t CUDNNWINAPI cudnnSetReduceTensorDescriptor( - cudnnReduceTensorDescriptor_t reduceTensorDesc, - cudnnReduceTensorOp_t reduceTensorOp, - cudnnDataType_t reduceTensorCompType, - cudnnNanPropagation_t reduceTensorNanOpt, - cudnnReduceTensorIndices_t reduceTensorIndices, - cudnnIndicesType_t reduceTensorIndicesType ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnReduceTensorDescriptor_t, cudnnReduceTensorOp_t, cudnnDataType_t, cudnnNanPropagation_t, cudnnReduceTensorIndices_t, cudnnIndicesType_t); + cudnnReduceTensorDescriptor_t reduceTensorDesc, + cudnnReduceTensorOp_t reduceTensorOp, cudnnDataType_t reduceTensorCompType, + cudnnNanPropagation_t reduceTensorNanOpt, + cudnnReduceTensorIndices_t reduceTensorIndices, + cudnnIndicesType_t reduceTensorIndicesType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnReduceTensorDescriptor_t, cudnnReduceTensorOp_t, cudnnDataType_t, + cudnnNanPropagation_t, cudnnReduceTensorIndices_t, cudnnIndicesType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetReduceTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(reduceTensorDesc, reduceTensorOp, reduceTensorCompType, reduceTensorNanOpt, reduceTensorIndices, reduceTensorIndicesType); + return func_ptr(reduceTensorDesc, reduceTensorOp, reduceTensorCompType, + reduceTensorNanOpt, reduceTensorIndices, + reduceTensorIndicesType); } cudnnStatus_t CUDNNWINAPI cudnnGetReduceTensorDescriptor( - const cudnnReduceTensorDescriptor_t reduceTensorDesc, - cudnnReduceTensorOp_t *reduceTensorOp, - cudnnDataType_t *reduceTensorCompType, - cudnnNanPropagation_t *reduceTensorNanOpt, - cudnnReduceTensorIndices_t *reduceTensorIndices, - cudnnIndicesType_t *reduceTensorIndicesType ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnReduceTensorDescriptor_t, cudnnReduceTensorOp_t *, cudnnDataType_t *, cudnnNanPropagation_t *, cudnnReduceTensorIndices_t *, cudnnIndicesType_t *); + const cudnnReduceTensorDescriptor_t reduceTensorDesc, + cudnnReduceTensorOp_t *reduceTensorOp, + cudnnDataType_t *reduceTensorCompType, + cudnnNanPropagation_t *reduceTensorNanOpt, + cudnnReduceTensorIndices_t *reduceTensorIndices, + cudnnIndicesType_t *reduceTensorIndicesType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnReduceTensorDescriptor_t, cudnnReduceTensorOp_t *, + cudnnDataType_t *, cudnnNanPropagation_t *, cudnnReduceTensorIndices_t *, + cudnnIndicesType_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetReduceTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(reduceTensorDesc, reduceTensorOp, reduceTensorCompType, reduceTensorNanOpt, reduceTensorIndices, reduceTensorIndicesType); + return func_ptr(reduceTensorDesc, reduceTensorOp, reduceTensorCompType, + reduceTensorNanOpt, reduceTensorIndices, + reduceTensorIndicesType); } cudnnStatus_t CUDNNWINAPI cudnnDestroyReduceTensorDescriptor( - cudnnReduceTensorDescriptor_t reduceTensorDesc ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnReduceTensorDescriptor_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyReduceTensorDescriptor"); + cudnnReduceTensorDescriptor_t reduceTensorDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnReduceTensorDescriptor_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnDestroyReduceTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(reduceTensorDesc); } cudnnStatus_t CUDNNWINAPI cudnnGetReductionIndicesSize( - cudnnHandle_t handle, - const cudnnReduceTensorDescriptor_t reduceTensorDesc, - const cudnnTensorDescriptor_t aDesc, - const cudnnTensorDescriptor_t cDesc, - size_t *sizeInBytes ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnReduceTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, size_t *); + cudnnHandle_t handle, const cudnnReduceTensorDescriptor_t reduceTensorDesc, + const cudnnTensorDescriptor_t aDesc, const cudnnTensorDescriptor_t cDesc, + size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnReduceTensorDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetReductionIndicesSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, reduceTensorDesc, aDesc, cDesc, sizeInBytes); } cudnnStatus_t CUDNNWINAPI cudnnGetReductionWorkspaceSize( - cudnnHandle_t handle, - const cudnnReduceTensorDescriptor_t reduceTensorDesc, - const cudnnTensorDescriptor_t aDesc, - const cudnnTensorDescriptor_t cDesc, - size_t *sizeInBytes ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnReduceTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, size_t *); + cudnnHandle_t handle, const cudnnReduceTensorDescriptor_t reduceTensorDesc, + const cudnnTensorDescriptor_t aDesc, const cudnnTensorDescriptor_t cDesc, + size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnReduceTensorDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetReductionWorkspaceSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, reduceTensorDesc, aDesc, cDesc, sizeInBytes); } cudnnStatus_t CUDNNWINAPI cudnnReduceTensor( - cudnnHandle_t handle, - const cudnnReduceTensorDescriptor_t reduceTensorDesc, - void *indices, - size_t indicesSizeInBytes, - void *workspace, - size_t workspaceSizeInBytes, - const void *alpha, - const cudnnTensorDescriptor_t aDesc, - const void *A, - const void *beta, - const cudnnTensorDescriptor_t cDesc, - void *C ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnReduceTensorDescriptor_t, void *, size_t, void *, size_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); + cudnnHandle_t handle, const cudnnReduceTensorDescriptor_t reduceTensorDesc, + void *indices, size_t indicesSizeInBytes, void *workspace, + size_t workspaceSizeInBytes, const void *alpha, + const cudnnTensorDescriptor_t aDesc, const void *A, const void *beta, + const cudnnTensorDescriptor_t cDesc, void *C) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnReduceTensorDescriptor_t, void *, size_t, + void *, size_t, const void *, const cudnnTensorDescriptor_t, const void *, + const void *, const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnReduceTensor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, reduceTensorDesc, indices, indicesSizeInBytes, workspace, workspaceSizeInBytes, alpha, aDesc, A, beta, cDesc, C); + return func_ptr(handle, reduceTensorDesc, indices, indicesSizeInBytes, + workspace, workspaceSizeInBytes, alpha, aDesc, A, beta, cDesc, + C); } -cudnnStatus_t CUDNNWINAPI cudnnSetTensor( - cudnnHandle_t handle, - const cudnnTensorDescriptor_t yDesc, - void *y, - const void *valuePtr ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, void *, const void *); +cudnnStatus_t CUDNNWINAPI cudnnSetTensor(cudnnHandle_t handle, + const cudnnTensorDescriptor_t yDesc, + void *y, const void *valuePtr) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, void *, const void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetTensor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, yDesc, y, valuePtr); } -cudnnStatus_t CUDNNWINAPI cudnnScaleTensor( - cudnnHandle_t handle, - const cudnnTensorDescriptor_t yDesc, - void *y, - const void *alpha ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, void *, const void *); +cudnnStatus_t CUDNNWINAPI cudnnScaleTensor(cudnnHandle_t handle, + const cudnnTensorDescriptor_t yDesc, + void *y, const void *alpha) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, void *, const void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnScaleTensor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, yDesc, y, alpha); } -cudnnStatus_t CUDNNWINAPI cudnnCreateFilterDescriptor( - cudnnFilterDescriptor_t *filterDesc ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnFilterDescriptor_t *); +cudnnStatus_t CUDNNWINAPI +cudnnCreateFilterDescriptor(cudnnFilterDescriptor_t *filterDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnFilterDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateFilterDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(filterDesc); } cudnnStatus_t CUDNNWINAPI cudnnSetFilter4dDescriptor( - cudnnFilterDescriptor_t filterDesc, - cudnnDataType_t dataType, /* image data type */ - cudnnTensorFormat_t format, - int k, /* number of output feature maps */ - int c, /* number of input feature maps */ - int h, /* height of each input filter */ - int w ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnFilterDescriptor_t, cudnnDataType_t, cudnnTensorFormat_t, int, int, int, int); + cudnnFilterDescriptor_t filterDesc, + cudnnDataType_t dataType, /* image data type */ + cudnnTensorFormat_t format, int k, /* number of output feature maps */ + int c, /* number of input feature maps */ + int h, /* height of each input filter */ + int w) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnFilterDescriptor_t, cudnnDataType_t, + cudnnTensorFormat_t, int, int, int, int); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetFilter4dDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(filterDesc, dataType, format, k, c, h, w); } cudnnStatus_t CUDNNWINAPI cudnnGetFilter4dDescriptor( - const cudnnFilterDescriptor_t filterDesc, - cudnnDataType_t *dataType, /* image data type */ - cudnnTensorFormat_t *format, - int *k, /* number of output feature maps */ - int *c, /* number of input feature maps */ - int *h, /* height of each input filter */ - int *w ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnFilterDescriptor_t, cudnnDataType_t *, cudnnTensorFormat_t *, int *, int *, int *, int *); + const cudnnFilterDescriptor_t filterDesc, + cudnnDataType_t *dataType, /* image data type */ + cudnnTensorFormat_t *format, int *k, /* number of output feature maps */ + int *c, /* number of input feature maps */ + int *h, /* height of each input filter */ + int *w) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnFilterDescriptor_t, cudnnDataType_t *, cudnnTensorFormat_t *, + int *, int *, int *, int *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetFilter4dDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(filterDesc, dataType, format, k, c, h, w); } cudnnStatus_t CUDNNWINAPI cudnnSetFilterNdDescriptor( - cudnnFilterDescriptor_t filterDesc, - cudnnDataType_t dataType, /* image data type */ - cudnnTensorFormat_t format, - int nbDims, - const int filterDimA[] ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnFilterDescriptor_t, cudnnDataType_t, cudnnTensorFormat_t, int, const int []); + cudnnFilterDescriptor_t filterDesc, + cudnnDataType_t dataType, /* image data type */ + cudnnTensorFormat_t format, int nbDims, const int filterDimA[]) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnFilterDescriptor_t, cudnnDataType_t, + cudnnTensorFormat_t, int, const int[]); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetFilterNdDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(filterDesc, dataType, format, nbDims, filterDimA); } cudnnStatus_t CUDNNWINAPI cudnnGetFilterNdDescriptor( - const cudnnFilterDescriptor_t filterDesc, - int nbDimsRequested, - cudnnDataType_t *dataType, /* image data type */ - cudnnTensorFormat_t *format, - int *nbDims, - int filterDimA[] ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnFilterDescriptor_t, int, cudnnDataType_t *, cudnnTensorFormat_t *, int *, int []); + const cudnnFilterDescriptor_t filterDesc, int nbDimsRequested, + cudnnDataType_t *dataType, /* image data type */ + cudnnTensorFormat_t *format, int *nbDims, int filterDimA[]) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnFilterDescriptor_t, int, cudnnDataType_t *, + cudnnTensorFormat_t *, int *, int[]); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetFilterNdDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(filterDesc, nbDimsRequested, dataType, format, nbDims, filterDimA); + return func_ptr(filterDesc, nbDimsRequested, dataType, format, nbDims, + filterDimA); } -cudnnStatus_t CUDNNWINAPI cudnnDestroyFilterDescriptor( - cudnnFilterDescriptor_t filterDesc ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnFilterDescriptor_t); +cudnnStatus_t CUDNNWINAPI +cudnnDestroyFilterDescriptor(cudnnFilterDescriptor_t filterDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnFilterDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyFilterDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(filterDesc); } -cudnnStatus_t CUDNNWINAPI cudnnCreateConvolutionDescriptor( - cudnnConvolutionDescriptor_t *convDesc ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateConvolutionDescriptor"); +cudnnStatus_t CUDNNWINAPI +cudnnCreateConvolutionDescriptor(cudnnConvolutionDescriptor_t *convDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnConvolutionDescriptor_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnCreateConvolutionDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(convDesc); } -cudnnStatus_t CUDNNWINAPI cudnnSetConvolutionMathType( cudnnConvolutionDescriptor_t convDesc, - cudnnMathType_t mathType ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, cudnnMathType_t); +cudnnStatus_t CUDNNWINAPI cudnnSetConvolutionMathType( + cudnnConvolutionDescriptor_t convDesc, cudnnMathType_t mathType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, + cudnnMathType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetConvolutionMathType"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(convDesc, mathType); } -cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionMathType( cudnnConvolutionDescriptor_t convDesc, - cudnnMathType_t *mathType ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, cudnnMathType_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionMathType( + cudnnConvolutionDescriptor_t convDesc, cudnnMathType_t *mathType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, + cudnnMathType_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionMathType"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(convDesc, mathType); } -cudnnStatus_t CUDNNWINAPI cudnnSetConvolutionGroupCount( cudnnConvolutionDescriptor_t convDesc, - int groupCount ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, int); +cudnnStatus_t CUDNNWINAPI cudnnSetConvolutionGroupCount( + cudnnConvolutionDescriptor_t convDesc, int groupCount) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, int); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetConvolutionGroupCount"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(convDesc, groupCount); } -cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionGroupCount( cudnnConvolutionDescriptor_t convDesc, - int *groupCount ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, int *); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionGroupCount( + cudnnConvolutionDescriptor_t convDesc, int *groupCount) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, int *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionGroupCount"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(convDesc, groupCount); } -cudnnStatus_t CUDNNWINAPI cudnnSetConvolution2dDescriptor( cudnnConvolutionDescriptor_t convDesc, - int pad_h, /* zero-padding height */ - int pad_w, /* zero-padding width */ - int u, /* vertical filter stride */ - int v, /* horizontal filter stride */ - int dilation_h, /* filter dilation in the vertical dimension */ - int dilation_w, /* filter dilation in the horizontal dimension */ - cudnnConvolutionMode_t mode, - cudnnDataType_t computeType - ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, int, int, int, int, int, int, cudnnConvolutionMode_t, cudnnDataType_t); +cudnnStatus_t CUDNNWINAPI cudnnSetConvolution2dDescriptor( + cudnnConvolutionDescriptor_t convDesc, int pad_h, /* zero-padding height */ + int pad_w, /* zero-padding width */ + int u, /* vertical filter stride */ + int v, /* horizontal filter stride */ + int dilation_h, /* filter dilation in the vertical dimension */ + int dilation_w, /* filter dilation in the horizontal dimension */ + cudnnConvolutionMode_t mode, cudnnDataType_t computeType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnConvolutionDescriptor_t, int, int, int, int, int, int, + cudnnConvolutionMode_t, cudnnDataType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetConvolution2dDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(convDesc, pad_h, pad_w, u, v, dilation_h, dilation_w, mode, computeType); + return func_ptr(convDesc, pad_h, pad_w, u, v, dilation_h, dilation_w, mode, + computeType); } -cudnnStatus_t CUDNNWINAPI cudnnGetConvolution2dDescriptor( const cudnnConvolutionDescriptor_t convDesc, - int* pad_h, /* zero-padding height */ - int* pad_w, /* zero-padding width */ - int* u, /* vertical filter stride */ - int* v, /* horizontal filter stride */ - int* dilation_h, /* filter dilation in the vertical dimension */ - int* dilation_w, /* filter dilation in the horizontal dimension */ - cudnnConvolutionMode_t* mode, - cudnnDataType_t *computeType - ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnConvolutionDescriptor_t, int *, int *, int *, int *, int *, int *, cudnnConvolutionMode_t *, cudnnDataType_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolution2dDescriptor( + const cudnnConvolutionDescriptor_t convDesc, + int *pad_h, /* zero-padding height */ + int *pad_w, /* zero-padding width */ + int *u, /* vertical filter stride */ + int *v, /* horizontal filter stride */ + int *dilation_h, /* filter dilation in the vertical dimension */ + int *dilation_w, /* filter dilation in the horizontal dimension */ + cudnnConvolutionMode_t *mode, cudnnDataType_t *computeType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnConvolutionDescriptor_t, int *, int *, int *, int *, int *, + int *, cudnnConvolutionMode_t *, cudnnDataType_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolution2dDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(convDesc, pad_h, pad_w, u, v, dilation_h, dilation_w, mode, computeType); + return func_ptr(convDesc, pad_h, pad_w, u, v, dilation_h, dilation_w, mode, + computeType); } cudnnStatus_t CUDNNWINAPI cudnnGetConvolution2dForwardOutputDim( - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t inputTensorDesc, - const cudnnFilterDescriptor_t filterDesc, - int *n, - int *c, - int *h, - int *w ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, int *, int *, int *, int *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolution2dForwardOutputDim"); + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t inputTensorDesc, + const cudnnFilterDescriptor_t filterDesc, int *n, int *c, int *h, int *w) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, + const cudnnFilterDescriptor_t, int *, int *, int *, int *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolution2dForwardOutputDim"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(convDesc, inputTensorDesc, filterDesc, n, c, h, w); } cudnnStatus_t CUDNNWINAPI cudnnSetConvolutionNdDescriptor( - cudnnConvolutionDescriptor_t convDesc, - int arrayLength, /* nbDims-2 size */ - const int padA[], - const int filterStrideA[], - const int dilationA[], - cudnnConvolutionMode_t mode, - cudnnDataType_t computeType ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, int, const int [], const int [], const int [], cudnnConvolutionMode_t, cudnnDataType_t); + cudnnConvolutionDescriptor_t convDesc, int arrayLength, /* nbDims-2 size */ + const int padA[], const int filterStrideA[], const int dilationA[], + cudnnConvolutionMode_t mode, cudnnDataType_t computeType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnConvolutionDescriptor_t, int, const int[], const int[], const int[], + cudnnConvolutionMode_t, cudnnDataType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetConvolutionNdDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(convDesc, arrayLength, padA, filterStrideA, dilationA, mode, computeType); + return func_ptr(convDesc, arrayLength, padA, filterStrideA, dilationA, mode, + computeType); } cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionNdDescriptor( - const cudnnConvolutionDescriptor_t convDesc, - int arrayLengthRequested, - int *arrayLength, - int padA[], - int strideA[], - int dilationA[], - cudnnConvolutionMode_t *mode, - cudnnDataType_t *computeType ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnConvolutionDescriptor_t, int, int *, int [], int [], int [], cudnnConvolutionMode_t *, cudnnDataType_t *); + const cudnnConvolutionDescriptor_t convDesc, int arrayLengthRequested, + int *arrayLength, int padA[], int strideA[], int dilationA[], + cudnnConvolutionMode_t *mode, cudnnDataType_t *computeType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnConvolutionDescriptor_t, int, int *, int[], int[], int[], + cudnnConvolutionMode_t *, cudnnDataType_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionNdDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(convDesc, arrayLengthRequested, arrayLength, padA, strideA, dilationA, mode, computeType); + return func_ptr(convDesc, arrayLengthRequested, arrayLength, padA, strideA, + dilationA, mode, computeType); } cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionNdForwardOutputDim( - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t inputTensorDesc, - const cudnnFilterDescriptor_t filterDesc, - int nbDims, - int tensorOutputDimA[] ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, int, int []); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionNdForwardOutputDim"); + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t inputTensorDesc, + const cudnnFilterDescriptor_t filterDesc, int nbDims, + int tensorOutputDimA[]) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, + const cudnnFilterDescriptor_t, int, int[]); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionNdForwardOutputDim"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(convDesc, inputTensorDesc, filterDesc, nbDims, tensorOutputDimA); + return func_ptr(convDesc, inputTensorDesc, filterDesc, nbDims, + tensorOutputDimA); } -cudnnStatus_t CUDNNWINAPI cudnnDestroyConvolutionDescriptor( - cudnnConvolutionDescriptor_t convDesc ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyConvolutionDescriptor"); +cudnnStatus_t CUDNNWINAPI +cudnnDestroyConvolutionDescriptor(cudnnConvolutionDescriptor_t convDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnConvolutionDescriptor_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnDestroyConvolutionDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(convDesc); } -cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionForwardAlgorithmMaxCount( cudnnHandle_t handle, - int *count) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, int *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionForwardAlgorithmMaxCount"); +cudnnStatus_t CUDNNWINAPI +cudnnGetConvolutionForwardAlgorithmMaxCount(cudnnHandle_t handle, int *count) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t, int *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionForwardAlgorithmMaxCount"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, count); } cudnnStatus_t CUDNNWINAPI cudnnFindConvolutionForwardAlgorithm( - cudnnHandle_t handle, - const cudnnTensorDescriptor_t xDesc, - const cudnnFilterDescriptor_t wDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t yDesc, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnConvolutionFwdAlgoPerf_t *perfResults ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, const int, int *, cudnnConvolutionFwdAlgoPerf_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindConvolutionForwardAlgorithm"); + cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, + const cudnnFilterDescriptor_t wDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t yDesc, const int requestedAlgoCount, + int *returnedAlgoCount, cudnnConvolutionFwdAlgoPerf_t *perfResults) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, + const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnTensorDescriptor_t, const int, int *, + cudnnConvolutionFwdAlgoPerf_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnFindConvolutionForwardAlgorithm"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, xDesc, wDesc, convDesc, yDesc, requestedAlgoCount, returnedAlgoCount, perfResults); + return func_ptr(handle, xDesc, wDesc, convDesc, yDesc, requestedAlgoCount, + returnedAlgoCount, perfResults); } cudnnStatus_t CUDNNWINAPI cudnnFindConvolutionForwardAlgorithmEx( - cudnnHandle_t handle, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t yDesc, - void *y, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnConvolutionFwdAlgoPerf_t *perfResults, - void *workSpace, - size_t workSpaceSizeInBytes ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, void *, const int, int *, cudnnConvolutionFwdAlgoPerf_t *, void *, size_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindConvolutionForwardAlgorithmEx"); + cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, const void *x, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t yDesc, void *y, const int requestedAlgoCount, + int *returnedAlgoCount, cudnnConvolutionFwdAlgoPerf_t *perfResults, + void *workSpace, size_t workSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, void *, + const int, int *, cudnnConvolutionFwdAlgoPerf_t *, void *, size_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnFindConvolutionForwardAlgorithmEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, xDesc, x, wDesc, w, convDesc, yDesc, y, requestedAlgoCount, returnedAlgoCount, perfResults, workSpace, workSpaceSizeInBytes); + return func_ptr(handle, xDesc, x, wDesc, w, convDesc, yDesc, y, + requestedAlgoCount, returnedAlgoCount, perfResults, workSpace, + workSpaceSizeInBytes); } cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionForwardAlgorithm( - cudnnHandle_t handle, - const cudnnTensorDescriptor_t xDesc, - const cudnnFilterDescriptor_t wDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t yDesc, - cudnnConvolutionFwdPreference_t preference, - size_t memoryLimitInBytes, - cudnnConvolutionFwdAlgo_t *algo ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, cudnnConvolutionFwdPreference_t, size_t, cudnnConvolutionFwdAlgo_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionForwardAlgorithm"); + cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, + const cudnnFilterDescriptor_t wDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t yDesc, + cudnnConvolutionFwdPreference_t preference, size_t memoryLimitInBytes, + cudnnConvolutionFwdAlgo_t *algo) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, + const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnTensorDescriptor_t, cudnnConvolutionFwdPreference_t, size_t, + cudnnConvolutionFwdAlgo_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionForwardAlgorithm"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, xDesc, wDesc, convDesc, yDesc, preference, memoryLimitInBytes, algo); + return func_ptr(handle, xDesc, wDesc, convDesc, yDesc, preference, + memoryLimitInBytes, algo); } cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionForwardAlgorithm_v7( - cudnnHandle_t handle, - const cudnnTensorDescriptor_t srcDesc, - const cudnnFilterDescriptor_t filterDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t destDesc, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnConvolutionFwdAlgoPerf_t *perfResults) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, const int, int *, cudnnConvolutionFwdAlgoPerf_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionForwardAlgorithm_v7"); + cudnnHandle_t handle, const cudnnTensorDescriptor_t srcDesc, + const cudnnFilterDescriptor_t filterDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t destDesc, const int requestedAlgoCount, + int *returnedAlgoCount, cudnnConvolutionFwdAlgoPerf_t *perfResults) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, + const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnTensorDescriptor_t, const int, int *, + cudnnConvolutionFwdAlgoPerf_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionForwardAlgorithm_v7"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, srcDesc, filterDesc, convDesc, destDesc, requestedAlgoCount, returnedAlgoCount, perfResults); + return func_ptr(handle, srcDesc, filterDesc, convDesc, destDesc, + requestedAlgoCount, returnedAlgoCount, perfResults); } cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionForwardWorkspaceSize( - cudnnHandle_t handle, - const cudnnTensorDescriptor_t xDesc, - const cudnnFilterDescriptor_t wDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t yDesc, - cudnnConvolutionFwdAlgo_t algo, - size_t *sizeInBytes ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, cudnnConvolutionFwdAlgo_t, size_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionForwardWorkspaceSize"); + cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, + const cudnnFilterDescriptor_t wDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t yDesc, cudnnConvolutionFwdAlgo_t algo, + size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, + const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnTensorDescriptor_t, cudnnConvolutionFwdAlgo_t, size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionForwardWorkspaceSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, xDesc, wDesc, convDesc, yDesc, algo, sizeInBytes); } cudnnStatus_t CUDNNWINAPI cudnnConvolutionForward( - cudnnHandle_t handle, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnConvolutionDescriptor_t convDesc, - cudnnConvolutionFwdAlgo_t algo, - void *workSpace, - size_t workSpaceSizeInBytes, - const void *beta, - const cudnnTensorDescriptor_t yDesc, - void *y ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnConvolutionDescriptor_t, cudnnConvolutionFwdAlgo_t, void *, size_t, const void *, const cudnnTensorDescriptor_t, void *); + cudnnHandle_t handle, const void *alpha, + const cudnnTensorDescriptor_t xDesc, const void *x, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnConvolutionDescriptor_t convDesc, cudnnConvolutionFwdAlgo_t algo, + void *workSpace, size_t workSpaceSizeInBytes, const void *beta, + const cudnnTensorDescriptor_t yDesc, void *y) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnConvolutionDescriptor_t, cudnnConvolutionFwdAlgo_t, void *, + size_t, const void *, const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnConvolutionForward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, alpha, xDesc, x, wDesc, w, convDesc, algo, workSpace, workSpaceSizeInBytes, beta, yDesc, y); + return func_ptr(handle, alpha, xDesc, x, wDesc, w, convDesc, algo, workSpace, + workSpaceSizeInBytes, beta, yDesc, y); } cudnnStatus_t CUDNNWINAPI cudnnConvolutionBiasActivationForward( - cudnnHandle_t handle, - const void *alpha1, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnConvolutionDescriptor_t convDesc, - cudnnConvolutionFwdAlgo_t algo, - void *workSpace, - size_t workSpaceSizeInBytes, - const void *alpha2, - const cudnnTensorDescriptor_t zDesc, - const void *z, - const cudnnTensorDescriptor_t biasDesc, - const void *bias, - const cudnnActivationDescriptor_t activationDesc, - const cudnnTensorDescriptor_t yDesc, - void *y ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnConvolutionDescriptor_t, cudnnConvolutionFwdAlgo_t, void *, size_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnActivationDescriptor_t, const cudnnTensorDescriptor_t, void *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnConvolutionBiasActivationForward"); + cudnnHandle_t handle, const void *alpha1, + const cudnnTensorDescriptor_t xDesc, const void *x, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnConvolutionDescriptor_t convDesc, cudnnConvolutionFwdAlgo_t algo, + void *workSpace, size_t workSpaceSizeInBytes, const void *alpha2, + const cudnnTensorDescriptor_t zDesc, const void *z, + const cudnnTensorDescriptor_t biasDesc, const void *bias, + const cudnnActivationDescriptor_t activationDesc, + const cudnnTensorDescriptor_t yDesc, void *y) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnConvolutionDescriptor_t, cudnnConvolutionFwdAlgo_t, void *, + size_t, const void *, const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnActivationDescriptor_t, const cudnnTensorDescriptor_t, void *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnConvolutionBiasActivationForward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, alpha1, xDesc, x, wDesc, w, convDesc, algo, workSpace, workSpaceSizeInBytes, alpha2, zDesc, z, biasDesc, bias, activationDesc, yDesc, y); + return func_ptr(handle, alpha1, xDesc, x, wDesc, w, convDesc, algo, workSpace, + workSpaceSizeInBytes, alpha2, zDesc, z, biasDesc, bias, + activationDesc, yDesc, y); } cudnnStatus_t CUDNNWINAPI cudnnConvolutionBackwardBias( - cudnnHandle_t handle, - const void *alpha, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const void *beta, - const cudnnTensorDescriptor_t dbDesc, - void *db ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); + cudnnHandle_t handle, const void *alpha, + const cudnnTensorDescriptor_t dyDesc, const void *dy, const void *beta, + const cudnnTensorDescriptor_t dbDesc, void *db) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, + const void *, const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnConvolutionBackwardBias"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, alpha, dyDesc, dy, beta, dbDesc, db); } -cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionBackwardFilterAlgorithmMaxCount( cudnnHandle_t handle, - int *count) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, int *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardFilterAlgorithmMaxCount"); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionBackwardFilterAlgorithmMaxCount( + cudnnHandle_t handle, int *count) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t, int *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardFilterAlgorithmMaxCount"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, count); } cudnnStatus_t CUDNNWINAPI cudnnFindConvolutionBackwardFilterAlgorithm( - cudnnHandle_t handle, - const cudnnTensorDescriptor_t xDesc, - const cudnnTensorDescriptor_t dyDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnFilterDescriptor_t dwDesc, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnConvolutionBwdFilterAlgoPerf_t *perfResults ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnFilterDescriptor_t, const int, int *, cudnnConvolutionBwdFilterAlgoPerf_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindConvolutionBackwardFilterAlgorithm"); + cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, + const cudnnTensorDescriptor_t dyDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnFilterDescriptor_t dwDesc, const int requestedAlgoCount, + int *returnedAlgoCount, cudnnConvolutionBwdFilterAlgoPerf_t *perfResults) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnFilterDescriptor_t, const int, int *, + cudnnConvolutionBwdFilterAlgoPerf_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnFindConvolutionBackwardFilterAlgorithm"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, xDesc, dyDesc, convDesc, dwDesc, requestedAlgoCount, returnedAlgoCount, perfResults); + return func_ptr(handle, xDesc, dyDesc, convDesc, dwDesc, requestedAlgoCount, + returnedAlgoCount, perfResults); } cudnnStatus_t CUDNNWINAPI cudnnFindConvolutionBackwardFilterAlgorithmEx( - cudnnHandle_t handle, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const cudnnTensorDescriptor_t dyDesc, - const void *y, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnFilterDescriptor_t dwDesc, - void *dw, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnConvolutionBwdFilterAlgoPerf_t *perfResults, - void *workSpace, - size_t workSpaceSizeInBytes ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnConvolutionDescriptor_t, const cudnnFilterDescriptor_t, void *, const int, int *, cudnnConvolutionBwdFilterAlgoPerf_t *, void *, size_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindConvolutionBackwardFilterAlgorithmEx"); + cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, const void *x, + const cudnnTensorDescriptor_t dyDesc, const void *y, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnFilterDescriptor_t dwDesc, void *dw, + const int requestedAlgoCount, int *returnedAlgoCount, + cudnnConvolutionBwdFilterAlgoPerf_t *perfResults, void *workSpace, + size_t workSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnConvolutionDescriptor_t, const cudnnFilterDescriptor_t, void *, + const int, int *, cudnnConvolutionBwdFilterAlgoPerf_t *, void *, size_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnFindConvolutionBackwardFilterAlgorithmEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, xDesc, x, dyDesc, y, convDesc, dwDesc, dw, requestedAlgoCount, returnedAlgoCount, perfResults, workSpace, workSpaceSizeInBytes); + return func_ptr(handle, xDesc, x, dyDesc, y, convDesc, dwDesc, dw, + requestedAlgoCount, returnedAlgoCount, perfResults, workSpace, + workSpaceSizeInBytes); } cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionBackwardFilterAlgorithm( - cudnnHandle_t handle, - const cudnnTensorDescriptor_t xDesc, - const cudnnTensorDescriptor_t dyDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnFilterDescriptor_t dwDesc, - cudnnConvolutionBwdFilterPreference_t preference, - size_t memoryLimitInBytes, - cudnnConvolutionBwdFilterAlgo_t *algo ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnFilterDescriptor_t, cudnnConvolutionBwdFilterPreference_t, size_t, cudnnConvolutionBwdFilterAlgo_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardFilterAlgorithm"); + cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, + const cudnnTensorDescriptor_t dyDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnFilterDescriptor_t dwDesc, + cudnnConvolutionBwdFilterPreference_t preference, size_t memoryLimitInBytes, + cudnnConvolutionBwdFilterAlgo_t *algo) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnFilterDescriptor_t, cudnnConvolutionBwdFilterPreference_t, + size_t, cudnnConvolutionBwdFilterAlgo_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardFilterAlgorithm"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, xDesc, dyDesc, convDesc, dwDesc, preference, memoryLimitInBytes, algo); + return func_ptr(handle, xDesc, dyDesc, convDesc, dwDesc, preference, + memoryLimitInBytes, algo); } cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionBackwardFilterAlgorithm_v7( - cudnnHandle_t handle, - const cudnnTensorDescriptor_t srcDesc, - const cudnnTensorDescriptor_t diffDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnFilterDescriptor_t gradDesc, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnConvolutionBwdFilterAlgoPerf_t *perfResults) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnFilterDescriptor_t, const int, int *, cudnnConvolutionBwdFilterAlgoPerf_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardFilterAlgorithm_v7"); + cudnnHandle_t handle, const cudnnTensorDescriptor_t srcDesc, + const cudnnTensorDescriptor_t diffDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnFilterDescriptor_t gradDesc, const int requestedAlgoCount, + int *returnedAlgoCount, cudnnConvolutionBwdFilterAlgoPerf_t *perfResults) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnFilterDescriptor_t, const int, int *, + cudnnConvolutionBwdFilterAlgoPerf_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardFilterAlgorithm_v7"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, srcDesc, diffDesc, convDesc, gradDesc, requestedAlgoCount, returnedAlgoCount, perfResults); + return func_ptr(handle, srcDesc, diffDesc, convDesc, gradDesc, + requestedAlgoCount, returnedAlgoCount, perfResults); } cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionBackwardFilterWorkspaceSize( - cudnnHandle_t handle, - const cudnnTensorDescriptor_t xDesc, - const cudnnTensorDescriptor_t dyDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnFilterDescriptor_t gradDesc, - cudnnConvolutionBwdFilterAlgo_t algo, - size_t *sizeInBytes ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnFilterDescriptor_t, cudnnConvolutionBwdFilterAlgo_t, size_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardFilterWorkspaceSize"); + cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, + const cudnnTensorDescriptor_t dyDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnFilterDescriptor_t gradDesc, + cudnnConvolutionBwdFilterAlgo_t algo, size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnFilterDescriptor_t, cudnnConvolutionBwdFilterAlgo_t, size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardFilterWorkspaceSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, xDesc, dyDesc, convDesc, gradDesc, algo, sizeInBytes); } cudnnStatus_t CUDNNWINAPI cudnnConvolutionBackwardFilter( - cudnnHandle_t handle, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const cudnnConvolutionDescriptor_t convDesc, - cudnnConvolutionBwdFilterAlgo_t algo, - void *workSpace, - size_t workSpaceSizeInBytes, - const void *beta, - const cudnnFilterDescriptor_t dwDesc, - void *dw ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnConvolutionDescriptor_t, cudnnConvolutionBwdFilterAlgo_t, void *, size_t, const void *, const cudnnFilterDescriptor_t, void *); + cudnnHandle_t handle, const void *alpha, + const cudnnTensorDescriptor_t xDesc, const void *x, + const cudnnTensorDescriptor_t dyDesc, const void *dy, + const cudnnConvolutionDescriptor_t convDesc, + cudnnConvolutionBwdFilterAlgo_t algo, void *workSpace, + size_t workSpaceSizeInBytes, const void *beta, + const cudnnFilterDescriptor_t dwDesc, void *dw) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnConvolutionDescriptor_t, cudnnConvolutionBwdFilterAlgo_t, + void *, size_t, const void *, const cudnnFilterDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnConvolutionBackwardFilter"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, alpha, xDesc, x, dyDesc, dy, convDesc, algo, workSpace, workSpaceSizeInBytes, beta, dwDesc, dw); + return func_ptr(handle, alpha, xDesc, x, dyDesc, dy, convDesc, algo, + workSpace, workSpaceSizeInBytes, beta, dwDesc, dw); } -cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionBackwardDataAlgorithmMaxCount( cudnnHandle_t handle, - int *count) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, int *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardDataAlgorithmMaxCount"); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionBackwardDataAlgorithmMaxCount( + cudnnHandle_t handle, int *count) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t, int *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardDataAlgorithmMaxCount"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, count); } cudnnStatus_t CUDNNWINAPI cudnnFindConvolutionBackwardDataAlgorithm( - cudnnHandle_t handle, - const cudnnFilterDescriptor_t wDesc, - const cudnnTensorDescriptor_t dyDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t dxDesc, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnConvolutionBwdDataAlgoPerf_t *perfResults ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnFilterDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, const int, int *, cudnnConvolutionBwdDataAlgoPerf_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindConvolutionBackwardDataAlgorithm"); + cudnnHandle_t handle, const cudnnFilterDescriptor_t wDesc, + const cudnnTensorDescriptor_t dyDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t dxDesc, const int requestedAlgoCount, + int *returnedAlgoCount, cudnnConvolutionBwdDataAlgoPerf_t *perfResults) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnFilterDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnTensorDescriptor_t, const int, int *, + cudnnConvolutionBwdDataAlgoPerf_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnFindConvolutionBackwardDataAlgorithm"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, wDesc, dyDesc, convDesc, dxDesc, requestedAlgoCount, returnedAlgoCount, perfResults); + return func_ptr(handle, wDesc, dyDesc, convDesc, dxDesc, requestedAlgoCount, + returnedAlgoCount, perfResults); } cudnnStatus_t CUDNNWINAPI cudnnFindConvolutionBackwardDataAlgorithmEx( - cudnnHandle_t handle, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t dxDesc, - void *dx, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnConvolutionBwdDataAlgoPerf_t *perfResults, - void *workSpace, - size_t workSpaceSizeInBytes ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnFilterDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, void *, const int, int *, cudnnConvolutionBwdDataAlgoPerf_t *, void *, size_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindConvolutionBackwardDataAlgorithmEx"); + cudnnHandle_t handle, const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnTensorDescriptor_t dyDesc, const void *dy, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t dxDesc, void *dx, + const int requestedAlgoCount, int *returnedAlgoCount, + cudnnConvolutionBwdDataAlgoPerf_t *perfResults, void *workSpace, + size_t workSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnFilterDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, void *, + const int, int *, cudnnConvolutionBwdDataAlgoPerf_t *, void *, size_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnFindConvolutionBackwardDataAlgorithmEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, wDesc, w, dyDesc, dy, convDesc, dxDesc, dx, requestedAlgoCount, returnedAlgoCount, perfResults, workSpace, workSpaceSizeInBytes); + return func_ptr(handle, wDesc, w, dyDesc, dy, convDesc, dxDesc, dx, + requestedAlgoCount, returnedAlgoCount, perfResults, workSpace, + workSpaceSizeInBytes); } cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionBackwardDataAlgorithm( - cudnnHandle_t handle, - const cudnnFilterDescriptor_t wDesc, - const cudnnTensorDescriptor_t dyDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t dxDesc, - cudnnConvolutionBwdDataPreference_t preference, - size_t memoryLimitInBytes, - cudnnConvolutionBwdDataAlgo_t *algo ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnFilterDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, cudnnConvolutionBwdDataPreference_t, size_t, cudnnConvolutionBwdDataAlgo_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardDataAlgorithm"); + cudnnHandle_t handle, const cudnnFilterDescriptor_t wDesc, + const cudnnTensorDescriptor_t dyDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t dxDesc, + cudnnConvolutionBwdDataPreference_t preference, size_t memoryLimitInBytes, + cudnnConvolutionBwdDataAlgo_t *algo) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnFilterDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnTensorDescriptor_t, cudnnConvolutionBwdDataPreference_t, + size_t, cudnnConvolutionBwdDataAlgo_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardDataAlgorithm"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, wDesc, dyDesc, convDesc, dxDesc, preference, memoryLimitInBytes, algo); + return func_ptr(handle, wDesc, dyDesc, convDesc, dxDesc, preference, + memoryLimitInBytes, algo); } cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionBackwardDataAlgorithm_v7( - cudnnHandle_t handle, - const cudnnFilterDescriptor_t filterDesc, - const cudnnTensorDescriptor_t diffDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t gradDesc, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnConvolutionBwdDataAlgoPerf_t *perfResults) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnFilterDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, const int, int *, cudnnConvolutionBwdDataAlgoPerf_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardDataAlgorithm_v7"); + cudnnHandle_t handle, const cudnnFilterDescriptor_t filterDesc, + const cudnnTensorDescriptor_t diffDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t gradDesc, const int requestedAlgoCount, + int *returnedAlgoCount, cudnnConvolutionBwdDataAlgoPerf_t *perfResults) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnFilterDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnTensorDescriptor_t, const int, int *, + cudnnConvolutionBwdDataAlgoPerf_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardDataAlgorithm_v7"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, filterDesc, diffDesc, convDesc, gradDesc, requestedAlgoCount, returnedAlgoCount, perfResults); + return func_ptr(handle, filterDesc, diffDesc, convDesc, gradDesc, + requestedAlgoCount, returnedAlgoCount, perfResults); } cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionBackwardDataWorkspaceSize( - cudnnHandle_t handle, - const cudnnFilterDescriptor_t wDesc, - const cudnnTensorDescriptor_t dyDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t dxDesc, - cudnnConvolutionBwdDataAlgo_t algo, - size_t *sizeInBytes ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnFilterDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, cudnnConvolutionBwdDataAlgo_t, size_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardDataWorkspaceSize"); + cudnnHandle_t handle, const cudnnFilterDescriptor_t wDesc, + const cudnnTensorDescriptor_t dyDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t dxDesc, cudnnConvolutionBwdDataAlgo_t algo, + size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnFilterDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnTensorDescriptor_t, cudnnConvolutionBwdDataAlgo_t, size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardDataWorkspaceSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, wDesc, dyDesc, convDesc, dxDesc, algo, sizeInBytes); } cudnnStatus_t CUDNNWINAPI cudnnConvolutionBackwardData( - cudnnHandle_t handle, - const void *alpha, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const cudnnConvolutionDescriptor_t convDesc, - cudnnConvolutionBwdDataAlgo_t algo, - void *workSpace, - size_t workSpaceSizeInBytes, - const void *beta, - const cudnnTensorDescriptor_t dxDesc, - void *dx ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnConvolutionDescriptor_t, cudnnConvolutionBwdDataAlgo_t, void *, size_t, const void *, const cudnnTensorDescriptor_t, void *); + cudnnHandle_t handle, const void *alpha, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnTensorDescriptor_t dyDesc, const void *dy, + const cudnnConvolutionDescriptor_t convDesc, + cudnnConvolutionBwdDataAlgo_t algo, void *workSpace, + size_t workSpaceSizeInBytes, const void *beta, + const cudnnTensorDescriptor_t dxDesc, void *dx) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const void *, const cudnnFilterDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnConvolutionDescriptor_t, cudnnConvolutionBwdDataAlgo_t, void *, + size_t, const void *, const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnConvolutionBackwardData"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, alpha, wDesc, w, dyDesc, dy, convDesc, algo, workSpace, workSpaceSizeInBytes, beta, dxDesc, dx); + return func_ptr(handle, alpha, wDesc, w, dyDesc, dy, convDesc, algo, + workSpace, workSpaceSizeInBytes, beta, dxDesc, dx); } -cudnnStatus_t CUDNNWINAPI cudnnIm2Col( - cudnnHandle_t handle, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const cudnnFilterDescriptor_t wDesc, - const cudnnConvolutionDescriptor_t convDesc, - void *colBuffer ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI +cudnnIm2Col(cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, + const void *x, const cudnnFilterDescriptor_t wDesc, + const cudnnConvolutionDescriptor_t convDesc, void *colBuffer) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, + const void *, const cudnnFilterDescriptor_t, + const cudnnConvolutionDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnIm2Col"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, xDesc, x, wDesc, convDesc, colBuffer); } cudnnStatus_t CUDNNWINAPI cudnnSoftmaxForward( - cudnnHandle_t handle, - cudnnSoftmaxAlgorithm_t algo, - cudnnSoftmaxMode_t mode, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t yDesc, - void *y ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnSoftmaxAlgorithm_t, cudnnSoftmaxMode_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); + cudnnHandle_t handle, cudnnSoftmaxAlgorithm_t algo, cudnnSoftmaxMode_t mode, + const void *alpha, const cudnnTensorDescriptor_t xDesc, const void *x, + const void *beta, const cudnnTensorDescriptor_t yDesc, void *y) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnSoftmaxAlgorithm_t, cudnnSoftmaxMode_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSoftmaxForward"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, algo, mode, alpha, xDesc, x, beta, yDesc, y); } cudnnStatus_t CUDNNWINAPI cudnnSoftmaxBackward( - cudnnHandle_t handle, - cudnnSoftmaxAlgorithm_t algo, - cudnnSoftmaxMode_t mode, - const void *alpha, - const cudnnTensorDescriptor_t yDesc, - const void *y, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const void *beta, - const cudnnTensorDescriptor_t dxDesc, - void *dx ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnSoftmaxAlgorithm_t, cudnnSoftmaxMode_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); + cudnnHandle_t handle, cudnnSoftmaxAlgorithm_t algo, cudnnSoftmaxMode_t mode, + const void *alpha, const cudnnTensorDescriptor_t yDesc, const void *y, + const cudnnTensorDescriptor_t dyDesc, const void *dy, const void *beta, + const cudnnTensorDescriptor_t dxDesc, void *dx) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnSoftmaxAlgorithm_t, cudnnSoftmaxMode_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSoftmaxBackward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, algo, mode, alpha, yDesc, y, dyDesc, dy, beta, dxDesc, dx); + return func_ptr(handle, algo, mode, alpha, yDesc, y, dyDesc, dy, beta, dxDesc, + dx); } -cudnnStatus_t CUDNNWINAPI cudnnCreatePoolingDescriptor( - cudnnPoolingDescriptor_t *poolingDesc ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnPoolingDescriptor_t *); +cudnnStatus_t CUDNNWINAPI +cudnnCreatePoolingDescriptor(cudnnPoolingDescriptor_t *poolingDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnPoolingDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreatePoolingDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(poolingDesc); } cudnnStatus_t CUDNNWINAPI cudnnSetPooling2dDescriptor( - cudnnPoolingDescriptor_t poolingDesc, - cudnnPoolingMode_t mode, - cudnnNanPropagation_t maxpoolingNanOpt, - int windowHeight, - int windowWidth, - int verticalPadding, - int horizontalPadding, - int verticalStride, - int horizontalStride ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnPoolingDescriptor_t, cudnnPoolingMode_t, cudnnNanPropagation_t, int, int, int, int, int, int); + cudnnPoolingDescriptor_t poolingDesc, cudnnPoolingMode_t mode, + cudnnNanPropagation_t maxpoolingNanOpt, int windowHeight, int windowWidth, + int verticalPadding, int horizontalPadding, int verticalStride, + int horizontalStride) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnPoolingDescriptor_t, cudnnPoolingMode_t, cudnnNanPropagation_t, int, + int, int, int, int, int); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetPooling2dDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(poolingDesc, mode, maxpoolingNanOpt, windowHeight, windowWidth, verticalPadding, horizontalPadding, verticalStride, horizontalStride); + return func_ptr(poolingDesc, mode, maxpoolingNanOpt, windowHeight, + windowWidth, verticalPadding, horizontalPadding, + verticalStride, horizontalStride); } cudnnStatus_t CUDNNWINAPI cudnnGetPooling2dDescriptor( - const cudnnPoolingDescriptor_t poolingDesc, - cudnnPoolingMode_t *mode, - cudnnNanPropagation_t *maxpoolingNanOpt, - int *windowHeight, - int *windowWidth, - int *verticalPadding, - int *horizontalPadding, - int *verticalStride, - int *horizontalStride ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnPoolingDescriptor_t, cudnnPoolingMode_t *, cudnnNanPropagation_t *, int *, int *, int *, int *, int *, int *); + const cudnnPoolingDescriptor_t poolingDesc, cudnnPoolingMode_t *mode, + cudnnNanPropagation_t *maxpoolingNanOpt, int *windowHeight, + int *windowWidth, int *verticalPadding, int *horizontalPadding, + int *verticalStride, int *horizontalStride) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnPoolingDescriptor_t, cudnnPoolingMode_t *, + cudnnNanPropagation_t *, int *, int *, int *, int *, int *, int *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetPooling2dDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(poolingDesc, mode, maxpoolingNanOpt, windowHeight, windowWidth, verticalPadding, horizontalPadding, verticalStride, horizontalStride); + return func_ptr(poolingDesc, mode, maxpoolingNanOpt, windowHeight, + windowWidth, verticalPadding, horizontalPadding, + verticalStride, horizontalStride); } cudnnStatus_t CUDNNWINAPI cudnnSetPoolingNdDescriptor( - cudnnPoolingDescriptor_t poolingDesc, - const cudnnPoolingMode_t mode, - const cudnnNanPropagation_t maxpoolingNanOpt, - int nbDims, - const int windowDimA[], - const int paddingA[], - const int strideA[] ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnPoolingDescriptor_t, const cudnnPoolingMode_t, const cudnnNanPropagation_t, int, const int [], const int [], const int []); + cudnnPoolingDescriptor_t poolingDesc, const cudnnPoolingMode_t mode, + const cudnnNanPropagation_t maxpoolingNanOpt, int nbDims, + const int windowDimA[], const int paddingA[], const int strideA[]) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnPoolingDescriptor_t, const cudnnPoolingMode_t, + const cudnnNanPropagation_t, int, const int[], const int[], const int[]); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetPoolingNdDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(poolingDesc, mode, maxpoolingNanOpt, nbDims, windowDimA, paddingA, strideA); + return func_ptr(poolingDesc, mode, maxpoolingNanOpt, nbDims, windowDimA, + paddingA, strideA); } cudnnStatus_t CUDNNWINAPI cudnnGetPoolingNdDescriptor( - const cudnnPoolingDescriptor_t poolingDesc, - int nbDimsRequested, - cudnnPoolingMode_t *mode, - cudnnNanPropagation_t *maxpoolingNanOpt, - int *nbDims, - int windowDimA[], - int paddingA[], - int strideA[] ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnPoolingDescriptor_t, int, cudnnPoolingMode_t *, cudnnNanPropagation_t *, int *, int [], int [], int []); + const cudnnPoolingDescriptor_t poolingDesc, int nbDimsRequested, + cudnnPoolingMode_t *mode, cudnnNanPropagation_t *maxpoolingNanOpt, + int *nbDims, int windowDimA[], int paddingA[], int strideA[]) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnPoolingDescriptor_t, int, cudnnPoolingMode_t *, + cudnnNanPropagation_t *, int *, int[], int[], int[]); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetPoolingNdDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(poolingDesc, nbDimsRequested, mode, maxpoolingNanOpt, nbDims, windowDimA, paddingA, strideA); + return func_ptr(poolingDesc, nbDimsRequested, mode, maxpoolingNanOpt, nbDims, + windowDimA, paddingA, strideA); } -cudnnStatus_t CUDNNWINAPI cudnnGetPoolingNdForwardOutputDim( - const cudnnPoolingDescriptor_t poolingDesc, - const cudnnTensorDescriptor_t inputTensorDesc, - int nbDims, - int outputTensorDimA[] ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnPoolingDescriptor_t, const cudnnTensorDescriptor_t, int, int []); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetPoolingNdForwardOutputDim"); +cudnnStatus_t CUDNNWINAPI +cudnnGetPoolingNdForwardOutputDim(const cudnnPoolingDescriptor_t poolingDesc, + const cudnnTensorDescriptor_t inputTensorDesc, + int nbDims, int outputTensorDimA[]) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(const cudnnPoolingDescriptor_t, + const cudnnTensorDescriptor_t, int, int[]); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetPoolingNdForwardOutputDim"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(poolingDesc, inputTensorDesc, nbDims, outputTensorDimA); } -cudnnStatus_t CUDNNWINAPI cudnnGetPooling2dForwardOutputDim( - const cudnnPoolingDescriptor_t poolingDesc, - const cudnnTensorDescriptor_t inputTensorDesc, - int *n, - int *c, - int *h, - int *w ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnPoolingDescriptor_t, const cudnnTensorDescriptor_t, int *, int *, int *, int *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetPooling2dForwardOutputDim"); +cudnnStatus_t CUDNNWINAPI +cudnnGetPooling2dForwardOutputDim(const cudnnPoolingDescriptor_t poolingDesc, + const cudnnTensorDescriptor_t inputTensorDesc, + int *n, int *c, int *h, int *w) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(const cudnnPoolingDescriptor_t, + const cudnnTensorDescriptor_t, + int *, int *, int *, int *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetPooling2dForwardOutputDim"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(poolingDesc, inputTensorDesc, n, c, h, w); } -cudnnStatus_t CUDNNWINAPI cudnnDestroyPoolingDescriptor( - cudnnPoolingDescriptor_t poolingDesc ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnPoolingDescriptor_t); +cudnnStatus_t CUDNNWINAPI +cudnnDestroyPoolingDescriptor(cudnnPoolingDescriptor_t poolingDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnPoolingDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyPoolingDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(poolingDesc); } cudnnStatus_t CUDNNWINAPI cudnnPoolingForward( - cudnnHandle_t handle, - const cudnnPoolingDescriptor_t poolingDesc, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t yDesc, - void *y ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnPoolingDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); + cudnnHandle_t handle, const cudnnPoolingDescriptor_t poolingDesc, + const void *alpha, const cudnnTensorDescriptor_t xDesc, const void *x, + const void *beta, const cudnnTensorDescriptor_t yDesc, void *y) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnPoolingDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnPoolingForward"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, poolingDesc, alpha, xDesc, x, beta, yDesc, y); } cudnnStatus_t CUDNNWINAPI cudnnPoolingBackward( - cudnnHandle_t handle, - const cudnnPoolingDescriptor_t poolingDesc, - const void *alpha, - const cudnnTensorDescriptor_t yDesc, - const void *y, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t dxDesc, - void *dx ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnPoolingDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); + cudnnHandle_t handle, const cudnnPoolingDescriptor_t poolingDesc, + const void *alpha, const cudnnTensorDescriptor_t yDesc, const void *y, + const cudnnTensorDescriptor_t dyDesc, const void *dy, + const cudnnTensorDescriptor_t xDesc, const void *x, const void *beta, + const cudnnTensorDescriptor_t dxDesc, void *dx) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnPoolingDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnPoolingBackward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, poolingDesc, alpha, yDesc, y, dyDesc, dy, xDesc, x, beta, dxDesc, dx); + return func_ptr(handle, poolingDesc, alpha, yDesc, y, dyDesc, dy, xDesc, x, + beta, dxDesc, dx); } -cudnnStatus_t CUDNNWINAPI cudnnCreateActivationDescriptor( - cudnnActivationDescriptor_t *activationDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnActivationDescriptor_t *); +cudnnStatus_t CUDNNWINAPI +cudnnCreateActivationDescriptor(cudnnActivationDescriptor_t *activationDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnActivationDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateActivationDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(activationDesc); } cudnnStatus_t CUDNNWINAPI cudnnSetActivationDescriptor( - cudnnActivationDescriptor_t activationDesc, - cudnnActivationMode_t mode, - cudnnNanPropagation_t reluNanOpt, - double coef ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnActivationDescriptor_t, cudnnActivationMode_t, cudnnNanPropagation_t, double); + cudnnActivationDescriptor_t activationDesc, cudnnActivationMode_t mode, + cudnnNanPropagation_t reluNanOpt, double coef) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnActivationDescriptor_t, + cudnnActivationMode_t, + cudnnNanPropagation_t, double); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetActivationDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(activationDesc, mode, reluNanOpt, coef); } -cudnnStatus_t CUDNNWINAPI cudnnGetActivationDescriptor( - const cudnnActivationDescriptor_t activationDesc, - cudnnActivationMode_t *mode, - cudnnNanPropagation_t *reluNanOpt, - double* coef ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnActivationDescriptor_t, cudnnActivationMode_t *, cudnnNanPropagation_t *, double *); +cudnnStatus_t CUDNNWINAPI +cudnnGetActivationDescriptor(const cudnnActivationDescriptor_t activationDesc, + cudnnActivationMode_t *mode, + cudnnNanPropagation_t *reluNanOpt, double *coef) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnActivationDescriptor_t, cudnnActivationMode_t *, + cudnnNanPropagation_t *, double *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetActivationDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(activationDesc, mode, reluNanOpt, coef); } -cudnnStatus_t CUDNNWINAPI cudnnDestroyActivationDescriptor( - cudnnActivationDescriptor_t activationDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnActivationDescriptor_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyActivationDescriptor"); +cudnnStatus_t CUDNNWINAPI +cudnnDestroyActivationDescriptor(cudnnActivationDescriptor_t activationDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnActivationDescriptor_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnDestroyActivationDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(activationDesc); } cudnnStatus_t CUDNNWINAPI cudnnActivationForward( - cudnnHandle_t handle, - cudnnActivationDescriptor_t activationDesc, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t yDesc, - void *y ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnActivationDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); + cudnnHandle_t handle, cudnnActivationDescriptor_t activationDesc, + const void *alpha, const cudnnTensorDescriptor_t xDesc, const void *x, + const void *beta, const cudnnTensorDescriptor_t yDesc, void *y) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnActivationDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnActivationForward"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, activationDesc, alpha, xDesc, x, beta, yDesc, y); } cudnnStatus_t CUDNNWINAPI cudnnActivationBackward( - cudnnHandle_t handle, - cudnnActivationDescriptor_t activationDesc, - const void *alpha, - const cudnnTensorDescriptor_t yDesc, - const void *y, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t dxDesc, - void *dx ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnActivationDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); + cudnnHandle_t handle, cudnnActivationDescriptor_t activationDesc, + const void *alpha, const cudnnTensorDescriptor_t yDesc, const void *y, + const cudnnTensorDescriptor_t dyDesc, const void *dy, + const cudnnTensorDescriptor_t xDesc, const void *x, const void *beta, + const cudnnTensorDescriptor_t dxDesc, void *dx) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnActivationDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnActivationBackward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, activationDesc, alpha, yDesc, y, dyDesc, dy, xDesc, x, beta, dxDesc, dx); + return func_ptr(handle, activationDesc, alpha, yDesc, y, dyDesc, dy, xDesc, x, + beta, dxDesc, dx); } -cudnnStatus_t CUDNNWINAPI cudnnCreateLRNDescriptor( - cudnnLRNDescriptor_t *normDesc ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnLRNDescriptor_t *); +cudnnStatus_t CUDNNWINAPI +cudnnCreateLRNDescriptor(cudnnLRNDescriptor_t *normDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnLRNDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateLRNDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(normDesc); } -cudnnStatus_t CUDNNWINAPI cudnnSetLRNDescriptor( - cudnnLRNDescriptor_t normDesc, - unsigned lrnN, - double lrnAlpha, - double lrnBeta, - double lrnK ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnLRNDescriptor_t, unsigned int, double, double, double); +cudnnStatus_t CUDNNWINAPI cudnnSetLRNDescriptor(cudnnLRNDescriptor_t normDesc, + unsigned lrnN, double lrnAlpha, + double lrnBeta, double lrnK) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnLRNDescriptor_t, unsigned int, double, double, double); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetLRNDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(normDesc, lrnN, lrnAlpha, lrnBeta, lrnK); } -cudnnStatus_t CUDNNWINAPI cudnnGetLRNDescriptor( - cudnnLRNDescriptor_t normDesc, - unsigned* lrnN, - double* lrnAlpha, - double* lrnBeta, - double* lrnK ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnLRNDescriptor_t, unsigned int *, double *, double *, double *); +cudnnStatus_t CUDNNWINAPI cudnnGetLRNDescriptor(cudnnLRNDescriptor_t normDesc, + unsigned *lrnN, + double *lrnAlpha, + double *lrnBeta, double *lrnK) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnLRNDescriptor_t, unsigned int *, double *, double *, double *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetLRNDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(normDesc, lrnN, lrnAlpha, lrnBeta, lrnK); } -cudnnStatus_t CUDNNWINAPI cudnnDestroyLRNDescriptor( cudnnLRNDescriptor_t lrnDesc ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnLRNDescriptor_t); +cudnnStatus_t CUDNNWINAPI +cudnnDestroyLRNDescriptor(cudnnLRNDescriptor_t lrnDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnLRNDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyLRNDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(lrnDesc); } cudnnStatus_t CUDNNWINAPI cudnnLRNCrossChannelForward( - cudnnHandle_t handle, - cudnnLRNDescriptor_t normDesc, - cudnnLRNMode_t lrnMode, - const void* alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t yDesc, - void *y ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnLRNDescriptor_t, cudnnLRNMode_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); + cudnnHandle_t handle, cudnnLRNDescriptor_t normDesc, cudnnLRNMode_t lrnMode, + const void *alpha, const cudnnTensorDescriptor_t xDesc, const void *x, + const void *beta, const cudnnTensorDescriptor_t yDesc, void *y) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnLRNDescriptor_t, cudnnLRNMode_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnLRNCrossChannelForward"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, normDesc, lrnMode, alpha, xDesc, x, beta, yDesc, y); } cudnnStatus_t CUDNNWINAPI cudnnLRNCrossChannelBackward( - cudnnHandle_t handle, - cudnnLRNDescriptor_t normDesc, - cudnnLRNMode_t lrnMode, - const void* alpha, - const cudnnTensorDescriptor_t yDesc, - const void *y, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t dxDesc, - void *dx) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnLRNDescriptor_t, cudnnLRNMode_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); + cudnnHandle_t handle, cudnnLRNDescriptor_t normDesc, cudnnLRNMode_t lrnMode, + const void *alpha, const cudnnTensorDescriptor_t yDesc, const void *y, + const cudnnTensorDescriptor_t dyDesc, const void *dy, + const cudnnTensorDescriptor_t xDesc, const void *x, const void *beta, + const cudnnTensorDescriptor_t dxDesc, void *dx) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnLRNDescriptor_t, cudnnLRNMode_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnLRNCrossChannelBackward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, normDesc, lrnMode, alpha, yDesc, y, dyDesc, dy, xDesc, x, beta, dxDesc, dx); + return func_ptr(handle, normDesc, lrnMode, alpha, yDesc, y, dyDesc, dy, xDesc, + x, beta, dxDesc, dx); } cudnnStatus_t CUDNNWINAPI cudnnDivisiveNormalizationForward( - cudnnHandle_t handle, - cudnnLRNDescriptor_t normDesc, - cudnnDivNormMode_t mode, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, /* same desc for means, temp, temp2 */ - const void *x, - const void *means, /* if NULL, means are assumed to be zero */ - void *temp, - void *temp2, - const void *beta, - const cudnnTensorDescriptor_t yDesc, - void *y ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnLRNDescriptor_t, cudnnDivNormMode_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, void *, void *, const void *, const cudnnTensorDescriptor_t, void *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDivisiveNormalizationForward"); + cudnnHandle_t handle, cudnnLRNDescriptor_t normDesc, + cudnnDivNormMode_t mode, const void *alpha, + const cudnnTensorDescriptor_t xDesc, /* same desc for means, temp, temp2 */ + const void *x, + const void *means, /* if NULL, means are assumed to be zero */ + void *temp, void *temp2, const void *beta, + const cudnnTensorDescriptor_t yDesc, void *y) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnLRNDescriptor_t, cudnnDivNormMode_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, void *, void *, + const void *, const cudnnTensorDescriptor_t, void *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnDivisiveNormalizationForward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, normDesc, mode, alpha, xDesc, x, means, temp, temp2, beta, yDesc, y); + return func_ptr(handle, normDesc, mode, alpha, xDesc, x, means, temp, temp2, + beta, yDesc, y); } cudnnStatus_t CUDNNWINAPI cudnnDivisiveNormalizationBackward( - cudnnHandle_t handle, - cudnnLRNDescriptor_t normDesc, - cudnnDivNormMode_t mode, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, /* same desc for x, means, dy, temp, temp2 */ - const void *x, - const void *means, /* if NULL, means are assumed to be zero */ - const void *dy, - void *temp, - void *temp2, - const void *beta, - const cudnnTensorDescriptor_t dXdMeansDesc, /* same desc for dx, dMeans */ - void *dx, /* output x differential */ - void *dMeans ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnLRNDescriptor_t, cudnnDivNormMode_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const void *, void *, void *, const void *, const cudnnTensorDescriptor_t, void *, void *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDivisiveNormalizationBackward"); + cudnnHandle_t handle, cudnnLRNDescriptor_t normDesc, + cudnnDivNormMode_t mode, const void *alpha, + const cudnnTensorDescriptor_t + xDesc, /* same desc for x, means, dy, temp, temp2 */ + const void *x, + const void *means, /* if NULL, means are assumed to be zero */ + const void *dy, void *temp, void *temp2, const void *beta, + const cudnnTensorDescriptor_t dXdMeansDesc, /* same desc for dx, dMeans */ + void *dx, /* output x differential */ + void *dMeans) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnLRNDescriptor_t, cudnnDivNormMode_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, const void *, + void *, void *, const void *, const cudnnTensorDescriptor_t, void *, + void *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnDivisiveNormalizationBackward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, normDesc, mode, alpha, xDesc, x, means, dy, temp, temp2, beta, dXdMeansDesc, dx, dMeans); + return func_ptr(handle, normDesc, mode, alpha, xDesc, x, means, dy, temp, + temp2, beta, dXdMeansDesc, dx, dMeans); } cudnnStatus_t CUDNNWINAPI cudnnDeriveBNTensorDescriptor( - cudnnTensorDescriptor_t derivedBnDesc, - const cudnnTensorDescriptor_t xDesc, - cudnnBatchNormMode_t mode ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, cudnnBatchNormMode_t); + cudnnTensorDescriptor_t derivedBnDesc, const cudnnTensorDescriptor_t xDesc, + cudnnBatchNormMode_t mode) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnTensorDescriptor_t, + const cudnnTensorDescriptor_t, + cudnnBatchNormMode_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDeriveBNTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(derivedBnDesc, xDesc, mode); } cudnnStatus_t CUDNNWINAPI cudnnBatchNormalizationForwardTraining( - cudnnHandle_t handle, - cudnnBatchNormMode_t mode, + cudnnHandle_t handle, cudnnBatchNormMode_t mode, - const void *alpha, /* alpha[0] = result blend factor */ - const void *beta, /* beta[0] = dest layer blend factor */ + const void *alpha, /* alpha[0] = result blend factor */ + const void *beta, /* beta[0] = dest layer blend factor */ - const cudnnTensorDescriptor_t xDesc, - const void *x, /* NxCxHxW */ - const cudnnTensorDescriptor_t yDesc, - void *y, /* NxCxHxW */ + const cudnnTensorDescriptor_t xDesc, const void *x, /* NxCxHxW */ + const cudnnTensorDescriptor_t yDesc, void *y, /* NxCxHxW */ - /* Shared desc for the next 6 tensors in the argument list. - Data type to be set as follows: - type = (typeOf(x) == double) ? double : float - Dimensions for this descriptor depend on normalization mode - - Spatial Normalization : tensors are expected to have dims 1xCx1x1 - (normalization is performed across NxHxW) - - Per-Activation Normalization : tensors are expected to have dims of 1xCxHxW - (normalization is performed across N) */ - const cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc, + /* Shared desc for the next 6 tensors in the argument list. + Data type to be set as follows: + type = (typeOf(x) == double) ? double : float + Dimensions for this descriptor depend on normalization mode + - Spatial Normalization : tensors are expected to have dims 1xCx1x1 + (normalization is performed across NxHxW) + - Per-Activation Normalization : tensors are expected to have dims of + 1xCxHxW (normalization is performed across N) */ + const cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc, - /* 'Gamma' and 'Beta' respectively in Ioffe and Szegedy's paper's notation */ - const void *bnScale, - const void *bnBias, + /* 'Gamma' and 'Beta' respectively in Ioffe and Szegedy's paper's notation + */ + const void *bnScale, const void *bnBias, - /* MUST use factor=1 in the very first call of a complete training cycle. - Use a factor=1/(1+n) at N-th call to the function to get - Cumulative Moving Average (CMA) behavior - CMA[n] = (x[1]+...+x[n])/n - Since CMA[n+1] = (n*CMA[n]+x[n+1])/(n+1) = - ((n+1)*CMA[n]-CMA[n])/(n+1) + x[n+1]/(n+1) = - CMA[n]*(1-1/(n+1)) + x[n+1]*1/(n+1) */ - double exponentialAverageFactor, + /* MUST use factor=1 in the very first call of a complete training cycle. + Use a factor=1/(1+n) at N-th call to the function to get + Cumulative Moving Average (CMA) behavior + CMA[n] = (x[1]+...+x[n])/n + Since CMA[n+1] = (n*CMA[n]+x[n+1])/(n+1) = + ((n+1)*CMA[n]-CMA[n])/(n+1) + x[n+1]/(n+1) = + CMA[n]*(1-1/(n+1)) + x[n+1]*1/(n+1) */ + double exponentialAverageFactor, - /* Used in Training phase only. - runningMean = newMean*factor + runningMean*(1-factor) */ - void *resultRunningMean, - /* Output in training mode, input in inference. Is the moving average - of variance[x] (factor is applied in the same way as for runningMean) */ - void *resultRunningVariance, + /* Used in Training phase only. + runningMean = newMean*factor + runningMean*(1-factor) */ + void *resultRunningMean, + /* Output in training mode, input in inference. Is the moving average + of variance[x] (factor is applied in the same way as for runningMean) */ + void *resultRunningVariance, - /* Has to be >= CUDNN_BN_MIN_EPSILON. Should be the same in forward and backward functions. */ - double epsilon, + /* Has to be >= CUDNN_BN_MIN_EPSILON. Should be the same in forward and + backward functions. */ + double epsilon, - /* Optionally save intermediate results from the forward pass here - - can be reused to speed up backward pass. NULL if unused */ - void *resultSaveMean, - void *resultSaveInvVariance ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnBatchNormMode_t, const void *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, const void *, const void *, double, void *, void *, double, void *, void *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnBatchNormalizationForwardTraining"); + /* Optionally save intermediate results from the forward pass here + - can be reused to speed up backward pass. NULL if unused */ + void *resultSaveMean, void *resultSaveInvVariance) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnBatchNormMode_t, const void *, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, + const void *, const void *, double, void *, void *, double, void *, + void *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnBatchNormalizationForwardTraining"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, mode, alpha, beta, xDesc, x, yDesc, y, bnScaleBiasMeanVarDesc, bnScale, bnBias, exponentialAverageFactor, resultRunningMean, resultRunningVariance, epsilon, resultSaveMean, resultSaveInvVariance); + return func_ptr( + handle, mode, alpha, beta, xDesc, x, yDesc, y, bnScaleBiasMeanVarDesc, + bnScale, bnBias, exponentialAverageFactor, resultRunningMean, + resultRunningVariance, epsilon, resultSaveMean, resultSaveInvVariance); } cudnnStatus_t CUDNNWINAPI cudnnBatchNormalizationForwardInference( - cudnnHandle_t handle, - cudnnBatchNormMode_t mode, - const void *alpha, /* alpha[0] = result blend factor */ - const void *beta, /* beta[0] = dest layer blend factor */ - const cudnnTensorDescriptor_t xDesc, - const void *x, /* NxCxHxW */ - const cudnnTensorDescriptor_t yDesc, - void *y, /* NxCxHxW */ - const cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc, - const void *bnScale, - const void *bnBias, - const void *estimatedMean, - const void *estimatedVariance, - double epsilon ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnBatchNormMode_t, const void *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, const void *, const void *, const void *, const void *, double); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnBatchNormalizationForwardInference"); + cudnnHandle_t handle, cudnnBatchNormMode_t mode, + const void *alpha, /* alpha[0] = result blend factor */ + const void *beta, /* beta[0] = dest layer blend factor */ + const cudnnTensorDescriptor_t xDesc, const void *x, /* NxCxHxW */ + const cudnnTensorDescriptor_t yDesc, void *y, /* NxCxHxW */ + const cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc, const void *bnScale, + const void *bnBias, const void *estimatedMean, + const void *estimatedVariance, double epsilon) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnBatchNormMode_t, const void *, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, + const void *, const void *, const void *, const void *, double); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnBatchNormalizationForwardInference"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, mode, alpha, beta, xDesc, x, yDesc, y, bnScaleBiasMeanVarDesc, bnScale, bnBias, estimatedMean, estimatedVariance, epsilon); + return func_ptr(handle, mode, alpha, beta, xDesc, x, yDesc, y, + bnScaleBiasMeanVarDesc, bnScale, bnBias, estimatedMean, + estimatedVariance, epsilon); } cudnnStatus_t CUDNNWINAPI cudnnBatchNormalizationBackward( - cudnnHandle_t handle, - cudnnBatchNormMode_t mode, - const void *alphaDataDiff, - const void *betaDataDiff, - const void *alphaParamDiff, - const void *betaParamDiff, - const cudnnTensorDescriptor_t xDesc, /* same desc for x, dx, dy */ - const void *x, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const cudnnTensorDescriptor_t dxDesc, - void *dx, - /* Shared tensor desc for the 4 tensors below */ - const cudnnTensorDescriptor_t dBnScaleBiasDesc, - const void *bnScale, /* bnBias doesn't affect backpropagation */ - /* scale and bias diff are not backpropagated below this layer */ - void *dBnScaleResult, - void *dBnBiasResult, - /* Same epsilon as forward pass */ - double epsilon, + cudnnHandle_t handle, cudnnBatchNormMode_t mode, const void *alphaDataDiff, + const void *betaDataDiff, const void *alphaParamDiff, + const void *betaParamDiff, + const cudnnTensorDescriptor_t xDesc, /* same desc for x, dx, dy */ + const void *x, const cudnnTensorDescriptor_t dyDesc, const void *dy, + const cudnnTensorDescriptor_t dxDesc, void *dx, + /* Shared tensor desc for the 4 tensors below */ + const cudnnTensorDescriptor_t dBnScaleBiasDesc, + const void *bnScale, /* bnBias doesn't affect backpropagation */ + /* scale and bias diff are not backpropagated below this layer */ + void *dBnScaleResult, void *dBnBiasResult, + /* Same epsilon as forward pass */ + double epsilon, - /* Optionally cached intermediate results from - forward pass */ - const void *savedMean, - const void *savedInvVariance ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnBatchNormMode_t, const void *, const void *, const void *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, const void *, void *, void *, double, const void *, const void *); + /* Optionally cached intermediate results from + forward pass */ + const void *savedMean, const void *savedInvVariance) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnBatchNormMode_t, const void *, const void *, + const void *, const void *, const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, + const void *, void *, void *, double, const void *, const void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnBatchNormalizationBackward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, mode, alphaDataDiff, betaDataDiff, alphaParamDiff, betaParamDiff, xDesc, x, dyDesc, dy, dxDesc, dx, dBnScaleBiasDesc, bnScale, dBnScaleResult, dBnBiasResult, epsilon, savedMean, savedInvVariance); + return func_ptr(handle, mode, alphaDataDiff, betaDataDiff, alphaParamDiff, + betaParamDiff, xDesc, x, dyDesc, dy, dxDesc, dx, + dBnScaleBiasDesc, bnScale, dBnScaleResult, dBnBiasResult, + epsilon, savedMean, savedInvVariance); } cudnnStatus_t CUDNNWINAPI cudnnCreateSpatialTransformerDescriptor( - cudnnSpatialTransformerDescriptor_t *stDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnSpatialTransformerDescriptor_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateSpatialTransformerDescriptor"); + cudnnSpatialTransformerDescriptor_t *stDesc) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnSpatialTransformerDescriptor_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnCreateSpatialTransformerDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(stDesc); } cudnnStatus_t CUDNNWINAPI cudnnSetSpatialTransformerNdDescriptor( - cudnnSpatialTransformerDescriptor_t stDesc, - cudnnSamplerType_t samplerType, - cudnnDataType_t dataType, - const int nbDims, - const int dimA[]) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnSpatialTransformerDescriptor_t, cudnnSamplerType_t, cudnnDataType_t, const int, const int []); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetSpatialTransformerNdDescriptor"); + cudnnSpatialTransformerDescriptor_t stDesc, cudnnSamplerType_t samplerType, + cudnnDataType_t dataType, const int nbDims, const int dimA[]) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnSpatialTransformerDescriptor_t, cudnnSamplerType_t, cudnnDataType_t, + const int, const int[]); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnSetSpatialTransformerNdDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(stDesc, samplerType, dataType, nbDims, dimA); } cudnnStatus_t CUDNNWINAPI cudnnDestroySpatialTransformerDescriptor( - cudnnSpatialTransformerDescriptor_t stDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnSpatialTransformerDescriptor_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroySpatialTransformerDescriptor"); + cudnnSpatialTransformerDescriptor_t stDesc) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnSpatialTransformerDescriptor_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnDestroySpatialTransformerDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(stDesc); } cudnnStatus_t CUDNNWINAPI cudnnSpatialTfGridGeneratorForward( - cudnnHandle_t handle, - const cudnnSpatialTransformerDescriptor_t stDesc, - const void *theta, - void *grid) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnSpatialTransformerDescriptor_t, const void *, void *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSpatialTfGridGeneratorForward"); + cudnnHandle_t handle, const cudnnSpatialTransformerDescriptor_t stDesc, + const void *theta, void *grid) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnSpatialTransformerDescriptor_t, const void *, + void *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnSpatialTfGridGeneratorForward"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, stDesc, theta, grid); } cudnnStatus_t CUDNNWINAPI cudnnSpatialTfGridGeneratorBackward( - cudnnHandle_t handle, - const cudnnSpatialTransformerDescriptor_t stDesc, - const void *dgrid, - void *dtheta) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnSpatialTransformerDescriptor_t, const void *, void *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSpatialTfGridGeneratorBackward"); + cudnnHandle_t handle, const cudnnSpatialTransformerDescriptor_t stDesc, + const void *dgrid, void *dtheta) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnSpatialTransformerDescriptor_t, const void *, + void *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnSpatialTfGridGeneratorBackward"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, stDesc, dgrid, dtheta); } cudnnStatus_t CUDNNWINAPI cudnnSpatialTfSamplerForward( - cudnnHandle_t handle, - cudnnSpatialTransformerDescriptor_t stDesc, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *grid, - const void *beta, - cudnnTensorDescriptor_t yDesc, - void *y) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnSpatialTransformerDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const void *, cudnnTensorDescriptor_t, void *); + cudnnHandle_t handle, cudnnSpatialTransformerDescriptor_t stDesc, + const void *alpha, const cudnnTensorDescriptor_t xDesc, const void *x, + const void *grid, const void *beta, cudnnTensorDescriptor_t yDesc, + void *y) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnSpatialTransformerDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, const void *, + cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSpatialTfSamplerForward"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, stDesc, alpha, xDesc, x, grid, beta, yDesc, y); } cudnnStatus_t CUDNNWINAPI cudnnSpatialTfSamplerBackward( - cudnnHandle_t handle, - cudnnSpatialTransformerDescriptor_t stDesc, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t dxDesc, - void *dx, - const void *alphaDgrid, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const void *grid, - const void *betaDgrid, - void *dgrid) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnSpatialTransformerDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const void *, void *); + cudnnHandle_t handle, cudnnSpatialTransformerDescriptor_t stDesc, + const void *alpha, const cudnnTensorDescriptor_t xDesc, const void *x, + const void *beta, const cudnnTensorDescriptor_t dxDesc, void *dx, + const void *alphaDgrid, const cudnnTensorDescriptor_t dyDesc, + const void *dy, const void *grid, const void *betaDgrid, void *dgrid) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnSpatialTransformerDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, const void *, + void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSpatialTfSamplerBackward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, stDesc, alpha, xDesc, x, beta, dxDesc, dx, alphaDgrid, dyDesc, dy, grid, betaDgrid, dgrid); + return func_ptr(handle, stDesc, alpha, xDesc, x, beta, dxDesc, dx, alphaDgrid, + dyDesc, dy, grid, betaDgrid, dgrid); } -cudnnStatus_t CUDNNWINAPI cudnnCreateDropoutDescriptor(cudnnDropoutDescriptor_t * dropoutDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnDropoutDescriptor_t *); +cudnnStatus_t CUDNNWINAPI +cudnnCreateDropoutDescriptor(cudnnDropoutDescriptor_t *dropoutDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnDropoutDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateDropoutDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(dropoutDesc); } -cudnnStatus_t CUDNNWINAPI cudnnDestroyDropoutDescriptor(cudnnDropoutDescriptor_t dropoutDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnDropoutDescriptor_t); +cudnnStatus_t CUDNNWINAPI +cudnnDestroyDropoutDescriptor(cudnnDropoutDescriptor_t dropoutDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnDropoutDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyDropoutDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(dropoutDesc); } -cudnnStatus_t CUDNNWINAPI cudnnDropoutGetStatesSize(cudnnHandle_t handle, size_t * sizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, size_t *); +cudnnStatus_t CUDNNWINAPI cudnnDropoutGetStatesSize(cudnnHandle_t handle, + size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDropoutGetStatesSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, sizeInBytes); } -cudnnStatus_t CUDNNWINAPI cudnnDropoutGetReserveSpaceSize(cudnnTensorDescriptor_t xdesc, size_t * sizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorDescriptor_t, size_t *); +cudnnStatus_t CUDNNWINAPI cudnnDropoutGetReserveSpaceSize( + cudnnTensorDescriptor_t xdesc, size_t *sizeInBytes) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnTensorDescriptor_t, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDropoutGetReserveSpaceSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(xdesc, sizeInBytes); } -cudnnStatus_t CUDNNWINAPI cudnnSetDropoutDescriptor(cudnnDropoutDescriptor_t dropoutDesc, - cudnnHandle_t handle, - float dropout, - void * states, - size_t stateSizeInBytes, - unsigned long long seed) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnDropoutDescriptor_t, cudnnHandle_t, float, void *, size_t, unsigned long long); +cudnnStatus_t CUDNNWINAPI cudnnSetDropoutDescriptor( + cudnnDropoutDescriptor_t dropoutDesc, cudnnHandle_t handle, float dropout, + void *states, size_t stateSizeInBytes, unsigned long long seed) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnDropoutDescriptor_t, cudnnHandle_t, + float, void *, size_t, unsigned long long); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetDropoutDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(dropoutDesc, handle, dropout, states, stateSizeInBytes, seed); } -cudnnStatus_t CUDNNWINAPI cudnnRestoreDropoutDescriptor(cudnnDropoutDescriptor_t dropoutDesc, - cudnnHandle_t handle, - float dropout, - void * states, - size_t stateSizeInBytes, - unsigned long long seed) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnDropoutDescriptor_t, cudnnHandle_t, float, void *, size_t, unsigned long long); +cudnnStatus_t CUDNNWINAPI cudnnRestoreDropoutDescriptor( + cudnnDropoutDescriptor_t dropoutDesc, cudnnHandle_t handle, float dropout, + void *states, size_t stateSizeInBytes, unsigned long long seed) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnDropoutDescriptor_t, cudnnHandle_t, + float, void *, size_t, unsigned long long); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRestoreDropoutDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(dropoutDesc, handle, dropout, states, stateSizeInBytes, seed); } -cudnnStatus_t CUDNNWINAPI cudnnGetDropoutDescriptor(cudnnDropoutDescriptor_t dropoutDesc, - cudnnHandle_t handle, - float * dropout, - void ** states, - unsigned long long * seed) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnDropoutDescriptor_t, cudnnHandle_t, float *, void **, unsigned long long *); +cudnnStatus_t CUDNNWINAPI cudnnGetDropoutDescriptor( + cudnnDropoutDescriptor_t dropoutDesc, cudnnHandle_t handle, float *dropout, + void **states, unsigned long long *seed) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnDropoutDescriptor_t, cudnnHandle_t, + float *, void **, unsigned long long *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetDropoutDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(dropoutDesc, handle, dropout, states, seed); } -cudnnStatus_t CUDNNWINAPI cudnnDropoutForward(cudnnHandle_t handle, - const cudnnDropoutDescriptor_t dropoutDesc, - const cudnnTensorDescriptor_t xdesc, - const void * x, - const cudnnTensorDescriptor_t ydesc, - void * y, - void * reserveSpace, - size_t reserveSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnDropoutDescriptor_t, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, void *, void *, size_t); +cudnnStatus_t CUDNNWINAPI cudnnDropoutForward( + cudnnHandle_t handle, const cudnnDropoutDescriptor_t dropoutDesc, + const cudnnTensorDescriptor_t xdesc, const void *x, + const cudnnTensorDescriptor_t ydesc, void *y, void *reserveSpace, + size_t reserveSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnDropoutDescriptor_t, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, void *, void *, size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDropoutForward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, dropoutDesc, xdesc, x, ydesc, y, reserveSpace, reserveSpaceSizeInBytes); + return func_ptr(handle, dropoutDesc, xdesc, x, ydesc, y, reserveSpace, + reserveSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI cudnnDropoutBackward(cudnnHandle_t handle, - const cudnnDropoutDescriptor_t dropoutDesc, - const cudnnTensorDescriptor_t dydesc, - const void * dy, - const cudnnTensorDescriptor_t dxdesc, - void * dx, - void * reserveSpace, - size_t reserveSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnDropoutDescriptor_t, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, void *, void *, size_t); +cudnnStatus_t CUDNNWINAPI cudnnDropoutBackward( + cudnnHandle_t handle, const cudnnDropoutDescriptor_t dropoutDesc, + const cudnnTensorDescriptor_t dydesc, const void *dy, + const cudnnTensorDescriptor_t dxdesc, void *dx, void *reserveSpace, + size_t reserveSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnDropoutDescriptor_t, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, void *, void *, size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDropoutBackward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, dropoutDesc, dydesc, dy, dxdesc, dx, reserveSpace, reserveSpaceSizeInBytes); + return func_ptr(handle, dropoutDesc, dydesc, dy, dxdesc, dx, reserveSpace, + reserveSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI cudnnCreateRNNDescriptor(cudnnRNNDescriptor_t * rnnDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t *); +cudnnStatus_t CUDNNWINAPI +cudnnCreateRNNDescriptor(cudnnRNNDescriptor_t *rnnDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnRNNDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateRNNDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(rnnDesc); } -cudnnStatus_t CUDNNWINAPI cudnnDestroyRNNDescriptor(cudnnRNNDescriptor_t rnnDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t); +cudnnStatus_t CUDNNWINAPI +cudnnDestroyRNNDescriptor(cudnnRNNDescriptor_t rnnDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnRNNDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyRNNDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(rnnDesc); } cudnnStatus_t CUDNNWINAPI cudnnGetRNNForwardInferenceAlgorithmMaxCount( - cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - int *count) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, int *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNForwardInferenceAlgorithmMaxCount"); + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, int *count) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, int *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetRNNForwardInferenceAlgorithmMaxCount"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rnnDesc, count); } -cudnnStatus_t CUDNNWINAPI cudnnFindRNNForwardInferenceAlgorithmEx( cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int seqLength, - const cudnnTensorDescriptor_t * xDesc, - const void * x, - const cudnnTensorDescriptor_t hxDesc, - const void * hx, - const cudnnTensorDescriptor_t cxDesc, - const void * cx, - const cudnnFilterDescriptor_t wDesc, - const void * w, - const cudnnTensorDescriptor_t *yDesc, - void * y, - const cudnnTensorDescriptor_t hyDesc, - void * hy, - const cudnnTensorDescriptor_t cyDesc, - void * cy, - const float findIntensity, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnAlgorithmPerformance_t *perfResults, - void * workspace, - size_t workSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, const float, const int, int *, cudnnAlgorithmPerformance_t *, void *, size_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindRNNForwardInferenceAlgorithmEx"); +cudnnStatus_t CUDNNWINAPI cudnnFindRNNForwardInferenceAlgorithmEx( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, const cudnnTensorDescriptor_t *xDesc, const void *x, + const cudnnTensorDescriptor_t hxDesc, const void *hx, + const cudnnTensorDescriptor_t cxDesc, const void *cx, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnTensorDescriptor_t *yDesc, void *y, + const cudnnTensorDescriptor_t hyDesc, void *hy, + const cudnnTensorDescriptor_t cyDesc, void *cy, const float findIntensity, + const int requestedAlgoCount, int *returnedAlgoCount, + cudnnAlgorithmPerformance_t *perfResults, void *workspace, + size_t workSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t *, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, + void *, const cudnnTensorDescriptor_t, void *, const float, const int, + int *, cudnnAlgorithmPerformance_t *, void *, size_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnFindRNNForwardInferenceAlgorithmEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, cxDesc, cx, wDesc, w, yDesc, y, hyDesc, hy, cyDesc, cy, findIntensity, requestedAlgoCount, returnedAlgoCount, perfResults, workspace, workSpaceSizeInBytes); + return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, cxDesc, cx, + wDesc, w, yDesc, y, hyDesc, hy, cyDesc, cy, findIntensity, + requestedAlgoCount, returnedAlgoCount, perfResults, workspace, + workSpaceSizeInBytes); } cudnnStatus_t CUDNNWINAPI cudnnGetRNNForwardTrainingAlgorithmMaxCount( - cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - int *count) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, int *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNForwardTrainingAlgorithmMaxCount"); + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, int *count) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, int *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetRNNForwardTrainingAlgorithmMaxCount"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rnnDesc, count); } -cudnnStatus_t CUDNNWINAPI cudnnFindRNNForwardTrainingAlgorithmEx( cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int seqLength, - const cudnnTensorDescriptor_t * xDesc, - const void * x, - const cudnnTensorDescriptor_t hxDesc, - const void * hx, - const cudnnTensorDescriptor_t cxDesc, - const void * cx, - const cudnnFilterDescriptor_t wDesc, - const void * w, - const cudnnTensorDescriptor_t *yDesc, - void * y, - const cudnnTensorDescriptor_t hyDesc, - void * hy, - const cudnnTensorDescriptor_t cyDesc, - void * cy, - const float findIntensity, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnAlgorithmPerformance_t *perfResults, - void * workspace, - size_t workSpaceSizeInBytes, - void * reserveSpace, - size_t reserveSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, const float, const int, int *, cudnnAlgorithmPerformance_t *, void *, size_t, void *, size_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindRNNForwardTrainingAlgorithmEx"); +cudnnStatus_t CUDNNWINAPI cudnnFindRNNForwardTrainingAlgorithmEx( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, const cudnnTensorDescriptor_t *xDesc, const void *x, + const cudnnTensorDescriptor_t hxDesc, const void *hx, + const cudnnTensorDescriptor_t cxDesc, const void *cx, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnTensorDescriptor_t *yDesc, void *y, + const cudnnTensorDescriptor_t hyDesc, void *hy, + const cudnnTensorDescriptor_t cyDesc, void *cy, const float findIntensity, + const int requestedAlgoCount, int *returnedAlgoCount, + cudnnAlgorithmPerformance_t *perfResults, void *workspace, + size_t workSpaceSizeInBytes, void *reserveSpace, + size_t reserveSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t *, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, + void *, const cudnnTensorDescriptor_t, void *, const float, const int, + int *, cudnnAlgorithmPerformance_t *, void *, size_t, void *, size_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnFindRNNForwardTrainingAlgorithmEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, cxDesc, cx, wDesc, w, yDesc, y, hyDesc, hy, cyDesc, cy, findIntensity, requestedAlgoCount, returnedAlgoCount, perfResults, workspace, workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes); + return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, cxDesc, cx, + wDesc, w, yDesc, y, hyDesc, hy, cyDesc, cy, findIntensity, + requestedAlgoCount, returnedAlgoCount, perfResults, workspace, + workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes); } cudnnStatus_t CUDNNWINAPI cudnnGetRNNBackwardDataAlgorithmMaxCount( - cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - int *count) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, int *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNBackwardDataAlgorithmMaxCount"); + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, int *count) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, int *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetRNNBackwardDataAlgorithmMaxCount"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rnnDesc, count); } -cudnnStatus_t CUDNNWINAPI cudnnFindRNNBackwardDataAlgorithmEx( cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int seqLength, - const cudnnTensorDescriptor_t * yDesc, - const void * y, - const cudnnTensorDescriptor_t * dyDesc, - const void * dy, - const cudnnTensorDescriptor_t dhyDesc, - const void * dhy, - const cudnnTensorDescriptor_t dcyDesc, - const void * dcy, - const cudnnFilterDescriptor_t wDesc, - const void * w, - const cudnnTensorDescriptor_t hxDesc, - const void * hx, - const cudnnTensorDescriptor_t cxDesc, - const void * cx, - const cudnnTensorDescriptor_t * dxDesc, - void * dx, - const cudnnTensorDescriptor_t dhxDesc, - void * dhx, - const cudnnTensorDescriptor_t dcxDesc, - void * dcx, - const float findIntensity, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnAlgorithmPerformance_t *perfResults, - void * workspace, - size_t workSpaceSizeInBytes, - void * reserveSpace, - size_t reserveSpaceSizeInBytes ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, const float, const int, int *, cudnnAlgorithmPerformance_t *, void *, size_t, void *, size_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindRNNBackwardDataAlgorithmEx"); +cudnnStatus_t CUDNNWINAPI cudnnFindRNNBackwardDataAlgorithmEx( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, const cudnnTensorDescriptor_t *yDesc, const void *y, + const cudnnTensorDescriptor_t *dyDesc, const void *dy, + const cudnnTensorDescriptor_t dhyDesc, const void *dhy, + const cudnnTensorDescriptor_t dcyDesc, const void *dcy, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnTensorDescriptor_t hxDesc, const void *hx, + const cudnnTensorDescriptor_t cxDesc, const void *cx, + const cudnnTensorDescriptor_t *dxDesc, void *dx, + const cudnnTensorDescriptor_t dhxDesc, void *dhx, + const cudnnTensorDescriptor_t dcxDesc, void *dcx, const float findIntensity, + const int requestedAlgoCount, int *returnedAlgoCount, + cudnnAlgorithmPerformance_t *perfResults, void *workspace, + size_t workSpaceSizeInBytes, void *reserveSpace, + size_t reserveSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t *, const void *, + const cudnnTensorDescriptor_t *, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, + void *, const cudnnTensorDescriptor_t, void *, const float, const int, + int *, cudnnAlgorithmPerformance_t *, void *, size_t, void *, size_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnFindRNNBackwardDataAlgorithmEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, seqLength, yDesc, y, dyDesc, dy, dhyDesc, dhy, dcyDesc, dcy, wDesc, w, hxDesc, hx, cxDesc, cx, dxDesc, dx, dhxDesc, dhx, dcxDesc, dcx, findIntensity, requestedAlgoCount, returnedAlgoCount, perfResults, workspace, workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes); + return func_ptr(handle, rnnDesc, seqLength, yDesc, y, dyDesc, dy, dhyDesc, + dhy, dcyDesc, dcy, wDesc, w, hxDesc, hx, cxDesc, cx, dxDesc, + dx, dhxDesc, dhx, dcxDesc, dcx, findIntensity, + requestedAlgoCount, returnedAlgoCount, perfResults, workspace, + workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes); } cudnnStatus_t CUDNNWINAPI cudnnGetRNNBackwardWeightsAlgorithmMaxCount( - cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - int *count) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, int *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNBackwardWeightsAlgorithmMaxCount"); + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, int *count) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, int *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetRNNBackwardWeightsAlgorithmMaxCount"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rnnDesc, count); } -cudnnStatus_t CUDNNWINAPI cudnnFindRNNBackwardWeightsAlgorithmEx( cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int seqLength, - const cudnnTensorDescriptor_t * xDesc, - const void * x, - const cudnnTensorDescriptor_t hxDesc, - const void * hx, - const cudnnTensorDescriptor_t * yDesc, - const void * y, - const float findIntensity, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnAlgorithmPerformance_t *perfResults, - const void * workspace, - size_t workSpaceSizeInBytes, - const cudnnFilterDescriptor_t dwDesc, - void * dw, - const void * reserveSpace, - size_t reserveSpaceSizeInBytes ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t *, const void *, const float, const int, int *, cudnnAlgorithmPerformance_t *, const void *, size_t, const cudnnFilterDescriptor_t, void *, const void *, size_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindRNNBackwardWeightsAlgorithmEx"); +cudnnStatus_t CUDNNWINAPI cudnnFindRNNBackwardWeightsAlgorithmEx( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, const cudnnTensorDescriptor_t *xDesc, const void *x, + const cudnnTensorDescriptor_t hxDesc, const void *hx, + const cudnnTensorDescriptor_t *yDesc, const void *y, + const float findIntensity, const int requestedAlgoCount, + int *returnedAlgoCount, cudnnAlgorithmPerformance_t *perfResults, + const void *workspace, size_t workSpaceSizeInBytes, + const cudnnFilterDescriptor_t dwDesc, void *dw, const void *reserveSpace, + size_t reserveSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t *, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t *, const void *, const float, const int, + int *, cudnnAlgorithmPerformance_t *, const void *, size_t, + const cudnnFilterDescriptor_t, void *, const void *, size_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnFindRNNBackwardWeightsAlgorithmEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, yDesc, y, findIntensity, requestedAlgoCount, returnedAlgoCount, perfResults, workspace, workSpaceSizeInBytes, dwDesc, dw, reserveSpace, reserveSpaceSizeInBytes); + return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, yDesc, y, + findIntensity, requestedAlgoCount, returnedAlgoCount, + perfResults, workspace, workSpaceSizeInBytes, dwDesc, dw, + reserveSpace, reserveSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI cudnnCreatePersistentRNNPlan(cudnnRNNDescriptor_t rnnDesc, - const int minibatch, - const cudnnDataType_t dataType, - cudnnPersistentRNNPlan_t * plan) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t, const int, const cudnnDataType_t, cudnnPersistentRNNPlan_t *); +cudnnStatus_t CUDNNWINAPI cudnnCreatePersistentRNNPlan( + cudnnRNNDescriptor_t rnnDesc, const int minibatch, + const cudnnDataType_t dataType, cudnnPersistentRNNPlan_t *plan) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnRNNDescriptor_t, const int, + const cudnnDataType_t, + cudnnPersistentRNNPlan_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreatePersistentRNNPlan"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(rnnDesc, minibatch, dataType, plan); } -cudnnStatus_t CUDNNWINAPI cudnnSetPersistentRNNPlan(cudnnRNNDescriptor_t rnnDesc, - cudnnPersistentRNNPlan_t plan) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t, cudnnPersistentRNNPlan_t); +cudnnStatus_t CUDNNWINAPI cudnnSetPersistentRNNPlan( + cudnnRNNDescriptor_t rnnDesc, cudnnPersistentRNNPlan_t plan) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnRNNDescriptor_t, + cudnnPersistentRNNPlan_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetPersistentRNNPlan"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(rnnDesc, plan); } -cudnnStatus_t CUDNNWINAPI cudnnDestroyPersistentRNNPlan(cudnnPersistentRNNPlan_t plan) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnPersistentRNNPlan_t); +cudnnStatus_t CUDNNWINAPI +cudnnDestroyPersistentRNNPlan(cudnnPersistentRNNPlan_t plan) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnPersistentRNNPlan_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyPersistentRNNPlan"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(plan); } -cudnnStatus_t CUDNNWINAPI cudnnSetRNNDescriptor(cudnnHandle_t handle, - cudnnRNNDescriptor_t rnnDesc, - const int hiddenSize, - const int numLayers, - cudnnDropoutDescriptor_t dropoutDesc, /* Between layers, not between recurrent steps. */ - cudnnRNNInputMode_t inputMode, - cudnnDirectionMode_t direction, - cudnnRNNMode_t mode, - cudnnRNNAlgo_t algo, - cudnnDataType_t dataType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnRNNDescriptor_t, const int, const int, cudnnDropoutDescriptor_t, cudnnRNNInputMode_t, cudnnDirectionMode_t, cudnnRNNMode_t, cudnnRNNAlgo_t, cudnnDataType_t); +cudnnStatus_t CUDNNWINAPI cudnnSetRNNDescriptor( + cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc, const int hiddenSize, + const int numLayers, + cudnnDropoutDescriptor_t + dropoutDesc, /* Between layers, not between recurrent steps. */ + cudnnRNNInputMode_t inputMode, cudnnDirectionMode_t direction, + cudnnRNNMode_t mode, cudnnRNNAlgo_t algo, cudnnDataType_t dataType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnRNNDescriptor_t, const int, const int, + cudnnDropoutDescriptor_t, cudnnRNNInputMode_t, cudnnDirectionMode_t, + cudnnRNNMode_t, cudnnRNNAlgo_t, cudnnDataType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetRNNDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, hiddenSize, numLayers, dropoutDesc, inputMode, direction, mode, algo, dataType); + return func_ptr(handle, rnnDesc, hiddenSize, numLayers, dropoutDesc, + inputMode, direction, mode, algo, dataType); } -cudnnStatus_t CUDNNWINAPI cudnnSetRNNProjectionLayers(cudnnHandle_t handle, - cudnnRNNDescriptor_t rnnDesc, - const int recProjSize, - const int outProjSize) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnRNNDescriptor_t, const int, const int); +cudnnStatus_t CUDNNWINAPI +cudnnSetRNNProjectionLayers(cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc, + const int recProjSize, const int outProjSize) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnRNNDescriptor_t, const int, const int); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetRNNProjectionLayers"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rnnDesc, recProjSize, outProjSize); } -cudnnStatus_t CUDNNWINAPI cudnnGetRNNProjectionLayers(cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - int *recProjSize, - int *outProjSize) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, int *, int *); +cudnnStatus_t CUDNNWINAPI cudnnGetRNNProjectionLayers( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, int *recProjSize, + int *outProjSize) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, int *, int *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNProjectionLayers"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rnnDesc, recProjSize, outProjSize); } -cudnnStatus_t CUDNNWINAPI cudnnSetRNNAlgorithmDescriptor(cudnnHandle_t handle, - cudnnRNNDescriptor_t rnnDesc, - cudnnAlgorithmDescriptor_t algoDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnRNNDescriptor_t, cudnnAlgorithmDescriptor_t); +cudnnStatus_t CUDNNWINAPI cudnnSetRNNAlgorithmDescriptor( + cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc, + cudnnAlgorithmDescriptor_t algoDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnRNNDescriptor_t, cudnnAlgorithmDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetRNNAlgorithmDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rnnDesc, algoDesc); } -cudnnStatus_t CUDNNWINAPI cudnnGetRNNDescriptor(cudnnHandle_t handle, - cudnnRNNDescriptor_t rnnDesc, - int * hiddenSize, - int * numLayers, - cudnnDropoutDescriptor_t * dropoutDesc, - cudnnRNNInputMode_t * inputMode, - cudnnDirectionMode_t * direction, - cudnnRNNMode_t * mode, - cudnnRNNAlgo_t * algo, - cudnnDataType_t * dataType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnRNNDescriptor_t, int *, int *, cudnnDropoutDescriptor_t *, cudnnRNNInputMode_t *, cudnnDirectionMode_t *, cudnnRNNMode_t *, cudnnRNNAlgo_t *, cudnnDataType_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetRNNDescriptor( + cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc, int *hiddenSize, + int *numLayers, cudnnDropoutDescriptor_t *dropoutDesc, + cudnnRNNInputMode_t *inputMode, cudnnDirectionMode_t *direction, + cudnnRNNMode_t *mode, cudnnRNNAlgo_t *algo, cudnnDataType_t *dataType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnRNNDescriptor_t, int *, int *, + cudnnDropoutDescriptor_t *, cudnnRNNInputMode_t *, cudnnDirectionMode_t *, + cudnnRNNMode_t *, cudnnRNNAlgo_t *, cudnnDataType_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, hiddenSize, numLayers, dropoutDesc, inputMode, direction, mode, algo, dataType); + return func_ptr(handle, rnnDesc, hiddenSize, numLayers, dropoutDesc, + inputMode, direction, mode, algo, dataType); } -cudnnStatus_t CUDNNWINAPI cudnnSetRNNMatrixMathType(cudnnRNNDescriptor_t rnnDesc, cudnnMathType_t mType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t, cudnnMathType_t); +cudnnStatus_t CUDNNWINAPI +cudnnSetRNNMatrixMathType(cudnnRNNDescriptor_t rnnDesc, cudnnMathType_t mType) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnRNNDescriptor_t, cudnnMathType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetRNNMatrixMathType"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(rnnDesc, mType); } -cudnnStatus_t CUDNNWINAPI cudnnGetRNNMatrixMathType(cudnnRNNDescriptor_t rnnDesc, cudnnMathType_t* mType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t, cudnnMathType_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetRNNMatrixMathType( + cudnnRNNDescriptor_t rnnDesc, cudnnMathType_t *mType) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnRNNDescriptor_t, cudnnMathType_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNMatrixMathType"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(rnnDesc, mType); } -cudnnStatus_t CUDNNWINAPI cudnnGetRNNWorkspaceSize( cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int seqLength, - const cudnnTensorDescriptor_t *xDesc, - size_t *sizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, size_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetRNNWorkspaceSize( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, const cudnnTensorDescriptor_t *xDesc, + size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t *, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNWorkspaceSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rnnDesc, seqLength, xDesc, sizeInBytes); } -cudnnStatus_t CUDNNWINAPI cudnnGetRNNTrainingReserveSize( cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int seqLength, - const cudnnTensorDescriptor_t *xDesc, - size_t *sizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, size_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetRNNTrainingReserveSize( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, const cudnnTensorDescriptor_t *xDesc, + size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t *, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNTrainingReserveSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rnnDesc, seqLength, xDesc, sizeInBytes); } -cudnnStatus_t CUDNNWINAPI cudnnGetRNNParamsSize(cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const cudnnTensorDescriptor_t xDesc, - size_t *sizeInBytes, - cudnnDataType_t dataType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const cudnnTensorDescriptor_t, size_t *, cudnnDataType_t); +cudnnStatus_t CUDNNWINAPI +cudnnGetRNNParamsSize(cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const cudnnTensorDescriptor_t xDesc, size_t *sizeInBytes, + cudnnDataType_t dataType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const cudnnTensorDescriptor_t, + size_t *, cudnnDataType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNParamsSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rnnDesc, xDesc, sizeInBytes, dataType); } -cudnnStatus_t CUDNNWINAPI cudnnGetRNNLinLayerMatrixParams( cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int pseudoLayer, - const cudnnTensorDescriptor_t xDesc, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const int linLayerID, - cudnnFilterDescriptor_t linLayerMatDesc, - void **linLayerMat) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, const void *, const int, cudnnFilterDescriptor_t, void **); +cudnnStatus_t CUDNNWINAPI cudnnGetRNNLinLayerMatrixParams( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int pseudoLayer, const cudnnTensorDescriptor_t xDesc, + const cudnnFilterDescriptor_t wDesc, const void *w, const int linLayerID, + cudnnFilterDescriptor_t linLayerMatDesc, void **linLayerMat) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, + const void *, const int, cudnnFilterDescriptor_t, void **); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNLinLayerMatrixParams"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, pseudoLayer, xDesc, wDesc, w, linLayerID, linLayerMatDesc, linLayerMat); + return func_ptr(handle, rnnDesc, pseudoLayer, xDesc, wDesc, w, linLayerID, + linLayerMatDesc, linLayerMat); } -cudnnStatus_t CUDNNWINAPI cudnnGetRNNLinLayerBiasParams( cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int pseudoLayer, - const cudnnTensorDescriptor_t xDesc, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const int linLayerID, - cudnnFilterDescriptor_t linLayerBiasDesc, - void **linLayerBias) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, const void *, const int, cudnnFilterDescriptor_t, void **); +cudnnStatus_t CUDNNWINAPI cudnnGetRNNLinLayerBiasParams( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int pseudoLayer, const cudnnTensorDescriptor_t xDesc, + const cudnnFilterDescriptor_t wDesc, const void *w, const int linLayerID, + cudnnFilterDescriptor_t linLayerBiasDesc, void **linLayerBias) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, + const void *, const int, cudnnFilterDescriptor_t, void **); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNLinLayerBiasParams"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, pseudoLayer, xDesc, wDesc, w, linLayerID, linLayerBiasDesc, linLayerBias); + return func_ptr(handle, rnnDesc, pseudoLayer, xDesc, wDesc, w, linLayerID, + linLayerBiasDesc, linLayerBias); } -cudnnStatus_t CUDNNWINAPI cudnnRNNForwardInference( cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int seqLength, - const cudnnTensorDescriptor_t *xDesc, - const void *x, - const cudnnTensorDescriptor_t hxDesc, - const void *hx, - const cudnnTensorDescriptor_t cxDesc, - const void *cx, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnTensorDescriptor_t *yDesc, - void *y, - const cudnnTensorDescriptor_t hyDesc, - void *hy, - const cudnnTensorDescriptor_t cyDesc, - void *cy, - void *workspace, - size_t workSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, void *, size_t); +cudnnStatus_t CUDNNWINAPI cudnnRNNForwardInference( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, const cudnnTensorDescriptor_t *xDesc, const void *x, + const cudnnTensorDescriptor_t hxDesc, const void *hx, + const cudnnTensorDescriptor_t cxDesc, const void *cx, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnTensorDescriptor_t *yDesc, void *y, + const cudnnTensorDescriptor_t hyDesc, void *hy, + const cudnnTensorDescriptor_t cyDesc, void *cy, void *workspace, + size_t workSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t *, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, + void *, const cudnnTensorDescriptor_t, void *, void *, size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRNNForwardInference"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, cxDesc, cx, wDesc, w, yDesc, y, hyDesc, hy, cyDesc, cy, workspace, workSpaceSizeInBytes); + return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, cxDesc, cx, + wDesc, w, yDesc, y, hyDesc, hy, cyDesc, cy, workspace, + workSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI cudnnRNNForwardTraining( cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int seqLength, - const cudnnTensorDescriptor_t *xDesc, - const void *x, - const cudnnTensorDescriptor_t hxDesc, - const void *hx, - const cudnnTensorDescriptor_t cxDesc, - const void *cx, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnTensorDescriptor_t *yDesc, - void *y, - const cudnnTensorDescriptor_t hyDesc, - void *hy, - const cudnnTensorDescriptor_t cyDesc, - void *cy, - void *workspace, - size_t workSpaceSizeInBytes, - void * reserveSpace, - size_t reserveSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, void *, size_t, void *, size_t); +cudnnStatus_t CUDNNWINAPI cudnnRNNForwardTraining( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, const cudnnTensorDescriptor_t *xDesc, const void *x, + const cudnnTensorDescriptor_t hxDesc, const void *hx, + const cudnnTensorDescriptor_t cxDesc, const void *cx, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnTensorDescriptor_t *yDesc, void *y, + const cudnnTensorDescriptor_t hyDesc, void *hy, + const cudnnTensorDescriptor_t cyDesc, void *cy, void *workspace, + size_t workSpaceSizeInBytes, void *reserveSpace, + size_t reserveSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t *, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, + void *, const cudnnTensorDescriptor_t, void *, void *, size_t, void *, + size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRNNForwardTraining"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, cxDesc, cx, wDesc, w, yDesc, y, hyDesc, hy, cyDesc, cy, workspace, workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes); + return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, cxDesc, cx, + wDesc, w, yDesc, y, hyDesc, hy, cyDesc, cy, workspace, + workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI cudnnRNNBackwardData( cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int seqLength, - const cudnnTensorDescriptor_t *yDesc, - const void *y, - const cudnnTensorDescriptor_t *dyDesc, - const void *dy, - const cudnnTensorDescriptor_t dhyDesc, - const void *dhy, - const cudnnTensorDescriptor_t dcyDesc, - const void *dcy, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnTensorDescriptor_t hxDesc, - const void *hx, - const cudnnTensorDescriptor_t cxDesc, - const void *cx, - const cudnnTensorDescriptor_t *dxDesc, - void *dx, - const cudnnTensorDescriptor_t dhxDesc, - void *dhx, - const cudnnTensorDescriptor_t dcxDesc, - void *dcx, - void *workspace, - size_t workSpaceSizeInBytes, - void * reserveSpace, - size_t reserveSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, void *, size_t, void *, size_t); +cudnnStatus_t CUDNNWINAPI +cudnnRNNBackwardData(cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, const cudnnTensorDescriptor_t *yDesc, + const void *y, const cudnnTensorDescriptor_t *dyDesc, + const void *dy, const cudnnTensorDescriptor_t dhyDesc, + const void *dhy, const cudnnTensorDescriptor_t dcyDesc, + const void *dcy, const cudnnFilterDescriptor_t wDesc, + const void *w, const cudnnTensorDescriptor_t hxDesc, + const void *hx, const cudnnTensorDescriptor_t cxDesc, + const void *cx, const cudnnTensorDescriptor_t *dxDesc, + void *dx, const cudnnTensorDescriptor_t dhxDesc, void *dhx, + const cudnnTensorDescriptor_t dcxDesc, void *dcx, + void *workspace, size_t workSpaceSizeInBytes, + void *reserveSpace, size_t reserveSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t *, const void *, + const cudnnTensorDescriptor_t *, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, + void *, const cudnnTensorDescriptor_t, void *, void *, size_t, void *, + size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRNNBackwardData"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, seqLength, yDesc, y, dyDesc, dy, dhyDesc, dhy, dcyDesc, dcy, wDesc, w, hxDesc, hx, cxDesc, cx, dxDesc, dx, dhxDesc, dhx, dcxDesc, dcx, workspace, workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes); + return func_ptr(handle, rnnDesc, seqLength, yDesc, y, dyDesc, dy, dhyDesc, + dhy, dcyDesc, dcy, wDesc, w, hxDesc, hx, cxDesc, cx, dxDesc, + dx, dhxDesc, dhx, dcxDesc, dcx, workspace, + workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI cudnnRNNBackwardWeights( cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int seqLength, - const cudnnTensorDescriptor_t *xDesc, - const void *x, - const cudnnTensorDescriptor_t hxDesc, - const void *hx, - const cudnnTensorDescriptor_t *yDesc, - const void *y, - const void *workspace, - size_t workSpaceSizeInBytes, - const cudnnFilterDescriptor_t dwDesc, - void *dw, - const void *reserveSpace, - size_t reserveSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t *, const void *, const void *, size_t, const cudnnFilterDescriptor_t, void *, const void *, size_t); +cudnnStatus_t CUDNNWINAPI cudnnRNNBackwardWeights( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, const cudnnTensorDescriptor_t *xDesc, const void *x, + const cudnnTensorDescriptor_t hxDesc, const void *hx, + const cudnnTensorDescriptor_t *yDesc, const void *y, const void *workspace, + size_t workSpaceSizeInBytes, const cudnnFilterDescriptor_t dwDesc, void *dw, + const void *reserveSpace, size_t reserveSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t *, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t *, const void *, const void *, size_t, + const cudnnFilterDescriptor_t, void *, const void *, size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRNNBackwardWeights"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, yDesc, y, workspace, workSpaceSizeInBytes, dwDesc, dw, reserveSpace, reserveSpaceSizeInBytes); + return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, yDesc, y, + workspace, workSpaceSizeInBytes, dwDesc, dw, reserveSpace, + reserveSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI cudnnCreateCTCLossDescriptor( cudnnCTCLossDescriptor_t* ctcLossDesc ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnCTCLossDescriptor_t *); +cudnnStatus_t CUDNNWINAPI +cudnnCreateCTCLossDescriptor(cudnnCTCLossDescriptor_t *ctcLossDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnCTCLossDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateCTCLossDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(ctcLossDesc); } cudnnStatus_t CUDNNWINAPI cudnnSetCTCLossDescriptor( - cudnnCTCLossDescriptor_t ctcLossDesc, - cudnnDataType_t compType ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnCTCLossDescriptor_t, cudnnDataType_t); + cudnnCTCLossDescriptor_t ctcLossDesc, cudnnDataType_t compType) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnCTCLossDescriptor_t, cudnnDataType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetCTCLossDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(ctcLossDesc, compType); } cudnnStatus_t CUDNNWINAPI cudnnGetCTCLossDescriptor( - cudnnCTCLossDescriptor_t ctcLossDesc, - cudnnDataType_t* compType ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnCTCLossDescriptor_t, cudnnDataType_t *); + cudnnCTCLossDescriptor_t ctcLossDesc, cudnnDataType_t *compType) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnCTCLossDescriptor_t, cudnnDataType_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetCTCLossDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(ctcLossDesc, compType); } -cudnnStatus_t CUDNNWINAPI cudnnDestroyCTCLossDescriptor( cudnnCTCLossDescriptor_t ctcLossDesc ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnCTCLossDescriptor_t); +cudnnStatus_t CUDNNWINAPI +cudnnDestroyCTCLossDescriptor(cudnnCTCLossDescriptor_t ctcLossDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnCTCLossDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyCTCLossDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(ctcLossDesc); } -cudnnStatus_t CUDNNWINAPI cudnnCTCLoss( cudnnHandle_t handle, - const cudnnTensorDescriptor_t probsDesc, /* Tensor descriptor for probabilities, the dimensions are T,N,A (T is the timing steps, N is the mini batch size, A is the alphabet size) */ - const void * probs, /* probabilities after softmax, in GPU memory */ - const int * labels, /* labels, in CPU memory */ - const int * labelLengths, /* the length of each label, in CPU memory */ - const int * inputLengths, /* the lengths of timing steps in each batch, in CPU memory */ - void * costs, /* the returned costs of CTC, in GPU memory */ - const cudnnTensorDescriptor_t gradientsDesc, /* Tensor descriptor for gradients, the dimensions are T,N,A */ - const void * gradients, /* the returned CTC gradients, in GPU memory, to compute costs only, set it to NULL */ - cudnnCTCLossAlgo_t algo, /* algorithm selected, supported now 0 and 1 */ - cudnnCTCLossDescriptor_t ctcLossDesc, - void * workspace, /* pointer to the workspace, in GPU memory */ - size_t workSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const void *, const int *, const int *, const int *, void *, const cudnnTensorDescriptor_t, const void *, cudnnCTCLossAlgo_t, cudnnCTCLossDescriptor_t, void *, size_t); +cudnnStatus_t CUDNNWINAPI cudnnCTCLoss( + cudnnHandle_t handle, + const cudnnTensorDescriptor_t + probsDesc, /* Tensor descriptor for probabilities, the dimensions are + T,N,A (T is the timing steps, N is the mini batch size, A + is the alphabet size) */ + const void *probs, /* probabilities after softmax, in GPU memory */ + const int *labels, /* labels, in CPU memory */ + const int *labelLengths, /* the length of each label, in CPU memory */ + const int *inputLengths, /* the lengths of timing steps in each batch, in + CPU memory */ + void *costs, /* the returned costs of CTC, in GPU memory */ + const cudnnTensorDescriptor_t + gradientsDesc, /* Tensor descriptor for gradients, the dimensions are + T,N,A */ + const void *gradients, /* the returned CTC gradients, in GPU memory, to + compute costs only, set it to NULL */ + cudnnCTCLossAlgo_t algo, /* algorithm selected, supported now 0 and 1 */ + cudnnCTCLossDescriptor_t ctcLossDesc, + void *workspace, /* pointer to the workspace, in GPU memory */ + size_t workSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, const void *, const int *, + const int *, const int *, void *, const cudnnTensorDescriptor_t, + const void *, cudnnCTCLossAlgo_t, cudnnCTCLossDescriptor_t, void *, + size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCTCLoss"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, probsDesc, probs, labels, labelLengths, inputLengths, costs, gradientsDesc, gradients, algo, ctcLossDesc, workspace, workSpaceSizeInBytes); + return func_ptr(handle, probsDesc, probs, labels, labelLengths, inputLengths, + costs, gradientsDesc, gradients, algo, ctcLossDesc, workspace, + workSpaceSizeInBytes); } cudnnStatus_t CUDNNWINAPI cudnnGetCTCLossWorkspaceSize( - cudnnHandle_t handle, - const cudnnTensorDescriptor_t probsDesc, /* Tensor descriptor for probabilities, the dimensions are T,N,A (T is the timing steps, N is the mini batch size, A is the alphabet size) */ - const cudnnTensorDescriptor_t gradientsDesc, /* Tensor descriptor for gradients, the dimensions are T,N,A. To compute costs only, set it to NULL */ - const int * labels, /* labels, in CPU memory */ - const int * labelLengths, /* the length of each label, in CPU memory */ - const int * inputLengths, /* the lengths of timing steps in each batch, in CPU memory */ - cudnnCTCLossAlgo_t algo, /* algorithm selected, supported now 0 and 1 */ - cudnnCTCLossDescriptor_t ctcLossDesc, - size_t *sizeInBytes ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const int *, const int *, const int *, cudnnCTCLossAlgo_t, cudnnCTCLossDescriptor_t, size_t *); + cudnnHandle_t handle, + const cudnnTensorDescriptor_t + probsDesc, /* Tensor descriptor for probabilities, the dimensions are + T,N,A (T is the timing steps, N is the mini batch size, A + is the alphabet size) */ + const cudnnTensorDescriptor_t + gradientsDesc, /* Tensor descriptor for gradients, the dimensions are + T,N,A. To compute costs only, set it to NULL */ + const int *labels, /* labels, in CPU memory */ + const int *labelLengths, /* the length of each label, in CPU memory */ + const int *inputLengths, /* the lengths of timing steps in each batch, in + CPU memory */ + cudnnCTCLossAlgo_t algo, /* algorithm selected, supported now 0 and 1 */ + cudnnCTCLossDescriptor_t ctcLossDesc, size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, + const cudnnTensorDescriptor_t, const int *, const int *, const int *, + cudnnCTCLossAlgo_t, cudnnCTCLossDescriptor_t, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetCTCLossWorkspaceSize"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, probsDesc, gradientsDesc, labels, labelLengths, inputLengths, algo, ctcLossDesc, sizeInBytes); + return func_ptr(handle, probsDesc, gradientsDesc, labels, labelLengths, + inputLengths, algo, ctcLossDesc, sizeInBytes); } -cudnnStatus_t CUDNNWINAPI cudnnCreateAlgorithmDescriptor( - cudnnAlgorithmDescriptor_t *algoDesc ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnAlgorithmDescriptor_t *); +cudnnStatus_t CUDNNWINAPI +cudnnCreateAlgorithmDescriptor(cudnnAlgorithmDescriptor_t *algoDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnAlgorithmDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateAlgorithmDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(algoDesc); } cudnnStatus_t CUDNNWINAPI cudnnSetAlgorithmDescriptor( - cudnnAlgorithmDescriptor_t algoDesc, - cudnnAlgorithm_t algorithm) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnAlgorithmDescriptor_t, cudnnAlgorithm_t); + cudnnAlgorithmDescriptor_t algoDesc, cudnnAlgorithm_t algorithm) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnAlgorithmDescriptor_t, + cudnnAlgorithm_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetAlgorithmDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(algoDesc, algorithm); } cudnnStatus_t CUDNNWINAPI cudnnGetAlgorithmDescriptor( - const cudnnAlgorithmDescriptor_t algoDesc, - cudnnAlgorithm_t* algorithm) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnAlgorithmDescriptor_t, cudnnAlgorithm_t *); + const cudnnAlgorithmDescriptor_t algoDesc, cudnnAlgorithm_t *algorithm) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(const cudnnAlgorithmDescriptor_t, + cudnnAlgorithm_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetAlgorithmDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(algoDesc, algorithm); } cudnnStatus_t CUDNNWINAPI cudnnCopyAlgorithmDescriptor( - const cudnnAlgorithmDescriptor_t src, - cudnnAlgorithmDescriptor_t dest) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnAlgorithmDescriptor_t, cudnnAlgorithmDescriptor_t); + const cudnnAlgorithmDescriptor_t src, cudnnAlgorithmDescriptor_t dest) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(const cudnnAlgorithmDescriptor_t, + cudnnAlgorithmDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCopyAlgorithmDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(src, dest); } -cudnnStatus_t CUDNNWINAPI cudnnDestroyAlgorithmDescriptor( - cudnnAlgorithmDescriptor_t algoDesc ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnAlgorithmDescriptor_t); +cudnnStatus_t CUDNNWINAPI +cudnnDestroyAlgorithmDescriptor(cudnnAlgorithmDescriptor_t algoDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnAlgorithmDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyAlgorithmDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(algoDesc); } cudnnStatus_t CUDNNWINAPI cudnnCreateAlgorithmPerformance( - cudnnAlgorithmPerformance_t* algoPerf, - int numberToCreate ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnAlgorithmPerformance_t *, int); + cudnnAlgorithmPerformance_t *algoPerf, int numberToCreate) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnAlgorithmPerformance_t *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateAlgorithmPerformance"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(algoPerf, numberToCreate); } cudnnStatus_t CUDNNWINAPI cudnnSetAlgorithmPerformance( - cudnnAlgorithmPerformance_t algoPerf, - cudnnAlgorithmDescriptor_t algoDesc, - cudnnStatus_t status, - float time, - size_t memory ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnAlgorithmPerformance_t, cudnnAlgorithmDescriptor_t, cudnnStatus_t, float, size_t); + cudnnAlgorithmPerformance_t algoPerf, cudnnAlgorithmDescriptor_t algoDesc, + cudnnStatus_t status, float time, size_t memory) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnAlgorithmPerformance_t, + cudnnAlgorithmDescriptor_t, + cudnnStatus_t, float, size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetAlgorithmPerformance"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(algoPerf, algoDesc, status, time, memory); } cudnnStatus_t CUDNNWINAPI cudnnGetAlgorithmPerformance( - const cudnnAlgorithmPerformance_t algoPerf, - cudnnAlgorithmDescriptor_t* algoDesc, - cudnnStatus_t* status, - float* time, - size_t* memory ) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnAlgorithmPerformance_t, cudnnAlgorithmDescriptor_t *, cudnnStatus_t *, float *, size_t *); + const cudnnAlgorithmPerformance_t algoPerf, + cudnnAlgorithmDescriptor_t *algoDesc, cudnnStatus_t *status, float *time, + size_t *memory) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnAlgorithmPerformance_t, cudnnAlgorithmDescriptor_t *, + cudnnStatus_t *, float *, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetAlgorithmPerformance"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(algoPerf, algoDesc, status, time, memory); } cudnnStatus_t CUDNNWINAPI cudnnDestroyAlgorithmPerformance( - cudnnAlgorithmPerformance_t* algoPerf, - int numberToDestroy) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnAlgorithmPerformance_t *, int); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyAlgorithmPerformance"); + cudnnAlgorithmPerformance_t *algoPerf, int numberToDestroy) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnAlgorithmPerformance_t *, int); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnDestroyAlgorithmPerformance"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(algoPerf, numberToDestroy); } cudnnStatus_t CUDNNWINAPI cudnnGetAlgorithmSpaceSize( - cudnnHandle_t handle, - cudnnAlgorithmDescriptor_t algoDesc, - size_t* algoSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnAlgorithmDescriptor_t, size_t *); + cudnnHandle_t handle, cudnnAlgorithmDescriptor_t algoDesc, + size_t *algoSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnAlgorithmDescriptor_t, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetAlgorithmSpaceSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, algoDesc, algoSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI cudnnSaveAlgorithm( - cudnnHandle_t handle, - cudnnAlgorithmDescriptor_t algoDesc, - void* algoSpace, - size_t algoSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnAlgorithmDescriptor_t, void *, size_t); +cudnnStatus_t CUDNNWINAPI +cudnnSaveAlgorithm(cudnnHandle_t handle, cudnnAlgorithmDescriptor_t algoDesc, + void *algoSpace, size_t algoSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnAlgorithmDescriptor_t, void *, size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSaveAlgorithm"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, algoDesc, algoSpace, algoSpaceSizeInBytes); } cudnnStatus_t CUDNNWINAPI cudnnRestoreAlgorithm( - cudnnHandle_t handle, - void* algoSpace, - size_t algoSpaceSizeInBytes, - cudnnAlgorithmDescriptor_t algoDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, void *, size_t, cudnnAlgorithmDescriptor_t); + cudnnHandle_t handle, void *algoSpace, size_t algoSpaceSizeInBytes, + cudnnAlgorithmDescriptor_t algoDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t, void *, size_t, + cudnnAlgorithmDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRestoreAlgorithm"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, algoSpace, algoSpaceSizeInBytes, algoDesc); } -cudnnStatus_t CUDNNWINAPI cudnnSetCallback( - unsigned mask, - void *udata, - cudnnCallback_t fptr) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(unsigned int, void *, cudnnCallback_t); +cudnnStatus_t CUDNNWINAPI cudnnSetCallback(unsigned mask, void *udata, + cudnnCallback_t fptr) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(unsigned int, void *, cudnnCallback_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetCallback"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(mask, udata, fptr); } -cudnnStatus_t CUDNNWINAPI cudnnGetCallback( - unsigned *mask, - void **udata, - cudnnCallback_t *fptr) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(unsigned int *, void **, cudnnCallback_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetCallback(unsigned *mask, void **udata, + cudnnCallback_t *fptr) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(unsigned int *, void **, cudnnCallback_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetCallback"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(mask, udata, fptr); } -cudnnStatus_t CUDNNWINAPI cudnnSetRNNDescriptor_v6(cudnnHandle_t handle, - cudnnRNNDescriptor_t rnnDesc, - const int hiddenSize, - const int numLayers, - cudnnDropoutDescriptor_t dropoutDesc, - cudnnRNNInputMode_t inputMode, - cudnnDirectionMode_t direction, - cudnnRNNMode_t mode, - cudnnRNNAlgo_t algo, - cudnnDataType_t dataType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnRNNDescriptor_t, const int, const int, cudnnDropoutDescriptor_t, cudnnRNNInputMode_t, cudnnDirectionMode_t, cudnnRNNMode_t, cudnnRNNAlgo_t, cudnnDataType_t); +cudnnStatus_t CUDNNWINAPI cudnnSetRNNDescriptor_v6( + cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc, const int hiddenSize, + const int numLayers, cudnnDropoutDescriptor_t dropoutDesc, + cudnnRNNInputMode_t inputMode, cudnnDirectionMode_t direction, + cudnnRNNMode_t mode, cudnnRNNAlgo_t algo, cudnnDataType_t dataType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnRNNDescriptor_t, const int, const int, + cudnnDropoutDescriptor_t, cudnnRNNInputMode_t, cudnnDirectionMode_t, + cudnnRNNMode_t, cudnnRNNAlgo_t, cudnnDataType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetRNNDescriptor_v6"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, hiddenSize, numLayers, dropoutDesc, inputMode, direction, mode, algo, dataType); + return func_ptr(handle, rnnDesc, hiddenSize, numLayers, dropoutDesc, + inputMode, direction, mode, algo, dataType); } -cudnnStatus_t CUDNNWINAPI cudnnSetRNNDescriptor_v5(cudnnRNNDescriptor_t rnnDesc, - int hiddenSize, - int numLayers, - cudnnDropoutDescriptor_t dropoutDesc, - cudnnRNNInputMode_t inputMode, - cudnnDirectionMode_t direction, - cudnnRNNMode_t mode, - cudnnDataType_t dataType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t, int, int, cudnnDropoutDescriptor_t, cudnnRNNInputMode_t, cudnnDirectionMode_t, cudnnRNNMode_t, cudnnDataType_t); +cudnnStatus_t CUDNNWINAPI cudnnSetRNNDescriptor_v5( + cudnnRNNDescriptor_t rnnDesc, int hiddenSize, int numLayers, + cudnnDropoutDescriptor_t dropoutDesc, cudnnRNNInputMode_t inputMode, + cudnnDirectionMode_t direction, cudnnRNNMode_t mode, + cudnnDataType_t dataType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnRNNDescriptor_t, int, int, cudnnDropoutDescriptor_t, + cudnnRNNInputMode_t, cudnnDirectionMode_t, cudnnRNNMode_t, + cudnnDataType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetRNNDescriptor_v5"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(rnnDesc, hiddenSize, numLayers, dropoutDesc, inputMode, direction, mode, dataType); + return func_ptr(rnnDesc, hiddenSize, numLayers, dropoutDesc, inputMode, + direction, mode, dataType); } } // extern "C" diff --git a/tensorflow/stream_executor/cuda/cudnn_7_3.inc b/tensorflow/stream_executor/cuda/cudnn_7_3.inc index 0ee8e1492d5..f1c25c74d0c 100644 --- a/tensorflow/stream_executor/cuda/cudnn_7_3.inc +++ b/tensorflow/stream_executor/cuda/cudnn_7_3.inc @@ -2,73 +2,71 @@ extern "C" { -size_t CUDNNWINAPI -cudnnGetVersion(void) { - using FuncPtr = size_t (CUDNNWINAPI *)(); +size_t CUDNNWINAPI cudnnGetVersion(void) { + using FuncPtr = size_t(CUDNNWINAPI *)(); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetVersion"); if (!func_ptr) return 0; return func_ptr(); } -size_t CUDNNWINAPI -cudnnGetCudartVersion(void) { - using FuncPtr = size_t (CUDNNWINAPI *)(); +size_t CUDNNWINAPI cudnnGetCudartVersion(void) { + using FuncPtr = size_t(CUDNNWINAPI *)(); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetCudartVersion"); if (!func_ptr) return 0; return func_ptr(); } -const char *CUDNNWINAPI -cudnnGetErrorString(cudnnStatus_t status) { - using FuncPtr = const char * (CUDNNWINAPI *)(cudnnStatus_t); +const char *CUDNNWINAPI cudnnGetErrorString(cudnnStatus_t status) { + using FuncPtr = const char *(CUDNNWINAPI *)(cudnnStatus_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetErrorString"); if (!func_ptr) return "cudnnGetErrorString symbol not found."; return func_ptr(status); } -cudnnStatus_t CUDNNWINAPI -cudnnQueryRuntimeError(cudnnHandle_t handle, cudnnStatus_t *rstatus, cudnnErrQueryMode_t mode, cudnnRuntimeTag_t *tag) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnStatus_t *, cudnnErrQueryMode_t, cudnnRuntimeTag_t *); +cudnnStatus_t CUDNNWINAPI cudnnQueryRuntimeError(cudnnHandle_t handle, + cudnnStatus_t *rstatus, + cudnnErrQueryMode_t mode, + cudnnRuntimeTag_t *tag) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnStatus_t *, cudnnErrQueryMode_t, cudnnRuntimeTag_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnQueryRuntimeError"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rstatus, mode, tag); } -cudnnStatus_t CUDNNWINAPI -cudnnGetProperty(libraryPropertyType type, int *value) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(libraryPropertyType, int *); +cudnnStatus_t CUDNNWINAPI cudnnGetProperty(libraryPropertyType type, + int *value) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(libraryPropertyType, int *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetProperty"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(type, value); } -cudnnStatus_t CUDNNWINAPI -cudnnCreate(cudnnHandle_t *handle) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t *); +cudnnStatus_t CUDNNWINAPI cudnnCreate(cudnnHandle_t *handle) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreate"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle); } -cudnnStatus_t CUDNNWINAPI -cudnnDestroy(cudnnHandle_t handle) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t); +cudnnStatus_t CUDNNWINAPI cudnnDestroy(cudnnHandle_t handle) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroy"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle); } -cudnnStatus_t CUDNNWINAPI -cudnnSetStream(cudnnHandle_t handle, cudaStream_t streamId) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudaStream_t); +cudnnStatus_t CUDNNWINAPI cudnnSetStream(cudnnHandle_t handle, + cudaStream_t streamId) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t, cudaStream_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetStream"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, streamId); } -cudnnStatus_t CUDNNWINAPI -cudnnGetStream(cudnnHandle_t handle, cudaStream_t *streamId) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudaStream_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetStream(cudnnHandle_t handle, + cudaStream_t *streamId) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t, cudaStream_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetStream"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, streamId); @@ -76,100 +74,97 @@ cudnnGetStream(cudnnHandle_t handle, cudaStream_t *streamId) { cudnnStatus_t CUDNNWINAPI cudnnCreateTensorDescriptor(cudnnTensorDescriptor_t *tensorDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorDescriptor_t *); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnTensorDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(tensorDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnSetTensor4dDescriptor(cudnnTensorDescriptor_t tensorDesc, - cudnnTensorFormat_t format, - cudnnDataType_t dataType, /* image data type */ - int n, /* number of inputs (batch size) */ - int c, /* number of input feature maps */ - int h, /* height of input section */ - int w) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorDescriptor_t, cudnnTensorFormat_t, cudnnDataType_t, int, int, int, int); +cudnnStatus_t CUDNNWINAPI cudnnSetTensor4dDescriptor( + cudnnTensorDescriptor_t tensorDesc, cudnnTensorFormat_t format, + cudnnDataType_t dataType, /* image data type */ + int n, /* number of inputs (batch size) */ + int c, /* number of input feature maps */ + int h, /* height of input section */ + int w) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnTensorDescriptor_t, cudnnTensorFormat_t, + cudnnDataType_t, int, int, int, int); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetTensor4dDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(tensorDesc, format, dataType, n, c, h, w); } -cudnnStatus_t CUDNNWINAPI -cudnnSetTensor4dDescriptorEx(cudnnTensorDescriptor_t tensorDesc, - cudnnDataType_t dataType, /* image data type */ - int n, /* number of inputs (batch size) */ - int c, /* number of input feature maps */ - int h, /* height of input section */ - int w, /* width of input section */ - int nStride, - int cStride, - int hStride, - int wStride) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorDescriptor_t, cudnnDataType_t, int, int, int, int, int, int, int, int); +cudnnStatus_t CUDNNWINAPI cudnnSetTensor4dDescriptorEx( + cudnnTensorDescriptor_t tensorDesc, + cudnnDataType_t dataType, /* image data type */ + int n, /* number of inputs (batch size) */ + int c, /* number of input feature maps */ + int h, /* height of input section */ + int w, /* width of input section */ + int nStride, int cStride, int hStride, int wStride) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnTensorDescriptor_t, cudnnDataType_t, + int, int, int, int, int, int, int, int); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetTensor4dDescriptorEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(tensorDesc, dataType, n, c, h, w, nStride, cStride, hStride, wStride); + return func_ptr(tensorDesc, dataType, n, c, h, w, nStride, cStride, hStride, + wStride); } -cudnnStatus_t CUDNNWINAPI -cudnnGetTensor4dDescriptor(const cudnnTensorDescriptor_t tensorDesc, - cudnnDataType_t *dataType, /* image data type */ - int *n, /* number of inputs (batch size) */ - int *c, /* number of input feature maps */ - int *h, /* height of input section */ - int *w, /* width of input section */ - int *nStride, - int *cStride, - int *hStride, - int *wStride) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnTensorDescriptor_t, cudnnDataType_t *, int *, int *, int *, int *, int *, int *, int *, int *); +cudnnStatus_t CUDNNWINAPI cudnnGetTensor4dDescriptor( + const cudnnTensorDescriptor_t tensorDesc, + cudnnDataType_t *dataType, /* image data type */ + int *n, /* number of inputs (batch size) */ + int *c, /* number of input feature maps */ + int *h, /* height of input section */ + int *w, /* width of input section */ + int *nStride, int *cStride, int *hStride, int *wStride) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnTensorDescriptor_t, cudnnDataType_t *, int *, int *, int *, + int *, int *, int *, int *, int *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetTensor4dDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(tensorDesc, dataType, n, c, h, w, nStride, cStride, hStride, wStride); + return func_ptr(tensorDesc, dataType, n, c, h, w, nStride, cStride, hStride, + wStride); } -cudnnStatus_t CUDNNWINAPI -cudnnSetTensorNdDescriptor(cudnnTensorDescriptor_t tensorDesc, - cudnnDataType_t dataType, - int nbDims, - const int dimA[], - const int strideA[]) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorDescriptor_t, cudnnDataType_t, int, const int [], const int []); +cudnnStatus_t CUDNNWINAPI cudnnSetTensorNdDescriptor( + cudnnTensorDescriptor_t tensorDesc, cudnnDataType_t dataType, int nbDims, + const int dimA[], const int strideA[]) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnTensorDescriptor_t, cudnnDataType_t, int, const int[], const int[]); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetTensorNdDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(tensorDesc, dataType, nbDims, dimA, strideA); } -cudnnStatus_t CUDNNWINAPI -cudnnSetTensorNdDescriptorEx(cudnnTensorDescriptor_t tensorDesc, - cudnnTensorFormat_t format, - cudnnDataType_t dataType, - int nbDims, - const int dimA[]) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorDescriptor_t, cudnnTensorFormat_t, cudnnDataType_t, int, const int []); +cudnnStatus_t CUDNNWINAPI cudnnSetTensorNdDescriptorEx( + cudnnTensorDescriptor_t tensorDesc, cudnnTensorFormat_t format, + cudnnDataType_t dataType, int nbDims, const int dimA[]) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnTensorDescriptor_t, cudnnTensorFormat_t, + cudnnDataType_t, int, const int[]); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetTensorNdDescriptorEx"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(tensorDesc, format, dataType, nbDims, dimA); } -cudnnStatus_t CUDNNWINAPI -cudnnGetTensorNdDescriptor(const cudnnTensorDescriptor_t tensorDesc, - int nbDimsRequested, - cudnnDataType_t *dataType, - int *nbDims, - int dimA[], - int strideA[]) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnTensorDescriptor_t, int, cudnnDataType_t *, int *, int [], int []); +cudnnStatus_t CUDNNWINAPI cudnnGetTensorNdDescriptor( + const cudnnTensorDescriptor_t tensorDesc, int nbDimsRequested, + cudnnDataType_t *dataType, int *nbDims, int dimA[], int strideA[]) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(const cudnnTensorDescriptor_t, int, + cudnnDataType_t *, int *, int[], int[]); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetTensorNdDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(tensorDesc, nbDimsRequested, dataType, nbDims, dimA, strideA); } -cudnnStatus_t CUDNNWINAPI -cudnnGetTensorSizeInBytes(const cudnnTensorDescriptor_t tensorDesc, size_t *size) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnTensorDescriptor_t, size_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetTensorSizeInBytes( + const cudnnTensorDescriptor_t tensorDesc, size_t *size) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(const cudnnTensorDescriptor_t, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetTensorSizeInBytes"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(tensorDesc, size); @@ -177,35 +172,33 @@ cudnnGetTensorSizeInBytes(const cudnnTensorDescriptor_t tensorDesc, size_t *size cudnnStatus_t CUDNNWINAPI cudnnDestroyTensorDescriptor(cudnnTensorDescriptor_t tensorDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorDescriptor_t); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnTensorDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(tensorDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnTransformTensor(cudnnHandle_t handle, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t yDesc, - void *y) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnTransformTensor( + cudnnHandle_t handle, const void *alpha, + const cudnnTensorDescriptor_t xDesc, const void *x, const void *beta, + const cudnnTensorDescriptor_t yDesc, void *y) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, + const void *, const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnTransformTensor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, alpha, xDesc, x, beta, yDesc, y); } -cudnnStatus_t CUDNNWINAPI -cudnnAddTensor(cudnnHandle_t handle, - const void *alpha, - const cudnnTensorDescriptor_t aDesc, - const void *A, - const void *beta, - const cudnnTensorDescriptor_t cDesc, - void *C) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnAddTensor(cudnnHandle_t handle, + const void *alpha, + const cudnnTensorDescriptor_t aDesc, + const void *A, const void *beta, + const cudnnTensorDescriptor_t cDesc, + void *C) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, + const void *, const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnAddTensor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, alpha, aDesc, A, beta, cDesc, C); @@ -213,29 +206,29 @@ cudnnAddTensor(cudnnHandle_t handle, cudnnStatus_t CUDNNWINAPI cudnnCreateOpTensorDescriptor(cudnnOpTensorDescriptor_t *opTensorDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnOpTensorDescriptor_t *); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnOpTensorDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateOpTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(opTensorDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnSetOpTensorDescriptor(cudnnOpTensorDescriptor_t opTensorDesc, - cudnnOpTensorOp_t opTensorOp, - cudnnDataType_t opTensorCompType, - cudnnNanPropagation_t opTensorNanOpt) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnOpTensorDescriptor_t, cudnnOpTensorOp_t, cudnnDataType_t, cudnnNanPropagation_t); +cudnnStatus_t CUDNNWINAPI cudnnSetOpTensorDescriptor( + cudnnOpTensorDescriptor_t opTensorDesc, cudnnOpTensorOp_t opTensorOp, + cudnnDataType_t opTensorCompType, cudnnNanPropagation_t opTensorNanOpt) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnOpTensorDescriptor_t, cudnnOpTensorOp_t, + cudnnDataType_t, cudnnNanPropagation_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetOpTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(opTensorDesc, opTensorOp, opTensorCompType, opTensorNanOpt); } -cudnnStatus_t CUDNNWINAPI -cudnnGetOpTensorDescriptor(const cudnnOpTensorDescriptor_t opTensorDesc, - cudnnOpTensorOp_t *opTensorOp, - cudnnDataType_t *opTensorCompType, - cudnnNanPropagation_t *opTensorNanOpt) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnOpTensorDescriptor_t, cudnnOpTensorOp_t *, cudnnDataType_t *, cudnnNanPropagation_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetOpTensorDescriptor( + const cudnnOpTensorDescriptor_t opTensorDesc, cudnnOpTensorOp_t *opTensorOp, + cudnnDataType_t *opTensorCompType, cudnnNanPropagation_t *opTensorNanOpt) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnOpTensorDescriptor_t, cudnnOpTensorOp_t *, cudnnDataType_t *, + cudnnNanPropagation_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetOpTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(opTensorDesc, opTensorOp, opTensorCompType, opTensorNanOpt); @@ -243,126 +236,136 @@ cudnnGetOpTensorDescriptor(const cudnnOpTensorDescriptor_t opTensorDesc, cudnnStatus_t CUDNNWINAPI cudnnDestroyOpTensorDescriptor(cudnnOpTensorDescriptor_t opTensorDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnOpTensorDescriptor_t); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnOpTensorDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyOpTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(opTensorDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnOpTensor(cudnnHandle_t handle, - const cudnnOpTensorDescriptor_t opTensorDesc, - const void *alpha1, - const cudnnTensorDescriptor_t aDesc, - const void *A, - const void *alpha2, - const cudnnTensorDescriptor_t bDesc, - const void *B, - const void *beta, - const cudnnTensorDescriptor_t cDesc, - void *C) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnOpTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnOpTensor( + cudnnHandle_t handle, const cudnnOpTensorDescriptor_t opTensorDesc, + const void *alpha1, const cudnnTensorDescriptor_t aDesc, const void *A, + const void *alpha2, const cudnnTensorDescriptor_t bDesc, const void *B, + const void *beta, const cudnnTensorDescriptor_t cDesc, void *C) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnOpTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnOpTensor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, opTensorDesc, alpha1, aDesc, A, alpha2, bDesc, B, beta, cDesc, C); + return func_ptr(handle, opTensorDesc, alpha1, aDesc, A, alpha2, bDesc, B, + beta, cDesc, C); } -cudnnStatus_t CUDNNWINAPI -cudnnCreateReduceTensorDescriptor(cudnnReduceTensorDescriptor_t *reduceTensorDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnReduceTensorDescriptor_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateReduceTensorDescriptor"); +cudnnStatus_t CUDNNWINAPI cudnnCreateReduceTensorDescriptor( + cudnnReduceTensorDescriptor_t *reduceTensorDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnReduceTensorDescriptor_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnCreateReduceTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(reduceTensorDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnSetReduceTensorDescriptor(cudnnReduceTensorDescriptor_t reduceTensorDesc, - cudnnReduceTensorOp_t reduceTensorOp, - cudnnDataType_t reduceTensorCompType, - cudnnNanPropagation_t reduceTensorNanOpt, - cudnnReduceTensorIndices_t reduceTensorIndices, - cudnnIndicesType_t reduceTensorIndicesType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnReduceTensorDescriptor_t, cudnnReduceTensorOp_t, cudnnDataType_t, cudnnNanPropagation_t, cudnnReduceTensorIndices_t, cudnnIndicesType_t); +cudnnStatus_t CUDNNWINAPI cudnnSetReduceTensorDescriptor( + cudnnReduceTensorDescriptor_t reduceTensorDesc, + cudnnReduceTensorOp_t reduceTensorOp, cudnnDataType_t reduceTensorCompType, + cudnnNanPropagation_t reduceTensorNanOpt, + cudnnReduceTensorIndices_t reduceTensorIndices, + cudnnIndicesType_t reduceTensorIndicesType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnReduceTensorDescriptor_t, cudnnReduceTensorOp_t, cudnnDataType_t, + cudnnNanPropagation_t, cudnnReduceTensorIndices_t, cudnnIndicesType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetReduceTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(reduceTensorDesc, reduceTensorOp, reduceTensorCompType, reduceTensorNanOpt, reduceTensorIndices, reduceTensorIndicesType); + return func_ptr(reduceTensorDesc, reduceTensorOp, reduceTensorCompType, + reduceTensorNanOpt, reduceTensorIndices, + reduceTensorIndicesType); } -cudnnStatus_t CUDNNWINAPI -cudnnGetReduceTensorDescriptor(const cudnnReduceTensorDescriptor_t reduceTensorDesc, - cudnnReduceTensorOp_t *reduceTensorOp, - cudnnDataType_t *reduceTensorCompType, - cudnnNanPropagation_t *reduceTensorNanOpt, - cudnnReduceTensorIndices_t *reduceTensorIndices, - cudnnIndicesType_t *reduceTensorIndicesType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnReduceTensorDescriptor_t, cudnnReduceTensorOp_t *, cudnnDataType_t *, cudnnNanPropagation_t *, cudnnReduceTensorIndices_t *, cudnnIndicesType_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetReduceTensorDescriptor( + const cudnnReduceTensorDescriptor_t reduceTensorDesc, + cudnnReduceTensorOp_t *reduceTensorOp, + cudnnDataType_t *reduceTensorCompType, + cudnnNanPropagation_t *reduceTensorNanOpt, + cudnnReduceTensorIndices_t *reduceTensorIndices, + cudnnIndicesType_t *reduceTensorIndicesType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnReduceTensorDescriptor_t, cudnnReduceTensorOp_t *, + cudnnDataType_t *, cudnnNanPropagation_t *, cudnnReduceTensorIndices_t *, + cudnnIndicesType_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetReduceTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(reduceTensorDesc, reduceTensorOp, reduceTensorCompType, reduceTensorNanOpt, reduceTensorIndices, reduceTensorIndicesType); + return func_ptr(reduceTensorDesc, reduceTensorOp, reduceTensorCompType, + reduceTensorNanOpt, reduceTensorIndices, + reduceTensorIndicesType); } -cudnnStatus_t CUDNNWINAPI -cudnnDestroyReduceTensorDescriptor(cudnnReduceTensorDescriptor_t reduceTensorDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnReduceTensorDescriptor_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyReduceTensorDescriptor"); +cudnnStatus_t CUDNNWINAPI cudnnDestroyReduceTensorDescriptor( + cudnnReduceTensorDescriptor_t reduceTensorDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnReduceTensorDescriptor_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnDestroyReduceTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(reduceTensorDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnGetReductionIndicesSize(cudnnHandle_t handle, - const cudnnReduceTensorDescriptor_t reduceTensorDesc, - const cudnnTensorDescriptor_t aDesc, - const cudnnTensorDescriptor_t cDesc, - size_t *sizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnReduceTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, size_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetReductionIndicesSize( + cudnnHandle_t handle, const cudnnReduceTensorDescriptor_t reduceTensorDesc, + const cudnnTensorDescriptor_t aDesc, const cudnnTensorDescriptor_t cDesc, + size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnReduceTensorDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetReductionIndicesSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, reduceTensorDesc, aDesc, cDesc, sizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnGetReductionWorkspaceSize(cudnnHandle_t handle, - const cudnnReduceTensorDescriptor_t reduceTensorDesc, - const cudnnTensorDescriptor_t aDesc, - const cudnnTensorDescriptor_t cDesc, - size_t *sizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnReduceTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, size_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetReductionWorkspaceSize( + cudnnHandle_t handle, const cudnnReduceTensorDescriptor_t reduceTensorDesc, + const cudnnTensorDescriptor_t aDesc, const cudnnTensorDescriptor_t cDesc, + size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnReduceTensorDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetReductionWorkspaceSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, reduceTensorDesc, aDesc, cDesc, sizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnReduceTensor(cudnnHandle_t handle, - const cudnnReduceTensorDescriptor_t reduceTensorDesc, - void *indices, - size_t indicesSizeInBytes, - void *workspace, - size_t workspaceSizeInBytes, - const void *alpha, - const cudnnTensorDescriptor_t aDesc, - const void *A, - const void *beta, - const cudnnTensorDescriptor_t cDesc, - void *C) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnReduceTensorDescriptor_t, void *, size_t, void *, size_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnReduceTensor( + cudnnHandle_t handle, const cudnnReduceTensorDescriptor_t reduceTensorDesc, + void *indices, size_t indicesSizeInBytes, void *workspace, + size_t workspaceSizeInBytes, const void *alpha, + const cudnnTensorDescriptor_t aDesc, const void *A, const void *beta, + const cudnnTensorDescriptor_t cDesc, void *C) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnReduceTensorDescriptor_t, void *, size_t, + void *, size_t, const void *, const cudnnTensorDescriptor_t, const void *, + const void *, const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnReduceTensor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, reduceTensorDesc, indices, indicesSizeInBytes, workspace, workspaceSizeInBytes, alpha, aDesc, A, beta, cDesc, C); + return func_ptr(handle, reduceTensorDesc, indices, indicesSizeInBytes, + workspace, workspaceSizeInBytes, alpha, aDesc, A, beta, cDesc, + C); } -cudnnStatus_t CUDNNWINAPI -cudnnSetTensor(cudnnHandle_t handle, const cudnnTensorDescriptor_t yDesc, void *y, const void *valuePtr) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, void *, const void *); +cudnnStatus_t CUDNNWINAPI cudnnSetTensor(cudnnHandle_t handle, + const cudnnTensorDescriptor_t yDesc, + void *y, const void *valuePtr) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, void *, const void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetTensor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, yDesc, y, valuePtr); } -cudnnStatus_t CUDNNWINAPI -cudnnScaleTensor(cudnnHandle_t handle, const cudnnTensorDescriptor_t yDesc, void *y, const void *alpha) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, void *, const void *); +cudnnStatus_t CUDNNWINAPI cudnnScaleTensor(cudnnHandle_t handle, + const cudnnTensorDescriptor_t yDesc, + void *y, const void *alpha) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, void *, const void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnScaleTensor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, yDesc, y, alpha); @@ -370,68 +373,70 @@ cudnnScaleTensor(cudnnHandle_t handle, const cudnnTensorDescriptor_t yDesc, void cudnnStatus_t CUDNNWINAPI cudnnCreateFilterDescriptor(cudnnFilterDescriptor_t *filterDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnFilterDescriptor_t *); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnFilterDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateFilterDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(filterDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnSetFilter4dDescriptor(cudnnFilterDescriptor_t filterDesc, - cudnnDataType_t dataType, /* image data type */ - cudnnTensorFormat_t format, - int k, /* number of output feature maps */ - int c, /* number of input feature maps */ - int h, /* height of each input filter */ - int w) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnFilterDescriptor_t, cudnnDataType_t, cudnnTensorFormat_t, int, int, int, int); +cudnnStatus_t CUDNNWINAPI cudnnSetFilter4dDescriptor( + cudnnFilterDescriptor_t filterDesc, + cudnnDataType_t dataType, /* image data type */ + cudnnTensorFormat_t format, int k, /* number of output feature maps */ + int c, /* number of input feature maps */ + int h, /* height of each input filter */ + int w) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnFilterDescriptor_t, cudnnDataType_t, + cudnnTensorFormat_t, int, int, int, int); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetFilter4dDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(filterDesc, dataType, format, k, c, h, w); } -cudnnStatus_t CUDNNWINAPI -cudnnGetFilter4dDescriptor(const cudnnFilterDescriptor_t filterDesc, - cudnnDataType_t *dataType, /* image data type */ - cudnnTensorFormat_t *format, - int *k, /* number of output feature maps */ - int *c, /* number of input feature maps */ - int *h, /* height of each input filter */ - int *w) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnFilterDescriptor_t, cudnnDataType_t *, cudnnTensorFormat_t *, int *, int *, int *, int *); +cudnnStatus_t CUDNNWINAPI cudnnGetFilter4dDescriptor( + const cudnnFilterDescriptor_t filterDesc, + cudnnDataType_t *dataType, /* image data type */ + cudnnTensorFormat_t *format, int *k, /* number of output feature maps */ + int *c, /* number of input feature maps */ + int *h, /* height of each input filter */ + int *w) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnFilterDescriptor_t, cudnnDataType_t *, cudnnTensorFormat_t *, + int *, int *, int *, int *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetFilter4dDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(filterDesc, dataType, format, k, c, h, w); } -cudnnStatus_t CUDNNWINAPI -cudnnSetFilterNdDescriptor(cudnnFilterDescriptor_t filterDesc, - cudnnDataType_t dataType, /* image data type */ - cudnnTensorFormat_t format, - int nbDims, - const int filterDimA[]) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnFilterDescriptor_t, cudnnDataType_t, cudnnTensorFormat_t, int, const int []); +cudnnStatus_t CUDNNWINAPI cudnnSetFilterNdDescriptor( + cudnnFilterDescriptor_t filterDesc, + cudnnDataType_t dataType, /* image data type */ + cudnnTensorFormat_t format, int nbDims, const int filterDimA[]) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnFilterDescriptor_t, cudnnDataType_t, + cudnnTensorFormat_t, int, const int[]); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetFilterNdDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(filterDesc, dataType, format, nbDims, filterDimA); } -cudnnStatus_t CUDNNWINAPI -cudnnGetFilterNdDescriptor(const cudnnFilterDescriptor_t filterDesc, - int nbDimsRequested, - cudnnDataType_t *dataType, /* image data type */ - cudnnTensorFormat_t *format, - int *nbDims, - int filterDimA[]) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnFilterDescriptor_t, int, cudnnDataType_t *, cudnnTensorFormat_t *, int *, int []); +cudnnStatus_t CUDNNWINAPI cudnnGetFilterNdDescriptor( + const cudnnFilterDescriptor_t filterDesc, int nbDimsRequested, + cudnnDataType_t *dataType, /* image data type */ + cudnnTensorFormat_t *format, int *nbDims, int filterDimA[]) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnFilterDescriptor_t, int, cudnnDataType_t *, + cudnnTensorFormat_t *, int *, int[]); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetFilterNdDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(filterDesc, nbDimsRequested, dataType, format, nbDims, filterDimA); + return func_ptr(filterDesc, nbDimsRequested, dataType, format, nbDims, + filterDimA); } cudnnStatus_t CUDNNWINAPI cudnnDestroyFilterDescriptor(cudnnFilterDescriptor_t filterDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnFilterDescriptor_t); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnFilterDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyFilterDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(filterDesc); @@ -439,622 +444,657 @@ cudnnDestroyFilterDescriptor(cudnnFilterDescriptor_t filterDesc) { cudnnStatus_t CUDNNWINAPI cudnnCreateConvolutionDescriptor(cudnnConvolutionDescriptor_t *convDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateConvolutionDescriptor"); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnConvolutionDescriptor_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnCreateConvolutionDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(convDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnSetConvolutionMathType(cudnnConvolutionDescriptor_t convDesc, cudnnMathType_t mathType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, cudnnMathType_t); +cudnnStatus_t CUDNNWINAPI cudnnSetConvolutionMathType( + cudnnConvolutionDescriptor_t convDesc, cudnnMathType_t mathType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, + cudnnMathType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetConvolutionMathType"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(convDesc, mathType); } -cudnnStatus_t CUDNNWINAPI -cudnnGetConvolutionMathType(cudnnConvolutionDescriptor_t convDesc, cudnnMathType_t *mathType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, cudnnMathType_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionMathType( + cudnnConvolutionDescriptor_t convDesc, cudnnMathType_t *mathType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, + cudnnMathType_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionMathType"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(convDesc, mathType); } -cudnnStatus_t CUDNNWINAPI -cudnnSetConvolutionGroupCount(cudnnConvolutionDescriptor_t convDesc, int groupCount) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, int); +cudnnStatus_t CUDNNWINAPI cudnnSetConvolutionGroupCount( + cudnnConvolutionDescriptor_t convDesc, int groupCount) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, int); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetConvolutionGroupCount"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(convDesc, groupCount); } -cudnnStatus_t CUDNNWINAPI -cudnnGetConvolutionGroupCount(cudnnConvolutionDescriptor_t convDesc, int *groupCount) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, int *); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionGroupCount( + cudnnConvolutionDescriptor_t convDesc, int *groupCount) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, int *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionGroupCount"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(convDesc, groupCount); } -cudnnStatus_t CUDNNWINAPI -cudnnSetConvolution2dDescriptor(cudnnConvolutionDescriptor_t convDesc, - int pad_h, /* zero-padding height */ - int pad_w, /* zero-padding width */ - int u, /* vertical filter stride */ - int v, /* horizontal filter stride */ - int dilation_h, /* filter dilation in the vertical dimension */ - int dilation_w, /* filter dilation in the horizontal dimension */ - cudnnConvolutionMode_t mode, - cudnnDataType_t computeType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, int, int, int, int, int, int, cudnnConvolutionMode_t, cudnnDataType_t); +cudnnStatus_t CUDNNWINAPI cudnnSetConvolution2dDescriptor( + cudnnConvolutionDescriptor_t convDesc, int pad_h, /* zero-padding height */ + int pad_w, /* zero-padding width */ + int u, /* vertical filter stride */ + int v, /* horizontal filter stride */ + int dilation_h, /* filter dilation in the vertical dimension */ + int dilation_w, /* filter dilation in the horizontal dimension */ + cudnnConvolutionMode_t mode, cudnnDataType_t computeType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnConvolutionDescriptor_t, int, int, int, int, int, int, + cudnnConvolutionMode_t, cudnnDataType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetConvolution2dDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(convDesc, pad_h, pad_w, u, v, dilation_h, dilation_w, mode, computeType); + return func_ptr(convDesc, pad_h, pad_w, u, v, dilation_h, dilation_w, mode, + computeType); } -cudnnStatus_t CUDNNWINAPI -cudnnGetConvolution2dDescriptor(const cudnnConvolutionDescriptor_t convDesc, - int *pad_h, /* zero-padding height */ - int *pad_w, /* zero-padding width */ - int *u, /* vertical filter stride */ - int *v, /* horizontal filter stride */ - int *dilation_h, /* filter dilation in the vertical dimension */ - int *dilation_w, /* filter dilation in the horizontal dimension */ - cudnnConvolutionMode_t *mode, - cudnnDataType_t *computeType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnConvolutionDescriptor_t, int *, int *, int *, int *, int *, int *, cudnnConvolutionMode_t *, cudnnDataType_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolution2dDescriptor( + const cudnnConvolutionDescriptor_t convDesc, + int *pad_h, /* zero-padding height */ + int *pad_w, /* zero-padding width */ + int *u, /* vertical filter stride */ + int *v, /* horizontal filter stride */ + int *dilation_h, /* filter dilation in the vertical dimension */ + int *dilation_w, /* filter dilation in the horizontal dimension */ + cudnnConvolutionMode_t *mode, cudnnDataType_t *computeType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnConvolutionDescriptor_t, int *, int *, int *, int *, int *, + int *, cudnnConvolutionMode_t *, cudnnDataType_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolution2dDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(convDesc, pad_h, pad_w, u, v, dilation_h, dilation_w, mode, computeType); + return func_ptr(convDesc, pad_h, pad_w, u, v, dilation_h, dilation_w, mode, + computeType); } -cudnnStatus_t CUDNNWINAPI -cudnnGetConvolution2dForwardOutputDim(const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t inputTensorDesc, - const cudnnFilterDescriptor_t filterDesc, - int *n, - int *c, - int *h, - int *w) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, int *, int *, int *, int *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolution2dForwardOutputDim"); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolution2dForwardOutputDim( + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t inputTensorDesc, + const cudnnFilterDescriptor_t filterDesc, int *n, int *c, int *h, int *w) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, + const cudnnFilterDescriptor_t, int *, int *, int *, int *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolution2dForwardOutputDim"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(convDesc, inputTensorDesc, filterDesc, n, c, h, w); } -cudnnStatus_t CUDNNWINAPI -cudnnSetConvolutionNdDescriptor(cudnnConvolutionDescriptor_t convDesc, - int arrayLength, /* nbDims-2 size */ - const int padA[], - const int filterStrideA[], - const int dilationA[], - cudnnConvolutionMode_t mode, - cudnnDataType_t computeType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, int, const int [], const int [], const int [], cudnnConvolutionMode_t, cudnnDataType_t); +cudnnStatus_t CUDNNWINAPI cudnnSetConvolutionNdDescriptor( + cudnnConvolutionDescriptor_t convDesc, int arrayLength, /* nbDims-2 size */ + const int padA[], const int filterStrideA[], const int dilationA[], + cudnnConvolutionMode_t mode, cudnnDataType_t computeType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnConvolutionDescriptor_t, int, const int[], const int[], const int[], + cudnnConvolutionMode_t, cudnnDataType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetConvolutionNdDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(convDesc, arrayLength, padA, filterStrideA, dilationA, mode, computeType); + return func_ptr(convDesc, arrayLength, padA, filterStrideA, dilationA, mode, + computeType); } -cudnnStatus_t CUDNNWINAPI -cudnnGetConvolutionNdDescriptor(const cudnnConvolutionDescriptor_t convDesc, - int arrayLengthRequested, - int *arrayLength, - int padA[], - int strideA[], - int dilationA[], - cudnnConvolutionMode_t *mode, - cudnnDataType_t *computeType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnConvolutionDescriptor_t, int, int *, int [], int [], int [], cudnnConvolutionMode_t *, cudnnDataType_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionNdDescriptor( + const cudnnConvolutionDescriptor_t convDesc, int arrayLengthRequested, + int *arrayLength, int padA[], int strideA[], int dilationA[], + cudnnConvolutionMode_t *mode, cudnnDataType_t *computeType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnConvolutionDescriptor_t, int, int *, int[], int[], int[], + cudnnConvolutionMode_t *, cudnnDataType_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionNdDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(convDesc, arrayLengthRequested, arrayLength, padA, strideA, dilationA, mode, computeType); + return func_ptr(convDesc, arrayLengthRequested, arrayLength, padA, strideA, + dilationA, mode, computeType); } -cudnnStatus_t CUDNNWINAPI -cudnnGetConvolutionNdForwardOutputDim(const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t inputTensorDesc, - const cudnnFilterDescriptor_t filterDesc, - int nbDims, - int tensorOutputDimA[]) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, int, int []); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionNdForwardOutputDim"); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionNdForwardOutputDim( + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t inputTensorDesc, + const cudnnFilterDescriptor_t filterDesc, int nbDims, + int tensorOutputDimA[]) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, + const cudnnFilterDescriptor_t, int, int[]); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionNdForwardOutputDim"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(convDesc, inputTensorDesc, filterDesc, nbDims, tensorOutputDimA); + return func_ptr(convDesc, inputTensorDesc, filterDesc, nbDims, + tensorOutputDimA); } cudnnStatus_t CUDNNWINAPI cudnnDestroyConvolutionDescriptor(cudnnConvolutionDescriptor_t convDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyConvolutionDescriptor"); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnConvolutionDescriptor_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnDestroyConvolutionDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(convDesc); } cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionForwardAlgorithmMaxCount(cudnnHandle_t handle, int *count) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, int *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionForwardAlgorithmMaxCount"); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t, int *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionForwardAlgorithmMaxCount"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, count); } -cudnnStatus_t CUDNNWINAPI -cudnnFindConvolutionForwardAlgorithm(cudnnHandle_t handle, - const cudnnTensorDescriptor_t xDesc, - const cudnnFilterDescriptor_t wDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t yDesc, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnConvolutionFwdAlgoPerf_t *perfResults) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, const int, int *, cudnnConvolutionFwdAlgoPerf_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindConvolutionForwardAlgorithm"); +cudnnStatus_t CUDNNWINAPI cudnnFindConvolutionForwardAlgorithm( + cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, + const cudnnFilterDescriptor_t wDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t yDesc, const int requestedAlgoCount, + int *returnedAlgoCount, cudnnConvolutionFwdAlgoPerf_t *perfResults) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, + const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnTensorDescriptor_t, const int, int *, + cudnnConvolutionFwdAlgoPerf_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnFindConvolutionForwardAlgorithm"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, xDesc, wDesc, convDesc, yDesc, requestedAlgoCount, returnedAlgoCount, perfResults); + return func_ptr(handle, xDesc, wDesc, convDesc, yDesc, requestedAlgoCount, + returnedAlgoCount, perfResults); } -cudnnStatus_t CUDNNWINAPI -cudnnFindConvolutionForwardAlgorithmEx(cudnnHandle_t handle, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t yDesc, - void *y, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnConvolutionFwdAlgoPerf_t *perfResults, - void *workSpace, - size_t workSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, void *, const int, int *, cudnnConvolutionFwdAlgoPerf_t *, void *, size_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindConvolutionForwardAlgorithmEx"); +cudnnStatus_t CUDNNWINAPI cudnnFindConvolutionForwardAlgorithmEx( + cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, const void *x, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t yDesc, void *y, const int requestedAlgoCount, + int *returnedAlgoCount, cudnnConvolutionFwdAlgoPerf_t *perfResults, + void *workSpace, size_t workSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, void *, + const int, int *, cudnnConvolutionFwdAlgoPerf_t *, void *, size_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnFindConvolutionForwardAlgorithmEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, xDesc, x, wDesc, w, convDesc, yDesc, y, requestedAlgoCount, returnedAlgoCount, perfResults, workSpace, workSpaceSizeInBytes); + return func_ptr(handle, xDesc, x, wDesc, w, convDesc, yDesc, y, + requestedAlgoCount, returnedAlgoCount, perfResults, workSpace, + workSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnGetConvolutionForwardAlgorithm(cudnnHandle_t handle, - const cudnnTensorDescriptor_t xDesc, - const cudnnFilterDescriptor_t wDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t yDesc, - cudnnConvolutionFwdPreference_t preference, - size_t memoryLimitInBytes, - cudnnConvolutionFwdAlgo_t *algo) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, cudnnConvolutionFwdPreference_t, size_t, cudnnConvolutionFwdAlgo_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionForwardAlgorithm"); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionForwardAlgorithm( + cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, + const cudnnFilterDescriptor_t wDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t yDesc, + cudnnConvolutionFwdPreference_t preference, size_t memoryLimitInBytes, + cudnnConvolutionFwdAlgo_t *algo) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, + const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnTensorDescriptor_t, cudnnConvolutionFwdPreference_t, size_t, + cudnnConvolutionFwdAlgo_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionForwardAlgorithm"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, xDesc, wDesc, convDesc, yDesc, preference, memoryLimitInBytes, algo); + return func_ptr(handle, xDesc, wDesc, convDesc, yDesc, preference, + memoryLimitInBytes, algo); } -cudnnStatus_t CUDNNWINAPI -cudnnGetConvolutionForwardAlgorithm_v7(cudnnHandle_t handle, - const cudnnTensorDescriptor_t srcDesc, - const cudnnFilterDescriptor_t filterDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t destDesc, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnConvolutionFwdAlgoPerf_t *perfResults) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, const int, int *, cudnnConvolutionFwdAlgoPerf_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionForwardAlgorithm_v7"); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionForwardAlgorithm_v7( + cudnnHandle_t handle, const cudnnTensorDescriptor_t srcDesc, + const cudnnFilterDescriptor_t filterDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t destDesc, const int requestedAlgoCount, + int *returnedAlgoCount, cudnnConvolutionFwdAlgoPerf_t *perfResults) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, + const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnTensorDescriptor_t, const int, int *, + cudnnConvolutionFwdAlgoPerf_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionForwardAlgorithm_v7"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, srcDesc, filterDesc, convDesc, destDesc, requestedAlgoCount, returnedAlgoCount, perfResults); + return func_ptr(handle, srcDesc, filterDesc, convDesc, destDesc, + requestedAlgoCount, returnedAlgoCount, perfResults); } -cudnnStatus_t CUDNNWINAPI -cudnnGetConvolutionForwardWorkspaceSize(cudnnHandle_t handle, - const cudnnTensorDescriptor_t xDesc, - const cudnnFilterDescriptor_t wDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t yDesc, - cudnnConvolutionFwdAlgo_t algo, - size_t *sizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, cudnnConvolutionFwdAlgo_t, size_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionForwardWorkspaceSize"); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionForwardWorkspaceSize( + cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, + const cudnnFilterDescriptor_t wDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t yDesc, cudnnConvolutionFwdAlgo_t algo, + size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, + const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnTensorDescriptor_t, cudnnConvolutionFwdAlgo_t, size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionForwardWorkspaceSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, xDesc, wDesc, convDesc, yDesc, algo, sizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnConvolutionForward(cudnnHandle_t handle, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnConvolutionDescriptor_t convDesc, - cudnnConvolutionFwdAlgo_t algo, - void *workSpace, - size_t workSpaceSizeInBytes, - const void *beta, - const cudnnTensorDescriptor_t yDesc, - void *y) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnConvolutionDescriptor_t, cudnnConvolutionFwdAlgo_t, void *, size_t, const void *, const cudnnTensorDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnConvolutionForward( + cudnnHandle_t handle, const void *alpha, + const cudnnTensorDescriptor_t xDesc, const void *x, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnConvolutionDescriptor_t convDesc, cudnnConvolutionFwdAlgo_t algo, + void *workSpace, size_t workSpaceSizeInBytes, const void *beta, + const cudnnTensorDescriptor_t yDesc, void *y) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnConvolutionDescriptor_t, cudnnConvolutionFwdAlgo_t, void *, + size_t, const void *, const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnConvolutionForward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, alpha, xDesc, x, wDesc, w, convDesc, algo, workSpace, workSpaceSizeInBytes, beta, yDesc, y); + return func_ptr(handle, alpha, xDesc, x, wDesc, w, convDesc, algo, workSpace, + workSpaceSizeInBytes, beta, yDesc, y); } -cudnnStatus_t CUDNNWINAPI -cudnnConvolutionBiasActivationForward(cudnnHandle_t handle, - const void *alpha1, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnConvolutionDescriptor_t convDesc, - cudnnConvolutionFwdAlgo_t algo, - void *workSpace, - size_t workSpaceSizeInBytes, - const void *alpha2, - const cudnnTensorDescriptor_t zDesc, - const void *z, - const cudnnTensorDescriptor_t biasDesc, - const void *bias, - const cudnnActivationDescriptor_t activationDesc, - const cudnnTensorDescriptor_t yDesc, - void *y) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnConvolutionDescriptor_t, cudnnConvolutionFwdAlgo_t, void *, size_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnActivationDescriptor_t, const cudnnTensorDescriptor_t, void *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnConvolutionBiasActivationForward"); +cudnnStatus_t CUDNNWINAPI cudnnConvolutionBiasActivationForward( + cudnnHandle_t handle, const void *alpha1, + const cudnnTensorDescriptor_t xDesc, const void *x, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnConvolutionDescriptor_t convDesc, cudnnConvolutionFwdAlgo_t algo, + void *workSpace, size_t workSpaceSizeInBytes, const void *alpha2, + const cudnnTensorDescriptor_t zDesc, const void *z, + const cudnnTensorDescriptor_t biasDesc, const void *bias, + const cudnnActivationDescriptor_t activationDesc, + const cudnnTensorDescriptor_t yDesc, void *y) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnConvolutionDescriptor_t, cudnnConvolutionFwdAlgo_t, void *, + size_t, const void *, const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnActivationDescriptor_t, const cudnnTensorDescriptor_t, void *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnConvolutionBiasActivationForward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, alpha1, xDesc, x, wDesc, w, convDesc, algo, workSpace, workSpaceSizeInBytes, alpha2, zDesc, z, biasDesc, bias, activationDesc, yDesc, y); + return func_ptr(handle, alpha1, xDesc, x, wDesc, w, convDesc, algo, workSpace, + workSpaceSizeInBytes, alpha2, zDesc, z, biasDesc, bias, + activationDesc, yDesc, y); } -cudnnStatus_t CUDNNWINAPI -cudnnConvolutionBackwardBias(cudnnHandle_t handle, - const void *alpha, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const void *beta, - const cudnnTensorDescriptor_t dbDesc, - void *db) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnConvolutionBackwardBias( + cudnnHandle_t handle, const void *alpha, + const cudnnTensorDescriptor_t dyDesc, const void *dy, const void *beta, + const cudnnTensorDescriptor_t dbDesc, void *db) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, + const void *, const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnConvolutionBackwardBias"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, alpha, dyDesc, dy, beta, dbDesc, db); } -cudnnStatus_t CUDNNWINAPI -cudnnGetConvolutionBackwardFilterAlgorithmMaxCount(cudnnHandle_t handle, int *count) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, int *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardFilterAlgorithmMaxCount"); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionBackwardFilterAlgorithmMaxCount( + cudnnHandle_t handle, int *count) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t, int *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardFilterAlgorithmMaxCount"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, count); } -cudnnStatus_t CUDNNWINAPI -cudnnFindConvolutionBackwardFilterAlgorithm(cudnnHandle_t handle, - const cudnnTensorDescriptor_t xDesc, - const cudnnTensorDescriptor_t dyDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnFilterDescriptor_t dwDesc, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnConvolutionBwdFilterAlgoPerf_t *perfResults) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnFilterDescriptor_t, const int, int *, cudnnConvolutionBwdFilterAlgoPerf_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindConvolutionBackwardFilterAlgorithm"); +cudnnStatus_t CUDNNWINAPI cudnnFindConvolutionBackwardFilterAlgorithm( + cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, + const cudnnTensorDescriptor_t dyDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnFilterDescriptor_t dwDesc, const int requestedAlgoCount, + int *returnedAlgoCount, cudnnConvolutionBwdFilterAlgoPerf_t *perfResults) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnFilterDescriptor_t, const int, int *, + cudnnConvolutionBwdFilterAlgoPerf_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnFindConvolutionBackwardFilterAlgorithm"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, xDesc, dyDesc, convDesc, dwDesc, requestedAlgoCount, returnedAlgoCount, perfResults); + return func_ptr(handle, xDesc, dyDesc, convDesc, dwDesc, requestedAlgoCount, + returnedAlgoCount, perfResults); } -cudnnStatus_t CUDNNWINAPI -cudnnFindConvolutionBackwardFilterAlgorithmEx(cudnnHandle_t handle, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const cudnnTensorDescriptor_t dyDesc, - const void *y, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnFilterDescriptor_t dwDesc, - void *dw, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnConvolutionBwdFilterAlgoPerf_t *perfResults, - void *workSpace, - size_t workSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnConvolutionDescriptor_t, const cudnnFilterDescriptor_t, void *, const int, int *, cudnnConvolutionBwdFilterAlgoPerf_t *, void *, size_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindConvolutionBackwardFilterAlgorithmEx"); +cudnnStatus_t CUDNNWINAPI cudnnFindConvolutionBackwardFilterAlgorithmEx( + cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, const void *x, + const cudnnTensorDescriptor_t dyDesc, const void *y, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnFilterDescriptor_t dwDesc, void *dw, + const int requestedAlgoCount, int *returnedAlgoCount, + cudnnConvolutionBwdFilterAlgoPerf_t *perfResults, void *workSpace, + size_t workSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnConvolutionDescriptor_t, const cudnnFilterDescriptor_t, void *, + const int, int *, cudnnConvolutionBwdFilterAlgoPerf_t *, void *, size_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnFindConvolutionBackwardFilterAlgorithmEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, xDesc, x, dyDesc, y, convDesc, dwDesc, dw, requestedAlgoCount, returnedAlgoCount, perfResults, workSpace, workSpaceSizeInBytes); + return func_ptr(handle, xDesc, x, dyDesc, y, convDesc, dwDesc, dw, + requestedAlgoCount, returnedAlgoCount, perfResults, workSpace, + workSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnGetConvolutionBackwardFilterAlgorithm(cudnnHandle_t handle, - const cudnnTensorDescriptor_t xDesc, - const cudnnTensorDescriptor_t dyDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnFilterDescriptor_t dwDesc, - cudnnConvolutionBwdFilterPreference_t preference, - size_t memoryLimitInBytes, - cudnnConvolutionBwdFilterAlgo_t *algo) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnFilterDescriptor_t, cudnnConvolutionBwdFilterPreference_t, size_t, cudnnConvolutionBwdFilterAlgo_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardFilterAlgorithm"); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionBackwardFilterAlgorithm( + cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, + const cudnnTensorDescriptor_t dyDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnFilterDescriptor_t dwDesc, + cudnnConvolutionBwdFilterPreference_t preference, size_t memoryLimitInBytes, + cudnnConvolutionBwdFilterAlgo_t *algo) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnFilterDescriptor_t, cudnnConvolutionBwdFilterPreference_t, + size_t, cudnnConvolutionBwdFilterAlgo_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardFilterAlgorithm"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, xDesc, dyDesc, convDesc, dwDesc, preference, memoryLimitInBytes, algo); + return func_ptr(handle, xDesc, dyDesc, convDesc, dwDesc, preference, + memoryLimitInBytes, algo); } -cudnnStatus_t CUDNNWINAPI -cudnnGetConvolutionBackwardFilterAlgorithm_v7(cudnnHandle_t handle, - const cudnnTensorDescriptor_t srcDesc, - const cudnnTensorDescriptor_t diffDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnFilterDescriptor_t gradDesc, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnConvolutionBwdFilterAlgoPerf_t *perfResults) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnFilterDescriptor_t, const int, int *, cudnnConvolutionBwdFilterAlgoPerf_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardFilterAlgorithm_v7"); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionBackwardFilterAlgorithm_v7( + cudnnHandle_t handle, const cudnnTensorDescriptor_t srcDesc, + const cudnnTensorDescriptor_t diffDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnFilterDescriptor_t gradDesc, const int requestedAlgoCount, + int *returnedAlgoCount, cudnnConvolutionBwdFilterAlgoPerf_t *perfResults) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnFilterDescriptor_t, const int, int *, + cudnnConvolutionBwdFilterAlgoPerf_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardFilterAlgorithm_v7"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, srcDesc, diffDesc, convDesc, gradDesc, requestedAlgoCount, returnedAlgoCount, perfResults); + return func_ptr(handle, srcDesc, diffDesc, convDesc, gradDesc, + requestedAlgoCount, returnedAlgoCount, perfResults); } -cudnnStatus_t CUDNNWINAPI -cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnnHandle_t handle, - const cudnnTensorDescriptor_t xDesc, - const cudnnTensorDescriptor_t dyDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnFilterDescriptor_t gradDesc, - cudnnConvolutionBwdFilterAlgo_t algo, - size_t *sizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnFilterDescriptor_t, cudnnConvolutionBwdFilterAlgo_t, size_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardFilterWorkspaceSize"); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionBackwardFilterWorkspaceSize( + cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, + const cudnnTensorDescriptor_t dyDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnFilterDescriptor_t gradDesc, + cudnnConvolutionBwdFilterAlgo_t algo, size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnFilterDescriptor_t, cudnnConvolutionBwdFilterAlgo_t, size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardFilterWorkspaceSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, xDesc, dyDesc, convDesc, gradDesc, algo, sizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnConvolutionBackwardFilter(cudnnHandle_t handle, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const cudnnConvolutionDescriptor_t convDesc, - cudnnConvolutionBwdFilterAlgo_t algo, - void *workSpace, - size_t workSpaceSizeInBytes, - const void *beta, - const cudnnFilterDescriptor_t dwDesc, - void *dw) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnConvolutionDescriptor_t, cudnnConvolutionBwdFilterAlgo_t, void *, size_t, const void *, const cudnnFilterDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnConvolutionBackwardFilter( + cudnnHandle_t handle, const void *alpha, + const cudnnTensorDescriptor_t xDesc, const void *x, + const cudnnTensorDescriptor_t dyDesc, const void *dy, + const cudnnConvolutionDescriptor_t convDesc, + cudnnConvolutionBwdFilterAlgo_t algo, void *workSpace, + size_t workSpaceSizeInBytes, const void *beta, + const cudnnFilterDescriptor_t dwDesc, void *dw) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnConvolutionDescriptor_t, cudnnConvolutionBwdFilterAlgo_t, + void *, size_t, const void *, const cudnnFilterDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnConvolutionBackwardFilter"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, alpha, xDesc, x, dyDesc, dy, convDesc, algo, workSpace, workSpaceSizeInBytes, beta, dwDesc, dw); + return func_ptr(handle, alpha, xDesc, x, dyDesc, dy, convDesc, algo, + workSpace, workSpaceSizeInBytes, beta, dwDesc, dw); } -cudnnStatus_t CUDNNWINAPI -cudnnGetConvolutionBackwardDataAlgorithmMaxCount(cudnnHandle_t handle, int *count) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, int *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardDataAlgorithmMaxCount"); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionBackwardDataAlgorithmMaxCount( + cudnnHandle_t handle, int *count) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t, int *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardDataAlgorithmMaxCount"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, count); } -cudnnStatus_t CUDNNWINAPI -cudnnFindConvolutionBackwardDataAlgorithm(cudnnHandle_t handle, - const cudnnFilterDescriptor_t wDesc, - const cudnnTensorDescriptor_t dyDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t dxDesc, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnConvolutionBwdDataAlgoPerf_t *perfResults) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnFilterDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, const int, int *, cudnnConvolutionBwdDataAlgoPerf_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindConvolutionBackwardDataAlgorithm"); +cudnnStatus_t CUDNNWINAPI cudnnFindConvolutionBackwardDataAlgorithm( + cudnnHandle_t handle, const cudnnFilterDescriptor_t wDesc, + const cudnnTensorDescriptor_t dyDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t dxDesc, const int requestedAlgoCount, + int *returnedAlgoCount, cudnnConvolutionBwdDataAlgoPerf_t *perfResults) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnFilterDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnTensorDescriptor_t, const int, int *, + cudnnConvolutionBwdDataAlgoPerf_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnFindConvolutionBackwardDataAlgorithm"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, wDesc, dyDesc, convDesc, dxDesc, requestedAlgoCount, returnedAlgoCount, perfResults); + return func_ptr(handle, wDesc, dyDesc, convDesc, dxDesc, requestedAlgoCount, + returnedAlgoCount, perfResults); } -cudnnStatus_t CUDNNWINAPI -cudnnFindConvolutionBackwardDataAlgorithmEx(cudnnHandle_t handle, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t dxDesc, - void *dx, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnConvolutionBwdDataAlgoPerf_t *perfResults, - void *workSpace, - size_t workSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnFilterDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, void *, const int, int *, cudnnConvolutionBwdDataAlgoPerf_t *, void *, size_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindConvolutionBackwardDataAlgorithmEx"); +cudnnStatus_t CUDNNWINAPI cudnnFindConvolutionBackwardDataAlgorithmEx( + cudnnHandle_t handle, const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnTensorDescriptor_t dyDesc, const void *dy, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t dxDesc, void *dx, + const int requestedAlgoCount, int *returnedAlgoCount, + cudnnConvolutionBwdDataAlgoPerf_t *perfResults, void *workSpace, + size_t workSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnFilterDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, void *, + const int, int *, cudnnConvolutionBwdDataAlgoPerf_t *, void *, size_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnFindConvolutionBackwardDataAlgorithmEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, wDesc, w, dyDesc, dy, convDesc, dxDesc, dx, requestedAlgoCount, returnedAlgoCount, perfResults, workSpace, workSpaceSizeInBytes); + return func_ptr(handle, wDesc, w, dyDesc, dy, convDesc, dxDesc, dx, + requestedAlgoCount, returnedAlgoCount, perfResults, workSpace, + workSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnGetConvolutionBackwardDataAlgorithm(cudnnHandle_t handle, - const cudnnFilterDescriptor_t wDesc, - const cudnnTensorDescriptor_t dyDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t dxDesc, - cudnnConvolutionBwdDataPreference_t preference, - size_t memoryLimitInBytes, - cudnnConvolutionBwdDataAlgo_t *algo) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnFilterDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, cudnnConvolutionBwdDataPreference_t, size_t, cudnnConvolutionBwdDataAlgo_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardDataAlgorithm"); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionBackwardDataAlgorithm( + cudnnHandle_t handle, const cudnnFilterDescriptor_t wDesc, + const cudnnTensorDescriptor_t dyDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t dxDesc, + cudnnConvolutionBwdDataPreference_t preference, size_t memoryLimitInBytes, + cudnnConvolutionBwdDataAlgo_t *algo) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnFilterDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnTensorDescriptor_t, cudnnConvolutionBwdDataPreference_t, + size_t, cudnnConvolutionBwdDataAlgo_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardDataAlgorithm"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, wDesc, dyDesc, convDesc, dxDesc, preference, memoryLimitInBytes, algo); + return func_ptr(handle, wDesc, dyDesc, convDesc, dxDesc, preference, + memoryLimitInBytes, algo); } -cudnnStatus_t CUDNNWINAPI -cudnnGetConvolutionBackwardDataAlgorithm_v7(cudnnHandle_t handle, - const cudnnFilterDescriptor_t filterDesc, - const cudnnTensorDescriptor_t diffDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t gradDesc, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnConvolutionBwdDataAlgoPerf_t *perfResults) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnFilterDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, const int, int *, cudnnConvolutionBwdDataAlgoPerf_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardDataAlgorithm_v7"); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionBackwardDataAlgorithm_v7( + cudnnHandle_t handle, const cudnnFilterDescriptor_t filterDesc, + const cudnnTensorDescriptor_t diffDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t gradDesc, const int requestedAlgoCount, + int *returnedAlgoCount, cudnnConvolutionBwdDataAlgoPerf_t *perfResults) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnFilterDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnTensorDescriptor_t, const int, int *, + cudnnConvolutionBwdDataAlgoPerf_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardDataAlgorithm_v7"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, filterDesc, diffDesc, convDesc, gradDesc, requestedAlgoCount, returnedAlgoCount, perfResults); + return func_ptr(handle, filterDesc, diffDesc, convDesc, gradDesc, + requestedAlgoCount, returnedAlgoCount, perfResults); } -cudnnStatus_t CUDNNWINAPI -cudnnGetConvolutionBackwardDataWorkspaceSize(cudnnHandle_t handle, - const cudnnFilterDescriptor_t wDesc, - const cudnnTensorDescriptor_t dyDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t dxDesc, - cudnnConvolutionBwdDataAlgo_t algo, - size_t *sizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnFilterDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, cudnnConvolutionBwdDataAlgo_t, size_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardDataWorkspaceSize"); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionBackwardDataWorkspaceSize( + cudnnHandle_t handle, const cudnnFilterDescriptor_t wDesc, + const cudnnTensorDescriptor_t dyDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t dxDesc, cudnnConvolutionBwdDataAlgo_t algo, + size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnFilterDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnTensorDescriptor_t, cudnnConvolutionBwdDataAlgo_t, size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardDataWorkspaceSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, wDesc, dyDesc, convDesc, dxDesc, algo, sizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnConvolutionBackwardData(cudnnHandle_t handle, - const void *alpha, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const cudnnConvolutionDescriptor_t convDesc, - cudnnConvolutionBwdDataAlgo_t algo, - void *workSpace, - size_t workSpaceSizeInBytes, - const void *beta, - const cudnnTensorDescriptor_t dxDesc, - void *dx) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnConvolutionDescriptor_t, cudnnConvolutionBwdDataAlgo_t, void *, size_t, const void *, const cudnnTensorDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnConvolutionBackwardData( + cudnnHandle_t handle, const void *alpha, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnTensorDescriptor_t dyDesc, const void *dy, + const cudnnConvolutionDescriptor_t convDesc, + cudnnConvolutionBwdDataAlgo_t algo, void *workSpace, + size_t workSpaceSizeInBytes, const void *beta, + const cudnnTensorDescriptor_t dxDesc, void *dx) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const void *, const cudnnFilterDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnConvolutionDescriptor_t, cudnnConvolutionBwdDataAlgo_t, void *, + size_t, const void *, const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnConvolutionBackwardData"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, alpha, wDesc, w, dyDesc, dy, convDesc, algo, workSpace, workSpaceSizeInBytes, beta, dxDesc, dx); + return func_ptr(handle, alpha, wDesc, w, dyDesc, dy, convDesc, algo, + workSpace, workSpaceSizeInBytes, beta, dxDesc, dx); } cudnnStatus_t CUDNNWINAPI -cudnnIm2Col(cudnnHandle_t handle, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const cudnnFilterDescriptor_t wDesc, - const cudnnConvolutionDescriptor_t convDesc, - void *colBuffer) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, void *); +cudnnIm2Col(cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, + const void *x, const cudnnFilterDescriptor_t wDesc, + const cudnnConvolutionDescriptor_t convDesc, void *colBuffer) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, + const void *, const cudnnFilterDescriptor_t, + const cudnnConvolutionDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnIm2Col"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, xDesc, x, wDesc, convDesc, colBuffer); } -cudnnStatus_t CUDNNWINAPI -cudnnSoftmaxForward(cudnnHandle_t handle, - cudnnSoftmaxAlgorithm_t algo, - cudnnSoftmaxMode_t mode, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t yDesc, - void *y) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnSoftmaxAlgorithm_t, cudnnSoftmaxMode_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnSoftmaxForward( + cudnnHandle_t handle, cudnnSoftmaxAlgorithm_t algo, cudnnSoftmaxMode_t mode, + const void *alpha, const cudnnTensorDescriptor_t xDesc, const void *x, + const void *beta, const cudnnTensorDescriptor_t yDesc, void *y) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnSoftmaxAlgorithm_t, cudnnSoftmaxMode_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSoftmaxForward"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, algo, mode, alpha, xDesc, x, beta, yDesc, y); } -cudnnStatus_t CUDNNWINAPI -cudnnSoftmaxBackward(cudnnHandle_t handle, - cudnnSoftmaxAlgorithm_t algo, - cudnnSoftmaxMode_t mode, - const void *alpha, - const cudnnTensorDescriptor_t yDesc, - const void *y, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const void *beta, - const cudnnTensorDescriptor_t dxDesc, - void *dx) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnSoftmaxAlgorithm_t, cudnnSoftmaxMode_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnSoftmaxBackward( + cudnnHandle_t handle, cudnnSoftmaxAlgorithm_t algo, cudnnSoftmaxMode_t mode, + const void *alpha, const cudnnTensorDescriptor_t yDesc, const void *y, + const cudnnTensorDescriptor_t dyDesc, const void *dy, const void *beta, + const cudnnTensorDescriptor_t dxDesc, void *dx) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnSoftmaxAlgorithm_t, cudnnSoftmaxMode_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSoftmaxBackward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, algo, mode, alpha, yDesc, y, dyDesc, dy, beta, dxDesc, dx); + return func_ptr(handle, algo, mode, alpha, yDesc, y, dyDesc, dy, beta, dxDesc, + dx); } cudnnStatus_t CUDNNWINAPI cudnnCreatePoolingDescriptor(cudnnPoolingDescriptor_t *poolingDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnPoolingDescriptor_t *); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnPoolingDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreatePoolingDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(poolingDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnSetPooling2dDescriptor(cudnnPoolingDescriptor_t poolingDesc, - cudnnPoolingMode_t mode, - cudnnNanPropagation_t maxpoolingNanOpt, - int windowHeight, - int windowWidth, - int verticalPadding, - int horizontalPadding, - int verticalStride, - int horizontalStride) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnPoolingDescriptor_t, cudnnPoolingMode_t, cudnnNanPropagation_t, int, int, int, int, int, int); +cudnnStatus_t CUDNNWINAPI cudnnSetPooling2dDescriptor( + cudnnPoolingDescriptor_t poolingDesc, cudnnPoolingMode_t mode, + cudnnNanPropagation_t maxpoolingNanOpt, int windowHeight, int windowWidth, + int verticalPadding, int horizontalPadding, int verticalStride, + int horizontalStride) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnPoolingDescriptor_t, cudnnPoolingMode_t, cudnnNanPropagation_t, int, + int, int, int, int, int); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetPooling2dDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(poolingDesc, mode, maxpoolingNanOpt, windowHeight, windowWidth, verticalPadding, horizontalPadding, verticalStride, horizontalStride); + return func_ptr(poolingDesc, mode, maxpoolingNanOpt, windowHeight, + windowWidth, verticalPadding, horizontalPadding, + verticalStride, horizontalStride); } -cudnnStatus_t CUDNNWINAPI -cudnnGetPooling2dDescriptor(const cudnnPoolingDescriptor_t poolingDesc, - cudnnPoolingMode_t *mode, - cudnnNanPropagation_t *maxpoolingNanOpt, - int *windowHeight, - int *windowWidth, - int *verticalPadding, - int *horizontalPadding, - int *verticalStride, - int *horizontalStride) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnPoolingDescriptor_t, cudnnPoolingMode_t *, cudnnNanPropagation_t *, int *, int *, int *, int *, int *, int *); +cudnnStatus_t CUDNNWINAPI cudnnGetPooling2dDescriptor( + const cudnnPoolingDescriptor_t poolingDesc, cudnnPoolingMode_t *mode, + cudnnNanPropagation_t *maxpoolingNanOpt, int *windowHeight, + int *windowWidth, int *verticalPadding, int *horizontalPadding, + int *verticalStride, int *horizontalStride) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnPoolingDescriptor_t, cudnnPoolingMode_t *, + cudnnNanPropagation_t *, int *, int *, int *, int *, int *, int *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetPooling2dDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(poolingDesc, mode, maxpoolingNanOpt, windowHeight, windowWidth, verticalPadding, horizontalPadding, verticalStride, horizontalStride); + return func_ptr(poolingDesc, mode, maxpoolingNanOpt, windowHeight, + windowWidth, verticalPadding, horizontalPadding, + verticalStride, horizontalStride); } -cudnnStatus_t CUDNNWINAPI -cudnnSetPoolingNdDescriptor(cudnnPoolingDescriptor_t poolingDesc, - const cudnnPoolingMode_t mode, - const cudnnNanPropagation_t maxpoolingNanOpt, - int nbDims, - const int windowDimA[], - const int paddingA[], - const int strideA[]) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnPoolingDescriptor_t, const cudnnPoolingMode_t, const cudnnNanPropagation_t, int, const int [], const int [], const int []); +cudnnStatus_t CUDNNWINAPI cudnnSetPoolingNdDescriptor( + cudnnPoolingDescriptor_t poolingDesc, const cudnnPoolingMode_t mode, + const cudnnNanPropagation_t maxpoolingNanOpt, int nbDims, + const int windowDimA[], const int paddingA[], const int strideA[]) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnPoolingDescriptor_t, const cudnnPoolingMode_t, + const cudnnNanPropagation_t, int, const int[], const int[], const int[]); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetPoolingNdDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(poolingDesc, mode, maxpoolingNanOpt, nbDims, windowDimA, paddingA, strideA); + return func_ptr(poolingDesc, mode, maxpoolingNanOpt, nbDims, windowDimA, + paddingA, strideA); } -cudnnStatus_t CUDNNWINAPI -cudnnGetPoolingNdDescriptor(const cudnnPoolingDescriptor_t poolingDesc, - int nbDimsRequested, - cudnnPoolingMode_t *mode, - cudnnNanPropagation_t *maxpoolingNanOpt, - int *nbDims, - int windowDimA[], - int paddingA[], - int strideA[]) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnPoolingDescriptor_t, int, cudnnPoolingMode_t *, cudnnNanPropagation_t *, int *, int [], int [], int []); +cudnnStatus_t CUDNNWINAPI cudnnGetPoolingNdDescriptor( + const cudnnPoolingDescriptor_t poolingDesc, int nbDimsRequested, + cudnnPoolingMode_t *mode, cudnnNanPropagation_t *maxpoolingNanOpt, + int *nbDims, int windowDimA[], int paddingA[], int strideA[]) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnPoolingDescriptor_t, int, cudnnPoolingMode_t *, + cudnnNanPropagation_t *, int *, int[], int[], int[]); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetPoolingNdDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(poolingDesc, nbDimsRequested, mode, maxpoolingNanOpt, nbDims, windowDimA, paddingA, strideA); + return func_ptr(poolingDesc, nbDimsRequested, mode, maxpoolingNanOpt, nbDims, + windowDimA, paddingA, strideA); } cudnnStatus_t CUDNNWINAPI cudnnGetPoolingNdForwardOutputDim(const cudnnPoolingDescriptor_t poolingDesc, const cudnnTensorDescriptor_t inputTensorDesc, - int nbDims, - int outputTensorDimA[]) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnPoolingDescriptor_t, const cudnnTensorDescriptor_t, int, int []); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetPoolingNdForwardOutputDim"); + int nbDims, int outputTensorDimA[]) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(const cudnnPoolingDescriptor_t, + const cudnnTensorDescriptor_t, int, int[]); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetPoolingNdForwardOutputDim"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(poolingDesc, inputTensorDesc, nbDims, outputTensorDimA); } @@ -1062,72 +1102,69 @@ cudnnGetPoolingNdForwardOutputDim(const cudnnPoolingDescriptor_t poolingDesc, cudnnStatus_t CUDNNWINAPI cudnnGetPooling2dForwardOutputDim(const cudnnPoolingDescriptor_t poolingDesc, const cudnnTensorDescriptor_t inputTensorDesc, - int *n, - int *c, - int *h, - int *w) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnPoolingDescriptor_t, const cudnnTensorDescriptor_t, int *, int *, int *, int *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetPooling2dForwardOutputDim"); + int *n, int *c, int *h, int *w) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(const cudnnPoolingDescriptor_t, + const cudnnTensorDescriptor_t, + int *, int *, int *, int *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetPooling2dForwardOutputDim"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(poolingDesc, inputTensorDesc, n, c, h, w); } cudnnStatus_t CUDNNWINAPI cudnnDestroyPoolingDescriptor(cudnnPoolingDescriptor_t poolingDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnPoolingDescriptor_t); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnPoolingDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyPoolingDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(poolingDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnPoolingForward(cudnnHandle_t handle, - const cudnnPoolingDescriptor_t poolingDesc, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t yDesc, - void *y) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnPoolingDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnPoolingForward( + cudnnHandle_t handle, const cudnnPoolingDescriptor_t poolingDesc, + const void *alpha, const cudnnTensorDescriptor_t xDesc, const void *x, + const void *beta, const cudnnTensorDescriptor_t yDesc, void *y) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnPoolingDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnPoolingForward"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, poolingDesc, alpha, xDesc, x, beta, yDesc, y); } -cudnnStatus_t CUDNNWINAPI -cudnnPoolingBackward(cudnnHandle_t handle, - const cudnnPoolingDescriptor_t poolingDesc, - const void *alpha, - const cudnnTensorDescriptor_t yDesc, - const void *y, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t dxDesc, - void *dx) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnPoolingDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnPoolingBackward( + cudnnHandle_t handle, const cudnnPoolingDescriptor_t poolingDesc, + const void *alpha, const cudnnTensorDescriptor_t yDesc, const void *y, + const cudnnTensorDescriptor_t dyDesc, const void *dy, + const cudnnTensorDescriptor_t xDesc, const void *x, const void *beta, + const cudnnTensorDescriptor_t dxDesc, void *dx) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnPoolingDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnPoolingBackward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, poolingDesc, alpha, yDesc, y, dyDesc, dy, xDesc, x, beta, dxDesc, dx); + return func_ptr(handle, poolingDesc, alpha, yDesc, y, dyDesc, dy, xDesc, x, + beta, dxDesc, dx); } cudnnStatus_t CUDNNWINAPI cudnnCreateActivationDescriptor(cudnnActivationDescriptor_t *activationDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnActivationDescriptor_t *); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnActivationDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateActivationDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(activationDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnSetActivationDescriptor(cudnnActivationDescriptor_t activationDesc, - cudnnActivationMode_t mode, - cudnnNanPropagation_t reluNanOpt, - double coef) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnActivationDescriptor_t, cudnnActivationMode_t, cudnnNanPropagation_t, double); +cudnnStatus_t CUDNNWINAPI cudnnSetActivationDescriptor( + cudnnActivationDescriptor_t activationDesc, cudnnActivationMode_t mode, + cudnnNanPropagation_t reluNanOpt, double coef) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnActivationDescriptor_t, + cudnnActivationMode_t, + cudnnNanPropagation_t, double); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetActivationDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(activationDesc, mode, reluNanOpt, coef); @@ -1136,9 +1173,10 @@ cudnnSetActivationDescriptor(cudnnActivationDescriptor_t activationDesc, cudnnStatus_t CUDNNWINAPI cudnnGetActivationDescriptor(const cudnnActivationDescriptor_t activationDesc, cudnnActivationMode_t *mode, - cudnnNanPropagation_t *reluNanOpt, - double *coef) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnActivationDescriptor_t, cudnnActivationMode_t *, cudnnNanPropagation_t *, double *); + cudnnNanPropagation_t *reluNanOpt, double *coef) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnActivationDescriptor_t, cudnnActivationMode_t *, + cudnnNanPropagation_t *, double *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetActivationDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(activationDesc, mode, reluNanOpt, coef); @@ -1146,65 +1184,68 @@ cudnnGetActivationDescriptor(const cudnnActivationDescriptor_t activationDesc, cudnnStatus_t CUDNNWINAPI cudnnDestroyActivationDescriptor(cudnnActivationDescriptor_t activationDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnActivationDescriptor_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyActivationDescriptor"); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnActivationDescriptor_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnDestroyActivationDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(activationDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnActivationForward(cudnnHandle_t handle, - cudnnActivationDescriptor_t activationDesc, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t yDesc, - void *y) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnActivationDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnActivationForward( + cudnnHandle_t handle, cudnnActivationDescriptor_t activationDesc, + const void *alpha, const cudnnTensorDescriptor_t xDesc, const void *x, + const void *beta, const cudnnTensorDescriptor_t yDesc, void *y) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnActivationDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnActivationForward"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, activationDesc, alpha, xDesc, x, beta, yDesc, y); } -cudnnStatus_t CUDNNWINAPI -cudnnActivationBackward(cudnnHandle_t handle, - cudnnActivationDescriptor_t activationDesc, - const void *alpha, - const cudnnTensorDescriptor_t yDesc, - const void *y, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t dxDesc, - void *dx) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnActivationDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnActivationBackward( + cudnnHandle_t handle, cudnnActivationDescriptor_t activationDesc, + const void *alpha, const cudnnTensorDescriptor_t yDesc, const void *y, + const cudnnTensorDescriptor_t dyDesc, const void *dy, + const cudnnTensorDescriptor_t xDesc, const void *x, const void *beta, + const cudnnTensorDescriptor_t dxDesc, void *dx) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnActivationDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnActivationBackward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, activationDesc, alpha, yDesc, y, dyDesc, dy, xDesc, x, beta, dxDesc, dx); + return func_ptr(handle, activationDesc, alpha, yDesc, y, dyDesc, dy, xDesc, x, + beta, dxDesc, dx); } cudnnStatus_t CUDNNWINAPI cudnnCreateLRNDescriptor(cudnnLRNDescriptor_t *normDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnLRNDescriptor_t *); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnLRNDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateLRNDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(normDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnSetLRNDescriptor(cudnnLRNDescriptor_t normDesc, unsigned lrnN, double lrnAlpha, double lrnBeta, double lrnK) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnLRNDescriptor_t, unsigned int, double, double, double); +cudnnStatus_t CUDNNWINAPI cudnnSetLRNDescriptor(cudnnLRNDescriptor_t normDesc, + unsigned lrnN, double lrnAlpha, + double lrnBeta, double lrnK) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnLRNDescriptor_t, unsigned int, double, double, double); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetLRNDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(normDesc, lrnN, lrnAlpha, lrnBeta, lrnK); } -cudnnStatus_t CUDNNWINAPI -cudnnGetLRNDescriptor(cudnnLRNDescriptor_t normDesc, unsigned *lrnN, double *lrnAlpha, double *lrnBeta, double *lrnK) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnLRNDescriptor_t, unsigned int *, double *, double *, double *); +cudnnStatus_t CUDNNWINAPI cudnnGetLRNDescriptor(cudnnLRNDescriptor_t normDesc, + unsigned *lrnN, + double *lrnAlpha, + double *lrnBeta, double *lrnK) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnLRNDescriptor_t, unsigned int *, double *, double *, double *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetLRNDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(normDesc, lrnN, lrnAlpha, lrnBeta, lrnK); @@ -1212,110 +1253,104 @@ cudnnGetLRNDescriptor(cudnnLRNDescriptor_t normDesc, unsigned *lrnN, double *lrn cudnnStatus_t CUDNNWINAPI cudnnDestroyLRNDescriptor(cudnnLRNDescriptor_t lrnDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnLRNDescriptor_t); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnLRNDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyLRNDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(lrnDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnLRNCrossChannelForward(cudnnHandle_t handle, - cudnnLRNDescriptor_t normDesc, - cudnnLRNMode_t lrnMode, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t yDesc, - void *y) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnLRNDescriptor_t, cudnnLRNMode_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnLRNCrossChannelForward( + cudnnHandle_t handle, cudnnLRNDescriptor_t normDesc, cudnnLRNMode_t lrnMode, + const void *alpha, const cudnnTensorDescriptor_t xDesc, const void *x, + const void *beta, const cudnnTensorDescriptor_t yDesc, void *y) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnLRNDescriptor_t, cudnnLRNMode_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnLRNCrossChannelForward"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, normDesc, lrnMode, alpha, xDesc, x, beta, yDesc, y); } -cudnnStatus_t CUDNNWINAPI -cudnnLRNCrossChannelBackward(cudnnHandle_t handle, - cudnnLRNDescriptor_t normDesc, - cudnnLRNMode_t lrnMode, - const void *alpha, - const cudnnTensorDescriptor_t yDesc, - const void *y, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t dxDesc, - void *dx) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnLRNDescriptor_t, cudnnLRNMode_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnLRNCrossChannelBackward( + cudnnHandle_t handle, cudnnLRNDescriptor_t normDesc, cudnnLRNMode_t lrnMode, + const void *alpha, const cudnnTensorDescriptor_t yDesc, const void *y, + const cudnnTensorDescriptor_t dyDesc, const void *dy, + const cudnnTensorDescriptor_t xDesc, const void *x, const void *beta, + const cudnnTensorDescriptor_t dxDesc, void *dx) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnLRNDescriptor_t, cudnnLRNMode_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnLRNCrossChannelBackward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, normDesc, lrnMode, alpha, yDesc, y, dyDesc, dy, xDesc, x, beta, dxDesc, dx); + return func_ptr(handle, normDesc, lrnMode, alpha, yDesc, y, dyDesc, dy, xDesc, + x, beta, dxDesc, dx); } -cudnnStatus_t CUDNNWINAPI -cudnnDivisiveNormalizationForward(cudnnHandle_t handle, - cudnnLRNDescriptor_t normDesc, - cudnnDivNormMode_t mode, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, /* same desc for means, temp, temp2 */ - const void *x, - const void *means, /* if NULL, means are assumed to be zero */ - void *temp, - void *temp2, - const void *beta, - const cudnnTensorDescriptor_t yDesc, - void *y) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnLRNDescriptor_t, cudnnDivNormMode_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, void *, void *, const void *, const cudnnTensorDescriptor_t, void *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDivisiveNormalizationForward"); +cudnnStatus_t CUDNNWINAPI cudnnDivisiveNormalizationForward( + cudnnHandle_t handle, cudnnLRNDescriptor_t normDesc, + cudnnDivNormMode_t mode, const void *alpha, + const cudnnTensorDescriptor_t xDesc, /* same desc for means, temp, temp2 */ + const void *x, + const void *means, /* if NULL, means are assumed to be zero */ + void *temp, void *temp2, const void *beta, + const cudnnTensorDescriptor_t yDesc, void *y) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnLRNDescriptor_t, cudnnDivNormMode_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, void *, void *, + const void *, const cudnnTensorDescriptor_t, void *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnDivisiveNormalizationForward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, normDesc, mode, alpha, xDesc, x, means, temp, temp2, beta, yDesc, y); + return func_ptr(handle, normDesc, mode, alpha, xDesc, x, means, temp, temp2, + beta, yDesc, y); } -cudnnStatus_t CUDNNWINAPI -cudnnDivisiveNormalizationBackward(cudnnHandle_t handle, - cudnnLRNDescriptor_t normDesc, - cudnnDivNormMode_t mode, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, /* same desc for x, means, dy, temp, temp2 */ - const void *x, - const void *means, /* if NULL, means are assumed to be zero */ - const void *dy, - void *temp, - void *temp2, - const void *beta, - const cudnnTensorDescriptor_t dXdMeansDesc, /* same desc for dx, dMeans */ - void *dx, /* output x differential */ - void *dMeans) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnLRNDescriptor_t, cudnnDivNormMode_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const void *, void *, void *, const void *, const cudnnTensorDescriptor_t, void *, void *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDivisiveNormalizationBackward"); +cudnnStatus_t CUDNNWINAPI cudnnDivisiveNormalizationBackward( + cudnnHandle_t handle, cudnnLRNDescriptor_t normDesc, + cudnnDivNormMode_t mode, const void *alpha, + const cudnnTensorDescriptor_t + xDesc, /* same desc for x, means, dy, temp, temp2 */ + const void *x, + const void *means, /* if NULL, means are assumed to be zero */ + const void *dy, void *temp, void *temp2, const void *beta, + const cudnnTensorDescriptor_t dXdMeansDesc, /* same desc for dx, dMeans */ + void *dx, /* output x differential */ + void *dMeans) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnLRNDescriptor_t, cudnnDivNormMode_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, const void *, + void *, void *, const void *, const cudnnTensorDescriptor_t, void *, + void *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnDivisiveNormalizationBackward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, normDesc, mode, alpha, xDesc, x, means, dy, temp, temp2, beta, dXdMeansDesc, dx, dMeans); + return func_ptr(handle, normDesc, mode, alpha, xDesc, x, means, dy, temp, + temp2, beta, dXdMeansDesc, dx, dMeans); } -cudnnStatus_t CUDNNWINAPI -cudnnDeriveBNTensorDescriptor(cudnnTensorDescriptor_t derivedBnDesc, - const cudnnTensorDescriptor_t xDesc, - cudnnBatchNormMode_t mode) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, cudnnBatchNormMode_t); +cudnnStatus_t CUDNNWINAPI cudnnDeriveBNTensorDescriptor( + cudnnTensorDescriptor_t derivedBnDesc, const cudnnTensorDescriptor_t xDesc, + cudnnBatchNormMode_t mode) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnTensorDescriptor_t, + const cudnnTensorDescriptor_t, + cudnnBatchNormMode_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDeriveBNTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(derivedBnDesc, xDesc, mode); } -cudnnStatus_t CUDNNWINAPI -cudnnBatchNormalizationForwardTraining( - cudnnHandle_t handle, - cudnnBatchNormMode_t mode, +cudnnStatus_t CUDNNWINAPI cudnnBatchNormalizationForwardTraining( + cudnnHandle_t handle, cudnnBatchNormMode_t mode, const void *alpha, /* alpha[0] = result blend factor */ const void *beta, /* beta[0] = dest layer blend factor */ - const cudnnTensorDescriptor_t xDesc, - const void *x, /* NxCxHxW */ - const cudnnTensorDescriptor_t yDesc, - void *y, /* NxCxHxW */ + const cudnnTensorDescriptor_t xDesc, const void *x, /* NxCxHxW */ + const cudnnTensorDescriptor_t yDesc, void *y, /* NxCxHxW */ /* Shared desc for the next 6 tensors in the argument list. Data type to be set as follows: @@ -1323,13 +1358,13 @@ cudnnBatchNormalizationForwardTraining( Dimensions for this descriptor depend on normalization mode - Spatial Normalization : tensors are expected to have dims 1xCx1x1 (normalization is performed across NxHxW) - - Per-Activation Normalization : tensors are expected to have dims of 1xCxHxW - (normalization is performed across N) */ + - Per-Activation Normalization : tensors are expected to have dims of + 1xCxHxW (normalization is performed across N) */ const cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc, - /* 'Gamma' and 'Beta' respectively in Ioffe and Szegedy's paper's notation */ - const void *bnScale, - const void *bnBias, + /* 'Gamma' and 'Beta' respectively in Ioffe and Szegedy's paper's notation + */ + const void *bnScale, const void *bnBias, /* MUST use factor=1 in the very first call of a complete training cycle. Use a factor=1/(1+n) at N-th call to the function to get @@ -1347,162 +1382,173 @@ cudnnBatchNormalizationForwardTraining( of variance[x] (factor is applied in the same way as for runningMean) */ void *resultRunningVariance, - /* Has to be >= CUDNN_BN_MIN_EPSILON. Should be the same in forward and backward functions. */ + /* Has to be >= CUDNN_BN_MIN_EPSILON. Should be the same in forward and + backward functions. */ double epsilon, /* Optionally save intermediate results from the forward pass here - can be reused to speed up backward pass. NULL if unused */ - void *resultSaveMean, - void *resultSaveInvVariance) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnBatchNormMode_t, const void *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, const void *, const void *, double, void *, void *, double, void *, void *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnBatchNormalizationForwardTraining"); + void *resultSaveMean, void *resultSaveInvVariance) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnBatchNormMode_t, const void *, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, + const void *, const void *, double, void *, void *, double, void *, + void *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnBatchNormalizationForwardTraining"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, mode, alpha, beta, xDesc, x, yDesc, y, bnScaleBiasMeanVarDesc, bnScale, bnBias, exponentialAverageFactor, resultRunningMean, resultRunningVariance, epsilon, resultSaveMean, resultSaveInvVariance); + return func_ptr( + handle, mode, alpha, beta, xDesc, x, yDesc, y, bnScaleBiasMeanVarDesc, + bnScale, bnBias, exponentialAverageFactor, resultRunningMean, + resultRunningVariance, epsilon, resultSaveMean, resultSaveInvVariance); } -cudnnStatus_t CUDNNWINAPI -cudnnBatchNormalizationForwardInference(cudnnHandle_t handle, - cudnnBatchNormMode_t mode, - const void *alpha, /* alpha[0] = result blend factor */ - const void *beta, /* beta[0] = dest layer blend factor */ - const cudnnTensorDescriptor_t xDesc, - const void *x, /* NxCxHxW */ - const cudnnTensorDescriptor_t yDesc, - void *y, /* NxCxHxW */ - const cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc, - const void *bnScale, - const void *bnBias, - const void *estimatedMean, - const void *estimatedVariance, - double epsilon) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnBatchNormMode_t, const void *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, const void *, const void *, const void *, const void *, double); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnBatchNormalizationForwardInference"); +cudnnStatus_t CUDNNWINAPI cudnnBatchNormalizationForwardInference( + cudnnHandle_t handle, cudnnBatchNormMode_t mode, + const void *alpha, /* alpha[0] = result blend factor */ + const void *beta, /* beta[0] = dest layer blend factor */ + const cudnnTensorDescriptor_t xDesc, const void *x, /* NxCxHxW */ + const cudnnTensorDescriptor_t yDesc, void *y, /* NxCxHxW */ + const cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc, const void *bnScale, + const void *bnBias, const void *estimatedMean, + const void *estimatedVariance, double epsilon) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnBatchNormMode_t, const void *, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, + const void *, const void *, const void *, const void *, double); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnBatchNormalizationForwardInference"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, mode, alpha, beta, xDesc, x, yDesc, y, bnScaleBiasMeanVarDesc, bnScale, bnBias, estimatedMean, estimatedVariance, epsilon); + return func_ptr(handle, mode, alpha, beta, xDesc, x, yDesc, y, + bnScaleBiasMeanVarDesc, bnScale, bnBias, estimatedMean, + estimatedVariance, epsilon); } -cudnnStatus_t CUDNNWINAPI -cudnnBatchNormalizationBackward(cudnnHandle_t handle, - cudnnBatchNormMode_t mode, - const void *alphaDataDiff, - const void *betaDataDiff, - const void *alphaParamDiff, - const void *betaParamDiff, - const cudnnTensorDescriptor_t xDesc, /* same desc for x, dx, dy */ - const void *x, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const cudnnTensorDescriptor_t dxDesc, - void *dx, - /* Shared tensor desc for the 4 tensors below */ - const cudnnTensorDescriptor_t dBnScaleBiasDesc, - const void *bnScale, /* bnBias doesn't affect backpropagation */ - /* scale and bias diff are not backpropagated below this layer */ - void *dBnScaleResult, - void *dBnBiasResult, - /* Same epsilon as forward pass */ - double epsilon, +cudnnStatus_t CUDNNWINAPI cudnnBatchNormalizationBackward( + cudnnHandle_t handle, cudnnBatchNormMode_t mode, const void *alphaDataDiff, + const void *betaDataDiff, const void *alphaParamDiff, + const void *betaParamDiff, + const cudnnTensorDescriptor_t xDesc, /* same desc for x, dx, dy */ + const void *x, const cudnnTensorDescriptor_t dyDesc, const void *dy, + const cudnnTensorDescriptor_t dxDesc, void *dx, + /* Shared tensor desc for the 4 tensors below */ + const cudnnTensorDescriptor_t dBnScaleBiasDesc, + const void *bnScale, /* bnBias doesn't affect backpropagation */ + /* scale and bias diff are not backpropagated below this layer */ + void *dBnScaleResult, void *dBnBiasResult, + /* Same epsilon as forward pass */ + double epsilon, - /* Optionally cached intermediate results from - forward pass */ - const void *savedMean, - const void *savedInvVariance) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnBatchNormMode_t, const void *, const void *, const void *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, const void *, void *, void *, double, const void *, const void *); + /* Optionally cached intermediate results from + forward pass */ + const void *savedMean, const void *savedInvVariance) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnBatchNormMode_t, const void *, const void *, + const void *, const void *, const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, + const void *, void *, void *, double, const void *, const void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnBatchNormalizationBackward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, mode, alphaDataDiff, betaDataDiff, alphaParamDiff, betaParamDiff, xDesc, x, dyDesc, dy, dxDesc, dx, dBnScaleBiasDesc, bnScale, dBnScaleResult, dBnBiasResult, epsilon, savedMean, savedInvVariance); + return func_ptr(handle, mode, alphaDataDiff, betaDataDiff, alphaParamDiff, + betaParamDiff, xDesc, x, dyDesc, dy, dxDesc, dx, + dBnScaleBiasDesc, bnScale, dBnScaleResult, dBnBiasResult, + epsilon, savedMean, savedInvVariance); } -cudnnStatus_t CUDNNWINAPI -cudnnCreateSpatialTransformerDescriptor(cudnnSpatialTransformerDescriptor_t *stDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnSpatialTransformerDescriptor_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateSpatialTransformerDescriptor"); +cudnnStatus_t CUDNNWINAPI cudnnCreateSpatialTransformerDescriptor( + cudnnSpatialTransformerDescriptor_t *stDesc) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnSpatialTransformerDescriptor_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnCreateSpatialTransformerDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(stDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnSetSpatialTransformerNdDescriptor(cudnnSpatialTransformerDescriptor_t stDesc, - cudnnSamplerType_t samplerType, - cudnnDataType_t dataType, - const int nbDims, - const int dimA[]) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnSpatialTransformerDescriptor_t, cudnnSamplerType_t, cudnnDataType_t, const int, const int []); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetSpatialTransformerNdDescriptor"); +cudnnStatus_t CUDNNWINAPI cudnnSetSpatialTransformerNdDescriptor( + cudnnSpatialTransformerDescriptor_t stDesc, cudnnSamplerType_t samplerType, + cudnnDataType_t dataType, const int nbDims, const int dimA[]) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnSpatialTransformerDescriptor_t, cudnnSamplerType_t, cudnnDataType_t, + const int, const int[]); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnSetSpatialTransformerNdDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(stDesc, samplerType, dataType, nbDims, dimA); } -cudnnStatus_t CUDNNWINAPI -cudnnDestroySpatialTransformerDescriptor(cudnnSpatialTransformerDescriptor_t stDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnSpatialTransformerDescriptor_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroySpatialTransformerDescriptor"); +cudnnStatus_t CUDNNWINAPI cudnnDestroySpatialTransformerDescriptor( + cudnnSpatialTransformerDescriptor_t stDesc) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnSpatialTransformerDescriptor_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnDestroySpatialTransformerDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(stDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnSpatialTfGridGeneratorForward(cudnnHandle_t handle, - const cudnnSpatialTransformerDescriptor_t stDesc, - const void *theta, - void *grid) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnSpatialTransformerDescriptor_t, const void *, void *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSpatialTfGridGeneratorForward"); +cudnnStatus_t CUDNNWINAPI cudnnSpatialTfGridGeneratorForward( + cudnnHandle_t handle, const cudnnSpatialTransformerDescriptor_t stDesc, + const void *theta, void *grid) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnSpatialTransformerDescriptor_t, const void *, + void *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnSpatialTfGridGeneratorForward"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, stDesc, theta, grid); } -cudnnStatus_t CUDNNWINAPI -cudnnSpatialTfGridGeneratorBackward(cudnnHandle_t handle, - const cudnnSpatialTransformerDescriptor_t stDesc, - const void *dgrid, - void *dtheta) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnSpatialTransformerDescriptor_t, const void *, void *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSpatialTfGridGeneratorBackward"); +cudnnStatus_t CUDNNWINAPI cudnnSpatialTfGridGeneratorBackward( + cudnnHandle_t handle, const cudnnSpatialTransformerDescriptor_t stDesc, + const void *dgrid, void *dtheta) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnSpatialTransformerDescriptor_t, const void *, + void *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnSpatialTfGridGeneratorBackward"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, stDesc, dgrid, dtheta); } -cudnnStatus_t CUDNNWINAPI -cudnnSpatialTfSamplerForward(cudnnHandle_t handle, - cudnnSpatialTransformerDescriptor_t stDesc, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *grid, - const void *beta, - cudnnTensorDescriptor_t yDesc, - void *y) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnSpatialTransformerDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const void *, cudnnTensorDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnSpatialTfSamplerForward( + cudnnHandle_t handle, cudnnSpatialTransformerDescriptor_t stDesc, + const void *alpha, const cudnnTensorDescriptor_t xDesc, const void *x, + const void *grid, const void *beta, cudnnTensorDescriptor_t yDesc, + void *y) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnSpatialTransformerDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, const void *, + cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSpatialTfSamplerForward"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, stDesc, alpha, xDesc, x, grid, beta, yDesc, y); } -cudnnStatus_t CUDNNWINAPI -cudnnSpatialTfSamplerBackward(cudnnHandle_t handle, - cudnnSpatialTransformerDescriptor_t stDesc, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t dxDesc, - void *dx, - const void *alphaDgrid, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const void *grid, - const void *betaDgrid, - void *dgrid) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnSpatialTransformerDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const void *, void *); +cudnnStatus_t CUDNNWINAPI cudnnSpatialTfSamplerBackward( + cudnnHandle_t handle, cudnnSpatialTransformerDescriptor_t stDesc, + const void *alpha, const cudnnTensorDescriptor_t xDesc, const void *x, + const void *beta, const cudnnTensorDescriptor_t dxDesc, void *dx, + const void *alphaDgrid, const cudnnTensorDescriptor_t dyDesc, + const void *dy, const void *grid, const void *betaDgrid, void *dgrid) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnSpatialTransformerDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, const void *, + void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSpatialTfSamplerBackward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, stDesc, alpha, xDesc, x, beta, dxDesc, dx, alphaDgrid, dyDesc, dy, grid, betaDgrid, dgrid); + return func_ptr(handle, stDesc, alpha, xDesc, x, beta, dxDesc, dx, alphaDgrid, + dyDesc, dy, grid, betaDgrid, dgrid); } cudnnStatus_t CUDNNWINAPI cudnnCreateDropoutDescriptor(cudnnDropoutDescriptor_t *dropoutDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnDropoutDescriptor_t *); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnDropoutDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateDropoutDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(dropoutDesc); @@ -1510,99 +1556,95 @@ cudnnCreateDropoutDescriptor(cudnnDropoutDescriptor_t *dropoutDesc) { cudnnStatus_t CUDNNWINAPI cudnnDestroyDropoutDescriptor(cudnnDropoutDescriptor_t dropoutDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnDropoutDescriptor_t); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnDropoutDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyDropoutDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(dropoutDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnDropoutGetStatesSize(cudnnHandle_t handle, size_t *sizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, size_t *); +cudnnStatus_t CUDNNWINAPI cudnnDropoutGetStatesSize(cudnnHandle_t handle, + size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDropoutGetStatesSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, sizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnDropoutGetReserveSpaceSize(cudnnTensorDescriptor_t xdesc, size_t *sizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorDescriptor_t, size_t *); +cudnnStatus_t CUDNNWINAPI cudnnDropoutGetReserveSpaceSize( + cudnnTensorDescriptor_t xdesc, size_t *sizeInBytes) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnTensorDescriptor_t, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDropoutGetReserveSpaceSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(xdesc, sizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnSetDropoutDescriptor(cudnnDropoutDescriptor_t dropoutDesc, - cudnnHandle_t handle, - float dropout, - void *states, - size_t stateSizeInBytes, - unsigned long long seed) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnDropoutDescriptor_t, cudnnHandle_t, float, void *, size_t, unsigned long long); +cudnnStatus_t CUDNNWINAPI cudnnSetDropoutDescriptor( + cudnnDropoutDescriptor_t dropoutDesc, cudnnHandle_t handle, float dropout, + void *states, size_t stateSizeInBytes, unsigned long long seed) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnDropoutDescriptor_t, cudnnHandle_t, + float, void *, size_t, unsigned long long); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetDropoutDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(dropoutDesc, handle, dropout, states, stateSizeInBytes, seed); } -cudnnStatus_t CUDNNWINAPI -cudnnRestoreDropoutDescriptor(cudnnDropoutDescriptor_t dropoutDesc, - cudnnHandle_t handle, - float dropout, - void *states, - size_t stateSizeInBytes, - unsigned long long seed) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnDropoutDescriptor_t, cudnnHandle_t, float, void *, size_t, unsigned long long); +cudnnStatus_t CUDNNWINAPI cudnnRestoreDropoutDescriptor( + cudnnDropoutDescriptor_t dropoutDesc, cudnnHandle_t handle, float dropout, + void *states, size_t stateSizeInBytes, unsigned long long seed) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnDropoutDescriptor_t, cudnnHandle_t, + float, void *, size_t, unsigned long long); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRestoreDropoutDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(dropoutDesc, handle, dropout, states, stateSizeInBytes, seed); } -cudnnStatus_t CUDNNWINAPI -cudnnGetDropoutDescriptor(cudnnDropoutDescriptor_t dropoutDesc, - cudnnHandle_t handle, - float *dropout, - void **states, - unsigned long long *seed) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnDropoutDescriptor_t, cudnnHandle_t, float *, void **, unsigned long long *); +cudnnStatus_t CUDNNWINAPI cudnnGetDropoutDescriptor( + cudnnDropoutDescriptor_t dropoutDesc, cudnnHandle_t handle, float *dropout, + void **states, unsigned long long *seed) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnDropoutDescriptor_t, cudnnHandle_t, + float *, void **, unsigned long long *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetDropoutDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(dropoutDesc, handle, dropout, states, seed); } -cudnnStatus_t CUDNNWINAPI -cudnnDropoutForward(cudnnHandle_t handle, - const cudnnDropoutDescriptor_t dropoutDesc, - const cudnnTensorDescriptor_t xdesc, - const void *x, - const cudnnTensorDescriptor_t ydesc, - void *y, - void *reserveSpace, - size_t reserveSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnDropoutDescriptor_t, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, void *, void *, size_t); +cudnnStatus_t CUDNNWINAPI cudnnDropoutForward( + cudnnHandle_t handle, const cudnnDropoutDescriptor_t dropoutDesc, + const cudnnTensorDescriptor_t xdesc, const void *x, + const cudnnTensorDescriptor_t ydesc, void *y, void *reserveSpace, + size_t reserveSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnDropoutDescriptor_t, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, void *, void *, size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDropoutForward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, dropoutDesc, xdesc, x, ydesc, y, reserveSpace, reserveSpaceSizeInBytes); + return func_ptr(handle, dropoutDesc, xdesc, x, ydesc, y, reserveSpace, + reserveSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnDropoutBackward(cudnnHandle_t handle, - const cudnnDropoutDescriptor_t dropoutDesc, - const cudnnTensorDescriptor_t dydesc, - const void *dy, - const cudnnTensorDescriptor_t dxdesc, - void *dx, - void *reserveSpace, - size_t reserveSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnDropoutDescriptor_t, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, void *, void *, size_t); +cudnnStatus_t CUDNNWINAPI cudnnDropoutBackward( + cudnnHandle_t handle, const cudnnDropoutDescriptor_t dropoutDesc, + const cudnnTensorDescriptor_t dydesc, const void *dy, + const cudnnTensorDescriptor_t dxdesc, void *dx, void *reserveSpace, + size_t reserveSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnDropoutDescriptor_t, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, void *, void *, size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDropoutBackward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, dropoutDesc, dydesc, dy, dxdesc, dx, reserveSpace, reserveSpaceSizeInBytes); + return func_ptr(handle, dropoutDesc, dydesc, dy, dxdesc, dx, reserveSpace, + reserveSpaceSizeInBytes); } cudnnStatus_t CUDNNWINAPI cudnnCreateRNNDescriptor(cudnnRNNDescriptor_t *rnnDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t *); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnRNNDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateRNNDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(rnnDesc); @@ -1610,184 +1652,192 @@ cudnnCreateRNNDescriptor(cudnnRNNDescriptor_t *rnnDesc) { cudnnStatus_t CUDNNWINAPI cudnnDestroyRNNDescriptor(cudnnRNNDescriptor_t rnnDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnRNNDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyRNNDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(rnnDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnGetRNNForwardInferenceAlgorithmMaxCount(cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, int *count) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, int *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNForwardInferenceAlgorithmMaxCount"); +cudnnStatus_t CUDNNWINAPI cudnnGetRNNForwardInferenceAlgorithmMaxCount( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, int *count) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, int *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetRNNForwardInferenceAlgorithmMaxCount"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rnnDesc, count); } -cudnnStatus_t CUDNNWINAPI -cudnnFindRNNForwardInferenceAlgorithmEx(cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int seqLength, - const cudnnTensorDescriptor_t *xDesc, - const void *x, - const cudnnTensorDescriptor_t hxDesc, - const void *hx, - const cudnnTensorDescriptor_t cxDesc, - const void *cx, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnTensorDescriptor_t *yDesc, - void *y, - const cudnnTensorDescriptor_t hyDesc, - void *hy, - const cudnnTensorDescriptor_t cyDesc, - void *cy, - const float findIntensity, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnAlgorithmPerformance_t *perfResults, - void *workspace, - size_t workSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, const float, const int, int *, cudnnAlgorithmPerformance_t *, void *, size_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindRNNForwardInferenceAlgorithmEx"); +cudnnStatus_t CUDNNWINAPI cudnnFindRNNForwardInferenceAlgorithmEx( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, const cudnnTensorDescriptor_t *xDesc, const void *x, + const cudnnTensorDescriptor_t hxDesc, const void *hx, + const cudnnTensorDescriptor_t cxDesc, const void *cx, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnTensorDescriptor_t *yDesc, void *y, + const cudnnTensorDescriptor_t hyDesc, void *hy, + const cudnnTensorDescriptor_t cyDesc, void *cy, const float findIntensity, + const int requestedAlgoCount, int *returnedAlgoCount, + cudnnAlgorithmPerformance_t *perfResults, void *workspace, + size_t workSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t *, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, + void *, const cudnnTensorDescriptor_t, void *, const float, const int, + int *, cudnnAlgorithmPerformance_t *, void *, size_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnFindRNNForwardInferenceAlgorithmEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, cxDesc, cx, wDesc, w, yDesc, y, hyDesc, hy, cyDesc, cy, findIntensity, requestedAlgoCount, returnedAlgoCount, perfResults, workspace, workSpaceSizeInBytes); + return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, cxDesc, cx, + wDesc, w, yDesc, y, hyDesc, hy, cyDesc, cy, findIntensity, + requestedAlgoCount, returnedAlgoCount, perfResults, workspace, + workSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnGetRNNForwardTrainingAlgorithmMaxCount(cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, int *count) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, int *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNForwardTrainingAlgorithmMaxCount"); +cudnnStatus_t CUDNNWINAPI cudnnGetRNNForwardTrainingAlgorithmMaxCount( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, int *count) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, int *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetRNNForwardTrainingAlgorithmMaxCount"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rnnDesc, count); } -cudnnStatus_t CUDNNWINAPI -cudnnFindRNNForwardTrainingAlgorithmEx(cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int seqLength, - const cudnnTensorDescriptor_t *xDesc, - const void *x, - const cudnnTensorDescriptor_t hxDesc, - const void *hx, - const cudnnTensorDescriptor_t cxDesc, - const void *cx, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnTensorDescriptor_t *yDesc, - void *y, - const cudnnTensorDescriptor_t hyDesc, - void *hy, - const cudnnTensorDescriptor_t cyDesc, - void *cy, - const float findIntensity, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnAlgorithmPerformance_t *perfResults, - void *workspace, - size_t workSpaceSizeInBytes, - void *reserveSpace, - size_t reserveSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, const float, const int, int *, cudnnAlgorithmPerformance_t *, void *, size_t, void *, size_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindRNNForwardTrainingAlgorithmEx"); +cudnnStatus_t CUDNNWINAPI cudnnFindRNNForwardTrainingAlgorithmEx( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, const cudnnTensorDescriptor_t *xDesc, const void *x, + const cudnnTensorDescriptor_t hxDesc, const void *hx, + const cudnnTensorDescriptor_t cxDesc, const void *cx, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnTensorDescriptor_t *yDesc, void *y, + const cudnnTensorDescriptor_t hyDesc, void *hy, + const cudnnTensorDescriptor_t cyDesc, void *cy, const float findIntensity, + const int requestedAlgoCount, int *returnedAlgoCount, + cudnnAlgorithmPerformance_t *perfResults, void *workspace, + size_t workSpaceSizeInBytes, void *reserveSpace, + size_t reserveSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t *, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, + void *, const cudnnTensorDescriptor_t, void *, const float, const int, + int *, cudnnAlgorithmPerformance_t *, void *, size_t, void *, size_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnFindRNNForwardTrainingAlgorithmEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, cxDesc, cx, wDesc, w, yDesc, y, hyDesc, hy, cyDesc, cy, findIntensity, requestedAlgoCount, returnedAlgoCount, perfResults, workspace, workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes); + return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, cxDesc, cx, + wDesc, w, yDesc, y, hyDesc, hy, cyDesc, cy, findIntensity, + requestedAlgoCount, returnedAlgoCount, perfResults, workspace, + workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnGetRNNBackwardDataAlgorithmMaxCount(cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, int *count) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, int *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNBackwardDataAlgorithmMaxCount"); +cudnnStatus_t CUDNNWINAPI cudnnGetRNNBackwardDataAlgorithmMaxCount( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, int *count) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, int *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetRNNBackwardDataAlgorithmMaxCount"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rnnDesc, count); } -cudnnStatus_t CUDNNWINAPI -cudnnFindRNNBackwardDataAlgorithmEx(cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int seqLength, - const cudnnTensorDescriptor_t *yDesc, - const void *y, - const cudnnTensorDescriptor_t *dyDesc, - const void *dy, - const cudnnTensorDescriptor_t dhyDesc, - const void *dhy, - const cudnnTensorDescriptor_t dcyDesc, - const void *dcy, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnTensorDescriptor_t hxDesc, - const void *hx, - const cudnnTensorDescriptor_t cxDesc, - const void *cx, - const cudnnTensorDescriptor_t *dxDesc, - void *dx, - const cudnnTensorDescriptor_t dhxDesc, - void *dhx, - const cudnnTensorDescriptor_t dcxDesc, - void *dcx, - const float findIntensity, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnAlgorithmPerformance_t *perfResults, - void *workspace, - size_t workSpaceSizeInBytes, - void *reserveSpace, - size_t reserveSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, const float, const int, int *, cudnnAlgorithmPerformance_t *, void *, size_t, void *, size_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindRNNBackwardDataAlgorithmEx"); +cudnnStatus_t CUDNNWINAPI cudnnFindRNNBackwardDataAlgorithmEx( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, const cudnnTensorDescriptor_t *yDesc, const void *y, + const cudnnTensorDescriptor_t *dyDesc, const void *dy, + const cudnnTensorDescriptor_t dhyDesc, const void *dhy, + const cudnnTensorDescriptor_t dcyDesc, const void *dcy, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnTensorDescriptor_t hxDesc, const void *hx, + const cudnnTensorDescriptor_t cxDesc, const void *cx, + const cudnnTensorDescriptor_t *dxDesc, void *dx, + const cudnnTensorDescriptor_t dhxDesc, void *dhx, + const cudnnTensorDescriptor_t dcxDesc, void *dcx, const float findIntensity, + const int requestedAlgoCount, int *returnedAlgoCount, + cudnnAlgorithmPerformance_t *perfResults, void *workspace, + size_t workSpaceSizeInBytes, void *reserveSpace, + size_t reserveSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t *, const void *, + const cudnnTensorDescriptor_t *, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, + void *, const cudnnTensorDescriptor_t, void *, const float, const int, + int *, cudnnAlgorithmPerformance_t *, void *, size_t, void *, size_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnFindRNNBackwardDataAlgorithmEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, seqLength, yDesc, y, dyDesc, dy, dhyDesc, dhy, dcyDesc, dcy, wDesc, w, hxDesc, hx, cxDesc, cx, dxDesc, dx, dhxDesc, dhx, dcxDesc, dcx, findIntensity, requestedAlgoCount, returnedAlgoCount, perfResults, workspace, workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes); + return func_ptr(handle, rnnDesc, seqLength, yDesc, y, dyDesc, dy, dhyDesc, + dhy, dcyDesc, dcy, wDesc, w, hxDesc, hx, cxDesc, cx, dxDesc, + dx, dhxDesc, dhx, dcxDesc, dcx, findIntensity, + requestedAlgoCount, returnedAlgoCount, perfResults, workspace, + workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnGetRNNBackwardWeightsAlgorithmMaxCount(cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, int *count) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, int *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNBackwardWeightsAlgorithmMaxCount"); +cudnnStatus_t CUDNNWINAPI cudnnGetRNNBackwardWeightsAlgorithmMaxCount( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, int *count) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, int *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetRNNBackwardWeightsAlgorithmMaxCount"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rnnDesc, count); } -cudnnStatus_t CUDNNWINAPI -cudnnFindRNNBackwardWeightsAlgorithmEx(cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int seqLength, - const cudnnTensorDescriptor_t *xDesc, - const void *x, - const cudnnTensorDescriptor_t hxDesc, - const void *hx, - const cudnnTensorDescriptor_t *yDesc, - const void *y, - const float findIntensity, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnAlgorithmPerformance_t *perfResults, - const void *workspace, - size_t workSpaceSizeInBytes, - const cudnnFilterDescriptor_t dwDesc, - void *dw, - const void *reserveSpace, - size_t reserveSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t *, const void *, const float, const int, int *, cudnnAlgorithmPerformance_t *, const void *, size_t, const cudnnFilterDescriptor_t, void *, const void *, size_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindRNNBackwardWeightsAlgorithmEx"); +cudnnStatus_t CUDNNWINAPI cudnnFindRNNBackwardWeightsAlgorithmEx( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, const cudnnTensorDescriptor_t *xDesc, const void *x, + const cudnnTensorDescriptor_t hxDesc, const void *hx, + const cudnnTensorDescriptor_t *yDesc, const void *y, + const float findIntensity, const int requestedAlgoCount, + int *returnedAlgoCount, cudnnAlgorithmPerformance_t *perfResults, + const void *workspace, size_t workSpaceSizeInBytes, + const cudnnFilterDescriptor_t dwDesc, void *dw, const void *reserveSpace, + size_t reserveSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t *, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t *, const void *, const float, const int, + int *, cudnnAlgorithmPerformance_t *, const void *, size_t, + const cudnnFilterDescriptor_t, void *, const void *, size_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnFindRNNBackwardWeightsAlgorithmEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, yDesc, y, findIntensity, requestedAlgoCount, returnedAlgoCount, perfResults, workspace, workSpaceSizeInBytes, dwDesc, dw, reserveSpace, reserveSpaceSizeInBytes); + return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, yDesc, y, + findIntensity, requestedAlgoCount, returnedAlgoCount, + perfResults, workspace, workSpaceSizeInBytes, dwDesc, dw, + reserveSpace, reserveSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnCreatePersistentRNNPlan(cudnnRNNDescriptor_t rnnDesc, - const int minibatch, - const cudnnDataType_t dataType, - cudnnPersistentRNNPlan_t *plan) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t, const int, const cudnnDataType_t, cudnnPersistentRNNPlan_t *); +cudnnStatus_t CUDNNWINAPI cudnnCreatePersistentRNNPlan( + cudnnRNNDescriptor_t rnnDesc, const int minibatch, + const cudnnDataType_t dataType, cudnnPersistentRNNPlan_t *plan) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnRNNDescriptor_t, const int, + const cudnnDataType_t, + cudnnPersistentRNNPlan_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreatePersistentRNNPlan"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(rnnDesc, minibatch, dataType, plan); } -cudnnStatus_t CUDNNWINAPI -cudnnSetPersistentRNNPlan(cudnnRNNDescriptor_t rnnDesc, cudnnPersistentRNNPlan_t plan) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t, cudnnPersistentRNNPlan_t); +cudnnStatus_t CUDNNWINAPI cudnnSetPersistentRNNPlan( + cudnnRNNDescriptor_t rnnDesc, cudnnPersistentRNNPlan_t plan) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnRNNDescriptor_t, + cudnnPersistentRNNPlan_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetPersistentRNNPlan"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(rnnDesc, plan); @@ -1795,289 +1845,285 @@ cudnnSetPersistentRNNPlan(cudnnRNNDescriptor_t rnnDesc, cudnnPersistentRNNPlan_t cudnnStatus_t CUDNNWINAPI cudnnDestroyPersistentRNNPlan(cudnnPersistentRNNPlan_t plan) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnPersistentRNNPlan_t); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnPersistentRNNPlan_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyPersistentRNNPlan"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(plan); } -cudnnStatus_t CUDNNWINAPI -cudnnSetRNNDescriptor(cudnnHandle_t handle, - cudnnRNNDescriptor_t rnnDesc, - const int hiddenSize, - const int numLayers, - cudnnDropoutDescriptor_t dropoutDesc, /* Between layers, not between recurrent steps. */ - cudnnRNNInputMode_t inputMode, - cudnnDirectionMode_t direction, - cudnnRNNMode_t mode, - cudnnRNNAlgo_t algo, - cudnnDataType_t dataType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnRNNDescriptor_t, const int, const int, cudnnDropoutDescriptor_t, cudnnRNNInputMode_t, cudnnDirectionMode_t, cudnnRNNMode_t, cudnnRNNAlgo_t, cudnnDataType_t); +cudnnStatus_t CUDNNWINAPI cudnnSetRNNDescriptor( + cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc, const int hiddenSize, + const int numLayers, + cudnnDropoutDescriptor_t + dropoutDesc, /* Between layers, not between recurrent steps. */ + cudnnRNNInputMode_t inputMode, cudnnDirectionMode_t direction, + cudnnRNNMode_t mode, cudnnRNNAlgo_t algo, cudnnDataType_t dataType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnRNNDescriptor_t, const int, const int, + cudnnDropoutDescriptor_t, cudnnRNNInputMode_t, cudnnDirectionMode_t, + cudnnRNNMode_t, cudnnRNNAlgo_t, cudnnDataType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetRNNDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, hiddenSize, numLayers, dropoutDesc, inputMode, direction, mode, algo, dataType); + return func_ptr(handle, rnnDesc, hiddenSize, numLayers, dropoutDesc, + inputMode, direction, mode, algo, dataType); } cudnnStatus_t CUDNNWINAPI -cudnnSetRNNProjectionLayers(cudnnHandle_t handle, - cudnnRNNDescriptor_t rnnDesc, - const int recProjSize, - const int outProjSize) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnRNNDescriptor_t, const int, const int); +cudnnSetRNNProjectionLayers(cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc, + const int recProjSize, const int outProjSize) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnRNNDescriptor_t, const int, const int); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetRNNProjectionLayers"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rnnDesc, recProjSize, outProjSize); } -cudnnStatus_t CUDNNWINAPI -cudnnGetRNNProjectionLayers(cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - int *recProjSize, - int *outProjSize) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, int *, int *); +cudnnStatus_t CUDNNWINAPI cudnnGetRNNProjectionLayers( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, int *recProjSize, + int *outProjSize) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, int *, int *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNProjectionLayers"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rnnDesc, recProjSize, outProjSize); } -cudnnStatus_t CUDNNWINAPI -cudnnSetRNNAlgorithmDescriptor(cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc, cudnnAlgorithmDescriptor_t algoDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnRNNDescriptor_t, cudnnAlgorithmDescriptor_t); +cudnnStatus_t CUDNNWINAPI cudnnSetRNNAlgorithmDescriptor( + cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc, + cudnnAlgorithmDescriptor_t algoDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnRNNDescriptor_t, cudnnAlgorithmDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetRNNAlgorithmDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rnnDesc, algoDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnGetRNNDescriptor(cudnnHandle_t handle, - cudnnRNNDescriptor_t rnnDesc, - int *hiddenSize, - int *numLayers, - cudnnDropoutDescriptor_t *dropoutDesc, - cudnnRNNInputMode_t *inputMode, - cudnnDirectionMode_t *direction, - cudnnRNNMode_t *mode, - cudnnRNNAlgo_t *algo, - cudnnDataType_t *dataType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnRNNDescriptor_t, int *, int *, cudnnDropoutDescriptor_t *, cudnnRNNInputMode_t *, cudnnDirectionMode_t *, cudnnRNNMode_t *, cudnnRNNAlgo_t *, cudnnDataType_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetRNNDescriptor( + cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc, int *hiddenSize, + int *numLayers, cudnnDropoutDescriptor_t *dropoutDesc, + cudnnRNNInputMode_t *inputMode, cudnnDirectionMode_t *direction, + cudnnRNNMode_t *mode, cudnnRNNAlgo_t *algo, cudnnDataType_t *dataType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnRNNDescriptor_t, int *, int *, + cudnnDropoutDescriptor_t *, cudnnRNNInputMode_t *, cudnnDirectionMode_t *, + cudnnRNNMode_t *, cudnnRNNAlgo_t *, cudnnDataType_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, hiddenSize, numLayers, dropoutDesc, inputMode, direction, mode, algo, dataType); + return func_ptr(handle, rnnDesc, hiddenSize, numLayers, dropoutDesc, + inputMode, direction, mode, algo, dataType); } cudnnStatus_t CUDNNWINAPI cudnnSetRNNMatrixMathType(cudnnRNNDescriptor_t rnnDesc, cudnnMathType_t mType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t, cudnnMathType_t); + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnRNNDescriptor_t, cudnnMathType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetRNNMatrixMathType"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(rnnDesc, mType); } -cudnnStatus_t CUDNNWINAPI -cudnnGetRNNMatrixMathType(cudnnRNNDescriptor_t rnnDesc, cudnnMathType_t *mType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t, cudnnMathType_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetRNNMatrixMathType( + cudnnRNNDescriptor_t rnnDesc, cudnnMathType_t *mType) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnRNNDescriptor_t, cudnnMathType_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNMatrixMathType"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(rnnDesc, mType); } -cudnnStatus_t CUDNNWINAPI -cudnnGetRNNWorkspaceSize(cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int seqLength, - const cudnnTensorDescriptor_t *xDesc, - size_t *sizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, size_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetRNNWorkspaceSize( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, const cudnnTensorDescriptor_t *xDesc, + size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t *, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNWorkspaceSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rnnDesc, seqLength, xDesc, sizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnGetRNNTrainingReserveSize(cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int seqLength, - const cudnnTensorDescriptor_t *xDesc, - size_t *sizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, size_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetRNNTrainingReserveSize( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, const cudnnTensorDescriptor_t *xDesc, + size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t *, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNTrainingReserveSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rnnDesc, seqLength, xDesc, sizeInBytes); } cudnnStatus_t CUDNNWINAPI -cudnnGetRNNParamsSize(cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const cudnnTensorDescriptor_t xDesc, - size_t *sizeInBytes, +cudnnGetRNNParamsSize(cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const cudnnTensorDescriptor_t xDesc, size_t *sizeInBytes, cudnnDataType_t dataType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const cudnnTensorDescriptor_t, size_t *, cudnnDataType_t); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const cudnnTensorDescriptor_t, + size_t *, cudnnDataType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNParamsSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rnnDesc, xDesc, sizeInBytes, dataType); } -cudnnStatus_t CUDNNWINAPI -cudnnGetRNNLinLayerMatrixParams(cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int pseudoLayer, - const cudnnTensorDescriptor_t xDesc, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const int linLayerID, - cudnnFilterDescriptor_t linLayerMatDesc, - void **linLayerMat) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, const void *, const int, cudnnFilterDescriptor_t, void **); +cudnnStatus_t CUDNNWINAPI cudnnGetRNNLinLayerMatrixParams( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int pseudoLayer, const cudnnTensorDescriptor_t xDesc, + const cudnnFilterDescriptor_t wDesc, const void *w, const int linLayerID, + cudnnFilterDescriptor_t linLayerMatDesc, void **linLayerMat) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, + const void *, const int, cudnnFilterDescriptor_t, void **); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNLinLayerMatrixParams"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, pseudoLayer, xDesc, wDesc, w, linLayerID, linLayerMatDesc, linLayerMat); + return func_ptr(handle, rnnDesc, pseudoLayer, xDesc, wDesc, w, linLayerID, + linLayerMatDesc, linLayerMat); } -cudnnStatus_t CUDNNWINAPI -cudnnGetRNNLinLayerBiasParams(cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int pseudoLayer, - const cudnnTensorDescriptor_t xDesc, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const int linLayerID, - cudnnFilterDescriptor_t linLayerBiasDesc, - void **linLayerBias) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, const void *, const int, cudnnFilterDescriptor_t, void **); +cudnnStatus_t CUDNNWINAPI cudnnGetRNNLinLayerBiasParams( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int pseudoLayer, const cudnnTensorDescriptor_t xDesc, + const cudnnFilterDescriptor_t wDesc, const void *w, const int linLayerID, + cudnnFilterDescriptor_t linLayerBiasDesc, void **linLayerBias) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, + const void *, const int, cudnnFilterDescriptor_t, void **); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNLinLayerBiasParams"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, pseudoLayer, xDesc, wDesc, w, linLayerID, linLayerBiasDesc, linLayerBias); + return func_ptr(handle, rnnDesc, pseudoLayer, xDesc, wDesc, w, linLayerID, + linLayerBiasDesc, linLayerBias); } -cudnnStatus_t CUDNNWINAPI -cudnnRNNForwardInference(cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int seqLength, - const cudnnTensorDescriptor_t *xDesc, - const void *x, - const cudnnTensorDescriptor_t hxDesc, - const void *hx, - const cudnnTensorDescriptor_t cxDesc, - const void *cx, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnTensorDescriptor_t *yDesc, - void *y, - const cudnnTensorDescriptor_t hyDesc, - void *hy, - const cudnnTensorDescriptor_t cyDesc, - void *cy, - void *workspace, - size_t workSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, void *, size_t); +cudnnStatus_t CUDNNWINAPI cudnnRNNForwardInference( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, const cudnnTensorDescriptor_t *xDesc, const void *x, + const cudnnTensorDescriptor_t hxDesc, const void *hx, + const cudnnTensorDescriptor_t cxDesc, const void *cx, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnTensorDescriptor_t *yDesc, void *y, + const cudnnTensorDescriptor_t hyDesc, void *hy, + const cudnnTensorDescriptor_t cyDesc, void *cy, void *workspace, + size_t workSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t *, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, + void *, const cudnnTensorDescriptor_t, void *, void *, size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRNNForwardInference"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, cxDesc, cx, wDesc, w, yDesc, y, hyDesc, hy, cyDesc, cy, workspace, workSpaceSizeInBytes); + return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, cxDesc, cx, + wDesc, w, yDesc, y, hyDesc, hy, cyDesc, cy, workspace, + workSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnRNNForwardTraining(cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int seqLength, - const cudnnTensorDescriptor_t *xDesc, - const void *x, - const cudnnTensorDescriptor_t hxDesc, - const void *hx, - const cudnnTensorDescriptor_t cxDesc, - const void *cx, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnTensorDescriptor_t *yDesc, - void *y, - const cudnnTensorDescriptor_t hyDesc, - void *hy, - const cudnnTensorDescriptor_t cyDesc, - void *cy, - void *workspace, - size_t workSpaceSizeInBytes, - void *reserveSpace, - size_t reserveSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, void *, size_t, void *, size_t); +cudnnStatus_t CUDNNWINAPI cudnnRNNForwardTraining( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, const cudnnTensorDescriptor_t *xDesc, const void *x, + const cudnnTensorDescriptor_t hxDesc, const void *hx, + const cudnnTensorDescriptor_t cxDesc, const void *cx, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnTensorDescriptor_t *yDesc, void *y, + const cudnnTensorDescriptor_t hyDesc, void *hy, + const cudnnTensorDescriptor_t cyDesc, void *cy, void *workspace, + size_t workSpaceSizeInBytes, void *reserveSpace, + size_t reserveSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t *, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, + void *, const cudnnTensorDescriptor_t, void *, void *, size_t, void *, + size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRNNForwardTraining"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, cxDesc, cx, wDesc, w, yDesc, y, hyDesc, hy, cyDesc, cy, workspace, workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes); + return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, cxDesc, cx, + wDesc, w, yDesc, y, hyDesc, hy, cyDesc, cy, workspace, + workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes); } cudnnStatus_t CUDNNWINAPI -cudnnRNNBackwardData(cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int seqLength, - const cudnnTensorDescriptor_t *yDesc, - const void *y, - const cudnnTensorDescriptor_t *dyDesc, - const void *dy, - const cudnnTensorDescriptor_t dhyDesc, - const void *dhy, - const cudnnTensorDescriptor_t dcyDesc, - const void *dcy, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnTensorDescriptor_t hxDesc, - const void *hx, - const cudnnTensorDescriptor_t cxDesc, - const void *cx, - const cudnnTensorDescriptor_t *dxDesc, - void *dx, - const cudnnTensorDescriptor_t dhxDesc, - void *dhx, - const cudnnTensorDescriptor_t dcxDesc, - void *dcx, - void *workspace, - size_t workSpaceSizeInBytes, - void *reserveSpace, - size_t reserveSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, void *, size_t, void *, size_t); +cudnnRNNBackwardData(cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, const cudnnTensorDescriptor_t *yDesc, + const void *y, const cudnnTensorDescriptor_t *dyDesc, + const void *dy, const cudnnTensorDescriptor_t dhyDesc, + const void *dhy, const cudnnTensorDescriptor_t dcyDesc, + const void *dcy, const cudnnFilterDescriptor_t wDesc, + const void *w, const cudnnTensorDescriptor_t hxDesc, + const void *hx, const cudnnTensorDescriptor_t cxDesc, + const void *cx, const cudnnTensorDescriptor_t *dxDesc, + void *dx, const cudnnTensorDescriptor_t dhxDesc, void *dhx, + const cudnnTensorDescriptor_t dcxDesc, void *dcx, + void *workspace, size_t workSpaceSizeInBytes, + void *reserveSpace, size_t reserveSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t *, const void *, + const cudnnTensorDescriptor_t *, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, + void *, const cudnnTensorDescriptor_t, void *, void *, size_t, void *, + size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRNNBackwardData"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, seqLength, yDesc, y, dyDesc, dy, dhyDesc, dhy, dcyDesc, dcy, wDesc, w, hxDesc, hx, cxDesc, cx, dxDesc, dx, dhxDesc, dhx, dcxDesc, dcx, workspace, workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes); + return func_ptr(handle, rnnDesc, seqLength, yDesc, y, dyDesc, dy, dhyDesc, + dhy, dcyDesc, dcy, wDesc, w, hxDesc, hx, cxDesc, cx, dxDesc, + dx, dhxDesc, dhx, dcxDesc, dcx, workspace, + workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnRNNBackwardWeights(cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int seqLength, - const cudnnTensorDescriptor_t *xDesc, - const void *x, - const cudnnTensorDescriptor_t hxDesc, - const void *hx, - const cudnnTensorDescriptor_t *yDesc, - const void *y, - const void *workspace, - size_t workSpaceSizeInBytes, - const cudnnFilterDescriptor_t dwDesc, - void *dw, - const void *reserveSpace, - size_t reserveSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t *, const void *, const void *, size_t, const cudnnFilterDescriptor_t, void *, const void *, size_t); +cudnnStatus_t CUDNNWINAPI cudnnRNNBackwardWeights( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, const cudnnTensorDescriptor_t *xDesc, const void *x, + const cudnnTensorDescriptor_t hxDesc, const void *hx, + const cudnnTensorDescriptor_t *yDesc, const void *y, const void *workspace, + size_t workSpaceSizeInBytes, const cudnnFilterDescriptor_t dwDesc, void *dw, + const void *reserveSpace, size_t reserveSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t *, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t *, const void *, const void *, size_t, + const cudnnFilterDescriptor_t, void *, const void *, size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRNNBackwardWeights"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, yDesc, y, workspace, workSpaceSizeInBytes, dwDesc, dw, reserveSpace, reserveSpaceSizeInBytes); + return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, yDesc, y, + workspace, workSpaceSizeInBytes, dwDesc, dw, reserveSpace, + reserveSpaceSizeInBytes); } cudnnStatus_t CUDNNWINAPI cudnnCreateCTCLossDescriptor(cudnnCTCLossDescriptor_t *ctcLossDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnCTCLossDescriptor_t *); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnCTCLossDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateCTCLossDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(ctcLossDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnSetCTCLossDescriptor(cudnnCTCLossDescriptor_t ctcLossDesc, cudnnDataType_t compType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnCTCLossDescriptor_t, cudnnDataType_t); +cudnnStatus_t CUDNNWINAPI cudnnSetCTCLossDescriptor( + cudnnCTCLossDescriptor_t ctcLossDesc, cudnnDataType_t compType) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnCTCLossDescriptor_t, cudnnDataType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetCTCLossDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(ctcLossDesc, compType); } -cudnnStatus_t CUDNNWINAPI -cudnnGetCTCLossDescriptor(cudnnCTCLossDescriptor_t ctcLossDesc, cudnnDataType_t *compType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnCTCLossDescriptor_t, cudnnDataType_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetCTCLossDescriptor( + cudnnCTCLossDescriptor_t ctcLossDesc, cudnnDataType_t *compType) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnCTCLossDescriptor_t, cudnnDataType_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetCTCLossDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(ctcLossDesc, compType); @@ -2085,82 +2131,102 @@ cudnnGetCTCLossDescriptor(cudnnCTCLossDescriptor_t ctcLossDesc, cudnnDataType_t cudnnStatus_t CUDNNWINAPI cudnnDestroyCTCLossDescriptor(cudnnCTCLossDescriptor_t ctcLossDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnCTCLossDescriptor_t); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnCTCLossDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyCTCLossDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(ctcLossDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnCTCLoss( +cudnnStatus_t CUDNNWINAPI cudnnCTCLoss( cudnnHandle_t handle, const cudnnTensorDescriptor_t - probsDesc, /* Tensor descriptor for probabilities, the dimensions are T,N,A (T is the timing steps, N is the - mini batch size, A is the alphabet size) */ - const void *probs, /* probabilities after softmax, in GPU memory */ - const int *labels, /* labels, in CPU memory */ - const int *labelLengths, /* the length of each label, in CPU memory */ - const int *inputLengths, /* the lengths of timing steps in each batch, in CPU memory */ - void *costs, /* the returned costs of CTC, in GPU memory */ - const cudnnTensorDescriptor_t gradientsDesc, /* Tensor descriptor for gradients, the dimensions are T,N,A */ - const void *gradients, /* the returned CTC gradients, in GPU memory, to compute costs only, set it to NULL */ + probsDesc, /* Tensor descriptor for probabilities, the dimensions are + T,N,A (T is the timing steps, N is the + mini batch size, A is the alphabet size) */ + const void *probs, /* probabilities after softmax, in GPU memory */ + const int *labels, /* labels, in CPU memory */ + const int *labelLengths, /* the length of each label, in CPU memory */ + const int *inputLengths, /* the lengths of timing steps in each batch, in + CPU memory */ + void *costs, /* the returned costs of CTC, in GPU memory */ + const cudnnTensorDescriptor_t + gradientsDesc, /* Tensor descriptor for gradients, the dimensions are + T,N,A */ + const void *gradients, /* the returned CTC gradients, in GPU memory, to + compute costs only, set it to NULL */ cudnnCTCLossAlgo_t algo, /* algorithm selected, supported now 0 and 1 */ cudnnCTCLossDescriptor_t ctcLossDesc, - void *workspace, /* pointer to the workspace, in GPU memory */ + void *workspace, /* pointer to the workspace, in GPU memory */ size_t workSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const void *, const int *, const int *, const int *, void *, const cudnnTensorDescriptor_t, const void *, cudnnCTCLossAlgo_t, cudnnCTCLossDescriptor_t, void *, size_t); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, const void *, const int *, + const int *, const int *, void *, const cudnnTensorDescriptor_t, + const void *, cudnnCTCLossAlgo_t, cudnnCTCLossDescriptor_t, void *, + size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCTCLoss"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, probsDesc, probs, labels, labelLengths, inputLengths, costs, gradientsDesc, gradients, algo, ctcLossDesc, workspace, workSpaceSizeInBytes); + return func_ptr(handle, probsDesc, probs, labels, labelLengths, inputLengths, + costs, gradientsDesc, gradients, algo, ctcLossDesc, workspace, + workSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnGetCTCLossWorkspaceSize( +cudnnStatus_t CUDNNWINAPI cudnnGetCTCLossWorkspaceSize( cudnnHandle_t handle, - const cudnnTensorDescriptor_t probsDesc, /* Tensor descriptor for probabilities, the dimensions are T,N,A (T is the - timing steps, N is the mini batch size, A is the alphabet size) */ - const cudnnTensorDescriptor_t gradientsDesc, /* Tensor descriptor for gradients, the - dimensions are T,N,A. To compute costs - only, set it to NULL */ - const int *labels, /* labels, in CPU memory */ - const int *labelLengths, /* the length of each label, in CPU memory */ - const int *inputLengths, /* the lengths of timing steps in each batch, in CPU memory */ - cudnnCTCLossAlgo_t algo, /* algorithm selected, supported now 0 and 1 */ - cudnnCTCLossDescriptor_t ctcLossDesc, - size_t *sizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const int *, const int *, const int *, cudnnCTCLossAlgo_t, cudnnCTCLossDescriptor_t, size_t *); + const cudnnTensorDescriptor_t + probsDesc, /* Tensor descriptor for probabilities, the dimensions are + T,N,A (T is the + timing steps, N is the mini batch size, A is the alphabet + size) */ + const cudnnTensorDescriptor_t + gradientsDesc, /* Tensor descriptor for gradients, the + dimensions are T,N,A. To compute costs + only, set it to NULL */ + const int *labels, /* labels, in CPU memory */ + const int *labelLengths, /* the length of each label, in CPU memory */ + const int *inputLengths, /* the lengths of timing steps in each batch, in + CPU memory */ + cudnnCTCLossAlgo_t algo, /* algorithm selected, supported now 0 and 1 */ + cudnnCTCLossDescriptor_t ctcLossDesc, size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, + const cudnnTensorDescriptor_t, const int *, const int *, const int *, + cudnnCTCLossAlgo_t, cudnnCTCLossDescriptor_t, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetCTCLossWorkspaceSize"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, probsDesc, gradientsDesc, labels, labelLengths, inputLengths, algo, ctcLossDesc, sizeInBytes); + return func_ptr(handle, probsDesc, gradientsDesc, labels, labelLengths, + inputLengths, algo, ctcLossDesc, sizeInBytes); } cudnnStatus_t CUDNNWINAPI cudnnCreateAlgorithmDescriptor(cudnnAlgorithmDescriptor_t *algoDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnAlgorithmDescriptor_t *); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnAlgorithmDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateAlgorithmDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(algoDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnSetAlgorithmDescriptor(cudnnAlgorithmDescriptor_t algoDesc, cudnnAlgorithm_t algorithm) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnAlgorithmDescriptor_t, cudnnAlgorithm_t); +cudnnStatus_t CUDNNWINAPI cudnnSetAlgorithmDescriptor( + cudnnAlgorithmDescriptor_t algoDesc, cudnnAlgorithm_t algorithm) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnAlgorithmDescriptor_t, + cudnnAlgorithm_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetAlgorithmDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(algoDesc, algorithm); } -cudnnStatus_t CUDNNWINAPI -cudnnGetAlgorithmDescriptor(const cudnnAlgorithmDescriptor_t algoDesc, cudnnAlgorithm_t *algorithm) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnAlgorithmDescriptor_t, cudnnAlgorithm_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetAlgorithmDescriptor( + const cudnnAlgorithmDescriptor_t algoDesc, cudnnAlgorithm_t *algorithm) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(const cudnnAlgorithmDescriptor_t, + cudnnAlgorithm_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetAlgorithmDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(algoDesc, algorithm); } -cudnnStatus_t CUDNNWINAPI -cudnnCopyAlgorithmDescriptor(const cudnnAlgorithmDescriptor_t src, cudnnAlgorithmDescriptor_t dest) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnAlgorithmDescriptor_t, cudnnAlgorithmDescriptor_t); +cudnnStatus_t CUDNNWINAPI cudnnCopyAlgorithmDescriptor( + const cudnnAlgorithmDescriptor_t src, cudnnAlgorithmDescriptor_t dest) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(const cudnnAlgorithmDescriptor_t, + cudnnAlgorithmDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCopyAlgorithmDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(src, dest); @@ -2168,135 +2234,141 @@ cudnnCopyAlgorithmDescriptor(const cudnnAlgorithmDescriptor_t src, cudnnAlgorith cudnnStatus_t CUDNNWINAPI cudnnDestroyAlgorithmDescriptor(cudnnAlgorithmDescriptor_t algoDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnAlgorithmDescriptor_t); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnAlgorithmDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyAlgorithmDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(algoDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnCreateAlgorithmPerformance(cudnnAlgorithmPerformance_t *algoPerf, int numberToCreate) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnAlgorithmPerformance_t *, int); +cudnnStatus_t CUDNNWINAPI cudnnCreateAlgorithmPerformance( + cudnnAlgorithmPerformance_t *algoPerf, int numberToCreate) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnAlgorithmPerformance_t *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateAlgorithmPerformance"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(algoPerf, numberToCreate); } -cudnnStatus_t CUDNNWINAPI -cudnnSetAlgorithmPerformance(cudnnAlgorithmPerformance_t algoPerf, - cudnnAlgorithmDescriptor_t algoDesc, - cudnnStatus_t status, - float time, - size_t memory) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnAlgorithmPerformance_t, cudnnAlgorithmDescriptor_t, cudnnStatus_t, float, size_t); +cudnnStatus_t CUDNNWINAPI cudnnSetAlgorithmPerformance( + cudnnAlgorithmPerformance_t algoPerf, cudnnAlgorithmDescriptor_t algoDesc, + cudnnStatus_t status, float time, size_t memory) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnAlgorithmPerformance_t, + cudnnAlgorithmDescriptor_t, + cudnnStatus_t, float, size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetAlgorithmPerformance"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(algoPerf, algoDesc, status, time, memory); } -cudnnStatus_t CUDNNWINAPI -cudnnGetAlgorithmPerformance(const cudnnAlgorithmPerformance_t algoPerf, - cudnnAlgorithmDescriptor_t *algoDesc, - cudnnStatus_t *status, - float *time, - size_t *memory) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnAlgorithmPerformance_t, cudnnAlgorithmDescriptor_t *, cudnnStatus_t *, float *, size_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetAlgorithmPerformance( + const cudnnAlgorithmPerformance_t algoPerf, + cudnnAlgorithmDescriptor_t *algoDesc, cudnnStatus_t *status, float *time, + size_t *memory) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnAlgorithmPerformance_t, cudnnAlgorithmDescriptor_t *, + cudnnStatus_t *, float *, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetAlgorithmPerformance"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(algoPerf, algoDesc, status, time, memory); } -cudnnStatus_t CUDNNWINAPI -cudnnDestroyAlgorithmPerformance(cudnnAlgorithmPerformance_t *algoPerf, int numberToDestroy) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnAlgorithmPerformance_t *, int); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyAlgorithmPerformance"); +cudnnStatus_t CUDNNWINAPI cudnnDestroyAlgorithmPerformance( + cudnnAlgorithmPerformance_t *algoPerf, int numberToDestroy) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnAlgorithmPerformance_t *, int); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnDestroyAlgorithmPerformance"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(algoPerf, numberToDestroy); } -cudnnStatus_t CUDNNWINAPI -cudnnGetAlgorithmSpaceSize(cudnnHandle_t handle, cudnnAlgorithmDescriptor_t algoDesc, size_t *algoSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnAlgorithmDescriptor_t, size_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetAlgorithmSpaceSize( + cudnnHandle_t handle, cudnnAlgorithmDescriptor_t algoDesc, + size_t *algoSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnAlgorithmDescriptor_t, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetAlgorithmSpaceSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, algoDesc, algoSpaceSizeInBytes); } cudnnStatus_t CUDNNWINAPI -cudnnSaveAlgorithm(cudnnHandle_t handle, - cudnnAlgorithmDescriptor_t algoDesc, - void *algoSpace, - size_t algoSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnAlgorithmDescriptor_t, void *, size_t); +cudnnSaveAlgorithm(cudnnHandle_t handle, cudnnAlgorithmDescriptor_t algoDesc, + void *algoSpace, size_t algoSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnAlgorithmDescriptor_t, void *, size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSaveAlgorithm"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, algoDesc, algoSpace, algoSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnRestoreAlgorithm(cudnnHandle_t handle, - void *algoSpace, - size_t algoSpaceSizeInBytes, - cudnnAlgorithmDescriptor_t algoDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, void *, size_t, cudnnAlgorithmDescriptor_t); +cudnnStatus_t CUDNNWINAPI cudnnRestoreAlgorithm( + cudnnHandle_t handle, void *algoSpace, size_t algoSpaceSizeInBytes, + cudnnAlgorithmDescriptor_t algoDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t, void *, size_t, + cudnnAlgorithmDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRestoreAlgorithm"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, algoSpace, algoSpaceSizeInBytes, algoDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnRNNSetClip(cudnnHandle_t handle, - cudnnRNNDescriptor_t rnnDesc, - cudnnRNNClipMode_t clipMode, - cudnnNanPropagation_t clipNanOpt, - double lclip, - double rclip) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnRNNDescriptor_t, cudnnRNNClipMode_t, cudnnNanPropagation_t, double, double); +cudnnStatus_t CUDNNWINAPI cudnnRNNSetClip(cudnnHandle_t handle, + cudnnRNNDescriptor_t rnnDesc, + cudnnRNNClipMode_t clipMode, + cudnnNanPropagation_t clipNanOpt, + double lclip, double rclip) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnRNNDescriptor_t, cudnnRNNClipMode_t, + cudnnNanPropagation_t, double, double); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRNNSetClip"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rnnDesc, clipMode, clipNanOpt, lclip, rclip); } -cudnnStatus_t CUDNNWINAPI -cudnnRNNGetClip(cudnnHandle_t handle, - cudnnRNNDescriptor_t rnnDesc, - cudnnRNNClipMode_t *clipMode, - cudnnNanPropagation_t *clipNanOpt, - double *lclip, - double *rclip) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnRNNDescriptor_t, cudnnRNNClipMode_t *, cudnnNanPropagation_t *, double *, double *); +cudnnStatus_t CUDNNWINAPI cudnnRNNGetClip(cudnnHandle_t handle, + cudnnRNNDescriptor_t rnnDesc, + cudnnRNNClipMode_t *clipMode, + cudnnNanPropagation_t *clipNanOpt, + double *lclip, double *rclip) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnRNNDescriptor_t, cudnnRNNClipMode_t *, + cudnnNanPropagation_t *, double *, double *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRNNGetClip"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rnnDesc, clipMode, clipNanOpt, lclip, rclip); } -cudnnStatus_t CUDNNWINAPI -cudnnSetCallback(unsigned mask, void *udata, cudnnCallback_t fptr) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(unsigned int, void *, cudnnCallback_t); +cudnnStatus_t CUDNNWINAPI cudnnSetCallback(unsigned mask, void *udata, + cudnnCallback_t fptr) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(unsigned int, void *, cudnnCallback_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetCallback"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(mask, udata, fptr); } -cudnnStatus_t CUDNNWINAPI -cudnnGetCallback(unsigned *mask, void **udata, cudnnCallback_t *fptr) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(unsigned int *, void **, cudnnCallback_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetCallback(unsigned *mask, void **udata, + cudnnCallback_t *fptr) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(unsigned int *, void **, cudnnCallback_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetCallback"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(mask, udata, fptr); } -cudnnStatus_t CUDNNWINAPI -cudnnSetRNNPaddingMode(cudnnRNNDescriptor_t rnnDesc, cudnnRNNPaddingMode_t paddingMode) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t, cudnnRNNPaddingMode_t); +cudnnStatus_t CUDNNWINAPI cudnnSetRNNPaddingMode( + cudnnRNNDescriptor_t rnnDesc, cudnnRNNPaddingMode_t paddingMode) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnRNNDescriptor_t, cudnnRNNPaddingMode_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetRNNPaddingMode"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(rnnDesc, paddingMode); } -cudnnStatus_t CUDNNWINAPI -cudnnGetRNNPaddingMode(cudnnRNNDescriptor_t rnnDesc, cudnnRNNPaddingMode_t *paddingMode) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t, cudnnRNNPaddingMode_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetRNNPaddingMode( + cudnnRNNDescriptor_t rnnDesc, cudnnRNNPaddingMode_t *paddingMode) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnRNNDescriptor_t, + cudnnRNNPaddingMode_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNPaddingMode"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(rnnDesc, paddingMode); @@ -2304,7 +2376,7 @@ cudnnGetRNNPaddingMode(cudnnRNNDescriptor_t rnnDesc, cudnnRNNPaddingMode_t *padd cudnnStatus_t CUDNNWINAPI cudnnCreateRNNDataDescriptor(cudnnRNNDataDescriptor_t *RNNDataDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDataDescriptor_t *); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnRNNDataDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateRNNDataDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(RNNDataDesc); @@ -2312,199 +2384,202 @@ cudnnCreateRNNDataDescriptor(cudnnRNNDataDescriptor_t *RNNDataDesc) { cudnnStatus_t CUDNNWINAPI cudnnDestroyRNNDataDescriptor(cudnnRNNDataDescriptor_t RNNDataDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDataDescriptor_t); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnRNNDataDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyRNNDataDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(RNNDataDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnSetRNNDataDescriptor(cudnnRNNDataDescriptor_t RNNDataDesc, - cudnnDataType_t dataType, - cudnnRNNDataLayout_t layout, - int maxSeqLength, - int batchSize, - int vectorSize, - const int seqLengthArray[], /* length of each sequence in the batch */ - void *paddingFill) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDataDescriptor_t, cudnnDataType_t, cudnnRNNDataLayout_t, int, int, int, const int [], void *); +cudnnStatus_t CUDNNWINAPI cudnnSetRNNDataDescriptor( + cudnnRNNDataDescriptor_t RNNDataDesc, cudnnDataType_t dataType, + cudnnRNNDataLayout_t layout, int maxSeqLength, int batchSize, + int vectorSize, + const int seqLengthArray[], /* length of each sequence in the batch */ + void *paddingFill) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnRNNDataDescriptor_t, cudnnDataType_t, cudnnRNNDataLayout_t, int, int, + int, const int[], void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetRNNDataDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(RNNDataDesc, dataType, layout, maxSeqLength, batchSize, vectorSize, seqLengthArray, paddingFill); + return func_ptr(RNNDataDesc, dataType, layout, maxSeqLength, batchSize, + vectorSize, seqLengthArray, paddingFill); } -cudnnStatus_t CUDNNWINAPI -cudnnGetRNNDataDescriptor(cudnnRNNDataDescriptor_t RNNDataDesc, - cudnnDataType_t *dataType, - cudnnRNNDataLayout_t *layout, - int *maxSeqLength, - int *batchSize, - int *vectorSize, - int arrayLengthRequested, - int seqLengthArray[], - void *paddingFill) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDataDescriptor_t, cudnnDataType_t *, cudnnRNNDataLayout_t *, int *, int *, int *, int, int [], void *); +cudnnStatus_t CUDNNWINAPI cudnnGetRNNDataDescriptor( + cudnnRNNDataDescriptor_t RNNDataDesc, cudnnDataType_t *dataType, + cudnnRNNDataLayout_t *layout, int *maxSeqLength, int *batchSize, + int *vectorSize, int arrayLengthRequested, int seqLengthArray[], + void *paddingFill) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnRNNDataDescriptor_t, cudnnDataType_t *, cudnnRNNDataLayout_t *, + int *, int *, int *, int, int[], void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNDataDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(RNNDataDesc, dataType, layout, maxSeqLength, batchSize, vectorSize, arrayLengthRequested, seqLengthArray, paddingFill); + return func_ptr(RNNDataDesc, dataType, layout, maxSeqLength, batchSize, + vectorSize, arrayLengthRequested, seqLengthArray, + paddingFill); } -cudnnStatus_t CUDNNWINAPI -cudnnRNNForwardTrainingEx(cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const cudnnRNNDataDescriptor_t xDesc, - const void *x, - const cudnnTensorDescriptor_t hxDesc, - const void *hx, - const cudnnTensorDescriptor_t cxDesc, - const void *cx, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnRNNDataDescriptor_t yDesc, - void *y, - const cudnnTensorDescriptor_t hyDesc, - void *hy, - const cudnnTensorDescriptor_t cyDesc, - void *cy, - const cudnnRNNDataDescriptor_t kDesc, /* reserved, should pass NULL */ - const void *keys, /* reserved, should pass NULL */ - const cudnnRNNDataDescriptor_t cDesc, /* reserved, should pass NULL */ - void *cAttn, /* reserved, should pass NULL */ - const cudnnRNNDataDescriptor_t iDesc, /* reserved, should pass NULL */ - void *iAttn, /* reserved, should pass NULL */ - const cudnnRNNDataDescriptor_t qDesc, /* reserved, should pass NULL */ - void *queries, /* reserved, should pass NULL */ - void *workSpace, - size_t workSpaceSizeInBytes, - void *reserveSpace, - size_t reserveSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const cudnnRNNDataDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnRNNDataDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, const cudnnRNNDataDescriptor_t, const void *, const cudnnRNNDataDescriptor_t, void *, const cudnnRNNDataDescriptor_t, void *, const cudnnRNNDataDescriptor_t, void *, void *, size_t, void *, size_t); +cudnnStatus_t CUDNNWINAPI cudnnRNNForwardTrainingEx( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const cudnnRNNDataDescriptor_t xDesc, const void *x, + const cudnnTensorDescriptor_t hxDesc, const void *hx, + const cudnnTensorDescriptor_t cxDesc, const void *cx, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnRNNDataDescriptor_t yDesc, void *y, + const cudnnTensorDescriptor_t hyDesc, void *hy, + const cudnnTensorDescriptor_t cyDesc, void *cy, + const cudnnRNNDataDescriptor_t kDesc, /* reserved, should pass NULL */ + const void *keys, /* reserved, should pass NULL */ + const cudnnRNNDataDescriptor_t cDesc, /* reserved, should pass NULL */ + void *cAttn, /* reserved, should pass NULL */ + const cudnnRNNDataDescriptor_t iDesc, /* reserved, should pass NULL */ + void *iAttn, /* reserved, should pass NULL */ + const cudnnRNNDataDescriptor_t qDesc, /* reserved, should pass NULL */ + void *queries, /* reserved, should pass NULL */ + void *workSpace, size_t workSpaceSizeInBytes, void *reserveSpace, + size_t reserveSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const cudnnRNNDataDescriptor_t, + const void *, const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnRNNDataDescriptor_t, void *, const cudnnTensorDescriptor_t, + void *, const cudnnTensorDescriptor_t, void *, + const cudnnRNNDataDescriptor_t, const void *, + const cudnnRNNDataDescriptor_t, void *, const cudnnRNNDataDescriptor_t, + void *, const cudnnRNNDataDescriptor_t, void *, void *, size_t, void *, + size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRNNForwardTrainingEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, xDesc, x, hxDesc, hx, cxDesc, cx, wDesc, w, yDesc, y, hyDesc, hy, cyDesc, cy, kDesc, keys, cDesc, cAttn, iDesc, iAttn, qDesc, queries, workSpace, workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes); + return func_ptr(handle, rnnDesc, xDesc, x, hxDesc, hx, cxDesc, cx, wDesc, w, + yDesc, y, hyDesc, hy, cyDesc, cy, kDesc, keys, cDesc, cAttn, + iDesc, iAttn, qDesc, queries, workSpace, workSpaceSizeInBytes, + reserveSpace, reserveSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnRNNForwardInferenceEx(cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const cudnnRNNDataDescriptor_t xDesc, - const void *x, - const cudnnTensorDescriptor_t hxDesc, - const void *hx, - const cudnnTensorDescriptor_t cxDesc, - const void *cx, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnRNNDataDescriptor_t yDesc, - void *y, - const cudnnTensorDescriptor_t hyDesc, - void *hy, - const cudnnTensorDescriptor_t cyDesc, - void *cy, - const cudnnRNNDataDescriptor_t kDesc, /* reserved, should pass NULL */ - const void *keys, /* reserved, should pass NULL */ - const cudnnRNNDataDescriptor_t cDesc, /* reserved, should pass NULL */ - void *cAttn, /* reserved, should pass NULL */ - const cudnnRNNDataDescriptor_t iDesc, /* reserved, should pass NULL */ - void *iAttn, /* reserved, should pass NULL */ - const cudnnRNNDataDescriptor_t qDesc, /* reserved, should pass NULL */ - void *queries, /* reserved, should pass NULL */ - void *workSpace, - size_t workSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const cudnnRNNDataDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnRNNDataDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, const cudnnRNNDataDescriptor_t, const void *, const cudnnRNNDataDescriptor_t, void *, const cudnnRNNDataDescriptor_t, void *, const cudnnRNNDataDescriptor_t, void *, void *, size_t); +cudnnStatus_t CUDNNWINAPI cudnnRNNForwardInferenceEx( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const cudnnRNNDataDescriptor_t xDesc, const void *x, + const cudnnTensorDescriptor_t hxDesc, const void *hx, + const cudnnTensorDescriptor_t cxDesc, const void *cx, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnRNNDataDescriptor_t yDesc, void *y, + const cudnnTensorDescriptor_t hyDesc, void *hy, + const cudnnTensorDescriptor_t cyDesc, void *cy, + const cudnnRNNDataDescriptor_t kDesc, /* reserved, should pass NULL */ + const void *keys, /* reserved, should pass NULL */ + const cudnnRNNDataDescriptor_t cDesc, /* reserved, should pass NULL */ + void *cAttn, /* reserved, should pass NULL */ + const cudnnRNNDataDescriptor_t iDesc, /* reserved, should pass NULL */ + void *iAttn, /* reserved, should pass NULL */ + const cudnnRNNDataDescriptor_t qDesc, /* reserved, should pass NULL */ + void *queries, /* reserved, should pass NULL */ + void *workSpace, size_t workSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const cudnnRNNDataDescriptor_t, + const void *, const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnRNNDataDescriptor_t, void *, const cudnnTensorDescriptor_t, + void *, const cudnnTensorDescriptor_t, void *, + const cudnnRNNDataDescriptor_t, const void *, + const cudnnRNNDataDescriptor_t, void *, const cudnnRNNDataDescriptor_t, + void *, const cudnnRNNDataDescriptor_t, void *, void *, size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRNNForwardInferenceEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, xDesc, x, hxDesc, hx, cxDesc, cx, wDesc, w, yDesc, y, hyDesc, hy, cyDesc, cy, kDesc, keys, cDesc, cAttn, iDesc, iAttn, qDesc, queries, workSpace, workSpaceSizeInBytes); + return func_ptr(handle, rnnDesc, xDesc, x, hxDesc, hx, cxDesc, cx, wDesc, w, + yDesc, y, hyDesc, hy, cyDesc, cy, kDesc, keys, cDesc, cAttn, + iDesc, iAttn, qDesc, queries, workSpace, + workSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnRNNBackwardDataEx(cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const cudnnRNNDataDescriptor_t yDesc, - const void *y, - const cudnnRNNDataDescriptor_t dyDesc, - const void *dy, - const cudnnRNNDataDescriptor_t dcDesc, /* reserved, should pass NULL */ - const void *dcAttn, /* reserved, should pass NULL */ - const cudnnTensorDescriptor_t dhyDesc, - const void *dhy, - const cudnnTensorDescriptor_t dcyDesc, - const void *dcy, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnTensorDescriptor_t hxDesc, - const void *hx, - const cudnnTensorDescriptor_t cxDesc, - const void *cx, - const cudnnRNNDataDescriptor_t dxDesc, - void *dx, - const cudnnTensorDescriptor_t dhxDesc, - void *dhx, - const cudnnTensorDescriptor_t dcxDesc, - void *dcx, - const cudnnRNNDataDescriptor_t dkDesc, /* reserved, should pass NULL */ - void *dkeys, /* reserved, should pass NULL */ - void *workSpace, - size_t workSpaceSizeInBytes, - void *reserveSpace, - size_t reserveSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const cudnnRNNDataDescriptor_t, const void *, const cudnnRNNDataDescriptor_t, const void *, const cudnnRNNDataDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnRNNDataDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, const cudnnRNNDataDescriptor_t, void *, void *, size_t, void *, size_t); +cudnnStatus_t CUDNNWINAPI cudnnRNNBackwardDataEx( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const cudnnRNNDataDescriptor_t yDesc, const void *y, + const cudnnRNNDataDescriptor_t dyDesc, const void *dy, + const cudnnRNNDataDescriptor_t dcDesc, /* reserved, should pass NULL */ + const void *dcAttn, /* reserved, should pass NULL */ + const cudnnTensorDescriptor_t dhyDesc, const void *dhy, + const cudnnTensorDescriptor_t dcyDesc, const void *dcy, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnTensorDescriptor_t hxDesc, const void *hx, + const cudnnTensorDescriptor_t cxDesc, const void *cx, + const cudnnRNNDataDescriptor_t dxDesc, void *dx, + const cudnnTensorDescriptor_t dhxDesc, void *dhx, + const cudnnTensorDescriptor_t dcxDesc, void *dcx, + const cudnnRNNDataDescriptor_t dkDesc, /* reserved, should pass NULL */ + void *dkeys, /* reserved, should pass NULL */ + void *workSpace, size_t workSpaceSizeInBytes, void *reserveSpace, + size_t reserveSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const cudnnRNNDataDescriptor_t, + const void *, const cudnnRNNDataDescriptor_t, const void *, + const cudnnRNNDataDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnRNNDataDescriptor_t, void *, const cudnnTensorDescriptor_t, + void *, const cudnnTensorDescriptor_t, void *, + const cudnnRNNDataDescriptor_t, void *, void *, size_t, void *, size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRNNBackwardDataEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, yDesc, y, dyDesc, dy, dcDesc, dcAttn, dhyDesc, dhy, dcyDesc, dcy, wDesc, w, hxDesc, hx, cxDesc, cx, dxDesc, dx, dhxDesc, dhx, dcxDesc, dcx, dkDesc, dkeys, workSpace, workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes); + return func_ptr(handle, rnnDesc, yDesc, y, dyDesc, dy, dcDesc, dcAttn, + dhyDesc, dhy, dcyDesc, dcy, wDesc, w, hxDesc, hx, cxDesc, cx, + dxDesc, dx, dhxDesc, dhx, dcxDesc, dcx, dkDesc, dkeys, + workSpace, workSpaceSizeInBytes, reserveSpace, + reserveSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnRNNBackwardWeightsEx(cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const cudnnRNNDataDescriptor_t xDesc, - const void *x, - const cudnnTensorDescriptor_t hxDesc, - const void *hx, - const cudnnRNNDataDescriptor_t yDesc, - const void *y, - void *workSpace, - size_t workSpaceSizeInBytes, - const cudnnFilterDescriptor_t dwDesc, - void *dw, - void *reserveSpace, - size_t reserveSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const cudnnRNNDataDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnRNNDataDescriptor_t, const void *, void *, size_t, const cudnnFilterDescriptor_t, void *, void *, size_t); +cudnnStatus_t CUDNNWINAPI cudnnRNNBackwardWeightsEx( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const cudnnRNNDataDescriptor_t xDesc, const void *x, + const cudnnTensorDescriptor_t hxDesc, const void *hx, + const cudnnRNNDataDescriptor_t yDesc, const void *y, void *workSpace, + size_t workSpaceSizeInBytes, const cudnnFilterDescriptor_t dwDesc, void *dw, + void *reserveSpace, size_t reserveSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const cudnnRNNDataDescriptor_t, + const void *, const cudnnTensorDescriptor_t, const void *, + const cudnnRNNDataDescriptor_t, const void *, void *, size_t, + const cudnnFilterDescriptor_t, void *, void *, size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRNNBackwardWeightsEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, xDesc, x, hxDesc, hx, yDesc, y, workSpace, workSpaceSizeInBytes, dwDesc, dw, reserveSpace, reserveSpaceSizeInBytes); + return func_ptr(handle, rnnDesc, xDesc, x, hxDesc, hx, yDesc, y, workSpace, + workSpaceSizeInBytes, dwDesc, dw, reserveSpace, + reserveSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnSetRNNDescriptor_v6(cudnnHandle_t handle, - cudnnRNNDescriptor_t rnnDesc, - const int hiddenSize, - const int numLayers, - cudnnDropoutDescriptor_t dropoutDesc, - cudnnRNNInputMode_t inputMode, - cudnnDirectionMode_t direction, - cudnnRNNMode_t mode, - cudnnRNNAlgo_t algo, - cudnnDataType_t dataType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnRNNDescriptor_t, const int, const int, cudnnDropoutDescriptor_t, cudnnRNNInputMode_t, cudnnDirectionMode_t, cudnnRNNMode_t, cudnnRNNAlgo_t, cudnnDataType_t); +cudnnStatus_t CUDNNWINAPI cudnnSetRNNDescriptor_v6( + cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc, const int hiddenSize, + const int numLayers, cudnnDropoutDescriptor_t dropoutDesc, + cudnnRNNInputMode_t inputMode, cudnnDirectionMode_t direction, + cudnnRNNMode_t mode, cudnnRNNAlgo_t algo, cudnnDataType_t dataType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnRNNDescriptor_t, const int, const int, + cudnnDropoutDescriptor_t, cudnnRNNInputMode_t, cudnnDirectionMode_t, + cudnnRNNMode_t, cudnnRNNAlgo_t, cudnnDataType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetRNNDescriptor_v6"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, hiddenSize, numLayers, dropoutDesc, inputMode, direction, mode, algo, dataType); + return func_ptr(handle, rnnDesc, hiddenSize, numLayers, dropoutDesc, + inputMode, direction, mode, algo, dataType); } -cudnnStatus_t CUDNNWINAPI -cudnnSetRNNDescriptor_v5(cudnnRNNDescriptor_t rnnDesc, - int hiddenSize, - int numLayers, - cudnnDropoutDescriptor_t dropoutDesc, - cudnnRNNInputMode_t inputMode, - cudnnDirectionMode_t direction, - cudnnRNNMode_t mode, - cudnnDataType_t dataType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t, int, int, cudnnDropoutDescriptor_t, cudnnRNNInputMode_t, cudnnDirectionMode_t, cudnnRNNMode_t, cudnnDataType_t); +cudnnStatus_t CUDNNWINAPI cudnnSetRNNDescriptor_v5( + cudnnRNNDescriptor_t rnnDesc, int hiddenSize, int numLayers, + cudnnDropoutDescriptor_t dropoutDesc, cudnnRNNInputMode_t inputMode, + cudnnDirectionMode_t direction, cudnnRNNMode_t mode, + cudnnDataType_t dataType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnRNNDescriptor_t, int, int, cudnnDropoutDescriptor_t, + cudnnRNNInputMode_t, cudnnDirectionMode_t, cudnnRNNMode_t, + cudnnDataType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetRNNDescriptor_v5"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(rnnDesc, hiddenSize, numLayers, dropoutDesc, inputMode, direction, mode, dataType); + return func_ptr(rnnDesc, hiddenSize, numLayers, dropoutDesc, inputMode, + direction, mode, dataType); } } // extern "C" diff --git a/tensorflow/stream_executor/cuda/cudnn_7_4.inc b/tensorflow/stream_executor/cuda/cudnn_7_4.inc index bd9f49f9780..883c8ba8812 100644 --- a/tensorflow/stream_executor/cuda/cudnn_7_4.inc +++ b/tensorflow/stream_executor/cuda/cudnn_7_4.inc @@ -2,73 +2,71 @@ extern "C" { -size_t CUDNNWINAPI -cudnnGetVersion(void) { - using FuncPtr = size_t (CUDNNWINAPI *)(); +size_t CUDNNWINAPI cudnnGetVersion(void) { + using FuncPtr = size_t(CUDNNWINAPI *)(); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetVersion"); if (!func_ptr) return 0; return func_ptr(); } -size_t CUDNNWINAPI -cudnnGetCudartVersion(void) { - using FuncPtr = size_t (CUDNNWINAPI *)(); +size_t CUDNNWINAPI cudnnGetCudartVersion(void) { + using FuncPtr = size_t(CUDNNWINAPI *)(); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetCudartVersion"); if (!func_ptr) return 0; return func_ptr(); } -const char *CUDNNWINAPI -cudnnGetErrorString(cudnnStatus_t status) { - using FuncPtr = const char * (CUDNNWINAPI *)(cudnnStatus_t); +const char *CUDNNWINAPI cudnnGetErrorString(cudnnStatus_t status) { + using FuncPtr = const char *(CUDNNWINAPI *)(cudnnStatus_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetErrorString"); if (!func_ptr) return "cudnnGetErrorString symbol not found."; return func_ptr(status); } -cudnnStatus_t CUDNNWINAPI -cudnnQueryRuntimeError(cudnnHandle_t handle, cudnnStatus_t *rstatus, cudnnErrQueryMode_t mode, cudnnRuntimeTag_t *tag) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnStatus_t *, cudnnErrQueryMode_t, cudnnRuntimeTag_t *); +cudnnStatus_t CUDNNWINAPI cudnnQueryRuntimeError(cudnnHandle_t handle, + cudnnStatus_t *rstatus, + cudnnErrQueryMode_t mode, + cudnnRuntimeTag_t *tag) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnStatus_t *, cudnnErrQueryMode_t, cudnnRuntimeTag_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnQueryRuntimeError"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rstatus, mode, tag); } -cudnnStatus_t CUDNNWINAPI -cudnnGetProperty(libraryPropertyType type, int *value) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(libraryPropertyType, int *); +cudnnStatus_t CUDNNWINAPI cudnnGetProperty(libraryPropertyType type, + int *value) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(libraryPropertyType, int *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetProperty"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(type, value); } -cudnnStatus_t CUDNNWINAPI -cudnnCreate(cudnnHandle_t *handle) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t *); +cudnnStatus_t CUDNNWINAPI cudnnCreate(cudnnHandle_t *handle) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreate"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle); } -cudnnStatus_t CUDNNWINAPI -cudnnDestroy(cudnnHandle_t handle) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t); +cudnnStatus_t CUDNNWINAPI cudnnDestroy(cudnnHandle_t handle) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroy"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle); } -cudnnStatus_t CUDNNWINAPI -cudnnSetStream(cudnnHandle_t handle, cudaStream_t streamId) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudaStream_t); +cudnnStatus_t CUDNNWINAPI cudnnSetStream(cudnnHandle_t handle, + cudaStream_t streamId) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t, cudaStream_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetStream"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, streamId); } -cudnnStatus_t CUDNNWINAPI -cudnnGetStream(cudnnHandle_t handle, cudaStream_t *streamId) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudaStream_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetStream(cudnnHandle_t handle, + cudaStream_t *streamId) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t, cudaStream_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetStream"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, streamId); @@ -76,100 +74,97 @@ cudnnGetStream(cudnnHandle_t handle, cudaStream_t *streamId) { cudnnStatus_t CUDNNWINAPI cudnnCreateTensorDescriptor(cudnnTensorDescriptor_t *tensorDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorDescriptor_t *); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnTensorDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(tensorDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnSetTensor4dDescriptor(cudnnTensorDescriptor_t tensorDesc, - cudnnTensorFormat_t format, - cudnnDataType_t dataType, /* image data type */ - int n, /* number of inputs (batch size) */ - int c, /* number of input feature maps */ - int h, /* height of input section */ - int w) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorDescriptor_t, cudnnTensorFormat_t, cudnnDataType_t, int, int, int, int); +cudnnStatus_t CUDNNWINAPI cudnnSetTensor4dDescriptor( + cudnnTensorDescriptor_t tensorDesc, cudnnTensorFormat_t format, + cudnnDataType_t dataType, /* image data type */ + int n, /* number of inputs (batch size) */ + int c, /* number of input feature maps */ + int h, /* height of input section */ + int w) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnTensorDescriptor_t, cudnnTensorFormat_t, + cudnnDataType_t, int, int, int, int); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetTensor4dDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(tensorDesc, format, dataType, n, c, h, w); } -cudnnStatus_t CUDNNWINAPI -cudnnSetTensor4dDescriptorEx(cudnnTensorDescriptor_t tensorDesc, - cudnnDataType_t dataType, /* image data type */ - int n, /* number of inputs (batch size) */ - int c, /* number of input feature maps */ - int h, /* height of input section */ - int w, /* width of input section */ - int nStride, - int cStride, - int hStride, - int wStride) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorDescriptor_t, cudnnDataType_t, int, int, int, int, int, int, int, int); +cudnnStatus_t CUDNNWINAPI cudnnSetTensor4dDescriptorEx( + cudnnTensorDescriptor_t tensorDesc, + cudnnDataType_t dataType, /* image data type */ + int n, /* number of inputs (batch size) */ + int c, /* number of input feature maps */ + int h, /* height of input section */ + int w, /* width of input section */ + int nStride, int cStride, int hStride, int wStride) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnTensorDescriptor_t, cudnnDataType_t, + int, int, int, int, int, int, int, int); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetTensor4dDescriptorEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(tensorDesc, dataType, n, c, h, w, nStride, cStride, hStride, wStride); + return func_ptr(tensorDesc, dataType, n, c, h, w, nStride, cStride, hStride, + wStride); } -cudnnStatus_t CUDNNWINAPI -cudnnGetTensor4dDescriptor(const cudnnTensorDescriptor_t tensorDesc, - cudnnDataType_t *dataType, /* image data type */ - int *n, /* number of inputs (batch size) */ - int *c, /* number of input feature maps */ - int *h, /* height of input section */ - int *w, /* width of input section */ - int *nStride, - int *cStride, - int *hStride, - int *wStride) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnTensorDescriptor_t, cudnnDataType_t *, int *, int *, int *, int *, int *, int *, int *, int *); +cudnnStatus_t CUDNNWINAPI cudnnGetTensor4dDescriptor( + const cudnnTensorDescriptor_t tensorDesc, + cudnnDataType_t *dataType, /* image data type */ + int *n, /* number of inputs (batch size) */ + int *c, /* number of input feature maps */ + int *h, /* height of input section */ + int *w, /* width of input section */ + int *nStride, int *cStride, int *hStride, int *wStride) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnTensorDescriptor_t, cudnnDataType_t *, int *, int *, int *, + int *, int *, int *, int *, int *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetTensor4dDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(tensorDesc, dataType, n, c, h, w, nStride, cStride, hStride, wStride); + return func_ptr(tensorDesc, dataType, n, c, h, w, nStride, cStride, hStride, + wStride); } -cudnnStatus_t CUDNNWINAPI -cudnnSetTensorNdDescriptor(cudnnTensorDescriptor_t tensorDesc, - cudnnDataType_t dataType, - int nbDims, - const int dimA[], - const int strideA[]) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorDescriptor_t, cudnnDataType_t, int, const int [], const int []); +cudnnStatus_t CUDNNWINAPI cudnnSetTensorNdDescriptor( + cudnnTensorDescriptor_t tensorDesc, cudnnDataType_t dataType, int nbDims, + const int dimA[], const int strideA[]) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnTensorDescriptor_t, cudnnDataType_t, int, const int[], const int[]); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetTensorNdDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(tensorDesc, dataType, nbDims, dimA, strideA); } -cudnnStatus_t CUDNNWINAPI -cudnnSetTensorNdDescriptorEx(cudnnTensorDescriptor_t tensorDesc, - cudnnTensorFormat_t format, - cudnnDataType_t dataType, - int nbDims, - const int dimA[]) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorDescriptor_t, cudnnTensorFormat_t, cudnnDataType_t, int, const int []); +cudnnStatus_t CUDNNWINAPI cudnnSetTensorNdDescriptorEx( + cudnnTensorDescriptor_t tensorDesc, cudnnTensorFormat_t format, + cudnnDataType_t dataType, int nbDims, const int dimA[]) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnTensorDescriptor_t, cudnnTensorFormat_t, + cudnnDataType_t, int, const int[]); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetTensorNdDescriptorEx"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(tensorDesc, format, dataType, nbDims, dimA); } -cudnnStatus_t CUDNNWINAPI -cudnnGetTensorNdDescriptor(const cudnnTensorDescriptor_t tensorDesc, - int nbDimsRequested, - cudnnDataType_t *dataType, - int *nbDims, - int dimA[], - int strideA[]) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnTensorDescriptor_t, int, cudnnDataType_t *, int *, int [], int []); +cudnnStatus_t CUDNNWINAPI cudnnGetTensorNdDescriptor( + const cudnnTensorDescriptor_t tensorDesc, int nbDimsRequested, + cudnnDataType_t *dataType, int *nbDims, int dimA[], int strideA[]) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(const cudnnTensorDescriptor_t, int, + cudnnDataType_t *, int *, int[], int[]); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetTensorNdDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(tensorDesc, nbDimsRequested, dataType, nbDims, dimA, strideA); } -cudnnStatus_t CUDNNWINAPI -cudnnGetTensorSizeInBytes(const cudnnTensorDescriptor_t tensorDesc, size_t *size) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnTensorDescriptor_t, size_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetTensorSizeInBytes( + const cudnnTensorDescriptor_t tensorDesc, size_t *size) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(const cudnnTensorDescriptor_t, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetTensorSizeInBytes"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(tensorDesc, size); @@ -177,35 +172,33 @@ cudnnGetTensorSizeInBytes(const cudnnTensorDescriptor_t tensorDesc, size_t *size cudnnStatus_t CUDNNWINAPI cudnnDestroyTensorDescriptor(cudnnTensorDescriptor_t tensorDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorDescriptor_t); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnTensorDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(tensorDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnTransformTensor(cudnnHandle_t handle, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t yDesc, - void *y) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnTransformTensor( + cudnnHandle_t handle, const void *alpha, + const cudnnTensorDescriptor_t xDesc, const void *x, const void *beta, + const cudnnTensorDescriptor_t yDesc, void *y) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, + const void *, const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnTransformTensor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, alpha, xDesc, x, beta, yDesc, y); } -cudnnStatus_t CUDNNWINAPI -cudnnAddTensor(cudnnHandle_t handle, - const void *alpha, - const cudnnTensorDescriptor_t aDesc, - const void *A, - const void *beta, - const cudnnTensorDescriptor_t cDesc, - void *C) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnAddTensor(cudnnHandle_t handle, + const void *alpha, + const cudnnTensorDescriptor_t aDesc, + const void *A, const void *beta, + const cudnnTensorDescriptor_t cDesc, + void *C) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, + const void *, const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnAddTensor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, alpha, aDesc, A, beta, cDesc, C); @@ -213,29 +206,29 @@ cudnnAddTensor(cudnnHandle_t handle, cudnnStatus_t CUDNNWINAPI cudnnCreateOpTensorDescriptor(cudnnOpTensorDescriptor_t *opTensorDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnOpTensorDescriptor_t *); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnOpTensorDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateOpTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(opTensorDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnSetOpTensorDescriptor(cudnnOpTensorDescriptor_t opTensorDesc, - cudnnOpTensorOp_t opTensorOp, - cudnnDataType_t opTensorCompType, - cudnnNanPropagation_t opTensorNanOpt) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnOpTensorDescriptor_t, cudnnOpTensorOp_t, cudnnDataType_t, cudnnNanPropagation_t); +cudnnStatus_t CUDNNWINAPI cudnnSetOpTensorDescriptor( + cudnnOpTensorDescriptor_t opTensorDesc, cudnnOpTensorOp_t opTensorOp, + cudnnDataType_t opTensorCompType, cudnnNanPropagation_t opTensorNanOpt) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnOpTensorDescriptor_t, cudnnOpTensorOp_t, + cudnnDataType_t, cudnnNanPropagation_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetOpTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(opTensorDesc, opTensorOp, opTensorCompType, opTensorNanOpt); } -cudnnStatus_t CUDNNWINAPI -cudnnGetOpTensorDescriptor(const cudnnOpTensorDescriptor_t opTensorDesc, - cudnnOpTensorOp_t *opTensorOp, - cudnnDataType_t *opTensorCompType, - cudnnNanPropagation_t *opTensorNanOpt) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnOpTensorDescriptor_t, cudnnOpTensorOp_t *, cudnnDataType_t *, cudnnNanPropagation_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetOpTensorDescriptor( + const cudnnOpTensorDescriptor_t opTensorDesc, cudnnOpTensorOp_t *opTensorOp, + cudnnDataType_t *opTensorCompType, cudnnNanPropagation_t *opTensorNanOpt) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnOpTensorDescriptor_t, cudnnOpTensorOp_t *, cudnnDataType_t *, + cudnnNanPropagation_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetOpTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(opTensorDesc, opTensorOp, opTensorCompType, opTensorNanOpt); @@ -243,126 +236,136 @@ cudnnGetOpTensorDescriptor(const cudnnOpTensorDescriptor_t opTensorDesc, cudnnStatus_t CUDNNWINAPI cudnnDestroyOpTensorDescriptor(cudnnOpTensorDescriptor_t opTensorDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnOpTensorDescriptor_t); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnOpTensorDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyOpTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(opTensorDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnOpTensor(cudnnHandle_t handle, - const cudnnOpTensorDescriptor_t opTensorDesc, - const void *alpha1, - const cudnnTensorDescriptor_t aDesc, - const void *A, - const void *alpha2, - const cudnnTensorDescriptor_t bDesc, - const void *B, - const void *beta, - const cudnnTensorDescriptor_t cDesc, - void *C) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnOpTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnOpTensor( + cudnnHandle_t handle, const cudnnOpTensorDescriptor_t opTensorDesc, + const void *alpha1, const cudnnTensorDescriptor_t aDesc, const void *A, + const void *alpha2, const cudnnTensorDescriptor_t bDesc, const void *B, + const void *beta, const cudnnTensorDescriptor_t cDesc, void *C) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnOpTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnOpTensor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, opTensorDesc, alpha1, aDesc, A, alpha2, bDesc, B, beta, cDesc, C); + return func_ptr(handle, opTensorDesc, alpha1, aDesc, A, alpha2, bDesc, B, + beta, cDesc, C); } -cudnnStatus_t CUDNNWINAPI -cudnnCreateReduceTensorDescriptor(cudnnReduceTensorDescriptor_t *reduceTensorDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnReduceTensorDescriptor_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateReduceTensorDescriptor"); +cudnnStatus_t CUDNNWINAPI cudnnCreateReduceTensorDescriptor( + cudnnReduceTensorDescriptor_t *reduceTensorDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnReduceTensorDescriptor_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnCreateReduceTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(reduceTensorDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnSetReduceTensorDescriptor(cudnnReduceTensorDescriptor_t reduceTensorDesc, - cudnnReduceTensorOp_t reduceTensorOp, - cudnnDataType_t reduceTensorCompType, - cudnnNanPropagation_t reduceTensorNanOpt, - cudnnReduceTensorIndices_t reduceTensorIndices, - cudnnIndicesType_t reduceTensorIndicesType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnReduceTensorDescriptor_t, cudnnReduceTensorOp_t, cudnnDataType_t, cudnnNanPropagation_t, cudnnReduceTensorIndices_t, cudnnIndicesType_t); +cudnnStatus_t CUDNNWINAPI cudnnSetReduceTensorDescriptor( + cudnnReduceTensorDescriptor_t reduceTensorDesc, + cudnnReduceTensorOp_t reduceTensorOp, cudnnDataType_t reduceTensorCompType, + cudnnNanPropagation_t reduceTensorNanOpt, + cudnnReduceTensorIndices_t reduceTensorIndices, + cudnnIndicesType_t reduceTensorIndicesType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnReduceTensorDescriptor_t, cudnnReduceTensorOp_t, cudnnDataType_t, + cudnnNanPropagation_t, cudnnReduceTensorIndices_t, cudnnIndicesType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetReduceTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(reduceTensorDesc, reduceTensorOp, reduceTensorCompType, reduceTensorNanOpt, reduceTensorIndices, reduceTensorIndicesType); + return func_ptr(reduceTensorDesc, reduceTensorOp, reduceTensorCompType, + reduceTensorNanOpt, reduceTensorIndices, + reduceTensorIndicesType); } -cudnnStatus_t CUDNNWINAPI -cudnnGetReduceTensorDescriptor(const cudnnReduceTensorDescriptor_t reduceTensorDesc, - cudnnReduceTensorOp_t *reduceTensorOp, - cudnnDataType_t *reduceTensorCompType, - cudnnNanPropagation_t *reduceTensorNanOpt, - cudnnReduceTensorIndices_t *reduceTensorIndices, - cudnnIndicesType_t *reduceTensorIndicesType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnReduceTensorDescriptor_t, cudnnReduceTensorOp_t *, cudnnDataType_t *, cudnnNanPropagation_t *, cudnnReduceTensorIndices_t *, cudnnIndicesType_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetReduceTensorDescriptor( + const cudnnReduceTensorDescriptor_t reduceTensorDesc, + cudnnReduceTensorOp_t *reduceTensorOp, + cudnnDataType_t *reduceTensorCompType, + cudnnNanPropagation_t *reduceTensorNanOpt, + cudnnReduceTensorIndices_t *reduceTensorIndices, + cudnnIndicesType_t *reduceTensorIndicesType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnReduceTensorDescriptor_t, cudnnReduceTensorOp_t *, + cudnnDataType_t *, cudnnNanPropagation_t *, cudnnReduceTensorIndices_t *, + cudnnIndicesType_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetReduceTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(reduceTensorDesc, reduceTensorOp, reduceTensorCompType, reduceTensorNanOpt, reduceTensorIndices, reduceTensorIndicesType); + return func_ptr(reduceTensorDesc, reduceTensorOp, reduceTensorCompType, + reduceTensorNanOpt, reduceTensorIndices, + reduceTensorIndicesType); } -cudnnStatus_t CUDNNWINAPI -cudnnDestroyReduceTensorDescriptor(cudnnReduceTensorDescriptor_t reduceTensorDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnReduceTensorDescriptor_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyReduceTensorDescriptor"); +cudnnStatus_t CUDNNWINAPI cudnnDestroyReduceTensorDescriptor( + cudnnReduceTensorDescriptor_t reduceTensorDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnReduceTensorDescriptor_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnDestroyReduceTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(reduceTensorDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnGetReductionIndicesSize(cudnnHandle_t handle, - const cudnnReduceTensorDescriptor_t reduceTensorDesc, - const cudnnTensorDescriptor_t aDesc, - const cudnnTensorDescriptor_t cDesc, - size_t *sizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnReduceTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, size_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetReductionIndicesSize( + cudnnHandle_t handle, const cudnnReduceTensorDescriptor_t reduceTensorDesc, + const cudnnTensorDescriptor_t aDesc, const cudnnTensorDescriptor_t cDesc, + size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnReduceTensorDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetReductionIndicesSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, reduceTensorDesc, aDesc, cDesc, sizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnGetReductionWorkspaceSize(cudnnHandle_t handle, - const cudnnReduceTensorDescriptor_t reduceTensorDesc, - const cudnnTensorDescriptor_t aDesc, - const cudnnTensorDescriptor_t cDesc, - size_t *sizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnReduceTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, size_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetReductionWorkspaceSize( + cudnnHandle_t handle, const cudnnReduceTensorDescriptor_t reduceTensorDesc, + const cudnnTensorDescriptor_t aDesc, const cudnnTensorDescriptor_t cDesc, + size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnReduceTensorDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetReductionWorkspaceSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, reduceTensorDesc, aDesc, cDesc, sizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnReduceTensor(cudnnHandle_t handle, - const cudnnReduceTensorDescriptor_t reduceTensorDesc, - void *indices, - size_t indicesSizeInBytes, - void *workspace, - size_t workspaceSizeInBytes, - const void *alpha, - const cudnnTensorDescriptor_t aDesc, - const void *A, - const void *beta, - const cudnnTensorDescriptor_t cDesc, - void *C) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnReduceTensorDescriptor_t, void *, size_t, void *, size_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnReduceTensor( + cudnnHandle_t handle, const cudnnReduceTensorDescriptor_t reduceTensorDesc, + void *indices, size_t indicesSizeInBytes, void *workspace, + size_t workspaceSizeInBytes, const void *alpha, + const cudnnTensorDescriptor_t aDesc, const void *A, const void *beta, + const cudnnTensorDescriptor_t cDesc, void *C) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnReduceTensorDescriptor_t, void *, size_t, + void *, size_t, const void *, const cudnnTensorDescriptor_t, const void *, + const void *, const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnReduceTensor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, reduceTensorDesc, indices, indicesSizeInBytes, workspace, workspaceSizeInBytes, alpha, aDesc, A, beta, cDesc, C); + return func_ptr(handle, reduceTensorDesc, indices, indicesSizeInBytes, + workspace, workspaceSizeInBytes, alpha, aDesc, A, beta, cDesc, + C); } -cudnnStatus_t CUDNNWINAPI -cudnnSetTensor(cudnnHandle_t handle, const cudnnTensorDescriptor_t yDesc, void *y, const void *valuePtr) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, void *, const void *); +cudnnStatus_t CUDNNWINAPI cudnnSetTensor(cudnnHandle_t handle, + const cudnnTensorDescriptor_t yDesc, + void *y, const void *valuePtr) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, void *, const void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetTensor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, yDesc, y, valuePtr); } -cudnnStatus_t CUDNNWINAPI -cudnnScaleTensor(cudnnHandle_t handle, const cudnnTensorDescriptor_t yDesc, void *y, const void *alpha) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, void *, const void *); +cudnnStatus_t CUDNNWINAPI cudnnScaleTensor(cudnnHandle_t handle, + const cudnnTensorDescriptor_t yDesc, + void *y, const void *alpha) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, void *, const void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnScaleTensor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, yDesc, y, alpha); @@ -370,68 +373,70 @@ cudnnScaleTensor(cudnnHandle_t handle, const cudnnTensorDescriptor_t yDesc, void cudnnStatus_t CUDNNWINAPI cudnnCreateFilterDescriptor(cudnnFilterDescriptor_t *filterDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnFilterDescriptor_t *); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnFilterDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateFilterDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(filterDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnSetFilter4dDescriptor(cudnnFilterDescriptor_t filterDesc, - cudnnDataType_t dataType, /* image data type */ - cudnnTensorFormat_t format, - int k, /* number of output feature maps */ - int c, /* number of input feature maps */ - int h, /* height of each input filter */ - int w) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnFilterDescriptor_t, cudnnDataType_t, cudnnTensorFormat_t, int, int, int, int); +cudnnStatus_t CUDNNWINAPI cudnnSetFilter4dDescriptor( + cudnnFilterDescriptor_t filterDesc, + cudnnDataType_t dataType, /* image data type */ + cudnnTensorFormat_t format, int k, /* number of output feature maps */ + int c, /* number of input feature maps */ + int h, /* height of each input filter */ + int w) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnFilterDescriptor_t, cudnnDataType_t, + cudnnTensorFormat_t, int, int, int, int); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetFilter4dDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(filterDesc, dataType, format, k, c, h, w); } -cudnnStatus_t CUDNNWINAPI -cudnnGetFilter4dDescriptor(const cudnnFilterDescriptor_t filterDesc, - cudnnDataType_t *dataType, /* image data type */ - cudnnTensorFormat_t *format, - int *k, /* number of output feature maps */ - int *c, /* number of input feature maps */ - int *h, /* height of each input filter */ - int *w) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnFilterDescriptor_t, cudnnDataType_t *, cudnnTensorFormat_t *, int *, int *, int *, int *); +cudnnStatus_t CUDNNWINAPI cudnnGetFilter4dDescriptor( + const cudnnFilterDescriptor_t filterDesc, + cudnnDataType_t *dataType, /* image data type */ + cudnnTensorFormat_t *format, int *k, /* number of output feature maps */ + int *c, /* number of input feature maps */ + int *h, /* height of each input filter */ + int *w) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnFilterDescriptor_t, cudnnDataType_t *, cudnnTensorFormat_t *, + int *, int *, int *, int *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetFilter4dDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(filterDesc, dataType, format, k, c, h, w); } -cudnnStatus_t CUDNNWINAPI -cudnnSetFilterNdDescriptor(cudnnFilterDescriptor_t filterDesc, - cudnnDataType_t dataType, /* image data type */ - cudnnTensorFormat_t format, - int nbDims, - const int filterDimA[]) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnFilterDescriptor_t, cudnnDataType_t, cudnnTensorFormat_t, int, const int []); +cudnnStatus_t CUDNNWINAPI cudnnSetFilterNdDescriptor( + cudnnFilterDescriptor_t filterDesc, + cudnnDataType_t dataType, /* image data type */ + cudnnTensorFormat_t format, int nbDims, const int filterDimA[]) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnFilterDescriptor_t, cudnnDataType_t, + cudnnTensorFormat_t, int, const int[]); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetFilterNdDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(filterDesc, dataType, format, nbDims, filterDimA); } -cudnnStatus_t CUDNNWINAPI -cudnnGetFilterNdDescriptor(const cudnnFilterDescriptor_t filterDesc, - int nbDimsRequested, - cudnnDataType_t *dataType, /* image data type */ - cudnnTensorFormat_t *format, - int *nbDims, - int filterDimA[]) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnFilterDescriptor_t, int, cudnnDataType_t *, cudnnTensorFormat_t *, int *, int []); +cudnnStatus_t CUDNNWINAPI cudnnGetFilterNdDescriptor( + const cudnnFilterDescriptor_t filterDesc, int nbDimsRequested, + cudnnDataType_t *dataType, /* image data type */ + cudnnTensorFormat_t *format, int *nbDims, int filterDimA[]) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnFilterDescriptor_t, int, cudnnDataType_t *, + cudnnTensorFormat_t *, int *, int[]); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetFilterNdDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(filterDesc, nbDimsRequested, dataType, format, nbDims, filterDimA); + return func_ptr(filterDesc, nbDimsRequested, dataType, format, nbDims, + filterDimA); } cudnnStatus_t CUDNNWINAPI cudnnDestroyFilterDescriptor(cudnnFilterDescriptor_t filterDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnFilterDescriptor_t); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnFilterDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyFilterDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(filterDesc); @@ -439,622 +444,657 @@ cudnnDestroyFilterDescriptor(cudnnFilterDescriptor_t filterDesc) { cudnnStatus_t CUDNNWINAPI cudnnCreateConvolutionDescriptor(cudnnConvolutionDescriptor_t *convDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateConvolutionDescriptor"); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnConvolutionDescriptor_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnCreateConvolutionDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(convDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnSetConvolutionMathType(cudnnConvolutionDescriptor_t convDesc, cudnnMathType_t mathType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, cudnnMathType_t); +cudnnStatus_t CUDNNWINAPI cudnnSetConvolutionMathType( + cudnnConvolutionDescriptor_t convDesc, cudnnMathType_t mathType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, + cudnnMathType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetConvolutionMathType"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(convDesc, mathType); } -cudnnStatus_t CUDNNWINAPI -cudnnGetConvolutionMathType(cudnnConvolutionDescriptor_t convDesc, cudnnMathType_t *mathType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, cudnnMathType_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionMathType( + cudnnConvolutionDescriptor_t convDesc, cudnnMathType_t *mathType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, + cudnnMathType_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionMathType"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(convDesc, mathType); } -cudnnStatus_t CUDNNWINAPI -cudnnSetConvolutionGroupCount(cudnnConvolutionDescriptor_t convDesc, int groupCount) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, int); +cudnnStatus_t CUDNNWINAPI cudnnSetConvolutionGroupCount( + cudnnConvolutionDescriptor_t convDesc, int groupCount) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, int); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetConvolutionGroupCount"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(convDesc, groupCount); } -cudnnStatus_t CUDNNWINAPI -cudnnGetConvolutionGroupCount(cudnnConvolutionDescriptor_t convDesc, int *groupCount) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, int *); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionGroupCount( + cudnnConvolutionDescriptor_t convDesc, int *groupCount) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, int *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionGroupCount"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(convDesc, groupCount); } -cudnnStatus_t CUDNNWINAPI -cudnnSetConvolution2dDescriptor(cudnnConvolutionDescriptor_t convDesc, - int pad_h, /* zero-padding height */ - int pad_w, /* zero-padding width */ - int u, /* vertical filter stride */ - int v, /* horizontal filter stride */ - int dilation_h, /* filter dilation in the vertical dimension */ - int dilation_w, /* filter dilation in the horizontal dimension */ - cudnnConvolutionMode_t mode, - cudnnDataType_t computeType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, int, int, int, int, int, int, cudnnConvolutionMode_t, cudnnDataType_t); +cudnnStatus_t CUDNNWINAPI cudnnSetConvolution2dDescriptor( + cudnnConvolutionDescriptor_t convDesc, int pad_h, /* zero-padding height */ + int pad_w, /* zero-padding width */ + int u, /* vertical filter stride */ + int v, /* horizontal filter stride */ + int dilation_h, /* filter dilation in the vertical dimension */ + int dilation_w, /* filter dilation in the horizontal dimension */ + cudnnConvolutionMode_t mode, cudnnDataType_t computeType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnConvolutionDescriptor_t, int, int, int, int, int, int, + cudnnConvolutionMode_t, cudnnDataType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetConvolution2dDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(convDesc, pad_h, pad_w, u, v, dilation_h, dilation_w, mode, computeType); + return func_ptr(convDesc, pad_h, pad_w, u, v, dilation_h, dilation_w, mode, + computeType); } -cudnnStatus_t CUDNNWINAPI -cudnnGetConvolution2dDescriptor(const cudnnConvolutionDescriptor_t convDesc, - int *pad_h, /* zero-padding height */ - int *pad_w, /* zero-padding width */ - int *u, /* vertical filter stride */ - int *v, /* horizontal filter stride */ - int *dilation_h, /* filter dilation in the vertical dimension */ - int *dilation_w, /* filter dilation in the horizontal dimension */ - cudnnConvolutionMode_t *mode, - cudnnDataType_t *computeType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnConvolutionDescriptor_t, int *, int *, int *, int *, int *, int *, cudnnConvolutionMode_t *, cudnnDataType_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolution2dDescriptor( + const cudnnConvolutionDescriptor_t convDesc, + int *pad_h, /* zero-padding height */ + int *pad_w, /* zero-padding width */ + int *u, /* vertical filter stride */ + int *v, /* horizontal filter stride */ + int *dilation_h, /* filter dilation in the vertical dimension */ + int *dilation_w, /* filter dilation in the horizontal dimension */ + cudnnConvolutionMode_t *mode, cudnnDataType_t *computeType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnConvolutionDescriptor_t, int *, int *, int *, int *, int *, + int *, cudnnConvolutionMode_t *, cudnnDataType_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolution2dDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(convDesc, pad_h, pad_w, u, v, dilation_h, dilation_w, mode, computeType); + return func_ptr(convDesc, pad_h, pad_w, u, v, dilation_h, dilation_w, mode, + computeType); } -cudnnStatus_t CUDNNWINAPI -cudnnGetConvolution2dForwardOutputDim(const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t inputTensorDesc, - const cudnnFilterDescriptor_t filterDesc, - int *n, - int *c, - int *h, - int *w) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, int *, int *, int *, int *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolution2dForwardOutputDim"); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolution2dForwardOutputDim( + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t inputTensorDesc, + const cudnnFilterDescriptor_t filterDesc, int *n, int *c, int *h, int *w) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, + const cudnnFilterDescriptor_t, int *, int *, int *, int *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolution2dForwardOutputDim"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(convDesc, inputTensorDesc, filterDesc, n, c, h, w); } -cudnnStatus_t CUDNNWINAPI -cudnnSetConvolutionNdDescriptor(cudnnConvolutionDescriptor_t convDesc, - int arrayLength, /* nbDims-2 size */ - const int padA[], - const int filterStrideA[], - const int dilationA[], - cudnnConvolutionMode_t mode, - cudnnDataType_t computeType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, int, const int [], const int [], const int [], cudnnConvolutionMode_t, cudnnDataType_t); +cudnnStatus_t CUDNNWINAPI cudnnSetConvolutionNdDescriptor( + cudnnConvolutionDescriptor_t convDesc, int arrayLength, /* nbDims-2 size */ + const int padA[], const int filterStrideA[], const int dilationA[], + cudnnConvolutionMode_t mode, cudnnDataType_t computeType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnConvolutionDescriptor_t, int, const int[], const int[], const int[], + cudnnConvolutionMode_t, cudnnDataType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetConvolutionNdDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(convDesc, arrayLength, padA, filterStrideA, dilationA, mode, computeType); + return func_ptr(convDesc, arrayLength, padA, filterStrideA, dilationA, mode, + computeType); } -cudnnStatus_t CUDNNWINAPI -cudnnGetConvolutionNdDescriptor(const cudnnConvolutionDescriptor_t convDesc, - int arrayLengthRequested, - int *arrayLength, - int padA[], - int strideA[], - int dilationA[], - cudnnConvolutionMode_t *mode, - cudnnDataType_t *computeType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnConvolutionDescriptor_t, int, int *, int [], int [], int [], cudnnConvolutionMode_t *, cudnnDataType_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionNdDescriptor( + const cudnnConvolutionDescriptor_t convDesc, int arrayLengthRequested, + int *arrayLength, int padA[], int strideA[], int dilationA[], + cudnnConvolutionMode_t *mode, cudnnDataType_t *computeType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnConvolutionDescriptor_t, int, int *, int[], int[], int[], + cudnnConvolutionMode_t *, cudnnDataType_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionNdDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(convDesc, arrayLengthRequested, arrayLength, padA, strideA, dilationA, mode, computeType); + return func_ptr(convDesc, arrayLengthRequested, arrayLength, padA, strideA, + dilationA, mode, computeType); } -cudnnStatus_t CUDNNWINAPI -cudnnGetConvolutionNdForwardOutputDim(const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t inputTensorDesc, - const cudnnFilterDescriptor_t filterDesc, - int nbDims, - int tensorOutputDimA[]) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, int, int []); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionNdForwardOutputDim"); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionNdForwardOutputDim( + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t inputTensorDesc, + const cudnnFilterDescriptor_t filterDesc, int nbDims, + int tensorOutputDimA[]) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, + const cudnnFilterDescriptor_t, int, int[]); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionNdForwardOutputDim"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(convDesc, inputTensorDesc, filterDesc, nbDims, tensorOutputDimA); + return func_ptr(convDesc, inputTensorDesc, filterDesc, nbDims, + tensorOutputDimA); } cudnnStatus_t CUDNNWINAPI cudnnDestroyConvolutionDescriptor(cudnnConvolutionDescriptor_t convDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyConvolutionDescriptor"); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnConvolutionDescriptor_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnDestroyConvolutionDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(convDesc); } cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionForwardAlgorithmMaxCount(cudnnHandle_t handle, int *count) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, int *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionForwardAlgorithmMaxCount"); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t, int *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionForwardAlgorithmMaxCount"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, count); } -cudnnStatus_t CUDNNWINAPI -cudnnFindConvolutionForwardAlgorithm(cudnnHandle_t handle, - const cudnnTensorDescriptor_t xDesc, - const cudnnFilterDescriptor_t wDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t yDesc, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnConvolutionFwdAlgoPerf_t *perfResults) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, const int, int *, cudnnConvolutionFwdAlgoPerf_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindConvolutionForwardAlgorithm"); +cudnnStatus_t CUDNNWINAPI cudnnFindConvolutionForwardAlgorithm( + cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, + const cudnnFilterDescriptor_t wDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t yDesc, const int requestedAlgoCount, + int *returnedAlgoCount, cudnnConvolutionFwdAlgoPerf_t *perfResults) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, + const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnTensorDescriptor_t, const int, int *, + cudnnConvolutionFwdAlgoPerf_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnFindConvolutionForwardAlgorithm"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, xDesc, wDesc, convDesc, yDesc, requestedAlgoCount, returnedAlgoCount, perfResults); + return func_ptr(handle, xDesc, wDesc, convDesc, yDesc, requestedAlgoCount, + returnedAlgoCount, perfResults); } -cudnnStatus_t CUDNNWINAPI -cudnnFindConvolutionForwardAlgorithmEx(cudnnHandle_t handle, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t yDesc, - void *y, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnConvolutionFwdAlgoPerf_t *perfResults, - void *workSpace, - size_t workSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, void *, const int, int *, cudnnConvolutionFwdAlgoPerf_t *, void *, size_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindConvolutionForwardAlgorithmEx"); +cudnnStatus_t CUDNNWINAPI cudnnFindConvolutionForwardAlgorithmEx( + cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, const void *x, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t yDesc, void *y, const int requestedAlgoCount, + int *returnedAlgoCount, cudnnConvolutionFwdAlgoPerf_t *perfResults, + void *workSpace, size_t workSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, void *, + const int, int *, cudnnConvolutionFwdAlgoPerf_t *, void *, size_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnFindConvolutionForwardAlgorithmEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, xDesc, x, wDesc, w, convDesc, yDesc, y, requestedAlgoCount, returnedAlgoCount, perfResults, workSpace, workSpaceSizeInBytes); + return func_ptr(handle, xDesc, x, wDesc, w, convDesc, yDesc, y, + requestedAlgoCount, returnedAlgoCount, perfResults, workSpace, + workSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnGetConvolutionForwardAlgorithm(cudnnHandle_t handle, - const cudnnTensorDescriptor_t xDesc, - const cudnnFilterDescriptor_t wDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t yDesc, - cudnnConvolutionFwdPreference_t preference, - size_t memoryLimitInBytes, - cudnnConvolutionFwdAlgo_t *algo) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, cudnnConvolutionFwdPreference_t, size_t, cudnnConvolutionFwdAlgo_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionForwardAlgorithm"); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionForwardAlgorithm( + cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, + const cudnnFilterDescriptor_t wDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t yDesc, + cudnnConvolutionFwdPreference_t preference, size_t memoryLimitInBytes, + cudnnConvolutionFwdAlgo_t *algo) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, + const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnTensorDescriptor_t, cudnnConvolutionFwdPreference_t, size_t, + cudnnConvolutionFwdAlgo_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionForwardAlgorithm"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, xDesc, wDesc, convDesc, yDesc, preference, memoryLimitInBytes, algo); + return func_ptr(handle, xDesc, wDesc, convDesc, yDesc, preference, + memoryLimitInBytes, algo); } -cudnnStatus_t CUDNNWINAPI -cudnnGetConvolutionForwardAlgorithm_v7(cudnnHandle_t handle, - const cudnnTensorDescriptor_t srcDesc, - const cudnnFilterDescriptor_t filterDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t destDesc, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnConvolutionFwdAlgoPerf_t *perfResults) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, const int, int *, cudnnConvolutionFwdAlgoPerf_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionForwardAlgorithm_v7"); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionForwardAlgorithm_v7( + cudnnHandle_t handle, const cudnnTensorDescriptor_t srcDesc, + const cudnnFilterDescriptor_t filterDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t destDesc, const int requestedAlgoCount, + int *returnedAlgoCount, cudnnConvolutionFwdAlgoPerf_t *perfResults) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, + const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnTensorDescriptor_t, const int, int *, + cudnnConvolutionFwdAlgoPerf_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionForwardAlgorithm_v7"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, srcDesc, filterDesc, convDesc, destDesc, requestedAlgoCount, returnedAlgoCount, perfResults); + return func_ptr(handle, srcDesc, filterDesc, convDesc, destDesc, + requestedAlgoCount, returnedAlgoCount, perfResults); } -cudnnStatus_t CUDNNWINAPI -cudnnGetConvolutionForwardWorkspaceSize(cudnnHandle_t handle, - const cudnnTensorDescriptor_t xDesc, - const cudnnFilterDescriptor_t wDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t yDesc, - cudnnConvolutionFwdAlgo_t algo, - size_t *sizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, cudnnConvolutionFwdAlgo_t, size_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionForwardWorkspaceSize"); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionForwardWorkspaceSize( + cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, + const cudnnFilterDescriptor_t wDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t yDesc, cudnnConvolutionFwdAlgo_t algo, + size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, + const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnTensorDescriptor_t, cudnnConvolutionFwdAlgo_t, size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionForwardWorkspaceSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, xDesc, wDesc, convDesc, yDesc, algo, sizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnConvolutionForward(cudnnHandle_t handle, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnConvolutionDescriptor_t convDesc, - cudnnConvolutionFwdAlgo_t algo, - void *workSpace, - size_t workSpaceSizeInBytes, - const void *beta, - const cudnnTensorDescriptor_t yDesc, - void *y) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnConvolutionDescriptor_t, cudnnConvolutionFwdAlgo_t, void *, size_t, const void *, const cudnnTensorDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnConvolutionForward( + cudnnHandle_t handle, const void *alpha, + const cudnnTensorDescriptor_t xDesc, const void *x, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnConvolutionDescriptor_t convDesc, cudnnConvolutionFwdAlgo_t algo, + void *workSpace, size_t workSpaceSizeInBytes, const void *beta, + const cudnnTensorDescriptor_t yDesc, void *y) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnConvolutionDescriptor_t, cudnnConvolutionFwdAlgo_t, void *, + size_t, const void *, const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnConvolutionForward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, alpha, xDesc, x, wDesc, w, convDesc, algo, workSpace, workSpaceSizeInBytes, beta, yDesc, y); + return func_ptr(handle, alpha, xDesc, x, wDesc, w, convDesc, algo, workSpace, + workSpaceSizeInBytes, beta, yDesc, y); } -cudnnStatus_t CUDNNWINAPI -cudnnConvolutionBiasActivationForward(cudnnHandle_t handle, - const void *alpha1, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnConvolutionDescriptor_t convDesc, - cudnnConvolutionFwdAlgo_t algo, - void *workSpace, - size_t workSpaceSizeInBytes, - const void *alpha2, - const cudnnTensorDescriptor_t zDesc, - const void *z, - const cudnnTensorDescriptor_t biasDesc, - const void *bias, - const cudnnActivationDescriptor_t activationDesc, - const cudnnTensorDescriptor_t yDesc, - void *y) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnConvolutionDescriptor_t, cudnnConvolutionFwdAlgo_t, void *, size_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnActivationDescriptor_t, const cudnnTensorDescriptor_t, void *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnConvolutionBiasActivationForward"); +cudnnStatus_t CUDNNWINAPI cudnnConvolutionBiasActivationForward( + cudnnHandle_t handle, const void *alpha1, + const cudnnTensorDescriptor_t xDesc, const void *x, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnConvolutionDescriptor_t convDesc, cudnnConvolutionFwdAlgo_t algo, + void *workSpace, size_t workSpaceSizeInBytes, const void *alpha2, + const cudnnTensorDescriptor_t zDesc, const void *z, + const cudnnTensorDescriptor_t biasDesc, const void *bias, + const cudnnActivationDescriptor_t activationDesc, + const cudnnTensorDescriptor_t yDesc, void *y) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnConvolutionDescriptor_t, cudnnConvolutionFwdAlgo_t, void *, + size_t, const void *, const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnActivationDescriptor_t, const cudnnTensorDescriptor_t, void *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnConvolutionBiasActivationForward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, alpha1, xDesc, x, wDesc, w, convDesc, algo, workSpace, workSpaceSizeInBytes, alpha2, zDesc, z, biasDesc, bias, activationDesc, yDesc, y); + return func_ptr(handle, alpha1, xDesc, x, wDesc, w, convDesc, algo, workSpace, + workSpaceSizeInBytes, alpha2, zDesc, z, biasDesc, bias, + activationDesc, yDesc, y); } -cudnnStatus_t CUDNNWINAPI -cudnnConvolutionBackwardBias(cudnnHandle_t handle, - const void *alpha, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const void *beta, - const cudnnTensorDescriptor_t dbDesc, - void *db) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnConvolutionBackwardBias( + cudnnHandle_t handle, const void *alpha, + const cudnnTensorDescriptor_t dyDesc, const void *dy, const void *beta, + const cudnnTensorDescriptor_t dbDesc, void *db) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, + const void *, const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnConvolutionBackwardBias"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, alpha, dyDesc, dy, beta, dbDesc, db); } -cudnnStatus_t CUDNNWINAPI -cudnnGetConvolutionBackwardFilterAlgorithmMaxCount(cudnnHandle_t handle, int *count) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, int *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardFilterAlgorithmMaxCount"); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionBackwardFilterAlgorithmMaxCount( + cudnnHandle_t handle, int *count) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t, int *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardFilterAlgorithmMaxCount"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, count); } -cudnnStatus_t CUDNNWINAPI -cudnnFindConvolutionBackwardFilterAlgorithm(cudnnHandle_t handle, - const cudnnTensorDescriptor_t xDesc, - const cudnnTensorDescriptor_t dyDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnFilterDescriptor_t dwDesc, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnConvolutionBwdFilterAlgoPerf_t *perfResults) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnFilterDescriptor_t, const int, int *, cudnnConvolutionBwdFilterAlgoPerf_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindConvolutionBackwardFilterAlgorithm"); +cudnnStatus_t CUDNNWINAPI cudnnFindConvolutionBackwardFilterAlgorithm( + cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, + const cudnnTensorDescriptor_t dyDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnFilterDescriptor_t dwDesc, const int requestedAlgoCount, + int *returnedAlgoCount, cudnnConvolutionBwdFilterAlgoPerf_t *perfResults) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnFilterDescriptor_t, const int, int *, + cudnnConvolutionBwdFilterAlgoPerf_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnFindConvolutionBackwardFilterAlgorithm"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, xDesc, dyDesc, convDesc, dwDesc, requestedAlgoCount, returnedAlgoCount, perfResults); + return func_ptr(handle, xDesc, dyDesc, convDesc, dwDesc, requestedAlgoCount, + returnedAlgoCount, perfResults); } -cudnnStatus_t CUDNNWINAPI -cudnnFindConvolutionBackwardFilterAlgorithmEx(cudnnHandle_t handle, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const cudnnTensorDescriptor_t dyDesc, - const void *y, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnFilterDescriptor_t dwDesc, - void *dw, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnConvolutionBwdFilterAlgoPerf_t *perfResults, - void *workSpace, - size_t workSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnConvolutionDescriptor_t, const cudnnFilterDescriptor_t, void *, const int, int *, cudnnConvolutionBwdFilterAlgoPerf_t *, void *, size_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindConvolutionBackwardFilterAlgorithmEx"); +cudnnStatus_t CUDNNWINAPI cudnnFindConvolutionBackwardFilterAlgorithmEx( + cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, const void *x, + const cudnnTensorDescriptor_t dyDesc, const void *y, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnFilterDescriptor_t dwDesc, void *dw, + const int requestedAlgoCount, int *returnedAlgoCount, + cudnnConvolutionBwdFilterAlgoPerf_t *perfResults, void *workSpace, + size_t workSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnConvolutionDescriptor_t, const cudnnFilterDescriptor_t, void *, + const int, int *, cudnnConvolutionBwdFilterAlgoPerf_t *, void *, size_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnFindConvolutionBackwardFilterAlgorithmEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, xDesc, x, dyDesc, y, convDesc, dwDesc, dw, requestedAlgoCount, returnedAlgoCount, perfResults, workSpace, workSpaceSizeInBytes); + return func_ptr(handle, xDesc, x, dyDesc, y, convDesc, dwDesc, dw, + requestedAlgoCount, returnedAlgoCount, perfResults, workSpace, + workSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnGetConvolutionBackwardFilterAlgorithm(cudnnHandle_t handle, - const cudnnTensorDescriptor_t xDesc, - const cudnnTensorDescriptor_t dyDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnFilterDescriptor_t dwDesc, - cudnnConvolutionBwdFilterPreference_t preference, - size_t memoryLimitInBytes, - cudnnConvolutionBwdFilterAlgo_t *algo) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnFilterDescriptor_t, cudnnConvolutionBwdFilterPreference_t, size_t, cudnnConvolutionBwdFilterAlgo_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardFilterAlgorithm"); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionBackwardFilterAlgorithm( + cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, + const cudnnTensorDescriptor_t dyDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnFilterDescriptor_t dwDesc, + cudnnConvolutionBwdFilterPreference_t preference, size_t memoryLimitInBytes, + cudnnConvolutionBwdFilterAlgo_t *algo) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnFilterDescriptor_t, cudnnConvolutionBwdFilterPreference_t, + size_t, cudnnConvolutionBwdFilterAlgo_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardFilterAlgorithm"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, xDesc, dyDesc, convDesc, dwDesc, preference, memoryLimitInBytes, algo); + return func_ptr(handle, xDesc, dyDesc, convDesc, dwDesc, preference, + memoryLimitInBytes, algo); } -cudnnStatus_t CUDNNWINAPI -cudnnGetConvolutionBackwardFilterAlgorithm_v7(cudnnHandle_t handle, - const cudnnTensorDescriptor_t srcDesc, - const cudnnTensorDescriptor_t diffDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnFilterDescriptor_t gradDesc, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnConvolutionBwdFilterAlgoPerf_t *perfResults) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnFilterDescriptor_t, const int, int *, cudnnConvolutionBwdFilterAlgoPerf_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardFilterAlgorithm_v7"); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionBackwardFilterAlgorithm_v7( + cudnnHandle_t handle, const cudnnTensorDescriptor_t srcDesc, + const cudnnTensorDescriptor_t diffDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnFilterDescriptor_t gradDesc, const int requestedAlgoCount, + int *returnedAlgoCount, cudnnConvolutionBwdFilterAlgoPerf_t *perfResults) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnFilterDescriptor_t, const int, int *, + cudnnConvolutionBwdFilterAlgoPerf_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardFilterAlgorithm_v7"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, srcDesc, diffDesc, convDesc, gradDesc, requestedAlgoCount, returnedAlgoCount, perfResults); + return func_ptr(handle, srcDesc, diffDesc, convDesc, gradDesc, + requestedAlgoCount, returnedAlgoCount, perfResults); } -cudnnStatus_t CUDNNWINAPI -cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnnHandle_t handle, - const cudnnTensorDescriptor_t xDesc, - const cudnnTensorDescriptor_t dyDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnFilterDescriptor_t gradDesc, - cudnnConvolutionBwdFilterAlgo_t algo, - size_t *sizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnFilterDescriptor_t, cudnnConvolutionBwdFilterAlgo_t, size_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardFilterWorkspaceSize"); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionBackwardFilterWorkspaceSize( + cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, + const cudnnTensorDescriptor_t dyDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnFilterDescriptor_t gradDesc, + cudnnConvolutionBwdFilterAlgo_t algo, size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnFilterDescriptor_t, cudnnConvolutionBwdFilterAlgo_t, size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardFilterWorkspaceSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, xDesc, dyDesc, convDesc, gradDesc, algo, sizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnConvolutionBackwardFilter(cudnnHandle_t handle, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const cudnnConvolutionDescriptor_t convDesc, - cudnnConvolutionBwdFilterAlgo_t algo, - void *workSpace, - size_t workSpaceSizeInBytes, - const void *beta, - const cudnnFilterDescriptor_t dwDesc, - void *dw) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnConvolutionDescriptor_t, cudnnConvolutionBwdFilterAlgo_t, void *, size_t, const void *, const cudnnFilterDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnConvolutionBackwardFilter( + cudnnHandle_t handle, const void *alpha, + const cudnnTensorDescriptor_t xDesc, const void *x, + const cudnnTensorDescriptor_t dyDesc, const void *dy, + const cudnnConvolutionDescriptor_t convDesc, + cudnnConvolutionBwdFilterAlgo_t algo, void *workSpace, + size_t workSpaceSizeInBytes, const void *beta, + const cudnnFilterDescriptor_t dwDesc, void *dw) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnConvolutionDescriptor_t, cudnnConvolutionBwdFilterAlgo_t, + void *, size_t, const void *, const cudnnFilterDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnConvolutionBackwardFilter"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, alpha, xDesc, x, dyDesc, dy, convDesc, algo, workSpace, workSpaceSizeInBytes, beta, dwDesc, dw); + return func_ptr(handle, alpha, xDesc, x, dyDesc, dy, convDesc, algo, + workSpace, workSpaceSizeInBytes, beta, dwDesc, dw); } -cudnnStatus_t CUDNNWINAPI -cudnnGetConvolutionBackwardDataAlgorithmMaxCount(cudnnHandle_t handle, int *count) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, int *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardDataAlgorithmMaxCount"); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionBackwardDataAlgorithmMaxCount( + cudnnHandle_t handle, int *count) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t, int *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardDataAlgorithmMaxCount"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, count); } -cudnnStatus_t CUDNNWINAPI -cudnnFindConvolutionBackwardDataAlgorithm(cudnnHandle_t handle, - const cudnnFilterDescriptor_t wDesc, - const cudnnTensorDescriptor_t dyDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t dxDesc, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnConvolutionBwdDataAlgoPerf_t *perfResults) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnFilterDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, const int, int *, cudnnConvolutionBwdDataAlgoPerf_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindConvolutionBackwardDataAlgorithm"); +cudnnStatus_t CUDNNWINAPI cudnnFindConvolutionBackwardDataAlgorithm( + cudnnHandle_t handle, const cudnnFilterDescriptor_t wDesc, + const cudnnTensorDescriptor_t dyDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t dxDesc, const int requestedAlgoCount, + int *returnedAlgoCount, cudnnConvolutionBwdDataAlgoPerf_t *perfResults) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnFilterDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnTensorDescriptor_t, const int, int *, + cudnnConvolutionBwdDataAlgoPerf_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnFindConvolutionBackwardDataAlgorithm"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, wDesc, dyDesc, convDesc, dxDesc, requestedAlgoCount, returnedAlgoCount, perfResults); + return func_ptr(handle, wDesc, dyDesc, convDesc, dxDesc, requestedAlgoCount, + returnedAlgoCount, perfResults); } -cudnnStatus_t CUDNNWINAPI -cudnnFindConvolutionBackwardDataAlgorithmEx(cudnnHandle_t handle, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t dxDesc, - void *dx, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnConvolutionBwdDataAlgoPerf_t *perfResults, - void *workSpace, - size_t workSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnFilterDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, void *, const int, int *, cudnnConvolutionBwdDataAlgoPerf_t *, void *, size_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindConvolutionBackwardDataAlgorithmEx"); +cudnnStatus_t CUDNNWINAPI cudnnFindConvolutionBackwardDataAlgorithmEx( + cudnnHandle_t handle, const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnTensorDescriptor_t dyDesc, const void *dy, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t dxDesc, void *dx, + const int requestedAlgoCount, int *returnedAlgoCount, + cudnnConvolutionBwdDataAlgoPerf_t *perfResults, void *workSpace, + size_t workSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnFilterDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, void *, + const int, int *, cudnnConvolutionBwdDataAlgoPerf_t *, void *, size_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnFindConvolutionBackwardDataAlgorithmEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, wDesc, w, dyDesc, dy, convDesc, dxDesc, dx, requestedAlgoCount, returnedAlgoCount, perfResults, workSpace, workSpaceSizeInBytes); + return func_ptr(handle, wDesc, w, dyDesc, dy, convDesc, dxDesc, dx, + requestedAlgoCount, returnedAlgoCount, perfResults, workSpace, + workSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnGetConvolutionBackwardDataAlgorithm(cudnnHandle_t handle, - const cudnnFilterDescriptor_t wDesc, - const cudnnTensorDescriptor_t dyDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t dxDesc, - cudnnConvolutionBwdDataPreference_t preference, - size_t memoryLimitInBytes, - cudnnConvolutionBwdDataAlgo_t *algo) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnFilterDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, cudnnConvolutionBwdDataPreference_t, size_t, cudnnConvolutionBwdDataAlgo_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardDataAlgorithm"); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionBackwardDataAlgorithm( + cudnnHandle_t handle, const cudnnFilterDescriptor_t wDesc, + const cudnnTensorDescriptor_t dyDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t dxDesc, + cudnnConvolutionBwdDataPreference_t preference, size_t memoryLimitInBytes, + cudnnConvolutionBwdDataAlgo_t *algo) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnFilterDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnTensorDescriptor_t, cudnnConvolutionBwdDataPreference_t, + size_t, cudnnConvolutionBwdDataAlgo_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardDataAlgorithm"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, wDesc, dyDesc, convDesc, dxDesc, preference, memoryLimitInBytes, algo); + return func_ptr(handle, wDesc, dyDesc, convDesc, dxDesc, preference, + memoryLimitInBytes, algo); } -cudnnStatus_t CUDNNWINAPI -cudnnGetConvolutionBackwardDataAlgorithm_v7(cudnnHandle_t handle, - const cudnnFilterDescriptor_t filterDesc, - const cudnnTensorDescriptor_t diffDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t gradDesc, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnConvolutionBwdDataAlgoPerf_t *perfResults) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnFilterDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, const int, int *, cudnnConvolutionBwdDataAlgoPerf_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardDataAlgorithm_v7"); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionBackwardDataAlgorithm_v7( + cudnnHandle_t handle, const cudnnFilterDescriptor_t filterDesc, + const cudnnTensorDescriptor_t diffDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t gradDesc, const int requestedAlgoCount, + int *returnedAlgoCount, cudnnConvolutionBwdDataAlgoPerf_t *perfResults) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnFilterDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnTensorDescriptor_t, const int, int *, + cudnnConvolutionBwdDataAlgoPerf_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardDataAlgorithm_v7"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, filterDesc, diffDesc, convDesc, gradDesc, requestedAlgoCount, returnedAlgoCount, perfResults); + return func_ptr(handle, filterDesc, diffDesc, convDesc, gradDesc, + requestedAlgoCount, returnedAlgoCount, perfResults); } -cudnnStatus_t CUDNNWINAPI -cudnnGetConvolutionBackwardDataWorkspaceSize(cudnnHandle_t handle, - const cudnnFilterDescriptor_t wDesc, - const cudnnTensorDescriptor_t dyDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t dxDesc, - cudnnConvolutionBwdDataAlgo_t algo, - size_t *sizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnFilterDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, cudnnConvolutionBwdDataAlgo_t, size_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardDataWorkspaceSize"); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionBackwardDataWorkspaceSize( + cudnnHandle_t handle, const cudnnFilterDescriptor_t wDesc, + const cudnnTensorDescriptor_t dyDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t dxDesc, cudnnConvolutionBwdDataAlgo_t algo, + size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnFilterDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnTensorDescriptor_t, cudnnConvolutionBwdDataAlgo_t, size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardDataWorkspaceSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, wDesc, dyDesc, convDesc, dxDesc, algo, sizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnConvolutionBackwardData(cudnnHandle_t handle, - const void *alpha, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const cudnnConvolutionDescriptor_t convDesc, - cudnnConvolutionBwdDataAlgo_t algo, - void *workSpace, - size_t workSpaceSizeInBytes, - const void *beta, - const cudnnTensorDescriptor_t dxDesc, - void *dx) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnConvolutionDescriptor_t, cudnnConvolutionBwdDataAlgo_t, void *, size_t, const void *, const cudnnTensorDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnConvolutionBackwardData( + cudnnHandle_t handle, const void *alpha, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnTensorDescriptor_t dyDesc, const void *dy, + const cudnnConvolutionDescriptor_t convDesc, + cudnnConvolutionBwdDataAlgo_t algo, void *workSpace, + size_t workSpaceSizeInBytes, const void *beta, + const cudnnTensorDescriptor_t dxDesc, void *dx) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const void *, const cudnnFilterDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnConvolutionDescriptor_t, cudnnConvolutionBwdDataAlgo_t, void *, + size_t, const void *, const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnConvolutionBackwardData"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, alpha, wDesc, w, dyDesc, dy, convDesc, algo, workSpace, workSpaceSizeInBytes, beta, dxDesc, dx); + return func_ptr(handle, alpha, wDesc, w, dyDesc, dy, convDesc, algo, + workSpace, workSpaceSizeInBytes, beta, dxDesc, dx); } cudnnStatus_t CUDNNWINAPI -cudnnIm2Col(cudnnHandle_t handle, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const cudnnFilterDescriptor_t wDesc, - const cudnnConvolutionDescriptor_t convDesc, - void *colBuffer) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, void *); +cudnnIm2Col(cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, + const void *x, const cudnnFilterDescriptor_t wDesc, + const cudnnConvolutionDescriptor_t convDesc, void *colBuffer) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, + const void *, const cudnnFilterDescriptor_t, + const cudnnConvolutionDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnIm2Col"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, xDesc, x, wDesc, convDesc, colBuffer); } -cudnnStatus_t CUDNNWINAPI -cudnnSoftmaxForward(cudnnHandle_t handle, - cudnnSoftmaxAlgorithm_t algo, - cudnnSoftmaxMode_t mode, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t yDesc, - void *y) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnSoftmaxAlgorithm_t, cudnnSoftmaxMode_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnSoftmaxForward( + cudnnHandle_t handle, cudnnSoftmaxAlgorithm_t algo, cudnnSoftmaxMode_t mode, + const void *alpha, const cudnnTensorDescriptor_t xDesc, const void *x, + const void *beta, const cudnnTensorDescriptor_t yDesc, void *y) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnSoftmaxAlgorithm_t, cudnnSoftmaxMode_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSoftmaxForward"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, algo, mode, alpha, xDesc, x, beta, yDesc, y); } -cudnnStatus_t CUDNNWINAPI -cudnnSoftmaxBackward(cudnnHandle_t handle, - cudnnSoftmaxAlgorithm_t algo, - cudnnSoftmaxMode_t mode, - const void *alpha, - const cudnnTensorDescriptor_t yDesc, - const void *y, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const void *beta, - const cudnnTensorDescriptor_t dxDesc, - void *dx) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnSoftmaxAlgorithm_t, cudnnSoftmaxMode_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnSoftmaxBackward( + cudnnHandle_t handle, cudnnSoftmaxAlgorithm_t algo, cudnnSoftmaxMode_t mode, + const void *alpha, const cudnnTensorDescriptor_t yDesc, const void *y, + const cudnnTensorDescriptor_t dyDesc, const void *dy, const void *beta, + const cudnnTensorDescriptor_t dxDesc, void *dx) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnSoftmaxAlgorithm_t, cudnnSoftmaxMode_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSoftmaxBackward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, algo, mode, alpha, yDesc, y, dyDesc, dy, beta, dxDesc, dx); + return func_ptr(handle, algo, mode, alpha, yDesc, y, dyDesc, dy, beta, dxDesc, + dx); } cudnnStatus_t CUDNNWINAPI cudnnCreatePoolingDescriptor(cudnnPoolingDescriptor_t *poolingDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnPoolingDescriptor_t *); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnPoolingDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreatePoolingDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(poolingDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnSetPooling2dDescriptor(cudnnPoolingDescriptor_t poolingDesc, - cudnnPoolingMode_t mode, - cudnnNanPropagation_t maxpoolingNanOpt, - int windowHeight, - int windowWidth, - int verticalPadding, - int horizontalPadding, - int verticalStride, - int horizontalStride) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnPoolingDescriptor_t, cudnnPoolingMode_t, cudnnNanPropagation_t, int, int, int, int, int, int); +cudnnStatus_t CUDNNWINAPI cudnnSetPooling2dDescriptor( + cudnnPoolingDescriptor_t poolingDesc, cudnnPoolingMode_t mode, + cudnnNanPropagation_t maxpoolingNanOpt, int windowHeight, int windowWidth, + int verticalPadding, int horizontalPadding, int verticalStride, + int horizontalStride) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnPoolingDescriptor_t, cudnnPoolingMode_t, cudnnNanPropagation_t, int, + int, int, int, int, int); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetPooling2dDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(poolingDesc, mode, maxpoolingNanOpt, windowHeight, windowWidth, verticalPadding, horizontalPadding, verticalStride, horizontalStride); + return func_ptr(poolingDesc, mode, maxpoolingNanOpt, windowHeight, + windowWidth, verticalPadding, horizontalPadding, + verticalStride, horizontalStride); } -cudnnStatus_t CUDNNWINAPI -cudnnGetPooling2dDescriptor(const cudnnPoolingDescriptor_t poolingDesc, - cudnnPoolingMode_t *mode, - cudnnNanPropagation_t *maxpoolingNanOpt, - int *windowHeight, - int *windowWidth, - int *verticalPadding, - int *horizontalPadding, - int *verticalStride, - int *horizontalStride) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnPoolingDescriptor_t, cudnnPoolingMode_t *, cudnnNanPropagation_t *, int *, int *, int *, int *, int *, int *); +cudnnStatus_t CUDNNWINAPI cudnnGetPooling2dDescriptor( + const cudnnPoolingDescriptor_t poolingDesc, cudnnPoolingMode_t *mode, + cudnnNanPropagation_t *maxpoolingNanOpt, int *windowHeight, + int *windowWidth, int *verticalPadding, int *horizontalPadding, + int *verticalStride, int *horizontalStride) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnPoolingDescriptor_t, cudnnPoolingMode_t *, + cudnnNanPropagation_t *, int *, int *, int *, int *, int *, int *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetPooling2dDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(poolingDesc, mode, maxpoolingNanOpt, windowHeight, windowWidth, verticalPadding, horizontalPadding, verticalStride, horizontalStride); + return func_ptr(poolingDesc, mode, maxpoolingNanOpt, windowHeight, + windowWidth, verticalPadding, horizontalPadding, + verticalStride, horizontalStride); } -cudnnStatus_t CUDNNWINAPI -cudnnSetPoolingNdDescriptor(cudnnPoolingDescriptor_t poolingDesc, - const cudnnPoolingMode_t mode, - const cudnnNanPropagation_t maxpoolingNanOpt, - int nbDims, - const int windowDimA[], - const int paddingA[], - const int strideA[]) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnPoolingDescriptor_t, const cudnnPoolingMode_t, const cudnnNanPropagation_t, int, const int [], const int [], const int []); +cudnnStatus_t CUDNNWINAPI cudnnSetPoolingNdDescriptor( + cudnnPoolingDescriptor_t poolingDesc, const cudnnPoolingMode_t mode, + const cudnnNanPropagation_t maxpoolingNanOpt, int nbDims, + const int windowDimA[], const int paddingA[], const int strideA[]) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnPoolingDescriptor_t, const cudnnPoolingMode_t, + const cudnnNanPropagation_t, int, const int[], const int[], const int[]); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetPoolingNdDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(poolingDesc, mode, maxpoolingNanOpt, nbDims, windowDimA, paddingA, strideA); + return func_ptr(poolingDesc, mode, maxpoolingNanOpt, nbDims, windowDimA, + paddingA, strideA); } -cudnnStatus_t CUDNNWINAPI -cudnnGetPoolingNdDescriptor(const cudnnPoolingDescriptor_t poolingDesc, - int nbDimsRequested, - cudnnPoolingMode_t *mode, - cudnnNanPropagation_t *maxpoolingNanOpt, - int *nbDims, - int windowDimA[], - int paddingA[], - int strideA[]) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnPoolingDescriptor_t, int, cudnnPoolingMode_t *, cudnnNanPropagation_t *, int *, int [], int [], int []); +cudnnStatus_t CUDNNWINAPI cudnnGetPoolingNdDescriptor( + const cudnnPoolingDescriptor_t poolingDesc, int nbDimsRequested, + cudnnPoolingMode_t *mode, cudnnNanPropagation_t *maxpoolingNanOpt, + int *nbDims, int windowDimA[], int paddingA[], int strideA[]) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnPoolingDescriptor_t, int, cudnnPoolingMode_t *, + cudnnNanPropagation_t *, int *, int[], int[], int[]); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetPoolingNdDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(poolingDesc, nbDimsRequested, mode, maxpoolingNanOpt, nbDims, windowDimA, paddingA, strideA); + return func_ptr(poolingDesc, nbDimsRequested, mode, maxpoolingNanOpt, nbDims, + windowDimA, paddingA, strideA); } cudnnStatus_t CUDNNWINAPI cudnnGetPoolingNdForwardOutputDim(const cudnnPoolingDescriptor_t poolingDesc, const cudnnTensorDescriptor_t inputTensorDesc, - int nbDims, - int outputTensorDimA[]) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnPoolingDescriptor_t, const cudnnTensorDescriptor_t, int, int []); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetPoolingNdForwardOutputDim"); + int nbDims, int outputTensorDimA[]) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(const cudnnPoolingDescriptor_t, + const cudnnTensorDescriptor_t, int, int[]); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetPoolingNdForwardOutputDim"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(poolingDesc, inputTensorDesc, nbDims, outputTensorDimA); } @@ -1062,72 +1102,69 @@ cudnnGetPoolingNdForwardOutputDim(const cudnnPoolingDescriptor_t poolingDesc, cudnnStatus_t CUDNNWINAPI cudnnGetPooling2dForwardOutputDim(const cudnnPoolingDescriptor_t poolingDesc, const cudnnTensorDescriptor_t inputTensorDesc, - int *n, - int *c, - int *h, - int *w) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnPoolingDescriptor_t, const cudnnTensorDescriptor_t, int *, int *, int *, int *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetPooling2dForwardOutputDim"); + int *n, int *c, int *h, int *w) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(const cudnnPoolingDescriptor_t, + const cudnnTensorDescriptor_t, + int *, int *, int *, int *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetPooling2dForwardOutputDim"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(poolingDesc, inputTensorDesc, n, c, h, w); } cudnnStatus_t CUDNNWINAPI cudnnDestroyPoolingDescriptor(cudnnPoolingDescriptor_t poolingDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnPoolingDescriptor_t); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnPoolingDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyPoolingDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(poolingDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnPoolingForward(cudnnHandle_t handle, - const cudnnPoolingDescriptor_t poolingDesc, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t yDesc, - void *y) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnPoolingDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnPoolingForward( + cudnnHandle_t handle, const cudnnPoolingDescriptor_t poolingDesc, + const void *alpha, const cudnnTensorDescriptor_t xDesc, const void *x, + const void *beta, const cudnnTensorDescriptor_t yDesc, void *y) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnPoolingDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnPoolingForward"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, poolingDesc, alpha, xDesc, x, beta, yDesc, y); } -cudnnStatus_t CUDNNWINAPI -cudnnPoolingBackward(cudnnHandle_t handle, - const cudnnPoolingDescriptor_t poolingDesc, - const void *alpha, - const cudnnTensorDescriptor_t yDesc, - const void *y, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t dxDesc, - void *dx) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnPoolingDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnPoolingBackward( + cudnnHandle_t handle, const cudnnPoolingDescriptor_t poolingDesc, + const void *alpha, const cudnnTensorDescriptor_t yDesc, const void *y, + const cudnnTensorDescriptor_t dyDesc, const void *dy, + const cudnnTensorDescriptor_t xDesc, const void *x, const void *beta, + const cudnnTensorDescriptor_t dxDesc, void *dx) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnPoolingDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnPoolingBackward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, poolingDesc, alpha, yDesc, y, dyDesc, dy, xDesc, x, beta, dxDesc, dx); + return func_ptr(handle, poolingDesc, alpha, yDesc, y, dyDesc, dy, xDesc, x, + beta, dxDesc, dx); } cudnnStatus_t CUDNNWINAPI cudnnCreateActivationDescriptor(cudnnActivationDescriptor_t *activationDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnActivationDescriptor_t *); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnActivationDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateActivationDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(activationDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnSetActivationDescriptor(cudnnActivationDescriptor_t activationDesc, - cudnnActivationMode_t mode, - cudnnNanPropagation_t reluNanOpt, - double coef) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnActivationDescriptor_t, cudnnActivationMode_t, cudnnNanPropagation_t, double); +cudnnStatus_t CUDNNWINAPI cudnnSetActivationDescriptor( + cudnnActivationDescriptor_t activationDesc, cudnnActivationMode_t mode, + cudnnNanPropagation_t reluNanOpt, double coef) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnActivationDescriptor_t, + cudnnActivationMode_t, + cudnnNanPropagation_t, double); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetActivationDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(activationDesc, mode, reluNanOpt, coef); @@ -1136,9 +1173,10 @@ cudnnSetActivationDescriptor(cudnnActivationDescriptor_t activationDesc, cudnnStatus_t CUDNNWINAPI cudnnGetActivationDescriptor(const cudnnActivationDescriptor_t activationDesc, cudnnActivationMode_t *mode, - cudnnNanPropagation_t *reluNanOpt, - double *coef) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnActivationDescriptor_t, cudnnActivationMode_t *, cudnnNanPropagation_t *, double *); + cudnnNanPropagation_t *reluNanOpt, double *coef) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnActivationDescriptor_t, cudnnActivationMode_t *, + cudnnNanPropagation_t *, double *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetActivationDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(activationDesc, mode, reluNanOpt, coef); @@ -1146,65 +1184,68 @@ cudnnGetActivationDescriptor(const cudnnActivationDescriptor_t activationDesc, cudnnStatus_t CUDNNWINAPI cudnnDestroyActivationDescriptor(cudnnActivationDescriptor_t activationDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnActivationDescriptor_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyActivationDescriptor"); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnActivationDescriptor_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnDestroyActivationDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(activationDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnActivationForward(cudnnHandle_t handle, - cudnnActivationDescriptor_t activationDesc, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t yDesc, - void *y) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnActivationDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnActivationForward( + cudnnHandle_t handle, cudnnActivationDescriptor_t activationDesc, + const void *alpha, const cudnnTensorDescriptor_t xDesc, const void *x, + const void *beta, const cudnnTensorDescriptor_t yDesc, void *y) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnActivationDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnActivationForward"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, activationDesc, alpha, xDesc, x, beta, yDesc, y); } -cudnnStatus_t CUDNNWINAPI -cudnnActivationBackward(cudnnHandle_t handle, - cudnnActivationDescriptor_t activationDesc, - const void *alpha, - const cudnnTensorDescriptor_t yDesc, - const void *y, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t dxDesc, - void *dx) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnActivationDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnActivationBackward( + cudnnHandle_t handle, cudnnActivationDescriptor_t activationDesc, + const void *alpha, const cudnnTensorDescriptor_t yDesc, const void *y, + const cudnnTensorDescriptor_t dyDesc, const void *dy, + const cudnnTensorDescriptor_t xDesc, const void *x, const void *beta, + const cudnnTensorDescriptor_t dxDesc, void *dx) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnActivationDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnActivationBackward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, activationDesc, alpha, yDesc, y, dyDesc, dy, xDesc, x, beta, dxDesc, dx); + return func_ptr(handle, activationDesc, alpha, yDesc, y, dyDesc, dy, xDesc, x, + beta, dxDesc, dx); } cudnnStatus_t CUDNNWINAPI cudnnCreateLRNDescriptor(cudnnLRNDescriptor_t *normDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnLRNDescriptor_t *); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnLRNDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateLRNDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(normDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnSetLRNDescriptor(cudnnLRNDescriptor_t normDesc, unsigned lrnN, double lrnAlpha, double lrnBeta, double lrnK) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnLRNDescriptor_t, unsigned int, double, double, double); +cudnnStatus_t CUDNNWINAPI cudnnSetLRNDescriptor(cudnnLRNDescriptor_t normDesc, + unsigned lrnN, double lrnAlpha, + double lrnBeta, double lrnK) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnLRNDescriptor_t, unsigned int, double, double, double); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetLRNDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(normDesc, lrnN, lrnAlpha, lrnBeta, lrnK); } -cudnnStatus_t CUDNNWINAPI -cudnnGetLRNDescriptor(cudnnLRNDescriptor_t normDesc, unsigned *lrnN, double *lrnAlpha, double *lrnBeta, double *lrnK) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnLRNDescriptor_t, unsigned int *, double *, double *, double *); +cudnnStatus_t CUDNNWINAPI cudnnGetLRNDescriptor(cudnnLRNDescriptor_t normDesc, + unsigned *lrnN, + double *lrnAlpha, + double *lrnBeta, double *lrnK) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnLRNDescriptor_t, unsigned int *, double *, double *, double *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetLRNDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(normDesc, lrnN, lrnAlpha, lrnBeta, lrnK); @@ -1212,157 +1253,157 @@ cudnnGetLRNDescriptor(cudnnLRNDescriptor_t normDesc, unsigned *lrnN, double *lrn cudnnStatus_t CUDNNWINAPI cudnnDestroyLRNDescriptor(cudnnLRNDescriptor_t lrnDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnLRNDescriptor_t); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnLRNDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyLRNDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(lrnDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnLRNCrossChannelForward(cudnnHandle_t handle, - cudnnLRNDescriptor_t normDesc, - cudnnLRNMode_t lrnMode, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t yDesc, - void *y) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnLRNDescriptor_t, cudnnLRNMode_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnLRNCrossChannelForward( + cudnnHandle_t handle, cudnnLRNDescriptor_t normDesc, cudnnLRNMode_t lrnMode, + const void *alpha, const cudnnTensorDescriptor_t xDesc, const void *x, + const void *beta, const cudnnTensorDescriptor_t yDesc, void *y) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnLRNDescriptor_t, cudnnLRNMode_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnLRNCrossChannelForward"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, normDesc, lrnMode, alpha, xDesc, x, beta, yDesc, y); } -cudnnStatus_t CUDNNWINAPI -cudnnLRNCrossChannelBackward(cudnnHandle_t handle, - cudnnLRNDescriptor_t normDesc, - cudnnLRNMode_t lrnMode, - const void *alpha, - const cudnnTensorDescriptor_t yDesc, - const void *y, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t dxDesc, - void *dx) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnLRNDescriptor_t, cudnnLRNMode_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnLRNCrossChannelBackward( + cudnnHandle_t handle, cudnnLRNDescriptor_t normDesc, cudnnLRNMode_t lrnMode, + const void *alpha, const cudnnTensorDescriptor_t yDesc, const void *y, + const cudnnTensorDescriptor_t dyDesc, const void *dy, + const cudnnTensorDescriptor_t xDesc, const void *x, const void *beta, + const cudnnTensorDescriptor_t dxDesc, void *dx) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnLRNDescriptor_t, cudnnLRNMode_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnLRNCrossChannelBackward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, normDesc, lrnMode, alpha, yDesc, y, dyDesc, dy, xDesc, x, beta, dxDesc, dx); + return func_ptr(handle, normDesc, lrnMode, alpha, yDesc, y, dyDesc, dy, xDesc, + x, beta, dxDesc, dx); } -cudnnStatus_t CUDNNWINAPI -cudnnDivisiveNormalizationForward(cudnnHandle_t handle, - cudnnLRNDescriptor_t normDesc, - cudnnDivNormMode_t mode, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, /* same desc for means, temp, temp2 */ - const void *x, - const void *means, /* if NULL, means are assumed to be zero */ - void *temp, - void *temp2, - const void *beta, - const cudnnTensorDescriptor_t yDesc, - void *y) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnLRNDescriptor_t, cudnnDivNormMode_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, void *, void *, const void *, const cudnnTensorDescriptor_t, void *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDivisiveNormalizationForward"); +cudnnStatus_t CUDNNWINAPI cudnnDivisiveNormalizationForward( + cudnnHandle_t handle, cudnnLRNDescriptor_t normDesc, + cudnnDivNormMode_t mode, const void *alpha, + const cudnnTensorDescriptor_t xDesc, /* same desc for means, temp, temp2 */ + const void *x, + const void *means, /* if NULL, means are assumed to be zero */ + void *temp, void *temp2, const void *beta, + const cudnnTensorDescriptor_t yDesc, void *y) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnLRNDescriptor_t, cudnnDivNormMode_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, void *, void *, + const void *, const cudnnTensorDescriptor_t, void *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnDivisiveNormalizationForward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, normDesc, mode, alpha, xDesc, x, means, temp, temp2, beta, yDesc, y); + return func_ptr(handle, normDesc, mode, alpha, xDesc, x, means, temp, temp2, + beta, yDesc, y); } -cudnnStatus_t CUDNNWINAPI -cudnnDivisiveNormalizationBackward(cudnnHandle_t handle, - cudnnLRNDescriptor_t normDesc, - cudnnDivNormMode_t mode, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, /* same desc for x, means, dy, temp, temp2 */ - const void *x, - const void *means, /* if NULL, means are assumed to be zero */ - const void *dy, - void *temp, - void *temp2, - const void *beta, - const cudnnTensorDescriptor_t dXdMeansDesc, /* same desc for dx, dMeans */ - void *dx, /* output x differential */ - void *dMeans) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnLRNDescriptor_t, cudnnDivNormMode_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const void *, void *, void *, const void *, const cudnnTensorDescriptor_t, void *, void *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDivisiveNormalizationBackward"); +cudnnStatus_t CUDNNWINAPI cudnnDivisiveNormalizationBackward( + cudnnHandle_t handle, cudnnLRNDescriptor_t normDesc, + cudnnDivNormMode_t mode, const void *alpha, + const cudnnTensorDescriptor_t + xDesc, /* same desc for x, means, dy, temp, temp2 */ + const void *x, + const void *means, /* if NULL, means are assumed to be zero */ + const void *dy, void *temp, void *temp2, const void *beta, + const cudnnTensorDescriptor_t dXdMeansDesc, /* same desc for dx, dMeans */ + void *dx, /* output x differential */ + void *dMeans) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnLRNDescriptor_t, cudnnDivNormMode_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, const void *, + void *, void *, const void *, const cudnnTensorDescriptor_t, void *, + void *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnDivisiveNormalizationBackward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, normDesc, mode, alpha, xDesc, x, means, dy, temp, temp2, beta, dXdMeansDesc, dx, dMeans); + return func_ptr(handle, normDesc, mode, alpha, xDesc, x, means, dy, temp, + temp2, beta, dXdMeansDesc, dx, dMeans); } -cudnnStatus_t CUDNNWINAPI -cudnnDeriveBNTensorDescriptor(cudnnTensorDescriptor_t derivedBnDesc, - const cudnnTensorDescriptor_t xDesc, - cudnnBatchNormMode_t mode) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, cudnnBatchNormMode_t); +cudnnStatus_t CUDNNWINAPI cudnnDeriveBNTensorDescriptor( + cudnnTensorDescriptor_t derivedBnDesc, const cudnnTensorDescriptor_t xDesc, + cudnnBatchNormMode_t mode) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnTensorDescriptor_t, + const cudnnTensorDescriptor_t, + cudnnBatchNormMode_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDeriveBNTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(derivedBnDesc, xDesc, mode); } cudnnStatus_t CUDNNWINAPI -cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(cudnnHandle_t handle, - cudnnBatchNormMode_t mode, - cudnnBatchNormOps_t bnOps, - const cudnnTensorDescriptor_t xDesc, - const cudnnTensorDescriptor_t zDesc, - const cudnnTensorDescriptor_t yDesc, - const cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc, - const cudnnActivationDescriptor_t activationDesc, - size_t *sizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnBatchNormMode_t, cudnnBatchNormOps_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnActivationDescriptor_t, size_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize"); +cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize( + cudnnHandle_t handle, cudnnBatchNormMode_t mode, cudnnBatchNormOps_t bnOps, + const cudnnTensorDescriptor_t xDesc, const cudnnTensorDescriptor_t zDesc, + const cudnnTensorDescriptor_t yDesc, + const cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc, + const cudnnActivationDescriptor_t activationDesc, size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnBatchNormMode_t, cudnnBatchNormOps_t, + const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, + const cudnnActivationDescriptor_t, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>( + "cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, mode, bnOps, xDesc, zDesc, yDesc, bnScaleBiasMeanVarDesc, activationDesc, sizeInBytes); + return func_ptr(handle, mode, bnOps, xDesc, zDesc, yDesc, + bnScaleBiasMeanVarDesc, activationDesc, sizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnGetBatchNormalizationBackwardExWorkspaceSize(cudnnHandle_t handle, - cudnnBatchNormMode_t mode, - cudnnBatchNormOps_t bnOps, - const cudnnTensorDescriptor_t xDesc, - const cudnnTensorDescriptor_t yDesc, - const cudnnTensorDescriptor_t dyDesc, - const cudnnTensorDescriptor_t dzDesc, - const cudnnTensorDescriptor_t dxDesc, - const cudnnTensorDescriptor_t dBnScaleBiasDesc, - const cudnnActivationDescriptor_t activationDesc, - size_t *sizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnBatchNormMode_t, cudnnBatchNormOps_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnActivationDescriptor_t, size_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetBatchNormalizationBackwardExWorkspaceSize"); +cudnnStatus_t CUDNNWINAPI cudnnGetBatchNormalizationBackwardExWorkspaceSize( + cudnnHandle_t handle, cudnnBatchNormMode_t mode, cudnnBatchNormOps_t bnOps, + const cudnnTensorDescriptor_t xDesc, const cudnnTensorDescriptor_t yDesc, + const cudnnTensorDescriptor_t dyDesc, const cudnnTensorDescriptor_t dzDesc, + const cudnnTensorDescriptor_t dxDesc, + const cudnnTensorDescriptor_t dBnScaleBiasDesc, + const cudnnActivationDescriptor_t activationDesc, size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnBatchNormMode_t, cudnnBatchNormOps_t, + const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, + const cudnnActivationDescriptor_t, size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetBatchNormalizationBackwardExWorkspaceSize"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, mode, bnOps, xDesc, yDesc, dyDesc, dzDesc, dxDesc, dBnScaleBiasDesc, activationDesc, sizeInBytes); + return func_ptr(handle, mode, bnOps, xDesc, yDesc, dyDesc, dzDesc, dxDesc, + dBnScaleBiasDesc, activationDesc, sizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnGetBatchNormalizationTrainingExReserveSpaceSize(cudnnHandle_t handle, - cudnnBatchNormMode_t mode, - cudnnBatchNormOps_t bnOps, - const cudnnActivationDescriptor_t activationDesc, - const cudnnTensorDescriptor_t xDesc, - size_t *sizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnBatchNormMode_t, cudnnBatchNormOps_t, const cudnnActivationDescriptor_t, const cudnnTensorDescriptor_t, size_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetBatchNormalizationTrainingExReserveSpaceSize"); +cudnnStatus_t CUDNNWINAPI cudnnGetBatchNormalizationTrainingExReserveSpaceSize( + cudnnHandle_t handle, cudnnBatchNormMode_t mode, cudnnBatchNormOps_t bnOps, + const cudnnActivationDescriptor_t activationDesc, + const cudnnTensorDescriptor_t xDesc, size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnBatchNormMode_t, cudnnBatchNormOps_t, + const cudnnActivationDescriptor_t, const cudnnTensorDescriptor_t, + size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>( + "cudnnGetBatchNormalizationTrainingExReserveSpaceSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, mode, bnOps, activationDesc, xDesc, sizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnBatchNormalizationForwardTraining( - cudnnHandle_t handle, - cudnnBatchNormMode_t mode, +cudnnStatus_t CUDNNWINAPI cudnnBatchNormalizationForwardTraining( + cudnnHandle_t handle, cudnnBatchNormMode_t mode, const void *alpha, /* alpha[0] = result blend factor */ const void *beta, /* beta[0] = dest layer blend factor */ - const cudnnTensorDescriptor_t xDesc, - const void *x, /* NxCxHxW */ - const cudnnTensorDescriptor_t yDesc, - void *y, /* NxCxHxW */ + const cudnnTensorDescriptor_t xDesc, const void *x, /* NxCxHxW */ + const cudnnTensorDescriptor_t yDesc, void *y, /* NxCxHxW */ /* Shared desc for the next 6 tensors in the argument list. Data type to be set as follows: @@ -1370,13 +1411,13 @@ cudnnBatchNormalizationForwardTraining( Dimensions for this descriptor depend on normalization mode - Spatial Normalization : tensors are expected to have dims 1xCx1x1 (normalization is performed across NxHxW) - - Per-Activation Normalization : tensors are expected to have dims of 1xCxHxW - (normalization is performed across N) */ + - Per-Activation Normalization : tensors are expected to have dims of + 1xCxHxW (normalization is performed across N) */ const cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc, - /* 'Gamma' and 'Beta' respectively in Ioffe and Szegedy's paper's notation */ - const void *bnScale, - const void *bnBias, + /* 'Gamma' and 'Beta' respectively in Ioffe and Szegedy's paper's notation + */ + const void *bnScale, const void *bnBias, /* MUST use factor=1 in the very first call of a complete training cycle. Use a factor=1/(1+n) at N-th call to the function to get @@ -1394,248 +1435,261 @@ cudnnBatchNormalizationForwardTraining( of variance[x] (factor is applied in the same way as for runningMean) */ void *resultRunningVariance, - /* Has to be >= CUDNN_BN_MIN_EPSILON. Should be the same in forward and backward functions. */ + /* Has to be >= CUDNN_BN_MIN_EPSILON. Should be the same in forward and + backward functions. */ double epsilon, /* Optionally save intermediate results from the forward pass here - can be reused to speed up backward pass. NULL if unused */ - void *resultSaveMean, - void *resultSaveInvVariance) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnBatchNormMode_t, const void *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, const void *, const void *, double, void *, void *, double, void *, void *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnBatchNormalizationForwardTraining"); + void *resultSaveMean, void *resultSaveInvVariance) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnBatchNormMode_t, const void *, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, + const void *, const void *, double, void *, void *, double, void *, + void *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnBatchNormalizationForwardTraining"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, mode, alpha, beta, xDesc, x, yDesc, y, bnScaleBiasMeanVarDesc, bnScale, bnBias, exponentialAverageFactor, resultRunningMean, resultRunningVariance, epsilon, resultSaveMean, resultSaveInvVariance); + return func_ptr( + handle, mode, alpha, beta, xDesc, x, yDesc, y, bnScaleBiasMeanVarDesc, + bnScale, bnBias, exponentialAverageFactor, resultRunningMean, + resultRunningVariance, epsilon, resultSaveMean, resultSaveInvVariance); } -cudnnStatus_t CUDNNWINAPI -cudnnBatchNormalizationForwardTrainingEx( - cudnnHandle_t handle, - cudnnBatchNormMode_t mode, - cudnnBatchNormOps_t bnOps, +cudnnStatus_t CUDNNWINAPI cudnnBatchNormalizationForwardTrainingEx( + cudnnHandle_t handle, cudnnBatchNormMode_t mode, cudnnBatchNormOps_t bnOps, const void *alpha, /* alpha[0] = result blend factor */ const void *beta, /* beta[0] = dest layer blend factor */ - const cudnnTensorDescriptor_t xDesc, - const void *xData, - const cudnnTensorDescriptor_t zDesc, - const void *zData, - const cudnnTensorDescriptor_t yDesc, - void *yData, + const cudnnTensorDescriptor_t xDesc, const void *xData, + const cudnnTensorDescriptor_t zDesc, const void *zData, + const cudnnTensorDescriptor_t yDesc, void *yData, - const cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc, - const void *bnScale, + const cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc, const void *bnScale, const void *bnBias, - double exponentialAverageFactor, - void *resultRunningMean, + double exponentialAverageFactor, void *resultRunningMean, void *resultRunningVariance, - /* Has to be >= CUDNN_BN_MIN_EPSILON. Should be the same in forward and backward functions. */ + /* Has to be >= CUDNN_BN_MIN_EPSILON. Should be the same in forward and + backward functions. */ double epsilon, /* Optionally save intermediate results from the forward pass here - can be reused to speed up backward pass. NULL if unused */ - void *resultSaveMean, - void *resultSaveInvVariance, + void *resultSaveMean, void *resultSaveInvVariance, - cudnnActivationDescriptor_t activationDesc, - void *workspace, - size_t workSpaceSizeInBytes, - void *reserveSpace, + cudnnActivationDescriptor_t activationDesc, void *workspace, + size_t workSpaceSizeInBytes, void *reserveSpace, size_t reserveSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnBatchNormMode_t, cudnnBatchNormOps_t, const void *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, const void *, const void *, double, void *, void *, double, void *, void *, cudnnActivationDescriptor_t, void *, size_t, void *, size_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnBatchNormalizationForwardTrainingEx"); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnBatchNormMode_t, cudnnBatchNormOps_t, const void *, + const void *, const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, + const void *, const void *, double, void *, void *, double, void *, + void *, cudnnActivationDescriptor_t, void *, size_t, void *, size_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnBatchNormalizationForwardTrainingEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, mode, bnOps, alpha, beta, xDesc, xData, zDesc, zData, yDesc, yData, bnScaleBiasMeanVarDesc, bnScale, bnBias, exponentialAverageFactor, resultRunningMean, resultRunningVariance, epsilon, resultSaveMean, resultSaveInvVariance, activationDesc, workspace, workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes); + return func_ptr(handle, mode, bnOps, alpha, beta, xDesc, xData, zDesc, zData, + yDesc, yData, bnScaleBiasMeanVarDesc, bnScale, bnBias, + exponentialAverageFactor, resultRunningMean, + resultRunningVariance, epsilon, resultSaveMean, + resultSaveInvVariance, activationDesc, workspace, + workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnBatchNormalizationForwardInference(cudnnHandle_t handle, - cudnnBatchNormMode_t mode, - const void *alpha, /* alpha[0] = result blend factor */ - const void *beta, /* beta[0] = dest layer blend factor */ - const cudnnTensorDescriptor_t xDesc, - const void *x, /* NxCxHxW */ - const cudnnTensorDescriptor_t yDesc, - void *y, /* NxCxHxW */ - const cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc, - const void *bnScale, - const void *bnBias, - const void *estimatedMean, - const void *estimatedVariance, - double epsilon) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnBatchNormMode_t, const void *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, const void *, const void *, const void *, const void *, double); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnBatchNormalizationForwardInference"); +cudnnStatus_t CUDNNWINAPI cudnnBatchNormalizationForwardInference( + cudnnHandle_t handle, cudnnBatchNormMode_t mode, + const void *alpha, /* alpha[0] = result blend factor */ + const void *beta, /* beta[0] = dest layer blend factor */ + const cudnnTensorDescriptor_t xDesc, const void *x, /* NxCxHxW */ + const cudnnTensorDescriptor_t yDesc, void *y, /* NxCxHxW */ + const cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc, const void *bnScale, + const void *bnBias, const void *estimatedMean, + const void *estimatedVariance, double epsilon) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnBatchNormMode_t, const void *, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, + const void *, const void *, const void *, const void *, double); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnBatchNormalizationForwardInference"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, mode, alpha, beta, xDesc, x, yDesc, y, bnScaleBiasMeanVarDesc, bnScale, bnBias, estimatedMean, estimatedVariance, epsilon); + return func_ptr(handle, mode, alpha, beta, xDesc, x, yDesc, y, + bnScaleBiasMeanVarDesc, bnScale, bnBias, estimatedMean, + estimatedVariance, epsilon); } -cudnnStatus_t CUDNNWINAPI -cudnnBatchNormalizationBackward(cudnnHandle_t handle, - cudnnBatchNormMode_t mode, - const void *alphaDataDiff, - const void *betaDataDiff, - const void *alphaParamDiff, - const void *betaParamDiff, - const cudnnTensorDescriptor_t xDesc, /* same desc for x, dx, dy */ - const void *x, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const cudnnTensorDescriptor_t dxDesc, - void *dx, - /* Shared tensor desc for the 4 tensors below */ - const cudnnTensorDescriptor_t dBnScaleBiasDesc, - const void *bnScale, /* bnBias doesn't affect backpropagation */ - /* scale and bias diff are not backpropagated below this layer */ - void *dBnScaleResult, - void *dBnBiasResult, - /* Same epsilon as forward pass */ - double epsilon, +cudnnStatus_t CUDNNWINAPI cudnnBatchNormalizationBackward( + cudnnHandle_t handle, cudnnBatchNormMode_t mode, const void *alphaDataDiff, + const void *betaDataDiff, const void *alphaParamDiff, + const void *betaParamDiff, + const cudnnTensorDescriptor_t xDesc, /* same desc for x, dx, dy */ + const void *x, const cudnnTensorDescriptor_t dyDesc, const void *dy, + const cudnnTensorDescriptor_t dxDesc, void *dx, + /* Shared tensor desc for the 4 tensors below */ + const cudnnTensorDescriptor_t dBnScaleBiasDesc, + const void *bnScale, /* bnBias doesn't affect backpropagation */ + /* scale and bias diff are not backpropagated below this layer */ + void *dBnScaleResult, void *dBnBiasResult, + /* Same epsilon as forward pass */ + double epsilon, - /* Optionally cached intermediate results from - forward pass */ - const void *savedMean, - const void *savedInvVariance) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnBatchNormMode_t, const void *, const void *, const void *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, const void *, void *, void *, double, const void *, const void *); + /* Optionally cached intermediate results from + forward pass */ + const void *savedMean, const void *savedInvVariance) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnBatchNormMode_t, const void *, const void *, + const void *, const void *, const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, + const void *, void *, void *, double, const void *, const void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnBatchNormalizationBackward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, mode, alphaDataDiff, betaDataDiff, alphaParamDiff, betaParamDiff, xDesc, x, dyDesc, dy, dxDesc, dx, dBnScaleBiasDesc, bnScale, dBnScaleResult, dBnBiasResult, epsilon, savedMean, savedInvVariance); + return func_ptr(handle, mode, alphaDataDiff, betaDataDiff, alphaParamDiff, + betaParamDiff, xDesc, x, dyDesc, dy, dxDesc, dx, + dBnScaleBiasDesc, bnScale, dBnScaleResult, dBnBiasResult, + epsilon, savedMean, savedInvVariance); } -cudnnStatus_t CUDNNWINAPI -cudnnBatchNormalizationBackwardEx(cudnnHandle_t handle, - cudnnBatchNormMode_t mode, - cudnnBatchNormOps_t bnOps, +cudnnStatus_t CUDNNWINAPI cudnnBatchNormalizationBackwardEx( + cudnnHandle_t handle, cudnnBatchNormMode_t mode, cudnnBatchNormOps_t bnOps, - const void *alphaDataDiff, - const void *betaDataDiff, - const void *alphaParamDiff, - const void *betaParamDiff, - const cudnnTensorDescriptor_t xDesc, - const void *xData, - const cudnnTensorDescriptor_t yDesc, - const void *yData, - const cudnnTensorDescriptor_t dyDesc, - const void *dyData, - const cudnnTensorDescriptor_t dzDesc, - void *dzData, - const cudnnTensorDescriptor_t dxDesc, - void *dxData, + const void *alphaDataDiff, const void *betaDataDiff, + const void *alphaParamDiff, const void *betaParamDiff, + const cudnnTensorDescriptor_t xDesc, const void *xData, + const cudnnTensorDescriptor_t yDesc, const void *yData, + const cudnnTensorDescriptor_t dyDesc, const void *dyData, + const cudnnTensorDescriptor_t dzDesc, void *dzData, + const cudnnTensorDescriptor_t dxDesc, void *dxData, - /* Shared tensor desc for the 4 tensors below */ - const cudnnTensorDescriptor_t dBnScaleBiasDesc, - const void *bnScaleData, - const void *bnBiasData, /* needed if there is activation */ - void *dBnScaleData, - void *dBnBiasData, - double epsilon, /* Same epsilon as forward pass */ + /* Shared tensor desc for the 4 tensors below */ + const cudnnTensorDescriptor_t dBnScaleBiasDesc, const void *bnScaleData, + const void *bnBiasData, /* needed if there is activation */ + void *dBnScaleData, void *dBnBiasData, + double epsilon, /* Same epsilon as forward pass */ - /* Optionally cached intermediate results from - forward pass */ - const void *savedMean, - const void *savedInvVariance, - cudnnActivationDescriptor_t activationDesc, - void *workSpace, - size_t workSpaceSizeInBytes, - void *reserveSpace, - size_t reserveSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnBatchNormMode_t, cudnnBatchNormOps_t, const void *, const void *, const void *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, const void *, const void *, void *, void *, double, const void *, const void *, cudnnActivationDescriptor_t, void *, size_t, void *, size_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnBatchNormalizationBackwardEx"); + /* Optionally cached intermediate results from + forward pass */ + const void *savedMean, const void *savedInvVariance, + cudnnActivationDescriptor_t activationDesc, void *workSpace, + size_t workSpaceSizeInBytes, void *reserveSpace, + size_t reserveSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnBatchNormMode_t, cudnnBatchNormOps_t, const void *, + const void *, const void *, const void *, const cudnnTensorDescriptor_t, + const void *, const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, + void *, const cudnnTensorDescriptor_t, const void *, const void *, void *, + void *, double, const void *, const void *, cudnnActivationDescriptor_t, + void *, size_t, void *, size_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnBatchNormalizationBackwardEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, mode, bnOps, alphaDataDiff, betaDataDiff, alphaParamDiff, betaParamDiff, xDesc, xData, yDesc, yData, dyDesc, dyData, dzDesc, dzData, dxDesc, dxData, dBnScaleBiasDesc, bnScaleData, bnBiasData, dBnScaleData, dBnBiasData, epsilon, savedMean, savedInvVariance, activationDesc, workSpace, workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes); + return func_ptr( + handle, mode, bnOps, alphaDataDiff, betaDataDiff, alphaParamDiff, + betaParamDiff, xDesc, xData, yDesc, yData, dyDesc, dyData, dzDesc, dzData, + dxDesc, dxData, dBnScaleBiasDesc, bnScaleData, bnBiasData, dBnScaleData, + dBnBiasData, epsilon, savedMean, savedInvVariance, activationDesc, + workSpace, workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnCreateSpatialTransformerDescriptor(cudnnSpatialTransformerDescriptor_t *stDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnSpatialTransformerDescriptor_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateSpatialTransformerDescriptor"); +cudnnStatus_t CUDNNWINAPI cudnnCreateSpatialTransformerDescriptor( + cudnnSpatialTransformerDescriptor_t *stDesc) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnSpatialTransformerDescriptor_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnCreateSpatialTransformerDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(stDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnSetSpatialTransformerNdDescriptor(cudnnSpatialTransformerDescriptor_t stDesc, - cudnnSamplerType_t samplerType, - cudnnDataType_t dataType, - const int nbDims, - const int dimA[]) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnSpatialTransformerDescriptor_t, cudnnSamplerType_t, cudnnDataType_t, const int, const int []); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetSpatialTransformerNdDescriptor"); +cudnnStatus_t CUDNNWINAPI cudnnSetSpatialTransformerNdDescriptor( + cudnnSpatialTransformerDescriptor_t stDesc, cudnnSamplerType_t samplerType, + cudnnDataType_t dataType, const int nbDims, const int dimA[]) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnSpatialTransformerDescriptor_t, cudnnSamplerType_t, cudnnDataType_t, + const int, const int[]); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnSetSpatialTransformerNdDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(stDesc, samplerType, dataType, nbDims, dimA); } -cudnnStatus_t CUDNNWINAPI -cudnnDestroySpatialTransformerDescriptor(cudnnSpatialTransformerDescriptor_t stDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnSpatialTransformerDescriptor_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroySpatialTransformerDescriptor"); +cudnnStatus_t CUDNNWINAPI cudnnDestroySpatialTransformerDescriptor( + cudnnSpatialTransformerDescriptor_t stDesc) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnSpatialTransformerDescriptor_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnDestroySpatialTransformerDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(stDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnSpatialTfGridGeneratorForward(cudnnHandle_t handle, - const cudnnSpatialTransformerDescriptor_t stDesc, - const void *theta, - void *grid) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnSpatialTransformerDescriptor_t, const void *, void *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSpatialTfGridGeneratorForward"); +cudnnStatus_t CUDNNWINAPI cudnnSpatialTfGridGeneratorForward( + cudnnHandle_t handle, const cudnnSpatialTransformerDescriptor_t stDesc, + const void *theta, void *grid) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnSpatialTransformerDescriptor_t, const void *, + void *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnSpatialTfGridGeneratorForward"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, stDesc, theta, grid); } -cudnnStatus_t CUDNNWINAPI -cudnnSpatialTfGridGeneratorBackward(cudnnHandle_t handle, - const cudnnSpatialTransformerDescriptor_t stDesc, - const void *dgrid, - void *dtheta) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnSpatialTransformerDescriptor_t, const void *, void *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSpatialTfGridGeneratorBackward"); +cudnnStatus_t CUDNNWINAPI cudnnSpatialTfGridGeneratorBackward( + cudnnHandle_t handle, const cudnnSpatialTransformerDescriptor_t stDesc, + const void *dgrid, void *dtheta) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnSpatialTransformerDescriptor_t, const void *, + void *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnSpatialTfGridGeneratorBackward"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, stDesc, dgrid, dtheta); } -cudnnStatus_t CUDNNWINAPI -cudnnSpatialTfSamplerForward(cudnnHandle_t handle, - cudnnSpatialTransformerDescriptor_t stDesc, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *grid, - const void *beta, - cudnnTensorDescriptor_t yDesc, - void *y) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnSpatialTransformerDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const void *, cudnnTensorDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnSpatialTfSamplerForward( + cudnnHandle_t handle, cudnnSpatialTransformerDescriptor_t stDesc, + const void *alpha, const cudnnTensorDescriptor_t xDesc, const void *x, + const void *grid, const void *beta, cudnnTensorDescriptor_t yDesc, + void *y) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnSpatialTransformerDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, const void *, + cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSpatialTfSamplerForward"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, stDesc, alpha, xDesc, x, grid, beta, yDesc, y); } -cudnnStatus_t CUDNNWINAPI -cudnnSpatialTfSamplerBackward(cudnnHandle_t handle, - cudnnSpatialTransformerDescriptor_t stDesc, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t dxDesc, - void *dx, - const void *alphaDgrid, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const void *grid, - const void *betaDgrid, - void *dgrid) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnSpatialTransformerDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const void *, void *); +cudnnStatus_t CUDNNWINAPI cudnnSpatialTfSamplerBackward( + cudnnHandle_t handle, cudnnSpatialTransformerDescriptor_t stDesc, + const void *alpha, const cudnnTensorDescriptor_t xDesc, const void *x, + const void *beta, const cudnnTensorDescriptor_t dxDesc, void *dx, + const void *alphaDgrid, const cudnnTensorDescriptor_t dyDesc, + const void *dy, const void *grid, const void *betaDgrid, void *dgrid) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnSpatialTransformerDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, const void *, + void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSpatialTfSamplerBackward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, stDesc, alpha, xDesc, x, beta, dxDesc, dx, alphaDgrid, dyDesc, dy, grid, betaDgrid, dgrid); + return func_ptr(handle, stDesc, alpha, xDesc, x, beta, dxDesc, dx, alphaDgrid, + dyDesc, dy, grid, betaDgrid, dgrid); } cudnnStatus_t CUDNNWINAPI cudnnCreateDropoutDescriptor(cudnnDropoutDescriptor_t *dropoutDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnDropoutDescriptor_t *); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnDropoutDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateDropoutDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(dropoutDesc); @@ -1643,99 +1697,95 @@ cudnnCreateDropoutDescriptor(cudnnDropoutDescriptor_t *dropoutDesc) { cudnnStatus_t CUDNNWINAPI cudnnDestroyDropoutDescriptor(cudnnDropoutDescriptor_t dropoutDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnDropoutDescriptor_t); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnDropoutDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyDropoutDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(dropoutDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnDropoutGetStatesSize(cudnnHandle_t handle, size_t *sizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, size_t *); +cudnnStatus_t CUDNNWINAPI cudnnDropoutGetStatesSize(cudnnHandle_t handle, + size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDropoutGetStatesSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, sizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnDropoutGetReserveSpaceSize(cudnnTensorDescriptor_t xdesc, size_t *sizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorDescriptor_t, size_t *); +cudnnStatus_t CUDNNWINAPI cudnnDropoutGetReserveSpaceSize( + cudnnTensorDescriptor_t xdesc, size_t *sizeInBytes) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnTensorDescriptor_t, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDropoutGetReserveSpaceSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(xdesc, sizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnSetDropoutDescriptor(cudnnDropoutDescriptor_t dropoutDesc, - cudnnHandle_t handle, - float dropout, - void *states, - size_t stateSizeInBytes, - unsigned long long seed) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnDropoutDescriptor_t, cudnnHandle_t, float, void *, size_t, unsigned long long); +cudnnStatus_t CUDNNWINAPI cudnnSetDropoutDescriptor( + cudnnDropoutDescriptor_t dropoutDesc, cudnnHandle_t handle, float dropout, + void *states, size_t stateSizeInBytes, unsigned long long seed) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnDropoutDescriptor_t, cudnnHandle_t, + float, void *, size_t, unsigned long long); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetDropoutDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(dropoutDesc, handle, dropout, states, stateSizeInBytes, seed); } -cudnnStatus_t CUDNNWINAPI -cudnnRestoreDropoutDescriptor(cudnnDropoutDescriptor_t dropoutDesc, - cudnnHandle_t handle, - float dropout, - void *states, - size_t stateSizeInBytes, - unsigned long long seed) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnDropoutDescriptor_t, cudnnHandle_t, float, void *, size_t, unsigned long long); +cudnnStatus_t CUDNNWINAPI cudnnRestoreDropoutDescriptor( + cudnnDropoutDescriptor_t dropoutDesc, cudnnHandle_t handle, float dropout, + void *states, size_t stateSizeInBytes, unsigned long long seed) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnDropoutDescriptor_t, cudnnHandle_t, + float, void *, size_t, unsigned long long); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRestoreDropoutDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(dropoutDesc, handle, dropout, states, stateSizeInBytes, seed); } -cudnnStatus_t CUDNNWINAPI -cudnnGetDropoutDescriptor(cudnnDropoutDescriptor_t dropoutDesc, - cudnnHandle_t handle, - float *dropout, - void **states, - unsigned long long *seed) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnDropoutDescriptor_t, cudnnHandle_t, float *, void **, unsigned long long *); +cudnnStatus_t CUDNNWINAPI cudnnGetDropoutDescriptor( + cudnnDropoutDescriptor_t dropoutDesc, cudnnHandle_t handle, float *dropout, + void **states, unsigned long long *seed) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnDropoutDescriptor_t, cudnnHandle_t, + float *, void **, unsigned long long *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetDropoutDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(dropoutDesc, handle, dropout, states, seed); } -cudnnStatus_t CUDNNWINAPI -cudnnDropoutForward(cudnnHandle_t handle, - const cudnnDropoutDescriptor_t dropoutDesc, - const cudnnTensorDescriptor_t xdesc, - const void *x, - const cudnnTensorDescriptor_t ydesc, - void *y, - void *reserveSpace, - size_t reserveSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnDropoutDescriptor_t, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, void *, void *, size_t); +cudnnStatus_t CUDNNWINAPI cudnnDropoutForward( + cudnnHandle_t handle, const cudnnDropoutDescriptor_t dropoutDesc, + const cudnnTensorDescriptor_t xdesc, const void *x, + const cudnnTensorDescriptor_t ydesc, void *y, void *reserveSpace, + size_t reserveSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnDropoutDescriptor_t, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, void *, void *, size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDropoutForward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, dropoutDesc, xdesc, x, ydesc, y, reserveSpace, reserveSpaceSizeInBytes); + return func_ptr(handle, dropoutDesc, xdesc, x, ydesc, y, reserveSpace, + reserveSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnDropoutBackward(cudnnHandle_t handle, - const cudnnDropoutDescriptor_t dropoutDesc, - const cudnnTensorDescriptor_t dydesc, - const void *dy, - const cudnnTensorDescriptor_t dxdesc, - void *dx, - void *reserveSpace, - size_t reserveSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnDropoutDescriptor_t, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, void *, void *, size_t); +cudnnStatus_t CUDNNWINAPI cudnnDropoutBackward( + cudnnHandle_t handle, const cudnnDropoutDescriptor_t dropoutDesc, + const cudnnTensorDescriptor_t dydesc, const void *dy, + const cudnnTensorDescriptor_t dxdesc, void *dx, void *reserveSpace, + size_t reserveSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnDropoutDescriptor_t, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, void *, void *, size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDropoutBackward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, dropoutDesc, dydesc, dy, dxdesc, dx, reserveSpace, reserveSpaceSizeInBytes); + return func_ptr(handle, dropoutDesc, dydesc, dy, dxdesc, dx, reserveSpace, + reserveSpaceSizeInBytes); } cudnnStatus_t CUDNNWINAPI cudnnCreateRNNDescriptor(cudnnRNNDescriptor_t *rnnDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t *); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnRNNDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateRNNDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(rnnDesc); @@ -1743,184 +1793,192 @@ cudnnCreateRNNDescriptor(cudnnRNNDescriptor_t *rnnDesc) { cudnnStatus_t CUDNNWINAPI cudnnDestroyRNNDescriptor(cudnnRNNDescriptor_t rnnDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnRNNDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyRNNDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(rnnDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnGetRNNForwardInferenceAlgorithmMaxCount(cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, int *count) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, int *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNForwardInferenceAlgorithmMaxCount"); +cudnnStatus_t CUDNNWINAPI cudnnGetRNNForwardInferenceAlgorithmMaxCount( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, int *count) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, int *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetRNNForwardInferenceAlgorithmMaxCount"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rnnDesc, count); } -cudnnStatus_t CUDNNWINAPI -cudnnFindRNNForwardInferenceAlgorithmEx(cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int seqLength, - const cudnnTensorDescriptor_t *xDesc, - const void *x, - const cudnnTensorDescriptor_t hxDesc, - const void *hx, - const cudnnTensorDescriptor_t cxDesc, - const void *cx, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnTensorDescriptor_t *yDesc, - void *y, - const cudnnTensorDescriptor_t hyDesc, - void *hy, - const cudnnTensorDescriptor_t cyDesc, - void *cy, - const float findIntensity, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnAlgorithmPerformance_t *perfResults, - void *workspace, - size_t workSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, const float, const int, int *, cudnnAlgorithmPerformance_t *, void *, size_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindRNNForwardInferenceAlgorithmEx"); +cudnnStatus_t CUDNNWINAPI cudnnFindRNNForwardInferenceAlgorithmEx( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, const cudnnTensorDescriptor_t *xDesc, const void *x, + const cudnnTensorDescriptor_t hxDesc, const void *hx, + const cudnnTensorDescriptor_t cxDesc, const void *cx, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnTensorDescriptor_t *yDesc, void *y, + const cudnnTensorDescriptor_t hyDesc, void *hy, + const cudnnTensorDescriptor_t cyDesc, void *cy, const float findIntensity, + const int requestedAlgoCount, int *returnedAlgoCount, + cudnnAlgorithmPerformance_t *perfResults, void *workspace, + size_t workSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t *, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, + void *, const cudnnTensorDescriptor_t, void *, const float, const int, + int *, cudnnAlgorithmPerformance_t *, void *, size_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnFindRNNForwardInferenceAlgorithmEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, cxDesc, cx, wDesc, w, yDesc, y, hyDesc, hy, cyDesc, cy, findIntensity, requestedAlgoCount, returnedAlgoCount, perfResults, workspace, workSpaceSizeInBytes); + return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, cxDesc, cx, + wDesc, w, yDesc, y, hyDesc, hy, cyDesc, cy, findIntensity, + requestedAlgoCount, returnedAlgoCount, perfResults, workspace, + workSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnGetRNNForwardTrainingAlgorithmMaxCount(cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, int *count) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, int *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNForwardTrainingAlgorithmMaxCount"); +cudnnStatus_t CUDNNWINAPI cudnnGetRNNForwardTrainingAlgorithmMaxCount( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, int *count) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, int *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetRNNForwardTrainingAlgorithmMaxCount"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rnnDesc, count); } -cudnnStatus_t CUDNNWINAPI -cudnnFindRNNForwardTrainingAlgorithmEx(cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int seqLength, - const cudnnTensorDescriptor_t *xDesc, - const void *x, - const cudnnTensorDescriptor_t hxDesc, - const void *hx, - const cudnnTensorDescriptor_t cxDesc, - const void *cx, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnTensorDescriptor_t *yDesc, - void *y, - const cudnnTensorDescriptor_t hyDesc, - void *hy, - const cudnnTensorDescriptor_t cyDesc, - void *cy, - const float findIntensity, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnAlgorithmPerformance_t *perfResults, - void *workspace, - size_t workSpaceSizeInBytes, - void *reserveSpace, - size_t reserveSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, const float, const int, int *, cudnnAlgorithmPerformance_t *, void *, size_t, void *, size_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindRNNForwardTrainingAlgorithmEx"); +cudnnStatus_t CUDNNWINAPI cudnnFindRNNForwardTrainingAlgorithmEx( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, const cudnnTensorDescriptor_t *xDesc, const void *x, + const cudnnTensorDescriptor_t hxDesc, const void *hx, + const cudnnTensorDescriptor_t cxDesc, const void *cx, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnTensorDescriptor_t *yDesc, void *y, + const cudnnTensorDescriptor_t hyDesc, void *hy, + const cudnnTensorDescriptor_t cyDesc, void *cy, const float findIntensity, + const int requestedAlgoCount, int *returnedAlgoCount, + cudnnAlgorithmPerformance_t *perfResults, void *workspace, + size_t workSpaceSizeInBytes, void *reserveSpace, + size_t reserveSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t *, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, + void *, const cudnnTensorDescriptor_t, void *, const float, const int, + int *, cudnnAlgorithmPerformance_t *, void *, size_t, void *, size_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnFindRNNForwardTrainingAlgorithmEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, cxDesc, cx, wDesc, w, yDesc, y, hyDesc, hy, cyDesc, cy, findIntensity, requestedAlgoCount, returnedAlgoCount, perfResults, workspace, workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes); + return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, cxDesc, cx, + wDesc, w, yDesc, y, hyDesc, hy, cyDesc, cy, findIntensity, + requestedAlgoCount, returnedAlgoCount, perfResults, workspace, + workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnGetRNNBackwardDataAlgorithmMaxCount(cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, int *count) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, int *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNBackwardDataAlgorithmMaxCount"); +cudnnStatus_t CUDNNWINAPI cudnnGetRNNBackwardDataAlgorithmMaxCount( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, int *count) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, int *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetRNNBackwardDataAlgorithmMaxCount"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rnnDesc, count); } -cudnnStatus_t CUDNNWINAPI -cudnnFindRNNBackwardDataAlgorithmEx(cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int seqLength, - const cudnnTensorDescriptor_t *yDesc, - const void *y, - const cudnnTensorDescriptor_t *dyDesc, - const void *dy, - const cudnnTensorDescriptor_t dhyDesc, - const void *dhy, - const cudnnTensorDescriptor_t dcyDesc, - const void *dcy, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnTensorDescriptor_t hxDesc, - const void *hx, - const cudnnTensorDescriptor_t cxDesc, - const void *cx, - const cudnnTensorDescriptor_t *dxDesc, - void *dx, - const cudnnTensorDescriptor_t dhxDesc, - void *dhx, - const cudnnTensorDescriptor_t dcxDesc, - void *dcx, - const float findIntensity, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnAlgorithmPerformance_t *perfResults, - void *workspace, - size_t workSpaceSizeInBytes, - void *reserveSpace, - size_t reserveSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, const float, const int, int *, cudnnAlgorithmPerformance_t *, void *, size_t, void *, size_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindRNNBackwardDataAlgorithmEx"); +cudnnStatus_t CUDNNWINAPI cudnnFindRNNBackwardDataAlgorithmEx( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, const cudnnTensorDescriptor_t *yDesc, const void *y, + const cudnnTensorDescriptor_t *dyDesc, const void *dy, + const cudnnTensorDescriptor_t dhyDesc, const void *dhy, + const cudnnTensorDescriptor_t dcyDesc, const void *dcy, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnTensorDescriptor_t hxDesc, const void *hx, + const cudnnTensorDescriptor_t cxDesc, const void *cx, + const cudnnTensorDescriptor_t *dxDesc, void *dx, + const cudnnTensorDescriptor_t dhxDesc, void *dhx, + const cudnnTensorDescriptor_t dcxDesc, void *dcx, const float findIntensity, + const int requestedAlgoCount, int *returnedAlgoCount, + cudnnAlgorithmPerformance_t *perfResults, void *workspace, + size_t workSpaceSizeInBytes, void *reserveSpace, + size_t reserveSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t *, const void *, + const cudnnTensorDescriptor_t *, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, + void *, const cudnnTensorDescriptor_t, void *, const float, const int, + int *, cudnnAlgorithmPerformance_t *, void *, size_t, void *, size_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnFindRNNBackwardDataAlgorithmEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, seqLength, yDesc, y, dyDesc, dy, dhyDesc, dhy, dcyDesc, dcy, wDesc, w, hxDesc, hx, cxDesc, cx, dxDesc, dx, dhxDesc, dhx, dcxDesc, dcx, findIntensity, requestedAlgoCount, returnedAlgoCount, perfResults, workspace, workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes); + return func_ptr(handle, rnnDesc, seqLength, yDesc, y, dyDesc, dy, dhyDesc, + dhy, dcyDesc, dcy, wDesc, w, hxDesc, hx, cxDesc, cx, dxDesc, + dx, dhxDesc, dhx, dcxDesc, dcx, findIntensity, + requestedAlgoCount, returnedAlgoCount, perfResults, workspace, + workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnGetRNNBackwardWeightsAlgorithmMaxCount(cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, int *count) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, int *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNBackwardWeightsAlgorithmMaxCount"); +cudnnStatus_t CUDNNWINAPI cudnnGetRNNBackwardWeightsAlgorithmMaxCount( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, int *count) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, int *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetRNNBackwardWeightsAlgorithmMaxCount"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rnnDesc, count); } -cudnnStatus_t CUDNNWINAPI -cudnnFindRNNBackwardWeightsAlgorithmEx(cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int seqLength, - const cudnnTensorDescriptor_t *xDesc, - const void *x, - const cudnnTensorDescriptor_t hxDesc, - const void *hx, - const cudnnTensorDescriptor_t *yDesc, - const void *y, - const float findIntensity, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnAlgorithmPerformance_t *perfResults, - const void *workspace, - size_t workSpaceSizeInBytes, - const cudnnFilterDescriptor_t dwDesc, - void *dw, - const void *reserveSpace, - size_t reserveSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t *, const void *, const float, const int, int *, cudnnAlgorithmPerformance_t *, const void *, size_t, const cudnnFilterDescriptor_t, void *, const void *, size_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindRNNBackwardWeightsAlgorithmEx"); +cudnnStatus_t CUDNNWINAPI cudnnFindRNNBackwardWeightsAlgorithmEx( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, const cudnnTensorDescriptor_t *xDesc, const void *x, + const cudnnTensorDescriptor_t hxDesc, const void *hx, + const cudnnTensorDescriptor_t *yDesc, const void *y, + const float findIntensity, const int requestedAlgoCount, + int *returnedAlgoCount, cudnnAlgorithmPerformance_t *perfResults, + const void *workspace, size_t workSpaceSizeInBytes, + const cudnnFilterDescriptor_t dwDesc, void *dw, const void *reserveSpace, + size_t reserveSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t *, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t *, const void *, const float, const int, + int *, cudnnAlgorithmPerformance_t *, const void *, size_t, + const cudnnFilterDescriptor_t, void *, const void *, size_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnFindRNNBackwardWeightsAlgorithmEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, yDesc, y, findIntensity, requestedAlgoCount, returnedAlgoCount, perfResults, workspace, workSpaceSizeInBytes, dwDesc, dw, reserveSpace, reserveSpaceSizeInBytes); + return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, yDesc, y, + findIntensity, requestedAlgoCount, returnedAlgoCount, + perfResults, workspace, workSpaceSizeInBytes, dwDesc, dw, + reserveSpace, reserveSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnCreatePersistentRNNPlan(cudnnRNNDescriptor_t rnnDesc, - const int minibatch, - const cudnnDataType_t dataType, - cudnnPersistentRNNPlan_t *plan) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t, const int, const cudnnDataType_t, cudnnPersistentRNNPlan_t *); +cudnnStatus_t CUDNNWINAPI cudnnCreatePersistentRNNPlan( + cudnnRNNDescriptor_t rnnDesc, const int minibatch, + const cudnnDataType_t dataType, cudnnPersistentRNNPlan_t *plan) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnRNNDescriptor_t, const int, + const cudnnDataType_t, + cudnnPersistentRNNPlan_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreatePersistentRNNPlan"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(rnnDesc, minibatch, dataType, plan); } -cudnnStatus_t CUDNNWINAPI -cudnnSetPersistentRNNPlan(cudnnRNNDescriptor_t rnnDesc, cudnnPersistentRNNPlan_t plan) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t, cudnnPersistentRNNPlan_t); +cudnnStatus_t CUDNNWINAPI cudnnSetPersistentRNNPlan( + cudnnRNNDescriptor_t rnnDesc, cudnnPersistentRNNPlan_t plan) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnRNNDescriptor_t, + cudnnPersistentRNNPlan_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetPersistentRNNPlan"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(rnnDesc, plan); @@ -1928,289 +1986,285 @@ cudnnSetPersistentRNNPlan(cudnnRNNDescriptor_t rnnDesc, cudnnPersistentRNNPlan_t cudnnStatus_t CUDNNWINAPI cudnnDestroyPersistentRNNPlan(cudnnPersistentRNNPlan_t plan) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnPersistentRNNPlan_t); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnPersistentRNNPlan_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyPersistentRNNPlan"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(plan); } -cudnnStatus_t CUDNNWINAPI -cudnnSetRNNDescriptor(cudnnHandle_t handle, - cudnnRNNDescriptor_t rnnDesc, - const int hiddenSize, - const int numLayers, - cudnnDropoutDescriptor_t dropoutDesc, /* Between layers, not between recurrent steps. */ - cudnnRNNInputMode_t inputMode, - cudnnDirectionMode_t direction, - cudnnRNNMode_t mode, - cudnnRNNAlgo_t algo, - cudnnDataType_t dataType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnRNNDescriptor_t, const int, const int, cudnnDropoutDescriptor_t, cudnnRNNInputMode_t, cudnnDirectionMode_t, cudnnRNNMode_t, cudnnRNNAlgo_t, cudnnDataType_t); +cudnnStatus_t CUDNNWINAPI cudnnSetRNNDescriptor( + cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc, const int hiddenSize, + const int numLayers, + cudnnDropoutDescriptor_t + dropoutDesc, /* Between layers, not between recurrent steps. */ + cudnnRNNInputMode_t inputMode, cudnnDirectionMode_t direction, + cudnnRNNMode_t mode, cudnnRNNAlgo_t algo, cudnnDataType_t dataType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnRNNDescriptor_t, const int, const int, + cudnnDropoutDescriptor_t, cudnnRNNInputMode_t, cudnnDirectionMode_t, + cudnnRNNMode_t, cudnnRNNAlgo_t, cudnnDataType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetRNNDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, hiddenSize, numLayers, dropoutDesc, inputMode, direction, mode, algo, dataType); + return func_ptr(handle, rnnDesc, hiddenSize, numLayers, dropoutDesc, + inputMode, direction, mode, algo, dataType); } cudnnStatus_t CUDNNWINAPI -cudnnSetRNNProjectionLayers(cudnnHandle_t handle, - cudnnRNNDescriptor_t rnnDesc, - const int recProjSize, - const int outProjSize) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnRNNDescriptor_t, const int, const int); +cudnnSetRNNProjectionLayers(cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc, + const int recProjSize, const int outProjSize) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnRNNDescriptor_t, const int, const int); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetRNNProjectionLayers"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rnnDesc, recProjSize, outProjSize); } -cudnnStatus_t CUDNNWINAPI -cudnnGetRNNProjectionLayers(cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - int *recProjSize, - int *outProjSize) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, int *, int *); +cudnnStatus_t CUDNNWINAPI cudnnGetRNNProjectionLayers( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, int *recProjSize, + int *outProjSize) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, int *, int *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNProjectionLayers"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rnnDesc, recProjSize, outProjSize); } -cudnnStatus_t CUDNNWINAPI -cudnnSetRNNAlgorithmDescriptor(cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc, cudnnAlgorithmDescriptor_t algoDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnRNNDescriptor_t, cudnnAlgorithmDescriptor_t); +cudnnStatus_t CUDNNWINAPI cudnnSetRNNAlgorithmDescriptor( + cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc, + cudnnAlgorithmDescriptor_t algoDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnRNNDescriptor_t, cudnnAlgorithmDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetRNNAlgorithmDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rnnDesc, algoDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnGetRNNDescriptor(cudnnHandle_t handle, - cudnnRNNDescriptor_t rnnDesc, - int *hiddenSize, - int *numLayers, - cudnnDropoutDescriptor_t *dropoutDesc, - cudnnRNNInputMode_t *inputMode, - cudnnDirectionMode_t *direction, - cudnnRNNMode_t *mode, - cudnnRNNAlgo_t *algo, - cudnnDataType_t *dataType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnRNNDescriptor_t, int *, int *, cudnnDropoutDescriptor_t *, cudnnRNNInputMode_t *, cudnnDirectionMode_t *, cudnnRNNMode_t *, cudnnRNNAlgo_t *, cudnnDataType_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetRNNDescriptor( + cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc, int *hiddenSize, + int *numLayers, cudnnDropoutDescriptor_t *dropoutDesc, + cudnnRNNInputMode_t *inputMode, cudnnDirectionMode_t *direction, + cudnnRNNMode_t *mode, cudnnRNNAlgo_t *algo, cudnnDataType_t *dataType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnRNNDescriptor_t, int *, int *, + cudnnDropoutDescriptor_t *, cudnnRNNInputMode_t *, cudnnDirectionMode_t *, + cudnnRNNMode_t *, cudnnRNNAlgo_t *, cudnnDataType_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, hiddenSize, numLayers, dropoutDesc, inputMode, direction, mode, algo, dataType); + return func_ptr(handle, rnnDesc, hiddenSize, numLayers, dropoutDesc, + inputMode, direction, mode, algo, dataType); } cudnnStatus_t CUDNNWINAPI cudnnSetRNNMatrixMathType(cudnnRNNDescriptor_t rnnDesc, cudnnMathType_t mType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t, cudnnMathType_t); + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnRNNDescriptor_t, cudnnMathType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetRNNMatrixMathType"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(rnnDesc, mType); } -cudnnStatus_t CUDNNWINAPI -cudnnGetRNNMatrixMathType(cudnnRNNDescriptor_t rnnDesc, cudnnMathType_t *mType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t, cudnnMathType_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetRNNMatrixMathType( + cudnnRNNDescriptor_t rnnDesc, cudnnMathType_t *mType) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnRNNDescriptor_t, cudnnMathType_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNMatrixMathType"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(rnnDesc, mType); } -cudnnStatus_t CUDNNWINAPI -cudnnGetRNNWorkspaceSize(cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int seqLength, - const cudnnTensorDescriptor_t *xDesc, - size_t *sizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, size_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetRNNWorkspaceSize( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, const cudnnTensorDescriptor_t *xDesc, + size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t *, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNWorkspaceSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rnnDesc, seqLength, xDesc, sizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnGetRNNTrainingReserveSize(cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int seqLength, - const cudnnTensorDescriptor_t *xDesc, - size_t *sizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, size_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetRNNTrainingReserveSize( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, const cudnnTensorDescriptor_t *xDesc, + size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t *, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNTrainingReserveSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rnnDesc, seqLength, xDesc, sizeInBytes); } cudnnStatus_t CUDNNWINAPI -cudnnGetRNNParamsSize(cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const cudnnTensorDescriptor_t xDesc, - size_t *sizeInBytes, +cudnnGetRNNParamsSize(cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const cudnnTensorDescriptor_t xDesc, size_t *sizeInBytes, cudnnDataType_t dataType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const cudnnTensorDescriptor_t, size_t *, cudnnDataType_t); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const cudnnTensorDescriptor_t, + size_t *, cudnnDataType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNParamsSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rnnDesc, xDesc, sizeInBytes, dataType); } -cudnnStatus_t CUDNNWINAPI -cudnnGetRNNLinLayerMatrixParams(cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int pseudoLayer, - const cudnnTensorDescriptor_t xDesc, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const int linLayerID, - cudnnFilterDescriptor_t linLayerMatDesc, - void **linLayerMat) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, const void *, const int, cudnnFilterDescriptor_t, void **); +cudnnStatus_t CUDNNWINAPI cudnnGetRNNLinLayerMatrixParams( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int pseudoLayer, const cudnnTensorDescriptor_t xDesc, + const cudnnFilterDescriptor_t wDesc, const void *w, const int linLayerID, + cudnnFilterDescriptor_t linLayerMatDesc, void **linLayerMat) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, + const void *, const int, cudnnFilterDescriptor_t, void **); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNLinLayerMatrixParams"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, pseudoLayer, xDesc, wDesc, w, linLayerID, linLayerMatDesc, linLayerMat); + return func_ptr(handle, rnnDesc, pseudoLayer, xDesc, wDesc, w, linLayerID, + linLayerMatDesc, linLayerMat); } -cudnnStatus_t CUDNNWINAPI -cudnnGetRNNLinLayerBiasParams(cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int pseudoLayer, - const cudnnTensorDescriptor_t xDesc, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const int linLayerID, - cudnnFilterDescriptor_t linLayerBiasDesc, - void **linLayerBias) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, const void *, const int, cudnnFilterDescriptor_t, void **); +cudnnStatus_t CUDNNWINAPI cudnnGetRNNLinLayerBiasParams( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int pseudoLayer, const cudnnTensorDescriptor_t xDesc, + const cudnnFilterDescriptor_t wDesc, const void *w, const int linLayerID, + cudnnFilterDescriptor_t linLayerBiasDesc, void **linLayerBias) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, + const void *, const int, cudnnFilterDescriptor_t, void **); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNLinLayerBiasParams"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, pseudoLayer, xDesc, wDesc, w, linLayerID, linLayerBiasDesc, linLayerBias); + return func_ptr(handle, rnnDesc, pseudoLayer, xDesc, wDesc, w, linLayerID, + linLayerBiasDesc, linLayerBias); } -cudnnStatus_t CUDNNWINAPI -cudnnRNNForwardInference(cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int seqLength, - const cudnnTensorDescriptor_t *xDesc, - const void *x, - const cudnnTensorDescriptor_t hxDesc, - const void *hx, - const cudnnTensorDescriptor_t cxDesc, - const void *cx, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnTensorDescriptor_t *yDesc, - void *y, - const cudnnTensorDescriptor_t hyDesc, - void *hy, - const cudnnTensorDescriptor_t cyDesc, - void *cy, - void *workspace, - size_t workSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, void *, size_t); +cudnnStatus_t CUDNNWINAPI cudnnRNNForwardInference( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, const cudnnTensorDescriptor_t *xDesc, const void *x, + const cudnnTensorDescriptor_t hxDesc, const void *hx, + const cudnnTensorDescriptor_t cxDesc, const void *cx, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnTensorDescriptor_t *yDesc, void *y, + const cudnnTensorDescriptor_t hyDesc, void *hy, + const cudnnTensorDescriptor_t cyDesc, void *cy, void *workspace, + size_t workSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t *, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, + void *, const cudnnTensorDescriptor_t, void *, void *, size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRNNForwardInference"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, cxDesc, cx, wDesc, w, yDesc, y, hyDesc, hy, cyDesc, cy, workspace, workSpaceSizeInBytes); + return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, cxDesc, cx, + wDesc, w, yDesc, y, hyDesc, hy, cyDesc, cy, workspace, + workSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnRNNForwardTraining(cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int seqLength, - const cudnnTensorDescriptor_t *xDesc, - const void *x, - const cudnnTensorDescriptor_t hxDesc, - const void *hx, - const cudnnTensorDescriptor_t cxDesc, - const void *cx, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnTensorDescriptor_t *yDesc, - void *y, - const cudnnTensorDescriptor_t hyDesc, - void *hy, - const cudnnTensorDescriptor_t cyDesc, - void *cy, - void *workspace, - size_t workSpaceSizeInBytes, - void *reserveSpace, - size_t reserveSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, void *, size_t, void *, size_t); +cudnnStatus_t CUDNNWINAPI cudnnRNNForwardTraining( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, const cudnnTensorDescriptor_t *xDesc, const void *x, + const cudnnTensorDescriptor_t hxDesc, const void *hx, + const cudnnTensorDescriptor_t cxDesc, const void *cx, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnTensorDescriptor_t *yDesc, void *y, + const cudnnTensorDescriptor_t hyDesc, void *hy, + const cudnnTensorDescriptor_t cyDesc, void *cy, void *workspace, + size_t workSpaceSizeInBytes, void *reserveSpace, + size_t reserveSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t *, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, + void *, const cudnnTensorDescriptor_t, void *, void *, size_t, void *, + size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRNNForwardTraining"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, cxDesc, cx, wDesc, w, yDesc, y, hyDesc, hy, cyDesc, cy, workspace, workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes); + return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, cxDesc, cx, + wDesc, w, yDesc, y, hyDesc, hy, cyDesc, cy, workspace, + workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes); } cudnnStatus_t CUDNNWINAPI -cudnnRNNBackwardData(cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int seqLength, - const cudnnTensorDescriptor_t *yDesc, - const void *y, - const cudnnTensorDescriptor_t *dyDesc, - const void *dy, - const cudnnTensorDescriptor_t dhyDesc, - const void *dhy, - const cudnnTensorDescriptor_t dcyDesc, - const void *dcy, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnTensorDescriptor_t hxDesc, - const void *hx, - const cudnnTensorDescriptor_t cxDesc, - const void *cx, - const cudnnTensorDescriptor_t *dxDesc, - void *dx, - const cudnnTensorDescriptor_t dhxDesc, - void *dhx, - const cudnnTensorDescriptor_t dcxDesc, - void *dcx, - void *workspace, - size_t workSpaceSizeInBytes, - void *reserveSpace, - size_t reserveSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, void *, size_t, void *, size_t); +cudnnRNNBackwardData(cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, const cudnnTensorDescriptor_t *yDesc, + const void *y, const cudnnTensorDescriptor_t *dyDesc, + const void *dy, const cudnnTensorDescriptor_t dhyDesc, + const void *dhy, const cudnnTensorDescriptor_t dcyDesc, + const void *dcy, const cudnnFilterDescriptor_t wDesc, + const void *w, const cudnnTensorDescriptor_t hxDesc, + const void *hx, const cudnnTensorDescriptor_t cxDesc, + const void *cx, const cudnnTensorDescriptor_t *dxDesc, + void *dx, const cudnnTensorDescriptor_t dhxDesc, void *dhx, + const cudnnTensorDescriptor_t dcxDesc, void *dcx, + void *workspace, size_t workSpaceSizeInBytes, + void *reserveSpace, size_t reserveSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t *, const void *, + const cudnnTensorDescriptor_t *, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, + void *, const cudnnTensorDescriptor_t, void *, void *, size_t, void *, + size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRNNBackwardData"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, seqLength, yDesc, y, dyDesc, dy, dhyDesc, dhy, dcyDesc, dcy, wDesc, w, hxDesc, hx, cxDesc, cx, dxDesc, dx, dhxDesc, dhx, dcxDesc, dcx, workspace, workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes); + return func_ptr(handle, rnnDesc, seqLength, yDesc, y, dyDesc, dy, dhyDesc, + dhy, dcyDesc, dcy, wDesc, w, hxDesc, hx, cxDesc, cx, dxDesc, + dx, dhxDesc, dhx, dcxDesc, dcx, workspace, + workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnRNNBackwardWeights(cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int seqLength, - const cudnnTensorDescriptor_t *xDesc, - const void *x, - const cudnnTensorDescriptor_t hxDesc, - const void *hx, - const cudnnTensorDescriptor_t *yDesc, - const void *y, - const void *workspace, - size_t workSpaceSizeInBytes, - const cudnnFilterDescriptor_t dwDesc, - void *dw, - const void *reserveSpace, - size_t reserveSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t *, const void *, const void *, size_t, const cudnnFilterDescriptor_t, void *, const void *, size_t); +cudnnStatus_t CUDNNWINAPI cudnnRNNBackwardWeights( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, const cudnnTensorDescriptor_t *xDesc, const void *x, + const cudnnTensorDescriptor_t hxDesc, const void *hx, + const cudnnTensorDescriptor_t *yDesc, const void *y, const void *workspace, + size_t workSpaceSizeInBytes, const cudnnFilterDescriptor_t dwDesc, void *dw, + const void *reserveSpace, size_t reserveSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t *, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t *, const void *, const void *, size_t, + const cudnnFilterDescriptor_t, void *, const void *, size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRNNBackwardWeights"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, yDesc, y, workspace, workSpaceSizeInBytes, dwDesc, dw, reserveSpace, reserveSpaceSizeInBytes); + return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, yDesc, y, + workspace, workSpaceSizeInBytes, dwDesc, dw, reserveSpace, + reserveSpaceSizeInBytes); } cudnnStatus_t CUDNNWINAPI cudnnCreateCTCLossDescriptor(cudnnCTCLossDescriptor_t *ctcLossDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnCTCLossDescriptor_t *); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnCTCLossDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateCTCLossDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(ctcLossDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnSetCTCLossDescriptor(cudnnCTCLossDescriptor_t ctcLossDesc, cudnnDataType_t compType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnCTCLossDescriptor_t, cudnnDataType_t); +cudnnStatus_t CUDNNWINAPI cudnnSetCTCLossDescriptor( + cudnnCTCLossDescriptor_t ctcLossDesc, cudnnDataType_t compType) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnCTCLossDescriptor_t, cudnnDataType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetCTCLossDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(ctcLossDesc, compType); } -cudnnStatus_t CUDNNWINAPI -cudnnGetCTCLossDescriptor(cudnnCTCLossDescriptor_t ctcLossDesc, cudnnDataType_t *compType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnCTCLossDescriptor_t, cudnnDataType_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetCTCLossDescriptor( + cudnnCTCLossDescriptor_t ctcLossDesc, cudnnDataType_t *compType) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnCTCLossDescriptor_t, cudnnDataType_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetCTCLossDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(ctcLossDesc, compType); @@ -2218,82 +2272,102 @@ cudnnGetCTCLossDescriptor(cudnnCTCLossDescriptor_t ctcLossDesc, cudnnDataType_t cudnnStatus_t CUDNNWINAPI cudnnDestroyCTCLossDescriptor(cudnnCTCLossDescriptor_t ctcLossDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnCTCLossDescriptor_t); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnCTCLossDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyCTCLossDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(ctcLossDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnCTCLoss( +cudnnStatus_t CUDNNWINAPI cudnnCTCLoss( cudnnHandle_t handle, const cudnnTensorDescriptor_t - probsDesc, /* Tensor descriptor for probabilities, the dimensions are T,N,A (T is the timing steps, N is the - mini batch size, A is the alphabet size) */ - const void *probs, /* probabilities after softmax, in GPU memory */ - const int *labels, /* labels, in CPU memory */ - const int *labelLengths, /* the length of each label, in CPU memory */ - const int *inputLengths, /* the lengths of timing steps in each batch, in CPU memory */ - void *costs, /* the returned costs of CTC, in GPU memory */ - const cudnnTensorDescriptor_t gradientsDesc, /* Tensor descriptor for gradients, the dimensions are T,N,A */ - const void *gradients, /* the returned CTC gradients, in GPU memory, to compute costs only, set it to NULL */ + probsDesc, /* Tensor descriptor for probabilities, the dimensions are + T,N,A (T is the timing steps, N is the + mini batch size, A is the alphabet size) */ + const void *probs, /* probabilities after softmax, in GPU memory */ + const int *labels, /* labels, in CPU memory */ + const int *labelLengths, /* the length of each label, in CPU memory */ + const int *inputLengths, /* the lengths of timing steps in each batch, in + CPU memory */ + void *costs, /* the returned costs of CTC, in GPU memory */ + const cudnnTensorDescriptor_t + gradientsDesc, /* Tensor descriptor for gradients, the dimensions are + T,N,A */ + const void *gradients, /* the returned CTC gradients, in GPU memory, to + compute costs only, set it to NULL */ cudnnCTCLossAlgo_t algo, /* algorithm selected, supported now 0 and 1 */ cudnnCTCLossDescriptor_t ctcLossDesc, - void *workspace, /* pointer to the workspace, in GPU memory */ + void *workspace, /* pointer to the workspace, in GPU memory */ size_t workSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const void *, const int *, const int *, const int *, void *, const cudnnTensorDescriptor_t, const void *, cudnnCTCLossAlgo_t, cudnnCTCLossDescriptor_t, void *, size_t); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, const void *, const int *, + const int *, const int *, void *, const cudnnTensorDescriptor_t, + const void *, cudnnCTCLossAlgo_t, cudnnCTCLossDescriptor_t, void *, + size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCTCLoss"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, probsDesc, probs, labels, labelLengths, inputLengths, costs, gradientsDesc, gradients, algo, ctcLossDesc, workspace, workSpaceSizeInBytes); + return func_ptr(handle, probsDesc, probs, labels, labelLengths, inputLengths, + costs, gradientsDesc, gradients, algo, ctcLossDesc, workspace, + workSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnGetCTCLossWorkspaceSize( +cudnnStatus_t CUDNNWINAPI cudnnGetCTCLossWorkspaceSize( cudnnHandle_t handle, - const cudnnTensorDescriptor_t probsDesc, /* Tensor descriptor for probabilities, the dimensions are T,N,A (T is the - timing steps, N is the mini batch size, A is the alphabet size) */ - const cudnnTensorDescriptor_t gradientsDesc, /* Tensor descriptor for gradients, the - dimensions are T,N,A. To compute costs - only, set it to NULL */ - const int *labels, /* labels, in CPU memory */ - const int *labelLengths, /* the length of each label, in CPU memory */ - const int *inputLengths, /* the lengths of timing steps in each batch, in CPU memory */ - cudnnCTCLossAlgo_t algo, /* algorithm selected, supported now 0 and 1 */ - cudnnCTCLossDescriptor_t ctcLossDesc, - size_t *sizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const int *, const int *, const int *, cudnnCTCLossAlgo_t, cudnnCTCLossDescriptor_t, size_t *); + const cudnnTensorDescriptor_t + probsDesc, /* Tensor descriptor for probabilities, the dimensions are + T,N,A (T is the + timing steps, N is the mini batch size, A is the alphabet + size) */ + const cudnnTensorDescriptor_t + gradientsDesc, /* Tensor descriptor for gradients, the + dimensions are T,N,A. To compute costs + only, set it to NULL */ + const int *labels, /* labels, in CPU memory */ + const int *labelLengths, /* the length of each label, in CPU memory */ + const int *inputLengths, /* the lengths of timing steps in each batch, in + CPU memory */ + cudnnCTCLossAlgo_t algo, /* algorithm selected, supported now 0 and 1 */ + cudnnCTCLossDescriptor_t ctcLossDesc, size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, + const cudnnTensorDescriptor_t, const int *, const int *, const int *, + cudnnCTCLossAlgo_t, cudnnCTCLossDescriptor_t, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetCTCLossWorkspaceSize"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, probsDesc, gradientsDesc, labels, labelLengths, inputLengths, algo, ctcLossDesc, sizeInBytes); + return func_ptr(handle, probsDesc, gradientsDesc, labels, labelLengths, + inputLengths, algo, ctcLossDesc, sizeInBytes); } cudnnStatus_t CUDNNWINAPI cudnnCreateAlgorithmDescriptor(cudnnAlgorithmDescriptor_t *algoDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnAlgorithmDescriptor_t *); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnAlgorithmDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateAlgorithmDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(algoDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnSetAlgorithmDescriptor(cudnnAlgorithmDescriptor_t algoDesc, cudnnAlgorithm_t algorithm) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnAlgorithmDescriptor_t, cudnnAlgorithm_t); +cudnnStatus_t CUDNNWINAPI cudnnSetAlgorithmDescriptor( + cudnnAlgorithmDescriptor_t algoDesc, cudnnAlgorithm_t algorithm) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnAlgorithmDescriptor_t, + cudnnAlgorithm_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetAlgorithmDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(algoDesc, algorithm); } -cudnnStatus_t CUDNNWINAPI -cudnnGetAlgorithmDescriptor(const cudnnAlgorithmDescriptor_t algoDesc, cudnnAlgorithm_t *algorithm) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnAlgorithmDescriptor_t, cudnnAlgorithm_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetAlgorithmDescriptor( + const cudnnAlgorithmDescriptor_t algoDesc, cudnnAlgorithm_t *algorithm) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(const cudnnAlgorithmDescriptor_t, + cudnnAlgorithm_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetAlgorithmDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(algoDesc, algorithm); } -cudnnStatus_t CUDNNWINAPI -cudnnCopyAlgorithmDescriptor(const cudnnAlgorithmDescriptor_t src, cudnnAlgorithmDescriptor_t dest) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnAlgorithmDescriptor_t, cudnnAlgorithmDescriptor_t); +cudnnStatus_t CUDNNWINAPI cudnnCopyAlgorithmDescriptor( + const cudnnAlgorithmDescriptor_t src, cudnnAlgorithmDescriptor_t dest) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(const cudnnAlgorithmDescriptor_t, + cudnnAlgorithmDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCopyAlgorithmDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(src, dest); @@ -2301,135 +2375,141 @@ cudnnCopyAlgorithmDescriptor(const cudnnAlgorithmDescriptor_t src, cudnnAlgorith cudnnStatus_t CUDNNWINAPI cudnnDestroyAlgorithmDescriptor(cudnnAlgorithmDescriptor_t algoDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnAlgorithmDescriptor_t); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnAlgorithmDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyAlgorithmDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(algoDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnCreateAlgorithmPerformance(cudnnAlgorithmPerformance_t *algoPerf, int numberToCreate) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnAlgorithmPerformance_t *, int); +cudnnStatus_t CUDNNWINAPI cudnnCreateAlgorithmPerformance( + cudnnAlgorithmPerformance_t *algoPerf, int numberToCreate) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnAlgorithmPerformance_t *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateAlgorithmPerformance"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(algoPerf, numberToCreate); } -cudnnStatus_t CUDNNWINAPI -cudnnSetAlgorithmPerformance(cudnnAlgorithmPerformance_t algoPerf, - cudnnAlgorithmDescriptor_t algoDesc, - cudnnStatus_t status, - float time, - size_t memory) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnAlgorithmPerformance_t, cudnnAlgorithmDescriptor_t, cudnnStatus_t, float, size_t); +cudnnStatus_t CUDNNWINAPI cudnnSetAlgorithmPerformance( + cudnnAlgorithmPerformance_t algoPerf, cudnnAlgorithmDescriptor_t algoDesc, + cudnnStatus_t status, float time, size_t memory) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnAlgorithmPerformance_t, + cudnnAlgorithmDescriptor_t, + cudnnStatus_t, float, size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetAlgorithmPerformance"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(algoPerf, algoDesc, status, time, memory); } -cudnnStatus_t CUDNNWINAPI -cudnnGetAlgorithmPerformance(const cudnnAlgorithmPerformance_t algoPerf, - cudnnAlgorithmDescriptor_t *algoDesc, - cudnnStatus_t *status, - float *time, - size_t *memory) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnAlgorithmPerformance_t, cudnnAlgorithmDescriptor_t *, cudnnStatus_t *, float *, size_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetAlgorithmPerformance( + const cudnnAlgorithmPerformance_t algoPerf, + cudnnAlgorithmDescriptor_t *algoDesc, cudnnStatus_t *status, float *time, + size_t *memory) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnAlgorithmPerformance_t, cudnnAlgorithmDescriptor_t *, + cudnnStatus_t *, float *, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetAlgorithmPerformance"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(algoPerf, algoDesc, status, time, memory); } -cudnnStatus_t CUDNNWINAPI -cudnnDestroyAlgorithmPerformance(cudnnAlgorithmPerformance_t *algoPerf, int numberToDestroy) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnAlgorithmPerformance_t *, int); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyAlgorithmPerformance"); +cudnnStatus_t CUDNNWINAPI cudnnDestroyAlgorithmPerformance( + cudnnAlgorithmPerformance_t *algoPerf, int numberToDestroy) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnAlgorithmPerformance_t *, int); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnDestroyAlgorithmPerformance"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(algoPerf, numberToDestroy); } -cudnnStatus_t CUDNNWINAPI -cudnnGetAlgorithmSpaceSize(cudnnHandle_t handle, cudnnAlgorithmDescriptor_t algoDesc, size_t *algoSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnAlgorithmDescriptor_t, size_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetAlgorithmSpaceSize( + cudnnHandle_t handle, cudnnAlgorithmDescriptor_t algoDesc, + size_t *algoSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnAlgorithmDescriptor_t, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetAlgorithmSpaceSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, algoDesc, algoSpaceSizeInBytes); } cudnnStatus_t CUDNNWINAPI -cudnnSaveAlgorithm(cudnnHandle_t handle, - cudnnAlgorithmDescriptor_t algoDesc, - void *algoSpace, - size_t algoSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnAlgorithmDescriptor_t, void *, size_t); +cudnnSaveAlgorithm(cudnnHandle_t handle, cudnnAlgorithmDescriptor_t algoDesc, + void *algoSpace, size_t algoSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnAlgorithmDescriptor_t, void *, size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSaveAlgorithm"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, algoDesc, algoSpace, algoSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnRestoreAlgorithm(cudnnHandle_t handle, - void *algoSpace, - size_t algoSpaceSizeInBytes, - cudnnAlgorithmDescriptor_t algoDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, void *, size_t, cudnnAlgorithmDescriptor_t); +cudnnStatus_t CUDNNWINAPI cudnnRestoreAlgorithm( + cudnnHandle_t handle, void *algoSpace, size_t algoSpaceSizeInBytes, + cudnnAlgorithmDescriptor_t algoDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t, void *, size_t, + cudnnAlgorithmDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRestoreAlgorithm"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, algoSpace, algoSpaceSizeInBytes, algoDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnRNNSetClip(cudnnHandle_t handle, - cudnnRNNDescriptor_t rnnDesc, - cudnnRNNClipMode_t clipMode, - cudnnNanPropagation_t clipNanOpt, - double lclip, - double rclip) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnRNNDescriptor_t, cudnnRNNClipMode_t, cudnnNanPropagation_t, double, double); +cudnnStatus_t CUDNNWINAPI cudnnRNNSetClip(cudnnHandle_t handle, + cudnnRNNDescriptor_t rnnDesc, + cudnnRNNClipMode_t clipMode, + cudnnNanPropagation_t clipNanOpt, + double lclip, double rclip) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnRNNDescriptor_t, cudnnRNNClipMode_t, + cudnnNanPropagation_t, double, double); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRNNSetClip"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rnnDesc, clipMode, clipNanOpt, lclip, rclip); } -cudnnStatus_t CUDNNWINAPI -cudnnRNNGetClip(cudnnHandle_t handle, - cudnnRNNDescriptor_t rnnDesc, - cudnnRNNClipMode_t *clipMode, - cudnnNanPropagation_t *clipNanOpt, - double *lclip, - double *rclip) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnRNNDescriptor_t, cudnnRNNClipMode_t *, cudnnNanPropagation_t *, double *, double *); +cudnnStatus_t CUDNNWINAPI cudnnRNNGetClip(cudnnHandle_t handle, + cudnnRNNDescriptor_t rnnDesc, + cudnnRNNClipMode_t *clipMode, + cudnnNanPropagation_t *clipNanOpt, + double *lclip, double *rclip) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnRNNDescriptor_t, cudnnRNNClipMode_t *, + cudnnNanPropagation_t *, double *, double *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRNNGetClip"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rnnDesc, clipMode, clipNanOpt, lclip, rclip); } -cudnnStatus_t CUDNNWINAPI -cudnnSetCallback(unsigned mask, void *udata, cudnnCallback_t fptr) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(unsigned int, void *, cudnnCallback_t); +cudnnStatus_t CUDNNWINAPI cudnnSetCallback(unsigned mask, void *udata, + cudnnCallback_t fptr) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(unsigned int, void *, cudnnCallback_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetCallback"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(mask, udata, fptr); } -cudnnStatus_t CUDNNWINAPI -cudnnGetCallback(unsigned *mask, void **udata, cudnnCallback_t *fptr) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(unsigned int *, void **, cudnnCallback_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetCallback(unsigned *mask, void **udata, + cudnnCallback_t *fptr) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(unsigned int *, void **, cudnnCallback_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetCallback"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(mask, udata, fptr); } -cudnnStatus_t CUDNNWINAPI -cudnnSetRNNPaddingMode(cudnnRNNDescriptor_t rnnDesc, cudnnRNNPaddingMode_t paddingMode) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t, cudnnRNNPaddingMode_t); +cudnnStatus_t CUDNNWINAPI cudnnSetRNNPaddingMode( + cudnnRNNDescriptor_t rnnDesc, cudnnRNNPaddingMode_t paddingMode) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnRNNDescriptor_t, cudnnRNNPaddingMode_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetRNNPaddingMode"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(rnnDesc, paddingMode); } -cudnnStatus_t CUDNNWINAPI -cudnnGetRNNPaddingMode(cudnnRNNDescriptor_t rnnDesc, cudnnRNNPaddingMode_t *paddingMode) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t, cudnnRNNPaddingMode_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetRNNPaddingMode( + cudnnRNNDescriptor_t rnnDesc, cudnnRNNPaddingMode_t *paddingMode) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnRNNDescriptor_t, + cudnnRNNPaddingMode_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNPaddingMode"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(rnnDesc, paddingMode); @@ -2437,7 +2517,7 @@ cudnnGetRNNPaddingMode(cudnnRNNDescriptor_t rnnDesc, cudnnRNNPaddingMode_t *padd cudnnStatus_t CUDNNWINAPI cudnnCreateRNNDataDescriptor(cudnnRNNDataDescriptor_t *RNNDataDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDataDescriptor_t *); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnRNNDataDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateRNNDataDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(RNNDataDesc); @@ -2445,199 +2525,202 @@ cudnnCreateRNNDataDescriptor(cudnnRNNDataDescriptor_t *RNNDataDesc) { cudnnStatus_t CUDNNWINAPI cudnnDestroyRNNDataDescriptor(cudnnRNNDataDescriptor_t RNNDataDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDataDescriptor_t); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnRNNDataDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyRNNDataDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(RNNDataDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnSetRNNDataDescriptor(cudnnRNNDataDescriptor_t RNNDataDesc, - cudnnDataType_t dataType, - cudnnRNNDataLayout_t layout, - int maxSeqLength, - int batchSize, - int vectorSize, - const int seqLengthArray[], /* length of each sequence in the batch */ - void *paddingFill) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDataDescriptor_t, cudnnDataType_t, cudnnRNNDataLayout_t, int, int, int, const int [], void *); +cudnnStatus_t CUDNNWINAPI cudnnSetRNNDataDescriptor( + cudnnRNNDataDescriptor_t RNNDataDesc, cudnnDataType_t dataType, + cudnnRNNDataLayout_t layout, int maxSeqLength, int batchSize, + int vectorSize, + const int seqLengthArray[], /* length of each sequence in the batch */ + void *paddingFill) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnRNNDataDescriptor_t, cudnnDataType_t, cudnnRNNDataLayout_t, int, int, + int, const int[], void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetRNNDataDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(RNNDataDesc, dataType, layout, maxSeqLength, batchSize, vectorSize, seqLengthArray, paddingFill); + return func_ptr(RNNDataDesc, dataType, layout, maxSeqLength, batchSize, + vectorSize, seqLengthArray, paddingFill); } -cudnnStatus_t CUDNNWINAPI -cudnnGetRNNDataDescriptor(cudnnRNNDataDescriptor_t RNNDataDesc, - cudnnDataType_t *dataType, - cudnnRNNDataLayout_t *layout, - int *maxSeqLength, - int *batchSize, - int *vectorSize, - int arrayLengthRequested, - int seqLengthArray[], - void *paddingFill) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDataDescriptor_t, cudnnDataType_t *, cudnnRNNDataLayout_t *, int *, int *, int *, int, int [], void *); +cudnnStatus_t CUDNNWINAPI cudnnGetRNNDataDescriptor( + cudnnRNNDataDescriptor_t RNNDataDesc, cudnnDataType_t *dataType, + cudnnRNNDataLayout_t *layout, int *maxSeqLength, int *batchSize, + int *vectorSize, int arrayLengthRequested, int seqLengthArray[], + void *paddingFill) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnRNNDataDescriptor_t, cudnnDataType_t *, cudnnRNNDataLayout_t *, + int *, int *, int *, int, int[], void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNDataDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(RNNDataDesc, dataType, layout, maxSeqLength, batchSize, vectorSize, arrayLengthRequested, seqLengthArray, paddingFill); + return func_ptr(RNNDataDesc, dataType, layout, maxSeqLength, batchSize, + vectorSize, arrayLengthRequested, seqLengthArray, + paddingFill); } -cudnnStatus_t CUDNNWINAPI -cudnnRNNForwardTrainingEx(cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const cudnnRNNDataDescriptor_t xDesc, - const void *x, - const cudnnTensorDescriptor_t hxDesc, - const void *hx, - const cudnnTensorDescriptor_t cxDesc, - const void *cx, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnRNNDataDescriptor_t yDesc, - void *y, - const cudnnTensorDescriptor_t hyDesc, - void *hy, - const cudnnTensorDescriptor_t cyDesc, - void *cy, - const cudnnRNNDataDescriptor_t kDesc, /* reserved, should pass NULL */ - const void *keys, /* reserved, should pass NULL */ - const cudnnRNNDataDescriptor_t cDesc, /* reserved, should pass NULL */ - void *cAttn, /* reserved, should pass NULL */ - const cudnnRNNDataDescriptor_t iDesc, /* reserved, should pass NULL */ - void *iAttn, /* reserved, should pass NULL */ - const cudnnRNNDataDescriptor_t qDesc, /* reserved, should pass NULL */ - void *queries, /* reserved, should pass NULL */ - void *workSpace, - size_t workSpaceSizeInBytes, - void *reserveSpace, - size_t reserveSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const cudnnRNNDataDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnRNNDataDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, const cudnnRNNDataDescriptor_t, const void *, const cudnnRNNDataDescriptor_t, void *, const cudnnRNNDataDescriptor_t, void *, const cudnnRNNDataDescriptor_t, void *, void *, size_t, void *, size_t); +cudnnStatus_t CUDNNWINAPI cudnnRNNForwardTrainingEx( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const cudnnRNNDataDescriptor_t xDesc, const void *x, + const cudnnTensorDescriptor_t hxDesc, const void *hx, + const cudnnTensorDescriptor_t cxDesc, const void *cx, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnRNNDataDescriptor_t yDesc, void *y, + const cudnnTensorDescriptor_t hyDesc, void *hy, + const cudnnTensorDescriptor_t cyDesc, void *cy, + const cudnnRNNDataDescriptor_t kDesc, /* reserved, should pass NULL */ + const void *keys, /* reserved, should pass NULL */ + const cudnnRNNDataDescriptor_t cDesc, /* reserved, should pass NULL */ + void *cAttn, /* reserved, should pass NULL */ + const cudnnRNNDataDescriptor_t iDesc, /* reserved, should pass NULL */ + void *iAttn, /* reserved, should pass NULL */ + const cudnnRNNDataDescriptor_t qDesc, /* reserved, should pass NULL */ + void *queries, /* reserved, should pass NULL */ + void *workSpace, size_t workSpaceSizeInBytes, void *reserveSpace, + size_t reserveSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const cudnnRNNDataDescriptor_t, + const void *, const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnRNNDataDescriptor_t, void *, const cudnnTensorDescriptor_t, + void *, const cudnnTensorDescriptor_t, void *, + const cudnnRNNDataDescriptor_t, const void *, + const cudnnRNNDataDescriptor_t, void *, const cudnnRNNDataDescriptor_t, + void *, const cudnnRNNDataDescriptor_t, void *, void *, size_t, void *, + size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRNNForwardTrainingEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, xDesc, x, hxDesc, hx, cxDesc, cx, wDesc, w, yDesc, y, hyDesc, hy, cyDesc, cy, kDesc, keys, cDesc, cAttn, iDesc, iAttn, qDesc, queries, workSpace, workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes); + return func_ptr(handle, rnnDesc, xDesc, x, hxDesc, hx, cxDesc, cx, wDesc, w, + yDesc, y, hyDesc, hy, cyDesc, cy, kDesc, keys, cDesc, cAttn, + iDesc, iAttn, qDesc, queries, workSpace, workSpaceSizeInBytes, + reserveSpace, reserveSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnRNNForwardInferenceEx(cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const cudnnRNNDataDescriptor_t xDesc, - const void *x, - const cudnnTensorDescriptor_t hxDesc, - const void *hx, - const cudnnTensorDescriptor_t cxDesc, - const void *cx, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnRNNDataDescriptor_t yDesc, - void *y, - const cudnnTensorDescriptor_t hyDesc, - void *hy, - const cudnnTensorDescriptor_t cyDesc, - void *cy, - const cudnnRNNDataDescriptor_t kDesc, /* reserved, should pass NULL */ - const void *keys, /* reserved, should pass NULL */ - const cudnnRNNDataDescriptor_t cDesc, /* reserved, should pass NULL */ - void *cAttn, /* reserved, should pass NULL */ - const cudnnRNNDataDescriptor_t iDesc, /* reserved, should pass NULL */ - void *iAttn, /* reserved, should pass NULL */ - const cudnnRNNDataDescriptor_t qDesc, /* reserved, should pass NULL */ - void *queries, /* reserved, should pass NULL */ - void *workSpace, - size_t workSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const cudnnRNNDataDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnRNNDataDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, const cudnnRNNDataDescriptor_t, const void *, const cudnnRNNDataDescriptor_t, void *, const cudnnRNNDataDescriptor_t, void *, const cudnnRNNDataDescriptor_t, void *, void *, size_t); +cudnnStatus_t CUDNNWINAPI cudnnRNNForwardInferenceEx( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const cudnnRNNDataDescriptor_t xDesc, const void *x, + const cudnnTensorDescriptor_t hxDesc, const void *hx, + const cudnnTensorDescriptor_t cxDesc, const void *cx, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnRNNDataDescriptor_t yDesc, void *y, + const cudnnTensorDescriptor_t hyDesc, void *hy, + const cudnnTensorDescriptor_t cyDesc, void *cy, + const cudnnRNNDataDescriptor_t kDesc, /* reserved, should pass NULL */ + const void *keys, /* reserved, should pass NULL */ + const cudnnRNNDataDescriptor_t cDesc, /* reserved, should pass NULL */ + void *cAttn, /* reserved, should pass NULL */ + const cudnnRNNDataDescriptor_t iDesc, /* reserved, should pass NULL */ + void *iAttn, /* reserved, should pass NULL */ + const cudnnRNNDataDescriptor_t qDesc, /* reserved, should pass NULL */ + void *queries, /* reserved, should pass NULL */ + void *workSpace, size_t workSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const cudnnRNNDataDescriptor_t, + const void *, const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnRNNDataDescriptor_t, void *, const cudnnTensorDescriptor_t, + void *, const cudnnTensorDescriptor_t, void *, + const cudnnRNNDataDescriptor_t, const void *, + const cudnnRNNDataDescriptor_t, void *, const cudnnRNNDataDescriptor_t, + void *, const cudnnRNNDataDescriptor_t, void *, void *, size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRNNForwardInferenceEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, xDesc, x, hxDesc, hx, cxDesc, cx, wDesc, w, yDesc, y, hyDesc, hy, cyDesc, cy, kDesc, keys, cDesc, cAttn, iDesc, iAttn, qDesc, queries, workSpace, workSpaceSizeInBytes); + return func_ptr(handle, rnnDesc, xDesc, x, hxDesc, hx, cxDesc, cx, wDesc, w, + yDesc, y, hyDesc, hy, cyDesc, cy, kDesc, keys, cDesc, cAttn, + iDesc, iAttn, qDesc, queries, workSpace, + workSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnRNNBackwardDataEx(cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const cudnnRNNDataDescriptor_t yDesc, - const void *y, - const cudnnRNNDataDescriptor_t dyDesc, - const void *dy, - const cudnnRNNDataDescriptor_t dcDesc, /* reserved, should pass NULL */ - const void *dcAttn, /* reserved, should pass NULL */ - const cudnnTensorDescriptor_t dhyDesc, - const void *dhy, - const cudnnTensorDescriptor_t dcyDesc, - const void *dcy, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnTensorDescriptor_t hxDesc, - const void *hx, - const cudnnTensorDescriptor_t cxDesc, - const void *cx, - const cudnnRNNDataDescriptor_t dxDesc, - void *dx, - const cudnnTensorDescriptor_t dhxDesc, - void *dhx, - const cudnnTensorDescriptor_t dcxDesc, - void *dcx, - const cudnnRNNDataDescriptor_t dkDesc, /* reserved, should pass NULL */ - void *dkeys, /* reserved, should pass NULL */ - void *workSpace, - size_t workSpaceSizeInBytes, - void *reserveSpace, - size_t reserveSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const cudnnRNNDataDescriptor_t, const void *, const cudnnRNNDataDescriptor_t, const void *, const cudnnRNNDataDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnRNNDataDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, const cudnnRNNDataDescriptor_t, void *, void *, size_t, void *, size_t); +cudnnStatus_t CUDNNWINAPI cudnnRNNBackwardDataEx( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const cudnnRNNDataDescriptor_t yDesc, const void *y, + const cudnnRNNDataDescriptor_t dyDesc, const void *dy, + const cudnnRNNDataDescriptor_t dcDesc, /* reserved, should pass NULL */ + const void *dcAttn, /* reserved, should pass NULL */ + const cudnnTensorDescriptor_t dhyDesc, const void *dhy, + const cudnnTensorDescriptor_t dcyDesc, const void *dcy, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnTensorDescriptor_t hxDesc, const void *hx, + const cudnnTensorDescriptor_t cxDesc, const void *cx, + const cudnnRNNDataDescriptor_t dxDesc, void *dx, + const cudnnTensorDescriptor_t dhxDesc, void *dhx, + const cudnnTensorDescriptor_t dcxDesc, void *dcx, + const cudnnRNNDataDescriptor_t dkDesc, /* reserved, should pass NULL */ + void *dkeys, /* reserved, should pass NULL */ + void *workSpace, size_t workSpaceSizeInBytes, void *reserveSpace, + size_t reserveSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const cudnnRNNDataDescriptor_t, + const void *, const cudnnRNNDataDescriptor_t, const void *, + const cudnnRNNDataDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnRNNDataDescriptor_t, void *, const cudnnTensorDescriptor_t, + void *, const cudnnTensorDescriptor_t, void *, + const cudnnRNNDataDescriptor_t, void *, void *, size_t, void *, size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRNNBackwardDataEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, yDesc, y, dyDesc, dy, dcDesc, dcAttn, dhyDesc, dhy, dcyDesc, dcy, wDesc, w, hxDesc, hx, cxDesc, cx, dxDesc, dx, dhxDesc, dhx, dcxDesc, dcx, dkDesc, dkeys, workSpace, workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes); + return func_ptr(handle, rnnDesc, yDesc, y, dyDesc, dy, dcDesc, dcAttn, + dhyDesc, dhy, dcyDesc, dcy, wDesc, w, hxDesc, hx, cxDesc, cx, + dxDesc, dx, dhxDesc, dhx, dcxDesc, dcx, dkDesc, dkeys, + workSpace, workSpaceSizeInBytes, reserveSpace, + reserveSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnRNNBackwardWeightsEx(cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const cudnnRNNDataDescriptor_t xDesc, - const void *x, - const cudnnTensorDescriptor_t hxDesc, - const void *hx, - const cudnnRNNDataDescriptor_t yDesc, - const void *y, - void *workSpace, - size_t workSpaceSizeInBytes, - const cudnnFilterDescriptor_t dwDesc, - void *dw, - void *reserveSpace, - size_t reserveSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const cudnnRNNDataDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnRNNDataDescriptor_t, const void *, void *, size_t, const cudnnFilterDescriptor_t, void *, void *, size_t); +cudnnStatus_t CUDNNWINAPI cudnnRNNBackwardWeightsEx( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const cudnnRNNDataDescriptor_t xDesc, const void *x, + const cudnnTensorDescriptor_t hxDesc, const void *hx, + const cudnnRNNDataDescriptor_t yDesc, const void *y, void *workSpace, + size_t workSpaceSizeInBytes, const cudnnFilterDescriptor_t dwDesc, void *dw, + void *reserveSpace, size_t reserveSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const cudnnRNNDataDescriptor_t, + const void *, const cudnnTensorDescriptor_t, const void *, + const cudnnRNNDataDescriptor_t, const void *, void *, size_t, + const cudnnFilterDescriptor_t, void *, void *, size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRNNBackwardWeightsEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, xDesc, x, hxDesc, hx, yDesc, y, workSpace, workSpaceSizeInBytes, dwDesc, dw, reserveSpace, reserveSpaceSizeInBytes); + return func_ptr(handle, rnnDesc, xDesc, x, hxDesc, hx, yDesc, y, workSpace, + workSpaceSizeInBytes, dwDesc, dw, reserveSpace, + reserveSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnSetRNNDescriptor_v6(cudnnHandle_t handle, - cudnnRNNDescriptor_t rnnDesc, - const int hiddenSize, - const int numLayers, - cudnnDropoutDescriptor_t dropoutDesc, - cudnnRNNInputMode_t inputMode, - cudnnDirectionMode_t direction, - cudnnRNNMode_t mode, - cudnnRNNAlgo_t algo, - cudnnDataType_t dataType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnRNNDescriptor_t, const int, const int, cudnnDropoutDescriptor_t, cudnnRNNInputMode_t, cudnnDirectionMode_t, cudnnRNNMode_t, cudnnRNNAlgo_t, cudnnDataType_t); +cudnnStatus_t CUDNNWINAPI cudnnSetRNNDescriptor_v6( + cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc, const int hiddenSize, + const int numLayers, cudnnDropoutDescriptor_t dropoutDesc, + cudnnRNNInputMode_t inputMode, cudnnDirectionMode_t direction, + cudnnRNNMode_t mode, cudnnRNNAlgo_t algo, cudnnDataType_t dataType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnRNNDescriptor_t, const int, const int, + cudnnDropoutDescriptor_t, cudnnRNNInputMode_t, cudnnDirectionMode_t, + cudnnRNNMode_t, cudnnRNNAlgo_t, cudnnDataType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetRNNDescriptor_v6"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, hiddenSize, numLayers, dropoutDesc, inputMode, direction, mode, algo, dataType); + return func_ptr(handle, rnnDesc, hiddenSize, numLayers, dropoutDesc, + inputMode, direction, mode, algo, dataType); } -cudnnStatus_t CUDNNWINAPI -cudnnSetRNNDescriptor_v5(cudnnRNNDescriptor_t rnnDesc, - int hiddenSize, - int numLayers, - cudnnDropoutDescriptor_t dropoutDesc, - cudnnRNNInputMode_t inputMode, - cudnnDirectionMode_t direction, - cudnnRNNMode_t mode, - cudnnDataType_t dataType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t, int, int, cudnnDropoutDescriptor_t, cudnnRNNInputMode_t, cudnnDirectionMode_t, cudnnRNNMode_t, cudnnDataType_t); +cudnnStatus_t CUDNNWINAPI cudnnSetRNNDescriptor_v5( + cudnnRNNDescriptor_t rnnDesc, int hiddenSize, int numLayers, + cudnnDropoutDescriptor_t dropoutDesc, cudnnRNNInputMode_t inputMode, + cudnnDirectionMode_t direction, cudnnRNNMode_t mode, + cudnnDataType_t dataType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnRNNDescriptor_t, int, int, cudnnDropoutDescriptor_t, + cudnnRNNInputMode_t, cudnnDirectionMode_t, cudnnRNNMode_t, + cudnnDataType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetRNNDescriptor_v5"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(rnnDesc, hiddenSize, numLayers, dropoutDesc, inputMode, direction, mode, dataType); + return func_ptr(rnnDesc, hiddenSize, numLayers, dropoutDesc, inputMode, + direction, mode, dataType); } } // extern "C" diff --git a/tensorflow/stream_executor/cuda/cudnn_7_6.inc b/tensorflow/stream_executor/cuda/cudnn_7_6.inc index 7a5f1c9751d..9dd420a9022 100644 --- a/tensorflow/stream_executor/cuda/cudnn_7_6.inc +++ b/tensorflow/stream_executor/cuda/cudnn_7_6.inc @@ -2,73 +2,71 @@ extern "C" { -size_t CUDNNWINAPI -cudnnGetVersion(void) { - using FuncPtr = size_t (CUDNNWINAPI *)(); +size_t CUDNNWINAPI cudnnGetVersion(void) { + using FuncPtr = size_t(CUDNNWINAPI *)(); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetVersion"); if (!func_ptr) return 0; return func_ptr(); } -size_t CUDNNWINAPI -cudnnGetCudartVersion(void) { - using FuncPtr = size_t (CUDNNWINAPI *)(); +size_t CUDNNWINAPI cudnnGetCudartVersion(void) { + using FuncPtr = size_t(CUDNNWINAPI *)(); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetCudartVersion"); if (!func_ptr) return 0; return func_ptr(); } -const char *CUDNNWINAPI -cudnnGetErrorString(cudnnStatus_t status) { - using FuncPtr = const char * (CUDNNWINAPI *)(cudnnStatus_t); +const char *CUDNNWINAPI cudnnGetErrorString(cudnnStatus_t status) { + using FuncPtr = const char *(CUDNNWINAPI *)(cudnnStatus_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetErrorString"); if (!func_ptr) return "cudnnGetErrorString symbol not found."; return func_ptr(status); } -cudnnStatus_t CUDNNWINAPI -cudnnQueryRuntimeError(cudnnHandle_t handle, cudnnStatus_t *rstatus, cudnnErrQueryMode_t mode, cudnnRuntimeTag_t *tag) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnStatus_t *, cudnnErrQueryMode_t, cudnnRuntimeTag_t *); +cudnnStatus_t CUDNNWINAPI cudnnQueryRuntimeError(cudnnHandle_t handle, + cudnnStatus_t *rstatus, + cudnnErrQueryMode_t mode, + cudnnRuntimeTag_t *tag) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnStatus_t *, cudnnErrQueryMode_t, cudnnRuntimeTag_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnQueryRuntimeError"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rstatus, mode, tag); } -cudnnStatus_t CUDNNWINAPI -cudnnGetProperty(libraryPropertyType type, int *value) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(libraryPropertyType, int *); +cudnnStatus_t CUDNNWINAPI cudnnGetProperty(libraryPropertyType type, + int *value) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(libraryPropertyType, int *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetProperty"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(type, value); } -cudnnStatus_t CUDNNWINAPI -cudnnCreate(cudnnHandle_t *handle) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t *); +cudnnStatus_t CUDNNWINAPI cudnnCreate(cudnnHandle_t *handle) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreate"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle); } -cudnnStatus_t CUDNNWINAPI -cudnnDestroy(cudnnHandle_t handle) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t); +cudnnStatus_t CUDNNWINAPI cudnnDestroy(cudnnHandle_t handle) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroy"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle); } -cudnnStatus_t CUDNNWINAPI -cudnnSetStream(cudnnHandle_t handle, cudaStream_t streamId) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudaStream_t); +cudnnStatus_t CUDNNWINAPI cudnnSetStream(cudnnHandle_t handle, + cudaStream_t streamId) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t, cudaStream_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetStream"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, streamId); } -cudnnStatus_t CUDNNWINAPI -cudnnGetStream(cudnnHandle_t handle, cudaStream_t *streamId) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudaStream_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetStream(cudnnHandle_t handle, + cudaStream_t *streamId) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t, cudaStream_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetStream"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, streamId); @@ -76,100 +74,97 @@ cudnnGetStream(cudnnHandle_t handle, cudaStream_t *streamId) { cudnnStatus_t CUDNNWINAPI cudnnCreateTensorDescriptor(cudnnTensorDescriptor_t *tensorDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorDescriptor_t *); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnTensorDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(tensorDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnSetTensor4dDescriptor(cudnnTensorDescriptor_t tensorDesc, - cudnnTensorFormat_t format, - cudnnDataType_t dataType, /* image data type */ - int n, /* number of inputs (batch size) */ - int c, /* number of input feature maps */ - int h, /* height of input section */ - int w) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorDescriptor_t, cudnnTensorFormat_t, cudnnDataType_t, int, int, int, int); +cudnnStatus_t CUDNNWINAPI cudnnSetTensor4dDescriptor( + cudnnTensorDescriptor_t tensorDesc, cudnnTensorFormat_t format, + cudnnDataType_t dataType, /* image data type */ + int n, /* number of inputs (batch size) */ + int c, /* number of input feature maps */ + int h, /* height of input section */ + int w) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnTensorDescriptor_t, cudnnTensorFormat_t, + cudnnDataType_t, int, int, int, int); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetTensor4dDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(tensorDesc, format, dataType, n, c, h, w); } -cudnnStatus_t CUDNNWINAPI -cudnnSetTensor4dDescriptorEx(cudnnTensorDescriptor_t tensorDesc, - cudnnDataType_t dataType, /* image data type */ - int n, /* number of inputs (batch size) */ - int c, /* number of input feature maps */ - int h, /* height of input section */ - int w, /* width of input section */ - int nStride, - int cStride, - int hStride, - int wStride) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorDescriptor_t, cudnnDataType_t, int, int, int, int, int, int, int, int); +cudnnStatus_t CUDNNWINAPI cudnnSetTensor4dDescriptorEx( + cudnnTensorDescriptor_t tensorDesc, + cudnnDataType_t dataType, /* image data type */ + int n, /* number of inputs (batch size) */ + int c, /* number of input feature maps */ + int h, /* height of input section */ + int w, /* width of input section */ + int nStride, int cStride, int hStride, int wStride) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnTensorDescriptor_t, cudnnDataType_t, + int, int, int, int, int, int, int, int); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetTensor4dDescriptorEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(tensorDesc, dataType, n, c, h, w, nStride, cStride, hStride, wStride); + return func_ptr(tensorDesc, dataType, n, c, h, w, nStride, cStride, hStride, + wStride); } -cudnnStatus_t CUDNNWINAPI -cudnnGetTensor4dDescriptor(const cudnnTensorDescriptor_t tensorDesc, - cudnnDataType_t *dataType, /* image data type */ - int *n, /* number of inputs (batch size) */ - int *c, /* number of input feature maps */ - int *h, /* height of input section */ - int *w, /* width of input section */ - int *nStride, - int *cStride, - int *hStride, - int *wStride) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnTensorDescriptor_t, cudnnDataType_t *, int *, int *, int *, int *, int *, int *, int *, int *); +cudnnStatus_t CUDNNWINAPI cudnnGetTensor4dDescriptor( + const cudnnTensorDescriptor_t tensorDesc, + cudnnDataType_t *dataType, /* image data type */ + int *n, /* number of inputs (batch size) */ + int *c, /* number of input feature maps */ + int *h, /* height of input section */ + int *w, /* width of input section */ + int *nStride, int *cStride, int *hStride, int *wStride) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnTensorDescriptor_t, cudnnDataType_t *, int *, int *, int *, + int *, int *, int *, int *, int *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetTensor4dDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(tensorDesc, dataType, n, c, h, w, nStride, cStride, hStride, wStride); + return func_ptr(tensorDesc, dataType, n, c, h, w, nStride, cStride, hStride, + wStride); } -cudnnStatus_t CUDNNWINAPI -cudnnSetTensorNdDescriptor(cudnnTensorDescriptor_t tensorDesc, - cudnnDataType_t dataType, - int nbDims, - const int dimA[], - const int strideA[]) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorDescriptor_t, cudnnDataType_t, int, const int [], const int []); +cudnnStatus_t CUDNNWINAPI cudnnSetTensorNdDescriptor( + cudnnTensorDescriptor_t tensorDesc, cudnnDataType_t dataType, int nbDims, + const int dimA[], const int strideA[]) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnTensorDescriptor_t, cudnnDataType_t, int, const int[], const int[]); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetTensorNdDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(tensorDesc, dataType, nbDims, dimA, strideA); } -cudnnStatus_t CUDNNWINAPI -cudnnSetTensorNdDescriptorEx(cudnnTensorDescriptor_t tensorDesc, - cudnnTensorFormat_t format, - cudnnDataType_t dataType, - int nbDims, - const int dimA[]) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorDescriptor_t, cudnnTensorFormat_t, cudnnDataType_t, int, const int []); +cudnnStatus_t CUDNNWINAPI cudnnSetTensorNdDescriptorEx( + cudnnTensorDescriptor_t tensorDesc, cudnnTensorFormat_t format, + cudnnDataType_t dataType, int nbDims, const int dimA[]) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnTensorDescriptor_t, cudnnTensorFormat_t, + cudnnDataType_t, int, const int[]); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetTensorNdDescriptorEx"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(tensorDesc, format, dataType, nbDims, dimA); } -cudnnStatus_t CUDNNWINAPI -cudnnGetTensorNdDescriptor(const cudnnTensorDescriptor_t tensorDesc, - int nbDimsRequested, - cudnnDataType_t *dataType, - int *nbDims, - int dimA[], - int strideA[]) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnTensorDescriptor_t, int, cudnnDataType_t *, int *, int [], int []); +cudnnStatus_t CUDNNWINAPI cudnnGetTensorNdDescriptor( + const cudnnTensorDescriptor_t tensorDesc, int nbDimsRequested, + cudnnDataType_t *dataType, int *nbDims, int dimA[], int strideA[]) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(const cudnnTensorDescriptor_t, int, + cudnnDataType_t *, int *, int[], int[]); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetTensorNdDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(tensorDesc, nbDimsRequested, dataType, nbDims, dimA, strideA); } -cudnnStatus_t CUDNNWINAPI -cudnnGetTensorSizeInBytes(const cudnnTensorDescriptor_t tensorDesc, size_t *size) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnTensorDescriptor_t, size_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetTensorSizeInBytes( + const cudnnTensorDescriptor_t tensorDesc, size_t *size) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(const cudnnTensorDescriptor_t, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetTensorSizeInBytes"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(tensorDesc, size); @@ -177,126 +172,141 @@ cudnnGetTensorSizeInBytes(const cudnnTensorDescriptor_t tensorDesc, size_t *size cudnnStatus_t CUDNNWINAPI cudnnDestroyTensorDescriptor(cudnnTensorDescriptor_t tensorDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorDescriptor_t); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnTensorDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(tensorDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnInitTransformDest(const cudnnTensorTransformDescriptor_t transformDesc, - const cudnnTensorDescriptor_t srcDesc, - cudnnTensorDescriptor_t destDesc, - size_t *destSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnTensorTransformDescriptor_t, const cudnnTensorDescriptor_t, cudnnTensorDescriptor_t, size_t *); +cudnnStatus_t CUDNNWINAPI cudnnInitTransformDest( + const cudnnTensorTransformDescriptor_t transformDesc, + const cudnnTensorDescriptor_t srcDesc, cudnnTensorDescriptor_t destDesc, + size_t *destSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnTensorTransformDescriptor_t, const cudnnTensorDescriptor_t, + cudnnTensorDescriptor_t, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnInitTransformDest"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(transformDesc, srcDesc, destDesc, destSizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnCreateTensorTransformDescriptor(cudnnTensorTransformDescriptor_t *transformDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorTransformDescriptor_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateTensorTransformDescriptor"); +cudnnStatus_t CUDNNWINAPI cudnnCreateTensorTransformDescriptor( + cudnnTensorTransformDescriptor_t *transformDesc) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnTensorTransformDescriptor_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnCreateTensorTransformDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(transformDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnSetTensorTransformDescriptor(cudnnTensorTransformDescriptor_t transformDesc, - const uint32_t nbDims, - const cudnnTensorFormat_t destFormat, - const int32_t padBeforeA[], - const int32_t padAfterA[], - const uint32_t foldA[], - const cudnnFoldingDirection_t direction) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorTransformDescriptor_t, const uint32_t, const cudnnTensorFormat_t, const int32_t [], const int32_t [], const uint32_t [], const cudnnFoldingDirection_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetTensorTransformDescriptor"); +cudnnStatus_t CUDNNWINAPI cudnnSetTensorTransformDescriptor( + cudnnTensorTransformDescriptor_t transformDesc, const uint32_t nbDims, + const cudnnTensorFormat_t destFormat, const int32_t padBeforeA[], + const int32_t padAfterA[], const uint32_t foldA[], + const cudnnFoldingDirection_t direction) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnTensorTransformDescriptor_t, const uint32_t, + const cudnnTensorFormat_t, const int32_t[], const int32_t[], + const uint32_t[], const cudnnFoldingDirection_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnSetTensorTransformDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(transformDesc, nbDims, destFormat, padBeforeA, padAfterA, foldA, direction); + return func_ptr(transformDesc, nbDims, destFormat, padBeforeA, padAfterA, + foldA, direction); } -cudnnStatus_t CUDNNWINAPI -cudnnGetTensorTransformDescriptor(cudnnTensorTransformDescriptor_t transformDesc, - uint32_t nbDimsRequested, - cudnnTensorFormat_t *destFormat, - int32_t padBeforeA[], - int32_t padAfterA[], - uint32_t foldA[], - cudnnFoldingDirection_t *direction) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorTransformDescriptor_t, uint32_t, cudnnTensorFormat_t *, int32_t [], int32_t [], uint32_t [], cudnnFoldingDirection_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetTensorTransformDescriptor"); +cudnnStatus_t CUDNNWINAPI cudnnGetTensorTransformDescriptor( + cudnnTensorTransformDescriptor_t transformDesc, uint32_t nbDimsRequested, + cudnnTensorFormat_t *destFormat, int32_t padBeforeA[], int32_t padAfterA[], + uint32_t foldA[], cudnnFoldingDirection_t *direction) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnTensorTransformDescriptor_t, uint32_t, cudnnTensorFormat_t *, + int32_t[], int32_t[], uint32_t[], cudnnFoldingDirection_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetTensorTransformDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(transformDesc, nbDimsRequested, destFormat, padBeforeA, padAfterA, foldA, direction); + return func_ptr(transformDesc, nbDimsRequested, destFormat, padBeforeA, + padAfterA, foldA, direction); } -cudnnStatus_t CUDNNWINAPI -cudnnDestroyTensorTransformDescriptor(cudnnTensorTransformDescriptor_t transformDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorTransformDescriptor_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyTensorTransformDescriptor"); +cudnnStatus_t CUDNNWINAPI cudnnDestroyTensorTransformDescriptor( + cudnnTensorTransformDescriptor_t transformDesc) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnTensorTransformDescriptor_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnDestroyTensorTransformDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(transformDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnTransformTensor(cudnnHandle_t handle, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t yDesc, - void *y) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnTransformTensor( + cudnnHandle_t handle, const void *alpha, + const cudnnTensorDescriptor_t xDesc, const void *x, const void *beta, + const cudnnTensorDescriptor_t yDesc, void *y) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, + const void *, const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnTransformTensor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, alpha, xDesc, x, beta, yDesc, y); } -cudnnStatus_t CUDNNWINAPI -cudnnTransformTensorEx(cudnnHandle_t handle, - const cudnnTensorTransformDescriptor_t transDesc, - const void *alpha, - const cudnnTensorDescriptor_t srcDesc, - const void *srcData, - const void *beta, - const cudnnTensorDescriptor_t destDesc, - void *destData) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorTransformDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnTransformTensorEx( + cudnnHandle_t handle, const cudnnTensorTransformDescriptor_t transDesc, + const void *alpha, const cudnnTensorDescriptor_t srcDesc, + const void *srcData, const void *beta, + const cudnnTensorDescriptor_t destDesc, void *destData) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorTransformDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnTransformTensorEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, transDesc, alpha, srcDesc, srcData, beta, destDesc, destData); + return func_ptr(handle, transDesc, alpha, srcDesc, srcData, beta, destDesc, + destData); } -cudnnStatus_t CUDNNWINAPI -cudnnGetFoldedConvBackwardDataDescriptors(const cudnnHandle_t handle, - const cudnnFilterDescriptor_t filterDesc, - const cudnnTensorDescriptor_t diffDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t gradDesc, - const cudnnTensorFormat_t transformFormat, - cudnnFilterDescriptor_t foldedFilterDesc, - cudnnTensorDescriptor_t paddedDiffDesc, - cudnnConvolutionDescriptor_t foldedConvDesc, - cudnnTensorDescriptor_t foldedGradDesc, - cudnnTensorTransformDescriptor_t filterFoldTransDesc, - cudnnTensorTransformDescriptor_t diffPadTransDesc, - cudnnTensorTransformDescriptor_t gradFoldTransDesc, - cudnnTensorTransformDescriptor_t gradUnfoldTransDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnHandle_t, const cudnnFilterDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, const cudnnTensorFormat_t, cudnnFilterDescriptor_t, cudnnTensorDescriptor_t, cudnnConvolutionDescriptor_t, cudnnTensorDescriptor_t, cudnnTensorTransformDescriptor_t, cudnnTensorTransformDescriptor_t, cudnnTensorTransformDescriptor_t, cudnnTensorTransformDescriptor_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetFoldedConvBackwardDataDescriptors"); +cudnnStatus_t CUDNNWINAPI cudnnGetFoldedConvBackwardDataDescriptors( + const cudnnHandle_t handle, const cudnnFilterDescriptor_t filterDesc, + const cudnnTensorDescriptor_t diffDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t gradDesc, + const cudnnTensorFormat_t transformFormat, + cudnnFilterDescriptor_t foldedFilterDesc, + cudnnTensorDescriptor_t paddedDiffDesc, + cudnnConvolutionDescriptor_t foldedConvDesc, + cudnnTensorDescriptor_t foldedGradDesc, + cudnnTensorTransformDescriptor_t filterFoldTransDesc, + cudnnTensorTransformDescriptor_t diffPadTransDesc, + cudnnTensorTransformDescriptor_t gradFoldTransDesc, + cudnnTensorTransformDescriptor_t gradUnfoldTransDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnHandle_t, const cudnnFilterDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnTensorFormat_t, + cudnnFilterDescriptor_t, cudnnTensorDescriptor_t, + cudnnConvolutionDescriptor_t, cudnnTensorDescriptor_t, + cudnnTensorTransformDescriptor_t, cudnnTensorTransformDescriptor_t, + cudnnTensorTransformDescriptor_t, cudnnTensorTransformDescriptor_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetFoldedConvBackwardDataDescriptors"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, filterDesc, diffDesc, convDesc, gradDesc, transformFormat, foldedFilterDesc, paddedDiffDesc, foldedConvDesc, foldedGradDesc, filterFoldTransDesc, diffPadTransDesc, gradFoldTransDesc, gradUnfoldTransDesc); + return func_ptr(handle, filterDesc, diffDesc, convDesc, gradDesc, + transformFormat, foldedFilterDesc, paddedDiffDesc, + foldedConvDesc, foldedGradDesc, filterFoldTransDesc, + diffPadTransDesc, gradFoldTransDesc, gradUnfoldTransDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnAddTensor(cudnnHandle_t handle, - const void *alpha, - const cudnnTensorDescriptor_t aDesc, - const void *A, - const void *beta, - const cudnnTensorDescriptor_t cDesc, - void *C) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnAddTensor(cudnnHandle_t handle, + const void *alpha, + const cudnnTensorDescriptor_t aDesc, + const void *A, const void *beta, + const cudnnTensorDescriptor_t cDesc, + void *C) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, + const void *, const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnAddTensor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, alpha, aDesc, A, beta, cDesc, C); @@ -304,29 +314,29 @@ cudnnAddTensor(cudnnHandle_t handle, cudnnStatus_t CUDNNWINAPI cudnnCreateOpTensorDescriptor(cudnnOpTensorDescriptor_t *opTensorDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnOpTensorDescriptor_t *); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnOpTensorDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateOpTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(opTensorDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnSetOpTensorDescriptor(cudnnOpTensorDescriptor_t opTensorDesc, - cudnnOpTensorOp_t opTensorOp, - cudnnDataType_t opTensorCompType, - cudnnNanPropagation_t opTensorNanOpt) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnOpTensorDescriptor_t, cudnnOpTensorOp_t, cudnnDataType_t, cudnnNanPropagation_t); +cudnnStatus_t CUDNNWINAPI cudnnSetOpTensorDescriptor( + cudnnOpTensorDescriptor_t opTensorDesc, cudnnOpTensorOp_t opTensorOp, + cudnnDataType_t opTensorCompType, cudnnNanPropagation_t opTensorNanOpt) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnOpTensorDescriptor_t, cudnnOpTensorOp_t, + cudnnDataType_t, cudnnNanPropagation_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetOpTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(opTensorDesc, opTensorOp, opTensorCompType, opTensorNanOpt); } -cudnnStatus_t CUDNNWINAPI -cudnnGetOpTensorDescriptor(const cudnnOpTensorDescriptor_t opTensorDesc, - cudnnOpTensorOp_t *opTensorOp, - cudnnDataType_t *opTensorCompType, - cudnnNanPropagation_t *opTensorNanOpt) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnOpTensorDescriptor_t, cudnnOpTensorOp_t *, cudnnDataType_t *, cudnnNanPropagation_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetOpTensorDescriptor( + const cudnnOpTensorDescriptor_t opTensorDesc, cudnnOpTensorOp_t *opTensorOp, + cudnnDataType_t *opTensorCompType, cudnnNanPropagation_t *opTensorNanOpt) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnOpTensorDescriptor_t, cudnnOpTensorOp_t *, cudnnDataType_t *, + cudnnNanPropagation_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetOpTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(opTensorDesc, opTensorOp, opTensorCompType, opTensorNanOpt); @@ -334,126 +344,136 @@ cudnnGetOpTensorDescriptor(const cudnnOpTensorDescriptor_t opTensorDesc, cudnnStatus_t CUDNNWINAPI cudnnDestroyOpTensorDescriptor(cudnnOpTensorDescriptor_t opTensorDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnOpTensorDescriptor_t); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnOpTensorDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyOpTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(opTensorDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnOpTensor(cudnnHandle_t handle, - const cudnnOpTensorDescriptor_t opTensorDesc, - const void *alpha1, - const cudnnTensorDescriptor_t aDesc, - const void *A, - const void *alpha2, - const cudnnTensorDescriptor_t bDesc, - const void *B, - const void *beta, - const cudnnTensorDescriptor_t cDesc, - void *C) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnOpTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnOpTensor( + cudnnHandle_t handle, const cudnnOpTensorDescriptor_t opTensorDesc, + const void *alpha1, const cudnnTensorDescriptor_t aDesc, const void *A, + const void *alpha2, const cudnnTensorDescriptor_t bDesc, const void *B, + const void *beta, const cudnnTensorDescriptor_t cDesc, void *C) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnOpTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnOpTensor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, opTensorDesc, alpha1, aDesc, A, alpha2, bDesc, B, beta, cDesc, C); + return func_ptr(handle, opTensorDesc, alpha1, aDesc, A, alpha2, bDesc, B, + beta, cDesc, C); } -cudnnStatus_t CUDNNWINAPI -cudnnCreateReduceTensorDescriptor(cudnnReduceTensorDescriptor_t *reduceTensorDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnReduceTensorDescriptor_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateReduceTensorDescriptor"); +cudnnStatus_t CUDNNWINAPI cudnnCreateReduceTensorDescriptor( + cudnnReduceTensorDescriptor_t *reduceTensorDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnReduceTensorDescriptor_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnCreateReduceTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(reduceTensorDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnSetReduceTensorDescriptor(cudnnReduceTensorDescriptor_t reduceTensorDesc, - cudnnReduceTensorOp_t reduceTensorOp, - cudnnDataType_t reduceTensorCompType, - cudnnNanPropagation_t reduceTensorNanOpt, - cudnnReduceTensorIndices_t reduceTensorIndices, - cudnnIndicesType_t reduceTensorIndicesType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnReduceTensorDescriptor_t, cudnnReduceTensorOp_t, cudnnDataType_t, cudnnNanPropagation_t, cudnnReduceTensorIndices_t, cudnnIndicesType_t); +cudnnStatus_t CUDNNWINAPI cudnnSetReduceTensorDescriptor( + cudnnReduceTensorDescriptor_t reduceTensorDesc, + cudnnReduceTensorOp_t reduceTensorOp, cudnnDataType_t reduceTensorCompType, + cudnnNanPropagation_t reduceTensorNanOpt, + cudnnReduceTensorIndices_t reduceTensorIndices, + cudnnIndicesType_t reduceTensorIndicesType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnReduceTensorDescriptor_t, cudnnReduceTensorOp_t, cudnnDataType_t, + cudnnNanPropagation_t, cudnnReduceTensorIndices_t, cudnnIndicesType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetReduceTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(reduceTensorDesc, reduceTensorOp, reduceTensorCompType, reduceTensorNanOpt, reduceTensorIndices, reduceTensorIndicesType); + return func_ptr(reduceTensorDesc, reduceTensorOp, reduceTensorCompType, + reduceTensorNanOpt, reduceTensorIndices, + reduceTensorIndicesType); } -cudnnStatus_t CUDNNWINAPI -cudnnGetReduceTensorDescriptor(const cudnnReduceTensorDescriptor_t reduceTensorDesc, - cudnnReduceTensorOp_t *reduceTensorOp, - cudnnDataType_t *reduceTensorCompType, - cudnnNanPropagation_t *reduceTensorNanOpt, - cudnnReduceTensorIndices_t *reduceTensorIndices, - cudnnIndicesType_t *reduceTensorIndicesType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnReduceTensorDescriptor_t, cudnnReduceTensorOp_t *, cudnnDataType_t *, cudnnNanPropagation_t *, cudnnReduceTensorIndices_t *, cudnnIndicesType_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetReduceTensorDescriptor( + const cudnnReduceTensorDescriptor_t reduceTensorDesc, + cudnnReduceTensorOp_t *reduceTensorOp, + cudnnDataType_t *reduceTensorCompType, + cudnnNanPropagation_t *reduceTensorNanOpt, + cudnnReduceTensorIndices_t *reduceTensorIndices, + cudnnIndicesType_t *reduceTensorIndicesType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnReduceTensorDescriptor_t, cudnnReduceTensorOp_t *, + cudnnDataType_t *, cudnnNanPropagation_t *, cudnnReduceTensorIndices_t *, + cudnnIndicesType_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetReduceTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(reduceTensorDesc, reduceTensorOp, reduceTensorCompType, reduceTensorNanOpt, reduceTensorIndices, reduceTensorIndicesType); + return func_ptr(reduceTensorDesc, reduceTensorOp, reduceTensorCompType, + reduceTensorNanOpt, reduceTensorIndices, + reduceTensorIndicesType); } -cudnnStatus_t CUDNNWINAPI -cudnnDestroyReduceTensorDescriptor(cudnnReduceTensorDescriptor_t reduceTensorDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnReduceTensorDescriptor_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyReduceTensorDescriptor"); +cudnnStatus_t CUDNNWINAPI cudnnDestroyReduceTensorDescriptor( + cudnnReduceTensorDescriptor_t reduceTensorDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnReduceTensorDescriptor_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnDestroyReduceTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(reduceTensorDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnGetReductionIndicesSize(cudnnHandle_t handle, - const cudnnReduceTensorDescriptor_t reduceTensorDesc, - const cudnnTensorDescriptor_t aDesc, - const cudnnTensorDescriptor_t cDesc, - size_t *sizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnReduceTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, size_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetReductionIndicesSize( + cudnnHandle_t handle, const cudnnReduceTensorDescriptor_t reduceTensorDesc, + const cudnnTensorDescriptor_t aDesc, const cudnnTensorDescriptor_t cDesc, + size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnReduceTensorDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetReductionIndicesSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, reduceTensorDesc, aDesc, cDesc, sizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnGetReductionWorkspaceSize(cudnnHandle_t handle, - const cudnnReduceTensorDescriptor_t reduceTensorDesc, - const cudnnTensorDescriptor_t aDesc, - const cudnnTensorDescriptor_t cDesc, - size_t *sizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnReduceTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, size_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetReductionWorkspaceSize( + cudnnHandle_t handle, const cudnnReduceTensorDescriptor_t reduceTensorDesc, + const cudnnTensorDescriptor_t aDesc, const cudnnTensorDescriptor_t cDesc, + size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnReduceTensorDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetReductionWorkspaceSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, reduceTensorDesc, aDesc, cDesc, sizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnReduceTensor(cudnnHandle_t handle, - const cudnnReduceTensorDescriptor_t reduceTensorDesc, - void *indices, - size_t indicesSizeInBytes, - void *workspace, - size_t workspaceSizeInBytes, - const void *alpha, - const cudnnTensorDescriptor_t aDesc, - const void *A, - const void *beta, - const cudnnTensorDescriptor_t cDesc, - void *C) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnReduceTensorDescriptor_t, void *, size_t, void *, size_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnReduceTensor( + cudnnHandle_t handle, const cudnnReduceTensorDescriptor_t reduceTensorDesc, + void *indices, size_t indicesSizeInBytes, void *workspace, + size_t workspaceSizeInBytes, const void *alpha, + const cudnnTensorDescriptor_t aDesc, const void *A, const void *beta, + const cudnnTensorDescriptor_t cDesc, void *C) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnReduceTensorDescriptor_t, void *, size_t, + void *, size_t, const void *, const cudnnTensorDescriptor_t, const void *, + const void *, const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnReduceTensor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, reduceTensorDesc, indices, indicesSizeInBytes, workspace, workspaceSizeInBytes, alpha, aDesc, A, beta, cDesc, C); + return func_ptr(handle, reduceTensorDesc, indices, indicesSizeInBytes, + workspace, workspaceSizeInBytes, alpha, aDesc, A, beta, cDesc, + C); } -cudnnStatus_t CUDNNWINAPI -cudnnSetTensor(cudnnHandle_t handle, const cudnnTensorDescriptor_t yDesc, void *y, const void *valuePtr) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, void *, const void *); +cudnnStatus_t CUDNNWINAPI cudnnSetTensor(cudnnHandle_t handle, + const cudnnTensorDescriptor_t yDesc, + void *y, const void *valuePtr) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, void *, const void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetTensor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, yDesc, y, valuePtr); } -cudnnStatus_t CUDNNWINAPI -cudnnScaleTensor(cudnnHandle_t handle, const cudnnTensorDescriptor_t yDesc, void *y, const void *alpha) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, void *, const void *); +cudnnStatus_t CUDNNWINAPI cudnnScaleTensor(cudnnHandle_t handle, + const cudnnTensorDescriptor_t yDesc, + void *y, const void *alpha) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, void *, const void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnScaleTensor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, yDesc, y, alpha); @@ -461,745 +481,785 @@ cudnnScaleTensor(cudnnHandle_t handle, const cudnnTensorDescriptor_t yDesc, void cudnnStatus_t CUDNNWINAPI cudnnCreateFilterDescriptor(cudnnFilterDescriptor_t *filterDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnFilterDescriptor_t *); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnFilterDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateFilterDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(filterDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnSetFilter4dDescriptor(cudnnFilterDescriptor_t filterDesc, - cudnnDataType_t dataType, /* image data type */ - cudnnTensorFormat_t format, - int k, /* number of output feature maps */ - int c, /* number of input feature maps */ - int h, /* height of each input filter */ - int w) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnFilterDescriptor_t, cudnnDataType_t, cudnnTensorFormat_t, int, int, int, int); +cudnnStatus_t CUDNNWINAPI cudnnSetFilter4dDescriptor( + cudnnFilterDescriptor_t filterDesc, + cudnnDataType_t dataType, /* image data type */ + cudnnTensorFormat_t format, int k, /* number of output feature maps */ + int c, /* number of input feature maps */ + int h, /* height of each input filter */ + int w) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnFilterDescriptor_t, cudnnDataType_t, + cudnnTensorFormat_t, int, int, int, int); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetFilter4dDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(filterDesc, dataType, format, k, c, h, w); } -cudnnStatus_t CUDNNWINAPI -cudnnGetFilter4dDescriptor(const cudnnFilterDescriptor_t filterDesc, - cudnnDataType_t *dataType, /* image data type */ - cudnnTensorFormat_t *format, - int *k, /* number of output feature maps */ - int *c, /* number of input feature maps */ - int *h, /* height of each input filter */ - int *w) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnFilterDescriptor_t, cudnnDataType_t *, cudnnTensorFormat_t *, int *, int *, int *, int *); +cudnnStatus_t CUDNNWINAPI cudnnGetFilter4dDescriptor( + const cudnnFilterDescriptor_t filterDesc, + cudnnDataType_t *dataType, /* image data type */ + cudnnTensorFormat_t *format, int *k, /* number of output feature maps */ + int *c, /* number of input feature maps */ + int *h, /* height of each input filter */ + int *w) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnFilterDescriptor_t, cudnnDataType_t *, cudnnTensorFormat_t *, + int *, int *, int *, int *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetFilter4dDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(filterDesc, dataType, format, k, c, h, w); } -cudnnStatus_t CUDNNWINAPI -cudnnSetFilterNdDescriptor(cudnnFilterDescriptor_t filterDesc, - cudnnDataType_t dataType, /* image data type */ - cudnnTensorFormat_t format, - int nbDims, - const int filterDimA[]) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnFilterDescriptor_t, cudnnDataType_t, cudnnTensorFormat_t, int, const int []); +cudnnStatus_t CUDNNWINAPI cudnnSetFilterNdDescriptor( + cudnnFilterDescriptor_t filterDesc, + cudnnDataType_t dataType, /* image data type */ + cudnnTensorFormat_t format, int nbDims, const int filterDimA[]) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnFilterDescriptor_t, cudnnDataType_t, + cudnnTensorFormat_t, int, const int[]); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetFilterNdDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(filterDesc, dataType, format, nbDims, filterDimA); } -cudnnStatus_t CUDNNWINAPI -cudnnGetFilterNdDescriptor(const cudnnFilterDescriptor_t filterDesc, - int nbDimsRequested, - cudnnDataType_t *dataType, /* image data type */ - cudnnTensorFormat_t *format, - int *nbDims, - int filterDimA[]) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnFilterDescriptor_t, int, cudnnDataType_t *, cudnnTensorFormat_t *, int *, int []); +cudnnStatus_t CUDNNWINAPI cudnnGetFilterNdDescriptor( + const cudnnFilterDescriptor_t filterDesc, int nbDimsRequested, + cudnnDataType_t *dataType, /* image data type */ + cudnnTensorFormat_t *format, int *nbDims, int filterDimA[]) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnFilterDescriptor_t, int, cudnnDataType_t *, + cudnnTensorFormat_t *, int *, int[]); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetFilterNdDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(filterDesc, nbDimsRequested, dataType, format, nbDims, filterDimA); + return func_ptr(filterDesc, nbDimsRequested, dataType, format, nbDims, + filterDimA); } -cudnnStatus_t CUDNNWINAPI -cudnnGetFilterSizeInBytes(const cudnnFilterDescriptor_t filterDesc, size_t *size) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnFilterDescriptor_t, size_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetFilterSizeInBytes( + const cudnnFilterDescriptor_t filterDesc, size_t *size) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(const cudnnFilterDescriptor_t, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetFilterSizeInBytes"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(filterDesc, size); } -cudnnStatus_t CUDNNWINAPI -cudnnTransformFilter(cudnnHandle_t handle, - const cudnnTensorTransformDescriptor_t transDesc, - const void *alpha, - const cudnnFilterDescriptor_t srcDesc, - const void *srcData, - const void *beta, - const cudnnFilterDescriptor_t destDesc, - void *destData) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorTransformDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const void *, const cudnnFilterDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnTransformFilter( + cudnnHandle_t handle, const cudnnTensorTransformDescriptor_t transDesc, + const void *alpha, const cudnnFilterDescriptor_t srcDesc, + const void *srcData, const void *beta, + const cudnnFilterDescriptor_t destDesc, void *destData) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorTransformDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, const void *, + const cudnnFilterDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnTransformFilter"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, transDesc, alpha, srcDesc, srcData, beta, destDesc, destData); + return func_ptr(handle, transDesc, alpha, srcDesc, srcData, beta, destDesc, + destData); } cudnnStatus_t CUDNNWINAPI cudnnDestroyFilterDescriptor(cudnnFilterDescriptor_t filterDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnFilterDescriptor_t); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnFilterDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyFilterDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(filterDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnReorderFilterAndBias(cudnnHandle_t handle, - const cudnnFilterDescriptor_t filterDesc, - cudnnReorderType_t reorderType, - const void *filterData, - void *reorderedFilterData, - int reorderBias, - const void *biasData, - void *reorderedBiasData) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnFilterDescriptor_t, cudnnReorderType_t, const void *, void *, int, const void *, void *); +cudnnStatus_t CUDNNWINAPI cudnnReorderFilterAndBias( + cudnnHandle_t handle, const cudnnFilterDescriptor_t filterDesc, + cudnnReorderType_t reorderType, const void *filterData, + void *reorderedFilterData, int reorderBias, const void *biasData, + void *reorderedBiasData) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnFilterDescriptor_t, cudnnReorderType_t, + const void *, void *, int, const void *, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnReorderFilterAndBias"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, filterDesc, reorderType, filterData, reorderedFilterData, reorderBias, biasData, reorderedBiasData); + return func_ptr(handle, filterDesc, reorderType, filterData, + reorderedFilterData, reorderBias, biasData, + reorderedBiasData); } cudnnStatus_t CUDNNWINAPI cudnnCreateConvolutionDescriptor(cudnnConvolutionDescriptor_t *convDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateConvolutionDescriptor"); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnConvolutionDescriptor_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnCreateConvolutionDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(convDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnSetConvolutionMathType(cudnnConvolutionDescriptor_t convDesc, cudnnMathType_t mathType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, cudnnMathType_t); +cudnnStatus_t CUDNNWINAPI cudnnSetConvolutionMathType( + cudnnConvolutionDescriptor_t convDesc, cudnnMathType_t mathType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, + cudnnMathType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetConvolutionMathType"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(convDesc, mathType); } -cudnnStatus_t CUDNNWINAPI -cudnnGetConvolutionMathType(cudnnConvolutionDescriptor_t convDesc, cudnnMathType_t *mathType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, cudnnMathType_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionMathType( + cudnnConvolutionDescriptor_t convDesc, cudnnMathType_t *mathType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, + cudnnMathType_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionMathType"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(convDesc, mathType); } -cudnnStatus_t CUDNNWINAPI -cudnnSetConvolutionGroupCount(cudnnConvolutionDescriptor_t convDesc, int groupCount) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, int); +cudnnStatus_t CUDNNWINAPI cudnnSetConvolutionGroupCount( + cudnnConvolutionDescriptor_t convDesc, int groupCount) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, int); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetConvolutionGroupCount"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(convDesc, groupCount); } -cudnnStatus_t CUDNNWINAPI -cudnnGetConvolutionGroupCount(cudnnConvolutionDescriptor_t convDesc, int *groupCount) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, int *); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionGroupCount( + cudnnConvolutionDescriptor_t convDesc, int *groupCount) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, int *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionGroupCount"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(convDesc, groupCount); } -cudnnStatus_t CUDNNWINAPI -cudnnSetConvolutionReorderType(cudnnConvolutionDescriptor_t convDesc, cudnnReorderType_t reorderType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, cudnnReorderType_t); +cudnnStatus_t CUDNNWINAPI cudnnSetConvolutionReorderType( + cudnnConvolutionDescriptor_t convDesc, cudnnReorderType_t reorderType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, + cudnnReorderType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetConvolutionReorderType"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(convDesc, reorderType); } -cudnnStatus_t CUDNNWINAPI -cudnnGetConvolutionReorderType(cudnnConvolutionDescriptor_t convDesc, cudnnReorderType_t *reorderType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, cudnnReorderType_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionReorderType( + cudnnConvolutionDescriptor_t convDesc, cudnnReorderType_t *reorderType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, + cudnnReorderType_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionReorderType"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(convDesc, reorderType); } -cudnnStatus_t CUDNNWINAPI -cudnnSetConvolution2dDescriptor(cudnnConvolutionDescriptor_t convDesc, - int pad_h, /* zero-padding height */ - int pad_w, /* zero-padding width */ - int u, /* vertical filter stride */ - int v, /* horizontal filter stride */ - int dilation_h, /* filter dilation in the vertical dimension */ - int dilation_w, /* filter dilation in the horizontal dimension */ - cudnnConvolutionMode_t mode, - cudnnDataType_t computeType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, int, int, int, int, int, int, cudnnConvolutionMode_t, cudnnDataType_t); +cudnnStatus_t CUDNNWINAPI cudnnSetConvolution2dDescriptor( + cudnnConvolutionDescriptor_t convDesc, int pad_h, /* zero-padding height */ + int pad_w, /* zero-padding width */ + int u, /* vertical filter stride */ + int v, /* horizontal filter stride */ + int dilation_h, /* filter dilation in the vertical dimension */ + int dilation_w, /* filter dilation in the horizontal dimension */ + cudnnConvolutionMode_t mode, cudnnDataType_t computeType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnConvolutionDescriptor_t, int, int, int, int, int, int, + cudnnConvolutionMode_t, cudnnDataType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetConvolution2dDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(convDesc, pad_h, pad_w, u, v, dilation_h, dilation_w, mode, computeType); + return func_ptr(convDesc, pad_h, pad_w, u, v, dilation_h, dilation_w, mode, + computeType); } -cudnnStatus_t CUDNNWINAPI -cudnnGetConvolution2dDescriptor(const cudnnConvolutionDescriptor_t convDesc, - int *pad_h, /* zero-padding height */ - int *pad_w, /* zero-padding width */ - int *u, /* vertical filter stride */ - int *v, /* horizontal filter stride */ - int *dilation_h, /* filter dilation in the vertical dimension */ - int *dilation_w, /* filter dilation in the horizontal dimension */ - cudnnConvolutionMode_t *mode, - cudnnDataType_t *computeType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnConvolutionDescriptor_t, int *, int *, int *, int *, int *, int *, cudnnConvolutionMode_t *, cudnnDataType_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolution2dDescriptor( + const cudnnConvolutionDescriptor_t convDesc, + int *pad_h, /* zero-padding height */ + int *pad_w, /* zero-padding width */ + int *u, /* vertical filter stride */ + int *v, /* horizontal filter stride */ + int *dilation_h, /* filter dilation in the vertical dimension */ + int *dilation_w, /* filter dilation in the horizontal dimension */ + cudnnConvolutionMode_t *mode, cudnnDataType_t *computeType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnConvolutionDescriptor_t, int *, int *, int *, int *, int *, + int *, cudnnConvolutionMode_t *, cudnnDataType_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolution2dDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(convDesc, pad_h, pad_w, u, v, dilation_h, dilation_w, mode, computeType); + return func_ptr(convDesc, pad_h, pad_w, u, v, dilation_h, dilation_w, mode, + computeType); } -cudnnStatus_t CUDNNWINAPI -cudnnGetConvolution2dForwardOutputDim(const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t inputTensorDesc, - const cudnnFilterDescriptor_t filterDesc, - int *n, - int *c, - int *h, - int *w) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, int *, int *, int *, int *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolution2dForwardOutputDim"); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolution2dForwardOutputDim( + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t inputTensorDesc, + const cudnnFilterDescriptor_t filterDesc, int *n, int *c, int *h, int *w) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, + const cudnnFilterDescriptor_t, int *, int *, int *, int *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolution2dForwardOutputDim"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(convDesc, inputTensorDesc, filterDesc, n, c, h, w); } -cudnnStatus_t CUDNNWINAPI -cudnnSetConvolutionNdDescriptor(cudnnConvolutionDescriptor_t convDesc, - int arrayLength, /* nbDims-2 size */ - const int padA[], - const int filterStrideA[], - const int dilationA[], - cudnnConvolutionMode_t mode, - cudnnDataType_t computeType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t, int, const int [], const int [], const int [], cudnnConvolutionMode_t, cudnnDataType_t); +cudnnStatus_t CUDNNWINAPI cudnnSetConvolutionNdDescriptor( + cudnnConvolutionDescriptor_t convDesc, int arrayLength, /* nbDims-2 size */ + const int padA[], const int filterStrideA[], const int dilationA[], + cudnnConvolutionMode_t mode, cudnnDataType_t computeType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnConvolutionDescriptor_t, int, const int[], const int[], const int[], + cudnnConvolutionMode_t, cudnnDataType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetConvolutionNdDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(convDesc, arrayLength, padA, filterStrideA, dilationA, mode, computeType); + return func_ptr(convDesc, arrayLength, padA, filterStrideA, dilationA, mode, + computeType); } -cudnnStatus_t CUDNNWINAPI -cudnnGetConvolutionNdDescriptor(const cudnnConvolutionDescriptor_t convDesc, - int arrayLengthRequested, - int *arrayLength, - int padA[], - int strideA[], - int dilationA[], - cudnnConvolutionMode_t *mode, - cudnnDataType_t *computeType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnConvolutionDescriptor_t, int, int *, int [], int [], int [], cudnnConvolutionMode_t *, cudnnDataType_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionNdDescriptor( + const cudnnConvolutionDescriptor_t convDesc, int arrayLengthRequested, + int *arrayLength, int padA[], int strideA[], int dilationA[], + cudnnConvolutionMode_t *mode, cudnnDataType_t *computeType) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnConvolutionDescriptor_t, int, int *, int[], int[], int[], + cudnnConvolutionMode_t *, cudnnDataType_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionNdDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(convDesc, arrayLengthRequested, arrayLength, padA, strideA, dilationA, mode, computeType); + return func_ptr(convDesc, arrayLengthRequested, arrayLength, padA, strideA, + dilationA, mode, computeType); } -cudnnStatus_t CUDNNWINAPI -cudnnGetConvolutionNdForwardOutputDim(const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t inputTensorDesc, - const cudnnFilterDescriptor_t filterDesc, - int nbDims, - int tensorOutputDimA[]) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, int, int []); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionNdForwardOutputDim"); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionNdForwardOutputDim( + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t inputTensorDesc, + const cudnnFilterDescriptor_t filterDesc, int nbDims, + int tensorOutputDimA[]) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, + const cudnnFilterDescriptor_t, int, int[]); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionNdForwardOutputDim"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(convDesc, inputTensorDesc, filterDesc, nbDims, tensorOutputDimA); + return func_ptr(convDesc, inputTensorDesc, filterDesc, nbDims, + tensorOutputDimA); } cudnnStatus_t CUDNNWINAPI cudnnDestroyConvolutionDescriptor(cudnnConvolutionDescriptor_t convDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnConvolutionDescriptor_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyConvolutionDescriptor"); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnConvolutionDescriptor_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnDestroyConvolutionDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(convDesc); } cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionForwardAlgorithmMaxCount(cudnnHandle_t handle, int *count) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, int *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionForwardAlgorithmMaxCount"); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t, int *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionForwardAlgorithmMaxCount"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, count); } -cudnnStatus_t CUDNNWINAPI -cudnnFindConvolutionForwardAlgorithm(cudnnHandle_t handle, - const cudnnTensorDescriptor_t xDesc, - const cudnnFilterDescriptor_t wDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t yDesc, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnConvolutionFwdAlgoPerf_t *perfResults) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, const int, int *, cudnnConvolutionFwdAlgoPerf_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindConvolutionForwardAlgorithm"); +cudnnStatus_t CUDNNWINAPI cudnnFindConvolutionForwardAlgorithm( + cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, + const cudnnFilterDescriptor_t wDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t yDesc, const int requestedAlgoCount, + int *returnedAlgoCount, cudnnConvolutionFwdAlgoPerf_t *perfResults) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, + const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnTensorDescriptor_t, const int, int *, + cudnnConvolutionFwdAlgoPerf_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnFindConvolutionForwardAlgorithm"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, xDesc, wDesc, convDesc, yDesc, requestedAlgoCount, returnedAlgoCount, perfResults); + return func_ptr(handle, xDesc, wDesc, convDesc, yDesc, requestedAlgoCount, + returnedAlgoCount, perfResults); } -cudnnStatus_t CUDNNWINAPI -cudnnFindConvolutionForwardAlgorithmEx(cudnnHandle_t handle, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t yDesc, - void *y, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnConvolutionFwdAlgoPerf_t *perfResults, - void *workSpace, - size_t workSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, void *, const int, int *, cudnnConvolutionFwdAlgoPerf_t *, void *, size_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindConvolutionForwardAlgorithmEx"); +cudnnStatus_t CUDNNWINAPI cudnnFindConvolutionForwardAlgorithmEx( + cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, const void *x, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t yDesc, void *y, const int requestedAlgoCount, + int *returnedAlgoCount, cudnnConvolutionFwdAlgoPerf_t *perfResults, + void *workSpace, size_t workSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, void *, + const int, int *, cudnnConvolutionFwdAlgoPerf_t *, void *, size_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnFindConvolutionForwardAlgorithmEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, xDesc, x, wDesc, w, convDesc, yDesc, y, requestedAlgoCount, returnedAlgoCount, perfResults, workSpace, workSpaceSizeInBytes); + return func_ptr(handle, xDesc, x, wDesc, w, convDesc, yDesc, y, + requestedAlgoCount, returnedAlgoCount, perfResults, workSpace, + workSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnGetConvolutionForwardAlgorithm(cudnnHandle_t handle, - const cudnnTensorDescriptor_t xDesc, - const cudnnFilterDescriptor_t wDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t yDesc, - cudnnConvolutionFwdPreference_t preference, - size_t memoryLimitInBytes, - cudnnConvolutionFwdAlgo_t *algo) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, cudnnConvolutionFwdPreference_t, size_t, cudnnConvolutionFwdAlgo_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionForwardAlgorithm"); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionForwardAlgorithm( + cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, + const cudnnFilterDescriptor_t wDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t yDesc, + cudnnConvolutionFwdPreference_t preference, size_t memoryLimitInBytes, + cudnnConvolutionFwdAlgo_t *algo) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, + const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnTensorDescriptor_t, cudnnConvolutionFwdPreference_t, size_t, + cudnnConvolutionFwdAlgo_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionForwardAlgorithm"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, xDesc, wDesc, convDesc, yDesc, preference, memoryLimitInBytes, algo); + return func_ptr(handle, xDesc, wDesc, convDesc, yDesc, preference, + memoryLimitInBytes, algo); } -cudnnStatus_t CUDNNWINAPI -cudnnGetConvolutionForwardAlgorithm_v7(cudnnHandle_t handle, - const cudnnTensorDescriptor_t srcDesc, - const cudnnFilterDescriptor_t filterDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t destDesc, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnConvolutionFwdAlgoPerf_t *perfResults) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, const int, int *, cudnnConvolutionFwdAlgoPerf_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionForwardAlgorithm_v7"); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionForwardAlgorithm_v7( + cudnnHandle_t handle, const cudnnTensorDescriptor_t srcDesc, + const cudnnFilterDescriptor_t filterDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t destDesc, const int requestedAlgoCount, + int *returnedAlgoCount, cudnnConvolutionFwdAlgoPerf_t *perfResults) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, + const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnTensorDescriptor_t, const int, int *, + cudnnConvolutionFwdAlgoPerf_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionForwardAlgorithm_v7"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, srcDesc, filterDesc, convDesc, destDesc, requestedAlgoCount, returnedAlgoCount, perfResults); + return func_ptr(handle, srcDesc, filterDesc, convDesc, destDesc, + requestedAlgoCount, returnedAlgoCount, perfResults); } -cudnnStatus_t CUDNNWINAPI -cudnnGetConvolutionForwardWorkspaceSize(cudnnHandle_t handle, - const cudnnTensorDescriptor_t xDesc, - const cudnnFilterDescriptor_t wDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t yDesc, - cudnnConvolutionFwdAlgo_t algo, - size_t *sizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, cudnnConvolutionFwdAlgo_t, size_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionForwardWorkspaceSize"); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionForwardWorkspaceSize( + cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, + const cudnnFilterDescriptor_t wDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t yDesc, cudnnConvolutionFwdAlgo_t algo, + size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, + const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnTensorDescriptor_t, cudnnConvolutionFwdAlgo_t, size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionForwardWorkspaceSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, xDesc, wDesc, convDesc, yDesc, algo, sizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnConvolutionForward(cudnnHandle_t handle, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnConvolutionDescriptor_t convDesc, - cudnnConvolutionFwdAlgo_t algo, - void *workSpace, - size_t workSpaceSizeInBytes, - const void *beta, - const cudnnTensorDescriptor_t yDesc, - void *y) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnConvolutionDescriptor_t, cudnnConvolutionFwdAlgo_t, void *, size_t, const void *, const cudnnTensorDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnConvolutionForward( + cudnnHandle_t handle, const void *alpha, + const cudnnTensorDescriptor_t xDesc, const void *x, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnConvolutionDescriptor_t convDesc, cudnnConvolutionFwdAlgo_t algo, + void *workSpace, size_t workSpaceSizeInBytes, const void *beta, + const cudnnTensorDescriptor_t yDesc, void *y) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnConvolutionDescriptor_t, cudnnConvolutionFwdAlgo_t, void *, + size_t, const void *, const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnConvolutionForward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, alpha, xDesc, x, wDesc, w, convDesc, algo, workSpace, workSpaceSizeInBytes, beta, yDesc, y); + return func_ptr(handle, alpha, xDesc, x, wDesc, w, convDesc, algo, workSpace, + workSpaceSizeInBytes, beta, yDesc, y); } -cudnnStatus_t CUDNNWINAPI -cudnnConvolutionBiasActivationForward(cudnnHandle_t handle, - const void *alpha1, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnConvolutionDescriptor_t convDesc, - cudnnConvolutionFwdAlgo_t algo, - void *workSpace, - size_t workSpaceSizeInBytes, - const void *alpha2, - const cudnnTensorDescriptor_t zDesc, - const void *z, - const cudnnTensorDescriptor_t biasDesc, - const void *bias, - const cudnnActivationDescriptor_t activationDesc, - const cudnnTensorDescriptor_t yDesc, - void *y) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnConvolutionDescriptor_t, cudnnConvolutionFwdAlgo_t, void *, size_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnActivationDescriptor_t, const cudnnTensorDescriptor_t, void *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnConvolutionBiasActivationForward"); +cudnnStatus_t CUDNNWINAPI cudnnConvolutionBiasActivationForward( + cudnnHandle_t handle, const void *alpha1, + const cudnnTensorDescriptor_t xDesc, const void *x, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnConvolutionDescriptor_t convDesc, cudnnConvolutionFwdAlgo_t algo, + void *workSpace, size_t workSpaceSizeInBytes, const void *alpha2, + const cudnnTensorDescriptor_t zDesc, const void *z, + const cudnnTensorDescriptor_t biasDesc, const void *bias, + const cudnnActivationDescriptor_t activationDesc, + const cudnnTensorDescriptor_t yDesc, void *y) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnConvolutionDescriptor_t, cudnnConvolutionFwdAlgo_t, void *, + size_t, const void *, const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnActivationDescriptor_t, const cudnnTensorDescriptor_t, void *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnConvolutionBiasActivationForward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, alpha1, xDesc, x, wDesc, w, convDesc, algo, workSpace, workSpaceSizeInBytes, alpha2, zDesc, z, biasDesc, bias, activationDesc, yDesc, y); + return func_ptr(handle, alpha1, xDesc, x, wDesc, w, convDesc, algo, workSpace, + workSpaceSizeInBytes, alpha2, zDesc, z, biasDesc, bias, + activationDesc, yDesc, y); } -cudnnStatus_t CUDNNWINAPI -cudnnConvolutionBackwardBias(cudnnHandle_t handle, - const void *alpha, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const void *beta, - const cudnnTensorDescriptor_t dbDesc, - void *db) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnConvolutionBackwardBias( + cudnnHandle_t handle, const void *alpha, + const cudnnTensorDescriptor_t dyDesc, const void *dy, const void *beta, + const cudnnTensorDescriptor_t dbDesc, void *db) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, + const void *, const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnConvolutionBackwardBias"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, alpha, dyDesc, dy, beta, dbDesc, db); } -cudnnStatus_t CUDNNWINAPI -cudnnGetConvolutionBackwardFilterAlgorithmMaxCount(cudnnHandle_t handle, int *count) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, int *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardFilterAlgorithmMaxCount"); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionBackwardFilterAlgorithmMaxCount( + cudnnHandle_t handle, int *count) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t, int *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardFilterAlgorithmMaxCount"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, count); } -cudnnStatus_t CUDNNWINAPI -cudnnFindConvolutionBackwardFilterAlgorithm(cudnnHandle_t handle, - const cudnnTensorDescriptor_t xDesc, - const cudnnTensorDescriptor_t dyDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnFilterDescriptor_t dwDesc, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnConvolutionBwdFilterAlgoPerf_t *perfResults) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnFilterDescriptor_t, const int, int *, cudnnConvolutionBwdFilterAlgoPerf_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindConvolutionBackwardFilterAlgorithm"); +cudnnStatus_t CUDNNWINAPI cudnnFindConvolutionBackwardFilterAlgorithm( + cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, + const cudnnTensorDescriptor_t dyDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnFilterDescriptor_t dwDesc, const int requestedAlgoCount, + int *returnedAlgoCount, cudnnConvolutionBwdFilterAlgoPerf_t *perfResults) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnFilterDescriptor_t, const int, int *, + cudnnConvolutionBwdFilterAlgoPerf_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnFindConvolutionBackwardFilterAlgorithm"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, xDesc, dyDesc, convDesc, dwDesc, requestedAlgoCount, returnedAlgoCount, perfResults); + return func_ptr(handle, xDesc, dyDesc, convDesc, dwDesc, requestedAlgoCount, + returnedAlgoCount, perfResults); } -cudnnStatus_t CUDNNWINAPI -cudnnFindConvolutionBackwardFilterAlgorithmEx(cudnnHandle_t handle, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const cudnnTensorDescriptor_t dyDesc, - const void *y, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnFilterDescriptor_t dwDesc, - void *dw, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnConvolutionBwdFilterAlgoPerf_t *perfResults, - void *workSpace, - size_t workSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnConvolutionDescriptor_t, const cudnnFilterDescriptor_t, void *, const int, int *, cudnnConvolutionBwdFilterAlgoPerf_t *, void *, size_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindConvolutionBackwardFilterAlgorithmEx"); +cudnnStatus_t CUDNNWINAPI cudnnFindConvolutionBackwardFilterAlgorithmEx( + cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, const void *x, + const cudnnTensorDescriptor_t dyDesc, const void *y, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnFilterDescriptor_t dwDesc, void *dw, + const int requestedAlgoCount, int *returnedAlgoCount, + cudnnConvolutionBwdFilterAlgoPerf_t *perfResults, void *workSpace, + size_t workSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnConvolutionDescriptor_t, const cudnnFilterDescriptor_t, void *, + const int, int *, cudnnConvolutionBwdFilterAlgoPerf_t *, void *, size_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnFindConvolutionBackwardFilterAlgorithmEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, xDesc, x, dyDesc, y, convDesc, dwDesc, dw, requestedAlgoCount, returnedAlgoCount, perfResults, workSpace, workSpaceSizeInBytes); + return func_ptr(handle, xDesc, x, dyDesc, y, convDesc, dwDesc, dw, + requestedAlgoCount, returnedAlgoCount, perfResults, workSpace, + workSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnGetConvolutionBackwardFilterAlgorithm(cudnnHandle_t handle, - const cudnnTensorDescriptor_t xDesc, - const cudnnTensorDescriptor_t dyDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnFilterDescriptor_t dwDesc, - cudnnConvolutionBwdFilterPreference_t preference, - size_t memoryLimitInBytes, - cudnnConvolutionBwdFilterAlgo_t *algo) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnFilterDescriptor_t, cudnnConvolutionBwdFilterPreference_t, size_t, cudnnConvolutionBwdFilterAlgo_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardFilterAlgorithm"); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionBackwardFilterAlgorithm( + cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, + const cudnnTensorDescriptor_t dyDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnFilterDescriptor_t dwDesc, + cudnnConvolutionBwdFilterPreference_t preference, size_t memoryLimitInBytes, + cudnnConvolutionBwdFilterAlgo_t *algo) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnFilterDescriptor_t, cudnnConvolutionBwdFilterPreference_t, + size_t, cudnnConvolutionBwdFilterAlgo_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardFilterAlgorithm"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, xDesc, dyDesc, convDesc, dwDesc, preference, memoryLimitInBytes, algo); + return func_ptr(handle, xDesc, dyDesc, convDesc, dwDesc, preference, + memoryLimitInBytes, algo); } -cudnnStatus_t CUDNNWINAPI -cudnnGetConvolutionBackwardFilterAlgorithm_v7(cudnnHandle_t handle, - const cudnnTensorDescriptor_t srcDesc, - const cudnnTensorDescriptor_t diffDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnFilterDescriptor_t gradDesc, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnConvolutionBwdFilterAlgoPerf_t *perfResults) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnFilterDescriptor_t, const int, int *, cudnnConvolutionBwdFilterAlgoPerf_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardFilterAlgorithm_v7"); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionBackwardFilterAlgorithm_v7( + cudnnHandle_t handle, const cudnnTensorDescriptor_t srcDesc, + const cudnnTensorDescriptor_t diffDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnFilterDescriptor_t gradDesc, const int requestedAlgoCount, + int *returnedAlgoCount, cudnnConvolutionBwdFilterAlgoPerf_t *perfResults) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnFilterDescriptor_t, const int, int *, + cudnnConvolutionBwdFilterAlgoPerf_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardFilterAlgorithm_v7"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, srcDesc, diffDesc, convDesc, gradDesc, requestedAlgoCount, returnedAlgoCount, perfResults); + return func_ptr(handle, srcDesc, diffDesc, convDesc, gradDesc, + requestedAlgoCount, returnedAlgoCount, perfResults); } -cudnnStatus_t CUDNNWINAPI -cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnnHandle_t handle, - const cudnnTensorDescriptor_t xDesc, - const cudnnTensorDescriptor_t dyDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnFilterDescriptor_t gradDesc, - cudnnConvolutionBwdFilterAlgo_t algo, - size_t *sizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnFilterDescriptor_t, cudnnConvolutionBwdFilterAlgo_t, size_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardFilterWorkspaceSize"); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionBackwardFilterWorkspaceSize( + cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, + const cudnnTensorDescriptor_t dyDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnFilterDescriptor_t gradDesc, + cudnnConvolutionBwdFilterAlgo_t algo, size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnFilterDescriptor_t, cudnnConvolutionBwdFilterAlgo_t, size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardFilterWorkspaceSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, xDesc, dyDesc, convDesc, gradDesc, algo, sizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnConvolutionBackwardFilter(cudnnHandle_t handle, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const cudnnConvolutionDescriptor_t convDesc, - cudnnConvolutionBwdFilterAlgo_t algo, - void *workSpace, - size_t workSpaceSizeInBytes, - const void *beta, - const cudnnFilterDescriptor_t dwDesc, - void *dw) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnConvolutionDescriptor_t, cudnnConvolutionBwdFilterAlgo_t, void *, size_t, const void *, const cudnnFilterDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnConvolutionBackwardFilter( + cudnnHandle_t handle, const void *alpha, + const cudnnTensorDescriptor_t xDesc, const void *x, + const cudnnTensorDescriptor_t dyDesc, const void *dy, + const cudnnConvolutionDescriptor_t convDesc, + cudnnConvolutionBwdFilterAlgo_t algo, void *workSpace, + size_t workSpaceSizeInBytes, const void *beta, + const cudnnFilterDescriptor_t dwDesc, void *dw) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const void *, const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnConvolutionDescriptor_t, cudnnConvolutionBwdFilterAlgo_t, + void *, size_t, const void *, const cudnnFilterDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnConvolutionBackwardFilter"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, alpha, xDesc, x, dyDesc, dy, convDesc, algo, workSpace, workSpaceSizeInBytes, beta, dwDesc, dw); + return func_ptr(handle, alpha, xDesc, x, dyDesc, dy, convDesc, algo, + workSpace, workSpaceSizeInBytes, beta, dwDesc, dw); } -cudnnStatus_t CUDNNWINAPI -cudnnGetConvolutionBackwardDataAlgorithmMaxCount(cudnnHandle_t handle, int *count) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, int *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardDataAlgorithmMaxCount"); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionBackwardDataAlgorithmMaxCount( + cudnnHandle_t handle, int *count) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t, int *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardDataAlgorithmMaxCount"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, count); } -cudnnStatus_t CUDNNWINAPI -cudnnFindConvolutionBackwardDataAlgorithm(cudnnHandle_t handle, - const cudnnFilterDescriptor_t wDesc, - const cudnnTensorDescriptor_t dyDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t dxDesc, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnConvolutionBwdDataAlgoPerf_t *perfResults) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnFilterDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, const int, int *, cudnnConvolutionBwdDataAlgoPerf_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindConvolutionBackwardDataAlgorithm"); +cudnnStatus_t CUDNNWINAPI cudnnFindConvolutionBackwardDataAlgorithm( + cudnnHandle_t handle, const cudnnFilterDescriptor_t wDesc, + const cudnnTensorDescriptor_t dyDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t dxDesc, const int requestedAlgoCount, + int *returnedAlgoCount, cudnnConvolutionBwdDataAlgoPerf_t *perfResults) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnFilterDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnTensorDescriptor_t, const int, int *, + cudnnConvolutionBwdDataAlgoPerf_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnFindConvolutionBackwardDataAlgorithm"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, wDesc, dyDesc, convDesc, dxDesc, requestedAlgoCount, returnedAlgoCount, perfResults); + return func_ptr(handle, wDesc, dyDesc, convDesc, dxDesc, requestedAlgoCount, + returnedAlgoCount, perfResults); } -cudnnStatus_t CUDNNWINAPI -cudnnFindConvolutionBackwardDataAlgorithmEx(cudnnHandle_t handle, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t dxDesc, - void *dx, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnConvolutionBwdDataAlgoPerf_t *perfResults, - void *workSpace, - size_t workSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnFilterDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, void *, const int, int *, cudnnConvolutionBwdDataAlgoPerf_t *, void *, size_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindConvolutionBackwardDataAlgorithmEx"); +cudnnStatus_t CUDNNWINAPI cudnnFindConvolutionBackwardDataAlgorithmEx( + cudnnHandle_t handle, const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnTensorDescriptor_t dyDesc, const void *dy, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t dxDesc, void *dx, + const int requestedAlgoCount, int *returnedAlgoCount, + cudnnConvolutionBwdDataAlgoPerf_t *perfResults, void *workSpace, + size_t workSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnFilterDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, void *, + const int, int *, cudnnConvolutionBwdDataAlgoPerf_t *, void *, size_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnFindConvolutionBackwardDataAlgorithmEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, wDesc, w, dyDesc, dy, convDesc, dxDesc, dx, requestedAlgoCount, returnedAlgoCount, perfResults, workSpace, workSpaceSizeInBytes); + return func_ptr(handle, wDesc, w, dyDesc, dy, convDesc, dxDesc, dx, + requestedAlgoCount, returnedAlgoCount, perfResults, workSpace, + workSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnGetConvolutionBackwardDataAlgorithm(cudnnHandle_t handle, - const cudnnFilterDescriptor_t wDesc, - const cudnnTensorDescriptor_t dyDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t dxDesc, - cudnnConvolutionBwdDataPreference_t preference, - size_t memoryLimitInBytes, - cudnnConvolutionBwdDataAlgo_t *algo) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnFilterDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, cudnnConvolutionBwdDataPreference_t, size_t, cudnnConvolutionBwdDataAlgo_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardDataAlgorithm"); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionBackwardDataAlgorithm( + cudnnHandle_t handle, const cudnnFilterDescriptor_t wDesc, + const cudnnTensorDescriptor_t dyDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t dxDesc, + cudnnConvolutionBwdDataPreference_t preference, size_t memoryLimitInBytes, + cudnnConvolutionBwdDataAlgo_t *algo) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnFilterDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnTensorDescriptor_t, cudnnConvolutionBwdDataPreference_t, + size_t, cudnnConvolutionBwdDataAlgo_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardDataAlgorithm"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, wDesc, dyDesc, convDesc, dxDesc, preference, memoryLimitInBytes, algo); + return func_ptr(handle, wDesc, dyDesc, convDesc, dxDesc, preference, + memoryLimitInBytes, algo); } -cudnnStatus_t CUDNNWINAPI -cudnnGetConvolutionBackwardDataAlgorithm_v7(cudnnHandle_t handle, - const cudnnFilterDescriptor_t filterDesc, - const cudnnTensorDescriptor_t diffDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t gradDesc, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnConvolutionBwdDataAlgoPerf_t *perfResults) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnFilterDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, const int, int *, cudnnConvolutionBwdDataAlgoPerf_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardDataAlgorithm_v7"); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionBackwardDataAlgorithm_v7( + cudnnHandle_t handle, const cudnnFilterDescriptor_t filterDesc, + const cudnnTensorDescriptor_t diffDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t gradDesc, const int requestedAlgoCount, + int *returnedAlgoCount, cudnnConvolutionBwdDataAlgoPerf_t *perfResults) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnFilterDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnTensorDescriptor_t, const int, int *, + cudnnConvolutionBwdDataAlgoPerf_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardDataAlgorithm_v7"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, filterDesc, diffDesc, convDesc, gradDesc, requestedAlgoCount, returnedAlgoCount, perfResults); + return func_ptr(handle, filterDesc, diffDesc, convDesc, gradDesc, + requestedAlgoCount, returnedAlgoCount, perfResults); } -cudnnStatus_t CUDNNWINAPI -cudnnGetConvolutionBackwardDataWorkspaceSize(cudnnHandle_t handle, - const cudnnFilterDescriptor_t wDesc, - const cudnnTensorDescriptor_t dyDesc, - const cudnnConvolutionDescriptor_t convDesc, - const cudnnTensorDescriptor_t dxDesc, - cudnnConvolutionBwdDataAlgo_t algo, - size_t *sizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnFilterDescriptor_t, const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, const cudnnTensorDescriptor_t, cudnnConvolutionBwdDataAlgo_t, size_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardDataWorkspaceSize"); +cudnnStatus_t CUDNNWINAPI cudnnGetConvolutionBackwardDataWorkspaceSize( + cudnnHandle_t handle, const cudnnFilterDescriptor_t wDesc, + const cudnnTensorDescriptor_t dyDesc, + const cudnnConvolutionDescriptor_t convDesc, + const cudnnTensorDescriptor_t dxDesc, cudnnConvolutionBwdDataAlgo_t algo, + size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnFilterDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnConvolutionDescriptor_t, + const cudnnTensorDescriptor_t, cudnnConvolutionBwdDataAlgo_t, size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetConvolutionBackwardDataWorkspaceSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, wDesc, dyDesc, convDesc, dxDesc, algo, sizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnConvolutionBackwardData(cudnnHandle_t handle, - const void *alpha, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const cudnnConvolutionDescriptor_t convDesc, - cudnnConvolutionBwdDataAlgo_t algo, - void *workSpace, - size_t workSpaceSizeInBytes, - const void *beta, - const cudnnTensorDescriptor_t dxDesc, - void *dx) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnConvolutionDescriptor_t, cudnnConvolutionBwdDataAlgo_t, void *, size_t, const void *, const cudnnTensorDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnConvolutionBackwardData( + cudnnHandle_t handle, const void *alpha, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnTensorDescriptor_t dyDesc, const void *dy, + const cudnnConvolutionDescriptor_t convDesc, + cudnnConvolutionBwdDataAlgo_t algo, void *workSpace, + size_t workSpaceSizeInBytes, const void *beta, + const cudnnTensorDescriptor_t dxDesc, void *dx) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const void *, const cudnnFilterDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnConvolutionDescriptor_t, cudnnConvolutionBwdDataAlgo_t, void *, + size_t, const void *, const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnConvolutionBackwardData"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, alpha, wDesc, w, dyDesc, dy, convDesc, algo, workSpace, workSpaceSizeInBytes, beta, dxDesc, dx); + return func_ptr(handle, alpha, wDesc, w, dyDesc, dy, convDesc, algo, + workSpace, workSpaceSizeInBytes, beta, dxDesc, dx); } cudnnStatus_t CUDNNWINAPI -cudnnIm2Col(cudnnHandle_t handle, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const cudnnFilterDescriptor_t wDesc, - const cudnnConvolutionDescriptor_t convDesc, - void *colBuffer) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const cudnnConvolutionDescriptor_t, void *); +cudnnIm2Col(cudnnHandle_t handle, const cudnnTensorDescriptor_t xDesc, + const void *x, const cudnnFilterDescriptor_t wDesc, + const cudnnConvolutionDescriptor_t convDesc, void *colBuffer) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, + const void *, const cudnnFilterDescriptor_t, + const cudnnConvolutionDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnIm2Col"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, xDesc, x, wDesc, convDesc, colBuffer); } -cudnnStatus_t CUDNNWINAPI -cudnnSoftmaxForward(cudnnHandle_t handle, - cudnnSoftmaxAlgorithm_t algo, - cudnnSoftmaxMode_t mode, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t yDesc, - void *y) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnSoftmaxAlgorithm_t, cudnnSoftmaxMode_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnSoftmaxForward( + cudnnHandle_t handle, cudnnSoftmaxAlgorithm_t algo, cudnnSoftmaxMode_t mode, + const void *alpha, const cudnnTensorDescriptor_t xDesc, const void *x, + const void *beta, const cudnnTensorDescriptor_t yDesc, void *y) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnSoftmaxAlgorithm_t, cudnnSoftmaxMode_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSoftmaxForward"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, algo, mode, alpha, xDesc, x, beta, yDesc, y); } -cudnnStatus_t CUDNNWINAPI -cudnnSoftmaxBackward(cudnnHandle_t handle, - cudnnSoftmaxAlgorithm_t algo, - cudnnSoftmaxMode_t mode, - const void *alpha, - const cudnnTensorDescriptor_t yDesc, - const void *y, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const void *beta, - const cudnnTensorDescriptor_t dxDesc, - void *dx) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnSoftmaxAlgorithm_t, cudnnSoftmaxMode_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnSoftmaxBackward( + cudnnHandle_t handle, cudnnSoftmaxAlgorithm_t algo, cudnnSoftmaxMode_t mode, + const void *alpha, const cudnnTensorDescriptor_t yDesc, const void *y, + const cudnnTensorDescriptor_t dyDesc, const void *dy, const void *beta, + const cudnnTensorDescriptor_t dxDesc, void *dx) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnSoftmaxAlgorithm_t, cudnnSoftmaxMode_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSoftmaxBackward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, algo, mode, alpha, yDesc, y, dyDesc, dy, beta, dxDesc, dx); + return func_ptr(handle, algo, mode, alpha, yDesc, y, dyDesc, dy, beta, dxDesc, + dx); } cudnnStatus_t CUDNNWINAPI cudnnCreatePoolingDescriptor(cudnnPoolingDescriptor_t *poolingDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnPoolingDescriptor_t *); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnPoolingDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreatePoolingDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(poolingDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnSetPooling2dDescriptor(cudnnPoolingDescriptor_t poolingDesc, - cudnnPoolingMode_t mode, - cudnnNanPropagation_t maxpoolingNanOpt, - int windowHeight, - int windowWidth, - int verticalPadding, - int horizontalPadding, - int verticalStride, - int horizontalStride) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnPoolingDescriptor_t, cudnnPoolingMode_t, cudnnNanPropagation_t, int, int, int, int, int, int); +cudnnStatus_t CUDNNWINAPI cudnnSetPooling2dDescriptor( + cudnnPoolingDescriptor_t poolingDesc, cudnnPoolingMode_t mode, + cudnnNanPropagation_t maxpoolingNanOpt, int windowHeight, int windowWidth, + int verticalPadding, int horizontalPadding, int verticalStride, + int horizontalStride) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnPoolingDescriptor_t, cudnnPoolingMode_t, cudnnNanPropagation_t, int, + int, int, int, int, int); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetPooling2dDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(poolingDesc, mode, maxpoolingNanOpt, windowHeight, windowWidth, verticalPadding, horizontalPadding, verticalStride, horizontalStride); + return func_ptr(poolingDesc, mode, maxpoolingNanOpt, windowHeight, + windowWidth, verticalPadding, horizontalPadding, + verticalStride, horizontalStride); } -cudnnStatus_t CUDNNWINAPI -cudnnGetPooling2dDescriptor(const cudnnPoolingDescriptor_t poolingDesc, - cudnnPoolingMode_t *mode, - cudnnNanPropagation_t *maxpoolingNanOpt, - int *windowHeight, - int *windowWidth, - int *verticalPadding, - int *horizontalPadding, - int *verticalStride, - int *horizontalStride) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnPoolingDescriptor_t, cudnnPoolingMode_t *, cudnnNanPropagation_t *, int *, int *, int *, int *, int *, int *); +cudnnStatus_t CUDNNWINAPI cudnnGetPooling2dDescriptor( + const cudnnPoolingDescriptor_t poolingDesc, cudnnPoolingMode_t *mode, + cudnnNanPropagation_t *maxpoolingNanOpt, int *windowHeight, + int *windowWidth, int *verticalPadding, int *horizontalPadding, + int *verticalStride, int *horizontalStride) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnPoolingDescriptor_t, cudnnPoolingMode_t *, + cudnnNanPropagation_t *, int *, int *, int *, int *, int *, int *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetPooling2dDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(poolingDesc, mode, maxpoolingNanOpt, windowHeight, windowWidth, verticalPadding, horizontalPadding, verticalStride, horizontalStride); + return func_ptr(poolingDesc, mode, maxpoolingNanOpt, windowHeight, + windowWidth, verticalPadding, horizontalPadding, + verticalStride, horizontalStride); } -cudnnStatus_t CUDNNWINAPI -cudnnSetPoolingNdDescriptor(cudnnPoolingDescriptor_t poolingDesc, - const cudnnPoolingMode_t mode, - const cudnnNanPropagation_t maxpoolingNanOpt, - int nbDims, - const int windowDimA[], - const int paddingA[], - const int strideA[]) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnPoolingDescriptor_t, const cudnnPoolingMode_t, const cudnnNanPropagation_t, int, const int [], const int [], const int []); +cudnnStatus_t CUDNNWINAPI cudnnSetPoolingNdDescriptor( + cudnnPoolingDescriptor_t poolingDesc, const cudnnPoolingMode_t mode, + const cudnnNanPropagation_t maxpoolingNanOpt, int nbDims, + const int windowDimA[], const int paddingA[], const int strideA[]) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnPoolingDescriptor_t, const cudnnPoolingMode_t, + const cudnnNanPropagation_t, int, const int[], const int[], const int[]); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetPoolingNdDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(poolingDesc, mode, maxpoolingNanOpt, nbDims, windowDimA, paddingA, strideA); + return func_ptr(poolingDesc, mode, maxpoolingNanOpt, nbDims, windowDimA, + paddingA, strideA); } -cudnnStatus_t CUDNNWINAPI -cudnnGetPoolingNdDescriptor(const cudnnPoolingDescriptor_t poolingDesc, - int nbDimsRequested, - cudnnPoolingMode_t *mode, - cudnnNanPropagation_t *maxpoolingNanOpt, - int *nbDims, - int windowDimA[], - int paddingA[], - int strideA[]) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnPoolingDescriptor_t, int, cudnnPoolingMode_t *, cudnnNanPropagation_t *, int *, int [], int [], int []); +cudnnStatus_t CUDNNWINAPI cudnnGetPoolingNdDescriptor( + const cudnnPoolingDescriptor_t poolingDesc, int nbDimsRequested, + cudnnPoolingMode_t *mode, cudnnNanPropagation_t *maxpoolingNanOpt, + int *nbDims, int windowDimA[], int paddingA[], int strideA[]) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnPoolingDescriptor_t, int, cudnnPoolingMode_t *, + cudnnNanPropagation_t *, int *, int[], int[], int[]); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetPoolingNdDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(poolingDesc, nbDimsRequested, mode, maxpoolingNanOpt, nbDims, windowDimA, paddingA, strideA); + return func_ptr(poolingDesc, nbDimsRequested, mode, maxpoolingNanOpt, nbDims, + windowDimA, paddingA, strideA); } cudnnStatus_t CUDNNWINAPI cudnnGetPoolingNdForwardOutputDim(const cudnnPoolingDescriptor_t poolingDesc, const cudnnTensorDescriptor_t inputTensorDesc, - int nbDims, - int outputTensorDimA[]) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnPoolingDescriptor_t, const cudnnTensorDescriptor_t, int, int []); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetPoolingNdForwardOutputDim"); + int nbDims, int outputTensorDimA[]) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(const cudnnPoolingDescriptor_t, + const cudnnTensorDescriptor_t, int, int[]); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetPoolingNdForwardOutputDim"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(poolingDesc, inputTensorDesc, nbDims, outputTensorDimA); } @@ -1207,72 +1267,69 @@ cudnnGetPoolingNdForwardOutputDim(const cudnnPoolingDescriptor_t poolingDesc, cudnnStatus_t CUDNNWINAPI cudnnGetPooling2dForwardOutputDim(const cudnnPoolingDescriptor_t poolingDesc, const cudnnTensorDescriptor_t inputTensorDesc, - int *n, - int *c, - int *h, - int *w) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnPoolingDescriptor_t, const cudnnTensorDescriptor_t, int *, int *, int *, int *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetPooling2dForwardOutputDim"); + int *n, int *c, int *h, int *w) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(const cudnnPoolingDescriptor_t, + const cudnnTensorDescriptor_t, + int *, int *, int *, int *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetPooling2dForwardOutputDim"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(poolingDesc, inputTensorDesc, n, c, h, w); } cudnnStatus_t CUDNNWINAPI cudnnDestroyPoolingDescriptor(cudnnPoolingDescriptor_t poolingDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnPoolingDescriptor_t); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnPoolingDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyPoolingDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(poolingDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnPoolingForward(cudnnHandle_t handle, - const cudnnPoolingDescriptor_t poolingDesc, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t yDesc, - void *y) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnPoolingDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnPoolingForward( + cudnnHandle_t handle, const cudnnPoolingDescriptor_t poolingDesc, + const void *alpha, const cudnnTensorDescriptor_t xDesc, const void *x, + const void *beta, const cudnnTensorDescriptor_t yDesc, void *y) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnPoolingDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnPoolingForward"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, poolingDesc, alpha, xDesc, x, beta, yDesc, y); } -cudnnStatus_t CUDNNWINAPI -cudnnPoolingBackward(cudnnHandle_t handle, - const cudnnPoolingDescriptor_t poolingDesc, - const void *alpha, - const cudnnTensorDescriptor_t yDesc, - const void *y, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t dxDesc, - void *dx) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnPoolingDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnPoolingBackward( + cudnnHandle_t handle, const cudnnPoolingDescriptor_t poolingDesc, + const void *alpha, const cudnnTensorDescriptor_t yDesc, const void *y, + const cudnnTensorDescriptor_t dyDesc, const void *dy, + const cudnnTensorDescriptor_t xDesc, const void *x, const void *beta, + const cudnnTensorDescriptor_t dxDesc, void *dx) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnPoolingDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnPoolingBackward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, poolingDesc, alpha, yDesc, y, dyDesc, dy, xDesc, x, beta, dxDesc, dx); + return func_ptr(handle, poolingDesc, alpha, yDesc, y, dyDesc, dy, xDesc, x, + beta, dxDesc, dx); } cudnnStatus_t CUDNNWINAPI cudnnCreateActivationDescriptor(cudnnActivationDescriptor_t *activationDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnActivationDescriptor_t *); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnActivationDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateActivationDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(activationDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnSetActivationDescriptor(cudnnActivationDescriptor_t activationDesc, - cudnnActivationMode_t mode, - cudnnNanPropagation_t reluNanOpt, - double coef) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnActivationDescriptor_t, cudnnActivationMode_t, cudnnNanPropagation_t, double); +cudnnStatus_t CUDNNWINAPI cudnnSetActivationDescriptor( + cudnnActivationDescriptor_t activationDesc, cudnnActivationMode_t mode, + cudnnNanPropagation_t reluNanOpt, double coef) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnActivationDescriptor_t, + cudnnActivationMode_t, + cudnnNanPropagation_t, double); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetActivationDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(activationDesc, mode, reluNanOpt, coef); @@ -1281,9 +1338,10 @@ cudnnSetActivationDescriptor(cudnnActivationDescriptor_t activationDesc, cudnnStatus_t CUDNNWINAPI cudnnGetActivationDescriptor(const cudnnActivationDescriptor_t activationDesc, cudnnActivationMode_t *mode, - cudnnNanPropagation_t *reluNanOpt, - double *coef) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnActivationDescriptor_t, cudnnActivationMode_t *, cudnnNanPropagation_t *, double *); + cudnnNanPropagation_t *reluNanOpt, double *coef) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnActivationDescriptor_t, cudnnActivationMode_t *, + cudnnNanPropagation_t *, double *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetActivationDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(activationDesc, mode, reluNanOpt, coef); @@ -1291,65 +1349,68 @@ cudnnGetActivationDescriptor(const cudnnActivationDescriptor_t activationDesc, cudnnStatus_t CUDNNWINAPI cudnnDestroyActivationDescriptor(cudnnActivationDescriptor_t activationDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnActivationDescriptor_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyActivationDescriptor"); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnActivationDescriptor_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnDestroyActivationDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(activationDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnActivationForward(cudnnHandle_t handle, - cudnnActivationDescriptor_t activationDesc, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t yDesc, - void *y) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnActivationDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnActivationForward( + cudnnHandle_t handle, cudnnActivationDescriptor_t activationDesc, + const void *alpha, const cudnnTensorDescriptor_t xDesc, const void *x, + const void *beta, const cudnnTensorDescriptor_t yDesc, void *y) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnActivationDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnActivationForward"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, activationDesc, alpha, xDesc, x, beta, yDesc, y); } -cudnnStatus_t CUDNNWINAPI -cudnnActivationBackward(cudnnHandle_t handle, - cudnnActivationDescriptor_t activationDesc, - const void *alpha, - const cudnnTensorDescriptor_t yDesc, - const void *y, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t dxDesc, - void *dx) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnActivationDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnActivationBackward( + cudnnHandle_t handle, cudnnActivationDescriptor_t activationDesc, + const void *alpha, const cudnnTensorDescriptor_t yDesc, const void *y, + const cudnnTensorDescriptor_t dyDesc, const void *dy, + const cudnnTensorDescriptor_t xDesc, const void *x, const void *beta, + const cudnnTensorDescriptor_t dxDesc, void *dx) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnActivationDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnActivationBackward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, activationDesc, alpha, yDesc, y, dyDesc, dy, xDesc, x, beta, dxDesc, dx); + return func_ptr(handle, activationDesc, alpha, yDesc, y, dyDesc, dy, xDesc, x, + beta, dxDesc, dx); } cudnnStatus_t CUDNNWINAPI cudnnCreateLRNDescriptor(cudnnLRNDescriptor_t *normDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnLRNDescriptor_t *); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnLRNDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateLRNDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(normDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnSetLRNDescriptor(cudnnLRNDescriptor_t normDesc, unsigned lrnN, double lrnAlpha, double lrnBeta, double lrnK) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnLRNDescriptor_t, unsigned int, double, double, double); +cudnnStatus_t CUDNNWINAPI cudnnSetLRNDescriptor(cudnnLRNDescriptor_t normDesc, + unsigned lrnN, double lrnAlpha, + double lrnBeta, double lrnK) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnLRNDescriptor_t, unsigned int, double, double, double); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetLRNDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(normDesc, lrnN, lrnAlpha, lrnBeta, lrnK); } -cudnnStatus_t CUDNNWINAPI -cudnnGetLRNDescriptor(cudnnLRNDescriptor_t normDesc, unsigned *lrnN, double *lrnAlpha, double *lrnBeta, double *lrnK) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnLRNDescriptor_t, unsigned int *, double *, double *, double *); +cudnnStatus_t CUDNNWINAPI cudnnGetLRNDescriptor(cudnnLRNDescriptor_t normDesc, + unsigned *lrnN, + double *lrnAlpha, + double *lrnBeta, double *lrnK) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnLRNDescriptor_t, unsigned int *, double *, double *, double *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetLRNDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(normDesc, lrnN, lrnAlpha, lrnBeta, lrnK); @@ -1357,157 +1418,157 @@ cudnnGetLRNDescriptor(cudnnLRNDescriptor_t normDesc, unsigned *lrnN, double *lrn cudnnStatus_t CUDNNWINAPI cudnnDestroyLRNDescriptor(cudnnLRNDescriptor_t lrnDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnLRNDescriptor_t); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnLRNDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyLRNDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(lrnDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnLRNCrossChannelForward(cudnnHandle_t handle, - cudnnLRNDescriptor_t normDesc, - cudnnLRNMode_t lrnMode, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t yDesc, - void *y) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnLRNDescriptor_t, cudnnLRNMode_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnLRNCrossChannelForward( + cudnnHandle_t handle, cudnnLRNDescriptor_t normDesc, cudnnLRNMode_t lrnMode, + const void *alpha, const cudnnTensorDescriptor_t xDesc, const void *x, + const void *beta, const cudnnTensorDescriptor_t yDesc, void *y) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnLRNDescriptor_t, cudnnLRNMode_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnLRNCrossChannelForward"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, normDesc, lrnMode, alpha, xDesc, x, beta, yDesc, y); } -cudnnStatus_t CUDNNWINAPI -cudnnLRNCrossChannelBackward(cudnnHandle_t handle, - cudnnLRNDescriptor_t normDesc, - cudnnLRNMode_t lrnMode, - const void *alpha, - const cudnnTensorDescriptor_t yDesc, - const void *y, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t dxDesc, - void *dx) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnLRNDescriptor_t, cudnnLRNMode_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnLRNCrossChannelBackward( + cudnnHandle_t handle, cudnnLRNDescriptor_t normDesc, cudnnLRNMode_t lrnMode, + const void *alpha, const cudnnTensorDescriptor_t yDesc, const void *y, + const cudnnTensorDescriptor_t dyDesc, const void *dy, + const cudnnTensorDescriptor_t xDesc, const void *x, const void *beta, + const cudnnTensorDescriptor_t dxDesc, void *dx) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnLRNDescriptor_t, cudnnLRNMode_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnLRNCrossChannelBackward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, normDesc, lrnMode, alpha, yDesc, y, dyDesc, dy, xDesc, x, beta, dxDesc, dx); + return func_ptr(handle, normDesc, lrnMode, alpha, yDesc, y, dyDesc, dy, xDesc, + x, beta, dxDesc, dx); } -cudnnStatus_t CUDNNWINAPI -cudnnDivisiveNormalizationForward(cudnnHandle_t handle, - cudnnLRNDescriptor_t normDesc, - cudnnDivNormMode_t mode, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, /* same desc for means, temp, temp2 */ - const void *x, - const void *means, /* if NULL, means are assumed to be zero */ - void *temp, - void *temp2, - const void *beta, - const cudnnTensorDescriptor_t yDesc, - void *y) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnLRNDescriptor_t, cudnnDivNormMode_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, void *, void *, const void *, const cudnnTensorDescriptor_t, void *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDivisiveNormalizationForward"); +cudnnStatus_t CUDNNWINAPI cudnnDivisiveNormalizationForward( + cudnnHandle_t handle, cudnnLRNDescriptor_t normDesc, + cudnnDivNormMode_t mode, const void *alpha, + const cudnnTensorDescriptor_t xDesc, /* same desc for means, temp, temp2 */ + const void *x, + const void *means, /* if NULL, means are assumed to be zero */ + void *temp, void *temp2, const void *beta, + const cudnnTensorDescriptor_t yDesc, void *y) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnLRNDescriptor_t, cudnnDivNormMode_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, void *, void *, + const void *, const cudnnTensorDescriptor_t, void *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnDivisiveNormalizationForward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, normDesc, mode, alpha, xDesc, x, means, temp, temp2, beta, yDesc, y); + return func_ptr(handle, normDesc, mode, alpha, xDesc, x, means, temp, temp2, + beta, yDesc, y); } -cudnnStatus_t CUDNNWINAPI -cudnnDivisiveNormalizationBackward(cudnnHandle_t handle, - cudnnLRNDescriptor_t normDesc, - cudnnDivNormMode_t mode, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, /* same desc for x, means, dy, temp, temp2 */ - const void *x, - const void *means, /* if NULL, means are assumed to be zero */ - const void *dy, - void *temp, - void *temp2, - const void *beta, - const cudnnTensorDescriptor_t dXdMeansDesc, /* same desc for dx, dMeans */ - void *dx, /* output x differential */ - void *dMeans) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnLRNDescriptor_t, cudnnDivNormMode_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const void *, void *, void *, const void *, const cudnnTensorDescriptor_t, void *, void *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDivisiveNormalizationBackward"); +cudnnStatus_t CUDNNWINAPI cudnnDivisiveNormalizationBackward( + cudnnHandle_t handle, cudnnLRNDescriptor_t normDesc, + cudnnDivNormMode_t mode, const void *alpha, + const cudnnTensorDescriptor_t + xDesc, /* same desc for x, means, dy, temp, temp2 */ + const void *x, + const void *means, /* if NULL, means are assumed to be zero */ + const void *dy, void *temp, void *temp2, const void *beta, + const cudnnTensorDescriptor_t dXdMeansDesc, /* same desc for dx, dMeans */ + void *dx, /* output x differential */ + void *dMeans) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnLRNDescriptor_t, cudnnDivNormMode_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, const void *, + void *, void *, const void *, const cudnnTensorDescriptor_t, void *, + void *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnDivisiveNormalizationBackward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, normDesc, mode, alpha, xDesc, x, means, dy, temp, temp2, beta, dXdMeansDesc, dx, dMeans); + return func_ptr(handle, normDesc, mode, alpha, xDesc, x, means, dy, temp, + temp2, beta, dXdMeansDesc, dx, dMeans); } -cudnnStatus_t CUDNNWINAPI -cudnnDeriveBNTensorDescriptor(cudnnTensorDescriptor_t derivedBnDesc, - const cudnnTensorDescriptor_t xDesc, - cudnnBatchNormMode_t mode) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, cudnnBatchNormMode_t); +cudnnStatus_t CUDNNWINAPI cudnnDeriveBNTensorDescriptor( + cudnnTensorDescriptor_t derivedBnDesc, const cudnnTensorDescriptor_t xDesc, + cudnnBatchNormMode_t mode) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnTensorDescriptor_t, + const cudnnTensorDescriptor_t, + cudnnBatchNormMode_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDeriveBNTensorDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(derivedBnDesc, xDesc, mode); } cudnnStatus_t CUDNNWINAPI -cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(cudnnHandle_t handle, - cudnnBatchNormMode_t mode, - cudnnBatchNormOps_t bnOps, - const cudnnTensorDescriptor_t xDesc, - const cudnnTensorDescriptor_t zDesc, - const cudnnTensorDescriptor_t yDesc, - const cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc, - const cudnnActivationDescriptor_t activationDesc, - size_t *sizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnBatchNormMode_t, cudnnBatchNormOps_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnActivationDescriptor_t, size_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize"); +cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize( + cudnnHandle_t handle, cudnnBatchNormMode_t mode, cudnnBatchNormOps_t bnOps, + const cudnnTensorDescriptor_t xDesc, const cudnnTensorDescriptor_t zDesc, + const cudnnTensorDescriptor_t yDesc, + const cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc, + const cudnnActivationDescriptor_t activationDesc, size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnBatchNormMode_t, cudnnBatchNormOps_t, + const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, + const cudnnActivationDescriptor_t, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>( + "cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, mode, bnOps, xDesc, zDesc, yDesc, bnScaleBiasMeanVarDesc, activationDesc, sizeInBytes); + return func_ptr(handle, mode, bnOps, xDesc, zDesc, yDesc, + bnScaleBiasMeanVarDesc, activationDesc, sizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnGetBatchNormalizationBackwardExWorkspaceSize(cudnnHandle_t handle, - cudnnBatchNormMode_t mode, - cudnnBatchNormOps_t bnOps, - const cudnnTensorDescriptor_t xDesc, - const cudnnTensorDescriptor_t yDesc, - const cudnnTensorDescriptor_t dyDesc, - const cudnnTensorDescriptor_t dzDesc, - const cudnnTensorDescriptor_t dxDesc, - const cudnnTensorDescriptor_t dBnScaleBiasDesc, - const cudnnActivationDescriptor_t activationDesc, - size_t *sizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnBatchNormMode_t, cudnnBatchNormOps_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const cudnnActivationDescriptor_t, size_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetBatchNormalizationBackwardExWorkspaceSize"); +cudnnStatus_t CUDNNWINAPI cudnnGetBatchNormalizationBackwardExWorkspaceSize( + cudnnHandle_t handle, cudnnBatchNormMode_t mode, cudnnBatchNormOps_t bnOps, + const cudnnTensorDescriptor_t xDesc, const cudnnTensorDescriptor_t yDesc, + const cudnnTensorDescriptor_t dyDesc, const cudnnTensorDescriptor_t dzDesc, + const cudnnTensorDescriptor_t dxDesc, + const cudnnTensorDescriptor_t dBnScaleBiasDesc, + const cudnnActivationDescriptor_t activationDesc, size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnBatchNormMode_t, cudnnBatchNormOps_t, + const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, + const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, + const cudnnActivationDescriptor_t, size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetBatchNormalizationBackwardExWorkspaceSize"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, mode, bnOps, xDesc, yDesc, dyDesc, dzDesc, dxDesc, dBnScaleBiasDesc, activationDesc, sizeInBytes); + return func_ptr(handle, mode, bnOps, xDesc, yDesc, dyDesc, dzDesc, dxDesc, + dBnScaleBiasDesc, activationDesc, sizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnGetBatchNormalizationTrainingExReserveSpaceSize(cudnnHandle_t handle, - cudnnBatchNormMode_t mode, - cudnnBatchNormOps_t bnOps, - const cudnnActivationDescriptor_t activationDesc, - const cudnnTensorDescriptor_t xDesc, - size_t *sizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnBatchNormMode_t, cudnnBatchNormOps_t, const cudnnActivationDescriptor_t, const cudnnTensorDescriptor_t, size_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetBatchNormalizationTrainingExReserveSpaceSize"); +cudnnStatus_t CUDNNWINAPI cudnnGetBatchNormalizationTrainingExReserveSpaceSize( + cudnnHandle_t handle, cudnnBatchNormMode_t mode, cudnnBatchNormOps_t bnOps, + const cudnnActivationDescriptor_t activationDesc, + const cudnnTensorDescriptor_t xDesc, size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnBatchNormMode_t, cudnnBatchNormOps_t, + const cudnnActivationDescriptor_t, const cudnnTensorDescriptor_t, + size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>( + "cudnnGetBatchNormalizationTrainingExReserveSpaceSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, mode, bnOps, activationDesc, xDesc, sizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnBatchNormalizationForwardTraining( - cudnnHandle_t handle, - cudnnBatchNormMode_t mode, +cudnnStatus_t CUDNNWINAPI cudnnBatchNormalizationForwardTraining( + cudnnHandle_t handle, cudnnBatchNormMode_t mode, const void *alpha, /* alpha[0] = result blend factor */ const void *beta, /* beta[0] = dest layer blend factor */ - const cudnnTensorDescriptor_t xDesc, - const void *x, /* NxCxHxW */ - const cudnnTensorDescriptor_t yDesc, - void *y, /* NxCxHxW */ + const cudnnTensorDescriptor_t xDesc, const void *x, /* NxCxHxW */ + const cudnnTensorDescriptor_t yDesc, void *y, /* NxCxHxW */ /* Shared desc for the next 6 tensors in the argument list. Data type to be set as follows: @@ -1515,13 +1576,13 @@ cudnnBatchNormalizationForwardTraining( Dimensions for this descriptor depend on normalization mode - Spatial Normalization : tensors are expected to have dims 1xCx1x1 (normalization is performed across NxHxW) - - Per-Activation Normalization : tensors are expected to have dims of 1xCxHxW - (normalization is performed across N) */ + - Per-Activation Normalization : tensors are expected to have dims of + 1xCxHxW (normalization is performed across N) */ const cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc, - /* 'Gamma' and 'Beta' respectively in Ioffe and Szegedy's paper's notation */ - const void *bnScale, - const void *bnBias, + /* 'Gamma' and 'Beta' respectively in Ioffe and Szegedy's paper's notation + */ + const void *bnScale, const void *bnBias, /* MUST use factor=1 in the very first call of a complete training cycle. Use a factor=1/(1+n) at N-th call to the function to get @@ -1539,248 +1600,261 @@ cudnnBatchNormalizationForwardTraining( of variance[x] (factor is applied in the same way as for runningMean) */ void *resultRunningVariance, - /* Has to be >= CUDNN_BN_MIN_EPSILON. Should be the same in forward and backward functions. */ + /* Has to be >= CUDNN_BN_MIN_EPSILON. Should be the same in forward and + backward functions. */ double epsilon, /* Optionally save intermediate results from the forward pass here - can be reused to speed up backward pass. NULL if unused */ - void *resultSaveMean, - void *resultSaveInvVariance) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnBatchNormMode_t, const void *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, const void *, const void *, double, void *, void *, double, void *, void *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnBatchNormalizationForwardTraining"); + void *resultSaveMean, void *resultSaveInvVariance) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnBatchNormMode_t, const void *, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, + const void *, const void *, double, void *, void *, double, void *, + void *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnBatchNormalizationForwardTraining"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, mode, alpha, beta, xDesc, x, yDesc, y, bnScaleBiasMeanVarDesc, bnScale, bnBias, exponentialAverageFactor, resultRunningMean, resultRunningVariance, epsilon, resultSaveMean, resultSaveInvVariance); + return func_ptr( + handle, mode, alpha, beta, xDesc, x, yDesc, y, bnScaleBiasMeanVarDesc, + bnScale, bnBias, exponentialAverageFactor, resultRunningMean, + resultRunningVariance, epsilon, resultSaveMean, resultSaveInvVariance); } -cudnnStatus_t CUDNNWINAPI -cudnnBatchNormalizationForwardTrainingEx( - cudnnHandle_t handle, - cudnnBatchNormMode_t mode, - cudnnBatchNormOps_t bnOps, +cudnnStatus_t CUDNNWINAPI cudnnBatchNormalizationForwardTrainingEx( + cudnnHandle_t handle, cudnnBatchNormMode_t mode, cudnnBatchNormOps_t bnOps, const void *alpha, /* alpha[0] = result blend factor */ const void *beta, /* beta[0] = dest layer blend factor */ - const cudnnTensorDescriptor_t xDesc, - const void *xData, - const cudnnTensorDescriptor_t zDesc, - const void *zData, - const cudnnTensorDescriptor_t yDesc, - void *yData, + const cudnnTensorDescriptor_t xDesc, const void *xData, + const cudnnTensorDescriptor_t zDesc, const void *zData, + const cudnnTensorDescriptor_t yDesc, void *yData, - const cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc, - const void *bnScale, + const cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc, const void *bnScale, const void *bnBias, - double exponentialAverageFactor, - void *resultRunningMean, + double exponentialAverageFactor, void *resultRunningMean, void *resultRunningVariance, - /* Has to be >= CUDNN_BN_MIN_EPSILON. Should be the same in forward and backward functions. */ + /* Has to be >= CUDNN_BN_MIN_EPSILON. Should be the same in forward and + backward functions. */ double epsilon, /* Optionally save intermediate results from the forward pass here - can be reused to speed up backward pass. NULL if unused */ - void *resultSaveMean, - void *resultSaveInvVariance, + void *resultSaveMean, void *resultSaveInvVariance, - cudnnActivationDescriptor_t activationDesc, - void *workspace, - size_t workSpaceSizeInBytes, - void *reserveSpace, + cudnnActivationDescriptor_t activationDesc, void *workspace, + size_t workSpaceSizeInBytes, void *reserveSpace, size_t reserveSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnBatchNormMode_t, cudnnBatchNormOps_t, const void *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, const void *, const void *, double, void *, void *, double, void *, void *, cudnnActivationDescriptor_t, void *, size_t, void *, size_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnBatchNormalizationForwardTrainingEx"); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnBatchNormMode_t, cudnnBatchNormOps_t, const void *, + const void *, const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, + const void *, const void *, double, void *, void *, double, void *, + void *, cudnnActivationDescriptor_t, void *, size_t, void *, size_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnBatchNormalizationForwardTrainingEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, mode, bnOps, alpha, beta, xDesc, xData, zDesc, zData, yDesc, yData, bnScaleBiasMeanVarDesc, bnScale, bnBias, exponentialAverageFactor, resultRunningMean, resultRunningVariance, epsilon, resultSaveMean, resultSaveInvVariance, activationDesc, workspace, workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes); + return func_ptr(handle, mode, bnOps, alpha, beta, xDesc, xData, zDesc, zData, + yDesc, yData, bnScaleBiasMeanVarDesc, bnScale, bnBias, + exponentialAverageFactor, resultRunningMean, + resultRunningVariance, epsilon, resultSaveMean, + resultSaveInvVariance, activationDesc, workspace, + workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnBatchNormalizationForwardInference(cudnnHandle_t handle, - cudnnBatchNormMode_t mode, - const void *alpha, /* alpha[0] = result blend factor */ - const void *beta, /* beta[0] = dest layer blend factor */ - const cudnnTensorDescriptor_t xDesc, - const void *x, /* NxCxHxW */ - const cudnnTensorDescriptor_t yDesc, - void *y, /* NxCxHxW */ - const cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc, - const void *bnScale, - const void *bnBias, - const void *estimatedMean, - const void *estimatedVariance, - double epsilon) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnBatchNormMode_t, const void *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, const void *, const void *, const void *, const void *, double); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnBatchNormalizationForwardInference"); +cudnnStatus_t CUDNNWINAPI cudnnBatchNormalizationForwardInference( + cudnnHandle_t handle, cudnnBatchNormMode_t mode, + const void *alpha, /* alpha[0] = result blend factor */ + const void *beta, /* beta[0] = dest layer blend factor */ + const cudnnTensorDescriptor_t xDesc, const void *x, /* NxCxHxW */ + const cudnnTensorDescriptor_t yDesc, void *y, /* NxCxHxW */ + const cudnnTensorDescriptor_t bnScaleBiasMeanVarDesc, const void *bnScale, + const void *bnBias, const void *estimatedMean, + const void *estimatedVariance, double epsilon) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnBatchNormMode_t, const void *, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, + const void *, const void *, const void *, const void *, double); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnBatchNormalizationForwardInference"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, mode, alpha, beta, xDesc, x, yDesc, y, bnScaleBiasMeanVarDesc, bnScale, bnBias, estimatedMean, estimatedVariance, epsilon); + return func_ptr(handle, mode, alpha, beta, xDesc, x, yDesc, y, + bnScaleBiasMeanVarDesc, bnScale, bnBias, estimatedMean, + estimatedVariance, epsilon); } -cudnnStatus_t CUDNNWINAPI -cudnnBatchNormalizationBackward(cudnnHandle_t handle, - cudnnBatchNormMode_t mode, - const void *alphaDataDiff, - const void *betaDataDiff, - const void *alphaParamDiff, - const void *betaParamDiff, - const cudnnTensorDescriptor_t xDesc, /* same desc for x, dx, dy */ - const void *x, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const cudnnTensorDescriptor_t dxDesc, - void *dx, - /* Shared tensor desc for the 4 tensors below */ - const cudnnTensorDescriptor_t dBnScaleBiasDesc, - const void *bnScale, /* bnBias doesn't affect backpropagation */ - /* scale and bias diff are not backpropagated below this layer */ - void *dBnScaleResult, - void *dBnBiasResult, - /* Same epsilon as forward pass */ - double epsilon, +cudnnStatus_t CUDNNWINAPI cudnnBatchNormalizationBackward( + cudnnHandle_t handle, cudnnBatchNormMode_t mode, const void *alphaDataDiff, + const void *betaDataDiff, const void *alphaParamDiff, + const void *betaParamDiff, + const cudnnTensorDescriptor_t xDesc, /* same desc for x, dx, dy */ + const void *x, const cudnnTensorDescriptor_t dyDesc, const void *dy, + const cudnnTensorDescriptor_t dxDesc, void *dx, + /* Shared tensor desc for the 4 tensors below */ + const cudnnTensorDescriptor_t dBnScaleBiasDesc, + const void *bnScale, /* bnBias doesn't affect backpropagation */ + /* scale and bias diff are not backpropagated below this layer */ + void *dBnScaleResult, void *dBnBiasResult, + /* Same epsilon as forward pass */ + double epsilon, - /* Optionally cached intermediate results from - forward pass */ - const void *savedMean, - const void *savedInvVariance) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnBatchNormMode_t, const void *, const void *, const void *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, const void *, void *, void *, double, const void *, const void *); + /* Optionally cached intermediate results from + forward pass */ + const void *savedMean, const void *savedInvVariance) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnBatchNormMode_t, const void *, const void *, + const void *, const void *, const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, + const void *, void *, void *, double, const void *, const void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnBatchNormalizationBackward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, mode, alphaDataDiff, betaDataDiff, alphaParamDiff, betaParamDiff, xDesc, x, dyDesc, dy, dxDesc, dx, dBnScaleBiasDesc, bnScale, dBnScaleResult, dBnBiasResult, epsilon, savedMean, savedInvVariance); + return func_ptr(handle, mode, alphaDataDiff, betaDataDiff, alphaParamDiff, + betaParamDiff, xDesc, x, dyDesc, dy, dxDesc, dx, + dBnScaleBiasDesc, bnScale, dBnScaleResult, dBnBiasResult, + epsilon, savedMean, savedInvVariance); } -cudnnStatus_t CUDNNWINAPI -cudnnBatchNormalizationBackwardEx(cudnnHandle_t handle, - cudnnBatchNormMode_t mode, - cudnnBatchNormOps_t bnOps, +cudnnStatus_t CUDNNWINAPI cudnnBatchNormalizationBackwardEx( + cudnnHandle_t handle, cudnnBatchNormMode_t mode, cudnnBatchNormOps_t bnOps, - const void *alphaDataDiff, - const void *betaDataDiff, - const void *alphaParamDiff, - const void *betaParamDiff, - const cudnnTensorDescriptor_t xDesc, - const void *xData, - const cudnnTensorDescriptor_t yDesc, - const void *yData, - const cudnnTensorDescriptor_t dyDesc, - const void *dyData, - const cudnnTensorDescriptor_t dzDesc, - void *dzData, - const cudnnTensorDescriptor_t dxDesc, - void *dxData, + const void *alphaDataDiff, const void *betaDataDiff, + const void *alphaParamDiff, const void *betaParamDiff, + const cudnnTensorDescriptor_t xDesc, const void *xData, + const cudnnTensorDescriptor_t yDesc, const void *yData, + const cudnnTensorDescriptor_t dyDesc, const void *dyData, + const cudnnTensorDescriptor_t dzDesc, void *dzData, + const cudnnTensorDescriptor_t dxDesc, void *dxData, - /* Shared tensor desc for the 4 tensors below */ - const cudnnTensorDescriptor_t dBnScaleBiasDesc, - const void *bnScaleData, - const void *bnBiasData, /* needed if there is activation */ - void *dBnScaleData, - void *dBnBiasData, - double epsilon, /* Same epsilon as forward pass */ + /* Shared tensor desc for the 4 tensors below */ + const cudnnTensorDescriptor_t dBnScaleBiasDesc, const void *bnScaleData, + const void *bnBiasData, /* needed if there is activation */ + void *dBnScaleData, void *dBnBiasData, + double epsilon, /* Same epsilon as forward pass */ - /* Optionally cached intermediate results from - forward pass */ - const void *savedMean, - const void *savedInvVariance, - cudnnActivationDescriptor_t activationDesc, - void *workSpace, - size_t workSpaceSizeInBytes, - void *reserveSpace, - size_t reserveSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnBatchNormMode_t, cudnnBatchNormOps_t, const void *, const void *, const void *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, const void *, const void *, void *, void *, double, const void *, const void *, cudnnActivationDescriptor_t, void *, size_t, void *, size_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnBatchNormalizationBackwardEx"); + /* Optionally cached intermediate results from + forward pass */ + const void *savedMean, const void *savedInvVariance, + cudnnActivationDescriptor_t activationDesc, void *workSpace, + size_t workSpaceSizeInBytes, void *reserveSpace, + size_t reserveSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnBatchNormMode_t, cudnnBatchNormOps_t, const void *, + const void *, const void *, const void *, const cudnnTensorDescriptor_t, + const void *, const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, + void *, const cudnnTensorDescriptor_t, const void *, const void *, void *, + void *, double, const void *, const void *, cudnnActivationDescriptor_t, + void *, size_t, void *, size_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnBatchNormalizationBackwardEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, mode, bnOps, alphaDataDiff, betaDataDiff, alphaParamDiff, betaParamDiff, xDesc, xData, yDesc, yData, dyDesc, dyData, dzDesc, dzData, dxDesc, dxData, dBnScaleBiasDesc, bnScaleData, bnBiasData, dBnScaleData, dBnBiasData, epsilon, savedMean, savedInvVariance, activationDesc, workSpace, workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes); + return func_ptr( + handle, mode, bnOps, alphaDataDiff, betaDataDiff, alphaParamDiff, + betaParamDiff, xDesc, xData, yDesc, yData, dyDesc, dyData, dzDesc, dzData, + dxDesc, dxData, dBnScaleBiasDesc, bnScaleData, bnBiasData, dBnScaleData, + dBnBiasData, epsilon, savedMean, savedInvVariance, activationDesc, + workSpace, workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnCreateSpatialTransformerDescriptor(cudnnSpatialTransformerDescriptor_t *stDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnSpatialTransformerDescriptor_t *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateSpatialTransformerDescriptor"); +cudnnStatus_t CUDNNWINAPI cudnnCreateSpatialTransformerDescriptor( + cudnnSpatialTransformerDescriptor_t *stDesc) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnSpatialTransformerDescriptor_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnCreateSpatialTransformerDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(stDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnSetSpatialTransformerNdDescriptor(cudnnSpatialTransformerDescriptor_t stDesc, - cudnnSamplerType_t samplerType, - cudnnDataType_t dataType, - const int nbDims, - const int dimA[]) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnSpatialTransformerDescriptor_t, cudnnSamplerType_t, cudnnDataType_t, const int, const int []); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetSpatialTransformerNdDescriptor"); +cudnnStatus_t CUDNNWINAPI cudnnSetSpatialTransformerNdDescriptor( + cudnnSpatialTransformerDescriptor_t stDesc, cudnnSamplerType_t samplerType, + cudnnDataType_t dataType, const int nbDims, const int dimA[]) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnSpatialTransformerDescriptor_t, cudnnSamplerType_t, cudnnDataType_t, + const int, const int[]); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnSetSpatialTransformerNdDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(stDesc, samplerType, dataType, nbDims, dimA); } -cudnnStatus_t CUDNNWINAPI -cudnnDestroySpatialTransformerDescriptor(cudnnSpatialTransformerDescriptor_t stDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnSpatialTransformerDescriptor_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroySpatialTransformerDescriptor"); +cudnnStatus_t CUDNNWINAPI cudnnDestroySpatialTransformerDescriptor( + cudnnSpatialTransformerDescriptor_t stDesc) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnSpatialTransformerDescriptor_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnDestroySpatialTransformerDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(stDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnSpatialTfGridGeneratorForward(cudnnHandle_t handle, - const cudnnSpatialTransformerDescriptor_t stDesc, - const void *theta, - void *grid) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnSpatialTransformerDescriptor_t, const void *, void *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSpatialTfGridGeneratorForward"); +cudnnStatus_t CUDNNWINAPI cudnnSpatialTfGridGeneratorForward( + cudnnHandle_t handle, const cudnnSpatialTransformerDescriptor_t stDesc, + const void *theta, void *grid) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnSpatialTransformerDescriptor_t, const void *, + void *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnSpatialTfGridGeneratorForward"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, stDesc, theta, grid); } -cudnnStatus_t CUDNNWINAPI -cudnnSpatialTfGridGeneratorBackward(cudnnHandle_t handle, - const cudnnSpatialTransformerDescriptor_t stDesc, - const void *dgrid, - void *dtheta) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnSpatialTransformerDescriptor_t, const void *, void *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSpatialTfGridGeneratorBackward"); +cudnnStatus_t CUDNNWINAPI cudnnSpatialTfGridGeneratorBackward( + cudnnHandle_t handle, const cudnnSpatialTransformerDescriptor_t stDesc, + const void *dgrid, void *dtheta) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnSpatialTransformerDescriptor_t, const void *, + void *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnSpatialTfGridGeneratorBackward"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, stDesc, dgrid, dtheta); } -cudnnStatus_t CUDNNWINAPI -cudnnSpatialTfSamplerForward(cudnnHandle_t handle, - cudnnSpatialTransformerDescriptor_t stDesc, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *grid, - const void *beta, - cudnnTensorDescriptor_t yDesc, - void *y) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnSpatialTransformerDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const void *, cudnnTensorDescriptor_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnSpatialTfSamplerForward( + cudnnHandle_t handle, cudnnSpatialTransformerDescriptor_t stDesc, + const void *alpha, const cudnnTensorDescriptor_t xDesc, const void *x, + const void *grid, const void *beta, cudnnTensorDescriptor_t yDesc, + void *y) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnSpatialTransformerDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, const void *, + cudnnTensorDescriptor_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSpatialTfSamplerForward"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, stDesc, alpha, xDesc, x, grid, beta, yDesc, y); } -cudnnStatus_t CUDNNWINAPI -cudnnSpatialTfSamplerBackward(cudnnHandle_t handle, - cudnnSpatialTransformerDescriptor_t stDesc, - const void *alpha, - const cudnnTensorDescriptor_t xDesc, - const void *x, - const void *beta, - const cudnnTensorDescriptor_t dxDesc, - void *dx, - const void *alphaDgrid, - const cudnnTensorDescriptor_t dyDesc, - const void *dy, - const void *grid, - const void *betaDgrid, - void *dgrid) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnSpatialTransformerDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const cudnnTensorDescriptor_t, void *, const void *, const cudnnTensorDescriptor_t, const void *, const void *, const void *, void *); +cudnnStatus_t CUDNNWINAPI cudnnSpatialTfSamplerBackward( + cudnnHandle_t handle, cudnnSpatialTransformerDescriptor_t stDesc, + const void *alpha, const cudnnTensorDescriptor_t xDesc, const void *x, + const void *beta, const cudnnTensorDescriptor_t dxDesc, void *dx, + const void *alphaDgrid, const cudnnTensorDescriptor_t dyDesc, + const void *dy, const void *grid, const void *betaDgrid, void *dgrid) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnSpatialTransformerDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, + const cudnnTensorDescriptor_t, void *, const void *, + const cudnnTensorDescriptor_t, const void *, const void *, const void *, + void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSpatialTfSamplerBackward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, stDesc, alpha, xDesc, x, beta, dxDesc, dx, alphaDgrid, dyDesc, dy, grid, betaDgrid, dgrid); + return func_ptr(handle, stDesc, alpha, xDesc, x, beta, dxDesc, dx, alphaDgrid, + dyDesc, dy, grid, betaDgrid, dgrid); } cudnnStatus_t CUDNNWINAPI cudnnCreateDropoutDescriptor(cudnnDropoutDescriptor_t *dropoutDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnDropoutDescriptor_t *); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnDropoutDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateDropoutDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(dropoutDesc); @@ -1788,99 +1862,95 @@ cudnnCreateDropoutDescriptor(cudnnDropoutDescriptor_t *dropoutDesc) { cudnnStatus_t CUDNNWINAPI cudnnDestroyDropoutDescriptor(cudnnDropoutDescriptor_t dropoutDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnDropoutDescriptor_t); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnDropoutDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyDropoutDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(dropoutDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnDropoutGetStatesSize(cudnnHandle_t handle, size_t *sizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, size_t *); +cudnnStatus_t CUDNNWINAPI cudnnDropoutGetStatesSize(cudnnHandle_t handle, + size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDropoutGetStatesSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, sizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnDropoutGetReserveSpaceSize(cudnnTensorDescriptor_t xdesc, size_t *sizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnTensorDescriptor_t, size_t *); +cudnnStatus_t CUDNNWINAPI cudnnDropoutGetReserveSpaceSize( + cudnnTensorDescriptor_t xdesc, size_t *sizeInBytes) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnTensorDescriptor_t, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDropoutGetReserveSpaceSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(xdesc, sizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnSetDropoutDescriptor(cudnnDropoutDescriptor_t dropoutDesc, - cudnnHandle_t handle, - float dropout, - void *states, - size_t stateSizeInBytes, - unsigned long long seed) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnDropoutDescriptor_t, cudnnHandle_t, float, void *, size_t, unsigned long long); +cudnnStatus_t CUDNNWINAPI cudnnSetDropoutDescriptor( + cudnnDropoutDescriptor_t dropoutDesc, cudnnHandle_t handle, float dropout, + void *states, size_t stateSizeInBytes, unsigned long long seed) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnDropoutDescriptor_t, cudnnHandle_t, + float, void *, size_t, unsigned long long); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetDropoutDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(dropoutDesc, handle, dropout, states, stateSizeInBytes, seed); } -cudnnStatus_t CUDNNWINAPI -cudnnRestoreDropoutDescriptor(cudnnDropoutDescriptor_t dropoutDesc, - cudnnHandle_t handle, - float dropout, - void *states, - size_t stateSizeInBytes, - unsigned long long seed) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnDropoutDescriptor_t, cudnnHandle_t, float, void *, size_t, unsigned long long); +cudnnStatus_t CUDNNWINAPI cudnnRestoreDropoutDescriptor( + cudnnDropoutDescriptor_t dropoutDesc, cudnnHandle_t handle, float dropout, + void *states, size_t stateSizeInBytes, unsigned long long seed) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnDropoutDescriptor_t, cudnnHandle_t, + float, void *, size_t, unsigned long long); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRestoreDropoutDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(dropoutDesc, handle, dropout, states, stateSizeInBytes, seed); } -cudnnStatus_t CUDNNWINAPI -cudnnGetDropoutDescriptor(cudnnDropoutDescriptor_t dropoutDesc, - cudnnHandle_t handle, - float *dropout, - void **states, - unsigned long long *seed) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnDropoutDescriptor_t, cudnnHandle_t, float *, void **, unsigned long long *); +cudnnStatus_t CUDNNWINAPI cudnnGetDropoutDescriptor( + cudnnDropoutDescriptor_t dropoutDesc, cudnnHandle_t handle, float *dropout, + void **states, unsigned long long *seed) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnDropoutDescriptor_t, cudnnHandle_t, + float *, void **, unsigned long long *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetDropoutDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(dropoutDesc, handle, dropout, states, seed); } -cudnnStatus_t CUDNNWINAPI -cudnnDropoutForward(cudnnHandle_t handle, - const cudnnDropoutDescriptor_t dropoutDesc, - const cudnnTensorDescriptor_t xdesc, - const void *x, - const cudnnTensorDescriptor_t ydesc, - void *y, - void *reserveSpace, - size_t reserveSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnDropoutDescriptor_t, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, void *, void *, size_t); +cudnnStatus_t CUDNNWINAPI cudnnDropoutForward( + cudnnHandle_t handle, const cudnnDropoutDescriptor_t dropoutDesc, + const cudnnTensorDescriptor_t xdesc, const void *x, + const cudnnTensorDescriptor_t ydesc, void *y, void *reserveSpace, + size_t reserveSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnDropoutDescriptor_t, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, void *, void *, size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDropoutForward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, dropoutDesc, xdesc, x, ydesc, y, reserveSpace, reserveSpaceSizeInBytes); + return func_ptr(handle, dropoutDesc, xdesc, x, ydesc, y, reserveSpace, + reserveSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnDropoutBackward(cudnnHandle_t handle, - const cudnnDropoutDescriptor_t dropoutDesc, - const cudnnTensorDescriptor_t dydesc, - const void *dy, - const cudnnTensorDescriptor_t dxdesc, - void *dx, - void *reserveSpace, - size_t reserveSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnDropoutDescriptor_t, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, void *, void *, size_t); +cudnnStatus_t CUDNNWINAPI cudnnDropoutBackward( + cudnnHandle_t handle, const cudnnDropoutDescriptor_t dropoutDesc, + const cudnnTensorDescriptor_t dydesc, const void *dy, + const cudnnTensorDescriptor_t dxdesc, void *dx, void *reserveSpace, + size_t reserveSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnDropoutDescriptor_t, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, void *, void *, size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDropoutBackward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, dropoutDesc, dydesc, dy, dxdesc, dx, reserveSpace, reserveSpaceSizeInBytes); + return func_ptr(handle, dropoutDesc, dydesc, dy, dxdesc, dx, reserveSpace, + reserveSpaceSizeInBytes); } cudnnStatus_t CUDNNWINAPI cudnnCreateRNNDescriptor(cudnnRNNDescriptor_t *rnnDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t *); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnRNNDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateRNNDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(rnnDesc); @@ -1888,132 +1958,130 @@ cudnnCreateRNNDescriptor(cudnnRNNDescriptor_t *rnnDesc) { cudnnStatus_t CUDNNWINAPI cudnnDestroyRNNDescriptor(cudnnRNNDescriptor_t rnnDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnRNNDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyRNNDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(rnnDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnSetRNNDescriptor(cudnnHandle_t handle, - cudnnRNNDescriptor_t rnnDesc, - const int hiddenSize, - const int numLayers, - cudnnDropoutDescriptor_t dropoutDesc, - cudnnRNNInputMode_t inputMode, - cudnnDirectionMode_t direction, - cudnnRNNMode_t mode, - cudnnRNNAlgo_t algo, - cudnnDataType_t mathPrec) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnRNNDescriptor_t, const int, const int, cudnnDropoutDescriptor_t, cudnnRNNInputMode_t, cudnnDirectionMode_t, cudnnRNNMode_t, cudnnRNNAlgo_t, cudnnDataType_t); +cudnnStatus_t CUDNNWINAPI cudnnSetRNNDescriptor( + cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc, const int hiddenSize, + const int numLayers, cudnnDropoutDescriptor_t dropoutDesc, + cudnnRNNInputMode_t inputMode, cudnnDirectionMode_t direction, + cudnnRNNMode_t mode, cudnnRNNAlgo_t algo, cudnnDataType_t mathPrec) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnRNNDescriptor_t, const int, const int, + cudnnDropoutDescriptor_t, cudnnRNNInputMode_t, cudnnDirectionMode_t, + cudnnRNNMode_t, cudnnRNNAlgo_t, cudnnDataType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetRNNDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, hiddenSize, numLayers, dropoutDesc, inputMode, direction, mode, algo, mathPrec); + return func_ptr(handle, rnnDesc, hiddenSize, numLayers, dropoutDesc, + inputMode, direction, mode, algo, mathPrec); } -cudnnStatus_t CUDNNWINAPI -cudnnGetRNNDescriptor(cudnnHandle_t handle, - cudnnRNNDescriptor_t rnnDesc, - int *hiddenSize, - int *numLayers, - cudnnDropoutDescriptor_t *dropoutDesc, - cudnnRNNInputMode_t *inputMode, - cudnnDirectionMode_t *direction, - cudnnRNNMode_t *mode, - cudnnRNNAlgo_t *algo, - cudnnDataType_t *mathPrec) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnRNNDescriptor_t, int *, int *, cudnnDropoutDescriptor_t *, cudnnRNNInputMode_t *, cudnnDirectionMode_t *, cudnnRNNMode_t *, cudnnRNNAlgo_t *, cudnnDataType_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetRNNDescriptor( + cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc, int *hiddenSize, + int *numLayers, cudnnDropoutDescriptor_t *dropoutDesc, + cudnnRNNInputMode_t *inputMode, cudnnDirectionMode_t *direction, + cudnnRNNMode_t *mode, cudnnRNNAlgo_t *algo, cudnnDataType_t *mathPrec) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnRNNDescriptor_t, int *, int *, + cudnnDropoutDescriptor_t *, cudnnRNNInputMode_t *, cudnnDirectionMode_t *, + cudnnRNNMode_t *, cudnnRNNAlgo_t *, cudnnDataType_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, hiddenSize, numLayers, dropoutDesc, inputMode, direction, mode, algo, mathPrec); + return func_ptr(handle, rnnDesc, hiddenSize, numLayers, dropoutDesc, + inputMode, direction, mode, algo, mathPrec); } cudnnStatus_t CUDNNWINAPI cudnnSetRNNMatrixMathType(cudnnRNNDescriptor_t rnnDesc, cudnnMathType_t mType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t, cudnnMathType_t); + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnRNNDescriptor_t, cudnnMathType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetRNNMatrixMathType"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(rnnDesc, mType); } -cudnnStatus_t CUDNNWINAPI -cudnnGetRNNMatrixMathType(cudnnRNNDescriptor_t rnnDesc, cudnnMathType_t *mType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t, cudnnMathType_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetRNNMatrixMathType( + cudnnRNNDescriptor_t rnnDesc, cudnnMathType_t *mType) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnRNNDescriptor_t, cudnnMathType_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNMatrixMathType"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(rnnDesc, mType); } -cudnnStatus_t CUDNNWINAPI -cudnnSetRNNBiasMode(cudnnRNNDescriptor_t rnnDesc, cudnnRNNBiasMode_t biasMode) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t, cudnnRNNBiasMode_t); +cudnnStatus_t CUDNNWINAPI cudnnSetRNNBiasMode(cudnnRNNDescriptor_t rnnDesc, + cudnnRNNBiasMode_t biasMode) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnRNNDescriptor_t, cudnnRNNBiasMode_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetRNNBiasMode"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(rnnDesc, biasMode); } -cudnnStatus_t CUDNNWINAPI -cudnnGetRNNBiasMode(cudnnRNNDescriptor_t rnnDesc, cudnnRNNBiasMode_t *biasMode) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t, cudnnRNNBiasMode_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetRNNBiasMode(cudnnRNNDescriptor_t rnnDesc, + cudnnRNNBiasMode_t *biasMode) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnRNNDescriptor_t, cudnnRNNBiasMode_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNBiasMode"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(rnnDesc, biasMode); } -cudnnStatus_t CUDNNWINAPI -cudnnRNNSetClip(cudnnHandle_t handle, - cudnnRNNDescriptor_t rnnDesc, - cudnnRNNClipMode_t clipMode, - cudnnNanPropagation_t clipNanOpt, - double lclip, - double rclip) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnRNNDescriptor_t, cudnnRNNClipMode_t, cudnnNanPropagation_t, double, double); +cudnnStatus_t CUDNNWINAPI cudnnRNNSetClip(cudnnHandle_t handle, + cudnnRNNDescriptor_t rnnDesc, + cudnnRNNClipMode_t clipMode, + cudnnNanPropagation_t clipNanOpt, + double lclip, double rclip) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnRNNDescriptor_t, cudnnRNNClipMode_t, + cudnnNanPropagation_t, double, double); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRNNSetClip"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rnnDesc, clipMode, clipNanOpt, lclip, rclip); } -cudnnStatus_t CUDNNWINAPI -cudnnRNNGetClip(cudnnHandle_t handle, - cudnnRNNDescriptor_t rnnDesc, - cudnnRNNClipMode_t *clipMode, - cudnnNanPropagation_t *clipNanOpt, - double *lclip, - double *rclip) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnRNNDescriptor_t, cudnnRNNClipMode_t *, cudnnNanPropagation_t *, double *, double *); +cudnnStatus_t CUDNNWINAPI cudnnRNNGetClip(cudnnHandle_t handle, + cudnnRNNDescriptor_t rnnDesc, + cudnnRNNClipMode_t *clipMode, + cudnnNanPropagation_t *clipNanOpt, + double *lclip, double *rclip) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnRNNDescriptor_t, cudnnRNNClipMode_t *, + cudnnNanPropagation_t *, double *, double *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRNNGetClip"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rnnDesc, clipMode, clipNanOpt, lclip, rclip); } cudnnStatus_t CUDNNWINAPI -cudnnSetRNNProjectionLayers(cudnnHandle_t handle, - cudnnRNNDescriptor_t rnnDesc, - const int recProjSize, - const int outProjSize) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnRNNDescriptor_t, const int, const int); +cudnnSetRNNProjectionLayers(cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc, + const int recProjSize, const int outProjSize) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnRNNDescriptor_t, const int, const int); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetRNNProjectionLayers"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rnnDesc, recProjSize, outProjSize); } -cudnnStatus_t CUDNNWINAPI -cudnnGetRNNProjectionLayers(cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - int *recProjSize, - int *outProjSize) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, int *, int *); +cudnnStatus_t CUDNNWINAPI cudnnGetRNNProjectionLayers( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, int *recProjSize, + int *outProjSize) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, int *, int *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNProjectionLayers"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rnnDesc, recProjSize, outProjSize); } -cudnnStatus_t CUDNNWINAPI -cudnnCreatePersistentRNNPlan(cudnnRNNDescriptor_t rnnDesc, - const int minibatch, - const cudnnDataType_t dataType, - cudnnPersistentRNNPlan_t *plan) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t, const int, const cudnnDataType_t, cudnnPersistentRNNPlan_t *); +cudnnStatus_t CUDNNWINAPI cudnnCreatePersistentRNNPlan( + cudnnRNNDescriptor_t rnnDesc, const int minibatch, + const cudnnDataType_t dataType, cudnnPersistentRNNPlan_t *plan) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnRNNDescriptor_t, const int, + const cudnnDataType_t, + cudnnPersistentRNNPlan_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreatePersistentRNNPlan"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(rnnDesc, minibatch, dataType, plan); @@ -2021,209 +2089,206 @@ cudnnCreatePersistentRNNPlan(cudnnRNNDescriptor_t rnnDesc, cudnnStatus_t CUDNNWINAPI cudnnDestroyPersistentRNNPlan(cudnnPersistentRNNPlan_t plan) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnPersistentRNNPlan_t); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnPersistentRNNPlan_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyPersistentRNNPlan"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(plan); } -cudnnStatus_t CUDNNWINAPI -cudnnSetPersistentRNNPlan(cudnnRNNDescriptor_t rnnDesc, cudnnPersistentRNNPlan_t plan) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t, cudnnPersistentRNNPlan_t); +cudnnStatus_t CUDNNWINAPI cudnnSetPersistentRNNPlan( + cudnnRNNDescriptor_t rnnDesc, cudnnPersistentRNNPlan_t plan) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnRNNDescriptor_t, + cudnnPersistentRNNPlan_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetPersistentRNNPlan"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(rnnDesc, plan); } -cudnnStatus_t CUDNNWINAPI -cudnnGetRNNWorkspaceSize(cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int seqLength, - const cudnnTensorDescriptor_t *xDesc, - size_t *sizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, size_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetRNNWorkspaceSize( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, const cudnnTensorDescriptor_t *xDesc, + size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t *, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNWorkspaceSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rnnDesc, seqLength, xDesc, sizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnGetRNNTrainingReserveSize(cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int seqLength, - const cudnnTensorDescriptor_t *xDesc, - size_t *sizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, size_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetRNNTrainingReserveSize( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, const cudnnTensorDescriptor_t *xDesc, + size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t *, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNTrainingReserveSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rnnDesc, seqLength, xDesc, sizeInBytes); } cudnnStatus_t CUDNNWINAPI -cudnnGetRNNParamsSize(cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const cudnnTensorDescriptor_t xDesc, - size_t *sizeInBytes, +cudnnGetRNNParamsSize(cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const cudnnTensorDescriptor_t xDesc, size_t *sizeInBytes, cudnnDataType_t dataType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const cudnnTensorDescriptor_t, size_t *, cudnnDataType_t); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const cudnnTensorDescriptor_t, + size_t *, cudnnDataType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNParamsSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rnnDesc, xDesc, sizeInBytes, dataType); } -cudnnStatus_t CUDNNWINAPI -cudnnGetRNNLinLayerMatrixParams(cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int pseudoLayer, - const cudnnTensorDescriptor_t xDesc, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const int linLayerID, - cudnnFilterDescriptor_t linLayerMatDesc, - void **linLayerMat) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, const void *, const int, cudnnFilterDescriptor_t, void **); +cudnnStatus_t CUDNNWINAPI cudnnGetRNNLinLayerMatrixParams( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int pseudoLayer, const cudnnTensorDescriptor_t xDesc, + const cudnnFilterDescriptor_t wDesc, const void *w, const int linLayerID, + cudnnFilterDescriptor_t linLayerMatDesc, void **linLayerMat) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, + const void *, const int, cudnnFilterDescriptor_t, void **); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNLinLayerMatrixParams"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, pseudoLayer, xDesc, wDesc, w, linLayerID, linLayerMatDesc, linLayerMat); + return func_ptr(handle, rnnDesc, pseudoLayer, xDesc, wDesc, w, linLayerID, + linLayerMatDesc, linLayerMat); } -cudnnStatus_t CUDNNWINAPI -cudnnGetRNNLinLayerBiasParams(cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int pseudoLayer, - const cudnnTensorDescriptor_t xDesc, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const int linLayerID, - cudnnFilterDescriptor_t linLayerBiasDesc, - void **linLayerBias) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, const void *, const int, cudnnFilterDescriptor_t, void **); +cudnnStatus_t CUDNNWINAPI cudnnGetRNNLinLayerBiasParams( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int pseudoLayer, const cudnnTensorDescriptor_t xDesc, + const cudnnFilterDescriptor_t wDesc, const void *w, const int linLayerID, + cudnnFilterDescriptor_t linLayerBiasDesc, void **linLayerBias) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t, const cudnnFilterDescriptor_t, + const void *, const int, cudnnFilterDescriptor_t, void **); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNLinLayerBiasParams"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, pseudoLayer, xDesc, wDesc, w, linLayerID, linLayerBiasDesc, linLayerBias); + return func_ptr(handle, rnnDesc, pseudoLayer, xDesc, wDesc, w, linLayerID, + linLayerBiasDesc, linLayerBias); } -cudnnStatus_t CUDNNWINAPI -cudnnRNNForwardInference(cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int seqLength, - const cudnnTensorDescriptor_t *xDesc, - const void *x, - const cudnnTensorDescriptor_t hxDesc, - const void *hx, - const cudnnTensorDescriptor_t cxDesc, - const void *cx, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnTensorDescriptor_t *yDesc, - void *y, - const cudnnTensorDescriptor_t hyDesc, - void *hy, - const cudnnTensorDescriptor_t cyDesc, - void *cy, - void *workspace, - size_t workSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, void *, size_t); +cudnnStatus_t CUDNNWINAPI cudnnRNNForwardInference( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, const cudnnTensorDescriptor_t *xDesc, const void *x, + const cudnnTensorDescriptor_t hxDesc, const void *hx, + const cudnnTensorDescriptor_t cxDesc, const void *cx, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnTensorDescriptor_t *yDesc, void *y, + const cudnnTensorDescriptor_t hyDesc, void *hy, + const cudnnTensorDescriptor_t cyDesc, void *cy, void *workspace, + size_t workSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t *, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, + void *, const cudnnTensorDescriptor_t, void *, void *, size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRNNForwardInference"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, cxDesc, cx, wDesc, w, yDesc, y, hyDesc, hy, cyDesc, cy, workspace, workSpaceSizeInBytes); + return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, cxDesc, cx, + wDesc, w, yDesc, y, hyDesc, hy, cyDesc, cy, workspace, + workSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnRNNForwardTraining(cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int seqLength, - const cudnnTensorDescriptor_t *xDesc, - const void *x, - const cudnnTensorDescriptor_t hxDesc, - const void *hx, - const cudnnTensorDescriptor_t cxDesc, - const void *cx, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnTensorDescriptor_t *yDesc, - void *y, - const cudnnTensorDescriptor_t hyDesc, - void *hy, - const cudnnTensorDescriptor_t cyDesc, - void *cy, - void *workspace, - size_t workSpaceSizeInBytes, - void *reserveSpace, - size_t reserveSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, void *, size_t, void *, size_t); +cudnnStatus_t CUDNNWINAPI cudnnRNNForwardTraining( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, const cudnnTensorDescriptor_t *xDesc, const void *x, + const cudnnTensorDescriptor_t hxDesc, const void *hx, + const cudnnTensorDescriptor_t cxDesc, const void *cx, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnTensorDescriptor_t *yDesc, void *y, + const cudnnTensorDescriptor_t hyDesc, void *hy, + const cudnnTensorDescriptor_t cyDesc, void *cy, void *workspace, + size_t workSpaceSizeInBytes, void *reserveSpace, + size_t reserveSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t *, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, + void *, const cudnnTensorDescriptor_t, void *, void *, size_t, void *, + size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRNNForwardTraining"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, cxDesc, cx, wDesc, w, yDesc, y, hyDesc, hy, cyDesc, cy, workspace, workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes); + return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, cxDesc, cx, + wDesc, w, yDesc, y, hyDesc, hy, cyDesc, cy, workspace, + workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes); } cudnnStatus_t CUDNNWINAPI -cudnnRNNBackwardData(cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int seqLength, - const cudnnTensorDescriptor_t *yDesc, - const void *y, - const cudnnTensorDescriptor_t *dyDesc, - const void *dy, - const cudnnTensorDescriptor_t dhyDesc, - const void *dhy, - const cudnnTensorDescriptor_t dcyDesc, - const void *dcy, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnTensorDescriptor_t hxDesc, - const void *hx, - const cudnnTensorDescriptor_t cxDesc, - const void *cx, - const cudnnTensorDescriptor_t *dxDesc, - void *dx, - const cudnnTensorDescriptor_t dhxDesc, - void *dhx, - const cudnnTensorDescriptor_t dcxDesc, - void *dcx, - void *workspace, - size_t workSpaceSizeInBytes, - void *reserveSpace, - size_t reserveSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, void *, size_t, void *, size_t); +cudnnRNNBackwardData(cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, const cudnnTensorDescriptor_t *yDesc, + const void *y, const cudnnTensorDescriptor_t *dyDesc, + const void *dy, const cudnnTensorDescriptor_t dhyDesc, + const void *dhy, const cudnnTensorDescriptor_t dcyDesc, + const void *dcy, const cudnnFilterDescriptor_t wDesc, + const void *w, const cudnnTensorDescriptor_t hxDesc, + const void *hx, const cudnnTensorDescriptor_t cxDesc, + const void *cx, const cudnnTensorDescriptor_t *dxDesc, + void *dx, const cudnnTensorDescriptor_t dhxDesc, void *dhx, + const cudnnTensorDescriptor_t dcxDesc, void *dcx, + void *workspace, size_t workSpaceSizeInBytes, + void *reserveSpace, size_t reserveSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t *, const void *, + const cudnnTensorDescriptor_t *, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, + void *, const cudnnTensorDescriptor_t, void *, void *, size_t, void *, + size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRNNBackwardData"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, seqLength, yDesc, y, dyDesc, dy, dhyDesc, dhy, dcyDesc, dcy, wDesc, w, hxDesc, hx, cxDesc, cx, dxDesc, dx, dhxDesc, dhx, dcxDesc, dcx, workspace, workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes); + return func_ptr(handle, rnnDesc, seqLength, yDesc, y, dyDesc, dy, dhyDesc, + dhy, dcyDesc, dcy, wDesc, w, hxDesc, hx, cxDesc, cx, dxDesc, + dx, dhxDesc, dhx, dcxDesc, dcx, workspace, + workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnRNNBackwardWeights(cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int seqLength, - const cudnnTensorDescriptor_t *xDesc, - const void *x, - const cudnnTensorDescriptor_t hxDesc, - const void *hx, - const cudnnTensorDescriptor_t *yDesc, - const void *y, - const void *workspace, - size_t workSpaceSizeInBytes, - const cudnnFilterDescriptor_t dwDesc, - void *dw, - const void *reserveSpace, - size_t reserveSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t *, const void *, const void *, size_t, const cudnnFilterDescriptor_t, void *, const void *, size_t); +cudnnStatus_t CUDNNWINAPI cudnnRNNBackwardWeights( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, const cudnnTensorDescriptor_t *xDesc, const void *x, + const cudnnTensorDescriptor_t hxDesc, const void *hx, + const cudnnTensorDescriptor_t *yDesc, const void *y, const void *workspace, + size_t workSpaceSizeInBytes, const cudnnFilterDescriptor_t dwDesc, void *dw, + const void *reserveSpace, size_t reserveSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t *, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t *, const void *, const void *, size_t, + const cudnnFilterDescriptor_t, void *, const void *, size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRNNBackwardWeights"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, yDesc, y, workspace, workSpaceSizeInBytes, dwDesc, dw, reserveSpace, reserveSpaceSizeInBytes); + return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, yDesc, y, + workspace, workSpaceSizeInBytes, dwDesc, dw, reserveSpace, + reserveSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnSetRNNPaddingMode(cudnnRNNDescriptor_t rnnDesc, cudnnRNNPaddingMode_t paddingMode) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t, cudnnRNNPaddingMode_t); +cudnnStatus_t CUDNNWINAPI cudnnSetRNNPaddingMode( + cudnnRNNDescriptor_t rnnDesc, cudnnRNNPaddingMode_t paddingMode) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnRNNDescriptor_t, cudnnRNNPaddingMode_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetRNNPaddingMode"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(rnnDesc, paddingMode); } -cudnnStatus_t CUDNNWINAPI -cudnnGetRNNPaddingMode(cudnnRNNDescriptor_t rnnDesc, cudnnRNNPaddingMode_t *paddingMode) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t, cudnnRNNPaddingMode_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetRNNPaddingMode( + cudnnRNNDescriptor_t rnnDesc, cudnnRNNPaddingMode_t *paddingMode) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnRNNDescriptor_t, + cudnnRNNPaddingMode_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNPaddingMode"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(rnnDesc, paddingMode); @@ -2231,7 +2296,7 @@ cudnnGetRNNPaddingMode(cudnnRNNDescriptor_t rnnDesc, cudnnRNNPaddingMode_t *padd cudnnStatus_t CUDNNWINAPI cudnnCreateRNNDataDescriptor(cudnnRNNDataDescriptor_t *rnnDataDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDataDescriptor_t *); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnRNNDataDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateRNNDataDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(rnnDataDesc); @@ -2239,338 +2304,352 @@ cudnnCreateRNNDataDescriptor(cudnnRNNDataDescriptor_t *rnnDataDesc) { cudnnStatus_t CUDNNWINAPI cudnnDestroyRNNDataDescriptor(cudnnRNNDataDescriptor_t rnnDataDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDataDescriptor_t); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnRNNDataDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyRNNDataDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(rnnDataDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnSetRNNDataDescriptor(cudnnRNNDataDescriptor_t rnnDataDesc, - cudnnDataType_t dataType, - cudnnRNNDataLayout_t layout, - int maxSeqLength, - int batchSize, - int vectorSize, - const int seqLengthArray[], /* length of each sequence in the batch */ - void *paddingFill) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDataDescriptor_t, cudnnDataType_t, cudnnRNNDataLayout_t, int, int, int, const int [], void *); +cudnnStatus_t CUDNNWINAPI cudnnSetRNNDataDescriptor( + cudnnRNNDataDescriptor_t rnnDataDesc, cudnnDataType_t dataType, + cudnnRNNDataLayout_t layout, int maxSeqLength, int batchSize, + int vectorSize, + const int seqLengthArray[], /* length of each sequence in the batch */ + void *paddingFill) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnRNNDataDescriptor_t, cudnnDataType_t, cudnnRNNDataLayout_t, int, int, + int, const int[], void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetRNNDataDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(rnnDataDesc, dataType, layout, maxSeqLength, batchSize, vectorSize, seqLengthArray, paddingFill); + return func_ptr(rnnDataDesc, dataType, layout, maxSeqLength, batchSize, + vectorSize, seqLengthArray, paddingFill); } -cudnnStatus_t CUDNNWINAPI -cudnnGetRNNDataDescriptor(cudnnRNNDataDescriptor_t rnnDataDesc, - cudnnDataType_t *dataType, - cudnnRNNDataLayout_t *layout, - int *maxSeqLength, - int *batchSize, - int *vectorSize, - int arrayLengthRequested, - int seqLengthArray[], - void *paddingFill) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDataDescriptor_t, cudnnDataType_t *, cudnnRNNDataLayout_t *, int *, int *, int *, int, int [], void *); +cudnnStatus_t CUDNNWINAPI cudnnGetRNNDataDescriptor( + cudnnRNNDataDescriptor_t rnnDataDesc, cudnnDataType_t *dataType, + cudnnRNNDataLayout_t *layout, int *maxSeqLength, int *batchSize, + int *vectorSize, int arrayLengthRequested, int seqLengthArray[], + void *paddingFill) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnRNNDataDescriptor_t, cudnnDataType_t *, cudnnRNNDataLayout_t *, + int *, int *, int *, int, int[], void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNDataDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(rnnDataDesc, dataType, layout, maxSeqLength, batchSize, vectorSize, arrayLengthRequested, seqLengthArray, paddingFill); + return func_ptr(rnnDataDesc, dataType, layout, maxSeqLength, batchSize, + vectorSize, arrayLengthRequested, seqLengthArray, + paddingFill); } -cudnnStatus_t CUDNNWINAPI -cudnnRNNForwardTrainingEx(cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const cudnnRNNDataDescriptor_t xDesc, - const void *x, - const cudnnTensorDescriptor_t hxDesc, - const void *hx, - const cudnnTensorDescriptor_t cxDesc, - const void *cx, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnRNNDataDescriptor_t yDesc, - void *y, - const cudnnTensorDescriptor_t hyDesc, - void *hy, - const cudnnTensorDescriptor_t cyDesc, - void *cy, - const cudnnRNNDataDescriptor_t kDesc, /* reserved, should pass NULL */ - const void *keys, /* reserved, should pass NULL */ - const cudnnRNNDataDescriptor_t cDesc, /* reserved, should pass NULL */ - void *cAttn, /* reserved, should pass NULL */ - const cudnnRNNDataDescriptor_t iDesc, /* reserved, should pass NULL */ - void *iAttn, /* reserved, should pass NULL */ - const cudnnRNNDataDescriptor_t qDesc, /* reserved, should pass NULL */ - void *queries, /* reserved, should pass NULL */ - void *workSpace, - size_t workSpaceSizeInBytes, - void *reserveSpace, - size_t reserveSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const cudnnRNNDataDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnRNNDataDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, const cudnnRNNDataDescriptor_t, const void *, const cudnnRNNDataDescriptor_t, void *, const cudnnRNNDataDescriptor_t, void *, const cudnnRNNDataDescriptor_t, void *, void *, size_t, void *, size_t); +cudnnStatus_t CUDNNWINAPI cudnnRNNForwardTrainingEx( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const cudnnRNNDataDescriptor_t xDesc, const void *x, + const cudnnTensorDescriptor_t hxDesc, const void *hx, + const cudnnTensorDescriptor_t cxDesc, const void *cx, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnRNNDataDescriptor_t yDesc, void *y, + const cudnnTensorDescriptor_t hyDesc, void *hy, + const cudnnTensorDescriptor_t cyDesc, void *cy, + const cudnnRNNDataDescriptor_t kDesc, /* reserved, should pass NULL */ + const void *keys, /* reserved, should pass NULL */ + const cudnnRNNDataDescriptor_t cDesc, /* reserved, should pass NULL */ + void *cAttn, /* reserved, should pass NULL */ + const cudnnRNNDataDescriptor_t iDesc, /* reserved, should pass NULL */ + void *iAttn, /* reserved, should pass NULL */ + const cudnnRNNDataDescriptor_t qDesc, /* reserved, should pass NULL */ + void *queries, /* reserved, should pass NULL */ + void *workSpace, size_t workSpaceSizeInBytes, void *reserveSpace, + size_t reserveSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const cudnnRNNDataDescriptor_t, + const void *, const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnRNNDataDescriptor_t, void *, const cudnnTensorDescriptor_t, + void *, const cudnnTensorDescriptor_t, void *, + const cudnnRNNDataDescriptor_t, const void *, + const cudnnRNNDataDescriptor_t, void *, const cudnnRNNDataDescriptor_t, + void *, const cudnnRNNDataDescriptor_t, void *, void *, size_t, void *, + size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRNNForwardTrainingEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, xDesc, x, hxDesc, hx, cxDesc, cx, wDesc, w, yDesc, y, hyDesc, hy, cyDesc, cy, kDesc, keys, cDesc, cAttn, iDesc, iAttn, qDesc, queries, workSpace, workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes); + return func_ptr(handle, rnnDesc, xDesc, x, hxDesc, hx, cxDesc, cx, wDesc, w, + yDesc, y, hyDesc, hy, cyDesc, cy, kDesc, keys, cDesc, cAttn, + iDesc, iAttn, qDesc, queries, workSpace, workSpaceSizeInBytes, + reserveSpace, reserveSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnRNNForwardInferenceEx(cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const cudnnRNNDataDescriptor_t xDesc, - const void *x, - const cudnnTensorDescriptor_t hxDesc, - const void *hx, - const cudnnTensorDescriptor_t cxDesc, - const void *cx, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnRNNDataDescriptor_t yDesc, - void *y, - const cudnnTensorDescriptor_t hyDesc, - void *hy, - const cudnnTensorDescriptor_t cyDesc, - void *cy, - const cudnnRNNDataDescriptor_t kDesc, /* reserved, should pass NULL */ - const void *keys, /* reserved, should pass NULL */ - const cudnnRNNDataDescriptor_t cDesc, /* reserved, should pass NULL */ - void *cAttn, /* reserved, should pass NULL */ - const cudnnRNNDataDescriptor_t iDesc, /* reserved, should pass NULL */ - void *iAttn, /* reserved, should pass NULL */ - const cudnnRNNDataDescriptor_t qDesc, /* reserved, should pass NULL */ - void *queries, /* reserved, should pass NULL */ - void *workSpace, - size_t workSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const cudnnRNNDataDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnRNNDataDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, const cudnnRNNDataDescriptor_t, const void *, const cudnnRNNDataDescriptor_t, void *, const cudnnRNNDataDescriptor_t, void *, const cudnnRNNDataDescriptor_t, void *, void *, size_t); +cudnnStatus_t CUDNNWINAPI cudnnRNNForwardInferenceEx( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const cudnnRNNDataDescriptor_t xDesc, const void *x, + const cudnnTensorDescriptor_t hxDesc, const void *hx, + const cudnnTensorDescriptor_t cxDesc, const void *cx, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnRNNDataDescriptor_t yDesc, void *y, + const cudnnTensorDescriptor_t hyDesc, void *hy, + const cudnnTensorDescriptor_t cyDesc, void *cy, + const cudnnRNNDataDescriptor_t kDesc, /* reserved, should pass NULL */ + const void *keys, /* reserved, should pass NULL */ + const cudnnRNNDataDescriptor_t cDesc, /* reserved, should pass NULL */ + void *cAttn, /* reserved, should pass NULL */ + const cudnnRNNDataDescriptor_t iDesc, /* reserved, should pass NULL */ + void *iAttn, /* reserved, should pass NULL */ + const cudnnRNNDataDescriptor_t qDesc, /* reserved, should pass NULL */ + void *queries, /* reserved, should pass NULL */ + void *workSpace, size_t workSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const cudnnRNNDataDescriptor_t, + const void *, const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnRNNDataDescriptor_t, void *, const cudnnTensorDescriptor_t, + void *, const cudnnTensorDescriptor_t, void *, + const cudnnRNNDataDescriptor_t, const void *, + const cudnnRNNDataDescriptor_t, void *, const cudnnRNNDataDescriptor_t, + void *, const cudnnRNNDataDescriptor_t, void *, void *, size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRNNForwardInferenceEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, xDesc, x, hxDesc, hx, cxDesc, cx, wDesc, w, yDesc, y, hyDesc, hy, cyDesc, cy, kDesc, keys, cDesc, cAttn, iDesc, iAttn, qDesc, queries, workSpace, workSpaceSizeInBytes); + return func_ptr(handle, rnnDesc, xDesc, x, hxDesc, hx, cxDesc, cx, wDesc, w, + yDesc, y, hyDesc, hy, cyDesc, cy, kDesc, keys, cDesc, cAttn, + iDesc, iAttn, qDesc, queries, workSpace, + workSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnRNNBackwardDataEx(cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const cudnnRNNDataDescriptor_t yDesc, - const void *y, - const cudnnRNNDataDescriptor_t dyDesc, - const void *dy, - const cudnnRNNDataDescriptor_t dcDesc, /* reserved, should pass NULL */ - const void *dcAttn, /* reserved, should pass NULL */ - const cudnnTensorDescriptor_t dhyDesc, - const void *dhy, - const cudnnTensorDescriptor_t dcyDesc, - const void *dcy, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnTensorDescriptor_t hxDesc, - const void *hx, - const cudnnTensorDescriptor_t cxDesc, - const void *cx, - const cudnnRNNDataDescriptor_t dxDesc, - void *dx, - const cudnnTensorDescriptor_t dhxDesc, - void *dhx, - const cudnnTensorDescriptor_t dcxDesc, - void *dcx, - const cudnnRNNDataDescriptor_t dkDesc, /* reserved, should pass NULL */ - void *dkeys, /* reserved, should pass NULL */ - void *workSpace, - size_t workSpaceSizeInBytes, - void *reserveSpace, - size_t reserveSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const cudnnRNNDataDescriptor_t, const void *, const cudnnRNNDataDescriptor_t, const void *, const cudnnRNNDataDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnRNNDataDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, const cudnnRNNDataDescriptor_t, void *, void *, size_t, void *, size_t); +cudnnStatus_t CUDNNWINAPI cudnnRNNBackwardDataEx( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const cudnnRNNDataDescriptor_t yDesc, const void *y, + const cudnnRNNDataDescriptor_t dyDesc, const void *dy, + const cudnnRNNDataDescriptor_t dcDesc, /* reserved, should pass NULL */ + const void *dcAttn, /* reserved, should pass NULL */ + const cudnnTensorDescriptor_t dhyDesc, const void *dhy, + const cudnnTensorDescriptor_t dcyDesc, const void *dcy, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnTensorDescriptor_t hxDesc, const void *hx, + const cudnnTensorDescriptor_t cxDesc, const void *cx, + const cudnnRNNDataDescriptor_t dxDesc, void *dx, + const cudnnTensorDescriptor_t dhxDesc, void *dhx, + const cudnnTensorDescriptor_t dcxDesc, void *dcx, + const cudnnRNNDataDescriptor_t dkDesc, /* reserved, should pass NULL */ + void *dkeys, /* reserved, should pass NULL */ + void *workSpace, size_t workSpaceSizeInBytes, void *reserveSpace, + size_t reserveSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const cudnnRNNDataDescriptor_t, + const void *, const cudnnRNNDataDescriptor_t, const void *, + const cudnnRNNDataDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnRNNDataDescriptor_t, void *, const cudnnTensorDescriptor_t, + void *, const cudnnTensorDescriptor_t, void *, + const cudnnRNNDataDescriptor_t, void *, void *, size_t, void *, size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRNNBackwardDataEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, yDesc, y, dyDesc, dy, dcDesc, dcAttn, dhyDesc, dhy, dcyDesc, dcy, wDesc, w, hxDesc, hx, cxDesc, cx, dxDesc, dx, dhxDesc, dhx, dcxDesc, dcx, dkDesc, dkeys, workSpace, workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes); + return func_ptr(handle, rnnDesc, yDesc, y, dyDesc, dy, dcDesc, dcAttn, + dhyDesc, dhy, dcyDesc, dcy, wDesc, w, hxDesc, hx, cxDesc, cx, + dxDesc, dx, dhxDesc, dhx, dcxDesc, dcx, dkDesc, dkeys, + workSpace, workSpaceSizeInBytes, reserveSpace, + reserveSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnRNNBackwardWeightsEx(cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const cudnnRNNDataDescriptor_t xDesc, - const void *x, - const cudnnTensorDescriptor_t hxDesc, - const void *hx, - const cudnnRNNDataDescriptor_t yDesc, - const void *y, - void *workSpace, - size_t workSpaceSizeInBytes, - const cudnnFilterDescriptor_t dwDesc, - void *dw, - void *reserveSpace, - size_t reserveSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const cudnnRNNDataDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnRNNDataDescriptor_t, const void *, void *, size_t, const cudnnFilterDescriptor_t, void *, void *, size_t); +cudnnStatus_t CUDNNWINAPI cudnnRNNBackwardWeightsEx( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const cudnnRNNDataDescriptor_t xDesc, const void *x, + const cudnnTensorDescriptor_t hxDesc, const void *hx, + const cudnnRNNDataDescriptor_t yDesc, const void *y, void *workSpace, + size_t workSpaceSizeInBytes, const cudnnFilterDescriptor_t dwDesc, void *dw, + void *reserveSpace, size_t reserveSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const cudnnRNNDataDescriptor_t, + const void *, const cudnnTensorDescriptor_t, const void *, + const cudnnRNNDataDescriptor_t, const void *, void *, size_t, + const cudnnFilterDescriptor_t, void *, void *, size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRNNBackwardWeightsEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, xDesc, x, hxDesc, hx, yDesc, y, workSpace, workSpaceSizeInBytes, dwDesc, dw, reserveSpace, reserveSpaceSizeInBytes); + return func_ptr(handle, rnnDesc, xDesc, x, hxDesc, hx, yDesc, y, workSpace, + workSpaceSizeInBytes, dwDesc, dw, reserveSpace, + reserveSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnSetRNNAlgorithmDescriptor(cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc, cudnnAlgorithmDescriptor_t algoDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnRNNDescriptor_t, cudnnAlgorithmDescriptor_t); +cudnnStatus_t CUDNNWINAPI cudnnSetRNNAlgorithmDescriptor( + cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc, + cudnnAlgorithmDescriptor_t algoDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnRNNDescriptor_t, cudnnAlgorithmDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetRNNAlgorithmDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rnnDesc, algoDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnGetRNNForwardInferenceAlgorithmMaxCount(cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, int *count) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, int *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNForwardInferenceAlgorithmMaxCount"); +cudnnStatus_t CUDNNWINAPI cudnnGetRNNForwardInferenceAlgorithmMaxCount( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, int *count) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, int *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetRNNForwardInferenceAlgorithmMaxCount"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rnnDesc, count); } -cudnnStatus_t CUDNNWINAPI -cudnnFindRNNForwardInferenceAlgorithmEx(cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int seqLength, - const cudnnTensorDescriptor_t *xDesc, - const void *x, - const cudnnTensorDescriptor_t hxDesc, - const void *hx, - const cudnnTensorDescriptor_t cxDesc, - const void *cx, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnTensorDescriptor_t *yDesc, - void *y, - const cudnnTensorDescriptor_t hyDesc, - void *hy, - const cudnnTensorDescriptor_t cyDesc, - void *cy, - const float findIntensity, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnAlgorithmPerformance_t *perfResults, - void *workspace, - size_t workSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, const float, const int, int *, cudnnAlgorithmPerformance_t *, void *, size_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindRNNForwardInferenceAlgorithmEx"); +cudnnStatus_t CUDNNWINAPI cudnnFindRNNForwardInferenceAlgorithmEx( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, const cudnnTensorDescriptor_t *xDesc, const void *x, + const cudnnTensorDescriptor_t hxDesc, const void *hx, + const cudnnTensorDescriptor_t cxDesc, const void *cx, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnTensorDescriptor_t *yDesc, void *y, + const cudnnTensorDescriptor_t hyDesc, void *hy, + const cudnnTensorDescriptor_t cyDesc, void *cy, const float findIntensity, + const int requestedAlgoCount, int *returnedAlgoCount, + cudnnAlgorithmPerformance_t *perfResults, void *workspace, + size_t workSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t *, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, + void *, const cudnnTensorDescriptor_t, void *, const float, const int, + int *, cudnnAlgorithmPerformance_t *, void *, size_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnFindRNNForwardInferenceAlgorithmEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, cxDesc, cx, wDesc, w, yDesc, y, hyDesc, hy, cyDesc, cy, findIntensity, requestedAlgoCount, returnedAlgoCount, perfResults, workspace, workSpaceSizeInBytes); + return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, cxDesc, cx, + wDesc, w, yDesc, y, hyDesc, hy, cyDesc, cy, findIntensity, + requestedAlgoCount, returnedAlgoCount, perfResults, workspace, + workSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnGetRNNForwardTrainingAlgorithmMaxCount(cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, int *count) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, int *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNForwardTrainingAlgorithmMaxCount"); +cudnnStatus_t CUDNNWINAPI cudnnGetRNNForwardTrainingAlgorithmMaxCount( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, int *count) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, int *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetRNNForwardTrainingAlgorithmMaxCount"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rnnDesc, count); } -cudnnStatus_t CUDNNWINAPI -cudnnFindRNNForwardTrainingAlgorithmEx(cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int seqLength, - const cudnnTensorDescriptor_t *xDesc, - const void *x, - const cudnnTensorDescriptor_t hxDesc, - const void *hx, - const cudnnTensorDescriptor_t cxDesc, - const void *cx, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnTensorDescriptor_t *yDesc, - void *y, - const cudnnTensorDescriptor_t hyDesc, - void *hy, - const cudnnTensorDescriptor_t cyDesc, - void *cy, - const float findIntensity, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnAlgorithmPerformance_t *perfResults, - void *workspace, - size_t workSpaceSizeInBytes, - void *reserveSpace, - size_t reserveSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, const float, const int, int *, cudnnAlgorithmPerformance_t *, void *, size_t, void *, size_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindRNNForwardTrainingAlgorithmEx"); +cudnnStatus_t CUDNNWINAPI cudnnFindRNNForwardTrainingAlgorithmEx( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, const cudnnTensorDescriptor_t *xDesc, const void *x, + const cudnnTensorDescriptor_t hxDesc, const void *hx, + const cudnnTensorDescriptor_t cxDesc, const void *cx, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnTensorDescriptor_t *yDesc, void *y, + const cudnnTensorDescriptor_t hyDesc, void *hy, + const cudnnTensorDescriptor_t cyDesc, void *cy, const float findIntensity, + const int requestedAlgoCount, int *returnedAlgoCount, + cudnnAlgorithmPerformance_t *perfResults, void *workspace, + size_t workSpaceSizeInBytes, void *reserveSpace, + size_t reserveSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t *, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, + void *, const cudnnTensorDescriptor_t, void *, const float, const int, + int *, cudnnAlgorithmPerformance_t *, void *, size_t, void *, size_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnFindRNNForwardTrainingAlgorithmEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, cxDesc, cx, wDesc, w, yDesc, y, hyDesc, hy, cyDesc, cy, findIntensity, requestedAlgoCount, returnedAlgoCount, perfResults, workspace, workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes); + return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, cxDesc, cx, + wDesc, w, yDesc, y, hyDesc, hy, cyDesc, cy, findIntensity, + requestedAlgoCount, returnedAlgoCount, perfResults, workspace, + workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnGetRNNBackwardDataAlgorithmMaxCount(cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, int *count) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, int *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNBackwardDataAlgorithmMaxCount"); +cudnnStatus_t CUDNNWINAPI cudnnGetRNNBackwardDataAlgorithmMaxCount( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, int *count) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, int *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetRNNBackwardDataAlgorithmMaxCount"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rnnDesc, count); } -cudnnStatus_t CUDNNWINAPI -cudnnFindRNNBackwardDataAlgorithmEx(cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int seqLength, - const cudnnTensorDescriptor_t *yDesc, - const void *y, - const cudnnTensorDescriptor_t *dyDesc, - const void *dy, - const cudnnTensorDescriptor_t dhyDesc, - const void *dhy, - const cudnnTensorDescriptor_t dcyDesc, - const void *dcy, - const cudnnFilterDescriptor_t wDesc, - const void *w, - const cudnnTensorDescriptor_t hxDesc, - const void *hx, - const cudnnTensorDescriptor_t cxDesc, - const void *cx, - const cudnnTensorDescriptor_t *dxDesc, - void *dx, - const cudnnTensorDescriptor_t dhxDesc, - void *dhx, - const cudnnTensorDescriptor_t dcxDesc, - void *dcx, - const float findIntensity, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnAlgorithmPerformance_t *perfResults, - void *workspace, - size_t workSpaceSizeInBytes, - void *reserveSpace, - size_t reserveSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnFilterDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, void *, const cudnnTensorDescriptor_t, void *, const float, const int, int *, cudnnAlgorithmPerformance_t *, void *, size_t, void *, size_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindRNNBackwardDataAlgorithmEx"); +cudnnStatus_t CUDNNWINAPI cudnnFindRNNBackwardDataAlgorithmEx( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, const cudnnTensorDescriptor_t *yDesc, const void *y, + const cudnnTensorDescriptor_t *dyDesc, const void *dy, + const cudnnTensorDescriptor_t dhyDesc, const void *dhy, + const cudnnTensorDescriptor_t dcyDesc, const void *dcy, + const cudnnFilterDescriptor_t wDesc, const void *w, + const cudnnTensorDescriptor_t hxDesc, const void *hx, + const cudnnTensorDescriptor_t cxDesc, const void *cx, + const cudnnTensorDescriptor_t *dxDesc, void *dx, + const cudnnTensorDescriptor_t dhxDesc, void *dhx, + const cudnnTensorDescriptor_t dcxDesc, void *dcx, const float findIntensity, + const int requestedAlgoCount, int *returnedAlgoCount, + cudnnAlgorithmPerformance_t *perfResults, void *workspace, + size_t workSpaceSizeInBytes, void *reserveSpace, + size_t reserveSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t *, const void *, + const cudnnTensorDescriptor_t *, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnFilterDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t *, void *, const cudnnTensorDescriptor_t, + void *, const cudnnTensorDescriptor_t, void *, const float, const int, + int *, cudnnAlgorithmPerformance_t *, void *, size_t, void *, size_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnFindRNNBackwardDataAlgorithmEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, seqLength, yDesc, y, dyDesc, dy, dhyDesc, dhy, dcyDesc, dcy, wDesc, w, hxDesc, hx, cxDesc, cx, dxDesc, dx, dhxDesc, dhx, dcxDesc, dcx, findIntensity, requestedAlgoCount, returnedAlgoCount, perfResults, workspace, workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes); + return func_ptr(handle, rnnDesc, seqLength, yDesc, y, dyDesc, dy, dhyDesc, + dhy, dcyDesc, dcy, wDesc, w, hxDesc, hx, cxDesc, cx, dxDesc, + dx, dhxDesc, dhx, dcxDesc, dcx, findIntensity, + requestedAlgoCount, returnedAlgoCount, perfResults, workspace, + workSpaceSizeInBytes, reserveSpace, reserveSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnGetRNNBackwardWeightsAlgorithmMaxCount(cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, int *count) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, int *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetRNNBackwardWeightsAlgorithmMaxCount"); +cudnnStatus_t CUDNNWINAPI cudnnGetRNNBackwardWeightsAlgorithmMaxCount( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, int *count) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, int *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetRNNBackwardWeightsAlgorithmMaxCount"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, rnnDesc, count); } -cudnnStatus_t CUDNNWINAPI -cudnnFindRNNBackwardWeightsAlgorithmEx(cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnnDesc, - const int seqLength, - const cudnnTensorDescriptor_t *xDesc, - const void *x, - const cudnnTensorDescriptor_t hxDesc, - const void *hx, - const cudnnTensorDescriptor_t *yDesc, - const void *y, - const float findIntensity, - const int requestedAlgoCount, - int *returnedAlgoCount, - cudnnAlgorithmPerformance_t *perfResults, - const void *workspace, - size_t workSpaceSizeInBytes, - const cudnnFilterDescriptor_t dwDesc, - void *dw, - const void *reserveSpace, - size_t reserveSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnRNNDescriptor_t, const int, const cudnnTensorDescriptor_t *, const void *, const cudnnTensorDescriptor_t, const void *, const cudnnTensorDescriptor_t *, const void *, const float, const int, int *, cudnnAlgorithmPerformance_t *, const void *, size_t, const cudnnFilterDescriptor_t, void *, const void *, size_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFindRNNBackwardWeightsAlgorithmEx"); +cudnnStatus_t CUDNNWINAPI cudnnFindRNNBackwardWeightsAlgorithmEx( + cudnnHandle_t handle, const cudnnRNNDescriptor_t rnnDesc, + const int seqLength, const cudnnTensorDescriptor_t *xDesc, const void *x, + const cudnnTensorDescriptor_t hxDesc, const void *hx, + const cudnnTensorDescriptor_t *yDesc, const void *y, + const float findIntensity, const int requestedAlgoCount, + int *returnedAlgoCount, cudnnAlgorithmPerformance_t *perfResults, + const void *workspace, size_t workSpaceSizeInBytes, + const cudnnFilterDescriptor_t dwDesc, void *dw, const void *reserveSpace, + size_t reserveSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnRNNDescriptor_t, const int, + const cudnnTensorDescriptor_t *, const void *, + const cudnnTensorDescriptor_t, const void *, + const cudnnTensorDescriptor_t *, const void *, const float, const int, + int *, cudnnAlgorithmPerformance_t *, const void *, size_t, + const cudnnFilterDescriptor_t, void *, const void *, size_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnFindRNNBackwardWeightsAlgorithmEx"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, yDesc, y, findIntensity, requestedAlgoCount, returnedAlgoCount, perfResults, workspace, workSpaceSizeInBytes, dwDesc, dw, reserveSpace, reserveSpaceSizeInBytes); + return func_ptr(handle, rnnDesc, seqLength, xDesc, x, hxDesc, hx, yDesc, y, + findIntensity, requestedAlgoCount, returnedAlgoCount, + perfResults, workspace, workSpaceSizeInBytes, dwDesc, dw, + reserveSpace, reserveSpaceSizeInBytes); } cudnnStatus_t CUDNNWINAPI cudnnCreateSeqDataDescriptor(cudnnSeqDataDescriptor_t *seqDataDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnSeqDataDescriptor_t *); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnSeqDataDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateSeqDataDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(seqDataDesc); @@ -2578,47 +2657,43 @@ cudnnCreateSeqDataDescriptor(cudnnSeqDataDescriptor_t *seqDataDesc) { cudnnStatus_t CUDNNWINAPI cudnnDestroySeqDataDescriptor(cudnnSeqDataDescriptor_t seqDataDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnSeqDataDescriptor_t); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnSeqDataDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroySeqDataDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(seqDataDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnSetSeqDataDescriptor(cudnnSeqDataDescriptor_t seqDataDesc, - cudnnDataType_t dataType, - int nbDims, - const int dimA[], - const cudnnSeqDataAxis_t axes[], - size_t seqLengthArraySize, - const int seqLengthArray[], - void *paddingFill) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnSeqDataDescriptor_t, cudnnDataType_t, int, const int [], const cudnnSeqDataAxis_t [], size_t, const int [], void *); +cudnnStatus_t CUDNNWINAPI cudnnSetSeqDataDescriptor( + cudnnSeqDataDescriptor_t seqDataDesc, cudnnDataType_t dataType, int nbDims, + const int dimA[], const cudnnSeqDataAxis_t axes[], + size_t seqLengthArraySize, const int seqLengthArray[], void *paddingFill) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnSeqDataDescriptor_t, cudnnDataType_t, int, const int[], + const cudnnSeqDataAxis_t[], size_t, const int[], void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetSeqDataDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(seqDataDesc, dataType, nbDims, dimA, axes, seqLengthArraySize, seqLengthArray, paddingFill); + return func_ptr(seqDataDesc, dataType, nbDims, dimA, axes, seqLengthArraySize, + seqLengthArray, paddingFill); } -cudnnStatus_t CUDNNWINAPI -cudnnGetSeqDataDescriptor(const cudnnSeqDataDescriptor_t seqDataDesc, - cudnnDataType_t *dataType, - int *nbDims, - int nbDimsRequested, - int dimA[], - cudnnSeqDataAxis_t axes[], - size_t *seqLengthArraySize, - size_t seqLengthSizeRequested, - int seqLengthArray[], - void *paddingFill) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnSeqDataDescriptor_t, cudnnDataType_t *, int *, int, int [], cudnnSeqDataAxis_t [], size_t *, size_t, int [], void *); +cudnnStatus_t CUDNNWINAPI cudnnGetSeqDataDescriptor( + const cudnnSeqDataDescriptor_t seqDataDesc, cudnnDataType_t *dataType, + int *nbDims, int nbDimsRequested, int dimA[], cudnnSeqDataAxis_t axes[], + size_t *seqLengthArraySize, size_t seqLengthSizeRequested, + int seqLengthArray[], void *paddingFill) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnSeqDataDescriptor_t, cudnnDataType_t *, int *, int, int[], + cudnnSeqDataAxis_t[], size_t *, size_t, int[], void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetSeqDataDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(seqDataDesc, dataType, nbDims, nbDimsRequested, dimA, axes, seqLengthArraySize, seqLengthSizeRequested, seqLengthArray, paddingFill); + return func_ptr(seqDataDesc, dataType, nbDims, nbDimsRequested, dimA, axes, + seqLengthArraySize, seqLengthSizeRequested, seqLengthArray, + paddingFill); } cudnnStatus_t CUDNNWINAPI cudnnCreateAttnDescriptor(cudnnAttnDescriptor_t *attnDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnAttnDescriptor_t *); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnAttnDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateAttnDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(attnDesc); @@ -2626,217 +2701,198 @@ cudnnCreateAttnDescriptor(cudnnAttnDescriptor_t *attnDesc) { cudnnStatus_t CUDNNWINAPI cudnnDestroyAttnDescriptor(cudnnAttnDescriptor_t attnDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnAttnDescriptor_t); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnAttnDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyAttnDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(attnDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnSetAttnDescriptor(cudnnAttnDescriptor_t attnDesc, - cudnnAttnQueryMap_t queryMap, - int nHeads, - double smScaler, - cudnnDataType_t dataType, - cudnnDataType_t computePrec, - cudnnMathType_t mathType, - cudnnDropoutDescriptor_t attnDropoutDesc, - cudnnDropoutDescriptor_t postDropoutDesc, - int qSize, - int kSize, - int vSize, - int qProjSize, - int kProjSize, - int vProjSize, - int oProjSize, - int qoMaxSeqLength, - int kvMaxSeqLength, - int maxBatchSize, - int maxBeamSize) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnAttnDescriptor_t, cudnnAttnQueryMap_t, int, double, cudnnDataType_t, cudnnDataType_t, cudnnMathType_t, cudnnDropoutDescriptor_t, cudnnDropoutDescriptor_t, int, int, int, int, int, int, int, int, int, int, int); +cudnnStatus_t CUDNNWINAPI cudnnSetAttnDescriptor( + cudnnAttnDescriptor_t attnDesc, cudnnAttnQueryMap_t queryMap, int nHeads, + double smScaler, cudnnDataType_t dataType, cudnnDataType_t computePrec, + cudnnMathType_t mathType, cudnnDropoutDescriptor_t attnDropoutDesc, + cudnnDropoutDescriptor_t postDropoutDesc, int qSize, int kSize, int vSize, + int qProjSize, int kProjSize, int vProjSize, int oProjSize, + int qoMaxSeqLength, int kvMaxSeqLength, int maxBatchSize, int maxBeamSize) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnAttnDescriptor_t, cudnnAttnQueryMap_t, int, double, cudnnDataType_t, + cudnnDataType_t, cudnnMathType_t, cudnnDropoutDescriptor_t, + cudnnDropoutDescriptor_t, int, int, int, int, int, int, int, int, int, + int, int); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetAttnDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(attnDesc, queryMap, nHeads, smScaler, dataType, computePrec, mathType, attnDropoutDesc, postDropoutDesc, qSize, kSize, vSize, qProjSize, kProjSize, vProjSize, oProjSize, qoMaxSeqLength, kvMaxSeqLength, maxBatchSize, maxBeamSize); + return func_ptr(attnDesc, queryMap, nHeads, smScaler, dataType, computePrec, + mathType, attnDropoutDesc, postDropoutDesc, qSize, kSize, + vSize, qProjSize, kProjSize, vProjSize, oProjSize, + qoMaxSeqLength, kvMaxSeqLength, maxBatchSize, maxBeamSize); } -cudnnStatus_t CUDNNWINAPI -cudnnGetAttnDescriptor(cudnnAttnDescriptor_t attnDesc, - cudnnAttnQueryMap_t *queryMap, - int *nHeads, - double *smScaler, - cudnnDataType_t *dataType, - cudnnDataType_t *computePrec, - cudnnMathType_t *mathType, - cudnnDropoutDescriptor_t *attnDropoutDesc, - cudnnDropoutDescriptor_t *postDropoutDesc, - int *qSize, - int *kSize, - int *vSize, - int *qProjSize, - int *kProjSize, - int *vProjSize, - int *oProjSize, - int *qoMaxSeqLength, - int *kvMaxSeqLength, - int *maxBatchSize, - int *maxBeamSize) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnAttnDescriptor_t, cudnnAttnQueryMap_t *, int *, double *, cudnnDataType_t *, cudnnDataType_t *, cudnnMathType_t *, cudnnDropoutDescriptor_t *, cudnnDropoutDescriptor_t *, int *, int *, int *, int *, int *, int *, int *, int *, int *, int *, int *); +cudnnStatus_t CUDNNWINAPI cudnnGetAttnDescriptor( + cudnnAttnDescriptor_t attnDesc, cudnnAttnQueryMap_t *queryMap, int *nHeads, + double *smScaler, cudnnDataType_t *dataType, cudnnDataType_t *computePrec, + cudnnMathType_t *mathType, cudnnDropoutDescriptor_t *attnDropoutDesc, + cudnnDropoutDescriptor_t *postDropoutDesc, int *qSize, int *kSize, + int *vSize, int *qProjSize, int *kProjSize, int *vProjSize, int *oProjSize, + int *qoMaxSeqLength, int *kvMaxSeqLength, int *maxBatchSize, + int *maxBeamSize) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnAttnDescriptor_t, cudnnAttnQueryMap_t *, int *, double *, + cudnnDataType_t *, cudnnDataType_t *, cudnnMathType_t *, + cudnnDropoutDescriptor_t *, cudnnDropoutDescriptor_t *, int *, int *, + int *, int *, int *, int *, int *, int *, int *, int *, int *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetAttnDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(attnDesc, queryMap, nHeads, smScaler, dataType, computePrec, mathType, attnDropoutDesc, postDropoutDesc, qSize, kSize, vSize, qProjSize, kProjSize, vProjSize, oProjSize, qoMaxSeqLength, kvMaxSeqLength, maxBatchSize, maxBeamSize); + return func_ptr(attnDesc, queryMap, nHeads, smScaler, dataType, computePrec, + mathType, attnDropoutDesc, postDropoutDesc, qSize, kSize, + vSize, qProjSize, kProjSize, vProjSize, oProjSize, + qoMaxSeqLength, kvMaxSeqLength, maxBatchSize, maxBeamSize); } -cudnnStatus_t CUDNNWINAPI -cudnnGetMultiHeadAttnBuffers(cudnnHandle_t handle, - const cudnnAttnDescriptor_t attnDesc, - size_t *weightSizeInBytes, - size_t *workSpaceSizeInBytes, - size_t *reserveSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnAttnDescriptor_t, size_t *, size_t *, size_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetMultiHeadAttnBuffers( + cudnnHandle_t handle, const cudnnAttnDescriptor_t attnDesc, + size_t *weightSizeInBytes, size_t *workSpaceSizeInBytes, + size_t *reserveSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnAttnDescriptor_t, size_t *, size_t *, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetMultiHeadAttnBuffers"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, attnDesc, weightSizeInBytes, workSpaceSizeInBytes, reserveSpaceSizeInBytes); + return func_ptr(handle, attnDesc, weightSizeInBytes, workSpaceSizeInBytes, + reserveSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnGetMultiHeadAttnWeights(cudnnHandle_t handle, - const cudnnAttnDescriptor_t attnDesc, - cudnnMultiHeadAttnWeightKind_t wKind, - size_t weightSizeInBytes, - const void *w, - cudnnTensorDescriptor_t wDesc, - void **wAddr) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnAttnDescriptor_t, cudnnMultiHeadAttnWeightKind_t, size_t, const void *, cudnnTensorDescriptor_t, void **); +cudnnStatus_t CUDNNWINAPI cudnnGetMultiHeadAttnWeights( + cudnnHandle_t handle, const cudnnAttnDescriptor_t attnDesc, + cudnnMultiHeadAttnWeightKind_t wKind, size_t weightSizeInBytes, + const void *w, cudnnTensorDescriptor_t wDesc, void **wAddr) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnAttnDescriptor_t, + cudnnMultiHeadAttnWeightKind_t, size_t, const void *, + cudnnTensorDescriptor_t, void **); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetMultiHeadAttnWeights"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, attnDesc, wKind, weightSizeInBytes, w, wDesc, wAddr); } -cudnnStatus_t CUDNNWINAPI -cudnnMultiHeadAttnForward(cudnnHandle_t handle, - const cudnnAttnDescriptor_t attnDesc, - int currIdx, - const int *loWinIdx, - const int *hiWinIdx, - const int *seqLengthArrayQRO, - const int *seqLengthArrayKV, - const cudnnSeqDataDescriptor_t qDesc, - const void *queries, - const void *residuals, - const cudnnSeqDataDescriptor_t kDesc, - const void *keys, - const cudnnSeqDataDescriptor_t vDesc, - const void *values, - const cudnnSeqDataDescriptor_t oDesc, - void *out, - size_t weightSizeInBytes, - const void *w, - size_t workSpaceSizeInBytes, - void *workSpace, - size_t reserveSpaceSizeInBytes, - void *reserveSpace) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnAttnDescriptor_t, int, const int *, const int *, const int *, const int *, const cudnnSeqDataDescriptor_t, const void *, const void *, const cudnnSeqDataDescriptor_t, const void *, const cudnnSeqDataDescriptor_t, const void *, const cudnnSeqDataDescriptor_t, void *, size_t, const void *, size_t, void *, size_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnMultiHeadAttnForward( + cudnnHandle_t handle, const cudnnAttnDescriptor_t attnDesc, int currIdx, + const int *loWinIdx, const int *hiWinIdx, const int *seqLengthArrayQRO, + const int *seqLengthArrayKV, const cudnnSeqDataDescriptor_t qDesc, + const void *queries, const void *residuals, + const cudnnSeqDataDescriptor_t kDesc, const void *keys, + const cudnnSeqDataDescriptor_t vDesc, const void *values, + const cudnnSeqDataDescriptor_t oDesc, void *out, size_t weightSizeInBytes, + const void *w, size_t workSpaceSizeInBytes, void *workSpace, + size_t reserveSpaceSizeInBytes, void *reserveSpace) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnAttnDescriptor_t, int, const int *, const int *, + const int *, const int *, const cudnnSeqDataDescriptor_t, const void *, + const void *, const cudnnSeqDataDescriptor_t, const void *, + const cudnnSeqDataDescriptor_t, const void *, + const cudnnSeqDataDescriptor_t, void *, size_t, const void *, size_t, + void *, size_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnMultiHeadAttnForward"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, attnDesc, currIdx, loWinIdx, hiWinIdx, seqLengthArrayQRO, seqLengthArrayKV, qDesc, queries, residuals, kDesc, keys, vDesc, values, oDesc, out, weightSizeInBytes, w, workSpaceSizeInBytes, workSpace, reserveSpaceSizeInBytes, reserveSpace); + return func_ptr(handle, attnDesc, currIdx, loWinIdx, hiWinIdx, + seqLengthArrayQRO, seqLengthArrayKV, qDesc, queries, + residuals, kDesc, keys, vDesc, values, oDesc, out, + weightSizeInBytes, w, workSpaceSizeInBytes, workSpace, + reserveSpaceSizeInBytes, reserveSpace); } -cudnnStatus_t CUDNNWINAPI -cudnnMultiHeadAttnBackwardData(cudnnHandle_t handle, - const cudnnAttnDescriptor_t attnDesc, - const int *loWinIdx, - const int *hiWinIdx, - const int *seqLengthArrayDQDO, - const int *seqLengthArrayDKDV, - const cudnnSeqDataDescriptor_t doDesc, - const void *dout, - const cudnnSeqDataDescriptor_t dqDesc, - void *dqueries, - const void *queries, - const cudnnSeqDataDescriptor_t dkDesc, - void *dkeys, - const void *keys, - const cudnnSeqDataDescriptor_t dvDesc, - void *dvalues, - const void *values, - size_t weightSizeInBytes, - const void *w, - size_t workSpaceSizeInBytes, - void *workSpace, - size_t reserveSpaceSizeInBytes, - void *reserveSpace) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnAttnDescriptor_t, const int *, const int *, const int *, const int *, const cudnnSeqDataDescriptor_t, const void *, const cudnnSeqDataDescriptor_t, void *, const void *, const cudnnSeqDataDescriptor_t, void *, const void *, const cudnnSeqDataDescriptor_t, void *, const void *, size_t, const void *, size_t, void *, size_t, void *); +cudnnStatus_t CUDNNWINAPI cudnnMultiHeadAttnBackwardData( + cudnnHandle_t handle, const cudnnAttnDescriptor_t attnDesc, + const int *loWinIdx, const int *hiWinIdx, const int *seqLengthArrayDQDO, + const int *seqLengthArrayDKDV, const cudnnSeqDataDescriptor_t doDesc, + const void *dout, const cudnnSeqDataDescriptor_t dqDesc, void *dqueries, + const void *queries, const cudnnSeqDataDescriptor_t dkDesc, void *dkeys, + const void *keys, const cudnnSeqDataDescriptor_t dvDesc, void *dvalues, + const void *values, size_t weightSizeInBytes, const void *w, + size_t workSpaceSizeInBytes, void *workSpace, + size_t reserveSpaceSizeInBytes, void *reserveSpace) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnAttnDescriptor_t, const int *, const int *, + const int *, const int *, const cudnnSeqDataDescriptor_t, const void *, + const cudnnSeqDataDescriptor_t, void *, const void *, + const cudnnSeqDataDescriptor_t, void *, const void *, + const cudnnSeqDataDescriptor_t, void *, const void *, size_t, + const void *, size_t, void *, size_t, void *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnMultiHeadAttnBackwardData"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, attnDesc, loWinIdx, hiWinIdx, seqLengthArrayDQDO, seqLengthArrayDKDV, doDesc, dout, dqDesc, dqueries, queries, dkDesc, dkeys, keys, dvDesc, dvalues, values, weightSizeInBytes, w, workSpaceSizeInBytes, workSpace, reserveSpaceSizeInBytes, reserveSpace); + return func_ptr(handle, attnDesc, loWinIdx, hiWinIdx, seqLengthArrayDQDO, + seqLengthArrayDKDV, doDesc, dout, dqDesc, dqueries, queries, + dkDesc, dkeys, keys, dvDesc, dvalues, values, + weightSizeInBytes, w, workSpaceSizeInBytes, workSpace, + reserveSpaceSizeInBytes, reserveSpace); } -cudnnStatus_t CUDNNWINAPI -cudnnMultiHeadAttnBackwardWeights(cudnnHandle_t handle, - const cudnnAttnDescriptor_t attnDesc, - cudnnWgradMode_t addGrad, - const cudnnSeqDataDescriptor_t qDesc, - const void *queries, - const cudnnSeqDataDescriptor_t kDesc, - const void *keys, - const cudnnSeqDataDescriptor_t vDesc, - const void *values, - const cudnnSeqDataDescriptor_t doDesc, - const void *dout, - size_t weightSizeInBytes, - const void *w, - void *dw, - size_t workSpaceSizeInBytes, - void *workSpace, - size_t reserveSpaceSizeInBytes, - void *reserveSpace) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnAttnDescriptor_t, cudnnWgradMode_t, const cudnnSeqDataDescriptor_t, const void *, const cudnnSeqDataDescriptor_t, const void *, const cudnnSeqDataDescriptor_t, const void *, const cudnnSeqDataDescriptor_t, const void *, size_t, const void *, void *, size_t, void *, size_t, void *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnMultiHeadAttnBackwardWeights"); +cudnnStatus_t CUDNNWINAPI cudnnMultiHeadAttnBackwardWeights( + cudnnHandle_t handle, const cudnnAttnDescriptor_t attnDesc, + cudnnWgradMode_t addGrad, const cudnnSeqDataDescriptor_t qDesc, + const void *queries, const cudnnSeqDataDescriptor_t kDesc, const void *keys, + const cudnnSeqDataDescriptor_t vDesc, const void *values, + const cudnnSeqDataDescriptor_t doDesc, const void *dout, + size_t weightSizeInBytes, const void *w, void *dw, + size_t workSpaceSizeInBytes, void *workSpace, + size_t reserveSpaceSizeInBytes, void *reserveSpace) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnAttnDescriptor_t, cudnnWgradMode_t, + const cudnnSeqDataDescriptor_t, const void *, + const cudnnSeqDataDescriptor_t, const void *, + const cudnnSeqDataDescriptor_t, const void *, + const cudnnSeqDataDescriptor_t, const void *, size_t, const void *, + void *, size_t, void *, size_t, void *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnMultiHeadAttnBackwardWeights"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, attnDesc, addGrad, qDesc, queries, kDesc, keys, vDesc, values, doDesc, dout, weightSizeInBytes, w, dw, workSpaceSizeInBytes, workSpace, reserveSpaceSizeInBytes, reserveSpace); + return func_ptr(handle, attnDesc, addGrad, qDesc, queries, kDesc, keys, vDesc, + values, doDesc, dout, weightSizeInBytes, w, dw, + workSpaceSizeInBytes, workSpace, reserveSpaceSizeInBytes, + reserveSpace); } cudnnStatus_t CUDNNWINAPI cudnnCreateCTCLossDescriptor(cudnnCTCLossDescriptor_t *ctcLossDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnCTCLossDescriptor_t *); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnCTCLossDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateCTCLossDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(ctcLossDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnSetCTCLossDescriptor(cudnnCTCLossDescriptor_t ctcLossDesc, cudnnDataType_t compType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnCTCLossDescriptor_t, cudnnDataType_t); +cudnnStatus_t CUDNNWINAPI cudnnSetCTCLossDescriptor( + cudnnCTCLossDescriptor_t ctcLossDesc, cudnnDataType_t compType) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnCTCLossDescriptor_t, cudnnDataType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetCTCLossDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(ctcLossDesc, compType); } -cudnnStatus_t CUDNNWINAPI -cudnnSetCTCLossDescriptorEx(cudnnCTCLossDescriptor_t ctcLossDesc, - cudnnDataType_t compType, - cudnnLossNormalizationMode_t normMode, - cudnnNanPropagation_t gradMode) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnCTCLossDescriptor_t, cudnnDataType_t, cudnnLossNormalizationMode_t, cudnnNanPropagation_t); +cudnnStatus_t CUDNNWINAPI cudnnSetCTCLossDescriptorEx( + cudnnCTCLossDescriptor_t ctcLossDesc, cudnnDataType_t compType, + cudnnLossNormalizationMode_t normMode, cudnnNanPropagation_t gradMode) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnCTCLossDescriptor_t, cudnnDataType_t, cudnnLossNormalizationMode_t, + cudnnNanPropagation_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetCTCLossDescriptorEx"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(ctcLossDesc, compType, normMode, gradMode); } -cudnnStatus_t CUDNNWINAPI -cudnnGetCTCLossDescriptor(cudnnCTCLossDescriptor_t ctcLossDesc, cudnnDataType_t *compType) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnCTCLossDescriptor_t, cudnnDataType_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetCTCLossDescriptor( + cudnnCTCLossDescriptor_t ctcLossDesc, cudnnDataType_t *compType) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnCTCLossDescriptor_t, cudnnDataType_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetCTCLossDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(ctcLossDesc, compType); } -cudnnStatus_t CUDNNWINAPI -cudnnGetCTCLossDescriptorEx(cudnnCTCLossDescriptor_t ctcLossDesc, - cudnnDataType_t *compType, - cudnnLossNormalizationMode_t *normMode, - cudnnNanPropagation_t *gradMode) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnCTCLossDescriptor_t, cudnnDataType_t *, cudnnLossNormalizationMode_t *, cudnnNanPropagation_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetCTCLossDescriptorEx( + cudnnCTCLossDescriptor_t ctcLossDesc, cudnnDataType_t *compType, + cudnnLossNormalizationMode_t *normMode, cudnnNanPropagation_t *gradMode) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnCTCLossDescriptor_t, cudnnDataType_t *, + cudnnLossNormalizationMode_t *, cudnnNanPropagation_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetCTCLossDescriptorEx"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(ctcLossDesc, compType, normMode, gradMode); @@ -2844,82 +2900,102 @@ cudnnGetCTCLossDescriptorEx(cudnnCTCLossDescriptor_t ctcLossDesc, cudnnStatus_t CUDNNWINAPI cudnnDestroyCTCLossDescriptor(cudnnCTCLossDescriptor_t ctcLossDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnCTCLossDescriptor_t); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnCTCLossDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyCTCLossDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(ctcLossDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnCTCLoss( +cudnnStatus_t CUDNNWINAPI cudnnCTCLoss( cudnnHandle_t handle, const cudnnTensorDescriptor_t - probsDesc, /* Tensor descriptor for probabilities, the dimensions are T,N,A (T is the timing steps, N is the - mini batch size, A is the alphabet size) */ - const void *probs, /* probabilities after softmax, in GPU memory */ - const int *labels, /* labels, in CPU memory */ - const int *labelLengths, /* the length of each label, in CPU memory */ - const int *inputLengths, /* the lengths of timing steps in each batch, in CPU memory */ - void *costs, /* the returned costs of CTC, in GPU memory */ - const cudnnTensorDescriptor_t gradientsDesc, /* Tensor descriptor for gradients, the dimensions are T,N,A */ - const void *gradients, /* the returned CTC gradients, in GPU memory, to compute costs only, set it to NULL */ + probsDesc, /* Tensor descriptor for probabilities, the dimensions are + T,N,A (T is the timing steps, N is the + mini batch size, A is the alphabet size) */ + const void *probs, /* probabilities after softmax, in GPU memory */ + const int *labels, /* labels, in CPU memory */ + const int *labelLengths, /* the length of each label, in CPU memory */ + const int *inputLengths, /* the lengths of timing steps in each batch, in + CPU memory */ + void *costs, /* the returned costs of CTC, in GPU memory */ + const cudnnTensorDescriptor_t + gradientsDesc, /* Tensor descriptor for gradients, the dimensions are + T,N,A */ + const void *gradients, /* the returned CTC gradients, in GPU memory, to + compute costs only, set it to NULL */ cudnnCTCLossAlgo_t algo, /* algorithm selected, supported now 0 and 1 */ cudnnCTCLossDescriptor_t ctcLossDesc, - void *workspace, /* pointer to the workspace, in GPU memory */ + void *workspace, /* pointer to the workspace, in GPU memory */ size_t workSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const void *, const int *, const int *, const int *, void *, const cudnnTensorDescriptor_t, const void *, cudnnCTCLossAlgo_t, cudnnCTCLossDescriptor_t, void *, size_t); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, const void *, const int *, + const int *, const int *, void *, const cudnnTensorDescriptor_t, + const void *, cudnnCTCLossAlgo_t, cudnnCTCLossDescriptor_t, void *, + size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCTCLoss"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, probsDesc, probs, labels, labelLengths, inputLengths, costs, gradientsDesc, gradients, algo, ctcLossDesc, workspace, workSpaceSizeInBytes); + return func_ptr(handle, probsDesc, probs, labels, labelLengths, inputLengths, + costs, gradientsDesc, gradients, algo, ctcLossDesc, workspace, + workSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnGetCTCLossWorkspaceSize( +cudnnStatus_t CUDNNWINAPI cudnnGetCTCLossWorkspaceSize( cudnnHandle_t handle, - const cudnnTensorDescriptor_t probsDesc, /* Tensor descriptor for probabilities, the dimensions are T,N,A (T is the - timing steps, N is the mini batch size, A is the alphabet size) */ - const cudnnTensorDescriptor_t gradientsDesc, /* Tensor descriptor for gradients, the - dimensions are T,N,A. To compute costs - only, set it to NULL */ - const int *labels, /* labels, in CPU memory */ - const int *labelLengths, /* the length of each label, in CPU memory */ - const int *inputLengths, /* the lengths of timing steps in each batch, in CPU memory */ - cudnnCTCLossAlgo_t algo, /* algorithm selected, supported now 0 and 1 */ - cudnnCTCLossDescriptor_t ctcLossDesc, - size_t *sizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnTensorDescriptor_t, const cudnnTensorDescriptor_t, const int *, const int *, const int *, cudnnCTCLossAlgo_t, cudnnCTCLossDescriptor_t, size_t *); + const cudnnTensorDescriptor_t + probsDesc, /* Tensor descriptor for probabilities, the dimensions are + T,N,A (T is the + timing steps, N is the mini batch size, A is the alphabet + size) */ + const cudnnTensorDescriptor_t + gradientsDesc, /* Tensor descriptor for gradients, the + dimensions are T,N,A. To compute costs + only, set it to NULL */ + const int *labels, /* labels, in CPU memory */ + const int *labelLengths, /* the length of each label, in CPU memory */ + const int *inputLengths, /* the lengths of timing steps in each batch, in + CPU memory */ + cudnnCTCLossAlgo_t algo, /* algorithm selected, supported now 0 and 1 */ + cudnnCTCLossDescriptor_t ctcLossDesc, size_t *sizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, const cudnnTensorDescriptor_t, + const cudnnTensorDescriptor_t, const int *, const int *, const int *, + cudnnCTCLossAlgo_t, cudnnCTCLossDescriptor_t, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetCTCLossWorkspaceSize"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, probsDesc, gradientsDesc, labels, labelLengths, inputLengths, algo, ctcLossDesc, sizeInBytes); + return func_ptr(handle, probsDesc, gradientsDesc, labels, labelLengths, + inputLengths, algo, ctcLossDesc, sizeInBytes); } cudnnStatus_t CUDNNWINAPI cudnnCreateAlgorithmDescriptor(cudnnAlgorithmDescriptor_t *algoDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnAlgorithmDescriptor_t *); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnAlgorithmDescriptor_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateAlgorithmDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(algoDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnSetAlgorithmDescriptor(cudnnAlgorithmDescriptor_t algoDesc, cudnnAlgorithm_t algorithm) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnAlgorithmDescriptor_t, cudnnAlgorithm_t); +cudnnStatus_t CUDNNWINAPI cudnnSetAlgorithmDescriptor( + cudnnAlgorithmDescriptor_t algoDesc, cudnnAlgorithm_t algorithm) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnAlgorithmDescriptor_t, + cudnnAlgorithm_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetAlgorithmDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(algoDesc, algorithm); } -cudnnStatus_t CUDNNWINAPI -cudnnGetAlgorithmDescriptor(const cudnnAlgorithmDescriptor_t algoDesc, cudnnAlgorithm_t *algorithm) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnAlgorithmDescriptor_t, cudnnAlgorithm_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetAlgorithmDescriptor( + const cudnnAlgorithmDescriptor_t algoDesc, cudnnAlgorithm_t *algorithm) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(const cudnnAlgorithmDescriptor_t, + cudnnAlgorithm_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetAlgorithmDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(algoDesc, algorithm); } -cudnnStatus_t CUDNNWINAPI -cudnnCopyAlgorithmDescriptor(const cudnnAlgorithmDescriptor_t src, cudnnAlgorithmDescriptor_t dest) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnAlgorithmDescriptor_t, cudnnAlgorithmDescriptor_t); +cudnnStatus_t CUDNNWINAPI cudnnCopyAlgorithmDescriptor( + const cudnnAlgorithmDescriptor_t src, cudnnAlgorithmDescriptor_t dest) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(const cudnnAlgorithmDescriptor_t, + cudnnAlgorithmDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCopyAlgorithmDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(src, dest); @@ -2927,236 +3003,255 @@ cudnnCopyAlgorithmDescriptor(const cudnnAlgorithmDescriptor_t src, cudnnAlgorith cudnnStatus_t CUDNNWINAPI cudnnDestroyAlgorithmDescriptor(cudnnAlgorithmDescriptor_t algoDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnAlgorithmDescriptor_t); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnAlgorithmDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyAlgorithmDescriptor"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(algoDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnCreateAlgorithmPerformance(cudnnAlgorithmPerformance_t *algoPerf, int numberToCreate) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnAlgorithmPerformance_t *, int); +cudnnStatus_t CUDNNWINAPI cudnnCreateAlgorithmPerformance( + cudnnAlgorithmPerformance_t *algoPerf, int numberToCreate) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnAlgorithmPerformance_t *, int); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateAlgorithmPerformance"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(algoPerf, numberToCreate); } -cudnnStatus_t CUDNNWINAPI -cudnnSetAlgorithmPerformance(cudnnAlgorithmPerformance_t algoPerf, - cudnnAlgorithmDescriptor_t algoDesc, - cudnnStatus_t status, - float time, - size_t memory) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnAlgorithmPerformance_t, cudnnAlgorithmDescriptor_t, cudnnStatus_t, float, size_t); +cudnnStatus_t CUDNNWINAPI cudnnSetAlgorithmPerformance( + cudnnAlgorithmPerformance_t algoPerf, cudnnAlgorithmDescriptor_t algoDesc, + cudnnStatus_t status, float time, size_t memory) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnAlgorithmPerformance_t, + cudnnAlgorithmDescriptor_t, + cudnnStatus_t, float, size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetAlgorithmPerformance"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(algoPerf, algoDesc, status, time, memory); } -cudnnStatus_t CUDNNWINAPI -cudnnGetAlgorithmPerformance(const cudnnAlgorithmPerformance_t algoPerf, - cudnnAlgorithmDescriptor_t *algoDesc, - cudnnStatus_t *status, - float *time, - size_t *memory) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnAlgorithmPerformance_t, cudnnAlgorithmDescriptor_t *, cudnnStatus_t *, float *, size_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetAlgorithmPerformance( + const cudnnAlgorithmPerformance_t algoPerf, + cudnnAlgorithmDescriptor_t *algoDesc, cudnnStatus_t *status, float *time, + size_t *memory) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnAlgorithmPerformance_t, cudnnAlgorithmDescriptor_t *, + cudnnStatus_t *, float *, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetAlgorithmPerformance"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(algoPerf, algoDesc, status, time, memory); } -cudnnStatus_t CUDNNWINAPI -cudnnDestroyAlgorithmPerformance(cudnnAlgorithmPerformance_t *algoPerf, int numberToDestroy) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnAlgorithmPerformance_t *, int); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyAlgorithmPerformance"); +cudnnStatus_t CUDNNWINAPI cudnnDestroyAlgorithmPerformance( + cudnnAlgorithmPerformance_t *algoPerf, int numberToDestroy) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnAlgorithmPerformance_t *, int); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnDestroyAlgorithmPerformance"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(algoPerf, numberToDestroy); } -cudnnStatus_t CUDNNWINAPI -cudnnGetAlgorithmSpaceSize(cudnnHandle_t handle, cudnnAlgorithmDescriptor_t algoDesc, size_t *algoSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnAlgorithmDescriptor_t, size_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetAlgorithmSpaceSize( + cudnnHandle_t handle, cudnnAlgorithmDescriptor_t algoDesc, + size_t *algoSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnAlgorithmDescriptor_t, size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetAlgorithmSpaceSize"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, algoDesc, algoSpaceSizeInBytes); } cudnnStatus_t CUDNNWINAPI -cudnnSaveAlgorithm(cudnnHandle_t handle, - cudnnAlgorithmDescriptor_t algoDesc, - void *algoSpace, - size_t algoSpaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnAlgorithmDescriptor_t, void *, size_t); +cudnnSaveAlgorithm(cudnnHandle_t handle, cudnnAlgorithmDescriptor_t algoDesc, + void *algoSpace, size_t algoSpaceSizeInBytes) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnAlgorithmDescriptor_t, void *, size_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSaveAlgorithm"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, algoDesc, algoSpace, algoSpaceSizeInBytes); } -cudnnStatus_t CUDNNWINAPI -cudnnRestoreAlgorithm(cudnnHandle_t handle, - void *algoSpace, - size_t algoSpaceSizeInBytes, - cudnnAlgorithmDescriptor_t algoDesc) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, void *, size_t, cudnnAlgorithmDescriptor_t); +cudnnStatus_t CUDNNWINAPI cudnnRestoreAlgorithm( + cudnnHandle_t handle, void *algoSpace, size_t algoSpaceSizeInBytes, + cudnnAlgorithmDescriptor_t algoDesc) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t, void *, size_t, + cudnnAlgorithmDescriptor_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnRestoreAlgorithm"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, algoSpace, algoSpaceSizeInBytes, algoDesc); } -cudnnStatus_t CUDNNWINAPI -cudnnSetCallback(unsigned mask, void *udata, cudnnCallback_t fptr) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(unsigned int, void *, cudnnCallback_t); +cudnnStatus_t CUDNNWINAPI cudnnSetCallback(unsigned mask, void *udata, + cudnnCallback_t fptr) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(unsigned int, void *, cudnnCallback_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetCallback"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(mask, udata, fptr); } -cudnnStatus_t CUDNNWINAPI -cudnnGetCallback(unsigned *mask, void **udata, cudnnCallback_t *fptr) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(unsigned int *, void **, cudnnCallback_t *); +cudnnStatus_t CUDNNWINAPI cudnnGetCallback(unsigned *mask, void **udata, + cudnnCallback_t *fptr) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(unsigned int *, void **, cudnnCallback_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetCallback"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(mask, udata, fptr); } -cudnnStatus_t CUDNNWINAPI -cudnnCreateFusedOpsConstParamPack(cudnnFusedOpsConstParamPack_t *constPack, cudnnFusedOps_t ops) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnFusedOpsConstParamPack_t *, cudnnFusedOps_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateFusedOpsConstParamPack"); +cudnnStatus_t CUDNNWINAPI cudnnCreateFusedOpsConstParamPack( + cudnnFusedOpsConstParamPack_t *constPack, cudnnFusedOps_t ops) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnFusedOpsConstParamPack_t *, + cudnnFusedOps_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnCreateFusedOpsConstParamPack"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(constPack, ops); } cudnnStatus_t CUDNNWINAPI cudnnDestroyFusedOpsConstParamPack(cudnnFusedOpsConstParamPack_t constPack) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnFusedOpsConstParamPack_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyFusedOpsConstParamPack"); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnFusedOpsConstParamPack_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnDestroyFusedOpsConstParamPack"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(constPack); } -cudnnStatus_t CUDNNWINAPI -cudnnSetFusedOpsConstParamPackAttribute(cudnnFusedOpsConstParamPack_t constPack, - cudnnFusedOpsConstParamLabel_t paramLabel, - const void *param) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnFusedOpsConstParamPack_t, cudnnFusedOpsConstParamLabel_t, const void *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetFusedOpsConstParamPackAttribute"); +cudnnStatus_t CUDNNWINAPI cudnnSetFusedOpsConstParamPackAttribute( + cudnnFusedOpsConstParamPack_t constPack, + cudnnFusedOpsConstParamLabel_t paramLabel, const void *param) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnFusedOpsConstParamPack_t, + cudnnFusedOpsConstParamLabel_t, + const void *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnSetFusedOpsConstParamPackAttribute"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(constPack, paramLabel, param); } -cudnnStatus_t CUDNNWINAPI -cudnnGetFusedOpsConstParamPackAttribute(const cudnnFusedOpsConstParamPack_t constPack, - cudnnFusedOpsConstParamLabel_t paramLabel, - void *param, - int *isNULL) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnFusedOpsConstParamPack_t, cudnnFusedOpsConstParamLabel_t, void *, int *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetFusedOpsConstParamPackAttribute"); +cudnnStatus_t CUDNNWINAPI cudnnGetFusedOpsConstParamPackAttribute( + const cudnnFusedOpsConstParamPack_t constPack, + cudnnFusedOpsConstParamLabel_t paramLabel, void *param, int *isNULL) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + const cudnnFusedOpsConstParamPack_t, cudnnFusedOpsConstParamLabel_t, + void *, int *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetFusedOpsConstParamPackAttribute"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(constPack, paramLabel, param, isNULL); } -cudnnStatus_t CUDNNWINAPI -cudnnCreateFusedOpsVariantParamPack(cudnnFusedOpsVariantParamPack_t *varPack, cudnnFusedOps_t ops) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnFusedOpsVariantParamPack_t *, cudnnFusedOps_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateFusedOpsVariantParamPack"); +cudnnStatus_t CUDNNWINAPI cudnnCreateFusedOpsVariantParamPack( + cudnnFusedOpsVariantParamPack_t *varPack, cudnnFusedOps_t ops) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnFusedOpsVariantParamPack_t *, cudnnFusedOps_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnCreateFusedOpsVariantParamPack"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(varPack, ops); } cudnnStatus_t CUDNNWINAPI cudnnDestroyFusedOpsVariantParamPack(cudnnFusedOpsVariantParamPack_t varPack) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnFusedOpsVariantParamPack_t); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyFusedOpsVariantParamPack"); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnFusedOpsVariantParamPack_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnDestroyFusedOpsVariantParamPack"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(varPack); } -cudnnStatus_t CUDNNWINAPI -cudnnSetFusedOpsVariantParamPackAttribute(cudnnFusedOpsVariantParamPack_t varPack, - cudnnFusedOpsVariantParamLabel_t paramLabel, - void *ptr) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnFusedOpsVariantParamPack_t, cudnnFusedOpsVariantParamLabel_t, void *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetFusedOpsVariantParamPackAttribute"); +cudnnStatus_t CUDNNWINAPI cudnnSetFusedOpsVariantParamPackAttribute( + cudnnFusedOpsVariantParamPack_t varPack, + cudnnFusedOpsVariantParamLabel_t paramLabel, void *ptr) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnFusedOpsVariantParamPack_t, + cudnnFusedOpsVariantParamLabel_t, void *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnSetFusedOpsVariantParamPackAttribute"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(varPack, paramLabel, ptr); } -cudnnStatus_t CUDNNWINAPI -cudnnGetFusedOpsVariantParamPackAttribute(const cudnnFusedOpsVariantParamPack_t varPack, - cudnnFusedOpsVariantParamLabel_t paramLabel, - void *ptr) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(const cudnnFusedOpsVariantParamPack_t, cudnnFusedOpsVariantParamLabel_t, void *); - static auto func_ptr = LoadSymbol<FuncPtr>("cudnnGetFusedOpsVariantParamPackAttribute"); +cudnnStatus_t CUDNNWINAPI cudnnGetFusedOpsVariantParamPackAttribute( + const cudnnFusedOpsVariantParamPack_t varPack, + cudnnFusedOpsVariantParamLabel_t paramLabel, void *ptr) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(const cudnnFusedOpsVariantParamPack_t, + cudnnFusedOpsVariantParamLabel_t, void *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cudnnGetFusedOpsVariantParamPackAttribute"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(varPack, paramLabel, ptr); } -cudnnStatus_t CUDNNWINAPI -cudnnCreateFusedOpsPlan(cudnnFusedOpsPlan_t *plan, cudnnFusedOps_t ops) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnFusedOpsPlan_t *, cudnnFusedOps_t); +cudnnStatus_t CUDNNWINAPI cudnnCreateFusedOpsPlan(cudnnFusedOpsPlan_t *plan, + cudnnFusedOps_t ops) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnFusedOpsPlan_t *, cudnnFusedOps_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnCreateFusedOpsPlan"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(plan, ops); } -cudnnStatus_t CUDNNWINAPI -cudnnDestroyFusedOpsPlan(cudnnFusedOpsPlan_t plan) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnFusedOpsPlan_t); +cudnnStatus_t CUDNNWINAPI cudnnDestroyFusedOpsPlan(cudnnFusedOpsPlan_t plan) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)(cudnnFusedOpsPlan_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnDestroyFusedOpsPlan"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(plan); } cudnnStatus_t CUDNNWINAPI -cudnnMakeFusedOpsPlan(cudnnHandle_t handle, - cudnnFusedOpsPlan_t plan, +cudnnMakeFusedOpsPlan(cudnnHandle_t handle, cudnnFusedOpsPlan_t plan, const cudnnFusedOpsConstParamPack_t constPack, size_t *workspaceSizeInBytes) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnFusedOpsPlan_t, const cudnnFusedOpsConstParamPack_t, size_t *); + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnFusedOpsPlan_t, const cudnnFusedOpsConstParamPack_t, + size_t *); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnMakeFusedOpsPlan"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, plan, constPack, workspaceSizeInBytes); } cudnnStatus_t CUDNNWINAPI -cudnnFusedOpsExecute(cudnnHandle_t handle, const cudnnFusedOpsPlan_t plan, cudnnFusedOpsVariantParamPack_t varPack) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, const cudnnFusedOpsPlan_t, cudnnFusedOpsVariantParamPack_t); +cudnnFusedOpsExecute(cudnnHandle_t handle, const cudnnFusedOpsPlan_t plan, + cudnnFusedOpsVariantParamPack_t varPack) { + using FuncPtr = + cudnnStatus_t(CUDNNWINAPI *)(cudnnHandle_t, const cudnnFusedOpsPlan_t, + cudnnFusedOpsVariantParamPack_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnFusedOpsExecute"); if (!func_ptr) return GetSymbolNotFoundError(); return func_ptr(handle, plan, varPack); } -cudnnStatus_t CUDNNWINAPI -cudnnSetRNNDescriptor_v6(cudnnHandle_t handle, - cudnnRNNDescriptor_t rnnDesc, - const int hiddenSize, - const int numLayers, - cudnnDropoutDescriptor_t dropoutDesc, - cudnnRNNInputMode_t inputMode, - cudnnDirectionMode_t direction, - cudnnRNNMode_t mode, - cudnnRNNAlgo_t algo, - cudnnDataType_t mathPrec) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnHandle_t, cudnnRNNDescriptor_t, const int, const int, cudnnDropoutDescriptor_t, cudnnRNNInputMode_t, cudnnDirectionMode_t, cudnnRNNMode_t, cudnnRNNAlgo_t, cudnnDataType_t); +cudnnStatus_t CUDNNWINAPI cudnnSetRNNDescriptor_v6( + cudnnHandle_t handle, cudnnRNNDescriptor_t rnnDesc, const int hiddenSize, + const int numLayers, cudnnDropoutDescriptor_t dropoutDesc, + cudnnRNNInputMode_t inputMode, cudnnDirectionMode_t direction, + cudnnRNNMode_t mode, cudnnRNNAlgo_t algo, cudnnDataType_t mathPrec) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnHandle_t, cudnnRNNDescriptor_t, const int, const int, + cudnnDropoutDescriptor_t, cudnnRNNInputMode_t, cudnnDirectionMode_t, + cudnnRNNMode_t, cudnnRNNAlgo_t, cudnnDataType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetRNNDescriptor_v6"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(handle, rnnDesc, hiddenSize, numLayers, dropoutDesc, inputMode, direction, mode, algo, mathPrec); + return func_ptr(handle, rnnDesc, hiddenSize, numLayers, dropoutDesc, + inputMode, direction, mode, algo, mathPrec); } -cudnnStatus_t CUDNNWINAPI -cudnnSetRNNDescriptor_v5(cudnnRNNDescriptor_t rnnDesc, - int hiddenSize, - int numLayers, - cudnnDropoutDescriptor_t dropoutDesc, - cudnnRNNInputMode_t inputMode, - cudnnDirectionMode_t direction, - cudnnRNNMode_t mode, - cudnnDataType_t mathPrec) { - using FuncPtr = cudnnStatus_t (CUDNNWINAPI *)(cudnnRNNDescriptor_t, int, int, cudnnDropoutDescriptor_t, cudnnRNNInputMode_t, cudnnDirectionMode_t, cudnnRNNMode_t, cudnnDataType_t); +cudnnStatus_t CUDNNWINAPI cudnnSetRNNDescriptor_v5( + cudnnRNNDescriptor_t rnnDesc, int hiddenSize, int numLayers, + cudnnDropoutDescriptor_t dropoutDesc, cudnnRNNInputMode_t inputMode, + cudnnDirectionMode_t direction, cudnnRNNMode_t mode, + cudnnDataType_t mathPrec) { + using FuncPtr = cudnnStatus_t(CUDNNWINAPI *)( + cudnnRNNDescriptor_t, int, int, cudnnDropoutDescriptor_t, + cudnnRNNInputMode_t, cudnnDirectionMode_t, cudnnRNNMode_t, + cudnnDataType_t); static auto func_ptr = LoadSymbol<FuncPtr>("cudnnSetRNNDescriptor_v5"); if (!func_ptr) return GetSymbolNotFoundError(); - return func_ptr(rnnDesc, hiddenSize, numLayers, dropoutDesc, inputMode, direction, mode, mathPrec); + return func_ptr(rnnDesc, hiddenSize, numLayers, dropoutDesc, inputMode, + direction, mode, mathPrec); } } // extern "C" diff --git a/tensorflow/stream_executor/cuda/cusolver_dense_11_0.inc b/tensorflow/stream_executor/cuda/cusolver_dense_11_0.inc new file mode 100644 index 00000000000..c4f32c84680 --- /dev/null +++ b/tensorflow/stream_executor/cuda/cusolver_dense_11_0.inc @@ -0,0 +1,4686 @@ +// Auto-generated, do not edit. + +extern "C" { + +cusolverStatus_t CUSOLVERAPI cusolverGetProperty(libraryPropertyType type, + int *value) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(libraryPropertyType, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverGetProperty"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(type, value); +} + +cusolverStatus_t CUSOLVERAPI cusolverGetVersion(int *version) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverGetVersion"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(version); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCreate(cusolverDnHandle_t *handle) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDestroy(cusolverDnHandle_t handle) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSetStream(cusolverDnHandle_t handle, + cudaStream_t streamId) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, cudaStream_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSetStream"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, streamId); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnGetStream(cusolverDnHandle_t handle, + cudaStream_t *streamId) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, cudaStream_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnGetStream"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, streamId); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnIRSParamsCreate(cusolverDnIRSParams_t *params_ptr) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnIRSParams_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnIRSParamsCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(params_ptr); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnIRSParamsDestroy(cusolverDnIRSParams_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnIRSParams_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnIRSParamsDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnIRSParamsSetRefinementSolver( + cusolverDnIRSParams_t params, cusolverIRSRefinement_t refinement_solver) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnIRSParams_t, + cusolverIRSRefinement_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusolverDnIRSParamsSetRefinementSolver"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(params, refinement_solver); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnIRSParamsSetSolverMainPrecision( + cusolverDnIRSParams_t params, cusolverPrecType_t solver_main_precision) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnIRSParams_t, + cusolverPrecType_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusolverDnIRSParamsSetSolverMainPrecision"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(params, solver_main_precision); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnIRSParamsSetSolverLowestPrecision( + cusolverDnIRSParams_t params, cusolverPrecType_t solver_lowest_precision) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnIRSParams_t, + cusolverPrecType_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusolverDnIRSParamsSetSolverLowestPrecision"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(params, solver_lowest_precision); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnIRSParamsSetSolverPrecisions( + cusolverDnIRSParams_t params, cusolverPrecType_t solver_main_precision, + cusolverPrecType_t solver_lowest_precision) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnIRSParams_t, cusolverPrecType_t, cusolverPrecType_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusolverDnIRSParamsSetSolverPrecisions"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(params, solver_main_precision, solver_lowest_precision); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnIRSParamsSetTol(cusolverDnIRSParams_t params, double val) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnIRSParams_t, double); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnIRSParamsSetTol"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(params, val); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnIRSParamsSetTolInner(cusolverDnIRSParams_t params, double val) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnIRSParams_t, double); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnIRSParamsSetTolInner"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(params, val); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnIRSParamsSetMaxIters( + cusolverDnIRSParams_t params, cusolver_int_t maxiters) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnIRSParams_t, cusolver_int_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnIRSParamsSetMaxIters"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(params, maxiters); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnIRSParamsSetMaxItersInner( + cusolverDnIRSParams_t params, cusolver_int_t maxiters_inner) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnIRSParams_t, cusolver_int_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusolverDnIRSParamsSetMaxItersInner"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(params, maxiters_inner); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnIRSParamsGetMaxIters( + cusolverDnIRSParams_t params, cusolver_int_t *maxiters) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnIRSParams_t, cusolver_int_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnIRSParamsGetMaxIters"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(params, maxiters); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnIRSParamsEnableFallback(cusolverDnIRSParams_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnIRSParams_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusolverDnIRSParamsEnableFallback"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(params); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnIRSParamsDisableFallback(cusolverDnIRSParams_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnIRSParams_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusolverDnIRSParamsDisableFallback"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(params); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnIRSInfosDestroy(cusolverDnIRSInfos_t infos) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnIRSInfos_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnIRSInfosDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(infos); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnIRSInfosCreate(cusolverDnIRSInfos_t *infos_ptr) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnIRSInfos_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnIRSInfosCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(infos_ptr); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnIRSInfosGetNiters( + cusolverDnIRSInfos_t infos, cusolver_int_t *niters) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnIRSInfos_t, cusolver_int_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnIRSInfosGetNiters"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(infos, niters); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnIRSInfosGetOuterNiters( + cusolverDnIRSInfos_t infos, cusolver_int_t *outer_niters) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnIRSInfos_t, cusolver_int_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusolverDnIRSInfosGetOuterNiters"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(infos, outer_niters); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnIRSInfosRequestResidual(cusolverDnIRSInfos_t infos) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnIRSInfos_t); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusolverDnIRSInfosRequestResidual"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(infos); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnIRSInfosGetResidualHistory( + cusolverDnIRSInfos_t infos, void **residual_history) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnIRSInfos_t, void **); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusolverDnIRSInfosGetResidualHistory"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(infos, residual_history); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnIRSInfosGetMaxIters( + cusolverDnIRSInfos_t infos, cusolver_int_t *maxiters) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnIRSInfos_t, cusolver_int_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnIRSInfosGetMaxIters"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(infos, maxiters); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZZgesv( + cusolverDnHandle_t handle, cusolver_int_t n, cusolver_int_t nrhs, + cuDoubleComplex *dA, cusolver_int_t ldda, cusolver_int_t *dipiv, + cuDoubleComplex *dB, cusolver_int_t lddb, cuDoubleComplex *dX, + cusolver_int_t lddx, void *dWorkspace, size_t lwork_bytes, + cusolver_int_t *iter, cusolver_int_t *d_info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cuDoubleComplex *, + cusolver_int_t, cusolver_int_t *, cuDoubleComplex *, cusolver_int_t, + cuDoubleComplex *, cusolver_int_t, void *, size_t, cusolver_int_t *, + cusolver_int_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZZgesv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, + dWorkspace, lwork_bytes, iter, d_info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZCgesv( + cusolverDnHandle_t handle, cusolver_int_t n, cusolver_int_t nrhs, + cuDoubleComplex *dA, cusolver_int_t ldda, cusolver_int_t *dipiv, + cuDoubleComplex *dB, cusolver_int_t lddb, cuDoubleComplex *dX, + cusolver_int_t lddx, void *dWorkspace, size_t lwork_bytes, + cusolver_int_t *iter, cusolver_int_t *d_info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cuDoubleComplex *, + cusolver_int_t, cusolver_int_t *, cuDoubleComplex *, cusolver_int_t, + cuDoubleComplex *, cusolver_int_t, void *, size_t, cusolver_int_t *, + cusolver_int_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZCgesv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, + dWorkspace, lwork_bytes, iter, d_info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZKgesv( + cusolverDnHandle_t handle, cusolver_int_t n, cusolver_int_t nrhs, + cuDoubleComplex *dA, cusolver_int_t ldda, cusolver_int_t *dipiv, + cuDoubleComplex *dB, cusolver_int_t lddb, cuDoubleComplex *dX, + cusolver_int_t lddx, void *dWorkspace, size_t lwork_bytes, + cusolver_int_t *iter, cusolver_int_t *d_info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cuDoubleComplex *, + cusolver_int_t, cusolver_int_t *, cuDoubleComplex *, cusolver_int_t, + cuDoubleComplex *, cusolver_int_t, void *, size_t, cusolver_int_t *, + cusolver_int_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZKgesv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, + dWorkspace, lwork_bytes, iter, d_info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZEgesv( + cusolverDnHandle_t handle, cusolver_int_t n, cusolver_int_t nrhs, + cuDoubleComplex *dA, cusolver_int_t ldda, cusolver_int_t *dipiv, + cuDoubleComplex *dB, cusolver_int_t lddb, cuDoubleComplex *dX, + cusolver_int_t lddx, void *dWorkspace, size_t lwork_bytes, + cusolver_int_t *iter, cusolver_int_t *d_info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cuDoubleComplex *, + cusolver_int_t, cusolver_int_t *, cuDoubleComplex *, cusolver_int_t, + cuDoubleComplex *, cusolver_int_t, void *, size_t, cusolver_int_t *, + cusolver_int_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZEgesv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, + dWorkspace, lwork_bytes, iter, d_info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZYgesv( + cusolverDnHandle_t handle, cusolver_int_t n, cusolver_int_t nrhs, + cuDoubleComplex *dA, cusolver_int_t ldda, cusolver_int_t *dipiv, + cuDoubleComplex *dB, cusolver_int_t lddb, cuDoubleComplex *dX, + cusolver_int_t lddx, void *dWorkspace, size_t lwork_bytes, + cusolver_int_t *iter, cusolver_int_t *d_info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cuDoubleComplex *, + cusolver_int_t, cusolver_int_t *, cuDoubleComplex *, cusolver_int_t, + cuDoubleComplex *, cusolver_int_t, void *, size_t, cusolver_int_t *, + cusolver_int_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZYgesv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, + dWorkspace, lwork_bytes, iter, d_info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCCgesv( + cusolverDnHandle_t handle, cusolver_int_t n, cusolver_int_t nrhs, + cuComplex *dA, cusolver_int_t ldda, cusolver_int_t *dipiv, cuComplex *dB, + cusolver_int_t lddb, cuComplex *dX, cusolver_int_t lddx, void *dWorkspace, + size_t lwork_bytes, cusolver_int_t *iter, cusolver_int_t *d_info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cuComplex *, + cusolver_int_t, cusolver_int_t *, cuComplex *, cusolver_int_t, + cuComplex *, cusolver_int_t, void *, size_t, cusolver_int_t *, + cusolver_int_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCCgesv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, + dWorkspace, lwork_bytes, iter, d_info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCEgesv( + cusolverDnHandle_t handle, cusolver_int_t n, cusolver_int_t nrhs, + cuComplex *dA, cusolver_int_t ldda, cusolver_int_t *dipiv, cuComplex *dB, + cusolver_int_t lddb, cuComplex *dX, cusolver_int_t lddx, void *dWorkspace, + size_t lwork_bytes, cusolver_int_t *iter, cusolver_int_t *d_info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cuComplex *, + cusolver_int_t, cusolver_int_t *, cuComplex *, cusolver_int_t, + cuComplex *, cusolver_int_t, void *, size_t, cusolver_int_t *, + cusolver_int_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCEgesv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, + dWorkspace, lwork_bytes, iter, d_info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCKgesv( + cusolverDnHandle_t handle, cusolver_int_t n, cusolver_int_t nrhs, + cuComplex *dA, cusolver_int_t ldda, cusolver_int_t *dipiv, cuComplex *dB, + cusolver_int_t lddb, cuComplex *dX, cusolver_int_t lddx, void *dWorkspace, + size_t lwork_bytes, cusolver_int_t *iter, cusolver_int_t *d_info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cuComplex *, + cusolver_int_t, cusolver_int_t *, cuComplex *, cusolver_int_t, + cuComplex *, cusolver_int_t, void *, size_t, cusolver_int_t *, + cusolver_int_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCKgesv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, + dWorkspace, lwork_bytes, iter, d_info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCYgesv( + cusolverDnHandle_t handle, cusolver_int_t n, cusolver_int_t nrhs, + cuComplex *dA, cusolver_int_t ldda, cusolver_int_t *dipiv, cuComplex *dB, + cusolver_int_t lddb, cuComplex *dX, cusolver_int_t lddx, void *dWorkspace, + size_t lwork_bytes, cusolver_int_t *iter, cusolver_int_t *d_info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cuComplex *, + cusolver_int_t, cusolver_int_t *, cuComplex *, cusolver_int_t, + cuComplex *, cusolver_int_t, void *, size_t, cusolver_int_t *, + cusolver_int_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCYgesv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, + dWorkspace, lwork_bytes, iter, d_info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDDgesv( + cusolverDnHandle_t handle, cusolver_int_t n, cusolver_int_t nrhs, + double *dA, cusolver_int_t ldda, cusolver_int_t *dipiv, double *dB, + cusolver_int_t lddb, double *dX, cusolver_int_t lddx, void *dWorkspace, + size_t lwork_bytes, cusolver_int_t *iter, cusolver_int_t *d_info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, double *, + cusolver_int_t, cusolver_int_t *, double *, cusolver_int_t, double *, + cusolver_int_t, void *, size_t, cusolver_int_t *, cusolver_int_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDDgesv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, + dWorkspace, lwork_bytes, iter, d_info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDSgesv( + cusolverDnHandle_t handle, cusolver_int_t n, cusolver_int_t nrhs, + double *dA, cusolver_int_t ldda, cusolver_int_t *dipiv, double *dB, + cusolver_int_t lddb, double *dX, cusolver_int_t lddx, void *dWorkspace, + size_t lwork_bytes, cusolver_int_t *iter, cusolver_int_t *d_info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, double *, + cusolver_int_t, cusolver_int_t *, double *, cusolver_int_t, double *, + cusolver_int_t, void *, size_t, cusolver_int_t *, cusolver_int_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDSgesv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, + dWorkspace, lwork_bytes, iter, d_info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDHgesv( + cusolverDnHandle_t handle, cusolver_int_t n, cusolver_int_t nrhs, + double *dA, cusolver_int_t ldda, cusolver_int_t *dipiv, double *dB, + cusolver_int_t lddb, double *dX, cusolver_int_t lddx, void *dWorkspace, + size_t lwork_bytes, cusolver_int_t *iter, cusolver_int_t *d_info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, double *, + cusolver_int_t, cusolver_int_t *, double *, cusolver_int_t, double *, + cusolver_int_t, void *, size_t, cusolver_int_t *, cusolver_int_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDHgesv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, + dWorkspace, lwork_bytes, iter, d_info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDBgesv( + cusolverDnHandle_t handle, cusolver_int_t n, cusolver_int_t nrhs, + double *dA, cusolver_int_t ldda, cusolver_int_t *dipiv, double *dB, + cusolver_int_t lddb, double *dX, cusolver_int_t lddx, void *dWorkspace, + size_t lwork_bytes, cusolver_int_t *iter, cusolver_int_t *d_info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, double *, + cusolver_int_t, cusolver_int_t *, double *, cusolver_int_t, double *, + cusolver_int_t, void *, size_t, cusolver_int_t *, cusolver_int_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDBgesv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, + dWorkspace, lwork_bytes, iter, d_info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDXgesv( + cusolverDnHandle_t handle, cusolver_int_t n, cusolver_int_t nrhs, + double *dA, cusolver_int_t ldda, cusolver_int_t *dipiv, double *dB, + cusolver_int_t lddb, double *dX, cusolver_int_t lddx, void *dWorkspace, + size_t lwork_bytes, cusolver_int_t *iter, cusolver_int_t *d_info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, double *, + cusolver_int_t, cusolver_int_t *, double *, cusolver_int_t, double *, + cusolver_int_t, void *, size_t, cusolver_int_t *, cusolver_int_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDXgesv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, + dWorkspace, lwork_bytes, iter, d_info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSSgesv( + cusolverDnHandle_t handle, cusolver_int_t n, cusolver_int_t nrhs, float *dA, + cusolver_int_t ldda, cusolver_int_t *dipiv, float *dB, cusolver_int_t lddb, + float *dX, cusolver_int_t lddx, void *dWorkspace, size_t lwork_bytes, + cusolver_int_t *iter, cusolver_int_t *d_info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, float *, + cusolver_int_t, cusolver_int_t *, float *, cusolver_int_t, float *, + cusolver_int_t, void *, size_t, cusolver_int_t *, cusolver_int_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSSgesv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, + dWorkspace, lwork_bytes, iter, d_info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSHgesv( + cusolverDnHandle_t handle, cusolver_int_t n, cusolver_int_t nrhs, float *dA, + cusolver_int_t ldda, cusolver_int_t *dipiv, float *dB, cusolver_int_t lddb, + float *dX, cusolver_int_t lddx, void *dWorkspace, size_t lwork_bytes, + cusolver_int_t *iter, cusolver_int_t *d_info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, float *, + cusolver_int_t, cusolver_int_t *, float *, cusolver_int_t, float *, + cusolver_int_t, void *, size_t, cusolver_int_t *, cusolver_int_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSHgesv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, + dWorkspace, lwork_bytes, iter, d_info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSBgesv( + cusolverDnHandle_t handle, cusolver_int_t n, cusolver_int_t nrhs, float *dA, + cusolver_int_t ldda, cusolver_int_t *dipiv, float *dB, cusolver_int_t lddb, + float *dX, cusolver_int_t lddx, void *dWorkspace, size_t lwork_bytes, + cusolver_int_t *iter, cusolver_int_t *d_info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, float *, + cusolver_int_t, cusolver_int_t *, float *, cusolver_int_t, float *, + cusolver_int_t, void *, size_t, cusolver_int_t *, cusolver_int_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSBgesv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, + dWorkspace, lwork_bytes, iter, d_info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSXgesv( + cusolverDnHandle_t handle, cusolver_int_t n, cusolver_int_t nrhs, float *dA, + cusolver_int_t ldda, cusolver_int_t *dipiv, float *dB, cusolver_int_t lddb, + float *dX, cusolver_int_t lddx, void *dWorkspace, size_t lwork_bytes, + cusolver_int_t *iter, cusolver_int_t *d_info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, float *, + cusolver_int_t, cusolver_int_t *, float *, cusolver_int_t, float *, + cusolver_int_t, void *, size_t, cusolver_int_t *, cusolver_int_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSXgesv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, + dWorkspace, lwork_bytes, iter, d_info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZZgesv_bufferSize( + cusolverDnHandle_t handle, cusolver_int_t n, cusolver_int_t nrhs, + cuDoubleComplex *dA, cusolver_int_t ldda, cusolver_int_t *dipiv, + cuDoubleComplex *dB, cusolver_int_t lddb, cuDoubleComplex *dX, + cusolver_int_t lddx, void *dWorkspace, size_t *lwork_bytes) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cuDoubleComplex *, + cusolver_int_t, cusolver_int_t *, cuDoubleComplex *, cusolver_int_t, + cuDoubleComplex *, cusolver_int_t, void *, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZZgesv_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, + dWorkspace, lwork_bytes); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZCgesv_bufferSize( + cusolverDnHandle_t handle, cusolver_int_t n, cusolver_int_t nrhs, + cuDoubleComplex *dA, cusolver_int_t ldda, cusolver_int_t *dipiv, + cuDoubleComplex *dB, cusolver_int_t lddb, cuDoubleComplex *dX, + cusolver_int_t lddx, void *dWorkspace, size_t *lwork_bytes) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cuDoubleComplex *, + cusolver_int_t, cusolver_int_t *, cuDoubleComplex *, cusolver_int_t, + cuDoubleComplex *, cusolver_int_t, void *, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZCgesv_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, + dWorkspace, lwork_bytes); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZKgesv_bufferSize( + cusolverDnHandle_t handle, cusolver_int_t n, cusolver_int_t nrhs, + cuDoubleComplex *dA, cusolver_int_t ldda, cusolver_int_t *dipiv, + cuDoubleComplex *dB, cusolver_int_t lddb, cuDoubleComplex *dX, + cusolver_int_t lddx, void *dWorkspace, size_t *lwork_bytes) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cuDoubleComplex *, + cusolver_int_t, cusolver_int_t *, cuDoubleComplex *, cusolver_int_t, + cuDoubleComplex *, cusolver_int_t, void *, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZKgesv_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, + dWorkspace, lwork_bytes); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZEgesv_bufferSize( + cusolverDnHandle_t handle, cusolver_int_t n, cusolver_int_t nrhs, + cuDoubleComplex *dA, cusolver_int_t ldda, cusolver_int_t *dipiv, + cuDoubleComplex *dB, cusolver_int_t lddb, cuDoubleComplex *dX, + cusolver_int_t lddx, void *dWorkspace, size_t *lwork_bytes) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cuDoubleComplex *, + cusolver_int_t, cusolver_int_t *, cuDoubleComplex *, cusolver_int_t, + cuDoubleComplex *, cusolver_int_t, void *, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZEgesv_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, + dWorkspace, lwork_bytes); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZYgesv_bufferSize( + cusolverDnHandle_t handle, cusolver_int_t n, cusolver_int_t nrhs, + cuDoubleComplex *dA, cusolver_int_t ldda, cusolver_int_t *dipiv, + cuDoubleComplex *dB, cusolver_int_t lddb, cuDoubleComplex *dX, + cusolver_int_t lddx, void *dWorkspace, size_t *lwork_bytes) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cuDoubleComplex *, + cusolver_int_t, cusolver_int_t *, cuDoubleComplex *, cusolver_int_t, + cuDoubleComplex *, cusolver_int_t, void *, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZYgesv_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, + dWorkspace, lwork_bytes); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCCgesv_bufferSize( + cusolverDnHandle_t handle, cusolver_int_t n, cusolver_int_t nrhs, + cuComplex *dA, cusolver_int_t ldda, cusolver_int_t *dipiv, cuComplex *dB, + cusolver_int_t lddb, cuComplex *dX, cusolver_int_t lddx, void *dWorkspace, + size_t *lwork_bytes) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cuComplex *, + cusolver_int_t, cusolver_int_t *, cuComplex *, cusolver_int_t, + cuComplex *, cusolver_int_t, void *, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCCgesv_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, + dWorkspace, lwork_bytes); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCKgesv_bufferSize( + cusolverDnHandle_t handle, cusolver_int_t n, cusolver_int_t nrhs, + cuComplex *dA, cusolver_int_t ldda, cusolver_int_t *dipiv, cuComplex *dB, + cusolver_int_t lddb, cuComplex *dX, cusolver_int_t lddx, void *dWorkspace, + size_t *lwork_bytes) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cuComplex *, + cusolver_int_t, cusolver_int_t *, cuComplex *, cusolver_int_t, + cuComplex *, cusolver_int_t, void *, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCKgesv_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, + dWorkspace, lwork_bytes); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCEgesv_bufferSize( + cusolverDnHandle_t handle, cusolver_int_t n, cusolver_int_t nrhs, + cuComplex *dA, cusolver_int_t ldda, cusolver_int_t *dipiv, cuComplex *dB, + cusolver_int_t lddb, cuComplex *dX, cusolver_int_t lddx, void *dWorkspace, + size_t *lwork_bytes) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cuComplex *, + cusolver_int_t, cusolver_int_t *, cuComplex *, cusolver_int_t, + cuComplex *, cusolver_int_t, void *, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCEgesv_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, + dWorkspace, lwork_bytes); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCYgesv_bufferSize( + cusolverDnHandle_t handle, cusolver_int_t n, cusolver_int_t nrhs, + cuComplex *dA, cusolver_int_t ldda, cusolver_int_t *dipiv, cuComplex *dB, + cusolver_int_t lddb, cuComplex *dX, cusolver_int_t lddx, void *dWorkspace, + size_t *lwork_bytes) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cuComplex *, + cusolver_int_t, cusolver_int_t *, cuComplex *, cusolver_int_t, + cuComplex *, cusolver_int_t, void *, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCYgesv_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, + dWorkspace, lwork_bytes); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDDgesv_bufferSize( + cusolverDnHandle_t handle, cusolver_int_t n, cusolver_int_t nrhs, + double *dA, cusolver_int_t ldda, cusolver_int_t *dipiv, double *dB, + cusolver_int_t lddb, double *dX, cusolver_int_t lddx, void *dWorkspace, + size_t *lwork_bytes) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, double *, + cusolver_int_t, cusolver_int_t *, double *, cusolver_int_t, double *, + cusolver_int_t, void *, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDDgesv_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, + dWorkspace, lwork_bytes); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDSgesv_bufferSize( + cusolverDnHandle_t handle, cusolver_int_t n, cusolver_int_t nrhs, + double *dA, cusolver_int_t ldda, cusolver_int_t *dipiv, double *dB, + cusolver_int_t lddb, double *dX, cusolver_int_t lddx, void *dWorkspace, + size_t *lwork_bytes) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, double *, + cusolver_int_t, cusolver_int_t *, double *, cusolver_int_t, double *, + cusolver_int_t, void *, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDSgesv_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, + dWorkspace, lwork_bytes); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDHgesv_bufferSize( + cusolverDnHandle_t handle, cusolver_int_t n, cusolver_int_t nrhs, + double *dA, cusolver_int_t ldda, cusolver_int_t *dipiv, double *dB, + cusolver_int_t lddb, double *dX, cusolver_int_t lddx, void *dWorkspace, + size_t *lwork_bytes) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, double *, + cusolver_int_t, cusolver_int_t *, double *, cusolver_int_t, double *, + cusolver_int_t, void *, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDHgesv_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, + dWorkspace, lwork_bytes); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDBgesv_bufferSize( + cusolverDnHandle_t handle, cusolver_int_t n, cusolver_int_t nrhs, + double *dA, cusolver_int_t ldda, cusolver_int_t *dipiv, double *dB, + cusolver_int_t lddb, double *dX, cusolver_int_t lddx, void *dWorkspace, + size_t *lwork_bytes) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, double *, + cusolver_int_t, cusolver_int_t *, double *, cusolver_int_t, double *, + cusolver_int_t, void *, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDBgesv_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, + dWorkspace, lwork_bytes); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDXgesv_bufferSize( + cusolverDnHandle_t handle, cusolver_int_t n, cusolver_int_t nrhs, + double *dA, cusolver_int_t ldda, cusolver_int_t *dipiv, double *dB, + cusolver_int_t lddb, double *dX, cusolver_int_t lddx, void *dWorkspace, + size_t *lwork_bytes) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, double *, + cusolver_int_t, cusolver_int_t *, double *, cusolver_int_t, double *, + cusolver_int_t, void *, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDXgesv_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, + dWorkspace, lwork_bytes); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSSgesv_bufferSize( + cusolverDnHandle_t handle, cusolver_int_t n, cusolver_int_t nrhs, float *dA, + cusolver_int_t ldda, cusolver_int_t *dipiv, float *dB, cusolver_int_t lddb, + float *dX, cusolver_int_t lddx, void *dWorkspace, size_t *lwork_bytes) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, float *, + cusolver_int_t, cusolver_int_t *, float *, cusolver_int_t, float *, + cusolver_int_t, void *, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSSgesv_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, + dWorkspace, lwork_bytes); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSHgesv_bufferSize( + cusolverDnHandle_t handle, cusolver_int_t n, cusolver_int_t nrhs, float *dA, + cusolver_int_t ldda, cusolver_int_t *dipiv, float *dB, cusolver_int_t lddb, + float *dX, cusolver_int_t lddx, void *dWorkspace, size_t *lwork_bytes) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, float *, + cusolver_int_t, cusolver_int_t *, float *, cusolver_int_t, float *, + cusolver_int_t, void *, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSHgesv_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, + dWorkspace, lwork_bytes); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSBgesv_bufferSize( + cusolverDnHandle_t handle, cusolver_int_t n, cusolver_int_t nrhs, float *dA, + cusolver_int_t ldda, cusolver_int_t *dipiv, float *dB, cusolver_int_t lddb, + float *dX, cusolver_int_t lddx, void *dWorkspace, size_t *lwork_bytes) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, float *, + cusolver_int_t, cusolver_int_t *, float *, cusolver_int_t, float *, + cusolver_int_t, void *, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSBgesv_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, + dWorkspace, lwork_bytes); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSXgesv_bufferSize( + cusolverDnHandle_t handle, cusolver_int_t n, cusolver_int_t nrhs, float *dA, + cusolver_int_t ldda, cusolver_int_t *dipiv, float *dB, cusolver_int_t lddb, + float *dX, cusolver_int_t lddx, void *dWorkspace, size_t *lwork_bytes) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, float *, + cusolver_int_t, cusolver_int_t *, float *, cusolver_int_t, float *, + cusolver_int_t, void *, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSXgesv_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, nrhs, dA, ldda, dipiv, dB, lddb, dX, lddx, + dWorkspace, lwork_bytes); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnZZgels(cusolverDnHandle_t handle, cusolver_int_t m, cusolver_int_t n, + cusolver_int_t nrhs, cuDoubleComplex *dA, cusolver_int_t ldda, + cuDoubleComplex *dB, cusolver_int_t lddb, cuDoubleComplex *dX, + cusolver_int_t lddx, void *dWorkspace, size_t lwork_bytes, + cusolver_int_t *iter, cusolver_int_t *d_info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cusolver_int_t, + cuDoubleComplex *, cusolver_int_t, cuDoubleComplex *, cusolver_int_t, + cuDoubleComplex *, cusolver_int_t, void *, size_t, cusolver_int_t *, + cusolver_int_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZZgels"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nrhs, dA, ldda, dB, lddb, dX, lddx, dWorkspace, + lwork_bytes, iter, d_info); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnZCgels(cusolverDnHandle_t handle, cusolver_int_t m, cusolver_int_t n, + cusolver_int_t nrhs, cuDoubleComplex *dA, cusolver_int_t ldda, + cuDoubleComplex *dB, cusolver_int_t lddb, cuDoubleComplex *dX, + cusolver_int_t lddx, void *dWorkspace, size_t lwork_bytes, + cusolver_int_t *iter, cusolver_int_t *d_info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cusolver_int_t, + cuDoubleComplex *, cusolver_int_t, cuDoubleComplex *, cusolver_int_t, + cuDoubleComplex *, cusolver_int_t, void *, size_t, cusolver_int_t *, + cusolver_int_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZCgels"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nrhs, dA, ldda, dB, lddb, dX, lddx, dWorkspace, + lwork_bytes, iter, d_info); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnZKgels(cusolverDnHandle_t handle, cusolver_int_t m, cusolver_int_t n, + cusolver_int_t nrhs, cuDoubleComplex *dA, cusolver_int_t ldda, + cuDoubleComplex *dB, cusolver_int_t lddb, cuDoubleComplex *dX, + cusolver_int_t lddx, void *dWorkspace, size_t lwork_bytes, + cusolver_int_t *iter, cusolver_int_t *d_info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cusolver_int_t, + cuDoubleComplex *, cusolver_int_t, cuDoubleComplex *, cusolver_int_t, + cuDoubleComplex *, cusolver_int_t, void *, size_t, cusolver_int_t *, + cusolver_int_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZKgels"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nrhs, dA, ldda, dB, lddb, dX, lddx, dWorkspace, + lwork_bytes, iter, d_info); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnZEgels(cusolverDnHandle_t handle, cusolver_int_t m, cusolver_int_t n, + cusolver_int_t nrhs, cuDoubleComplex *dA, cusolver_int_t ldda, + cuDoubleComplex *dB, cusolver_int_t lddb, cuDoubleComplex *dX, + cusolver_int_t lddx, void *dWorkspace, size_t lwork_bytes, + cusolver_int_t *iter, cusolver_int_t *d_info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cusolver_int_t, + cuDoubleComplex *, cusolver_int_t, cuDoubleComplex *, cusolver_int_t, + cuDoubleComplex *, cusolver_int_t, void *, size_t, cusolver_int_t *, + cusolver_int_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZEgels"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nrhs, dA, ldda, dB, lddb, dX, lddx, dWorkspace, + lwork_bytes, iter, d_info); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnZYgels(cusolverDnHandle_t handle, cusolver_int_t m, cusolver_int_t n, + cusolver_int_t nrhs, cuDoubleComplex *dA, cusolver_int_t ldda, + cuDoubleComplex *dB, cusolver_int_t lddb, cuDoubleComplex *dX, + cusolver_int_t lddx, void *dWorkspace, size_t lwork_bytes, + cusolver_int_t *iter, cusolver_int_t *d_info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cusolver_int_t, + cuDoubleComplex *, cusolver_int_t, cuDoubleComplex *, cusolver_int_t, + cuDoubleComplex *, cusolver_int_t, void *, size_t, cusolver_int_t *, + cusolver_int_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZYgels"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nrhs, dA, ldda, dB, lddb, dX, lddx, dWorkspace, + lwork_bytes, iter, d_info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCCgels( + cusolverDnHandle_t handle, cusolver_int_t m, cusolver_int_t n, + cusolver_int_t nrhs, cuComplex *dA, cusolver_int_t ldda, cuComplex *dB, + cusolver_int_t lddb, cuComplex *dX, cusolver_int_t lddx, void *dWorkspace, + size_t lwork_bytes, cusolver_int_t *iter, cusolver_int_t *d_info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cusolver_int_t, + cuComplex *, cusolver_int_t, cuComplex *, cusolver_int_t, cuComplex *, + cusolver_int_t, void *, size_t, cusolver_int_t *, cusolver_int_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCCgels"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nrhs, dA, ldda, dB, lddb, dX, lddx, dWorkspace, + lwork_bytes, iter, d_info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCKgels( + cusolverDnHandle_t handle, cusolver_int_t m, cusolver_int_t n, + cusolver_int_t nrhs, cuComplex *dA, cusolver_int_t ldda, cuComplex *dB, + cusolver_int_t lddb, cuComplex *dX, cusolver_int_t lddx, void *dWorkspace, + size_t lwork_bytes, cusolver_int_t *iter, cusolver_int_t *d_info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cusolver_int_t, + cuComplex *, cusolver_int_t, cuComplex *, cusolver_int_t, cuComplex *, + cusolver_int_t, void *, size_t, cusolver_int_t *, cusolver_int_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCKgels"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nrhs, dA, ldda, dB, lddb, dX, lddx, dWorkspace, + lwork_bytes, iter, d_info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCEgels( + cusolverDnHandle_t handle, cusolver_int_t m, cusolver_int_t n, + cusolver_int_t nrhs, cuComplex *dA, cusolver_int_t ldda, cuComplex *dB, + cusolver_int_t lddb, cuComplex *dX, cusolver_int_t lddx, void *dWorkspace, + size_t lwork_bytes, cusolver_int_t *iter, cusolver_int_t *d_info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cusolver_int_t, + cuComplex *, cusolver_int_t, cuComplex *, cusolver_int_t, cuComplex *, + cusolver_int_t, void *, size_t, cusolver_int_t *, cusolver_int_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCEgels"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nrhs, dA, ldda, dB, lddb, dX, lddx, dWorkspace, + lwork_bytes, iter, d_info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCYgels( + cusolverDnHandle_t handle, cusolver_int_t m, cusolver_int_t n, + cusolver_int_t nrhs, cuComplex *dA, cusolver_int_t ldda, cuComplex *dB, + cusolver_int_t lddb, cuComplex *dX, cusolver_int_t lddx, void *dWorkspace, + size_t lwork_bytes, cusolver_int_t *iter, cusolver_int_t *d_info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cusolver_int_t, + cuComplex *, cusolver_int_t, cuComplex *, cusolver_int_t, cuComplex *, + cusolver_int_t, void *, size_t, cusolver_int_t *, cusolver_int_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCYgels"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nrhs, dA, ldda, dB, lddb, dX, lddx, dWorkspace, + lwork_bytes, iter, d_info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDDgels( + cusolverDnHandle_t handle, cusolver_int_t m, cusolver_int_t n, + cusolver_int_t nrhs, double *dA, cusolver_int_t ldda, double *dB, + cusolver_int_t lddb, double *dX, cusolver_int_t lddx, void *dWorkspace, + size_t lwork_bytes, cusolver_int_t *iter, cusolver_int_t *d_info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cusolver_int_t, + double *, cusolver_int_t, double *, cusolver_int_t, double *, + cusolver_int_t, void *, size_t, cusolver_int_t *, cusolver_int_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDDgels"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nrhs, dA, ldda, dB, lddb, dX, lddx, dWorkspace, + lwork_bytes, iter, d_info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDSgels( + cusolverDnHandle_t handle, cusolver_int_t m, cusolver_int_t n, + cusolver_int_t nrhs, double *dA, cusolver_int_t ldda, double *dB, + cusolver_int_t lddb, double *dX, cusolver_int_t lddx, void *dWorkspace, + size_t lwork_bytes, cusolver_int_t *iter, cusolver_int_t *d_info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cusolver_int_t, + double *, cusolver_int_t, double *, cusolver_int_t, double *, + cusolver_int_t, void *, size_t, cusolver_int_t *, cusolver_int_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDSgels"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nrhs, dA, ldda, dB, lddb, dX, lddx, dWorkspace, + lwork_bytes, iter, d_info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDHgels( + cusolverDnHandle_t handle, cusolver_int_t m, cusolver_int_t n, + cusolver_int_t nrhs, double *dA, cusolver_int_t ldda, double *dB, + cusolver_int_t lddb, double *dX, cusolver_int_t lddx, void *dWorkspace, + size_t lwork_bytes, cusolver_int_t *iter, cusolver_int_t *d_info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cusolver_int_t, + double *, cusolver_int_t, double *, cusolver_int_t, double *, + cusolver_int_t, void *, size_t, cusolver_int_t *, cusolver_int_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDHgels"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nrhs, dA, ldda, dB, lddb, dX, lddx, dWorkspace, + lwork_bytes, iter, d_info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDBgels( + cusolverDnHandle_t handle, cusolver_int_t m, cusolver_int_t n, + cusolver_int_t nrhs, double *dA, cusolver_int_t ldda, double *dB, + cusolver_int_t lddb, double *dX, cusolver_int_t lddx, void *dWorkspace, + size_t lwork_bytes, cusolver_int_t *iter, cusolver_int_t *d_info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cusolver_int_t, + double *, cusolver_int_t, double *, cusolver_int_t, double *, + cusolver_int_t, void *, size_t, cusolver_int_t *, cusolver_int_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDBgels"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nrhs, dA, ldda, dB, lddb, dX, lddx, dWorkspace, + lwork_bytes, iter, d_info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDXgels( + cusolverDnHandle_t handle, cusolver_int_t m, cusolver_int_t n, + cusolver_int_t nrhs, double *dA, cusolver_int_t ldda, double *dB, + cusolver_int_t lddb, double *dX, cusolver_int_t lddx, void *dWorkspace, + size_t lwork_bytes, cusolver_int_t *iter, cusolver_int_t *d_info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cusolver_int_t, + double *, cusolver_int_t, double *, cusolver_int_t, double *, + cusolver_int_t, void *, size_t, cusolver_int_t *, cusolver_int_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDXgels"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nrhs, dA, ldda, dB, lddb, dX, lddx, dWorkspace, + lwork_bytes, iter, d_info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSSgels( + cusolverDnHandle_t handle, cusolver_int_t m, cusolver_int_t n, + cusolver_int_t nrhs, float *dA, cusolver_int_t ldda, float *dB, + cusolver_int_t lddb, float *dX, cusolver_int_t lddx, void *dWorkspace, + size_t lwork_bytes, cusolver_int_t *iter, cusolver_int_t *d_info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cusolver_int_t, + float *, cusolver_int_t, float *, cusolver_int_t, float *, cusolver_int_t, + void *, size_t, cusolver_int_t *, cusolver_int_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSSgels"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nrhs, dA, ldda, dB, lddb, dX, lddx, dWorkspace, + lwork_bytes, iter, d_info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSHgels( + cusolverDnHandle_t handle, cusolver_int_t m, cusolver_int_t n, + cusolver_int_t nrhs, float *dA, cusolver_int_t ldda, float *dB, + cusolver_int_t lddb, float *dX, cusolver_int_t lddx, void *dWorkspace, + size_t lwork_bytes, cusolver_int_t *iter, cusolver_int_t *d_info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cusolver_int_t, + float *, cusolver_int_t, float *, cusolver_int_t, float *, cusolver_int_t, + void *, size_t, cusolver_int_t *, cusolver_int_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSHgels"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nrhs, dA, ldda, dB, lddb, dX, lddx, dWorkspace, + lwork_bytes, iter, d_info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSBgels( + cusolverDnHandle_t handle, cusolver_int_t m, cusolver_int_t n, + cusolver_int_t nrhs, float *dA, cusolver_int_t ldda, float *dB, + cusolver_int_t lddb, float *dX, cusolver_int_t lddx, void *dWorkspace, + size_t lwork_bytes, cusolver_int_t *iter, cusolver_int_t *d_info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cusolver_int_t, + float *, cusolver_int_t, float *, cusolver_int_t, float *, cusolver_int_t, + void *, size_t, cusolver_int_t *, cusolver_int_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSBgels"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nrhs, dA, ldda, dB, lddb, dX, lddx, dWorkspace, + lwork_bytes, iter, d_info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSXgels( + cusolverDnHandle_t handle, cusolver_int_t m, cusolver_int_t n, + cusolver_int_t nrhs, float *dA, cusolver_int_t ldda, float *dB, + cusolver_int_t lddb, float *dX, cusolver_int_t lddx, void *dWorkspace, + size_t lwork_bytes, cusolver_int_t *iter, cusolver_int_t *d_info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cusolver_int_t, + float *, cusolver_int_t, float *, cusolver_int_t, float *, cusolver_int_t, + void *, size_t, cusolver_int_t *, cusolver_int_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSXgels"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nrhs, dA, ldda, dB, lddb, dX, lddx, dWorkspace, + lwork_bytes, iter, d_info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZZgels_bufferSize( + cusolverDnHandle_t handle, cusolver_int_t m, cusolver_int_t n, + cusolver_int_t nrhs, cuDoubleComplex *dA, cusolver_int_t ldda, + cuDoubleComplex *dB, cusolver_int_t lddb, cuDoubleComplex *dX, + cusolver_int_t lddx, void *dWorkspace, size_t *lwork_bytes) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cusolver_int_t, + cuDoubleComplex *, cusolver_int_t, cuDoubleComplex *, cusolver_int_t, + cuDoubleComplex *, cusolver_int_t, void *, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZZgels_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nrhs, dA, ldda, dB, lddb, dX, lddx, dWorkspace, + lwork_bytes); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZCgels_bufferSize( + cusolverDnHandle_t handle, cusolver_int_t m, cusolver_int_t n, + cusolver_int_t nrhs, cuDoubleComplex *dA, cusolver_int_t ldda, + cuDoubleComplex *dB, cusolver_int_t lddb, cuDoubleComplex *dX, + cusolver_int_t lddx, void *dWorkspace, size_t *lwork_bytes) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cusolver_int_t, + cuDoubleComplex *, cusolver_int_t, cuDoubleComplex *, cusolver_int_t, + cuDoubleComplex *, cusolver_int_t, void *, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZCgels_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nrhs, dA, ldda, dB, lddb, dX, lddx, dWorkspace, + lwork_bytes); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZKgels_bufferSize( + cusolverDnHandle_t handle, cusolver_int_t m, cusolver_int_t n, + cusolver_int_t nrhs, cuDoubleComplex *dA, cusolver_int_t ldda, + cuDoubleComplex *dB, cusolver_int_t lddb, cuDoubleComplex *dX, + cusolver_int_t lddx, void *dWorkspace, size_t *lwork_bytes) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cusolver_int_t, + cuDoubleComplex *, cusolver_int_t, cuDoubleComplex *, cusolver_int_t, + cuDoubleComplex *, cusolver_int_t, void *, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZKgels_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nrhs, dA, ldda, dB, lddb, dX, lddx, dWorkspace, + lwork_bytes); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZEgels_bufferSize( + cusolverDnHandle_t handle, cusolver_int_t m, cusolver_int_t n, + cusolver_int_t nrhs, cuDoubleComplex *dA, cusolver_int_t ldda, + cuDoubleComplex *dB, cusolver_int_t lddb, cuDoubleComplex *dX, + cusolver_int_t lddx, void *dWorkspace, size_t *lwork_bytes) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cusolver_int_t, + cuDoubleComplex *, cusolver_int_t, cuDoubleComplex *, cusolver_int_t, + cuDoubleComplex *, cusolver_int_t, void *, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZEgels_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nrhs, dA, ldda, dB, lddb, dX, lddx, dWorkspace, + lwork_bytes); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZYgels_bufferSize( + cusolverDnHandle_t handle, cusolver_int_t m, cusolver_int_t n, + cusolver_int_t nrhs, cuDoubleComplex *dA, cusolver_int_t ldda, + cuDoubleComplex *dB, cusolver_int_t lddb, cuDoubleComplex *dX, + cusolver_int_t lddx, void *dWorkspace, size_t *lwork_bytes) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cusolver_int_t, + cuDoubleComplex *, cusolver_int_t, cuDoubleComplex *, cusolver_int_t, + cuDoubleComplex *, cusolver_int_t, void *, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZYgels_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nrhs, dA, ldda, dB, lddb, dX, lddx, dWorkspace, + lwork_bytes); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCCgels_bufferSize( + cusolverDnHandle_t handle, cusolver_int_t m, cusolver_int_t n, + cusolver_int_t nrhs, cuComplex *dA, cusolver_int_t ldda, cuComplex *dB, + cusolver_int_t lddb, cuComplex *dX, cusolver_int_t lddx, void *dWorkspace, + size_t *lwork_bytes) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cusolver_int_t, + cuComplex *, cusolver_int_t, cuComplex *, cusolver_int_t, cuComplex *, + cusolver_int_t, void *, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCCgels_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nrhs, dA, ldda, dB, lddb, dX, lddx, dWorkspace, + lwork_bytes); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCKgels_bufferSize( + cusolverDnHandle_t handle, cusolver_int_t m, cusolver_int_t n, + cusolver_int_t nrhs, cuComplex *dA, cusolver_int_t ldda, cuComplex *dB, + cusolver_int_t lddb, cuComplex *dX, cusolver_int_t lddx, void *dWorkspace, + size_t *lwork_bytes) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cusolver_int_t, + cuComplex *, cusolver_int_t, cuComplex *, cusolver_int_t, cuComplex *, + cusolver_int_t, void *, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCKgels_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nrhs, dA, ldda, dB, lddb, dX, lddx, dWorkspace, + lwork_bytes); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCEgels_bufferSize( + cusolverDnHandle_t handle, cusolver_int_t m, cusolver_int_t n, + cusolver_int_t nrhs, cuComplex *dA, cusolver_int_t ldda, cuComplex *dB, + cusolver_int_t lddb, cuComplex *dX, cusolver_int_t lddx, void *dWorkspace, + size_t *lwork_bytes) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cusolver_int_t, + cuComplex *, cusolver_int_t, cuComplex *, cusolver_int_t, cuComplex *, + cusolver_int_t, void *, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCEgels_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nrhs, dA, ldda, dB, lddb, dX, lddx, dWorkspace, + lwork_bytes); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCYgels_bufferSize( + cusolverDnHandle_t handle, cusolver_int_t m, cusolver_int_t n, + cusolver_int_t nrhs, cuComplex *dA, cusolver_int_t ldda, cuComplex *dB, + cusolver_int_t lddb, cuComplex *dX, cusolver_int_t lddx, void *dWorkspace, + size_t *lwork_bytes) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cusolver_int_t, + cuComplex *, cusolver_int_t, cuComplex *, cusolver_int_t, cuComplex *, + cusolver_int_t, void *, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCYgels_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nrhs, dA, ldda, dB, lddb, dX, lddx, dWorkspace, + lwork_bytes); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDDgels_bufferSize( + cusolverDnHandle_t handle, cusolver_int_t m, cusolver_int_t n, + cusolver_int_t nrhs, double *dA, cusolver_int_t ldda, double *dB, + cusolver_int_t lddb, double *dX, cusolver_int_t lddx, void *dWorkspace, + size_t *lwork_bytes) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cusolver_int_t, + double *, cusolver_int_t, double *, cusolver_int_t, double *, + cusolver_int_t, void *, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDDgels_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nrhs, dA, ldda, dB, lddb, dX, lddx, dWorkspace, + lwork_bytes); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDSgels_bufferSize( + cusolverDnHandle_t handle, cusolver_int_t m, cusolver_int_t n, + cusolver_int_t nrhs, double *dA, cusolver_int_t ldda, double *dB, + cusolver_int_t lddb, double *dX, cusolver_int_t lddx, void *dWorkspace, + size_t *lwork_bytes) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cusolver_int_t, + double *, cusolver_int_t, double *, cusolver_int_t, double *, + cusolver_int_t, void *, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDSgels_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nrhs, dA, ldda, dB, lddb, dX, lddx, dWorkspace, + lwork_bytes); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDHgels_bufferSize( + cusolverDnHandle_t handle, cusolver_int_t m, cusolver_int_t n, + cusolver_int_t nrhs, double *dA, cusolver_int_t ldda, double *dB, + cusolver_int_t lddb, double *dX, cusolver_int_t lddx, void *dWorkspace, + size_t *lwork_bytes) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cusolver_int_t, + double *, cusolver_int_t, double *, cusolver_int_t, double *, + cusolver_int_t, void *, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDHgels_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nrhs, dA, ldda, dB, lddb, dX, lddx, dWorkspace, + lwork_bytes); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDBgels_bufferSize( + cusolverDnHandle_t handle, cusolver_int_t m, cusolver_int_t n, + cusolver_int_t nrhs, double *dA, cusolver_int_t ldda, double *dB, + cusolver_int_t lddb, double *dX, cusolver_int_t lddx, void *dWorkspace, + size_t *lwork_bytes) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cusolver_int_t, + double *, cusolver_int_t, double *, cusolver_int_t, double *, + cusolver_int_t, void *, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDBgels_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nrhs, dA, ldda, dB, lddb, dX, lddx, dWorkspace, + lwork_bytes); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDXgels_bufferSize( + cusolverDnHandle_t handle, cusolver_int_t m, cusolver_int_t n, + cusolver_int_t nrhs, double *dA, cusolver_int_t ldda, double *dB, + cusolver_int_t lddb, double *dX, cusolver_int_t lddx, void *dWorkspace, + size_t *lwork_bytes) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cusolver_int_t, + double *, cusolver_int_t, double *, cusolver_int_t, double *, + cusolver_int_t, void *, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDXgels_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nrhs, dA, ldda, dB, lddb, dX, lddx, dWorkspace, + lwork_bytes); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSSgels_bufferSize( + cusolverDnHandle_t handle, cusolver_int_t m, cusolver_int_t n, + cusolver_int_t nrhs, float *dA, cusolver_int_t ldda, float *dB, + cusolver_int_t lddb, float *dX, cusolver_int_t lddx, void *dWorkspace, + size_t *lwork_bytes) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cusolver_int_t, + float *, cusolver_int_t, float *, cusolver_int_t, float *, cusolver_int_t, + void *, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSSgels_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nrhs, dA, ldda, dB, lddb, dX, lddx, dWorkspace, + lwork_bytes); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSHgels_bufferSize( + cusolverDnHandle_t handle, cusolver_int_t m, cusolver_int_t n, + cusolver_int_t nrhs, float *dA, cusolver_int_t ldda, float *dB, + cusolver_int_t lddb, float *dX, cusolver_int_t lddx, void *dWorkspace, + size_t *lwork_bytes) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cusolver_int_t, + float *, cusolver_int_t, float *, cusolver_int_t, float *, cusolver_int_t, + void *, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSHgels_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nrhs, dA, ldda, dB, lddb, dX, lddx, dWorkspace, + lwork_bytes); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSBgels_bufferSize( + cusolverDnHandle_t handle, cusolver_int_t m, cusolver_int_t n, + cusolver_int_t nrhs, float *dA, cusolver_int_t ldda, float *dB, + cusolver_int_t lddb, float *dX, cusolver_int_t lddx, void *dWorkspace, + size_t *lwork_bytes) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cusolver_int_t, + float *, cusolver_int_t, float *, cusolver_int_t, float *, cusolver_int_t, + void *, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSBgels_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nrhs, dA, ldda, dB, lddb, dX, lddx, dWorkspace, + lwork_bytes); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSXgels_bufferSize( + cusolverDnHandle_t handle, cusolver_int_t m, cusolver_int_t n, + cusolver_int_t nrhs, float *dA, cusolver_int_t ldda, float *dB, + cusolver_int_t lddb, float *dX, cusolver_int_t lddx, void *dWorkspace, + size_t *lwork_bytes) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolver_int_t, cusolver_int_t, cusolver_int_t, + float *, cusolver_int_t, float *, cusolver_int_t, float *, cusolver_int_t, + void *, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSXgels_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nrhs, dA, ldda, dB, lddb, dX, lddx, dWorkspace, + lwork_bytes); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnIRSXgesv( + cusolverDnHandle_t handle, cusolverDnIRSParams_t gesv_irs_params, + cusolverDnIRSInfos_t gesv_irs_infos, cusolver_int_t n, cusolver_int_t nrhs, + void *dA, cusolver_int_t ldda, void *dB, cusolver_int_t lddb, void *dX, + cusolver_int_t lddx, void *dWorkspace, size_t lwork_bytes, + cusolver_int_t *niters, cusolver_int_t *d_info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverDnIRSParams_t, cusolverDnIRSInfos_t, + cusolver_int_t, cusolver_int_t, void *, cusolver_int_t, void *, + cusolver_int_t, void *, cusolver_int_t, void *, size_t, cusolver_int_t *, + cusolver_int_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnIRSXgesv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, gesv_irs_params, gesv_irs_infos, n, nrhs, dA, ldda, + dB, lddb, dX, lddx, dWorkspace, lwork_bytes, niters, d_info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnIRSXgesv_bufferSize( + cusolverDnHandle_t handle, cusolverDnIRSParams_t params, cusolver_int_t n, + cusolver_int_t nrhs, size_t *lwork_bytes) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, cusolverDnIRSParams_t, + cusolver_int_t, cusolver_int_t, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnIRSXgesv_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, params, n, nrhs, lwork_bytes); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnIRSXgels( + cusolverDnHandle_t handle, cusolverDnIRSParams_t gels_irs_params, + cusolverDnIRSInfos_t gels_irs_infos, cusolver_int_t m, cusolver_int_t n, + cusolver_int_t nrhs, void *dA, cusolver_int_t ldda, void *dB, + cusolver_int_t lddb, void *dX, cusolver_int_t lddx, void *dWorkspace, + size_t lwork_bytes, cusolver_int_t *niters, cusolver_int_t *d_info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverDnIRSParams_t, cusolverDnIRSInfos_t, + cusolver_int_t, cusolver_int_t, cusolver_int_t, void *, cusolver_int_t, + void *, cusolver_int_t, void *, cusolver_int_t, void *, size_t, + cusolver_int_t *, cusolver_int_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnIRSXgels"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, gels_irs_params, gels_irs_infos, m, n, nrhs, dA, ldda, + dB, lddb, dX, lddx, dWorkspace, lwork_bytes, niters, d_info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnIRSXgels_bufferSize( + cusolverDnHandle_t handle, cusolverDnIRSParams_t params, cusolver_int_t m, + cusolver_int_t n, cusolver_int_t nrhs, size_t *lwork_bytes) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverDnIRSParams_t, cusolver_int_t, cusolver_int_t, + cusolver_int_t, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnIRSXgels_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, params, m, n, nrhs, lwork_bytes); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnSpotrf_bufferSize(cusolverDnHandle_t handle, cublasFillMode_t uplo, + int n, float *A, int lda, int *Lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, float *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSpotrf_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, Lwork); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnDpotrf_bufferSize(cusolverDnHandle_t handle, cublasFillMode_t uplo, + int n, double *A, int lda, int *Lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, double *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDpotrf_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, Lwork); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnCpotrf_bufferSize(cusolverDnHandle_t handle, cublasFillMode_t uplo, + int n, cuComplex *A, int lda, int *Lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, cuComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCpotrf_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, Lwork); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnZpotrf_bufferSize(cusolverDnHandle_t handle, cublasFillMode_t uplo, + int n, cuDoubleComplex *A, int lda, int *Lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZpotrf_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, Lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSpotrf(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + float *A, int lda, + float *Workspace, int Lwork, + int *devInfo) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, cublasFillMode_t, int, + float *, int, float *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSpotrf"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, Workspace, Lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDpotrf(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + double *A, int lda, + double *Workspace, int Lwork, + int *devInfo) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, cublasFillMode_t, int, + double *, int, double *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDpotrf"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, Workspace, Lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCpotrf(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + cuComplex *A, int lda, + cuComplex *Workspace, int Lwork, + int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, cuComplex *, int, cuComplex *, + int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCpotrf"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, Workspace, Lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZpotrf(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + cuDoubleComplex *A, int lda, + cuDoubleComplex *Workspace, + int Lwork, int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, cuDoubleComplex *, int, + cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZpotrf"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, Workspace, Lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSpotrs(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + int nrhs, const float *A, int lda, + float *B, int ldb, int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, int, const float *, int, + float *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSpotrs"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, nrhs, A, lda, B, ldb, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDpotrs(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + int nrhs, const double *A, + int lda, double *B, int ldb, + int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, int, const double *, int, + double *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDpotrs"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, nrhs, A, lda, B, ldb, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCpotrs(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + int nrhs, const cuComplex *A, + int lda, cuComplex *B, int ldb, + int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, int, const cuComplex *, int, + cuComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCpotrs"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, nrhs, A, lda, B, ldb, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZpotrs(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + int nrhs, + const cuDoubleComplex *A, int lda, + cuDoubleComplex *B, int ldb, + int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, int, const cuDoubleComplex *, + int, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZpotrs"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, nrhs, A, lda, B, ldb, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSpotrfBatched(cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, float *Aarray[], + int lda, int *infoArray, + int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, float *[], int, int *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSpotrfBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, Aarray, lda, infoArray, batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDpotrfBatched(cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, double *Aarray[], + int lda, int *infoArray, + int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, double *[], int, int *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDpotrfBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, Aarray, lda, infoArray, batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCpotrfBatched(cusolverDnHandle_t handle, + cublasFillMode_t uplo, + int n, cuComplex *Aarray[], + int lda, int *infoArray, + int batchSize) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, cublasFillMode_t, int, + cuComplex *[], int, int *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCpotrfBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, Aarray, lda, infoArray, batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZpotrfBatched( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, + cuDoubleComplex *Aarray[], int lda, int *infoArray, int batchSize) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, cublasFillMode_t, int, + cuDoubleComplex *[], int, int *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZpotrfBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, Aarray, lda, infoArray, batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSpotrsBatched( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, + int nrhs, /* only support rhs = 1*/ + float *A[], int lda, float *B[], int ldb, int *d_info, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, int, float *[], int, float *[], + int, int *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSpotrsBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, nrhs, A, lda, B, ldb, d_info, batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDpotrsBatched( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, + int nrhs, /* only support rhs = 1*/ + double *A[], int lda, double *B[], int ldb, int *d_info, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, int, double *[], int, + double *[], int, int *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDpotrsBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, nrhs, A, lda, B, ldb, d_info, batchSize); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnCpotrsBatched(cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, + int nrhs, /* only support rhs = 1*/ + cuComplex *A[], int lda, cuComplex *B[], int ldb, + int *d_info, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, int, cuComplex *[], int, + cuComplex *[], int, int *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCpotrsBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, nrhs, A, lda, B, ldb, d_info, batchSize); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnZpotrsBatched(cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, + int nrhs, /* only support rhs = 1*/ + cuDoubleComplex *A[], int lda, cuDoubleComplex *B[], + int ldb, int *d_info, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, int, cuDoubleComplex *[], int, + cuDoubleComplex *[], int, int *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZpotrsBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, nrhs, A, lda, B, ldb, d_info, batchSize); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnSpotri_bufferSize(cusolverDnHandle_t handle, cublasFillMode_t uplo, + int n, float *A, int lda, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, float *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSpotri_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, lwork); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnDpotri_bufferSize(cusolverDnHandle_t handle, cublasFillMode_t uplo, + int n, double *A, int lda, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, double *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDpotri_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, lwork); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnCpotri_bufferSize(cusolverDnHandle_t handle, cublasFillMode_t uplo, + int n, cuComplex *A, int lda, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, cuComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCpotri_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, lwork); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnZpotri_bufferSize(cusolverDnHandle_t handle, cublasFillMode_t uplo, + int n, cuDoubleComplex *A, int lda, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZpotri_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSpotri(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + float *A, int lda, float *work, + int lwork, int *devInfo) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, cublasFillMode_t, int, + float *, int, float *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSpotri"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, work, lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDpotri(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + double *A, int lda, double *work, + int lwork, int *devInfo) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, cublasFillMode_t, int, + double *, int, double *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDpotri"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, work, lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCpotri(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + cuComplex *A, int lda, + cuComplex *work, int lwork, + int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, cuComplex *, int, cuComplex *, + int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCpotri"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, work, lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZpotri(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + cuDoubleComplex *A, int lda, + cuDoubleComplex *work, int lwork, + int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, cuDoubleComplex *, int, + cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZpotri"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, work, lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnStrtri_bufferSize( + cusolverDnHandle_t handle, cublasFillMode_t uplo, cublasDiagType_t diag, + int n, float *A, int lda, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, cublasDiagType_t, int, float *, int, + int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnStrtri_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, diag, n, A, lda, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDtrtri_bufferSize( + cusolverDnHandle_t handle, cublasFillMode_t uplo, cublasDiagType_t diag, + int n, double *A, int lda, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, cublasDiagType_t, int, double *, + int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDtrtri_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, diag, n, A, lda, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCtrtri_bufferSize( + cusolverDnHandle_t handle, cublasFillMode_t uplo, cublasDiagType_t diag, + int n, cuComplex *A, int lda, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, cublasDiagType_t, int, cuComplex *, + int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCtrtri_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, diag, n, A, lda, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZtrtri_bufferSize( + cusolverDnHandle_t handle, cublasFillMode_t uplo, cublasDiagType_t diag, + int n, cuDoubleComplex *A, int lda, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, cublasDiagType_t, int, + cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZtrtri_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, diag, n, A, lda, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnStrtri(cusolverDnHandle_t handle, + cublasFillMode_t uplo, + cublasDiagType_t diag, int n, + float *A, int lda, float *work, + int lwork, int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, cublasDiagType_t, int, float *, int, + float *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnStrtri"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, diag, n, A, lda, work, lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDtrtri(cusolverDnHandle_t handle, + cublasFillMode_t uplo, + cublasDiagType_t diag, int n, + double *A, int lda, double *work, + int lwork, int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, cublasDiagType_t, int, double *, + int, double *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDtrtri"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, diag, n, A, lda, work, lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCtrtri( + cusolverDnHandle_t handle, cublasFillMode_t uplo, cublasDiagType_t diag, + int n, cuComplex *A, int lda, cuComplex *work, int lwork, int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, cublasDiagType_t, int, cuComplex *, + int, cuComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCtrtri"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, diag, n, A, lda, work, lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZtrtri(cusolverDnHandle_t handle, + cublasFillMode_t uplo, + cublasDiagType_t diag, int n, + cuDoubleComplex *A, int lda, + cuDoubleComplex *work, int lwork, + int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, cublasDiagType_t, int, + cuDoubleComplex *, int, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZtrtri"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, diag, n, A, lda, work, lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnSlauum_bufferSize(cusolverDnHandle_t handle, cublasFillMode_t uplo, + int n, float *A, int lda, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, float *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSlauum_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, lwork); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnDlauum_bufferSize(cusolverDnHandle_t handle, cublasFillMode_t uplo, + int n, double *A, int lda, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, double *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDlauum_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, lwork); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnClauum_bufferSize(cusolverDnHandle_t handle, cublasFillMode_t uplo, + int n, cuComplex *A, int lda, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, cuComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnClauum_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, lwork); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnZlauum_bufferSize(cusolverDnHandle_t handle, cublasFillMode_t uplo, + int n, cuDoubleComplex *A, int lda, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZlauum_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSlauum(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + float *A, int lda, float *work, + int lwork, int *devInfo) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, cublasFillMode_t, int, + float *, int, float *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSlauum"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, work, lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDlauum(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + double *A, int lda, double *work, + int lwork, int *devInfo) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, cublasFillMode_t, int, + double *, int, double *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDlauum"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, work, lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnClauum(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + cuComplex *A, int lda, + cuComplex *work, int lwork, + int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, cuComplex *, int, cuComplex *, + int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnClauum"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, work, lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZlauum(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + cuDoubleComplex *A, int lda, + cuDoubleComplex *work, int lwork, + int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, cuDoubleComplex *, int, + cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZlauum"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, work, lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSgetrf_bufferSize( + cusolverDnHandle_t handle, int m, int n, float *A, int lda, int *Lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, + float *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSgetrf_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, Lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDgetrf_bufferSize( + cusolverDnHandle_t handle, int m, int n, double *A, int lda, int *Lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, + double *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDgetrf_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, Lwork); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnCgetrf_bufferSize(cusolverDnHandle_t handle, int m, int n, + cuComplex *A, int lda, int *Lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, + cuComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCgetrf_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, Lwork); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnZgetrf_bufferSize(cusolverDnHandle_t handle, int m, int n, + cuDoubleComplex *A, int lda, int *Lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, int, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZgetrf_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, Lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSgetrf(cusolverDnHandle_t handle, int m, + int n, float *A, int lda, + float *Workspace, int *devIpiv, + int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, int, float *, int, float *, int *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSgetrf"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, Workspace, devIpiv, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDgetrf(cusolverDnHandle_t handle, int m, + int n, double *A, int lda, + double *Workspace, int *devIpiv, + int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, int, double *, int, double *, int *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDgetrf"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, Workspace, devIpiv, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCgetrf(cusolverDnHandle_t handle, int m, + int n, cuComplex *A, int lda, + cuComplex *Workspace, + int *devIpiv, int *devInfo) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, cuComplex *, + int, cuComplex *, int *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCgetrf"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, Workspace, devIpiv, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZgetrf(cusolverDnHandle_t handle, int m, + int n, cuDoubleComplex *A, + int lda, + cuDoubleComplex *Workspace, + int *devIpiv, int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, int, cuDoubleComplex *, int, cuDoubleComplex *, + int *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZgetrf"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, Workspace, devIpiv, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSlaswp(cusolverDnHandle_t handle, int n, + float *A, int lda, int k1, int k2, + const int *devIpiv, int incx) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, float *, int, int, int, const int *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSlaswp"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, A, lda, k1, k2, devIpiv, incx); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDlaswp(cusolverDnHandle_t handle, int n, + double *A, int lda, int k1, + int k2, const int *devIpiv, + int incx) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, double *, int, int, int, const int *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDlaswp"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, A, lda, k1, k2, devIpiv, incx); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnClaswp(cusolverDnHandle_t handle, int n, + cuComplex *A, int lda, int k1, + int k2, const int *devIpiv, + int incx) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, cuComplex *, int, int, int, const int *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnClaswp"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, A, lda, k1, k2, devIpiv, incx); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZlaswp(cusolverDnHandle_t handle, int n, + cuDoubleComplex *A, int lda, + int k1, int k2, + const int *devIpiv, int incx) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, + cuDoubleComplex *, int, int, + int, const int *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZlaswp"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, A, lda, k1, k2, devIpiv, incx); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSgetrs(cusolverDnHandle_t handle, + cublasOperation_t trans, int n, + int nrhs, const float *A, int lda, + const int *devIpiv, float *B, + int ldb, int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasOperation_t, int, int, const float *, int, + const int *, float *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSgetrs"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, n, nrhs, A, lda, devIpiv, B, ldb, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDgetrs(cusolverDnHandle_t handle, + cublasOperation_t trans, int n, + int nrhs, const double *A, + int lda, const int *devIpiv, + double *B, int ldb, + int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasOperation_t, int, int, const double *, int, + const int *, double *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDgetrs"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, n, nrhs, A, lda, devIpiv, B, ldb, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCgetrs(cusolverDnHandle_t handle, + cublasOperation_t trans, int n, + int nrhs, const cuComplex *A, + int lda, const int *devIpiv, + cuComplex *B, int ldb, + int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasOperation_t, int, int, const cuComplex *, int, + const int *, cuComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCgetrs"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, n, nrhs, A, lda, devIpiv, B, ldb, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZgetrs( + cusolverDnHandle_t handle, cublasOperation_t trans, int n, int nrhs, + const cuDoubleComplex *A, int lda, const int *devIpiv, cuDoubleComplex *B, + int ldb, int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasOperation_t, int, int, const cuDoubleComplex *, + int, const int *, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZgetrs"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, trans, n, nrhs, A, lda, devIpiv, B, ldb, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSgeqrf_bufferSize( + cusolverDnHandle_t handle, int m, int n, float *A, int lda, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, + float *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSgeqrf_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDgeqrf_bufferSize( + cusolverDnHandle_t handle, int m, int n, double *A, int lda, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, + double *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDgeqrf_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, lwork); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnCgeqrf_bufferSize(cusolverDnHandle_t handle, int m, int n, + cuComplex *A, int lda, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, + cuComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCgeqrf_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, lwork); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnZgeqrf_bufferSize(cusolverDnHandle_t handle, int m, int n, + cuDoubleComplex *A, int lda, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, int, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZgeqrf_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSgeqrf(cusolverDnHandle_t handle, int m, + int n, float *A, int lda, + float *TAU, float *Workspace, + int Lwork, int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, int, float *, int, float *, float *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSgeqrf"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, TAU, Workspace, Lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDgeqrf(cusolverDnHandle_t handle, int m, + int n, double *A, int lda, + double *TAU, double *Workspace, + int Lwork, int *devInfo) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, double *, + int, double *, double *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDgeqrf"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, TAU, Workspace, Lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCgeqrf(cusolverDnHandle_t handle, int m, + int n, cuComplex *A, int lda, + cuComplex *TAU, + cuComplex *Workspace, int Lwork, + int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, + cuComplex *, int, cuComplex *, + cuComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCgeqrf"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, TAU, Workspace, Lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZgeqrf(cusolverDnHandle_t handle, int m, + int n, cuDoubleComplex *A, + int lda, cuDoubleComplex *TAU, + cuDoubleComplex *Workspace, + int Lwork, int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, int, cuDoubleComplex *, int, cuDoubleComplex *, + cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZgeqrf"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, TAU, Workspace, Lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSorgqr_bufferSize( + cusolverDnHandle_t handle, int m, int n, int k, const float *A, int lda, + const float *tau, int *lwork) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, int, + const float *, int, const float *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSorgqr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, k, A, lda, tau, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDorgqr_bufferSize( + cusolverDnHandle_t handle, int m, int n, int k, const double *A, int lda, + const double *tau, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, + int, const double *, int, + const double *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDorgqr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, k, A, lda, tau, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCungqr_bufferSize( + cusolverDnHandle_t handle, int m, int n, int k, const cuComplex *A, int lda, + const cuComplex *tau, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, + int, const cuComplex *, int, + const cuComplex *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCungqr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, k, A, lda, tau, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZungqr_bufferSize( + cusolverDnHandle_t handle, int m, int n, int k, const cuDoubleComplex *A, + int lda, const cuDoubleComplex *tau, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, int, int, const cuDoubleComplex *, int, + const cuDoubleComplex *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZungqr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, k, A, lda, tau, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSorgqr(cusolverDnHandle_t handle, int m, + int n, int k, float *A, int lda, + const float *tau, float *work, + int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, int, int, float *, int, const float *, float *, + int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSorgqr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, k, A, lda, tau, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDorgqr(cusolverDnHandle_t handle, int m, + int n, int k, double *A, int lda, + const double *tau, double *work, + int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, int, int, double *, int, const double *, + double *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDorgqr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, k, A, lda, tau, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCungqr(cusolverDnHandle_t handle, int m, + int n, int k, cuComplex *A, + int lda, const cuComplex *tau, + cuComplex *work, int lwork, + int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, int, int, cuComplex *, int, const cuComplex *, + cuComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCungqr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, k, A, lda, tau, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZungqr( + cusolverDnHandle_t handle, int m, int n, int k, cuDoubleComplex *A, int lda, + const cuDoubleComplex *tau, cuDoubleComplex *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, int, int, cuDoubleComplex *, int, + const cuDoubleComplex *, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZungqr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, k, A, lda, tau, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSormqr_bufferSize( + cusolverDnHandle_t handle, cublasSideMode_t side, cublasOperation_t trans, + int m, int n, int k, const float *A, int lda, const float *tau, + const float *C, int ldc, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, cublasOperation_t, int, int, int, + const float *, int, const float *, const float *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSormqr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, trans, m, n, k, A, lda, tau, C, ldc, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDormqr_bufferSize( + cusolverDnHandle_t handle, cublasSideMode_t side, cublasOperation_t trans, + int m, int n, int k, const double *A, int lda, const double *tau, + const double *C, int ldc, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, cublasOperation_t, int, int, int, + const double *, int, const double *, const double *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDormqr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, trans, m, n, k, A, lda, tau, C, ldc, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCunmqr_bufferSize( + cusolverDnHandle_t handle, cublasSideMode_t side, cublasOperation_t trans, + int m, int n, int k, const cuComplex *A, int lda, const cuComplex *tau, + const cuComplex *C, int ldc, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, cublasOperation_t, int, int, int, + const cuComplex *, int, const cuComplex *, const cuComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCunmqr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, trans, m, n, k, A, lda, tau, C, ldc, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZunmqr_bufferSize( + cusolverDnHandle_t handle, cublasSideMode_t side, cublasOperation_t trans, + int m, int n, int k, const cuDoubleComplex *A, int lda, + const cuDoubleComplex *tau, const cuDoubleComplex *C, int ldc, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, cublasOperation_t, int, int, int, + const cuDoubleComplex *, int, const cuDoubleComplex *, + const cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZunmqr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, trans, m, n, k, A, lda, tau, C, ldc, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSormqr( + cusolverDnHandle_t handle, cublasSideMode_t side, cublasOperation_t trans, + int m, int n, int k, const float *A, int lda, const float *tau, float *C, + int ldc, float *work, int lwork, int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, cublasOperation_t, int, int, int, + const float *, int, const float *, float *, int, float *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSormqr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, trans, m, n, k, A, lda, tau, C, ldc, work, + lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDormqr( + cusolverDnHandle_t handle, cublasSideMode_t side, cublasOperation_t trans, + int m, int n, int k, const double *A, int lda, const double *tau, double *C, + int ldc, double *work, int lwork, int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, cublasOperation_t, int, int, int, + const double *, int, const double *, double *, int, double *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDormqr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, trans, m, n, k, A, lda, tau, C, ldc, work, + lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCunmqr( + cusolverDnHandle_t handle, cublasSideMode_t side, cublasOperation_t trans, + int m, int n, int k, const cuComplex *A, int lda, const cuComplex *tau, + cuComplex *C, int ldc, cuComplex *work, int lwork, int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, cublasOperation_t, int, int, int, + const cuComplex *, int, const cuComplex *, cuComplex *, int, cuComplex *, + int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCunmqr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, trans, m, n, k, A, lda, tau, C, ldc, work, + lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZunmqr( + cusolverDnHandle_t handle, cublasSideMode_t side, cublasOperation_t trans, + int m, int n, int k, const cuDoubleComplex *A, int lda, + const cuDoubleComplex *tau, cuDoubleComplex *C, int ldc, + cuDoubleComplex *work, int lwork, int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, cublasOperation_t, int, int, int, + const cuDoubleComplex *, int, const cuDoubleComplex *, cuDoubleComplex *, + int, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZunmqr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, trans, m, n, k, A, lda, tau, C, ldc, work, + lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsytrf_bufferSize( + cusolverDnHandle_t handle, int n, float *A, int lda, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, + float *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSsytrf_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, A, lda, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsytrf_bufferSize( + cusolverDnHandle_t handle, int n, double *A, int lda, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, + double *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDsytrf_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, A, lda, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCsytrf_bufferSize( + cusolverDnHandle_t handle, int n, cuComplex *A, int lda, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, + cuComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCsytrf_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, A, lda, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZsytrf_bufferSize( + cusolverDnHandle_t handle, int n, cuDoubleComplex *A, int lda, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZsytrf_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, A, lda, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsytrf(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + float *A, int lda, int *ipiv, + float *work, int lwork, + int *info) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, cublasFillMode_t, int, + float *, int, int *, float *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSsytrf"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, ipiv, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsytrf(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + double *A, int lda, int *ipiv, + double *work, int lwork, + int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, double *, int, int *, double *, + int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDsytrf"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, ipiv, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCsytrf(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + cuComplex *A, int lda, int *ipiv, + cuComplex *work, int lwork, + int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, cuComplex *, int, int *, + cuComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCsytrf"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, ipiv, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZsytrf(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + cuDoubleComplex *A, int lda, + int *ipiv, cuDoubleComplex *work, + int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, cuDoubleComplex *, int, int *, + cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZsytrf"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, ipiv, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsytrs_bufferSize( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, int nrhs, + const float *A, int lda, const int *ipiv, float *B, int ldb, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, int, const float *, int, + const int *, float *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSsytrs_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, nrhs, A, lda, ipiv, B, ldb, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsytrs_bufferSize( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, int nrhs, + const double *A, int lda, const int *ipiv, double *B, int ldb, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, int, const double *, int, + const int *, double *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDsytrs_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, nrhs, A, lda, ipiv, B, ldb, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCsytrs_bufferSize( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, int nrhs, + const cuComplex *A, int lda, const int *ipiv, cuComplex *B, int ldb, + int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, int, const cuComplex *, int, + const int *, cuComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCsytrs_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, nrhs, A, lda, ipiv, B, ldb, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZsytrs_bufferSize( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, int nrhs, + const cuDoubleComplex *A, int lda, const int *ipiv, cuDoubleComplex *B, + int ldb, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, int, const cuDoubleComplex *, + int, const int *, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZsytrs_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, nrhs, A, lda, ipiv, B, ldb, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsytrs(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + int nrhs, const float *A, int lda, + const int *ipiv, float *B, + int ldb, float *work, int lwork, + int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, int, const float *, int, + const int *, float *, int, float *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSsytrs"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, nrhs, A, lda, ipiv, B, ldb, work, lwork, + info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsytrs(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + int nrhs, const double *A, + int lda, const int *ipiv, + double *B, int ldb, double *work, + int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, int, const double *, int, + const int *, double *, int, double *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDsytrs"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, nrhs, A, lda, ipiv, B, ldb, work, lwork, + info); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnCsytrs(cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, + int nrhs, const cuComplex *A, int lda, const int *ipiv, + cuComplex *B, int ldb, cuComplex *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, int, const cuComplex *, int, + const int *, cuComplex *, int, cuComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCsytrs"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, nrhs, A, lda, ipiv, B, ldb, work, lwork, + info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZsytrs( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, int nrhs, + const cuDoubleComplex *A, int lda, const int *ipiv, cuDoubleComplex *B, + int ldb, cuDoubleComplex *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, int, const cuDoubleComplex *, + int, const int *, cuDoubleComplex *, int, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZsytrs"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, nrhs, A, lda, ipiv, B, ldb, work, lwork, + info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsytri_bufferSize( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, float *A, int lda, + const int *ipiv, int *lwork) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, cublasFillMode_t, int, + float *, int, const int *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSsytri_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, ipiv, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsytri_bufferSize( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, double *A, int lda, + const int *ipiv, int *lwork) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, cublasFillMode_t, int, + double *, int, const int *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDsytri_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, ipiv, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCsytri_bufferSize( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, cuComplex *A, + int lda, const int *ipiv, int *lwork) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, cublasFillMode_t, int, + cuComplex *, int, const int *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCsytri_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, ipiv, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZsytri_bufferSize( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, cuDoubleComplex *A, + int lda, const int *ipiv, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, cuDoubleComplex *, int, + const int *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZsytri_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, ipiv, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsytri(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + float *A, int lda, + const int *ipiv, float *work, + int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, float *, int, const int *, + float *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSsytri"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, ipiv, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsytri(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + double *A, int lda, + const int *ipiv, double *work, + int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, double *, int, const int *, + double *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDsytri"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, ipiv, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCsytri(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + cuComplex *A, int lda, + const int *ipiv, cuComplex *work, + int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, cuComplex *, int, const int *, + cuComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCsytri"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, ipiv, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZsytri( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, cuDoubleComplex *A, + int lda, const int *ipiv, cuDoubleComplex *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, cuDoubleComplex *, int, + const int *, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZsytri"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, ipiv, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSgebrd_bufferSize( + cusolverDnHandle_t handle, int m, int n, int *Lwork) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSgebrd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, Lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDgebrd_bufferSize( + cusolverDnHandle_t handle, int m, int n, int *Lwork) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDgebrd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, Lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCgebrd_bufferSize( + cusolverDnHandle_t handle, int m, int n, int *Lwork) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCgebrd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, Lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZgebrd_bufferSize( + cusolverDnHandle_t handle, int m, int n, int *Lwork) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZgebrd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, Lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSgebrd(cusolverDnHandle_t handle, int m, + int n, float *A, int lda, + float *D, float *E, float *TAUQ, + float *TAUP, float *Work, + int Lwork, int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, int, float *, int, float *, float *, float *, + float *, float *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSgebrd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, D, E, TAUQ, TAUP, Work, Lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDgebrd(cusolverDnHandle_t handle, int m, + int n, double *A, int lda, + double *D, double *E, + double *TAUQ, double *TAUP, + double *Work, int Lwork, + int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, int, double *, int, double *, double *, double *, + double *, double *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDgebrd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, D, E, TAUQ, TAUP, Work, Lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCgebrd(cusolverDnHandle_t handle, int m, + int n, cuComplex *A, int lda, + float *D, float *E, + cuComplex *TAUQ, cuComplex *TAUP, + cuComplex *Work, int Lwork, + int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, int, cuComplex *, int, float *, float *, + cuComplex *, cuComplex *, cuComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCgebrd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, D, E, TAUQ, TAUP, Work, Lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZgebrd( + cusolverDnHandle_t handle, int m, int n, cuDoubleComplex *A, int lda, + double *D, double *E, cuDoubleComplex *TAUQ, cuDoubleComplex *TAUP, + cuDoubleComplex *Work, int Lwork, int *devInfo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, int, int, cuDoubleComplex *, int, double *, double *, + cuDoubleComplex *, cuDoubleComplex *, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZgebrd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, D, E, TAUQ, TAUP, Work, Lwork, devInfo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSorgbr_bufferSize( + cusolverDnHandle_t handle, cublasSideMode_t side, int m, int n, int k, + const float *A, int lda, const float *tau, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, int, int, int, const float *, int, + const float *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSorgbr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, m, n, k, A, lda, tau, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDorgbr_bufferSize( + cusolverDnHandle_t handle, cublasSideMode_t side, int m, int n, int k, + const double *A, int lda, const double *tau, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, int, int, int, const double *, int, + const double *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDorgbr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, m, n, k, A, lda, tau, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCungbr_bufferSize( + cusolverDnHandle_t handle, cublasSideMode_t side, int m, int n, int k, + const cuComplex *A, int lda, const cuComplex *tau, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, int, int, int, const cuComplex *, + int, const cuComplex *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCungbr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, m, n, k, A, lda, tau, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZungbr_bufferSize( + cusolverDnHandle_t handle, cublasSideMode_t side, int m, int n, int k, + const cuDoubleComplex *A, int lda, const cuDoubleComplex *tau, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, int, int, int, + const cuDoubleComplex *, int, const cuDoubleComplex *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZungbr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, m, n, k, A, lda, tau, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSorgbr(cusolverDnHandle_t handle, + cublasSideMode_t side, int m, + int n, int k, float *A, int lda, + const float *tau, float *work, + int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, int, int, int, float *, int, + const float *, float *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSorgbr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, m, n, k, A, lda, tau, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDorgbr(cusolverDnHandle_t handle, + cublasSideMode_t side, int m, + int n, int k, double *A, int lda, + const double *tau, double *work, + int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, int, int, int, double *, int, + const double *, double *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDorgbr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, m, n, k, A, lda, tau, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCungbr(cusolverDnHandle_t handle, + cublasSideMode_t side, int m, + int n, int k, cuComplex *A, + int lda, const cuComplex *tau, + cuComplex *work, int lwork, + int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, int, int, int, cuComplex *, int, + const cuComplex *, cuComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCungbr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, m, n, k, A, lda, tau, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnZungbr(cusolverDnHandle_t handle, cublasSideMode_t side, int m, int n, + int k, cuDoubleComplex *A, int lda, const cuDoubleComplex *tau, + cuDoubleComplex *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, int, int, int, cuDoubleComplex *, + int, const cuDoubleComplex *, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZungbr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, m, n, k, A, lda, tau, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsytrd_bufferSize( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, const float *A, + int lda, const float *d, const float *e, const float *tau, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, const float *, int, + const float *, const float *, const float *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSsytrd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, d, e, tau, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsytrd_bufferSize( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, const double *A, + int lda, const double *d, const double *e, const double *tau, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, const double *, int, + const double *, const double *, const double *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDsytrd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, d, e, tau, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnChetrd_bufferSize( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, const cuComplex *A, + int lda, const float *d, const float *e, const cuComplex *tau, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, const cuComplex *, int, + const float *, const float *, const cuComplex *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnChetrd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, d, e, tau, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZhetrd_bufferSize( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, + const cuDoubleComplex *A, int lda, const double *d, const double *e, + const cuDoubleComplex *tau, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, int, + const double *, const double *, const cuDoubleComplex *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZhetrd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, d, e, tau, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsytrd(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + float *A, int lda, float *d, + float *e, float *tau, float *work, + int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, float *, int, float *, float *, + float *, float *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSsytrd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, d, e, tau, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsytrd( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, double *A, int lda, + double *d, double *e, double *tau, double *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, double *, int, double *, + double *, double *, double *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDsytrd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, d, e, tau, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnChetrd(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + cuComplex *A, int lda, float *d, + float *e, cuComplex *tau, + cuComplex *work, int lwork, + int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, cuComplex *, int, float *, + float *, cuComplex *, cuComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnChetrd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, d, e, tau, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZhetrd( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, cuDoubleComplex *A, + int lda, double *d, double *e, cuDoubleComplex *tau, cuDoubleComplex *work, + int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, cuDoubleComplex *, int, + double *, double *, cuDoubleComplex *, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZhetrd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, d, e, tau, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSorgtr_bufferSize( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, const float *A, + int lda, const float *tau, int *lwork) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, cublasFillMode_t, int, + const float *, int, const float *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSorgtr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, tau, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDorgtr_bufferSize( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, const double *A, + int lda, const double *tau, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, const double *, int, + const double *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDorgtr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, tau, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCungtr_bufferSize( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, const cuComplex *A, + int lda, const cuComplex *tau, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, const cuComplex *, int, + const cuComplex *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCungtr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, tau, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZungtr_bufferSize( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, + const cuDoubleComplex *A, int lda, const cuDoubleComplex *tau, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, const cuDoubleComplex *, int, + const cuDoubleComplex *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZungtr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, tau, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSorgtr(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + float *A, int lda, + const float *tau, float *work, + int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, float *, int, const float *, + float *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSorgtr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, tau, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDorgtr(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + double *A, int lda, + const double *tau, double *work, + int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, double *, int, const double *, + double *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDorgtr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, tau, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCungtr( + cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, cuComplex *A, + int lda, const cuComplex *tau, cuComplex *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, cuComplex *, int, + const cuComplex *, cuComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCungtr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, tau, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZungtr(cusolverDnHandle_t handle, + cublasFillMode_t uplo, int n, + cuDoubleComplex *A, int lda, + const cuDoubleComplex *tau, + cuDoubleComplex *work, int lwork, + int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasFillMode_t, int, cuDoubleComplex *, int, + const cuDoubleComplex *, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZungtr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, uplo, n, A, lda, tau, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSormtr_bufferSize( + cusolverDnHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, int m, int n, const float *A, int lda, + const float *tau, const float *C, int ldc, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + int, int, const float *, int, const float *, const float *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSormtr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, trans, m, n, A, lda, tau, C, ldc, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDormtr_bufferSize( + cusolverDnHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, int m, int n, const double *A, int lda, + const double *tau, const double *C, int ldc, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + int, int, const double *, int, const double *, const double *, int, + int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDormtr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, trans, m, n, A, lda, tau, C, ldc, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCunmtr_bufferSize( + cusolverDnHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, int m, int n, const cuComplex *A, int lda, + const cuComplex *tau, const cuComplex *C, int ldc, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + int, int, const cuComplex *, int, const cuComplex *, const cuComplex *, + int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCunmtr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, trans, m, n, A, lda, tau, C, ldc, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZunmtr_bufferSize( + cusolverDnHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, int m, int n, const cuDoubleComplex *A, int lda, + const cuDoubleComplex *tau, const cuDoubleComplex *C, int ldc, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + int, int, const cuDoubleComplex *, int, const cuDoubleComplex *, + const cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZunmtr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, trans, m, n, A, lda, tau, C, ldc, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSormtr( + cusolverDnHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, int m, int n, float *A, int lda, float *tau, + float *C, int ldc, float *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + int, int, float *, int, float *, float *, int, float *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSormtr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, trans, m, n, A, lda, tau, C, ldc, work, + lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDormtr( + cusolverDnHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, int m, int n, double *A, int lda, double *tau, + double *C, int ldc, double *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + int, int, double *, int, double *, double *, int, double *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDormtr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, trans, m, n, A, lda, tau, C, ldc, work, + lwork, info); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnCunmtr(cusolverDnHandle_t handle, cublasSideMode_t side, + cublasFillMode_t uplo, cublasOperation_t trans, int m, int n, + cuComplex *A, int lda, cuComplex *tau, cuComplex *C, int ldc, + cuComplex *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + int, int, cuComplex *, int, cuComplex *, cuComplex *, int, cuComplex *, + int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCunmtr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, trans, m, n, A, lda, tau, C, ldc, work, + lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZunmtr( + cusolverDnHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, + cublasOperation_t trans, int m, int n, cuDoubleComplex *A, int lda, + cuDoubleComplex *tau, cuDoubleComplex *C, int ldc, cuDoubleComplex *work, + int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cublasSideMode_t, cublasFillMode_t, cublasOperation_t, + int, int, cuDoubleComplex *, int, cuDoubleComplex *, cuDoubleComplex *, + int, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZunmtr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, side, uplo, trans, m, n, A, lda, tau, C, ldc, work, + lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSgesvd_bufferSize( + cusolverDnHandle_t handle, int m, int n, int *lwork) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSgesvd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDgesvd_bufferSize( + cusolverDnHandle_t handle, int m, int n, int *lwork) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDgesvd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCgesvd_bufferSize( + cusolverDnHandle_t handle, int m, int n, int *lwork) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCgesvd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZgesvd_bufferSize( + cusolverDnHandle_t handle, int m, int n, int *lwork) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, int, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZgesvd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSgesvd( + cusolverDnHandle_t handle, signed char jobu, signed char jobvt, int m, + int n, float *A, int lda, float *S, float *U, int ldu, float *VT, int ldvt, + float *work, int lwork, float *rwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, signed char, signed char, int, int, float *, int, + float *, float *, int, float *, int, float *, int, float *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSgesvd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobu, jobvt, m, n, A, lda, S, U, ldu, VT, ldvt, work, + lwork, rwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDgesvd( + cusolverDnHandle_t handle, signed char jobu, signed char jobvt, int m, + int n, double *A, int lda, double *S, double *U, int ldu, double *VT, + int ldvt, double *work, int lwork, double *rwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, signed char, signed char, int, int, double *, int, + double *, double *, int, double *, int, double *, int, double *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDgesvd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobu, jobvt, m, n, A, lda, S, U, ldu, VT, ldvt, work, + lwork, rwork, info); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnCgesvd(cusolverDnHandle_t handle, signed char jobu, signed char jobvt, + int m, int n, cuComplex *A, int lda, float *S, cuComplex *U, + int ldu, cuComplex *VT, int ldvt, cuComplex *work, int lwork, + float *rwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, signed char, signed char, int, int, cuComplex *, int, + float *, cuComplex *, int, cuComplex *, int, cuComplex *, int, float *, + int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCgesvd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobu, jobvt, m, n, A, lda, S, U, ldu, VT, ldvt, work, + lwork, rwork, info); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnZgesvd(cusolverDnHandle_t handle, signed char jobu, signed char jobvt, + int m, int n, cuDoubleComplex *A, int lda, double *S, + cuDoubleComplex *U, int ldu, cuDoubleComplex *VT, int ldvt, + cuDoubleComplex *work, int lwork, double *rwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, signed char, signed char, int, int, cuDoubleComplex *, + int, double *, cuDoubleComplex *, int, cuDoubleComplex *, int, + cuDoubleComplex *, int, double *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZgesvd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobu, jobvt, m, n, A, lda, S, U, ldu, VT, ldvt, work, + lwork, rwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsyevd_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, const float *A, int lda, const float *W, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, + const float *, int, const float *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSsyevd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsyevd_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, const double *A, int lda, const double *W, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, + const double *, int, const double *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDsyevd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCheevd_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, const cuComplex *A, int lda, const float *W, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, + const cuComplex *, int, const float *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCheevd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZheevd_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, const cuDoubleComplex *A, int lda, const double *W, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, + const cuDoubleComplex *, int, const double *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZheevd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsyevd( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, float *A, int lda, float *W, float *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, float *, + int, float *, float *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSsyevd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsyevd( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, double *A, int lda, double *W, double *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, double *, + int, double *, double *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDsyevd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCheevd(cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, + cuComplex *A, int lda, float *W, + cuComplex *work, int lwork, + int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, cuComplex *, + int, float *, cuComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCheevd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZheevd(cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, + cuDoubleComplex *A, int lda, + double *W, cuDoubleComplex *work, + int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, + cuDoubleComplex *, int, double *, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZheevd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsyevdx_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cusolverEigRange_t range, + cublasFillMode_t uplo, int n, const float *A, int lda, float vl, float vu, + int il, int iu, int *meig, const float *W, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cusolverEigRange_t, + cublasFillMode_t, int, const float *, int, float, float, int, int, int *, + const float *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSsyevdx_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, range, uplo, n, A, lda, vl, vu, il, iu, meig, W, + lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsyevdx_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cusolverEigRange_t range, + cublasFillMode_t uplo, int n, const double *A, int lda, double vl, + double vu, int il, int iu, int *meig, const double *W, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cusolverEigRange_t, + cublasFillMode_t, int, const double *, int, double, double, int, int, + int *, const double *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDsyevdx_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, range, uplo, n, A, lda, vl, vu, il, iu, meig, W, + lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCheevdx_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cusolverEigRange_t range, + cublasFillMode_t uplo, int n, const cuComplex *A, int lda, float vl, + float vu, int il, int iu, int *meig, const float *W, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cusolverEigRange_t, + cublasFillMode_t, int, const cuComplex *, int, float, float, int, int, + int *, const float *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCheevdx_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, range, uplo, n, A, lda, vl, vu, il, iu, meig, W, + lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZheevdx_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cusolverEigRange_t range, + cublasFillMode_t uplo, int n, const cuDoubleComplex *A, int lda, double vl, + double vu, int il, int iu, int *meig, const double *W, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cusolverEigRange_t, + cublasFillMode_t, int, const cuDoubleComplex *, int, double, double, int, + int, int *, const double *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZheevdx_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, range, uplo, n, A, lda, vl, vu, il, iu, meig, W, + lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsyevdx( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cusolverEigRange_t range, + cublasFillMode_t uplo, int n, float *A, int lda, float vl, float vu, int il, + int iu, int *meig, float *W, float *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cusolverEigRange_t, + cublasFillMode_t, int, float *, int, float, float, int, int, int *, + float *, float *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSsyevdx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, range, uplo, n, A, lda, vl, vu, il, iu, meig, W, + work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsyevdx( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cusolverEigRange_t range, + cublasFillMode_t uplo, int n, double *A, int lda, double vl, double vu, + int il, int iu, int *meig, double *W, double *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cusolverEigRange_t, + cublasFillMode_t, int, double *, int, double, double, int, int, int *, + double *, double *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDsyevdx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, range, uplo, n, A, lda, vl, vu, il, iu, meig, W, + work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnCheevdx(cusolverDnHandle_t handle, cusolverEigMode_t jobz, + cusolverEigRange_t range, cublasFillMode_t uplo, int n, + cuComplex *A, int lda, float vl, float vu, int il, int iu, + int *meig, float *W, cuComplex *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cusolverEigRange_t, + cublasFillMode_t, int, cuComplex *, int, float, float, int, int, int *, + float *, cuComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCheevdx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, range, uplo, n, A, lda, vl, vu, il, iu, meig, W, + work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZheevdx( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cusolverEigRange_t range, + cublasFillMode_t uplo, int n, cuDoubleComplex *A, int lda, double vl, + double vu, int il, int iu, int *meig, double *W, cuDoubleComplex *work, + int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cusolverEigRange_t, + cublasFillMode_t, int, cuDoubleComplex *, int, double, double, int, int, + int *, double *, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZheevdx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, range, uplo, n, A, lda, vl, vu, il, iu, meig, W, + work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsygvdx_bufferSize( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cusolverEigRange_t range, cublasFillMode_t uplo, int n, const float *A, + int lda, const float *B, int ldb, float vl, float vu, int il, int iu, + int *meig, const float *W, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cusolverEigRange_t, cublasFillMode_t, int, const float *, int, + const float *, int, float, float, int, int, int *, const float *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSsygvdx_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, range, uplo, n, A, lda, B, ldb, vl, vu, + il, iu, meig, W, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsygvdx_bufferSize( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cusolverEigRange_t range, cublasFillMode_t uplo, int n, const double *A, + int lda, const double *B, int ldb, double vl, double vu, int il, int iu, + int *meig, const double *W, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cusolverEigRange_t, cublasFillMode_t, int, const double *, int, + const double *, int, double, double, int, int, int *, const double *, + int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDsygvdx_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, range, uplo, n, A, lda, B, ldb, vl, vu, + il, iu, meig, W, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnChegvdx_bufferSize( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cusolverEigRange_t range, cublasFillMode_t uplo, int n, const cuComplex *A, + int lda, const cuComplex *B, int ldb, float vl, float vu, int il, int iu, + int *meig, const float *W, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cusolverEigRange_t, cublasFillMode_t, int, const cuComplex *, int, + const cuComplex *, int, float, float, int, int, int *, const float *, + int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnChegvdx_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, range, uplo, n, A, lda, B, ldb, vl, vu, + il, iu, meig, W, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZhegvdx_bufferSize( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cusolverEigRange_t range, cublasFillMode_t uplo, int n, + const cuDoubleComplex *A, int lda, const cuDoubleComplex *B, int ldb, + double vl, double vu, int il, int iu, int *meig, const double *W, + int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cusolverEigRange_t, cublasFillMode_t, int, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, double, double, int, int, int *, + const double *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZhegvdx_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, range, uplo, n, A, lda, B, ldb, vl, vu, + il, iu, meig, W, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsygvdx( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cusolverEigRange_t range, cublasFillMode_t uplo, int n, float *A, int lda, + float *B, int ldb, float vl, float vu, int il, int iu, int *meig, float *W, + float *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cusolverEigRange_t, cublasFillMode_t, int, float *, int, float *, int, + float, float, int, int, int *, float *, float *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSsygvdx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, range, uplo, n, A, lda, B, ldb, vl, vu, + il, iu, meig, W, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsygvdx( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cusolverEigRange_t range, cublasFillMode_t uplo, int n, double *A, int lda, + double *B, int ldb, double vl, double vu, int il, int iu, int *meig, + double *W, double *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cusolverEigRange_t, cublasFillMode_t, int, double *, int, double *, int, + double, double, int, int, int *, double *, double *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDsygvdx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, range, uplo, n, A, lda, B, ldb, vl, vu, + il, iu, meig, W, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnChegvdx( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cusolverEigRange_t range, cublasFillMode_t uplo, int n, cuComplex *A, + int lda, cuComplex *B, int ldb, float vl, float vu, int il, int iu, + int *meig, float *W, cuComplex *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cusolverEigRange_t, cublasFillMode_t, int, cuComplex *, int, cuComplex *, + int, float, float, int, int, int *, float *, cuComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnChegvdx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, range, uplo, n, A, lda, B, ldb, vl, vu, + il, iu, meig, W, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZhegvdx( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cusolverEigRange_t range, cublasFillMode_t uplo, int n, cuDoubleComplex *A, + int lda, cuDoubleComplex *B, int ldb, double vl, double vu, int il, int iu, + int *meig, double *W, cuDoubleComplex *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cusolverEigRange_t, cublasFillMode_t, int, cuDoubleComplex *, int, + cuDoubleComplex *, int, double, double, int, int, int *, double *, + cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZhegvdx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, range, uplo, n, A, lda, B, ldb, vl, vu, + il, iu, meig, W, work, lwork, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsygvd_bufferSize( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, const float *A, int lda, const float *B, + int ldb, const float *W, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cublasFillMode_t, int, const float *, int, const float *, int, + const float *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSsygvd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, uplo, n, A, lda, B, ldb, W, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsygvd_bufferSize( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, const double *A, int lda, const double *B, + int ldb, const double *W, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cublasFillMode_t, int, const double *, int, const double *, int, + const double *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDsygvd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, uplo, n, A, lda, B, ldb, W, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnChegvd_bufferSize( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, const cuComplex *A, int lda, + const cuComplex *B, int ldb, const float *W, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cublasFillMode_t, int, const cuComplex *, int, const cuComplex *, int, + const float *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnChegvd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, uplo, n, A, lda, B, ldb, W, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZhegvd_bufferSize( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, const cuDoubleComplex *A, int lda, + const cuDoubleComplex *B, int ldb, const double *W, int *lwork) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cublasFillMode_t, int, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, const double *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZhegvd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, uplo, n, A, lda, B, ldb, W, lwork); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsygvd( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, float *A, int lda, float *B, int ldb, + float *W, float *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cublasFillMode_t, int, float *, int, float *, int, float *, float *, int, + int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSsygvd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, uplo, n, A, lda, B, ldb, W, work, lwork, + info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsygvd( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, double *A, int lda, double *B, int ldb, + double *W, double *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cublasFillMode_t, int, double *, int, double *, int, double *, double *, + int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDsygvd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, uplo, n, A, lda, B, ldb, W, work, lwork, + info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnChegvd( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, cuComplex *A, int lda, cuComplex *B, int ldb, + float *W, cuComplex *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cublasFillMode_t, int, cuComplex *, int, cuComplex *, int, float *, + cuComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnChegvd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, uplo, n, A, lda, B, ldb, W, work, lwork, + info); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnZhegvd(cusolverDnHandle_t handle, cusolverEigType_t itype, + cusolverEigMode_t jobz, cublasFillMode_t uplo, int n, + cuDoubleComplex *A, int lda, cuDoubleComplex *B, int ldb, + double *W, cuDoubleComplex *work, int lwork, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cublasFillMode_t, int, cuDoubleComplex *, int, cuDoubleComplex *, int, + double *, cuDoubleComplex *, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZhegvd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, uplo, n, A, lda, B, ldb, W, work, lwork, + info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCreateSyevjInfo(syevjInfo_t *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(syevjInfo_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCreateSyevjInfo"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDestroySyevjInfo(syevjInfo_t info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(syevjInfo_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDestroySyevjInfo"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnXsyevjSetTolerance(syevjInfo_t info, + double tolerance) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(syevjInfo_t, double); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnXsyevjSetTolerance"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info, tolerance); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnXsyevjSetMaxSweeps(syevjInfo_t info, + int max_sweeps) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(syevjInfo_t, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnXsyevjSetMaxSweeps"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info, max_sweeps); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnXsyevjSetSortEig(syevjInfo_t info, + int sort_eig) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(syevjInfo_t, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnXsyevjSetSortEig"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info, sort_eig); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnXsyevjGetResidual( + cusolverDnHandle_t handle, syevjInfo_t info, double *residual) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, + syevjInfo_t, double *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnXsyevjGetResidual"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, info, residual); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnXsyevjGetSweeps( + cusolverDnHandle_t handle, syevjInfo_t info, int *executed_sweeps) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, syevjInfo_t, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnXsyevjGetSweeps"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, info, executed_sweeps); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsyevjBatched_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, const float *A, int lda, const float *W, int *lwork, + syevjInfo_t params, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, + const float *, int, const float *, int *, syevjInfo_t, int); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusolverDnSsyevjBatched_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, lwork, params, batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsyevjBatched_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, const double *A, int lda, const double *W, int *lwork, + syevjInfo_t params, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, + const double *, int, const double *, int *, syevjInfo_t, int); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusolverDnDsyevjBatched_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, lwork, params, batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCheevjBatched_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, const cuComplex *A, int lda, const float *W, int *lwork, + syevjInfo_t params, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, + const cuComplex *, int, const float *, int *, syevjInfo_t, int); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusolverDnCheevjBatched_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, lwork, params, batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZheevjBatched_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, const cuDoubleComplex *A, int lda, const double *W, int *lwork, + syevjInfo_t params, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, + const cuDoubleComplex *, int, const double *, int *, syevjInfo_t, int); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusolverDnZheevjBatched_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, lwork, params, batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsyevjBatched( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, float *A, int lda, float *W, float *work, int lwork, int *info, + syevjInfo_t params, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, float *, + int, float *, float *, int, int *, syevjInfo_t, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSsyevjBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, work, lwork, info, params, + batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsyevjBatched( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, double *A, int lda, double *W, double *work, int lwork, int *info, + syevjInfo_t params, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, double *, + int, double *, double *, int, int *, syevjInfo_t, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDsyevjBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, work, lwork, info, params, + batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCheevjBatched( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, cuComplex *A, int lda, float *W, cuComplex *work, int lwork, + int *info, syevjInfo_t params, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, cuComplex *, + int, float *, cuComplex *, int, int *, syevjInfo_t, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCheevjBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, work, lwork, info, params, + batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZheevjBatched( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, cuDoubleComplex *A, int lda, double *W, cuDoubleComplex *work, + int lwork, int *info, syevjInfo_t params, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, + cuDoubleComplex *, int, double *, cuDoubleComplex *, int, int *, + syevjInfo_t, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZheevjBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, work, lwork, info, params, + batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsyevj_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, const float *A, int lda, const float *W, int *lwork, + syevjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, + const float *, int, const float *, int *, syevjInfo_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSsyevj_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, lwork, params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsyevj_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, const double *A, int lda, const double *W, int *lwork, + syevjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, + const double *, int, const double *, int *, syevjInfo_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDsyevj_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, lwork, params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCheevj_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, const cuComplex *A, int lda, const float *W, int *lwork, + syevjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, + const cuComplex *, int, const float *, int *, syevjInfo_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCheevj_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, lwork, params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZheevj_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, const cuDoubleComplex *A, int lda, const double *W, int *lwork, + syevjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, + const cuDoubleComplex *, int, const double *, int *, syevjInfo_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZheevj_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, lwork, params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsyevj(cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, + float *A, int lda, float *W, + float *work, int lwork, int *info, + syevjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, float *, + int, float *, float *, int, int *, syevjInfo_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSsyevj"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, work, lwork, info, params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsyevj(cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, + double *A, int lda, double *W, + double *work, int lwork, + int *info, syevjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, double *, + int, double *, double *, int, int *, syevjInfo_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDsyevj"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, work, lwork, info, params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCheevj(cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, + cuComplex *A, int lda, float *W, + cuComplex *work, int lwork, + int *info, syevjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, cuComplex *, + int, float *, cuComplex *, int, int *, syevjInfo_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCheevj"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, work, lwork, info, params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZheevj( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, + int n, cuDoubleComplex *A, int lda, double *W, cuDoubleComplex *work, + int lwork, int *info, syevjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, cublasFillMode_t, int, + cuDoubleComplex *, int, double *, cuDoubleComplex *, int, int *, + syevjInfo_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZheevj"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, uplo, n, A, lda, W, work, lwork, info, params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsygvj_bufferSize( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, const float *A, int lda, const float *B, + int ldb, const float *W, int *lwork, syevjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cublasFillMode_t, int, const float *, int, const float *, int, + const float *, int *, syevjInfo_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSsygvj_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, uplo, n, A, lda, B, ldb, W, lwork, + params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsygvj_bufferSize( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, const double *A, int lda, const double *B, + int ldb, const double *W, int *lwork, syevjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cublasFillMode_t, int, const double *, int, const double *, int, + const double *, int *, syevjInfo_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDsygvj_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, uplo, n, A, lda, B, ldb, W, lwork, + params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnChegvj_bufferSize( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, const cuComplex *A, int lda, + const cuComplex *B, int ldb, const float *W, int *lwork, + syevjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cublasFillMode_t, int, const cuComplex *, int, const cuComplex *, int, + const float *, int *, syevjInfo_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnChegvj_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, uplo, n, A, lda, B, ldb, W, lwork, + params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZhegvj_bufferSize( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, const cuDoubleComplex *A, int lda, + const cuDoubleComplex *B, int ldb, const double *W, int *lwork, + syevjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cublasFillMode_t, int, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, const double *, int *, syevjInfo_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZhegvj_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, uplo, n, A, lda, B, ldb, W, lwork, + params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSsygvj( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, float *A, int lda, float *B, int ldb, + float *W, float *work, int lwork, int *info, syevjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cublasFillMode_t, int, float *, int, float *, int, float *, float *, int, + int *, syevjInfo_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSsygvj"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, uplo, n, A, lda, B, ldb, W, work, lwork, + info, params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDsygvj( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, double *A, int lda, double *B, int ldb, + double *W, double *work, int lwork, int *info, syevjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cublasFillMode_t, int, double *, int, double *, int, double *, double *, + int, int *, syevjInfo_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDsygvj"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, uplo, n, A, lda, B, ldb, W, work, lwork, + info, params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnChegvj( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, cuComplex *A, int lda, cuComplex *B, int ldb, + float *W, cuComplex *work, int lwork, int *info, syevjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cublasFillMode_t, int, cuComplex *, int, cuComplex *, int, float *, + cuComplex *, int, int *, syevjInfo_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnChegvj"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, uplo, n, A, lda, B, ldb, W, work, lwork, + info, params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZhegvj( + cusolverDnHandle_t handle, cusolverEigType_t itype, cusolverEigMode_t jobz, + cublasFillMode_t uplo, int n, cuDoubleComplex *A, int lda, + cuDoubleComplex *B, int ldb, double *W, cuDoubleComplex *work, int lwork, + int *info, syevjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigType_t, cusolverEigMode_t, + cublasFillMode_t, int, cuDoubleComplex *, int, cuDoubleComplex *, int, + double *, cuDoubleComplex *, int, int *, syevjInfo_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZhegvj"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, itype, jobz, uplo, n, A, lda, B, ldb, W, work, lwork, + info, params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCreateGesvdjInfo(gesvdjInfo_t *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(gesvdjInfo_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCreateGesvdjInfo"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDestroyGesvdjInfo(gesvdjInfo_t info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(gesvdjInfo_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDestroyGesvdjInfo"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnXgesvdjSetTolerance(gesvdjInfo_t info, + double tolerance) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(gesvdjInfo_t, double); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnXgesvdjSetTolerance"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info, tolerance); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnXgesvdjSetMaxSweeps(gesvdjInfo_t info, + int max_sweeps) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(gesvdjInfo_t, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnXgesvdjSetMaxSweeps"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info, max_sweeps); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnXgesvdjSetSortEig(gesvdjInfo_t info, + int sort_svd) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(gesvdjInfo_t, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnXgesvdjSetSortEig"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info, sort_svd); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnXgesvdjGetResidual( + cusolverDnHandle_t handle, gesvdjInfo_t info, double *residual) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, + gesvdjInfo_t, double *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnXgesvdjGetResidual"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, info, residual); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnXgesvdjGetSweeps( + cusolverDnHandle_t handle, gesvdjInfo_t info, int *executed_sweeps) { + using FuncPtr = + cusolverStatus_t(CUSOLVERAPI *)(cusolverDnHandle_t, gesvdjInfo_t, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnXgesvdjGetSweeps"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, info, executed_sweeps); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSgesvdjBatched_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int m, int n, + const float *A, int lda, const float *S, const float *U, int ldu, + const float *V, int ldv, int *lwork, gesvdjInfo_t params, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, const float *, int, + const float *, const float *, int, const float *, int, int *, + gesvdjInfo_t, int); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusolverDnSgesvdjBatched_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, m, n, A, lda, S, U, ldu, V, ldv, lwork, params, + batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDgesvdjBatched_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int m, int n, + const double *A, int lda, const double *S, const double *U, int ldu, + const double *V, int ldv, int *lwork, gesvdjInfo_t params, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, const double *, int, + const double *, const double *, int, const double *, int, int *, + gesvdjInfo_t, int); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusolverDnDgesvdjBatched_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, m, n, A, lda, S, U, ldu, V, ldv, lwork, params, + batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCgesvdjBatched_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int m, int n, + const cuComplex *A, int lda, const float *S, const cuComplex *U, int ldu, + const cuComplex *V, int ldv, int *lwork, gesvdjInfo_t params, + int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, const cuComplex *, int, + const float *, const cuComplex *, int, const cuComplex *, int, int *, + gesvdjInfo_t, int); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusolverDnCgesvdjBatched_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, m, n, A, lda, S, U, ldu, V, ldv, lwork, params, + batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZgesvdjBatched_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int m, int n, + const cuDoubleComplex *A, int lda, const double *S, + const cuDoubleComplex *U, int ldu, const cuDoubleComplex *V, int ldv, + int *lwork, gesvdjInfo_t params, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, const cuDoubleComplex *, + int, const double *, const cuDoubleComplex *, int, + const cuDoubleComplex *, int, int *, gesvdjInfo_t, int); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusolverDnZgesvdjBatched_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, m, n, A, lda, S, U, ldu, V, ldv, lwork, params, + batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSgesvdjBatched( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int m, int n, float *A, + int lda, float *S, float *U, int ldu, float *V, int ldv, float *work, + int lwork, int *info, gesvdjInfo_t params, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, float *, int, float *, + float *, int, float *, int, float *, int, int *, gesvdjInfo_t, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSgesvdjBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, m, n, A, lda, S, U, ldu, V, ldv, work, lwork, + info, params, batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDgesvdjBatched( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int m, int n, double *A, + int lda, double *S, double *U, int ldu, double *V, int ldv, double *work, + int lwork, int *info, gesvdjInfo_t params, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, double *, int, double *, + double *, int, double *, int, double *, int, int *, gesvdjInfo_t, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDgesvdjBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, m, n, A, lda, S, U, ldu, V, ldv, work, lwork, + info, params, batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCgesvdjBatched( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int m, int n, + cuComplex *A, int lda, float *S, cuComplex *U, int ldu, cuComplex *V, + int ldv, cuComplex *work, int lwork, int *info, gesvdjInfo_t params, + int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, cuComplex *, int, + float *, cuComplex *, int, cuComplex *, int, cuComplex *, int, int *, + gesvdjInfo_t, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCgesvdjBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, m, n, A, lda, S, U, ldu, V, ldv, work, lwork, + info, params, batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZgesvdjBatched( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int m, int n, + cuDoubleComplex *A, int lda, double *S, cuDoubleComplex *U, int ldu, + cuDoubleComplex *V, int ldv, cuDoubleComplex *work, int lwork, int *info, + gesvdjInfo_t params, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, cuDoubleComplex *, int, + double *, cuDoubleComplex *, int, cuDoubleComplex *, int, + cuDoubleComplex *, int, int *, gesvdjInfo_t, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZgesvdjBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, m, n, A, lda, S, U, ldu, V, ldv, work, lwork, + info, params, batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSgesvdj_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int econ, int m, int n, + const float *A, int lda, const float *S, const float *U, int ldu, + const float *V, int ldv, int *lwork, gesvdjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, int, const float *, int, + const float *, const float *, int, const float *, int, int *, + gesvdjInfo_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSgesvdj_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, lwork, + params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDgesvdj_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int econ, int m, int n, + const double *A, int lda, const double *S, const double *U, int ldu, + const double *V, int ldv, int *lwork, gesvdjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, int, const double *, int, + const double *, const double *, int, const double *, int, int *, + gesvdjInfo_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDgesvdj_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, lwork, + params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCgesvdj_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int econ, int m, int n, + const cuComplex *A, int lda, const float *S, const cuComplex *U, int ldu, + const cuComplex *V, int ldv, int *lwork, gesvdjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, int, const cuComplex *, + int, const float *, const cuComplex *, int, const cuComplex *, int, int *, + gesvdjInfo_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCgesvdj_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, lwork, + params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZgesvdj_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int econ, int m, int n, + const cuDoubleComplex *A, int lda, const double *S, + const cuDoubleComplex *U, int ldu, const cuDoubleComplex *V, int ldv, + int *lwork, gesvdjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, int, + const cuDoubleComplex *, int, const double *, const cuDoubleComplex *, + int, const cuDoubleComplex *, int, int *, gesvdjInfo_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZgesvdj_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, lwork, + params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSgesvdj( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int econ, int m, int n, + float *A, int lda, float *S, float *U, int ldu, float *V, int ldv, + float *work, int lwork, int *info, gesvdjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, int, float *, int, + float *, float *, int, float *, int, float *, int, int *, gesvdjInfo_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSgesvdj"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, work, + lwork, info, params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDgesvdj( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int econ, int m, int n, + double *A, int lda, double *S, double *U, int ldu, double *V, int ldv, + double *work, int lwork, int *info, gesvdjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, int, double *, int, + double *, double *, int, double *, int, double *, int, int *, + gesvdjInfo_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDgesvdj"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, work, + lwork, info, params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCgesvdj( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int econ, int m, int n, + cuComplex *A, int lda, float *S, cuComplex *U, int ldu, cuComplex *V, + int ldv, cuComplex *work, int lwork, int *info, gesvdjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, int, cuComplex *, int, + float *, cuComplex *, int, cuComplex *, int, cuComplex *, int, int *, + gesvdjInfo_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCgesvdj"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, work, + lwork, info, params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZgesvdj( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int econ, int m, int n, + cuDoubleComplex *A, int lda, double *S, cuDoubleComplex *U, int ldu, + cuDoubleComplex *V, int ldv, cuDoubleComplex *work, int lwork, int *info, + gesvdjInfo_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, int, cuDoubleComplex *, + int, double *, cuDoubleComplex *, int, cuDoubleComplex *, int, + cuDoubleComplex *, int, int *, gesvdjInfo_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZgesvdj"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, work, + lwork, info, params); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSgesvdaStridedBatched_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int rank, int m, int n, + const float *d_A, int lda, long long int strideA, const float *d_S, + long long int strideS, const float *d_U, int ldu, long long int strideU, + const float *d_V, int ldv, long long int strideV, int *lwork, + int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, int, const float *, int, + long long, const float *, long long, const float *, int, long long, + const float *, int, long long, int *, int); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusolverDnSgesvdaStridedBatched_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, rank, m, n, d_A, lda, strideA, d_S, strideS, + d_U, ldu, strideU, d_V, ldv, strideV, lwork, batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDgesvdaStridedBatched_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int rank, int m, int n, + const double *d_A, int lda, long long int strideA, const double *d_S, + long long int strideS, const double *d_U, int ldu, long long int strideU, + const double *d_V, int ldv, long long int strideV, int *lwork, + int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, int, const double *, int, + long long, const double *, long long, const double *, int, long long, + const double *, int, long long, int *, int); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusolverDnDgesvdaStridedBatched_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, rank, m, n, d_A, lda, strideA, d_S, strideS, + d_U, ldu, strideU, d_V, ldv, strideV, lwork, batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCgesvdaStridedBatched_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int rank, int m, int n, + const cuComplex *d_A, int lda, long long int strideA, const float *d_S, + long long int strideS, const cuComplex *d_U, int ldu, long long int strideU, + const cuComplex *d_V, int ldv, long long int strideV, int *lwork, + int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, int, const cuComplex *, + int, long long, const float *, long long, const cuComplex *, int, + long long, const cuComplex *, int, long long, int *, int); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusolverDnCgesvdaStridedBatched_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, rank, m, n, d_A, lda, strideA, d_S, strideS, + d_U, ldu, strideU, d_V, ldv, strideV, lwork, batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZgesvdaStridedBatched_bufferSize( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int rank, int m, int n, + const cuDoubleComplex *d_A, int lda, long long int strideA, + const double *d_S, long long int strideS, const cuDoubleComplex *d_U, + int ldu, long long int strideU, const cuDoubleComplex *d_V, int ldv, + long long int strideV, int *lwork, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, int, + const cuDoubleComplex *, int, long long, const double *, long long, + const cuDoubleComplex *, int, long long, const cuDoubleComplex *, int, + long long, int *, int); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusolverDnZgesvdaStridedBatched_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, rank, m, n, d_A, lda, strideA, d_S, strideS, + d_U, ldu, strideU, d_V, ldv, strideV, lwork, batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSgesvdaStridedBatched( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int rank, int m, int n, + const float *d_A, int lda, long long int strideA, float *d_S, + long long int strideS, float *d_U, int ldu, long long int strideU, + float *d_V, int ldv, long long int strideV, float *d_work, int lwork, + int *d_info, double *h_R_nrmF, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, int, const float *, int, + long long, float *, long long, float *, int, long long, float *, int, + long long, float *, int, int *, double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSgesvdaStridedBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, rank, m, n, d_A, lda, strideA, d_S, strideS, + d_U, ldu, strideU, d_V, ldv, strideV, d_work, lwork, d_info, + h_R_nrmF, batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnDgesvdaStridedBatched( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int rank, int m, int n, + const double *d_A, int lda, long long int strideA, double *d_S, + long long int strideS, double *d_U, int ldu, long long int strideU, + double *d_V, int ldv, long long int strideV, double *d_work, int lwork, + int *d_info, double *h_R_nrmF, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, int, const double *, int, + long long, double *, long long, double *, int, long long, double *, int, + long long, double *, int, int *, double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDgesvdaStridedBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, rank, m, n, d_A, lda, strideA, d_S, strideS, + d_U, ldu, strideU, d_V, ldv, strideV, d_work, lwork, d_info, + h_R_nrmF, batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnCgesvdaStridedBatched( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int rank, int m, int n, + const cuComplex *d_A, int lda, long long int strideA, float *d_S, + long long int strideS, cuComplex *d_U, int ldu, long long int strideU, + cuComplex *d_V, int ldv, long long int strideV, cuComplex *d_work, + int lwork, int *d_info, double *h_R_nrmF, int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, int, const cuComplex *, + int, long long, float *, long long, cuComplex *, int, long long, + cuComplex *, int, long long, cuComplex *, int, int *, double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCgesvdaStridedBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, rank, m, n, d_A, lda, strideA, d_S, strideS, + d_U, ldu, strideU, d_V, ldv, strideV, d_work, lwork, d_info, + h_R_nrmF, batchSize); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnZgesvdaStridedBatched( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int rank, int m, int n, + const cuDoubleComplex *d_A, int lda, long long int strideA, double *d_S, + long long int strideS, cuDoubleComplex *d_U, int ldu, long long int strideU, + cuDoubleComplex *d_V, int ldv, long long int strideV, + cuDoubleComplex *d_work, int lwork, int *d_info, double *h_R_nrmF, + int batchSize) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverEigMode_t, int, int, int, + const cuDoubleComplex *, int, long long, double *, long long, + cuDoubleComplex *, int, long long, cuDoubleComplex *, int, long long, + cuDoubleComplex *, int, int *, double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnZgesvdaStridedBatched"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, jobz, rank, m, n, d_A, lda, strideA, d_S, strideS, + d_U, ldu, strideU, d_V, ldv, strideV, d_work, lwork, d_info, + h_R_nrmF, batchSize); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnCreateParams(cusolverDnParams_t *params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnParams_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnCreateParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(params); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnDestroyParams(cusolverDnParams_t params) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)(cusolverDnParams_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnDestroyParams"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(params); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnSetAdvOptions(cusolverDnParams_t params, + cusolverDnFunction_t function, cusolverAlgMode_t algo) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnParams_t, cusolverDnFunction_t, cusolverAlgMode_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSetAdvOptions"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(params, function, algo); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnPotrf_bufferSize( + cusolverDnHandle_t handle, cusolverDnParams_t params, cublasFillMode_t uplo, + int64_t n, cudaDataType dataTypeA, const void *A, int64_t lda, + cudaDataType computeType, size_t *workspaceInBytes) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverDnParams_t, cublasFillMode_t, int64_t, + cudaDataType, const void *, int64_t, cudaDataType, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnPotrf_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, params, uplo, n, dataTypeA, A, lda, computeType, + workspaceInBytes); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnPotrf(cusolverDnHandle_t handle, cusolverDnParams_t params, + cublasFillMode_t uplo, int64_t n, cudaDataType dataTypeA, + void *A, int64_t lda, cudaDataType computeType, void *pBuffer, + size_t workspaceInBytes, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverDnParams_t, cublasFillMode_t, int64_t, + cudaDataType, void *, int64_t, cudaDataType, void *, size_t, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnPotrf"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, params, uplo, n, dataTypeA, A, lda, computeType, + pBuffer, workspaceInBytes, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnPotrs( + cusolverDnHandle_t handle, cusolverDnParams_t params, cublasFillMode_t uplo, + int64_t n, int64_t nrhs, cudaDataType dataTypeA, const void *A, int64_t lda, + cudaDataType dataTypeB, void *B, int64_t ldb, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverDnParams_t, cublasFillMode_t, int64_t, + int64_t, cudaDataType, const void *, int64_t, cudaDataType, void *, + int64_t, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnPotrs"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, params, uplo, n, nrhs, dataTypeA, A, lda, dataTypeB, + B, ldb, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnGeqrf_bufferSize( + cusolverDnHandle_t handle, cusolverDnParams_t params, int64_t m, int64_t n, + cudaDataType dataTypeA, const void *A, int64_t lda, + cudaDataType dataTypeTau, const void *tau, cudaDataType computeType, + size_t *workspaceInBytes) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverDnParams_t, int64_t, int64_t, cudaDataType, + const void *, int64_t, cudaDataType, const void *, cudaDataType, + size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnGeqrf_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, params, m, n, dataTypeA, A, lda, dataTypeTau, tau, + computeType, workspaceInBytes); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnGeqrf(cusolverDnHandle_t handle, cusolverDnParams_t params, int64_t m, + int64_t n, cudaDataType dataTypeA, void *A, int64_t lda, + cudaDataType dataTypeTau, void *tau, cudaDataType computeType, + void *pBuffer, size_t workspaceInBytes, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverDnParams_t, int64_t, int64_t, cudaDataType, + void *, int64_t, cudaDataType, void *, cudaDataType, void *, size_t, + int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnGeqrf"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, params, m, n, dataTypeA, A, lda, dataTypeTau, tau, + computeType, pBuffer, workspaceInBytes, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnGetrf_bufferSize( + cusolverDnHandle_t handle, cusolverDnParams_t params, int64_t m, int64_t n, + cudaDataType dataTypeA, const void *A, int64_t lda, + cudaDataType computeType, size_t *workspaceInBytes) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverDnParams_t, int64_t, int64_t, cudaDataType, + const void *, int64_t, cudaDataType, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnGetrf_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, params, m, n, dataTypeA, A, lda, computeType, + workspaceInBytes); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnGetrf(cusolverDnHandle_t handle, cusolverDnParams_t params, int64_t m, + int64_t n, cudaDataType dataTypeA, void *A, int64_t lda, + int64_t *ipiv, cudaDataType computeType, void *pBuffer, + size_t workspaceInBytes, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverDnParams_t, int64_t, int64_t, cudaDataType, + void *, int64_t, int64_t *, cudaDataType, void *, size_t, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnGetrf"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, params, m, n, dataTypeA, A, lda, ipiv, computeType, + pBuffer, workspaceInBytes, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnGetrs( + cusolverDnHandle_t handle, cusolverDnParams_t params, + cublasOperation_t trans, int64_t n, int64_t nrhs, cudaDataType dataTypeA, + const void *A, int64_t lda, const int64_t *ipiv, cudaDataType dataTypeB, + void *B, int64_t ldb, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverDnParams_t, cublasOperation_t, int64_t, + int64_t, cudaDataType, const void *, int64_t, const int64_t *, + cudaDataType, void *, int64_t, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnGetrs"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, params, trans, n, nrhs, dataTypeA, A, lda, ipiv, + dataTypeB, B, ldb, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSyevd_bufferSize( + cusolverDnHandle_t handle, cusolverDnParams_t params, + cusolverEigMode_t jobz, cublasFillMode_t uplo, int64_t n, + cudaDataType dataTypeA, const void *A, int64_t lda, cudaDataType dataTypeW, + const void *W, cudaDataType computeType, size_t *workspaceInBytes) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverDnParams_t, cusolverEigMode_t, + cublasFillMode_t, int64_t, cudaDataType, const void *, int64_t, + cudaDataType, const void *, cudaDataType, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSyevd_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, params, jobz, uplo, n, dataTypeA, A, lda, dataTypeW, + W, computeType, workspaceInBytes); +} + +cusolverStatus_t CUSOLVERAPI +cusolverDnSyevd(cusolverDnHandle_t handle, cusolverDnParams_t params, + cusolverEigMode_t jobz, cublasFillMode_t uplo, int64_t n, + cudaDataType dataTypeA, void *A, int64_t lda, + cudaDataType dataTypeW, void *W, cudaDataType computeType, + void *pBuffer, size_t workspaceInBytes, int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverDnParams_t, cusolverEigMode_t, + cublasFillMode_t, int64_t, cudaDataType, void *, int64_t, cudaDataType, + void *, cudaDataType, void *, size_t, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSyevd"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, params, jobz, uplo, n, dataTypeA, A, lda, dataTypeW, + W, computeType, pBuffer, workspaceInBytes, info); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSyevdx_bufferSize( + cusolverDnHandle_t handle, cusolverDnParams_t params, + cusolverEigMode_t jobz, cusolverEigRange_t range, cublasFillMode_t uplo, + int64_t n, cudaDataType dataTypeA, const void *A, int64_t lda, void *vl, + void *vu, int64_t il, int64_t iu, int64_t *h_meig, cudaDataType dataTypeW, + const void *W, cudaDataType computeType, size_t *workspaceInBytes) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverDnParams_t, cusolverEigMode_t, + cusolverEigRange_t, cublasFillMode_t, int64_t, cudaDataType, const void *, + int64_t, void *, void *, int64_t, int64_t, int64_t *, cudaDataType, + const void *, cudaDataType, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSyevdx_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, params, jobz, range, uplo, n, dataTypeA, A, lda, vl, + vu, il, iu, h_meig, dataTypeW, W, computeType, + workspaceInBytes); +} + +cusolverStatus_t CUSOLVERAPI cusolverDnSyevdx( + cusolverDnHandle_t handle, cusolverDnParams_t params, + cusolverEigMode_t jobz, cusolverEigRange_t range, cublasFillMode_t uplo, + int64_t n, cudaDataType dataTypeA, void *A, int64_t lda, void *vl, void *vu, + int64_t il, int64_t iu, int64_t *meig64, cudaDataType dataTypeW, void *W, + cudaDataType computeType, void *pBuffer, size_t workspaceInBytes, + int *info) { + using FuncPtr = cusolverStatus_t(CUSOLVERAPI *)( + cusolverDnHandle_t, cusolverDnParams_t, cusolverEigMode_t, + cusolverEigRange_t, cublasFillMode_t, int64_t, cudaDataType, void *, + int64_t, void *, void *, int64_t, int64_t, int64_t *, cudaDataType, + void *, cudaDataType, void *, size_t, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusolverDnSyevdx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, params, jobz, range, uplo, n, dataTypeA, A, lda, vl, + vu, il, iu, meig64, dataTypeW, W, computeType, pBuffer, + workspaceInBytes, info); +} + +} // extern "C" diff --git a/tensorflow/stream_executor/cuda/cusparse_10_1.inc b/tensorflow/stream_executor/cuda/cusparse_10_1.inc index 3b7f3815829..e94aa081b8c 100644 --- a/tensorflow/stream_executor/cuda/cusparse_10_1.inc +++ b/tensorflow/stream_executor/cuda/cusparse_10_1.inc @@ -8225,6 +8225,6 @@ cusparseStatus_t CUSPARSEAPI cusparseConstrainedGeMM_bufferSize( bufferSize); } -#endif // _WIN32 +#endif // _WIN32 } // extern "C" diff --git a/tensorflow/stream_executor/cuda/cusparse_10_2.inc b/tensorflow/stream_executor/cuda/cusparse_10_2.inc index 3b7f3815829..e94aa081b8c 100644 --- a/tensorflow/stream_executor/cuda/cusparse_10_2.inc +++ b/tensorflow/stream_executor/cuda/cusparse_10_2.inc @@ -8225,6 +8225,6 @@ cusparseStatus_t CUSPARSEAPI cusparseConstrainedGeMM_bufferSize( bufferSize); } -#endif // _WIN32 +#endif // _WIN32 } // extern "C" diff --git a/tensorflow/stream_executor/cuda/cusparse_11_0.inc b/tensorflow/stream_executor/cuda/cusparse_11_0.inc new file mode 100644 index 00000000000..31eb65c24ec --- /dev/null +++ b/tensorflow/stream_executor/cuda/cusparse_11_0.inc @@ -0,0 +1,6584 @@ +// Auto-generated, do not edit. + +#define CUSPARSE_DEPRECATED(new_func) + +extern "C" { + +cusparseStatus_t CUSPARSEAPI cusparseCreate(cusparseHandle_t *handle) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(cusparseHandle_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCreate"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle); +} + +cusparseStatus_t CUSPARSEAPI cusparseDestroy(cusparseHandle_t handle) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(cusparseHandle_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDestroy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle); +} + +cusparseStatus_t CUSPARSEAPI cusparseGetVersion(cusparseHandle_t handle, + int *version) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(cusparseHandle_t, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseGetVersion"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, version); +} + +cusparseStatus_t CUSPARSEAPI cusparseGetProperty(libraryPropertyType type, + int *value) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(libraryPropertyType, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseGetProperty"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(type, value); +} + +const char *CUSPARSEAPI cusparseGetErrorName(cusparseStatus_t status) { + using FuncPtr = const char *(CUSPARSEAPI *)(cusparseStatus_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseGetErrorName"); + if (!func_ptr) return "cusparseGetErrorName symbol not found."; + return func_ptr(status); +} + +const char *CUSPARSEAPI cusparseGetErrorString(cusparseStatus_t status) { + using FuncPtr = const char *(CUSPARSEAPI *)(cusparseStatus_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseGetErrorString"); + if (!func_ptr) return "cusparseGetErrorString symbol not found."; + return func_ptr(status); +} + +cusparseStatus_t CUSPARSEAPI cusparseSetStream(cusparseHandle_t handle, + cudaStream_t streamId) { + using FuncPtr = + cusparseStatus_t(CUSPARSEAPI *)(cusparseHandle_t, cudaStream_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSetStream"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, streamId); +} + +cusparseStatus_t CUSPARSEAPI cusparseGetStream(cusparseHandle_t handle, + cudaStream_t *streamId) { + using FuncPtr = + cusparseStatus_t(CUSPARSEAPI *)(cusparseHandle_t, cudaStream_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseGetStream"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, streamId); +} + +cusparseStatus_t CUSPARSEAPI +cusparseGetPointerMode(cusparseHandle_t handle, cusparsePointerMode_t *mode) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(cusparseHandle_t, + cusparsePointerMode_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseGetPointerMode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, mode); +} + +cusparseStatus_t CUSPARSEAPI +cusparseSetPointerMode(cusparseHandle_t handle, cusparsePointerMode_t mode) { + using FuncPtr = + cusparseStatus_t(CUSPARSEAPI *)(cusparseHandle_t, cusparsePointerMode_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSetPointerMode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, mode); +} + +cusparseStatus_t CUSPARSEAPI +cusparseCreateMatDescr(cusparseMatDescr_t *descrA) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(cusparseMatDescr_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCreateMatDescr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(descrA); +} + +cusparseStatus_t CUSPARSEAPI +cusparseDestroyMatDescr(cusparseMatDescr_t descrA) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(cusparseMatDescr_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDestroyMatDescr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(descrA); +} + +cusparseStatus_t CUSPARSEAPI +cusparseCopyMatDescr(cusparseMatDescr_t dest, const cusparseMatDescr_t src) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(cusparseMatDescr_t, + const cusparseMatDescr_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCopyMatDescr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dest, src); +} + +cusparseStatus_t CUSPARSEAPI cusparseSetMatType(cusparseMatDescr_t descrA, + cusparseMatrixType_t type) { + using FuncPtr = + cusparseStatus_t(CUSPARSEAPI *)(cusparseMatDescr_t, cusparseMatrixType_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSetMatType"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(descrA, type); +} + +cusparseStatus_t CUSPARSEAPI +cusparseSetMatFillMode(cusparseMatDescr_t descrA, cusparseFillMode_t fillMode) { + using FuncPtr = + cusparseStatus_t(CUSPARSEAPI *)(cusparseMatDescr_t, cusparseFillMode_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSetMatFillMode"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(descrA, fillMode); +} + +cusparseStatus_t CUSPARSEAPI +cusparseSetMatDiagType(cusparseMatDescr_t descrA, cusparseDiagType_t diagType) { + using FuncPtr = + cusparseStatus_t(CUSPARSEAPI *)(cusparseMatDescr_t, cusparseDiagType_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSetMatDiagType"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(descrA, diagType); +} + +cusparseStatus_t CUSPARSEAPI cusparseSetMatIndexBase(cusparseMatDescr_t descrA, + cusparseIndexBase_t base) { + using FuncPtr = + cusparseStatus_t(CUSPARSEAPI *)(cusparseMatDescr_t, cusparseIndexBase_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSetMatIndexBase"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(descrA, base); +} + +cusparseStatus_t CUSPARSEAPI cusparseCreateCsrsv2Info(csrsv2Info_t *info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(csrsv2Info_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCreateCsrsv2Info"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusparseStatus_t CUSPARSEAPI cusparseDestroyCsrsv2Info(csrsv2Info_t info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(csrsv2Info_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDestroyCsrsv2Info"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusparseStatus_t CUSPARSEAPI cusparseCreateCsric02Info(csric02Info_t *info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(csric02Info_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCreateCsric02Info"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusparseStatus_t CUSPARSEAPI cusparseDestroyCsric02Info(csric02Info_t info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(csric02Info_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDestroyCsric02Info"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusparseStatus_t CUSPARSEAPI cusparseCreateBsric02Info(bsric02Info_t *info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(bsric02Info_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCreateBsric02Info"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusparseStatus_t CUSPARSEAPI cusparseDestroyBsric02Info(bsric02Info_t info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(bsric02Info_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDestroyBsric02Info"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusparseStatus_t CUSPARSEAPI cusparseCreateCsrilu02Info(csrilu02Info_t *info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(csrilu02Info_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCreateCsrilu02Info"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusparseStatus_t CUSPARSEAPI cusparseDestroyCsrilu02Info(csrilu02Info_t info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(csrilu02Info_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDestroyCsrilu02Info"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusparseStatus_t CUSPARSEAPI cusparseCreateBsrilu02Info(bsrilu02Info_t *info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(bsrilu02Info_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCreateBsrilu02Info"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusparseStatus_t CUSPARSEAPI cusparseDestroyBsrilu02Info(bsrilu02Info_t info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(bsrilu02Info_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDestroyBsrilu02Info"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusparseStatus_t CUSPARSEAPI cusparseCreateBsrsv2Info(bsrsv2Info_t *info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(bsrsv2Info_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCreateBsrsv2Info"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusparseStatus_t CUSPARSEAPI cusparseDestroyBsrsv2Info(bsrsv2Info_t info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(bsrsv2Info_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDestroyBsrsv2Info"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusparseStatus_t CUSPARSEAPI cusparseCreateBsrsm2Info(bsrsm2Info_t *info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(bsrsm2Info_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCreateBsrsm2Info"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusparseStatus_t CUSPARSEAPI cusparseDestroyBsrsm2Info(bsrsm2Info_t info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(bsrsm2Info_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDestroyBsrsm2Info"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusparseStatus_t CUSPARSEAPI cusparseCreateCsru2csrInfo(csru2csrInfo_t *info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(csru2csrInfo_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCreateCsru2csrInfo"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusparseStatus_t CUSPARSEAPI cusparseDestroyCsru2csrInfo(csru2csrInfo_t info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(csru2csrInfo_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDestroyCsru2csrInfo"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusparseStatus_t CUSPARSEAPI +cusparseCreateColorInfo(cusparseColorInfo_t *info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(cusparseColorInfo_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCreateColorInfo"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusparseStatus_t CUSPARSEAPI +cusparseDestroyColorInfo(cusparseColorInfo_t info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(cusparseColorInfo_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDestroyColorInfo"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusparseStatus_t CUSPARSEAPI cusparseSetColorAlgs(cusparseColorInfo_t info, + cusparseColorAlg_t alg) { + using FuncPtr = + cusparseStatus_t(CUSPARSEAPI *)(cusparseColorInfo_t, cusparseColorAlg_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSetColorAlgs"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info, alg); +} + +cusparseStatus_t CUSPARSEAPI cusparseGetColorAlgs(cusparseColorInfo_t info, + cusparseColorAlg_t *alg) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(cusparseColorInfo_t, + cusparseColorAlg_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseGetColorAlgs"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info, alg); +} + +cusparseStatus_t CUSPARSEAPI cusparseCreatePruneInfo(pruneInfo_t *info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(pruneInfo_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCreatePruneInfo"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusparseStatus_t CUSPARSEAPI cusparseDestroyPruneInfo(pruneInfo_t info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(pruneInfo_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDestroyPruneInfo"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusparseStatus_t CUSPARSEAPI cusparseSaxpyi(cusparseHandle_t handle, int nnz, + const float *alpha, + const float *xVal, const int *xInd, + float *y, + cusparseIndexBase_t idxBase) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, const float *, const float *, const int *, float *, + cusparseIndexBase_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSaxpyi"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, nnz, alpha, xVal, xInd, y, idxBase); +} + +cusparseStatus_t CUSPARSEAPI cusparseDaxpyi(cusparseHandle_t handle, int nnz, + const double *alpha, + const double *xVal, const int *xInd, + double *y, + cusparseIndexBase_t idxBase) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, const double *, const double *, const int *, + double *, cusparseIndexBase_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDaxpyi"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, nnz, alpha, xVal, xInd, y, idxBase); +} + +cusparseStatus_t CUSPARSEAPI cusparseCaxpyi(cusparseHandle_t handle, int nnz, + const cuComplex *alpha, + const cuComplex *xVal, + const int *xInd, cuComplex *y, + cusparseIndexBase_t idxBase) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, const cuComplex *, const cuComplex *, const int *, + cuComplex *, cusparseIndexBase_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCaxpyi"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, nnz, alpha, xVal, xInd, y, idxBase); +} + +cusparseStatus_t CUSPARSEAPI cusparseZaxpyi(cusparseHandle_t handle, int nnz, + const cuDoubleComplex *alpha, + const cuDoubleComplex *xVal, + const int *xInd, cuDoubleComplex *y, + cusparseIndexBase_t idxBase) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, const cuDoubleComplex *, const cuDoubleComplex *, + const int *, cuDoubleComplex *, cusparseIndexBase_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZaxpyi"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, nnz, alpha, xVal, xInd, y, idxBase); +} + +cusparseStatus_t CUSPARSEAPI cusparseSgthr(cusparseHandle_t handle, int nnz, + const float *y, float *xVal, + const int *xInd, + cusparseIndexBase_t idxBase) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, const float *, float *, const int *, + cusparseIndexBase_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSgthr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, nnz, y, xVal, xInd, idxBase); +} + +cusparseStatus_t CUSPARSEAPI cusparseDgthr(cusparseHandle_t handle, int nnz, + const double *y, double *xVal, + const int *xInd, + cusparseIndexBase_t idxBase) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, const double *, double *, const int *, + cusparseIndexBase_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDgthr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, nnz, y, xVal, xInd, idxBase); +} + +cusparseStatus_t CUSPARSEAPI cusparseCgthr(cusparseHandle_t handle, int nnz, + const cuComplex *y, cuComplex *xVal, + const int *xInd, + cusparseIndexBase_t idxBase) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, const cuComplex *, cuComplex *, const int *, + cusparseIndexBase_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCgthr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, nnz, y, xVal, xInd, idxBase); +} + +cusparseStatus_t CUSPARSEAPI cusparseZgthr(cusparseHandle_t handle, int nnz, + const cuDoubleComplex *y, + cuDoubleComplex *xVal, + const int *xInd, + cusparseIndexBase_t idxBase) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, const cuDoubleComplex *, cuDoubleComplex *, + const int *, cusparseIndexBase_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZgthr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, nnz, y, xVal, xInd, idxBase); +} + +cusparseStatus_t CUSPARSEAPI cusparseSgthrz(cusparseHandle_t handle, int nnz, + float *y, float *xVal, + const int *xInd, + cusparseIndexBase_t idxBase) { + using FuncPtr = + cusparseStatus_t(CUSPARSEAPI *)(cusparseHandle_t, int, float *, float *, + const int *, cusparseIndexBase_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSgthrz"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, nnz, y, xVal, xInd, idxBase); +} + +cusparseStatus_t CUSPARSEAPI cusparseDgthrz(cusparseHandle_t handle, int nnz, + double *y, double *xVal, + const int *xInd, + cusparseIndexBase_t idxBase) { + using FuncPtr = + cusparseStatus_t(CUSPARSEAPI *)(cusparseHandle_t, int, double *, double *, + const int *, cusparseIndexBase_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDgthrz"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, nnz, y, xVal, xInd, idxBase); +} + +cusparseStatus_t CUSPARSEAPI cusparseCgthrz(cusparseHandle_t handle, int nnz, + cuComplex *y, cuComplex *xVal, + const int *xInd, + cusparseIndexBase_t idxBase) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, cuComplex *, cuComplex *, const int *, + cusparseIndexBase_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCgthrz"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, nnz, y, xVal, xInd, idxBase); +} + +cusparseStatus_t CUSPARSEAPI cusparseZgthrz(cusparseHandle_t handle, int nnz, + cuDoubleComplex *y, + cuDoubleComplex *xVal, + const int *xInd, + cusparseIndexBase_t idxBase) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, cuDoubleComplex *, cuDoubleComplex *, const int *, + cusparseIndexBase_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZgthrz"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, nnz, y, xVal, xInd, idxBase); +} + +cusparseStatus_t CUSPARSEAPI cusparseSsctr(cusparseHandle_t handle, int nnz, + const float *xVal, const int *xInd, + float *y, + cusparseIndexBase_t idxBase) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(cusparseHandle_t, int, + const float *, const int *, + float *, cusparseIndexBase_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSsctr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, nnz, xVal, xInd, y, idxBase); +} + +cusparseStatus_t CUSPARSEAPI cusparseDsctr(cusparseHandle_t handle, int nnz, + const double *xVal, const int *xInd, + double *y, + cusparseIndexBase_t idxBase) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, const double *, const int *, double *, + cusparseIndexBase_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDsctr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, nnz, xVal, xInd, y, idxBase); +} + +cusparseStatus_t CUSPARSEAPI cusparseCsctr(cusparseHandle_t handle, int nnz, + const cuComplex *xVal, + const int *xInd, cuComplex *y, + cusparseIndexBase_t idxBase) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, const cuComplex *, const int *, cuComplex *, + cusparseIndexBase_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCsctr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, nnz, xVal, xInd, y, idxBase); +} + +cusparseStatus_t CUSPARSEAPI cusparseZsctr(cusparseHandle_t handle, int nnz, + const cuDoubleComplex *xVal, + const int *xInd, cuDoubleComplex *y, + cusparseIndexBase_t idxBase) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, const cuDoubleComplex *, const int *, + cuDoubleComplex *, cusparseIndexBase_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZsctr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, nnz, xVal, xInd, y, idxBase); +} + +cusparseStatus_t CUSPARSEAPI cusparseSroti(cusparseHandle_t handle, int nnz, + float *xVal, const int *xInd, + float *y, const float *c, + const float *s, + cusparseIndexBase_t idxBase) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, float *, const int *, float *, const float *, + const float *, cusparseIndexBase_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSroti"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, nnz, xVal, xInd, y, c, s, idxBase); +} + +cusparseStatus_t CUSPARSEAPI cusparseDroti(cusparseHandle_t handle, int nnz, + double *xVal, const int *xInd, + double *y, const double *c, + const double *s, + cusparseIndexBase_t idxBase) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, double *, const int *, double *, const double *, + const double *, cusparseIndexBase_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDroti"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, nnz, xVal, xInd, y, c, s, idxBase); +} + +cusparseStatus_t CUSPARSEAPI +cusparseSgemvi(cusparseHandle_t handle, cusparseOperation_t transA, int m, + int n, const float *alpha, const float *A, int lda, int nnz, + const float *xVal, const int *xInd, const float *beta, float *y, + cusparseIndexBase_t idxBase, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, const float *, + const float *, int, int, const float *, const int *, const float *, + float *, cusparseIndexBase_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSgemvi"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, n, alpha, A, lda, nnz, xVal, xInd, beta, y, + idxBase, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI +cusparseSgemvi_bufferSize(cusparseHandle_t handle, cusparseOperation_t transA, + int m, int n, int nnz, int *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSgemvi_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, n, nnz, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI +cusparseDgemvi(cusparseHandle_t handle, cusparseOperation_t transA, int m, + int n, const double *alpha, const double *A, int lda, int nnz, + const double *xVal, const int *xInd, const double *beta, + double *y, cusparseIndexBase_t idxBase, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, const double *, + const double *, int, int, const double *, const int *, const double *, + double *, cusparseIndexBase_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDgemvi"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, n, alpha, A, lda, nnz, xVal, xInd, beta, y, + idxBase, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI +cusparseDgemvi_bufferSize(cusparseHandle_t handle, cusparseOperation_t transA, + int m, int n, int nnz, int *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDgemvi_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, n, nnz, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseCgemvi( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int n, + const cuComplex *alpha, const cuComplex *A, int lda, int nnz, + const cuComplex *xVal, const int *xInd, const cuComplex *beta, cuComplex *y, + cusparseIndexBase_t idxBase, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, const cuComplex *, + const cuComplex *, int, int, const cuComplex *, const int *, + const cuComplex *, cuComplex *, cusparseIndexBase_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCgemvi"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, n, alpha, A, lda, nnz, xVal, xInd, beta, y, + idxBase, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI +cusparseCgemvi_bufferSize(cusparseHandle_t handle, cusparseOperation_t transA, + int m, int n, int nnz, int *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCgemvi_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, n, nnz, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseZgemvi( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int n, + const cuDoubleComplex *alpha, const cuDoubleComplex *A, int lda, int nnz, + const cuDoubleComplex *xVal, const int *xInd, const cuDoubleComplex *beta, + cuDoubleComplex *y, cusparseIndexBase_t idxBase, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, const cuDoubleComplex *, + const cuDoubleComplex *, int, int, const cuDoubleComplex *, const int *, + const cuDoubleComplex *, cuDoubleComplex *, cusparseIndexBase_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZgemvi"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, n, alpha, A, lda, nnz, xVal, xInd, beta, y, + idxBase, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI +cusparseZgemvi_bufferSize(cusparseHandle_t handle, cusparseOperation_t transA, + int m, int n, int nnz, int *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZgemvi_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, n, nnz, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseCsrmvEx_bufferSize( + cusparseHandle_t handle, cusparseAlgMode_t alg, cusparseOperation_t transA, + int m, int n, int nnz, const void *alpha, cudaDataType alphatype, + const cusparseMatDescr_t descrA, const void *csrValA, + cudaDataType csrValAtype, const int *csrRowPtrA, const int *csrColIndA, + const void *x, cudaDataType xtype, const void *beta, cudaDataType betatype, + void *y, cudaDataType ytype, cudaDataType executiontype, + size_t *bufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseAlgMode_t, cusparseOperation_t, int, int, int, + const void *, cudaDataType, const cusparseMatDescr_t, const void *, + cudaDataType, const int *, const int *, const void *, cudaDataType, + const void *, cudaDataType, void *, cudaDataType, cudaDataType, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCsrmvEx_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, alg, transA, m, n, nnz, alpha, alphatype, descrA, + csrValA, csrValAtype, csrRowPtrA, csrColIndA, x, xtype, beta, + betatype, y, ytype, executiontype, bufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseCsrmvEx( + cusparseHandle_t handle, cusparseAlgMode_t alg, cusparseOperation_t transA, + int m, int n, int nnz, const void *alpha, cudaDataType alphatype, + const cusparseMatDescr_t descrA, const void *csrValA, + cudaDataType csrValAtype, const int *csrRowPtrA, const int *csrColIndA, + const void *x, cudaDataType xtype, const void *beta, cudaDataType betatype, + void *y, cudaDataType ytype, cudaDataType executiontype, void *buffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseAlgMode_t, cusparseOperation_t, int, int, int, + const void *, cudaDataType, const cusparseMatDescr_t, const void *, + cudaDataType, const int *, const int *, const void *, cudaDataType, + const void *, cudaDataType, void *, cudaDataType, cudaDataType, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCsrmvEx"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, alg, transA, m, n, nnz, alpha, alphatype, descrA, + csrValA, csrValAtype, csrRowPtrA, csrColIndA, x, xtype, beta, + betatype, y, ytype, executiontype, buffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseSbsrmv( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, int mb, int nb, int nnzb, const float *alpha, + const cusparseMatDescr_t descrA, const float *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int blockDim, + const float *x, const float *beta, float *y) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, int, int, int, + const float *, const cusparseMatDescr_t, const float *, const int *, + const int *, int, const float *, const float *, float *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSbsrmv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, mb, nb, nnzb, alpha, descrA, + bsrSortedValA, bsrSortedRowPtrA, bsrSortedColIndA, blockDim, + x, beta, y); +} + +cusparseStatus_t CUSPARSEAPI cusparseDbsrmv( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, int mb, int nb, int nnzb, const double *alpha, + const cusparseMatDescr_t descrA, const double *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int blockDim, + const double *x, const double *beta, double *y) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, int, int, int, + const double *, const cusparseMatDescr_t, const double *, const int *, + const int *, int, const double *, const double *, double *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDbsrmv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, mb, nb, nnzb, alpha, descrA, + bsrSortedValA, bsrSortedRowPtrA, bsrSortedColIndA, blockDim, + x, beta, y); +} + +cusparseStatus_t CUSPARSEAPI +cusparseCbsrmv(cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, int mb, int nb, int nnzb, + const cuComplex *alpha, const cusparseMatDescr_t descrA, + const cuComplex *bsrSortedValA, const int *bsrSortedRowPtrA, + const int *bsrSortedColIndA, int blockDim, const cuComplex *x, + const cuComplex *beta, cuComplex *y) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, int, int, int, + const cuComplex *, const cusparseMatDescr_t, const cuComplex *, + const int *, const int *, int, const cuComplex *, const cuComplex *, + cuComplex *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCbsrmv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, mb, nb, nnzb, alpha, descrA, + bsrSortedValA, bsrSortedRowPtrA, bsrSortedColIndA, blockDim, + x, beta, y); +} + +cusparseStatus_t CUSPARSEAPI cusparseZbsrmv( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, int mb, int nb, int nnzb, + const cuDoubleComplex *alpha, const cusparseMatDescr_t descrA, + const cuDoubleComplex *bsrSortedValA, const int *bsrSortedRowPtrA, + const int *bsrSortedColIndA, int blockDim, const cuDoubleComplex *x, + const cuDoubleComplex *beta, cuDoubleComplex *y) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, int, int, int, + const cuDoubleComplex *, const cusparseMatDescr_t, + const cuDoubleComplex *, const int *, const int *, int, + const cuDoubleComplex *, const cuDoubleComplex *, cuDoubleComplex *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZbsrmv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, mb, nb, nnzb, alpha, descrA, + bsrSortedValA, bsrSortedRowPtrA, bsrSortedColIndA, blockDim, + x, beta, y); +} + +cusparseStatus_t CUSPARSEAPI +cusparseSbsrxmv(cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, int sizeOfMask, int mb, int nb, + int nnzb, const float *alpha, const cusparseMatDescr_t descrA, + const float *bsrSortedValA, const int *bsrSortedMaskPtrA, + const int *bsrSortedRowPtrA, const int *bsrSortedEndPtrA, + const int *bsrSortedColIndA, int blockDim, const float *x, + const float *beta, float *y) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, int, int, int, + int, const float *, const cusparseMatDescr_t, const float *, const int *, + const int *, const int *, const int *, int, const float *, const float *, + float *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSbsrxmv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, sizeOfMask, mb, nb, nnzb, alpha, descrA, + bsrSortedValA, bsrSortedMaskPtrA, bsrSortedRowPtrA, + bsrSortedEndPtrA, bsrSortedColIndA, blockDim, x, beta, y); +} + +cusparseStatus_t CUSPARSEAPI +cusparseDbsrxmv(cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, int sizeOfMask, int mb, int nb, + int nnzb, const double *alpha, const cusparseMatDescr_t descrA, + const double *bsrSortedValA, const int *bsrSortedMaskPtrA, + const int *bsrSortedRowPtrA, const int *bsrSortedEndPtrA, + const int *bsrSortedColIndA, int blockDim, const double *x, + const double *beta, double *y) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, int, int, int, + int, const double *, const cusparseMatDescr_t, const double *, + const int *, const int *, const int *, const int *, int, const double *, + const double *, double *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDbsrxmv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, sizeOfMask, mb, nb, nnzb, alpha, descrA, + bsrSortedValA, bsrSortedMaskPtrA, bsrSortedRowPtrA, + bsrSortedEndPtrA, bsrSortedColIndA, blockDim, x, beta, y); +} + +cusparseStatus_t CUSPARSEAPI cusparseCbsrxmv( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, int sizeOfMask, int mb, int nb, int nnzb, + const cuComplex *alpha, const cusparseMatDescr_t descrA, + const cuComplex *bsrSortedValA, const int *bsrSortedMaskPtrA, + const int *bsrSortedRowPtrA, const int *bsrSortedEndPtrA, + const int *bsrSortedColIndA, int blockDim, const cuComplex *x, + const cuComplex *beta, cuComplex *y) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, int, int, int, + int, const cuComplex *, const cusparseMatDescr_t, const cuComplex *, + const int *, const int *, const int *, const int *, int, + const cuComplex *, const cuComplex *, cuComplex *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCbsrxmv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, sizeOfMask, mb, nb, nnzb, alpha, descrA, + bsrSortedValA, bsrSortedMaskPtrA, bsrSortedRowPtrA, + bsrSortedEndPtrA, bsrSortedColIndA, blockDim, x, beta, y); +} + +cusparseStatus_t CUSPARSEAPI cusparseZbsrxmv( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, int sizeOfMask, int mb, int nb, int nnzb, + const cuDoubleComplex *alpha, const cusparseMatDescr_t descrA, + const cuDoubleComplex *bsrSortedValA, const int *bsrSortedMaskPtrA, + const int *bsrSortedRowPtrA, const int *bsrSortedEndPtrA, + const int *bsrSortedColIndA, int blockDim, const cuDoubleComplex *x, + const cuDoubleComplex *beta, cuDoubleComplex *y) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, int, int, int, + int, const cuDoubleComplex *, const cusparseMatDescr_t, + const cuDoubleComplex *, const int *, const int *, const int *, + const int *, int, const cuDoubleComplex *, const cuDoubleComplex *, + cuDoubleComplex *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZbsrxmv"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, sizeOfMask, mb, nb, nnzb, alpha, descrA, + bsrSortedValA, bsrSortedMaskPtrA, bsrSortedRowPtrA, + bsrSortedEndPtrA, bsrSortedColIndA, blockDim, x, beta, y); +} + +cusparseStatus_t CUSPARSEAPI cusparseXcsrsv2_zeroPivot(cusparseHandle_t handle, + csrsv2Info_t info, + int *position) { + using FuncPtr = + cusparseStatus_t(CUSPARSEAPI *)(cusparseHandle_t, csrsv2Info_t, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseXcsrsv2_zeroPivot"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, info, position); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsrsv2_bufferSize( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int nnz, + const cusparseMatDescr_t descrA, float *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, csrsv2Info_t info, + int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, const cusparseMatDescr_t, + float *, const int *, const int *, csrsv2Info_t, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseScsrsv2_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, nnz, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsrsv2_bufferSize( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int nnz, + const cusparseMatDescr_t descrA, double *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, csrsv2Info_t info, + int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, const cusparseMatDescr_t, + double *, const int *, const int *, csrsv2Info_t, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDcsrsv2_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, nnz, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsrsv2_bufferSize( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int nnz, + const cusparseMatDescr_t descrA, cuComplex *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, csrsv2Info_t info, + int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, const cusparseMatDescr_t, + cuComplex *, const int *, const int *, csrsv2Info_t, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCcsrsv2_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, nnz, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsrsv2_bufferSize( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int nnz, + const cusparseMatDescr_t descrA, cuDoubleComplex *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, csrsv2Info_t info, + int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, const cusparseMatDescr_t, + cuDoubleComplex *, const int *, const int *, csrsv2Info_t, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZcsrsv2_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, nnz, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsrsv2_bufferSizeExt( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int nnz, + const cusparseMatDescr_t descrA, float *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, csrsv2Info_t info, + size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, const cusparseMatDescr_t, + float *, const int *, const int *, csrsv2Info_t, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseScsrsv2_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, nnz, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, info, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsrsv2_bufferSizeExt( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int nnz, + const cusparseMatDescr_t descrA, double *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, csrsv2Info_t info, + size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, const cusparseMatDescr_t, + double *, const int *, const int *, csrsv2Info_t, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDcsrsv2_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, nnz, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, info, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsrsv2_bufferSizeExt( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int nnz, + const cusparseMatDescr_t descrA, cuComplex *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, csrsv2Info_t info, + size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, const cusparseMatDescr_t, + cuComplex *, const int *, const int *, csrsv2Info_t, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCcsrsv2_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, nnz, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, info, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsrsv2_bufferSizeExt( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int nnz, + const cusparseMatDescr_t descrA, cuDoubleComplex *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, csrsv2Info_t info, + size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, const cusparseMatDescr_t, + cuDoubleComplex *, const int *, const int *, csrsv2Info_t, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZcsrsv2_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, nnz, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, info, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsrsv2_analysis( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int nnz, + const cusparseMatDescr_t descrA, const float *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, csrsv2Info_t info, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, const cusparseMatDescr_t, + const float *, const int *, const int *, csrsv2Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseScsrsv2_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, nnz, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsrsv2_analysis( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int nnz, + const cusparseMatDescr_t descrA, const double *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, csrsv2Info_t info, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, const cusparseMatDescr_t, + const double *, const int *, const int *, csrsv2Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDcsrsv2_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, nnz, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsrsv2_analysis( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int nnz, + const cusparseMatDescr_t descrA, const cuComplex *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, csrsv2Info_t info, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, const cusparseMatDescr_t, + const cuComplex *, const int *, const int *, csrsv2Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCcsrsv2_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, nnz, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsrsv2_analysis( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int nnz, + const cusparseMatDescr_t descrA, const cuDoubleComplex *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, csrsv2Info_t info, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, const cusparseMatDescr_t, + const cuDoubleComplex *, const int *, const int *, csrsv2Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZcsrsv2_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, nnz, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsrsv2_solve( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int nnz, + const float *alpha, const cusparseMatDescr_t descrA, + const float *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, csrsv2Info_t info, const float *f, float *x, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, const float *, + const cusparseMatDescr_t, const float *, const int *, const int *, + csrsv2Info_t, const float *, float *, cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseScsrsv2_solve"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, nnz, alpha, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, info, f, x, policy, + pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsrsv2_solve( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int nnz, + const double *alpha, const cusparseMatDescr_t descrA, + const double *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, csrsv2Info_t info, const double *f, double *x, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, const double *, + const cusparseMatDescr_t, const double *, const int *, const int *, + csrsv2Info_t, const double *, double *, cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDcsrsv2_solve"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, nnz, alpha, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, info, f, x, policy, + pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsrsv2_solve( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int nnz, + const cuComplex *alpha, const cusparseMatDescr_t descrA, + const cuComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, csrsv2Info_t info, const cuComplex *f, + cuComplex *x, cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, const cuComplex *, + const cusparseMatDescr_t, const cuComplex *, const int *, const int *, + csrsv2Info_t, const cuComplex *, cuComplex *, cusparseSolvePolicy_t, + void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCcsrsv2_solve"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, nnz, alpha, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, info, f, x, policy, + pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsrsv2_solve( + cusparseHandle_t handle, cusparseOperation_t transA, int m, int nnz, + const cuDoubleComplex *alpha, const cusparseMatDescr_t descrA, + const cuDoubleComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, csrsv2Info_t info, const cuDoubleComplex *f, + cuDoubleComplex *x, cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, int, int, const cuDoubleComplex *, + const cusparseMatDescr_t, const cuDoubleComplex *, const int *, + const int *, csrsv2Info_t, const cuDoubleComplex *, cuDoubleComplex *, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZcsrsv2_solve"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, transA, m, nnz, alpha, descrA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, info, f, x, policy, + pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseXbsrsv2_zeroPivot(cusparseHandle_t handle, + bsrsv2Info_t info, + int *position) { + using FuncPtr = + cusparseStatus_t(CUSPARSEAPI *)(cusparseHandle_t, bsrsv2Info_t, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseXbsrsv2_zeroPivot"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, info, position); +} + +cusparseStatus_t CUSPARSEAPI cusparseSbsrsv2_bufferSize( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, int mb, int nnzb, + const cusparseMatDescr_t descrA, float *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int blockDim, + bsrsv2Info_t info, int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, int, int, + const cusparseMatDescr_t, float *, const int *, const int *, int, + bsrsv2Info_t, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSbsrsv2_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, mb, nnzb, descrA, bsrSortedValA, + bsrSortedRowPtrA, bsrSortedColIndA, blockDim, info, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseDbsrsv2_bufferSize( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, int mb, int nnzb, + const cusparseMatDescr_t descrA, double *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int blockDim, + bsrsv2Info_t info, int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, int, int, + const cusparseMatDescr_t, double *, const int *, const int *, int, + bsrsv2Info_t, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDbsrsv2_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, mb, nnzb, descrA, bsrSortedValA, + bsrSortedRowPtrA, bsrSortedColIndA, blockDim, info, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseCbsrsv2_bufferSize( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, int mb, int nnzb, + const cusparseMatDescr_t descrA, cuComplex *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int blockDim, + bsrsv2Info_t info, int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, int, int, + const cusparseMatDescr_t, cuComplex *, const int *, const int *, int, + bsrsv2Info_t, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCbsrsv2_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, mb, nnzb, descrA, bsrSortedValA, + bsrSortedRowPtrA, bsrSortedColIndA, blockDim, info, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseZbsrsv2_bufferSize( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, int mb, int nnzb, + const cusparseMatDescr_t descrA, cuDoubleComplex *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int blockDim, + bsrsv2Info_t info, int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, int, int, + const cusparseMatDescr_t, cuDoubleComplex *, const int *, const int *, + int, bsrsv2Info_t, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZbsrsv2_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, mb, nnzb, descrA, bsrSortedValA, + bsrSortedRowPtrA, bsrSortedColIndA, blockDim, info, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseSbsrsv2_bufferSizeExt( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, int mb, int nnzb, + const cusparseMatDescr_t descrA, float *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int blockSize, + bsrsv2Info_t info, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, int, int, + const cusparseMatDescr_t, float *, const int *, const int *, int, + bsrsv2Info_t, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSbsrsv2_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, mb, nnzb, descrA, bsrSortedValA, + bsrSortedRowPtrA, bsrSortedColIndA, blockSize, info, + pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseDbsrsv2_bufferSizeExt( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, int mb, int nnzb, + const cusparseMatDescr_t descrA, double *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int blockSize, + bsrsv2Info_t info, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, int, int, + const cusparseMatDescr_t, double *, const int *, const int *, int, + bsrsv2Info_t, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDbsrsv2_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, mb, nnzb, descrA, bsrSortedValA, + bsrSortedRowPtrA, bsrSortedColIndA, blockSize, info, + pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseCbsrsv2_bufferSizeExt( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, int mb, int nnzb, + const cusparseMatDescr_t descrA, cuComplex *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int blockSize, + bsrsv2Info_t info, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, int, int, + const cusparseMatDescr_t, cuComplex *, const int *, const int *, int, + bsrsv2Info_t, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCbsrsv2_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, mb, nnzb, descrA, bsrSortedValA, + bsrSortedRowPtrA, bsrSortedColIndA, blockSize, info, + pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseZbsrsv2_bufferSizeExt( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, int mb, int nnzb, + const cusparseMatDescr_t descrA, cuDoubleComplex *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int blockSize, + bsrsv2Info_t info, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, int, int, + const cusparseMatDescr_t, cuDoubleComplex *, const int *, const int *, + int, bsrsv2Info_t, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZbsrsv2_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, mb, nnzb, descrA, bsrSortedValA, + bsrSortedRowPtrA, bsrSortedColIndA, blockSize, info, + pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseSbsrsv2_analysis( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, int mb, int nnzb, + const cusparseMatDescr_t descrA, const float *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int blockDim, + bsrsv2Info_t info, cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, int, int, + const cusparseMatDescr_t, const float *, const int *, const int *, int, + bsrsv2Info_t, cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSbsrsv2_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, mb, nnzb, descrA, bsrSortedValA, + bsrSortedRowPtrA, bsrSortedColIndA, blockDim, info, policy, + pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDbsrsv2_analysis( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, int mb, int nnzb, + const cusparseMatDescr_t descrA, const double *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int blockDim, + bsrsv2Info_t info, cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, int, int, + const cusparseMatDescr_t, const double *, const int *, const int *, int, + bsrsv2Info_t, cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDbsrsv2_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, mb, nnzb, descrA, bsrSortedValA, + bsrSortedRowPtrA, bsrSortedColIndA, blockDim, info, policy, + pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCbsrsv2_analysis( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, int mb, int nnzb, + const cusparseMatDescr_t descrA, const cuComplex *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int blockDim, + bsrsv2Info_t info, cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, int, int, + const cusparseMatDescr_t, const cuComplex *, const int *, const int *, + int, bsrsv2Info_t, cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCbsrsv2_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, mb, nnzb, descrA, bsrSortedValA, + bsrSortedRowPtrA, bsrSortedColIndA, blockDim, info, policy, + pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseZbsrsv2_analysis( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, int mb, int nnzb, + const cusparseMatDescr_t descrA, const cuDoubleComplex *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int blockDim, + bsrsv2Info_t info, cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, int, int, + const cusparseMatDescr_t, const cuDoubleComplex *, const int *, + const int *, int, bsrsv2Info_t, cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZbsrsv2_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, mb, nnzb, descrA, bsrSortedValA, + bsrSortedRowPtrA, bsrSortedColIndA, blockDim, info, policy, + pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseSbsrsv2_solve( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, int mb, int nnzb, const float *alpha, + const cusparseMatDescr_t descrA, const float *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int blockDim, + bsrsv2Info_t info, const float *f, float *x, cusparseSolvePolicy_t policy, + void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, int, int, + const float *, const cusparseMatDescr_t, const float *, const int *, + const int *, int, bsrsv2Info_t, const float *, float *, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSbsrsv2_solve"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, mb, nnzb, alpha, descrA, bsrSortedValA, + bsrSortedRowPtrA, bsrSortedColIndA, blockDim, info, f, x, + policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDbsrsv2_solve( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, int mb, int nnzb, const double *alpha, + const cusparseMatDescr_t descrA, const double *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int blockDim, + bsrsv2Info_t info, const double *f, double *x, cusparseSolvePolicy_t policy, + void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, int, int, + const double *, const cusparseMatDescr_t, const double *, const int *, + const int *, int, bsrsv2Info_t, const double *, double *, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDbsrsv2_solve"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, mb, nnzb, alpha, descrA, bsrSortedValA, + bsrSortedRowPtrA, bsrSortedColIndA, blockDim, info, f, x, + policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCbsrsv2_solve( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, int mb, int nnzb, const cuComplex *alpha, + const cusparseMatDescr_t descrA, const cuComplex *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int blockDim, + bsrsv2Info_t info, const cuComplex *f, cuComplex *x, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, int, int, + const cuComplex *, const cusparseMatDescr_t, const cuComplex *, + const int *, const int *, int, bsrsv2Info_t, const cuComplex *, + cuComplex *, cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCbsrsv2_solve"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, mb, nnzb, alpha, descrA, bsrSortedValA, + bsrSortedRowPtrA, bsrSortedColIndA, blockDim, info, f, x, + policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseZbsrsv2_solve( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, int mb, int nnzb, const cuDoubleComplex *alpha, + const cusparseMatDescr_t descrA, const cuDoubleComplex *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int blockDim, + bsrsv2Info_t info, const cuDoubleComplex *f, cuDoubleComplex *x, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, int, int, + const cuDoubleComplex *, const cusparseMatDescr_t, + const cuDoubleComplex *, const int *, const int *, int, bsrsv2Info_t, + const cuDoubleComplex *, cuDoubleComplex *, cusparseSolvePolicy_t, + void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZbsrsv2_solve"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, mb, nnzb, alpha, descrA, bsrSortedValA, + bsrSortedRowPtrA, bsrSortedColIndA, blockDim, info, f, x, + policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseSbsrmm( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, cusparseOperation_t transB, int mb, int n, + int kb, int nnzb, const float *alpha, const cusparseMatDescr_t descrA, + const float *bsrSortedValA, const int *bsrSortedRowPtrA, + const int *bsrSortedColIndA, const int blockSize, const float *B, + const int ldb, const float *beta, float *C, int ldc) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, + cusparseOperation_t, int, int, int, int, const float *, + const cusparseMatDescr_t, const float *, const int *, const int *, + const int, const float *, const int, const float *, float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSbsrmm"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, transB, mb, n, kb, nnzb, alpha, descrA, + bsrSortedValA, bsrSortedRowPtrA, bsrSortedColIndA, blockSize, + B, ldb, beta, C, ldc); +} + +cusparseStatus_t CUSPARSEAPI cusparseDbsrmm( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, cusparseOperation_t transB, int mb, int n, + int kb, int nnzb, const double *alpha, const cusparseMatDescr_t descrA, + const double *bsrSortedValA, const int *bsrSortedRowPtrA, + const int *bsrSortedColIndA, const int blockSize, const double *B, + const int ldb, const double *beta, double *C, int ldc) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, + cusparseOperation_t, int, int, int, int, const double *, + const cusparseMatDescr_t, const double *, const int *, const int *, + const int, const double *, const int, const double *, double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDbsrmm"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, transB, mb, n, kb, nnzb, alpha, descrA, + bsrSortedValA, bsrSortedRowPtrA, bsrSortedColIndA, blockSize, + B, ldb, beta, C, ldc); +} + +cusparseStatus_t CUSPARSEAPI cusparseCbsrmm( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, cusparseOperation_t transB, int mb, int n, + int kb, int nnzb, const cuComplex *alpha, const cusparseMatDescr_t descrA, + const cuComplex *bsrSortedValA, const int *bsrSortedRowPtrA, + const int *bsrSortedColIndA, const int blockSize, const cuComplex *B, + const int ldb, const cuComplex *beta, cuComplex *C, int ldc) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, + cusparseOperation_t, int, int, int, int, const cuComplex *, + const cusparseMatDescr_t, const cuComplex *, const int *, const int *, + const int, const cuComplex *, const int, const cuComplex *, cuComplex *, + int); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCbsrmm"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, transB, mb, n, kb, nnzb, alpha, descrA, + bsrSortedValA, bsrSortedRowPtrA, bsrSortedColIndA, blockSize, + B, ldb, beta, C, ldc); +} + +cusparseStatus_t CUSPARSEAPI cusparseZbsrmm( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, cusparseOperation_t transB, int mb, int n, + int kb, int nnzb, const cuDoubleComplex *alpha, + const cusparseMatDescr_t descrA, const cuDoubleComplex *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, + const int blockSize, const cuDoubleComplex *B, const int ldb, + const cuDoubleComplex *beta, cuDoubleComplex *C, int ldc) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, + cusparseOperation_t, int, int, int, int, const cuDoubleComplex *, + const cusparseMatDescr_t, const cuDoubleComplex *, const int *, + const int *, const int, const cuDoubleComplex *, const int, + const cuDoubleComplex *, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZbsrmm"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, transB, mb, n, kb, nnzb, alpha, descrA, + bsrSortedValA, bsrSortedRowPtrA, bsrSortedColIndA, blockSize, + B, ldb, beta, C, ldc); +} + +cusparseStatus_t CUSPARSEAPI cusparseSgemmi( + cusparseHandle_t handle, int m, int n, int k, int nnz, const float *alpha, + const float *A, int lda, const float *cscValB, const int *cscColPtrB, + const int *cscRowIndB, const float *beta, float *C, int ldc) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, int, const float *, const float *, int, + const float *, const int *, const int *, const float *, float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSgemmi"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, k, nnz, alpha, A, lda, cscValB, cscColPtrB, + cscRowIndB, beta, C, ldc); +} + +cusparseStatus_t CUSPARSEAPI cusparseDgemmi( + cusparseHandle_t handle, int m, int n, int k, int nnz, const double *alpha, + const double *A, int lda, const double *cscValB, const int *cscColPtrB, + const int *cscRowIndB, const double *beta, double *C, int ldc) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, int, const double *, const double *, int, + const double *, const int *, const int *, const double *, double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDgemmi"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, k, nnz, alpha, A, lda, cscValB, cscColPtrB, + cscRowIndB, beta, C, ldc); +} + +cusparseStatus_t CUSPARSEAPI cusparseCgemmi( + cusparseHandle_t handle, int m, int n, int k, int nnz, + const cuComplex *alpha, const cuComplex *A, int lda, + const cuComplex *cscValB, const int *cscColPtrB, const int *cscRowIndB, + const cuComplex *beta, cuComplex *C, int ldc) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, int, const cuComplex *, + const cuComplex *, int, const cuComplex *, const int *, const int *, + const cuComplex *, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCgemmi"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, k, nnz, alpha, A, lda, cscValB, cscColPtrB, + cscRowIndB, beta, C, ldc); +} + +cusparseStatus_t CUSPARSEAPI +cusparseZgemmi(cusparseHandle_t handle, int m, int n, int k, int nnz, + const cuDoubleComplex *alpha, const cuDoubleComplex *A, int lda, + const cuDoubleComplex *cscValB, const int *cscColPtrB, + const int *cscRowIndB, const cuDoubleComplex *beta, + cuDoubleComplex *C, int ldc) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, int, const cuDoubleComplex *, + const cuDoubleComplex *, int, const cuDoubleComplex *, const int *, + const int *, const cuDoubleComplex *, cuDoubleComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZgemmi"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, k, nnz, alpha, A, lda, cscValB, cscColPtrB, + cscRowIndB, beta, C, ldc); +} + +cusparseStatus_t CUSPARSEAPI cusparseCreateCsrsm2Info(csrsm2Info_t *info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(csrsm2Info_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCreateCsrsm2Info"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusparseStatus_t CUSPARSEAPI cusparseDestroyCsrsm2Info(csrsm2Info_t info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(csrsm2Info_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDestroyCsrsm2Info"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +cusparseStatus_t CUSPARSEAPI cusparseXcsrsm2_zeroPivot(cusparseHandle_t handle, + csrsm2Info_t info, + int *position) { + using FuncPtr = + cusparseStatus_t(CUSPARSEAPI *)(cusparseHandle_t, csrsm2Info_t, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseXcsrsm2_zeroPivot"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, info, position); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsrsm2_bufferSizeExt( + cusparseHandle_t handle, int algo, cusparseOperation_t transA, + cusparseOperation_t transB, int m, int nrhs, int nnz, const float *alpha, + const cusparseMatDescr_t descrA, const float *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, const float *B, + int ldb, csrsm2Info_t info, cusparseSolvePolicy_t policy, + size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, cusparseOperation_t, cusparseOperation_t, int, int, + int, const float *, const cusparseMatDescr_t, const float *, const int *, + const int *, const float *, int, csrsm2Info_t, cusparseSolvePolicy_t, + size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseScsrsm2_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, algo, transA, transB, m, nrhs, nnz, alpha, descrA, + csrSortedValA, csrSortedRowPtrA, csrSortedColIndA, B, ldb, + info, policy, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsrsm2_bufferSizeExt( + cusparseHandle_t handle, int algo, cusparseOperation_t transA, + cusparseOperation_t transB, int m, int nrhs, int nnz, const double *alpha, + const cusparseMatDescr_t descrA, const double *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, const double *B, + int ldb, csrsm2Info_t info, cusparseSolvePolicy_t policy, + size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, cusparseOperation_t, cusparseOperation_t, int, int, + int, const double *, const cusparseMatDescr_t, const double *, + const int *, const int *, const double *, int, csrsm2Info_t, + cusparseSolvePolicy_t, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDcsrsm2_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, algo, transA, transB, m, nrhs, nnz, alpha, descrA, + csrSortedValA, csrSortedRowPtrA, csrSortedColIndA, B, ldb, + info, policy, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsrsm2_bufferSizeExt( + cusparseHandle_t handle, int algo, cusparseOperation_t transA, + cusparseOperation_t transB, int m, int nrhs, int nnz, + const cuComplex *alpha, const cusparseMatDescr_t descrA, + const cuComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, const cuComplex *B, int ldb, csrsm2Info_t info, + cusparseSolvePolicy_t policy, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, cusparseOperation_t, cusparseOperation_t, int, int, + int, const cuComplex *, const cusparseMatDescr_t, const cuComplex *, + const int *, const int *, const cuComplex *, int, csrsm2Info_t, + cusparseSolvePolicy_t, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCcsrsm2_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, algo, transA, transB, m, nrhs, nnz, alpha, descrA, + csrSortedValA, csrSortedRowPtrA, csrSortedColIndA, B, ldb, + info, policy, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsrsm2_bufferSizeExt( + cusparseHandle_t handle, int algo, cusparseOperation_t transA, + cusparseOperation_t transB, int m, int nrhs, int nnz, + const cuDoubleComplex *alpha, const cusparseMatDescr_t descrA, + const cuDoubleComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, const cuDoubleComplex *B, int ldb, + csrsm2Info_t info, cusparseSolvePolicy_t policy, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, cusparseOperation_t, cusparseOperation_t, int, int, + int, const cuDoubleComplex *, const cusparseMatDescr_t, + const cuDoubleComplex *, const int *, const int *, + const cuDoubleComplex *, int, csrsm2Info_t, cusparseSolvePolicy_t, + size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZcsrsm2_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, algo, transA, transB, m, nrhs, nnz, alpha, descrA, + csrSortedValA, csrSortedRowPtrA, csrSortedColIndA, B, ldb, + info, policy, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsrsm2_analysis( + cusparseHandle_t handle, int algo, cusparseOperation_t transA, + cusparseOperation_t transB, int m, int nrhs, int nnz, const float *alpha, + const cusparseMatDescr_t descrA, const float *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, const float *B, + int ldb, csrsm2Info_t info, cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, cusparseOperation_t, cusparseOperation_t, int, int, + int, const float *, const cusparseMatDescr_t, const float *, const int *, + const int *, const float *, int, csrsm2Info_t, cusparseSolvePolicy_t, + void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseScsrsm2_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, algo, transA, transB, m, nrhs, nnz, alpha, descrA, + csrSortedValA, csrSortedRowPtrA, csrSortedColIndA, B, ldb, + info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsrsm2_analysis( + cusparseHandle_t handle, int algo, cusparseOperation_t transA, + cusparseOperation_t transB, int m, int nrhs, int nnz, const double *alpha, + const cusparseMatDescr_t descrA, const double *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, const double *B, + int ldb, csrsm2Info_t info, cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, cusparseOperation_t, cusparseOperation_t, int, int, + int, const double *, const cusparseMatDescr_t, const double *, + const int *, const int *, const double *, int, csrsm2Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDcsrsm2_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, algo, transA, transB, m, nrhs, nnz, alpha, descrA, + csrSortedValA, csrSortedRowPtrA, csrSortedColIndA, B, ldb, + info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsrsm2_analysis( + cusparseHandle_t handle, int algo, cusparseOperation_t transA, + cusparseOperation_t transB, int m, int nrhs, int nnz, + const cuComplex *alpha, const cusparseMatDescr_t descrA, + const cuComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, const cuComplex *B, int ldb, csrsm2Info_t info, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, cusparseOperation_t, cusparseOperation_t, int, int, + int, const cuComplex *, const cusparseMatDescr_t, const cuComplex *, + const int *, const int *, const cuComplex *, int, csrsm2Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCcsrsm2_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, algo, transA, transB, m, nrhs, nnz, alpha, descrA, + csrSortedValA, csrSortedRowPtrA, csrSortedColIndA, B, ldb, + info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsrsm2_analysis( + cusparseHandle_t handle, int algo, cusparseOperation_t transA, + cusparseOperation_t transB, int m, int nrhs, int nnz, + const cuDoubleComplex *alpha, const cusparseMatDescr_t descrA, + const cuDoubleComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, const cuDoubleComplex *B, int ldb, + csrsm2Info_t info, cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, cusparseOperation_t, cusparseOperation_t, int, int, + int, const cuDoubleComplex *, const cusparseMatDescr_t, + const cuDoubleComplex *, const int *, const int *, + const cuDoubleComplex *, int, csrsm2Info_t, cusparseSolvePolicy_t, + void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZcsrsm2_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, algo, transA, transB, m, nrhs, nnz, alpha, descrA, + csrSortedValA, csrSortedRowPtrA, csrSortedColIndA, B, ldb, + info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsrsm2_solve( + cusparseHandle_t handle, int algo, cusparseOperation_t transA, + cusparseOperation_t transB, int m, int nrhs, int nnz, const float *alpha, + const cusparseMatDescr_t descrA, const float *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, float *B, int ldb, + csrsm2Info_t info, cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, cusparseOperation_t, cusparseOperation_t, int, int, + int, const float *, const cusparseMatDescr_t, const float *, const int *, + const int *, float *, int, csrsm2Info_t, cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseScsrsm2_solve"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, algo, transA, transB, m, nrhs, nnz, alpha, descrA, + csrSortedValA, csrSortedRowPtrA, csrSortedColIndA, B, ldb, + info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsrsm2_solve( + cusparseHandle_t handle, int algo, cusparseOperation_t transA, + cusparseOperation_t transB, int m, int nrhs, int nnz, const double *alpha, + const cusparseMatDescr_t descrA, const double *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, double *B, + int ldb, csrsm2Info_t info, cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, cusparseOperation_t, cusparseOperation_t, int, int, + int, const double *, const cusparseMatDescr_t, const double *, + const int *, const int *, double *, int, csrsm2Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDcsrsm2_solve"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, algo, transA, transB, m, nrhs, nnz, alpha, descrA, + csrSortedValA, csrSortedRowPtrA, csrSortedColIndA, B, ldb, + info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsrsm2_solve( + cusparseHandle_t handle, int algo, cusparseOperation_t transA, + cusparseOperation_t transB, int m, int nrhs, int nnz, + const cuComplex *alpha, const cusparseMatDescr_t descrA, + const cuComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, cuComplex *B, int ldb, csrsm2Info_t info, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, cusparseOperation_t, cusparseOperation_t, int, int, + int, const cuComplex *, const cusparseMatDescr_t, const cuComplex *, + const int *, const int *, cuComplex *, int, csrsm2Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCcsrsm2_solve"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, algo, transA, transB, m, nrhs, nnz, alpha, descrA, + csrSortedValA, csrSortedRowPtrA, csrSortedColIndA, B, ldb, + info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsrsm2_solve( + cusparseHandle_t handle, int algo, cusparseOperation_t transA, + cusparseOperation_t transB, int m, int nrhs, int nnz, + const cuDoubleComplex *alpha, const cusparseMatDescr_t descrA, + const cuDoubleComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, cuDoubleComplex *B, int ldb, csrsm2Info_t info, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, cusparseOperation_t, cusparseOperation_t, int, int, + int, const cuDoubleComplex *, const cusparseMatDescr_t, + const cuDoubleComplex *, const int *, const int *, cuDoubleComplex *, int, + csrsm2Info_t, cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZcsrsm2_solve"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, algo, transA, transB, m, nrhs, nnz, alpha, descrA, + csrSortedValA, csrSortedRowPtrA, csrSortedColIndA, B, ldb, + info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseXbsrsm2_zeroPivot(cusparseHandle_t handle, + bsrsm2Info_t info, + int *position) { + using FuncPtr = + cusparseStatus_t(CUSPARSEAPI *)(cusparseHandle_t, bsrsm2Info_t, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseXbsrsm2_zeroPivot"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, info, position); +} + +cusparseStatus_t CUSPARSEAPI cusparseSbsrsm2_bufferSize( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, cusparseOperation_t transXY, int mb, int n, + int nnzb, const cusparseMatDescr_t descrA, float *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockSize, + bsrsm2Info_t info, int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, + cusparseOperation_t, int, int, int, const cusparseMatDescr_t, float *, + const int *, const int *, int, bsrsm2Info_t, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSbsrsm2_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, transXY, mb, n, nnzb, descrA, + bsrSortedVal, bsrSortedRowPtr, bsrSortedColInd, blockSize, + info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseDbsrsm2_bufferSize( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, cusparseOperation_t transXY, int mb, int n, + int nnzb, const cusparseMatDescr_t descrA, double *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockSize, + bsrsm2Info_t info, int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, + cusparseOperation_t, int, int, int, const cusparseMatDescr_t, double *, + const int *, const int *, int, bsrsm2Info_t, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDbsrsm2_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, transXY, mb, n, nnzb, descrA, + bsrSortedVal, bsrSortedRowPtr, bsrSortedColInd, blockSize, + info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseCbsrsm2_bufferSize( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, cusparseOperation_t transXY, int mb, int n, + int nnzb, const cusparseMatDescr_t descrA, cuComplex *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockSize, + bsrsm2Info_t info, int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, + cusparseOperation_t, int, int, int, const cusparseMatDescr_t, cuComplex *, + const int *, const int *, int, bsrsm2Info_t, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCbsrsm2_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, transXY, mb, n, nnzb, descrA, + bsrSortedVal, bsrSortedRowPtr, bsrSortedColInd, blockSize, + info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseZbsrsm2_bufferSize( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, cusparseOperation_t transXY, int mb, int n, + int nnzb, const cusparseMatDescr_t descrA, cuDoubleComplex *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockSize, + bsrsm2Info_t info, int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, + cusparseOperation_t, int, int, int, const cusparseMatDescr_t, + cuDoubleComplex *, const int *, const int *, int, bsrsm2Info_t, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZbsrsm2_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, transXY, mb, n, nnzb, descrA, + bsrSortedVal, bsrSortedRowPtr, bsrSortedColInd, blockSize, + info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseSbsrsm2_bufferSizeExt( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, cusparseOperation_t transB, int mb, int n, + int nnzb, const cusparseMatDescr_t descrA, float *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockSize, + bsrsm2Info_t info, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, + cusparseOperation_t, int, int, int, const cusparseMatDescr_t, float *, + const int *, const int *, int, bsrsm2Info_t, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSbsrsm2_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, transB, mb, n, nnzb, descrA, + bsrSortedVal, bsrSortedRowPtr, bsrSortedColInd, blockSize, + info, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseDbsrsm2_bufferSizeExt( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, cusparseOperation_t transB, int mb, int n, + int nnzb, const cusparseMatDescr_t descrA, double *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockSize, + bsrsm2Info_t info, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, + cusparseOperation_t, int, int, int, const cusparseMatDescr_t, double *, + const int *, const int *, int, bsrsm2Info_t, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDbsrsm2_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, transB, mb, n, nnzb, descrA, + bsrSortedVal, bsrSortedRowPtr, bsrSortedColInd, blockSize, + info, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseCbsrsm2_bufferSizeExt( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, cusparseOperation_t transB, int mb, int n, + int nnzb, const cusparseMatDescr_t descrA, cuComplex *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockSize, + bsrsm2Info_t info, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, + cusparseOperation_t, int, int, int, const cusparseMatDescr_t, cuComplex *, + const int *, const int *, int, bsrsm2Info_t, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCbsrsm2_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, transB, mb, n, nnzb, descrA, + bsrSortedVal, bsrSortedRowPtr, bsrSortedColInd, blockSize, + info, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseZbsrsm2_bufferSizeExt( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, cusparseOperation_t transB, int mb, int n, + int nnzb, const cusparseMatDescr_t descrA, cuDoubleComplex *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockSize, + bsrsm2Info_t info, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, + cusparseOperation_t, int, int, int, const cusparseMatDescr_t, + cuDoubleComplex *, const int *, const int *, int, bsrsm2Info_t, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZbsrsm2_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, transB, mb, n, nnzb, descrA, + bsrSortedVal, bsrSortedRowPtr, bsrSortedColInd, blockSize, + info, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseSbsrsm2_analysis( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, cusparseOperation_t transXY, int mb, int n, + int nnzb, const cusparseMatDescr_t descrA, const float *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockSize, + bsrsm2Info_t info, cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, + cusparseOperation_t, int, int, int, const cusparseMatDescr_t, + const float *, const int *, const int *, int, bsrsm2Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSbsrsm2_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, transXY, mb, n, nnzb, descrA, + bsrSortedVal, bsrSortedRowPtr, bsrSortedColInd, blockSize, + info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDbsrsm2_analysis( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, cusparseOperation_t transXY, int mb, int n, + int nnzb, const cusparseMatDescr_t descrA, const double *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockSize, + bsrsm2Info_t info, cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, + cusparseOperation_t, int, int, int, const cusparseMatDescr_t, + const double *, const int *, const int *, int, bsrsm2Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDbsrsm2_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, transXY, mb, n, nnzb, descrA, + bsrSortedVal, bsrSortedRowPtr, bsrSortedColInd, blockSize, + info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCbsrsm2_analysis( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, cusparseOperation_t transXY, int mb, int n, + int nnzb, const cusparseMatDescr_t descrA, const cuComplex *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockSize, + bsrsm2Info_t info, cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, + cusparseOperation_t, int, int, int, const cusparseMatDescr_t, + const cuComplex *, const int *, const int *, int, bsrsm2Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCbsrsm2_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, transXY, mb, n, nnzb, descrA, + bsrSortedVal, bsrSortedRowPtr, bsrSortedColInd, blockSize, + info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseZbsrsm2_analysis( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, cusparseOperation_t transXY, int mb, int n, + int nnzb, const cusparseMatDescr_t descrA, + const cuDoubleComplex *bsrSortedVal, const int *bsrSortedRowPtr, + const int *bsrSortedColInd, int blockSize, bsrsm2Info_t info, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, + cusparseOperation_t, int, int, int, const cusparseMatDescr_t, + const cuDoubleComplex *, const int *, const int *, int, bsrsm2Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZbsrsm2_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, transXY, mb, n, nnzb, descrA, + bsrSortedVal, bsrSortedRowPtr, bsrSortedColInd, blockSize, + info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseSbsrsm2_solve( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, cusparseOperation_t transXY, int mb, int n, + int nnzb, const float *alpha, const cusparseMatDescr_t descrA, + const float *bsrSortedVal, const int *bsrSortedRowPtr, + const int *bsrSortedColInd, int blockSize, bsrsm2Info_t info, + const float *B, int ldb, float *X, int ldx, cusparseSolvePolicy_t policy, + void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, + cusparseOperation_t, int, int, int, const float *, + const cusparseMatDescr_t, const float *, const int *, const int *, int, + bsrsm2Info_t, const float *, int, float *, int, cusparseSolvePolicy_t, + void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSbsrsm2_solve"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, transXY, mb, n, nnzb, alpha, descrA, + bsrSortedVal, bsrSortedRowPtr, bsrSortedColInd, blockSize, + info, B, ldb, X, ldx, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDbsrsm2_solve( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, cusparseOperation_t transXY, int mb, int n, + int nnzb, const double *alpha, const cusparseMatDescr_t descrA, + const double *bsrSortedVal, const int *bsrSortedRowPtr, + const int *bsrSortedColInd, int blockSize, bsrsm2Info_t info, + const double *B, int ldb, double *X, int ldx, cusparseSolvePolicy_t policy, + void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, + cusparseOperation_t, int, int, int, const double *, + const cusparseMatDescr_t, const double *, const int *, const int *, int, + bsrsm2Info_t, const double *, int, double *, int, cusparseSolvePolicy_t, + void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDbsrsm2_solve"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, transXY, mb, n, nnzb, alpha, descrA, + bsrSortedVal, bsrSortedRowPtr, bsrSortedColInd, blockSize, + info, B, ldb, X, ldx, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCbsrsm2_solve( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, cusparseOperation_t transXY, int mb, int n, + int nnzb, const cuComplex *alpha, const cusparseMatDescr_t descrA, + const cuComplex *bsrSortedVal, const int *bsrSortedRowPtr, + const int *bsrSortedColInd, int blockSize, bsrsm2Info_t info, + const cuComplex *B, int ldb, cuComplex *X, int ldx, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, + cusparseOperation_t, int, int, int, const cuComplex *, + const cusparseMatDescr_t, const cuComplex *, const int *, const int *, + int, bsrsm2Info_t, const cuComplex *, int, cuComplex *, int, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCbsrsm2_solve"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, transXY, mb, n, nnzb, alpha, descrA, + bsrSortedVal, bsrSortedRowPtr, bsrSortedColInd, blockSize, + info, B, ldb, X, ldx, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseZbsrsm2_solve( + cusparseHandle_t handle, cusparseDirection_t dirA, + cusparseOperation_t transA, cusparseOperation_t transXY, int mb, int n, + int nnzb, const cuDoubleComplex *alpha, const cusparseMatDescr_t descrA, + const cuDoubleComplex *bsrSortedVal, const int *bsrSortedRowPtr, + const int *bsrSortedColInd, int blockSize, bsrsm2Info_t info, + const cuDoubleComplex *B, int ldb, cuDoubleComplex *X, int ldx, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, cusparseOperation_t, + cusparseOperation_t, int, int, int, const cuDoubleComplex *, + const cusparseMatDescr_t, const cuDoubleComplex *, const int *, + const int *, int, bsrsm2Info_t, const cuDoubleComplex *, int, + cuDoubleComplex *, int, cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZbsrsm2_solve"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, transA, transXY, mb, n, nnzb, alpha, descrA, + bsrSortedVal, bsrSortedRowPtr, bsrSortedColInd, blockSize, + info, B, ldb, X, ldx, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsrilu02_numericBoost( + cusparseHandle_t handle, csrilu02Info_t info, int enable_boost, double *tol, + float *boost_val) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, csrilu02Info_t, int, double *, float *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseScsrilu02_numericBoost"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, info, enable_boost, tol, boost_val); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsrilu02_numericBoost( + cusparseHandle_t handle, csrilu02Info_t info, int enable_boost, double *tol, + double *boost_val) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, csrilu02Info_t, int, double *, double *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDcsrilu02_numericBoost"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, info, enable_boost, tol, boost_val); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsrilu02_numericBoost( + cusparseHandle_t handle, csrilu02Info_t info, int enable_boost, double *tol, + cuComplex *boost_val) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, csrilu02Info_t, int, double *, cuComplex *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCcsrilu02_numericBoost"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, info, enable_boost, tol, boost_val); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsrilu02_numericBoost( + cusparseHandle_t handle, csrilu02Info_t info, int enable_boost, double *tol, + cuDoubleComplex *boost_val) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, csrilu02Info_t, int, double *, cuDoubleComplex *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZcsrilu02_numericBoost"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, info, enable_boost, tol, boost_val); +} + +cusparseStatus_t CUSPARSEAPI cusparseXcsrilu02_zeroPivot( + cusparseHandle_t handle, csrilu02Info_t info, int *position) { + using FuncPtr = + cusparseStatus_t(CUSPARSEAPI *)(cusparseHandle_t, csrilu02Info_t, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseXcsrilu02_zeroPivot"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, info, position); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsrilu02_bufferSize( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + float *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, csrilu02Info_t info, int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, float *, + const int *, const int *, csrilu02Info_t, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseScsrilu02_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsrilu02_bufferSize( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + double *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, csrilu02Info_t info, int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, double *, + const int *, const int *, csrilu02Info_t, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDcsrilu02_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsrilu02_bufferSize( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + cuComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, csrilu02Info_t info, int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, cuComplex *, + const int *, const int *, csrilu02Info_t, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCcsrilu02_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsrilu02_bufferSize( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + cuDoubleComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, csrilu02Info_t info, int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, cuDoubleComplex *, + const int *, const int *, csrilu02Info_t, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZcsrilu02_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsrilu02_bufferSizeExt( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + float *csrSortedVal, const int *csrSortedRowPtr, const int *csrSortedColInd, + csrilu02Info_t info, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, float *, + const int *, const int *, csrilu02Info_t, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseScsrilu02_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedVal, csrSortedRowPtr, + csrSortedColInd, info, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsrilu02_bufferSizeExt( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + double *csrSortedVal, const int *csrSortedRowPtr, + const int *csrSortedColInd, csrilu02Info_t info, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, double *, + const int *, const int *, csrilu02Info_t, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDcsrilu02_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedVal, csrSortedRowPtr, + csrSortedColInd, info, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsrilu02_bufferSizeExt( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + cuComplex *csrSortedVal, const int *csrSortedRowPtr, + const int *csrSortedColInd, csrilu02Info_t info, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, cuComplex *, + const int *, const int *, csrilu02Info_t, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCcsrilu02_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedVal, csrSortedRowPtr, + csrSortedColInd, info, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsrilu02_bufferSizeExt( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + cuDoubleComplex *csrSortedVal, const int *csrSortedRowPtr, + const int *csrSortedColInd, csrilu02Info_t info, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, cuDoubleComplex *, + const int *, const int *, csrilu02Info_t, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZcsrilu02_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedVal, csrSortedRowPtr, + csrSortedColInd, info, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsrilu02_analysis( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + const float *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, csrilu02Info_t info, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const float *, + const int *, const int *, csrilu02Info_t, cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseScsrilu02_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsrilu02_analysis( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + const double *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, csrilu02Info_t info, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const double *, + const int *, const int *, csrilu02Info_t, cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDcsrilu02_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsrilu02_analysis( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + const cuComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, csrilu02Info_t info, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const cuComplex *, + const int *, const int *, csrilu02Info_t, cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCcsrilu02_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsrilu02_analysis( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + const cuDoubleComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, csrilu02Info_t info, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, + const cuDoubleComplex *, const int *, const int *, csrilu02Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZcsrilu02_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsrilu02( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + float *csrSortedValA_valM, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, csrilu02Info_t info, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, float *, + const int *, const int *, csrilu02Info_t, cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseScsrilu02"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedValA_valM, csrSortedRowPtrA, + csrSortedColIndA, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsrilu02( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + double *csrSortedValA_valM, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, csrilu02Info_t info, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, double *, + const int *, const int *, csrilu02Info_t, cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDcsrilu02"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedValA_valM, csrSortedRowPtrA, + csrSortedColIndA, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsrilu02( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + cuComplex *csrSortedValA_valM, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, csrilu02Info_t info, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, cuComplex *, + const int *, const int *, csrilu02Info_t, cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCcsrilu02"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedValA_valM, csrSortedRowPtrA, + csrSortedColIndA, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsrilu02( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + cuDoubleComplex *csrSortedValA_valM, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, csrilu02Info_t info, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, cuDoubleComplex *, + const int *, const int *, csrilu02Info_t, cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZcsrilu02"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedValA_valM, csrSortedRowPtrA, + csrSortedColIndA, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseSbsrilu02_numericBoost( + cusparseHandle_t handle, bsrilu02Info_t info, int enable_boost, double *tol, + float *boost_val) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, bsrilu02Info_t, int, double *, float *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSbsrilu02_numericBoost"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, info, enable_boost, tol, boost_val); +} + +cusparseStatus_t CUSPARSEAPI cusparseDbsrilu02_numericBoost( + cusparseHandle_t handle, bsrilu02Info_t info, int enable_boost, double *tol, + double *boost_val) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, bsrilu02Info_t, int, double *, double *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDbsrilu02_numericBoost"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, info, enable_boost, tol, boost_val); +} + +cusparseStatus_t CUSPARSEAPI cusparseCbsrilu02_numericBoost( + cusparseHandle_t handle, bsrilu02Info_t info, int enable_boost, double *tol, + cuComplex *boost_val) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, bsrilu02Info_t, int, double *, cuComplex *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCbsrilu02_numericBoost"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, info, enable_boost, tol, boost_val); +} + +cusparseStatus_t CUSPARSEAPI cusparseZbsrilu02_numericBoost( + cusparseHandle_t handle, bsrilu02Info_t info, int enable_boost, double *tol, + cuDoubleComplex *boost_val) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, bsrilu02Info_t, int, double *, cuDoubleComplex *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZbsrilu02_numericBoost"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, info, enable_boost, tol, boost_val); +} + +cusparseStatus_t CUSPARSEAPI cusparseXbsrilu02_zeroPivot( + cusparseHandle_t handle, bsrilu02Info_t info, int *position) { + using FuncPtr = + cusparseStatus_t(CUSPARSEAPI *)(cusparseHandle_t, bsrilu02Info_t, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseXbsrilu02_zeroPivot"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, info, position); +} + +cusparseStatus_t CUSPARSEAPI cusparseSbsrilu02_bufferSize( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, float *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockDim, + bsrilu02Info_t info, int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + float *, const int *, const int *, int, bsrilu02Info_t, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSbsrilu02_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockDim, info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseDbsrilu02_bufferSize( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, double *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockDim, + bsrilu02Info_t info, int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + double *, const int *, const int *, int, bsrilu02Info_t, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDbsrilu02_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockDim, info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseCbsrilu02_bufferSize( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, cuComplex *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockDim, + bsrilu02Info_t info, int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + cuComplex *, const int *, const int *, int, bsrilu02Info_t, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCbsrilu02_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockDim, info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseZbsrilu02_bufferSize( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, cuDoubleComplex *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockDim, + bsrilu02Info_t info, int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + cuDoubleComplex *, const int *, const int *, int, bsrilu02Info_t, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZbsrilu02_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockDim, info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseSbsrilu02_bufferSizeExt( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, float *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockSize, + bsrilu02Info_t info, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + float *, const int *, const int *, int, bsrilu02Info_t, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSbsrilu02_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockSize, info, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseDbsrilu02_bufferSizeExt( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, double *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockSize, + bsrilu02Info_t info, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + double *, const int *, const int *, int, bsrilu02Info_t, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDbsrilu02_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockSize, info, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseCbsrilu02_bufferSizeExt( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, cuComplex *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockSize, + bsrilu02Info_t info, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + cuComplex *, const int *, const int *, int, bsrilu02Info_t, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCbsrilu02_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockSize, info, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseZbsrilu02_bufferSizeExt( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, cuDoubleComplex *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockSize, + bsrilu02Info_t info, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + cuDoubleComplex *, const int *, const int *, int, bsrilu02Info_t, + size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZbsrilu02_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockSize, info, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseSbsrilu02_analysis( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, float *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockDim, + bsrilu02Info_t info, cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + float *, const int *, const int *, int, bsrilu02Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSbsrilu02_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockDim, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDbsrilu02_analysis( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, double *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockDim, + bsrilu02Info_t info, cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + double *, const int *, const int *, int, bsrilu02Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDbsrilu02_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockDim, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCbsrilu02_analysis( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, cuComplex *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockDim, + bsrilu02Info_t info, cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + cuComplex *, const int *, const int *, int, bsrilu02Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCbsrilu02_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockDim, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseZbsrilu02_analysis( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, cuDoubleComplex *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockDim, + bsrilu02Info_t info, cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + cuDoubleComplex *, const int *, const int *, int, bsrilu02Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZbsrilu02_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockDim, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseSbsrilu02( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, float *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockDim, + bsrilu02Info_t info, cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + float *, const int *, const int *, int, bsrilu02Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSbsrilu02"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockDim, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDbsrilu02( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, double *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockDim, + bsrilu02Info_t info, cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + double *, const int *, const int *, int, bsrilu02Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDbsrilu02"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockDim, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCbsrilu02( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, cuComplex *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockDim, + bsrilu02Info_t info, cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + cuComplex *, const int *, const int *, int, bsrilu02Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCbsrilu02"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockDim, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseZbsrilu02( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, cuDoubleComplex *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockDim, + bsrilu02Info_t info, cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + cuDoubleComplex *, const int *, const int *, int, bsrilu02Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZbsrilu02"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockDim, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseXcsric02_zeroPivot(cusparseHandle_t handle, + csric02Info_t info, + int *position) { + using FuncPtr = + cusparseStatus_t(CUSPARSEAPI *)(cusparseHandle_t, csric02Info_t, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseXcsric02_zeroPivot"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, info, position); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsric02_bufferSize( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + float *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, csric02Info_t info, int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, float *, + const int *, const int *, csric02Info_t, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseScsric02_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsric02_bufferSize( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + double *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, csric02Info_t info, int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, double *, + const int *, const int *, csric02Info_t, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDcsric02_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsric02_bufferSize( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + cuComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, csric02Info_t info, int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, cuComplex *, + const int *, const int *, csric02Info_t, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCcsric02_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsric02_bufferSize( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + cuDoubleComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, csric02Info_t info, int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, cuDoubleComplex *, + const int *, const int *, csric02Info_t, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZcsric02_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsric02_bufferSizeExt( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + float *csrSortedVal, const int *csrSortedRowPtr, const int *csrSortedColInd, + csric02Info_t info, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, float *, + const int *, const int *, csric02Info_t, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseScsric02_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedVal, csrSortedRowPtr, + csrSortedColInd, info, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsric02_bufferSizeExt( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + double *csrSortedVal, const int *csrSortedRowPtr, + const int *csrSortedColInd, csric02Info_t info, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, double *, + const int *, const int *, csric02Info_t, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDcsric02_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedVal, csrSortedRowPtr, + csrSortedColInd, info, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsric02_bufferSizeExt( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + cuComplex *csrSortedVal, const int *csrSortedRowPtr, + const int *csrSortedColInd, csric02Info_t info, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, cuComplex *, + const int *, const int *, csric02Info_t, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCcsric02_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedVal, csrSortedRowPtr, + csrSortedColInd, info, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsric02_bufferSizeExt( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + cuDoubleComplex *csrSortedVal, const int *csrSortedRowPtr, + const int *csrSortedColInd, csric02Info_t info, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, cuDoubleComplex *, + const int *, const int *, csric02Info_t, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZcsric02_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedVal, csrSortedRowPtr, + csrSortedColInd, info, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsric02_analysis( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + const float *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, csric02Info_t info, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const float *, + const int *, const int *, csric02Info_t, cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseScsric02_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsric02_analysis( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + const double *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, csric02Info_t info, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const double *, + const int *, const int *, csric02Info_t, cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDcsric02_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsric02_analysis( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + const cuComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, csric02Info_t info, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const cuComplex *, + const int *, const int *, csric02Info_t, cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCcsric02_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsric02_analysis( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + const cuDoubleComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, csric02Info_t info, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, + const cuDoubleComplex *, const int *, const int *, csric02Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZcsric02_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsric02( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + float *csrSortedValA_valM, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, csric02Info_t info, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, float *, + const int *, const int *, csric02Info_t, cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseScsric02"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedValA_valM, csrSortedRowPtrA, + csrSortedColIndA, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsric02( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + double *csrSortedValA_valM, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, csric02Info_t info, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, double *, + const int *, const int *, csric02Info_t, cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDcsric02"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedValA_valM, csrSortedRowPtrA, + csrSortedColIndA, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsric02( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + cuComplex *csrSortedValA_valM, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, csric02Info_t info, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, cuComplex *, + const int *, const int *, csric02Info_t, cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCcsric02"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedValA_valM, csrSortedRowPtrA, + csrSortedColIndA, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsric02( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + cuDoubleComplex *csrSortedValA_valM, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, csric02Info_t info, + cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, cuDoubleComplex *, + const int *, const int *, csric02Info_t, cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZcsric02"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedValA_valM, csrSortedRowPtrA, + csrSortedColIndA, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseXbsric02_zeroPivot(cusparseHandle_t handle, + bsric02Info_t info, + int *position) { + using FuncPtr = + cusparseStatus_t(CUSPARSEAPI *)(cusparseHandle_t, bsric02Info_t, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseXbsric02_zeroPivot"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, info, position); +} + +cusparseStatus_t CUSPARSEAPI cusparseSbsric02_bufferSize( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, float *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockDim, + bsric02Info_t info, int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + float *, const int *, const int *, int, bsric02Info_t, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSbsric02_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockDim, info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseDbsric02_bufferSize( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, double *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockDim, + bsric02Info_t info, int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + double *, const int *, const int *, int, bsric02Info_t, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDbsric02_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockDim, info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseCbsric02_bufferSize( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, cuComplex *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockDim, + bsric02Info_t info, int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + cuComplex *, const int *, const int *, int, bsric02Info_t, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCbsric02_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockDim, info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseZbsric02_bufferSize( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, cuDoubleComplex *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockDim, + bsric02Info_t info, int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + cuDoubleComplex *, const int *, const int *, int, bsric02Info_t, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZbsric02_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockDim, info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseSbsric02_bufferSizeExt( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, float *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockSize, + bsric02Info_t info, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + float *, const int *, const int *, int, bsric02Info_t, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSbsric02_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockSize, info, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseDbsric02_bufferSizeExt( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, double *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockSize, + bsric02Info_t info, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + double *, const int *, const int *, int, bsric02Info_t, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDbsric02_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockSize, info, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseCbsric02_bufferSizeExt( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, cuComplex *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockSize, + bsric02Info_t info, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + cuComplex *, const int *, const int *, int, bsric02Info_t, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCbsric02_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockSize, info, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseZbsric02_bufferSizeExt( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, cuDoubleComplex *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockSize, + bsric02Info_t info, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + cuDoubleComplex *, const int *, const int *, int, bsric02Info_t, + size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZbsric02_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockSize, info, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseSbsric02_analysis( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, const float *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockDim, + bsric02Info_t info, cusparseSolvePolicy_t policy, void *pInputBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const float *, const int *, const int *, int, bsric02Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSbsric02_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockDim, info, policy, pInputBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDbsric02_analysis( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, const double *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockDim, + bsric02Info_t info, cusparseSolvePolicy_t policy, void *pInputBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const double *, const int *, const int *, int, bsric02Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDbsric02_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockDim, info, policy, pInputBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCbsric02_analysis( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, const cuComplex *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockDim, + bsric02Info_t info, cusparseSolvePolicy_t policy, void *pInputBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const cuComplex *, const int *, const int *, int, bsric02Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCbsric02_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockDim, info, policy, pInputBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseZbsric02_analysis( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, const cuDoubleComplex *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockDim, + bsric02Info_t info, cusparseSolvePolicy_t policy, void *pInputBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const cuDoubleComplex *, const int *, const int *, int, bsric02Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZbsric02_analysis"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockDim, info, policy, pInputBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseSbsric02( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, float *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockDim, + bsric02Info_t info, cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + float *, const int *, const int *, int, bsric02Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSbsric02"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockDim, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDbsric02( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, double *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockDim, + bsric02Info_t info, cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + double *, const int *, const int *, int, bsric02Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDbsric02"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockDim, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCbsric02( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, cuComplex *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockDim, + bsric02Info_t info, cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + cuComplex *, const int *, const int *, int, bsric02Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCbsric02"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockDim, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseZbsric02( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nnzb, + const cusparseMatDescr_t descrA, cuDoubleComplex *bsrSortedVal, + const int *bsrSortedRowPtr, const int *bsrSortedColInd, int blockDim, + bsric02Info_t info, cusparseSolvePolicy_t policy, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + cuDoubleComplex *, const int *, const int *, int, bsric02Info_t, + cusparseSolvePolicy_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZbsric02"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nnzb, descrA, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, blockDim, info, policy, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseSgtsv2_bufferSizeExt( + cusparseHandle_t handle, int m, int n, const float *dl, const float *d, + const float *du, const float *B, int ldb, size_t *bufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const float *, const float *, const float *, + const float *, int, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSgtsv2_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, dl, d, du, B, ldb, bufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseDgtsv2_bufferSizeExt( + cusparseHandle_t handle, int m, int n, const double *dl, const double *d, + const double *du, const double *B, int ldb, size_t *bufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const double *, const double *, + const double *, const double *, int, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDgtsv2_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, dl, d, du, B, ldb, bufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseCgtsv2_bufferSizeExt( + cusparseHandle_t handle, int m, int n, const cuComplex *dl, + const cuComplex *d, const cuComplex *du, const cuComplex *B, int ldb, + size_t *bufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cuComplex *, const cuComplex *, + const cuComplex *, const cuComplex *, int, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCgtsv2_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, dl, d, du, B, ldb, bufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseZgtsv2_bufferSizeExt( + cusparseHandle_t handle, int m, int n, const cuDoubleComplex *dl, + const cuDoubleComplex *d, const cuDoubleComplex *du, + const cuDoubleComplex *B, int ldb, size_t *bufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cuDoubleComplex *, + const cuDoubleComplex *, const cuDoubleComplex *, const cuDoubleComplex *, + int, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZgtsv2_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, dl, d, du, B, ldb, bufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseSgtsv2(cusparseHandle_t handle, int m, + int n, const float *dl, + const float *d, const float *du, + float *B, int ldb, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const float *, const float *, const float *, + float *, int, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSgtsv2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, dl, d, du, B, ldb, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDgtsv2(cusparseHandle_t handle, int m, + int n, const double *dl, + const double *d, const double *du, + double *B, int ldb, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const double *, const double *, + const double *, double *, int, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDgtsv2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, dl, d, du, B, ldb, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCgtsv2(cusparseHandle_t handle, int m, + int n, const cuComplex *dl, + const cuComplex *d, + const cuComplex *du, cuComplex *B, + int ldb, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cuComplex *, const cuComplex *, + const cuComplex *, cuComplex *, int, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCgtsv2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, dl, d, du, B, ldb, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseZgtsv2(cusparseHandle_t handle, int m, + int n, const cuDoubleComplex *dl, + const cuDoubleComplex *d, + const cuDoubleComplex *du, + cuDoubleComplex *B, int ldb, + void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cuDoubleComplex *, + const cuDoubleComplex *, const cuDoubleComplex *, cuDoubleComplex *, int, + void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZgtsv2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, dl, d, du, B, ldb, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseSgtsv2_nopivot_bufferSizeExt( + cusparseHandle_t handle, int m, int n, const float *dl, const float *d, + const float *du, const float *B, int ldb, size_t *bufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const float *, const float *, const float *, + const float *, int, size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusparseSgtsv2_nopivot_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, dl, d, du, B, ldb, bufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseDgtsv2_nopivot_bufferSizeExt( + cusparseHandle_t handle, int m, int n, const double *dl, const double *d, + const double *du, const double *B, int ldb, size_t *bufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const double *, const double *, + const double *, const double *, int, size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusparseDgtsv2_nopivot_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, dl, d, du, B, ldb, bufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseCgtsv2_nopivot_bufferSizeExt( + cusparseHandle_t handle, int m, int n, const cuComplex *dl, + const cuComplex *d, const cuComplex *du, const cuComplex *B, int ldb, + size_t *bufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cuComplex *, const cuComplex *, + const cuComplex *, const cuComplex *, int, size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusparseCgtsv2_nopivot_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, dl, d, du, B, ldb, bufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseZgtsv2_nopivot_bufferSizeExt( + cusparseHandle_t handle, int m, int n, const cuDoubleComplex *dl, + const cuDoubleComplex *d, const cuDoubleComplex *du, + const cuDoubleComplex *B, int ldb, size_t *bufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cuDoubleComplex *, + const cuDoubleComplex *, const cuDoubleComplex *, const cuDoubleComplex *, + int, size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusparseZgtsv2_nopivot_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, dl, d, du, B, ldb, bufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseSgtsv2_nopivot( + cusparseHandle_t handle, int m, int n, const float *dl, const float *d, + const float *du, float *B, int ldb, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const float *, const float *, const float *, + float *, int, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSgtsv2_nopivot"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, dl, d, du, B, ldb, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDgtsv2_nopivot( + cusparseHandle_t handle, int m, int n, const double *dl, const double *d, + const double *du, double *B, int ldb, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const double *, const double *, + const double *, double *, int, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDgtsv2_nopivot"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, dl, d, du, B, ldb, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCgtsv2_nopivot( + cusparseHandle_t handle, int m, int n, const cuComplex *dl, + const cuComplex *d, const cuComplex *du, cuComplex *B, int ldb, + void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cuComplex *, const cuComplex *, + const cuComplex *, cuComplex *, int, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCgtsv2_nopivot"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, dl, d, du, B, ldb, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseZgtsv2_nopivot( + cusparseHandle_t handle, int m, int n, const cuDoubleComplex *dl, + const cuDoubleComplex *d, const cuDoubleComplex *du, cuDoubleComplex *B, + int ldb, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cuDoubleComplex *, + const cuDoubleComplex *, const cuDoubleComplex *, cuDoubleComplex *, int, + void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZgtsv2_nopivot"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, dl, d, du, B, ldb, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseSgtsv2StridedBatch_bufferSizeExt( + cusparseHandle_t handle, int m, const float *dl, const float *d, + const float *du, const float *x, int batchCount, int batchStride, + size_t *bufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, const float *, const float *, const float *, + const float *, int, int, size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusparseSgtsv2StridedBatch_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, dl, d, du, x, batchCount, batchStride, + bufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseDgtsv2StridedBatch_bufferSizeExt( + cusparseHandle_t handle, int m, const double *dl, const double *d, + const double *du, const double *x, int batchCount, int batchStride, + size_t *bufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, const double *, const double *, const double *, + const double *, int, int, size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusparseDgtsv2StridedBatch_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, dl, d, du, x, batchCount, batchStride, + bufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseCgtsv2StridedBatch_bufferSizeExt( + cusparseHandle_t handle, int m, const cuComplex *dl, const cuComplex *d, + const cuComplex *du, const cuComplex *x, int batchCount, int batchStride, + size_t *bufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, const cuComplex *, const cuComplex *, + const cuComplex *, const cuComplex *, int, int, size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusparseCgtsv2StridedBatch_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, dl, d, du, x, batchCount, batchStride, + bufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseZgtsv2StridedBatch_bufferSizeExt( + cusparseHandle_t handle, int m, const cuDoubleComplex *dl, + const cuDoubleComplex *d, const cuDoubleComplex *du, + const cuDoubleComplex *x, int batchCount, int batchStride, + size_t *bufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, const cuDoubleComplex *, const cuDoubleComplex *, + const cuDoubleComplex *, const cuDoubleComplex *, int, int, size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusparseZgtsv2StridedBatch_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, dl, d, du, x, batchCount, batchStride, + bufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseSgtsv2StridedBatch( + cusparseHandle_t handle, int m, const float *dl, const float *d, + const float *du, float *x, int batchCount, int batchStride, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, const float *, const float *, const float *, + float *, int, int, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSgtsv2StridedBatch"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, dl, d, du, x, batchCount, batchStride, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI +cusparseDgtsv2StridedBatch(cusparseHandle_t handle, int m, const double *dl, + const double *d, const double *du, double *x, + int batchCount, int batchStride, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, const double *, const double *, const double *, + double *, int, int, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDgtsv2StridedBatch"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, dl, d, du, x, batchCount, batchStride, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCgtsv2StridedBatch( + cusparseHandle_t handle, int m, const cuComplex *dl, const cuComplex *d, + const cuComplex *du, cuComplex *x, int batchCount, int batchStride, + void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, const cuComplex *, const cuComplex *, + const cuComplex *, cuComplex *, int, int, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCgtsv2StridedBatch"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, dl, d, du, x, batchCount, batchStride, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseZgtsv2StridedBatch( + cusparseHandle_t handle, int m, const cuDoubleComplex *dl, + const cuDoubleComplex *d, const cuDoubleComplex *du, cuDoubleComplex *x, + int batchCount, int batchStride, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, const cuDoubleComplex *, const cuDoubleComplex *, + const cuDoubleComplex *, cuDoubleComplex *, int, int, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZgtsv2StridedBatch"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, dl, d, du, x, batchCount, batchStride, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseSgtsvInterleavedBatch_bufferSizeExt( + cusparseHandle_t handle, int algo, int m, const float *dl, const float *d, + const float *du, const float *x, int batchCount, + size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const float *, const float *, const float *, + const float *, int, size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusparseSgtsvInterleavedBatch_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, algo, m, dl, d, du, x, batchCount, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseDgtsvInterleavedBatch_bufferSizeExt( + cusparseHandle_t handle, int algo, int m, const double *dl, const double *d, + const double *du, const double *x, int batchCount, + size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const double *, const double *, + const double *, const double *, int, size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusparseDgtsvInterleavedBatch_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, algo, m, dl, d, du, x, batchCount, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseCgtsvInterleavedBatch_bufferSizeExt( + cusparseHandle_t handle, int algo, int m, const cuComplex *dl, + const cuComplex *d, const cuComplex *du, const cuComplex *x, int batchCount, + size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cuComplex *, const cuComplex *, + const cuComplex *, const cuComplex *, int, size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusparseCgtsvInterleavedBatch_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, algo, m, dl, d, du, x, batchCount, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseZgtsvInterleavedBatch_bufferSizeExt( + cusparseHandle_t handle, int algo, int m, const cuDoubleComplex *dl, + const cuDoubleComplex *d, const cuDoubleComplex *du, + const cuDoubleComplex *x, int batchCount, size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cuDoubleComplex *, + const cuDoubleComplex *, const cuDoubleComplex *, const cuDoubleComplex *, + int, size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusparseZgtsvInterleavedBatch_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, algo, m, dl, d, du, x, batchCount, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseSgtsvInterleavedBatch( + cusparseHandle_t handle, int algo, int m, float *dl, float *d, float *du, + float *x, int batchCount, void *pBuffer) { + using FuncPtr = + cusparseStatus_t(CUSPARSEAPI *)(cusparseHandle_t, int, int, float *, + float *, float *, float *, int, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSgtsvInterleavedBatch"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, algo, m, dl, d, du, x, batchCount, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDgtsvInterleavedBatch( + cusparseHandle_t handle, int algo, int m, double *dl, double *d, double *du, + double *x, int batchCount, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(cusparseHandle_t, int, int, + double *, double *, double *, + double *, int, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDgtsvInterleavedBatch"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, algo, m, dl, d, du, x, batchCount, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCgtsvInterleavedBatch( + cusparseHandle_t handle, int algo, int m, cuComplex *dl, cuComplex *d, + cuComplex *du, cuComplex *x, int batchCount, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, cuComplex *, cuComplex *, cuComplex *, + cuComplex *, int, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCgtsvInterleavedBatch"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, algo, m, dl, d, du, x, batchCount, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseZgtsvInterleavedBatch( + cusparseHandle_t handle, int algo, int m, cuDoubleComplex *dl, + cuDoubleComplex *d, cuDoubleComplex *du, cuDoubleComplex *x, int batchCount, + void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, cuDoubleComplex *, cuDoubleComplex *, + cuDoubleComplex *, cuDoubleComplex *, int, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZgtsvInterleavedBatch"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, algo, m, dl, d, du, x, batchCount, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseSgpsvInterleavedBatch_bufferSizeExt( + cusparseHandle_t handle, int algo, int m, const float *ds, const float *dl, + const float *d, const float *du, const float *dw, const float *x, + int batchCount, size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const float *, const float *, const float *, + const float *, const float *, const float *, int, size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusparseSgpsvInterleavedBatch_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, algo, m, ds, dl, d, du, dw, x, batchCount, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseDgpsvInterleavedBatch_bufferSizeExt( + cusparseHandle_t handle, int algo, int m, const double *ds, + const double *dl, const double *d, const double *du, const double *dw, + const double *x, int batchCount, size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const double *, const double *, + const double *, const double *, const double *, const double *, int, + size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusparseDgpsvInterleavedBatch_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, algo, m, ds, dl, d, du, dw, x, batchCount, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseCgpsvInterleavedBatch_bufferSizeExt( + cusparseHandle_t handle, int algo, int m, const cuComplex *ds, + const cuComplex *dl, const cuComplex *d, const cuComplex *du, + const cuComplex *dw, const cuComplex *x, int batchCount, + size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cuComplex *, const cuComplex *, + const cuComplex *, const cuComplex *, const cuComplex *, + const cuComplex *, int, size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusparseCgpsvInterleavedBatch_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, algo, m, ds, dl, d, du, dw, x, batchCount, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseZgpsvInterleavedBatch_bufferSizeExt( + cusparseHandle_t handle, int algo, int m, const cuDoubleComplex *ds, + const cuDoubleComplex *dl, const cuDoubleComplex *d, + const cuDoubleComplex *du, const cuDoubleComplex *dw, + const cuDoubleComplex *x, int batchCount, size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cuDoubleComplex *, + const cuDoubleComplex *, const cuDoubleComplex *, const cuDoubleComplex *, + const cuDoubleComplex *, const cuDoubleComplex *, int, size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusparseZgpsvInterleavedBatch_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, algo, m, ds, dl, d, du, dw, x, batchCount, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseSgpsvInterleavedBatch( + cusparseHandle_t handle, int algo, int m, float *ds, float *dl, float *d, + float *du, float *dw, float *x, int batchCount, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, float *, float *, float *, float *, float *, + float *, int, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSgpsvInterleavedBatch"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, algo, m, ds, dl, d, du, dw, x, batchCount, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDgpsvInterleavedBatch( + cusparseHandle_t handle, int algo, int m, double *ds, double *dl, double *d, + double *du, double *dw, double *x, int batchCount, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, double *, double *, double *, double *, + double *, double *, int, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDgpsvInterleavedBatch"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, algo, m, ds, dl, d, du, dw, x, batchCount, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCgpsvInterleavedBatch( + cusparseHandle_t handle, int algo, int m, cuComplex *ds, cuComplex *dl, + cuComplex *d, cuComplex *du, cuComplex *dw, cuComplex *x, int batchCount, + void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, cuComplex *, cuComplex *, cuComplex *, + cuComplex *, cuComplex *, cuComplex *, int, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCgpsvInterleavedBatch"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, algo, m, ds, dl, d, du, dw, x, batchCount, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseZgpsvInterleavedBatch( + cusparseHandle_t handle, int algo, int m, cuDoubleComplex *ds, + cuDoubleComplex *dl, cuDoubleComplex *d, cuDoubleComplex *du, + cuDoubleComplex *dw, cuDoubleComplex *x, int batchCount, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, cuDoubleComplex *, cuDoubleComplex *, + cuDoubleComplex *, cuDoubleComplex *, cuDoubleComplex *, + cuDoubleComplex *, int, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZgpsvInterleavedBatch"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, algo, m, ds, dl, d, du, dw, x, batchCount, pBuffer); +} + +CUSPARSE_DEPRECATED(cusparseSpGEMM) +cusparseStatus_t CUSPARSEAPI cusparseCreateCsrgemm2Info(csrgemm2Info_t *info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(csrgemm2Info_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCreateCsrgemm2Info"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +CUSPARSE_DEPRECATED(cusparseSpGEMM) +cusparseStatus_t CUSPARSEAPI cusparseDestroyCsrgemm2Info(csrgemm2Info_t info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(csrgemm2Info_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDestroyCsrgemm2Info"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(info); +} + +CUSPARSE_DEPRECATED(cusparseSpGEMM) +cusparseStatus_t CUSPARSEAPI cusparseScsrgemm2_bufferSizeExt( + cusparseHandle_t handle, int m, int n, int k, const float *alpha, + const cusparseMatDescr_t descrA, int nnzA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, const cusparseMatDescr_t descrB, int nnzB, + const int *csrSortedRowPtrB, const int *csrSortedColIndB, const float *beta, + const cusparseMatDescr_t descrD, int nnzD, const int *csrSortedRowPtrD, + const int *csrSortedColIndD, csrgemm2Info_t info, + size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const float *, const cusparseMatDescr_t, + int, const int *, const int *, const cusparseMatDescr_t, int, const int *, + const int *, const float *, const cusparseMatDescr_t, int, const int *, + const int *, csrgemm2Info_t, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseScsrgemm2_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, k, alpha, descrA, nnzA, csrSortedRowPtrA, + csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB, + csrSortedColIndB, beta, descrD, nnzD, csrSortedRowPtrD, + csrSortedColIndD, info, pBufferSizeInBytes); +} + +CUSPARSE_DEPRECATED(cusparseSpGEMM) +cusparseStatus_t CUSPARSEAPI cusparseDcsrgemm2_bufferSizeExt( + cusparseHandle_t handle, int m, int n, int k, const double *alpha, + const cusparseMatDescr_t descrA, int nnzA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, const cusparseMatDescr_t descrB, int nnzB, + const int *csrSortedRowPtrB, const int *csrSortedColIndB, + const double *beta, const cusparseMatDescr_t descrD, int nnzD, + const int *csrSortedRowPtrD, const int *csrSortedColIndD, + csrgemm2Info_t info, size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const double *, const cusparseMatDescr_t, + int, const int *, const int *, const cusparseMatDescr_t, int, const int *, + const int *, const double *, const cusparseMatDescr_t, int, const int *, + const int *, csrgemm2Info_t, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDcsrgemm2_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, k, alpha, descrA, nnzA, csrSortedRowPtrA, + csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB, + csrSortedColIndB, beta, descrD, nnzD, csrSortedRowPtrD, + csrSortedColIndD, info, pBufferSizeInBytes); +} + +CUSPARSE_DEPRECATED(cusparseSpGEMM) +cusparseStatus_t CUSPARSEAPI cusparseCcsrgemm2_bufferSizeExt( + cusparseHandle_t handle, int m, int n, int k, const cuComplex *alpha, + const cusparseMatDescr_t descrA, int nnzA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, const cusparseMatDescr_t descrB, int nnzB, + const int *csrSortedRowPtrB, const int *csrSortedColIndB, + const cuComplex *beta, const cusparseMatDescr_t descrD, int nnzD, + const int *csrSortedRowPtrD, const int *csrSortedColIndD, + csrgemm2Info_t info, size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cuComplex *, + const cusparseMatDescr_t, int, const int *, const int *, + const cusparseMatDescr_t, int, const int *, const int *, + const cuComplex *, const cusparseMatDescr_t, int, const int *, + const int *, csrgemm2Info_t, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCcsrgemm2_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, k, alpha, descrA, nnzA, csrSortedRowPtrA, + csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB, + csrSortedColIndB, beta, descrD, nnzD, csrSortedRowPtrD, + csrSortedColIndD, info, pBufferSizeInBytes); +} + +CUSPARSE_DEPRECATED(cusparseSpGEMM) +cusparseStatus_t CUSPARSEAPI cusparseZcsrgemm2_bufferSizeExt( + cusparseHandle_t handle, int m, int n, int k, const cuDoubleComplex *alpha, + const cusparseMatDescr_t descrA, int nnzA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, const cusparseMatDescr_t descrB, int nnzB, + const int *csrSortedRowPtrB, const int *csrSortedColIndB, + const cuDoubleComplex *beta, const cusparseMatDescr_t descrD, int nnzD, + const int *csrSortedRowPtrD, const int *csrSortedColIndD, + csrgemm2Info_t info, size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cuDoubleComplex *, + const cusparseMatDescr_t, int, const int *, const int *, + const cusparseMatDescr_t, int, const int *, const int *, + const cuDoubleComplex *, const cusparseMatDescr_t, int, const int *, + const int *, csrgemm2Info_t, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZcsrgemm2_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, k, alpha, descrA, nnzA, csrSortedRowPtrA, + csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB, + csrSortedColIndB, beta, descrD, nnzD, csrSortedRowPtrD, + csrSortedColIndD, info, pBufferSizeInBytes); +} + +CUSPARSE_DEPRECATED(cusparseSpGEMM) +cusparseStatus_t CUSPARSEAPI cusparseXcsrgemm2Nnz( + cusparseHandle_t handle, int m, int n, int k, + const cusparseMatDescr_t descrA, int nnzA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, const cusparseMatDescr_t descrB, int nnzB, + const int *csrSortedRowPtrB, const int *csrSortedColIndB, + const cusparseMatDescr_t descrD, int nnzD, const int *csrSortedRowPtrD, + const int *csrSortedColIndD, const cusparseMatDescr_t descrC, + int *csrSortedRowPtrC, int *nnzTotalDevHostPtr, const csrgemm2Info_t info, + void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cusparseMatDescr_t, int, + const int *, const int *, const cusparseMatDescr_t, int, const int *, + const int *, const cusparseMatDescr_t, int, const int *, const int *, + const cusparseMatDescr_t, int *, int *, const csrgemm2Info_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseXcsrgemm2Nnz"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, k, descrA, nnzA, csrSortedRowPtrA, + csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB, + csrSortedColIndB, descrD, nnzD, csrSortedRowPtrD, + csrSortedColIndD, descrC, csrSortedRowPtrC, + nnzTotalDevHostPtr, info, pBuffer); +} + +CUSPARSE_DEPRECATED(cusparseSpGEMM) +cusparseStatus_t CUSPARSEAPI cusparseScsrgemm2( + cusparseHandle_t handle, int m, int n, int k, const float *alpha, + const cusparseMatDescr_t descrA, int nnzA, const float *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + const cusparseMatDescr_t descrB, int nnzB, const float *csrSortedValB, + const int *csrSortedRowPtrB, const int *csrSortedColIndB, const float *beta, + const cusparseMatDescr_t descrD, int nnzD, const float *csrSortedValD, + const int *csrSortedRowPtrD, const int *csrSortedColIndD, + const cusparseMatDescr_t descrC, float *csrSortedValC, + const int *csrSortedRowPtrC, int *csrSortedColIndC, + const csrgemm2Info_t info, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const float *, const cusparseMatDescr_t, + int, const float *, const int *, const int *, const cusparseMatDescr_t, + int, const float *, const int *, const int *, const float *, + const cusparseMatDescr_t, int, const float *, const int *, const int *, + const cusparseMatDescr_t, float *, const int *, int *, + const csrgemm2Info_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseScsrgemm2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, k, alpha, descrA, nnzA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, descrB, nnzB, + csrSortedValB, csrSortedRowPtrB, csrSortedColIndB, beta, + descrD, nnzD, csrSortedValD, csrSortedRowPtrD, + csrSortedColIndD, descrC, csrSortedValC, csrSortedRowPtrC, + csrSortedColIndC, info, pBuffer); +} + +CUSPARSE_DEPRECATED(cusparseSpGEMM) +cusparseStatus_t CUSPARSEAPI cusparseDcsrgemm2( + cusparseHandle_t handle, int m, int n, int k, const double *alpha, + const cusparseMatDescr_t descrA, int nnzA, const double *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + const cusparseMatDescr_t descrB, int nnzB, const double *csrSortedValB, + const int *csrSortedRowPtrB, const int *csrSortedColIndB, + const double *beta, const cusparseMatDescr_t descrD, int nnzD, + const double *csrSortedValD, const int *csrSortedRowPtrD, + const int *csrSortedColIndD, const cusparseMatDescr_t descrC, + double *csrSortedValC, const int *csrSortedRowPtrC, int *csrSortedColIndC, + const csrgemm2Info_t info, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const double *, const cusparseMatDescr_t, + int, const double *, const int *, const int *, const cusparseMatDescr_t, + int, const double *, const int *, const int *, const double *, + const cusparseMatDescr_t, int, const double *, const int *, const int *, + const cusparseMatDescr_t, double *, const int *, int *, + const csrgemm2Info_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDcsrgemm2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, k, alpha, descrA, nnzA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, descrB, nnzB, + csrSortedValB, csrSortedRowPtrB, csrSortedColIndB, beta, + descrD, nnzD, csrSortedValD, csrSortedRowPtrD, + csrSortedColIndD, descrC, csrSortedValC, csrSortedRowPtrC, + csrSortedColIndC, info, pBuffer); +} + +CUSPARSE_DEPRECATED(cusparseSpGEMM) +cusparseStatus_t CUSPARSEAPI cusparseCcsrgemm2( + cusparseHandle_t handle, int m, int n, int k, const cuComplex *alpha, + const cusparseMatDescr_t descrA, int nnzA, const cuComplex *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + const cusparseMatDescr_t descrB, int nnzB, const cuComplex *csrSortedValB, + const int *csrSortedRowPtrB, const int *csrSortedColIndB, + const cuComplex *beta, const cusparseMatDescr_t descrD, int nnzD, + const cuComplex *csrSortedValD, const int *csrSortedRowPtrD, + const int *csrSortedColIndD, const cusparseMatDescr_t descrC, + cuComplex *csrSortedValC, const int *csrSortedRowPtrC, + int *csrSortedColIndC, const csrgemm2Info_t info, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cuComplex *, + const cusparseMatDescr_t, int, const cuComplex *, const int *, + const int *, const cusparseMatDescr_t, int, const cuComplex *, + const int *, const int *, const cuComplex *, const cusparseMatDescr_t, + int, const cuComplex *, const int *, const int *, + const cusparseMatDescr_t, cuComplex *, const int *, int *, + const csrgemm2Info_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCcsrgemm2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, k, alpha, descrA, nnzA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, descrB, nnzB, + csrSortedValB, csrSortedRowPtrB, csrSortedColIndB, beta, + descrD, nnzD, csrSortedValD, csrSortedRowPtrD, + csrSortedColIndD, descrC, csrSortedValC, csrSortedRowPtrC, + csrSortedColIndC, info, pBuffer); +} + +CUSPARSE_DEPRECATED(cusparseSpGEMM) +cusparseStatus_t CUSPARSEAPI cusparseZcsrgemm2( + cusparseHandle_t handle, int m, int n, int k, const cuDoubleComplex *alpha, + const cusparseMatDescr_t descrA, int nnzA, + const cuDoubleComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, const cusparseMatDescr_t descrB, int nnzB, + const cuDoubleComplex *csrSortedValB, const int *csrSortedRowPtrB, + const int *csrSortedColIndB, const cuDoubleComplex *beta, + const cusparseMatDescr_t descrD, int nnzD, + const cuDoubleComplex *csrSortedValD, const int *csrSortedRowPtrD, + const int *csrSortedColIndD, const cusparseMatDescr_t descrC, + cuDoubleComplex *csrSortedValC, const int *csrSortedRowPtrC, + int *csrSortedColIndC, const csrgemm2Info_t info, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cuDoubleComplex *, + const cusparseMatDescr_t, int, const cuDoubleComplex *, const int *, + const int *, const cusparseMatDescr_t, int, const cuDoubleComplex *, + const int *, const int *, const cuDoubleComplex *, + const cusparseMatDescr_t, int, const cuDoubleComplex *, const int *, + const int *, const cusparseMatDescr_t, cuDoubleComplex *, const int *, + int *, const csrgemm2Info_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZcsrgemm2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, k, alpha, descrA, nnzA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, descrB, nnzB, + csrSortedValB, csrSortedRowPtrB, csrSortedColIndB, beta, + descrD, nnzD, csrSortedValD, csrSortedRowPtrD, + csrSortedColIndD, descrC, csrSortedValC, csrSortedRowPtrC, + csrSortedColIndC, info, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsrgeam2_bufferSizeExt( + cusparseHandle_t handle, int m, int n, const float *alpha, + const cusparseMatDescr_t descrA, int nnzA, const float *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, const float *beta, + const cusparseMatDescr_t descrB, int nnzB, const float *csrSortedValB, + const int *csrSortedRowPtrB, const int *csrSortedColIndB, + const cusparseMatDescr_t descrC, const float *csrSortedValC, + const int *csrSortedRowPtrC, const int *csrSortedColIndC, + size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const float *, const cusparseMatDescr_t, int, + const float *, const int *, const int *, const float *, + const cusparseMatDescr_t, int, const float *, const int *, const int *, + const cusparseMatDescr_t, const float *, const int *, const int *, + size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseScsrgeam2_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, alpha, descrA, nnzA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, beta, descrB, nnzB, + csrSortedValB, csrSortedRowPtrB, csrSortedColIndB, descrC, + csrSortedValC, csrSortedRowPtrC, csrSortedColIndC, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsrgeam2_bufferSizeExt( + cusparseHandle_t handle, int m, int n, const double *alpha, + const cusparseMatDescr_t descrA, int nnzA, const double *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + const double *beta, const cusparseMatDescr_t descrB, int nnzB, + const double *csrSortedValB, const int *csrSortedRowPtrB, + const int *csrSortedColIndB, const cusparseMatDescr_t descrC, + const double *csrSortedValC, const int *csrSortedRowPtrC, + const int *csrSortedColIndC, size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const double *, const cusparseMatDescr_t, int, + const double *, const int *, const int *, const double *, + const cusparseMatDescr_t, int, const double *, const int *, const int *, + const cusparseMatDescr_t, const double *, const int *, const int *, + size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDcsrgeam2_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, alpha, descrA, nnzA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, beta, descrB, nnzB, + csrSortedValB, csrSortedRowPtrB, csrSortedColIndB, descrC, + csrSortedValC, csrSortedRowPtrC, csrSortedColIndC, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsrgeam2_bufferSizeExt( + cusparseHandle_t handle, int m, int n, const cuComplex *alpha, + const cusparseMatDescr_t descrA, int nnzA, const cuComplex *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + const cuComplex *beta, const cusparseMatDescr_t descrB, int nnzB, + const cuComplex *csrSortedValB, const int *csrSortedRowPtrB, + const int *csrSortedColIndB, const cusparseMatDescr_t descrC, + const cuComplex *csrSortedValC, const int *csrSortedRowPtrC, + const int *csrSortedColIndC, size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cuComplex *, const cusparseMatDescr_t, + int, const cuComplex *, const int *, const int *, const cuComplex *, + const cusparseMatDescr_t, int, const cuComplex *, const int *, + const int *, const cusparseMatDescr_t, const cuComplex *, const int *, + const int *, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCcsrgeam2_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, alpha, descrA, nnzA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, beta, descrB, nnzB, + csrSortedValB, csrSortedRowPtrB, csrSortedColIndB, descrC, + csrSortedValC, csrSortedRowPtrC, csrSortedColIndC, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsrgeam2_bufferSizeExt( + cusparseHandle_t handle, int m, int n, const cuDoubleComplex *alpha, + const cusparseMatDescr_t descrA, int nnzA, + const cuDoubleComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, const cuDoubleComplex *beta, + const cusparseMatDescr_t descrB, int nnzB, + const cuDoubleComplex *csrSortedValB, const int *csrSortedRowPtrB, + const int *csrSortedColIndB, const cusparseMatDescr_t descrC, + const cuDoubleComplex *csrSortedValC, const int *csrSortedRowPtrC, + const int *csrSortedColIndC, size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cuDoubleComplex *, + const cusparseMatDescr_t, int, const cuDoubleComplex *, const int *, + const int *, const cuDoubleComplex *, const cusparseMatDescr_t, int, + const cuDoubleComplex *, const int *, const int *, + const cusparseMatDescr_t, const cuDoubleComplex *, const int *, + const int *, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZcsrgeam2_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, alpha, descrA, nnzA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, beta, descrB, nnzB, + csrSortedValB, csrSortedRowPtrB, csrSortedColIndB, descrC, + csrSortedValC, csrSortedRowPtrC, csrSortedColIndC, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseXcsrgeam2Nnz( + cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, + int nnzA, const int *csrSortedRowPtrA, const int *csrSortedColIndA, + const cusparseMatDescr_t descrB, int nnzB, const int *csrSortedRowPtrB, + const int *csrSortedColIndB, const cusparseMatDescr_t descrC, + int *csrSortedRowPtrC, int *nnzTotalDevHostPtr, void *workspace) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, int, const int *, + const int *, const cusparseMatDescr_t, int, const int *, const int *, + const cusparseMatDescr_t, int *, int *, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseXcsrgeam2Nnz"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, descrA, nnzA, csrSortedRowPtrA, + csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB, + csrSortedColIndB, descrC, csrSortedRowPtrC, + nnzTotalDevHostPtr, workspace); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsrgeam2( + cusparseHandle_t handle, int m, int n, const float *alpha, + const cusparseMatDescr_t descrA, int nnzA, const float *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, const float *beta, + const cusparseMatDescr_t descrB, int nnzB, const float *csrSortedValB, + const int *csrSortedRowPtrB, const int *csrSortedColIndB, + const cusparseMatDescr_t descrC, float *csrSortedValC, + int *csrSortedRowPtrC, int *csrSortedColIndC, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const float *, const cusparseMatDescr_t, int, + const float *, const int *, const int *, const float *, + const cusparseMatDescr_t, int, const float *, const int *, const int *, + const cusparseMatDescr_t, float *, int *, int *, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseScsrgeam2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, alpha, descrA, nnzA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, beta, descrB, nnzB, + csrSortedValB, csrSortedRowPtrB, csrSortedColIndB, descrC, + csrSortedValC, csrSortedRowPtrC, csrSortedColIndC, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsrgeam2( + cusparseHandle_t handle, int m, int n, const double *alpha, + const cusparseMatDescr_t descrA, int nnzA, const double *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + const double *beta, const cusparseMatDescr_t descrB, int nnzB, + const double *csrSortedValB, const int *csrSortedRowPtrB, + const int *csrSortedColIndB, const cusparseMatDescr_t descrC, + double *csrSortedValC, int *csrSortedRowPtrC, int *csrSortedColIndC, + void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const double *, const cusparseMatDescr_t, int, + const double *, const int *, const int *, const double *, + const cusparseMatDescr_t, int, const double *, const int *, const int *, + const cusparseMatDescr_t, double *, int *, int *, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDcsrgeam2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, alpha, descrA, nnzA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, beta, descrB, nnzB, + csrSortedValB, csrSortedRowPtrB, csrSortedColIndB, descrC, + csrSortedValC, csrSortedRowPtrC, csrSortedColIndC, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsrgeam2( + cusparseHandle_t handle, int m, int n, const cuComplex *alpha, + const cusparseMatDescr_t descrA, int nnzA, const cuComplex *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + const cuComplex *beta, const cusparseMatDescr_t descrB, int nnzB, + const cuComplex *csrSortedValB, const int *csrSortedRowPtrB, + const int *csrSortedColIndB, const cusparseMatDescr_t descrC, + cuComplex *csrSortedValC, int *csrSortedRowPtrC, int *csrSortedColIndC, + void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cuComplex *, const cusparseMatDescr_t, + int, const cuComplex *, const int *, const int *, const cuComplex *, + const cusparseMatDescr_t, int, const cuComplex *, const int *, + const int *, const cusparseMatDescr_t, cuComplex *, int *, int *, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCcsrgeam2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, alpha, descrA, nnzA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, beta, descrB, nnzB, + csrSortedValB, csrSortedRowPtrB, csrSortedColIndB, descrC, + csrSortedValC, csrSortedRowPtrC, csrSortedColIndC, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsrgeam2( + cusparseHandle_t handle, int m, int n, const cuDoubleComplex *alpha, + const cusparseMatDescr_t descrA, int nnzA, + const cuDoubleComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, const cuDoubleComplex *beta, + const cusparseMatDescr_t descrB, int nnzB, + const cuDoubleComplex *csrSortedValB, const int *csrSortedRowPtrB, + const int *csrSortedColIndB, const cusparseMatDescr_t descrC, + cuDoubleComplex *csrSortedValC, int *csrSortedRowPtrC, + int *csrSortedColIndC, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cuDoubleComplex *, + const cusparseMatDescr_t, int, const cuDoubleComplex *, const int *, + const int *, const cuDoubleComplex *, const cusparseMatDescr_t, int, + const cuDoubleComplex *, const int *, const int *, + const cusparseMatDescr_t, cuDoubleComplex *, int *, int *, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZcsrgeam2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, alpha, descrA, nnzA, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA, beta, descrB, nnzB, + csrSortedValB, csrSortedRowPtrB, csrSortedColIndB, descrC, + csrSortedValC, csrSortedRowPtrC, csrSortedColIndC, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsrcolor( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + const float *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, const float *fractionToColor, int *ncolors, + int *coloring, int *reordering, const cusparseColorInfo_t info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const float *, + const int *, const int *, const float *, int *, int *, int *, + const cusparseColorInfo_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseScsrcolor"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, fractionToColor, ncolors, coloring, + reordering, info); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsrcolor( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + const double *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, const double *fractionToColor, int *ncolors, + int *coloring, int *reordering, const cusparseColorInfo_t info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const double *, + const int *, const int *, const double *, int *, int *, int *, + const cusparseColorInfo_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDcsrcolor"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, fractionToColor, ncolors, coloring, + reordering, info); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsrcolor( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + const cuComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, const float *fractionToColor, int *ncolors, + int *coloring, int *reordering, const cusparseColorInfo_t info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const cuComplex *, + const int *, const int *, const float *, int *, int *, int *, + const cusparseColorInfo_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCcsrcolor"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, fractionToColor, ncolors, coloring, + reordering, info); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsrcolor( + cusparseHandle_t handle, int m, int nnz, const cusparseMatDescr_t descrA, + const cuDoubleComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, const double *fractionToColor, int *ncolors, + int *coloring, int *reordering, const cusparseColorInfo_t info) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, + const cuDoubleComplex *, const int *, const int *, const double *, int *, + int *, int *, const cusparseColorInfo_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZcsrcolor"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, nnz, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, fractionToColor, ncolors, coloring, + reordering, info); +} + +cusparseStatus_t CUSPARSEAPI +cusparseSnnz(cusparseHandle_t handle, cusparseDirection_t dirA, int m, int n, + const cusparseMatDescr_t descrA, const float *A, int lda, + int *nnzPerRowCol, int *nnzTotalDevHostPtr) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const float *, int, int *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSnnz"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, m, n, descrA, A, lda, nnzPerRowCol, + nnzTotalDevHostPtr); +} + +cusparseStatus_t CUSPARSEAPI +cusparseDnnz(cusparseHandle_t handle, cusparseDirection_t dirA, int m, int n, + const cusparseMatDescr_t descrA, const double *A, int lda, + int *nnzPerRowCol, int *nnzTotalDevHostPtr) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const double *, int, int *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDnnz"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, m, n, descrA, A, lda, nnzPerRowCol, + nnzTotalDevHostPtr); +} + +cusparseStatus_t CUSPARSEAPI +cusparseCnnz(cusparseHandle_t handle, cusparseDirection_t dirA, int m, int n, + const cusparseMatDescr_t descrA, const cuComplex *A, int lda, + int *nnzPerRowCol, int *nnzTotalDevHostPtr) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const cuComplex *, int, int *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCnnz"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, m, n, descrA, A, lda, nnzPerRowCol, + nnzTotalDevHostPtr); +} + +cusparseStatus_t CUSPARSEAPI +cusparseZnnz(cusparseHandle_t handle, cusparseDirection_t dirA, int m, int n, + const cusparseMatDescr_t descrA, const cuDoubleComplex *A, int lda, + int *nnzPerRowCol, int *nnzTotalDevHostPtr) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const cuDoubleComplex *, int, int *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZnnz"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, m, n, descrA, A, lda, nnzPerRowCol, + nnzTotalDevHostPtr); +} + +cusparseStatus_t CUSPARSEAPI cusparseSnnz_compress( + cusparseHandle_t handle, int m, const cusparseMatDescr_t descr, + const float *csrSortedValA, const int *csrSortedRowPtrA, int *nnzPerRow, + int *nnzC, float tol) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, const cusparseMatDescr_t, const float *, + const int *, int *, int *, float); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSnnz_compress"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, descr, csrSortedValA, csrSortedRowPtrA, nnzPerRow, + nnzC, tol); +} + +cusparseStatus_t CUSPARSEAPI cusparseDnnz_compress( + cusparseHandle_t handle, int m, const cusparseMatDescr_t descr, + const double *csrSortedValA, const int *csrSortedRowPtrA, int *nnzPerRow, + int *nnzC, double tol) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, const cusparseMatDescr_t, const double *, + const int *, int *, int *, double); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDnnz_compress"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, descr, csrSortedValA, csrSortedRowPtrA, nnzPerRow, + nnzC, tol); +} + +cusparseStatus_t CUSPARSEAPI cusparseCnnz_compress( + cusparseHandle_t handle, int m, const cusparseMatDescr_t descr, + const cuComplex *csrSortedValA, const int *csrSortedRowPtrA, int *nnzPerRow, + int *nnzC, cuComplex tol) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, const cusparseMatDescr_t, const cuComplex *, + const int *, int *, int *, cuComplex); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCnnz_compress"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, descr, csrSortedValA, csrSortedRowPtrA, nnzPerRow, + nnzC, tol); +} + +cusparseStatus_t CUSPARSEAPI cusparseZnnz_compress( + cusparseHandle_t handle, int m, const cusparseMatDescr_t descr, + const cuDoubleComplex *csrSortedValA, const int *csrSortedRowPtrA, + int *nnzPerRow, int *nnzC, cuDoubleComplex tol) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, const cusparseMatDescr_t, const cuDoubleComplex *, + const int *, int *, int *, cuDoubleComplex); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZnnz_compress"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, descr, csrSortedValA, csrSortedRowPtrA, nnzPerRow, + nnzC, tol); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsr2csr_compress( + cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, + const float *csrSortedValA, const int *csrSortedColIndA, + const int *csrSortedRowPtrA, int nnzA, const int *nnzPerRow, + float *csrSortedValC, int *csrSortedColIndC, int *csrSortedRowPtrC, + float tol) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const float *, + const int *, const int *, int, const int *, float *, int *, int *, float); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseScsr2csr_compress"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, descrA, csrSortedValA, csrSortedColIndA, + csrSortedRowPtrA, nnzA, nnzPerRow, csrSortedValC, + csrSortedColIndC, csrSortedRowPtrC, tol); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsr2csr_compress( + cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, + const double *csrSortedValA, const int *csrSortedColIndA, + const int *csrSortedRowPtrA, int nnzA, const int *nnzPerRow, + double *csrSortedValC, int *csrSortedColIndC, int *csrSortedRowPtrC, + double tol) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const double *, + const int *, const int *, int, const int *, double *, int *, int *, + double); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDcsr2csr_compress"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, descrA, csrSortedValA, csrSortedColIndA, + csrSortedRowPtrA, nnzA, nnzPerRow, csrSortedValC, + csrSortedColIndC, csrSortedRowPtrC, tol); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsr2csr_compress( + cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, + const cuComplex *csrSortedValA, const int *csrSortedColIndA, + const int *csrSortedRowPtrA, int nnzA, const int *nnzPerRow, + cuComplex *csrSortedValC, int *csrSortedColIndC, int *csrSortedRowPtrC, + cuComplex tol) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const cuComplex *, + const int *, const int *, int, const int *, cuComplex *, int *, int *, + cuComplex); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCcsr2csr_compress"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, descrA, csrSortedValA, csrSortedColIndA, + csrSortedRowPtrA, nnzA, nnzPerRow, csrSortedValC, + csrSortedColIndC, csrSortedRowPtrC, tol); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsr2csr_compress( + cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, + const cuDoubleComplex *csrSortedValA, const int *csrSortedColIndA, + const int *csrSortedRowPtrA, int nnzA, const int *nnzPerRow, + cuDoubleComplex *csrSortedValC, int *csrSortedColIndC, + int *csrSortedRowPtrC, cuDoubleComplex tol) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, + const cuDoubleComplex *, const int *, const int *, int, const int *, + cuDoubleComplex *, int *, int *, cuDoubleComplex); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZcsr2csr_compress"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, descrA, csrSortedValA, csrSortedColIndA, + csrSortedRowPtrA, nnzA, nnzPerRow, csrSortedValC, + csrSortedColIndC, csrSortedRowPtrC, tol); +} + +cusparseStatus_t CUSPARSEAPI cusparseSdense2csr( + cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, + const float *A, int lda, const int *nnzPerRow, float *csrSortedValA, + int *csrSortedRowPtrA, int *csrSortedColIndA) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const float *, int, + const int *, float *, int *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSdense2csr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, descrA, A, lda, nnzPerRow, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA); +} + +cusparseStatus_t CUSPARSEAPI cusparseDdense2csr( + cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, + const double *A, int lda, const int *nnzPerRow, double *csrSortedValA, + int *csrSortedRowPtrA, int *csrSortedColIndA) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const double *, int, + const int *, double *, int *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDdense2csr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, descrA, A, lda, nnzPerRow, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA); +} + +cusparseStatus_t CUSPARSEAPI cusparseCdense2csr( + cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, + const cuComplex *A, int lda, const int *nnzPerRow, cuComplex *csrSortedValA, + int *csrSortedRowPtrA, int *csrSortedColIndA) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const cuComplex *, + int, const int *, cuComplex *, int *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCdense2csr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, descrA, A, lda, nnzPerRow, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA); +} + +cusparseStatus_t CUSPARSEAPI cusparseZdense2csr( + cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, + const cuDoubleComplex *A, int lda, const int *nnzPerRow, + cuDoubleComplex *csrSortedValA, int *csrSortedRowPtrA, + int *csrSortedColIndA) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, + const cuDoubleComplex *, int, const int *, cuDoubleComplex *, int *, + int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZdense2csr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, descrA, A, lda, nnzPerRow, csrSortedValA, + csrSortedRowPtrA, csrSortedColIndA); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsr2dense( + cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, + const float *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, float *A, int lda) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const float *, + const int *, const int *, float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseScsr2dense"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, A, lda); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsr2dense( + cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, + const double *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, double *A, int lda) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const double *, + const int *, const int *, double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDcsr2dense"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, A, lda); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsr2dense( + cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, + const cuComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, cuComplex *A, int lda) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const cuComplex *, + const int *, const int *, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCcsr2dense"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, A, lda); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsr2dense( + cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, + const cuDoubleComplex *csrSortedValA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, cuDoubleComplex *A, int lda) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, + const cuDoubleComplex *, const int *, const int *, cuDoubleComplex *, + int); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZcsr2dense"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, A, lda); +} + +cusparseStatus_t CUSPARSEAPI cusparseSdense2csc( + cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, + const float *A, int lda, const int *nnzPerCol, float *cscSortedValA, + int *cscSortedRowIndA, int *cscSortedColPtrA) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const float *, int, + const int *, float *, int *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSdense2csc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, descrA, A, lda, nnzPerCol, cscSortedValA, + cscSortedRowIndA, cscSortedColPtrA); +} + +cusparseStatus_t CUSPARSEAPI cusparseDdense2csc( + cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, + const double *A, int lda, const int *nnzPerCol, double *cscSortedValA, + int *cscSortedRowIndA, int *cscSortedColPtrA) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const double *, int, + const int *, double *, int *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDdense2csc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, descrA, A, lda, nnzPerCol, cscSortedValA, + cscSortedRowIndA, cscSortedColPtrA); +} + +cusparseStatus_t CUSPARSEAPI cusparseCdense2csc( + cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, + const cuComplex *A, int lda, const int *nnzPerCol, cuComplex *cscSortedValA, + int *cscSortedRowIndA, int *cscSortedColPtrA) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const cuComplex *, + int, const int *, cuComplex *, int *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCdense2csc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, descrA, A, lda, nnzPerCol, cscSortedValA, + cscSortedRowIndA, cscSortedColPtrA); +} + +cusparseStatus_t CUSPARSEAPI cusparseZdense2csc( + cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, + const cuDoubleComplex *A, int lda, const int *nnzPerCol, + cuDoubleComplex *cscSortedValA, int *cscSortedRowIndA, + int *cscSortedColPtrA) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, + const cuDoubleComplex *, int, const int *, cuDoubleComplex *, int *, + int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZdense2csc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, descrA, A, lda, nnzPerCol, cscSortedValA, + cscSortedRowIndA, cscSortedColPtrA); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsc2dense( + cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, + const float *cscSortedValA, const int *cscSortedRowIndA, + const int *cscSortedColPtrA, float *A, int lda) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const float *, + const int *, const int *, float *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseScsc2dense"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, descrA, cscSortedValA, cscSortedRowIndA, + cscSortedColPtrA, A, lda); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsc2dense( + cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, + const double *cscSortedValA, const int *cscSortedRowIndA, + const int *cscSortedColPtrA, double *A, int lda) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const double *, + const int *, const int *, double *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDcsc2dense"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, descrA, cscSortedValA, cscSortedRowIndA, + cscSortedColPtrA, A, lda); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsc2dense( + cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, + const cuComplex *cscSortedValA, const int *cscSortedRowIndA, + const int *cscSortedColPtrA, cuComplex *A, int lda) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, const cuComplex *, + const int *, const int *, cuComplex *, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCcsc2dense"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, descrA, cscSortedValA, cscSortedRowIndA, + cscSortedColPtrA, A, lda); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsc2dense( + cusparseHandle_t handle, int m, int n, const cusparseMatDescr_t descrA, + const cuDoubleComplex *cscSortedValA, const int *cscSortedRowIndA, + const int *cscSortedColPtrA, cuDoubleComplex *A, int lda) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const cusparseMatDescr_t, + const cuDoubleComplex *, const int *, const int *, cuDoubleComplex *, + int); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZcsc2dense"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, descrA, cscSortedValA, cscSortedRowIndA, + cscSortedColPtrA, A, lda); +} + +cusparseStatus_t CUSPARSEAPI cusparseXcoo2csr(cusparseHandle_t handle, + const int *cooRowInd, int nnz, + int m, int *csrSortedRowPtr, + cusparseIndexBase_t idxBase) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, const int *, int, int, int *, cusparseIndexBase_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseXcoo2csr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, cooRowInd, nnz, m, csrSortedRowPtr, idxBase); +} + +cusparseStatus_t CUSPARSEAPI cusparseXcsr2coo(cusparseHandle_t handle, + const int *csrSortedRowPtr, + int nnz, int m, int *cooRowInd, + cusparseIndexBase_t idxBase) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, const int *, int, int, int *, cusparseIndexBase_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseXcsr2coo"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, csrSortedRowPtr, nnz, m, cooRowInd, idxBase); +} + +cusparseStatus_t CUSPARSEAPI cusparseXcsr2bsrNnz( + cusparseHandle_t handle, cusparseDirection_t dirA, int m, int n, + const cusparseMatDescr_t descrA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, int blockDim, const cusparseMatDescr_t descrC, + int *bsrSortedRowPtrC, int *nnzTotalDevHostPtr) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const int *, const int *, int, const cusparseMatDescr_t, int *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseXcsr2bsrNnz"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, m, n, descrA, csrSortedRowPtrA, + csrSortedColIndA, blockDim, descrC, bsrSortedRowPtrC, + nnzTotalDevHostPtr); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsr2bsr( + cusparseHandle_t handle, cusparseDirection_t dirA, int m, int n, + const cusparseMatDescr_t descrA, const float *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, int blockDim, + const cusparseMatDescr_t descrC, float *bsrSortedValC, + int *bsrSortedRowPtrC, int *bsrSortedColIndC) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const float *, const int *, const int *, int, const cusparseMatDescr_t, + float *, int *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseScsr2bsr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, m, n, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, blockDim, descrC, bsrSortedValC, + bsrSortedRowPtrC, bsrSortedColIndC); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsr2bsr( + cusparseHandle_t handle, cusparseDirection_t dirA, int m, int n, + const cusparseMatDescr_t descrA, const double *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, int blockDim, + const cusparseMatDescr_t descrC, double *bsrSortedValC, + int *bsrSortedRowPtrC, int *bsrSortedColIndC) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const double *, const int *, const int *, int, const cusparseMatDescr_t, + double *, int *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDcsr2bsr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, m, n, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, blockDim, descrC, bsrSortedValC, + bsrSortedRowPtrC, bsrSortedColIndC); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsr2bsr( + cusparseHandle_t handle, cusparseDirection_t dirA, int m, int n, + const cusparseMatDescr_t descrA, const cuComplex *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, int blockDim, + const cusparseMatDescr_t descrC, cuComplex *bsrSortedValC, + int *bsrSortedRowPtrC, int *bsrSortedColIndC) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const cuComplex *, const int *, const int *, int, + const cusparseMatDescr_t, cuComplex *, int *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCcsr2bsr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, m, n, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, blockDim, descrC, bsrSortedValC, + bsrSortedRowPtrC, bsrSortedColIndC); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsr2bsr( + cusparseHandle_t handle, cusparseDirection_t dirA, int m, int n, + const cusparseMatDescr_t descrA, const cuDoubleComplex *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, int blockDim, + const cusparseMatDescr_t descrC, cuDoubleComplex *bsrSortedValC, + int *bsrSortedRowPtrC, int *bsrSortedColIndC) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const cuDoubleComplex *, const int *, const int *, int, + const cusparseMatDescr_t, cuDoubleComplex *, int *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZcsr2bsr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, m, n, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, blockDim, descrC, bsrSortedValC, + bsrSortedRowPtrC, bsrSortedColIndC); +} + +cusparseStatus_t CUSPARSEAPI cusparseSbsr2csr( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nb, + const cusparseMatDescr_t descrA, const float *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int blockDim, + const cusparseMatDescr_t descrC, float *csrSortedValC, + int *csrSortedRowPtrC, int *csrSortedColIndC) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const float *, const int *, const int *, int, const cusparseMatDescr_t, + float *, int *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSbsr2csr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nb, descrA, bsrSortedValA, bsrSortedRowPtrA, + bsrSortedColIndA, blockDim, descrC, csrSortedValC, + csrSortedRowPtrC, csrSortedColIndC); +} + +cusparseStatus_t CUSPARSEAPI cusparseDbsr2csr( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nb, + const cusparseMatDescr_t descrA, const double *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int blockDim, + const cusparseMatDescr_t descrC, double *csrSortedValC, + int *csrSortedRowPtrC, int *csrSortedColIndC) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const double *, const int *, const int *, int, const cusparseMatDescr_t, + double *, int *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDbsr2csr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nb, descrA, bsrSortedValA, bsrSortedRowPtrA, + bsrSortedColIndA, blockDim, descrC, csrSortedValC, + csrSortedRowPtrC, csrSortedColIndC); +} + +cusparseStatus_t CUSPARSEAPI cusparseCbsr2csr( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nb, + const cusparseMatDescr_t descrA, const cuComplex *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int blockDim, + const cusparseMatDescr_t descrC, cuComplex *csrSortedValC, + int *csrSortedRowPtrC, int *csrSortedColIndC) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const cuComplex *, const int *, const int *, int, + const cusparseMatDescr_t, cuComplex *, int *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCbsr2csr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nb, descrA, bsrSortedValA, bsrSortedRowPtrA, + bsrSortedColIndA, blockDim, descrC, csrSortedValC, + csrSortedRowPtrC, csrSortedColIndC); +} + +cusparseStatus_t CUSPARSEAPI cusparseZbsr2csr( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nb, + const cusparseMatDescr_t descrA, const cuDoubleComplex *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int blockDim, + const cusparseMatDescr_t descrC, cuDoubleComplex *csrSortedValC, + int *csrSortedRowPtrC, int *csrSortedColIndC) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const cuDoubleComplex *, const int *, const int *, int, + const cusparseMatDescr_t, cuDoubleComplex *, int *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZbsr2csr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nb, descrA, bsrSortedValA, bsrSortedRowPtrA, + bsrSortedColIndA, blockDim, descrC, csrSortedValC, + csrSortedRowPtrC, csrSortedColIndC); +} + +cusparseStatus_t CUSPARSEAPI cusparseSgebsr2gebsc_bufferSize( + cusparseHandle_t handle, int mb, int nb, int nnzb, + const float *bsrSortedVal, const int *bsrSortedRowPtr, + const int *bsrSortedColInd, int rowBlockDim, int colBlockDim, + int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const float *, const int *, const int *, + int, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSgebsr2gebsc_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, mb, nb, nnzb, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, rowBlockDim, colBlockDim, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseDgebsr2gebsc_bufferSize( + cusparseHandle_t handle, int mb, int nb, int nnzb, + const double *bsrSortedVal, const int *bsrSortedRowPtr, + const int *bsrSortedColInd, int rowBlockDim, int colBlockDim, + int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const double *, const int *, const int *, + int, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDgebsr2gebsc_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, mb, nb, nnzb, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, rowBlockDim, colBlockDim, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseCgebsr2gebsc_bufferSize( + cusparseHandle_t handle, int mb, int nb, int nnzb, + const cuComplex *bsrSortedVal, const int *bsrSortedRowPtr, + const int *bsrSortedColInd, int rowBlockDim, int colBlockDim, + int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cuComplex *, const int *, + const int *, int, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCgebsr2gebsc_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, mb, nb, nnzb, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, rowBlockDim, colBlockDim, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseZgebsr2gebsc_bufferSize( + cusparseHandle_t handle, int mb, int nb, int nnzb, + const cuDoubleComplex *bsrSortedVal, const int *bsrSortedRowPtr, + const int *bsrSortedColInd, int rowBlockDim, int colBlockDim, + int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cuDoubleComplex *, const int *, + const int *, int, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZgebsr2gebsc_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, mb, nb, nnzb, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, rowBlockDim, colBlockDim, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseSgebsr2gebsc_bufferSizeExt( + cusparseHandle_t handle, int mb, int nb, int nnzb, + const float *bsrSortedVal, const int *bsrSortedRowPtr, + const int *bsrSortedColInd, int rowBlockDim, int colBlockDim, + size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const float *, const int *, const int *, + int, int, size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusparseSgebsr2gebsc_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, mb, nb, nnzb, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, rowBlockDim, colBlockDim, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseDgebsr2gebsc_bufferSizeExt( + cusparseHandle_t handle, int mb, int nb, int nnzb, + const double *bsrSortedVal, const int *bsrSortedRowPtr, + const int *bsrSortedColInd, int rowBlockDim, int colBlockDim, + size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const double *, const int *, const int *, + int, int, size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusparseDgebsr2gebsc_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, mb, nb, nnzb, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, rowBlockDim, colBlockDim, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseCgebsr2gebsc_bufferSizeExt( + cusparseHandle_t handle, int mb, int nb, int nnzb, + const cuComplex *bsrSortedVal, const int *bsrSortedRowPtr, + const int *bsrSortedColInd, int rowBlockDim, int colBlockDim, + size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cuComplex *, const int *, + const int *, int, int, size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusparseCgebsr2gebsc_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, mb, nb, nnzb, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, rowBlockDim, colBlockDim, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseZgebsr2gebsc_bufferSizeExt( + cusparseHandle_t handle, int mb, int nb, int nnzb, + const cuDoubleComplex *bsrSortedVal, const int *bsrSortedRowPtr, + const int *bsrSortedColInd, int rowBlockDim, int colBlockDim, + size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cuDoubleComplex *, const int *, + const int *, int, int, size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusparseZgebsr2gebsc_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, mb, nb, nnzb, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, rowBlockDim, colBlockDim, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseSgebsr2gebsc( + cusparseHandle_t handle, int mb, int nb, int nnzb, + const float *bsrSortedVal, const int *bsrSortedRowPtr, + const int *bsrSortedColInd, int rowBlockDim, int colBlockDim, float *bscVal, + int *bscRowInd, int *bscColPtr, cusparseAction_t copyValues, + cusparseIndexBase_t idxBase, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const float *, const int *, const int *, + int, int, float *, int *, int *, cusparseAction_t, cusparseIndexBase_t, + void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSgebsr2gebsc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, mb, nb, nnzb, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, rowBlockDim, colBlockDim, bscVal, bscRowInd, + bscColPtr, copyValues, idxBase, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDgebsr2gebsc( + cusparseHandle_t handle, int mb, int nb, int nnzb, + const double *bsrSortedVal, const int *bsrSortedRowPtr, + const int *bsrSortedColInd, int rowBlockDim, int colBlockDim, + double *bscVal, int *bscRowInd, int *bscColPtr, cusparseAction_t copyValues, + cusparseIndexBase_t idxBase, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const double *, const int *, const int *, + int, int, double *, int *, int *, cusparseAction_t, cusparseIndexBase_t, + void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDgebsr2gebsc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, mb, nb, nnzb, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, rowBlockDim, colBlockDim, bscVal, bscRowInd, + bscColPtr, copyValues, idxBase, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCgebsr2gebsc( + cusparseHandle_t handle, int mb, int nb, int nnzb, + const cuComplex *bsrSortedVal, const int *bsrSortedRowPtr, + const int *bsrSortedColInd, int rowBlockDim, int colBlockDim, + cuComplex *bscVal, int *bscRowInd, int *bscColPtr, + cusparseAction_t copyValues, cusparseIndexBase_t idxBase, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cuComplex *, const int *, + const int *, int, int, cuComplex *, int *, int *, cusparseAction_t, + cusparseIndexBase_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCgebsr2gebsc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, mb, nb, nnzb, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, rowBlockDim, colBlockDim, bscVal, bscRowInd, + bscColPtr, copyValues, idxBase, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseZgebsr2gebsc( + cusparseHandle_t handle, int mb, int nb, int nnzb, + const cuDoubleComplex *bsrSortedVal, const int *bsrSortedRowPtr, + const int *bsrSortedColInd, int rowBlockDim, int colBlockDim, + cuDoubleComplex *bscVal, int *bscRowInd, int *bscColPtr, + cusparseAction_t copyValues, cusparseIndexBase_t idxBase, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cuDoubleComplex *, const int *, + const int *, int, int, cuDoubleComplex *, int *, int *, cusparseAction_t, + cusparseIndexBase_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZgebsr2gebsc"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, mb, nb, nnzb, bsrSortedVal, bsrSortedRowPtr, + bsrSortedColInd, rowBlockDim, colBlockDim, bscVal, bscRowInd, + bscColPtr, copyValues, idxBase, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseXgebsr2csr( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nb, + const cusparseMatDescr_t descrA, const int *bsrSortedRowPtrA, + const int *bsrSortedColIndA, int rowBlockDim, int colBlockDim, + const cusparseMatDescr_t descrC, int *csrSortedRowPtrC, + int *csrSortedColIndC) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const int *, const int *, int, int, const cusparseMatDescr_t, int *, + int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseXgebsr2csr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nb, descrA, bsrSortedRowPtrA, + bsrSortedColIndA, rowBlockDim, colBlockDim, descrC, + csrSortedRowPtrC, csrSortedColIndC); +} + +cusparseStatus_t CUSPARSEAPI cusparseSgebsr2csr( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nb, + const cusparseMatDescr_t descrA, const float *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int rowBlockDim, + int colBlockDim, const cusparseMatDescr_t descrC, float *csrSortedValC, + int *csrSortedRowPtrC, int *csrSortedColIndC) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const float *, const int *, const int *, int, int, + const cusparseMatDescr_t, float *, int *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSgebsr2csr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nb, descrA, bsrSortedValA, bsrSortedRowPtrA, + bsrSortedColIndA, rowBlockDim, colBlockDim, descrC, + csrSortedValC, csrSortedRowPtrC, csrSortedColIndC); +} + +cusparseStatus_t CUSPARSEAPI cusparseDgebsr2csr( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nb, + const cusparseMatDescr_t descrA, const double *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int rowBlockDim, + int colBlockDim, const cusparseMatDescr_t descrC, double *csrSortedValC, + int *csrSortedRowPtrC, int *csrSortedColIndC) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const double *, const int *, const int *, int, int, + const cusparseMatDescr_t, double *, int *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDgebsr2csr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nb, descrA, bsrSortedValA, bsrSortedRowPtrA, + bsrSortedColIndA, rowBlockDim, colBlockDim, descrC, + csrSortedValC, csrSortedRowPtrC, csrSortedColIndC); +} + +cusparseStatus_t CUSPARSEAPI cusparseCgebsr2csr( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nb, + const cusparseMatDescr_t descrA, const cuComplex *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int rowBlockDim, + int colBlockDim, const cusparseMatDescr_t descrC, cuComplex *csrSortedValC, + int *csrSortedRowPtrC, int *csrSortedColIndC) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const cuComplex *, const int *, const int *, int, int, + const cusparseMatDescr_t, cuComplex *, int *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCgebsr2csr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nb, descrA, bsrSortedValA, bsrSortedRowPtrA, + bsrSortedColIndA, rowBlockDim, colBlockDim, descrC, + csrSortedValC, csrSortedRowPtrC, csrSortedColIndC); +} + +cusparseStatus_t CUSPARSEAPI cusparseZgebsr2csr( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nb, + const cusparseMatDescr_t descrA, const cuDoubleComplex *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int rowBlockDim, + int colBlockDim, const cusparseMatDescr_t descrC, + cuDoubleComplex *csrSortedValC, int *csrSortedRowPtrC, + int *csrSortedColIndC) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const cuDoubleComplex *, const int *, const int *, int, int, + const cusparseMatDescr_t, cuDoubleComplex *, int *, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZgebsr2csr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nb, descrA, bsrSortedValA, bsrSortedRowPtrA, + bsrSortedColIndA, rowBlockDim, colBlockDim, descrC, + csrSortedValC, csrSortedRowPtrC, csrSortedColIndC); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsr2gebsr_bufferSize( + cusparseHandle_t handle, cusparseDirection_t dirA, int m, int n, + const cusparseMatDescr_t descrA, const float *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, int rowBlockDim, + int colBlockDim, int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const float *, const int *, const int *, int, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseScsr2gebsr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, m, n, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, rowBlockDim, colBlockDim, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsr2gebsr_bufferSize( + cusparseHandle_t handle, cusparseDirection_t dirA, int m, int n, + const cusparseMatDescr_t descrA, const double *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, int rowBlockDim, + int colBlockDim, int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const double *, const int *, const int *, int, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDcsr2gebsr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, m, n, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, rowBlockDim, colBlockDim, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsr2gebsr_bufferSize( + cusparseHandle_t handle, cusparseDirection_t dirA, int m, int n, + const cusparseMatDescr_t descrA, const cuComplex *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, int rowBlockDim, + int colBlockDim, int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const cuComplex *, const int *, const int *, int, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCcsr2gebsr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, m, n, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, rowBlockDim, colBlockDim, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsr2gebsr_bufferSize( + cusparseHandle_t handle, cusparseDirection_t dirA, int m, int n, + const cusparseMatDescr_t descrA, const cuDoubleComplex *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, int rowBlockDim, + int colBlockDim, int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const cuDoubleComplex *, const int *, const int *, int, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZcsr2gebsr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, m, n, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, rowBlockDim, colBlockDim, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsr2gebsr_bufferSizeExt( + cusparseHandle_t handle, cusparseDirection_t dirA, int m, int n, + const cusparseMatDescr_t descrA, const float *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, int rowBlockDim, + int colBlockDim, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const float *, const int *, const int *, int, int, size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusparseScsr2gebsr_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, m, n, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, rowBlockDim, colBlockDim, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsr2gebsr_bufferSizeExt( + cusparseHandle_t handle, cusparseDirection_t dirA, int m, int n, + const cusparseMatDescr_t descrA, const double *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, int rowBlockDim, + int colBlockDim, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const double *, const int *, const int *, int, int, size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusparseDcsr2gebsr_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, m, n, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, rowBlockDim, colBlockDim, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsr2gebsr_bufferSizeExt( + cusparseHandle_t handle, cusparseDirection_t dirA, int m, int n, + const cusparseMatDescr_t descrA, const cuComplex *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, int rowBlockDim, + int colBlockDim, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const cuComplex *, const int *, const int *, int, int, size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusparseCcsr2gebsr_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, m, n, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, rowBlockDim, colBlockDim, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsr2gebsr_bufferSizeExt( + cusparseHandle_t handle, cusparseDirection_t dirA, int m, int n, + const cusparseMatDescr_t descrA, const cuDoubleComplex *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, int rowBlockDim, + int colBlockDim, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const cuDoubleComplex *, const int *, const int *, int, int, size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusparseZcsr2gebsr_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, m, n, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, rowBlockDim, colBlockDim, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseXcsr2gebsrNnz( + cusparseHandle_t handle, cusparseDirection_t dirA, int m, int n, + const cusparseMatDescr_t descrA, const int *csrSortedRowPtrA, + const int *csrSortedColIndA, const cusparseMatDescr_t descrC, + int *bsrSortedRowPtrC, int rowBlockDim, int colBlockDim, + int *nnzTotalDevHostPtr, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const int *, const int *, const cusparseMatDescr_t, int *, int, int, + int *, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseXcsr2gebsrNnz"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, m, n, descrA, csrSortedRowPtrA, + csrSortedColIndA, descrC, bsrSortedRowPtrC, rowBlockDim, + colBlockDim, nnzTotalDevHostPtr, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsr2gebsr( + cusparseHandle_t handle, cusparseDirection_t dirA, int m, int n, + const cusparseMatDescr_t descrA, const float *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + const cusparseMatDescr_t descrC, float *bsrSortedValC, + int *bsrSortedRowPtrC, int *bsrSortedColIndC, int rowBlockDim, + int colBlockDim, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const float *, const int *, const int *, const cusparseMatDescr_t, + float *, int *, int *, int, int, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseScsr2gebsr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, m, n, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, descrC, bsrSortedValC, bsrSortedRowPtrC, + bsrSortedColIndC, rowBlockDim, colBlockDim, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsr2gebsr( + cusparseHandle_t handle, cusparseDirection_t dirA, int m, int n, + const cusparseMatDescr_t descrA, const double *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + const cusparseMatDescr_t descrC, double *bsrSortedValC, + int *bsrSortedRowPtrC, int *bsrSortedColIndC, int rowBlockDim, + int colBlockDim, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const double *, const int *, const int *, const cusparseMatDescr_t, + double *, int *, int *, int, int, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDcsr2gebsr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, m, n, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, descrC, bsrSortedValC, bsrSortedRowPtrC, + bsrSortedColIndC, rowBlockDim, colBlockDim, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsr2gebsr( + cusparseHandle_t handle, cusparseDirection_t dirA, int m, int n, + const cusparseMatDescr_t descrA, const cuComplex *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + const cusparseMatDescr_t descrC, cuComplex *bsrSortedValC, + int *bsrSortedRowPtrC, int *bsrSortedColIndC, int rowBlockDim, + int colBlockDim, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const cuComplex *, const int *, const int *, const cusparseMatDescr_t, + cuComplex *, int *, int *, int, int, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCcsr2gebsr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, m, n, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, descrC, bsrSortedValC, bsrSortedRowPtrC, + bsrSortedColIndC, rowBlockDim, colBlockDim, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsr2gebsr( + cusparseHandle_t handle, cusparseDirection_t dirA, int m, int n, + const cusparseMatDescr_t descrA, const cuDoubleComplex *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + const cusparseMatDescr_t descrC, cuDoubleComplex *bsrSortedValC, + int *bsrSortedRowPtrC, int *bsrSortedColIndC, int rowBlockDim, + int colBlockDim, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, const cusparseMatDescr_t, + const cuDoubleComplex *, const int *, const int *, + const cusparseMatDescr_t, cuDoubleComplex *, int *, int *, int, int, + void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZcsr2gebsr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, m, n, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, descrC, bsrSortedValC, bsrSortedRowPtrC, + bsrSortedColIndC, rowBlockDim, colBlockDim, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseSgebsr2gebsr_bufferSize( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nb, int nnzb, + const cusparseMatDescr_t descrA, const float *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int rowBlockDimA, + int colBlockDimA, int rowBlockDimC, int colBlockDimC, + int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, int, + const cusparseMatDescr_t, const float *, const int *, const int *, int, + int, int, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSgebsr2gebsr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nb, nnzb, descrA, bsrSortedValA, + bsrSortedRowPtrA, bsrSortedColIndA, rowBlockDimA, + colBlockDimA, rowBlockDimC, colBlockDimC, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseDgebsr2gebsr_bufferSize( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nb, int nnzb, + const cusparseMatDescr_t descrA, const double *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int rowBlockDimA, + int colBlockDimA, int rowBlockDimC, int colBlockDimC, + int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, int, + const cusparseMatDescr_t, const double *, const int *, const int *, int, + int, int, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDgebsr2gebsr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nb, nnzb, descrA, bsrSortedValA, + bsrSortedRowPtrA, bsrSortedColIndA, rowBlockDimA, + colBlockDimA, rowBlockDimC, colBlockDimC, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseCgebsr2gebsr_bufferSize( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nb, int nnzb, + const cusparseMatDescr_t descrA, const cuComplex *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int rowBlockDimA, + int colBlockDimA, int rowBlockDimC, int colBlockDimC, + int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, int, + const cusparseMatDescr_t, const cuComplex *, const int *, const int *, + int, int, int, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCgebsr2gebsr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nb, nnzb, descrA, bsrSortedValA, + bsrSortedRowPtrA, bsrSortedColIndA, rowBlockDimA, + colBlockDimA, rowBlockDimC, colBlockDimC, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseZgebsr2gebsr_bufferSize( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nb, int nnzb, + const cusparseMatDescr_t descrA, const cuDoubleComplex *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int rowBlockDimA, + int colBlockDimA, int rowBlockDimC, int colBlockDimC, + int *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, int, + const cusparseMatDescr_t, const cuDoubleComplex *, const int *, + const int *, int, int, int, int, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZgebsr2gebsr_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nb, nnzb, descrA, bsrSortedValA, + bsrSortedRowPtrA, bsrSortedColIndA, rowBlockDimA, + colBlockDimA, rowBlockDimC, colBlockDimC, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseSgebsr2gebsr_bufferSizeExt( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nb, int nnzb, + const cusparseMatDescr_t descrA, const float *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int rowBlockDimA, + int colBlockDimA, int rowBlockDimC, int colBlockDimC, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, int, + const cusparseMatDescr_t, const float *, const int *, const int *, int, + int, int, int, size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusparseSgebsr2gebsr_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nb, nnzb, descrA, bsrSortedValA, + bsrSortedRowPtrA, bsrSortedColIndA, rowBlockDimA, + colBlockDimA, rowBlockDimC, colBlockDimC, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseDgebsr2gebsr_bufferSizeExt( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nb, int nnzb, + const cusparseMatDescr_t descrA, const double *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int rowBlockDimA, + int colBlockDimA, int rowBlockDimC, int colBlockDimC, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, int, + const cusparseMatDescr_t, const double *, const int *, const int *, int, + int, int, int, size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusparseDgebsr2gebsr_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nb, nnzb, descrA, bsrSortedValA, + bsrSortedRowPtrA, bsrSortedColIndA, rowBlockDimA, + colBlockDimA, rowBlockDimC, colBlockDimC, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseCgebsr2gebsr_bufferSizeExt( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nb, int nnzb, + const cusparseMatDescr_t descrA, const cuComplex *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int rowBlockDimA, + int colBlockDimA, int rowBlockDimC, int colBlockDimC, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, int, + const cusparseMatDescr_t, const cuComplex *, const int *, const int *, + int, int, int, int, size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusparseCgebsr2gebsr_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nb, nnzb, descrA, bsrSortedValA, + bsrSortedRowPtrA, bsrSortedColIndA, rowBlockDimA, + colBlockDimA, rowBlockDimC, colBlockDimC, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseZgebsr2gebsr_bufferSizeExt( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nb, int nnzb, + const cusparseMatDescr_t descrA, const cuDoubleComplex *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int rowBlockDimA, + int colBlockDimA, int rowBlockDimC, int colBlockDimC, size_t *pBufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, int, + const cusparseMatDescr_t, const cuDoubleComplex *, const int *, + const int *, int, int, int, int, size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusparseZgebsr2gebsr_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nb, nnzb, descrA, bsrSortedValA, + bsrSortedRowPtrA, bsrSortedColIndA, rowBlockDimA, + colBlockDimA, rowBlockDimC, colBlockDimC, pBufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseXgebsr2gebsrNnz( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nb, int nnzb, + const cusparseMatDescr_t descrA, const int *bsrSortedRowPtrA, + const int *bsrSortedColIndA, int rowBlockDimA, int colBlockDimA, + const cusparseMatDescr_t descrC, int *bsrSortedRowPtrC, int rowBlockDimC, + int colBlockDimC, int *nnzTotalDevHostPtr, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, int, + const cusparseMatDescr_t, const int *, const int *, int, int, + const cusparseMatDescr_t, int *, int, int, int *, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseXgebsr2gebsrNnz"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nb, nnzb, descrA, bsrSortedRowPtrA, + bsrSortedColIndA, rowBlockDimA, colBlockDimA, descrC, + bsrSortedRowPtrC, rowBlockDimC, colBlockDimC, + nnzTotalDevHostPtr, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseSgebsr2gebsr( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nb, int nnzb, + const cusparseMatDescr_t descrA, const float *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int rowBlockDimA, + int colBlockDimA, const cusparseMatDescr_t descrC, float *bsrSortedValC, + int *bsrSortedRowPtrC, int *bsrSortedColIndC, int rowBlockDimC, + int colBlockDimC, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, int, + const cusparseMatDescr_t, const float *, const int *, const int *, int, + int, const cusparseMatDescr_t, float *, int *, int *, int, int, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSgebsr2gebsr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nb, nnzb, descrA, bsrSortedValA, + bsrSortedRowPtrA, bsrSortedColIndA, rowBlockDimA, + colBlockDimA, descrC, bsrSortedValC, bsrSortedRowPtrC, + bsrSortedColIndC, rowBlockDimC, colBlockDimC, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDgebsr2gebsr( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nb, int nnzb, + const cusparseMatDescr_t descrA, const double *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int rowBlockDimA, + int colBlockDimA, const cusparseMatDescr_t descrC, double *bsrSortedValC, + int *bsrSortedRowPtrC, int *bsrSortedColIndC, int rowBlockDimC, + int colBlockDimC, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, int, + const cusparseMatDescr_t, const double *, const int *, const int *, int, + int, const cusparseMatDescr_t, double *, int *, int *, int, int, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDgebsr2gebsr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nb, nnzb, descrA, bsrSortedValA, + bsrSortedRowPtrA, bsrSortedColIndA, rowBlockDimA, + colBlockDimA, descrC, bsrSortedValC, bsrSortedRowPtrC, + bsrSortedColIndC, rowBlockDimC, colBlockDimC, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCgebsr2gebsr( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nb, int nnzb, + const cusparseMatDescr_t descrA, const cuComplex *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int rowBlockDimA, + int colBlockDimA, const cusparseMatDescr_t descrC, cuComplex *bsrSortedValC, + int *bsrSortedRowPtrC, int *bsrSortedColIndC, int rowBlockDimC, + int colBlockDimC, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, int, + const cusparseMatDescr_t, const cuComplex *, const int *, const int *, + int, int, const cusparseMatDescr_t, cuComplex *, int *, int *, int, int, + void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCgebsr2gebsr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nb, nnzb, descrA, bsrSortedValA, + bsrSortedRowPtrA, bsrSortedColIndA, rowBlockDimA, + colBlockDimA, descrC, bsrSortedValC, bsrSortedRowPtrC, + bsrSortedColIndC, rowBlockDimC, colBlockDimC, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseZgebsr2gebsr( + cusparseHandle_t handle, cusparseDirection_t dirA, int mb, int nb, int nnzb, + const cusparseMatDescr_t descrA, const cuDoubleComplex *bsrSortedValA, + const int *bsrSortedRowPtrA, const int *bsrSortedColIndA, int rowBlockDimA, + int colBlockDimA, const cusparseMatDescr_t descrC, + cuDoubleComplex *bsrSortedValC, int *bsrSortedRowPtrC, + int *bsrSortedColIndC, int rowBlockDimC, int colBlockDimC, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseDirection_t, int, int, int, + const cusparseMatDescr_t, const cuDoubleComplex *, const int *, + const int *, int, int, const cusparseMatDescr_t, cuDoubleComplex *, int *, + int *, int, int, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZgebsr2gebsr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, dirA, mb, nb, nnzb, descrA, bsrSortedValA, + bsrSortedRowPtrA, bsrSortedColIndA, rowBlockDimA, + colBlockDimA, descrC, bsrSortedValC, bsrSortedRowPtrC, + bsrSortedColIndC, rowBlockDimC, colBlockDimC, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI +cusparseCreateIdentityPermutation(cusparseHandle_t handle, int n, int *p) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(cusparseHandle_t, int, int *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusparseCreateIdentityPermutation"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, n, p); +} + +cusparseStatus_t CUSPARSEAPI cusparseXcoosort_bufferSizeExt( + cusparseHandle_t handle, int m, int n, int nnz, const int *cooRowsA, + const int *cooColsA, size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const int *, const int *, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseXcoosort_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnz, cooRowsA, cooColsA, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseXcoosortByRow(cusparseHandle_t handle, + int m, int n, int nnz, + int *cooRowsA, int *cooColsA, + int *P, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, int *, int *, int *, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseXcoosortByRow"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnz, cooRowsA, cooColsA, P, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseXcoosortByColumn(cusparseHandle_t handle, + int m, int n, int nnz, + int *cooRowsA, + int *cooColsA, int *P, + void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, int *, int *, int *, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseXcoosortByColumn"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnz, cooRowsA, cooColsA, P, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseXcsrsort_bufferSizeExt( + cusparseHandle_t handle, int m, int n, int nnz, const int *csrRowPtrA, + const int *csrColIndA, size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const int *, const int *, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseXcsrsort_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnz, csrRowPtrA, csrColIndA, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseXcsrsort(cusparseHandle_t handle, int m, + int n, int nnz, + const cusparseMatDescr_t descrA, + const int *csrRowPtrA, + int *csrColIndA, int *P, + void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cusparseMatDescr_t, const int *, + int *, int *, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseXcsrsort"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnz, descrA, csrRowPtrA, csrColIndA, P, + pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseXcscsort_bufferSizeExt( + cusparseHandle_t handle, int m, int n, int nnz, const int *cscColPtrA, + const int *cscRowIndA, size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const int *, const int *, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseXcscsort_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnz, cscColPtrA, cscRowIndA, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseXcscsort(cusparseHandle_t handle, int m, + int n, int nnz, + const cusparseMatDescr_t descrA, + const int *cscColPtrA, + int *cscRowIndA, int *P, + void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cusparseMatDescr_t, const int *, + int *, int *, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseXcscsort"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnz, descrA, cscColPtrA, cscRowIndA, P, + pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsru2csr_bufferSizeExt( + cusparseHandle_t handle, int m, int n, int nnz, float *csrVal, + const int *csrRowPtr, int *csrColInd, csru2csrInfo_t info, + size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, float *, const int *, int *, + csru2csrInfo_t, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseScsru2csr_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnz, csrVal, csrRowPtr, csrColInd, info, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsru2csr_bufferSizeExt( + cusparseHandle_t handle, int m, int n, int nnz, double *csrVal, + const int *csrRowPtr, int *csrColInd, csru2csrInfo_t info, + size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, double *, const int *, int *, + csru2csrInfo_t, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDcsru2csr_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnz, csrVal, csrRowPtr, csrColInd, info, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsru2csr_bufferSizeExt( + cusparseHandle_t handle, int m, int n, int nnz, cuComplex *csrVal, + const int *csrRowPtr, int *csrColInd, csru2csrInfo_t info, + size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, cuComplex *, const int *, int *, + csru2csrInfo_t, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCcsru2csr_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnz, csrVal, csrRowPtr, csrColInd, info, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsru2csr_bufferSizeExt( + cusparseHandle_t handle, int m, int n, int nnz, cuDoubleComplex *csrVal, + const int *csrRowPtr, int *csrColInd, csru2csrInfo_t info, + size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, cuDoubleComplex *, const int *, int *, + csru2csrInfo_t, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZcsru2csr_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnz, csrVal, csrRowPtr, csrColInd, info, + pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsru2csr( + cusparseHandle_t handle, int m, int n, int nnz, + const cusparseMatDescr_t descrA, float *csrVal, const int *csrRowPtr, + int *csrColInd, csru2csrInfo_t info, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cusparseMatDescr_t, float *, + const int *, int *, csru2csrInfo_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseScsru2csr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnz, descrA, csrVal, csrRowPtr, csrColInd, info, + pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsru2csr( + cusparseHandle_t handle, int m, int n, int nnz, + const cusparseMatDescr_t descrA, double *csrVal, const int *csrRowPtr, + int *csrColInd, csru2csrInfo_t info, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cusparseMatDescr_t, double *, + const int *, int *, csru2csrInfo_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDcsru2csr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnz, descrA, csrVal, csrRowPtr, csrColInd, info, + pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsru2csr( + cusparseHandle_t handle, int m, int n, int nnz, + const cusparseMatDescr_t descrA, cuComplex *csrVal, const int *csrRowPtr, + int *csrColInd, csru2csrInfo_t info, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cusparseMatDescr_t, cuComplex *, + const int *, int *, csru2csrInfo_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCcsru2csr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnz, descrA, csrVal, csrRowPtr, csrColInd, info, + pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsru2csr( + cusparseHandle_t handle, int m, int n, int nnz, + const cusparseMatDescr_t descrA, cuDoubleComplex *csrVal, + const int *csrRowPtr, int *csrColInd, csru2csrInfo_t info, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cusparseMatDescr_t, + cuDoubleComplex *, const int *, int *, csru2csrInfo_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZcsru2csr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnz, descrA, csrVal, csrRowPtr, csrColInd, info, + pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseScsr2csru( + cusparseHandle_t handle, int m, int n, int nnz, + const cusparseMatDescr_t descrA, float *csrVal, const int *csrRowPtr, + int *csrColInd, csru2csrInfo_t info, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cusparseMatDescr_t, float *, + const int *, int *, csru2csrInfo_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseScsr2csru"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnz, descrA, csrVal, csrRowPtr, csrColInd, info, + pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDcsr2csru( + cusparseHandle_t handle, int m, int n, int nnz, + const cusparseMatDescr_t descrA, double *csrVal, const int *csrRowPtr, + int *csrColInd, csru2csrInfo_t info, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cusparseMatDescr_t, double *, + const int *, int *, csru2csrInfo_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDcsr2csru"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnz, descrA, csrVal, csrRowPtr, csrColInd, info, + pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCcsr2csru( + cusparseHandle_t handle, int m, int n, int nnz, + const cusparseMatDescr_t descrA, cuComplex *csrVal, const int *csrRowPtr, + int *csrColInd, csru2csrInfo_t info, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cusparseMatDescr_t, cuComplex *, + const int *, int *, csru2csrInfo_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCcsr2csru"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnz, descrA, csrVal, csrRowPtr, csrColInd, info, + pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseZcsr2csru( + cusparseHandle_t handle, int m, int n, int nnz, + const cusparseMatDescr_t descrA, cuDoubleComplex *csrVal, + const int *csrRowPtr, int *csrColInd, csru2csrInfo_t info, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cusparseMatDescr_t, + cuDoubleComplex *, const int *, int *, csru2csrInfo_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseZcsr2csru"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnz, descrA, csrVal, csrRowPtr, csrColInd, info, + pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseSpruneDense2csr_bufferSizeExt( + cusparseHandle_t handle, int m, int n, const float *A, int lda, + const float *threshold, const cusparseMatDescr_t descrC, + const float *csrSortedValC, const int *csrSortedRowPtrC, + const int *csrSortedColIndC, size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const float *, int, const float *, + const cusparseMatDescr_t, const float *, const int *, const int *, + size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusparseSpruneDense2csr_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, threshold, descrC, csrSortedValC, + csrSortedRowPtrC, csrSortedColIndC, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseDpruneDense2csr_bufferSizeExt( + cusparseHandle_t handle, int m, int n, const double *A, int lda, + const double *threshold, const cusparseMatDescr_t descrC, + const double *csrSortedValC, const int *csrSortedRowPtrC, + const int *csrSortedColIndC, size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const double *, int, const double *, + const cusparseMatDescr_t, const double *, const int *, const int *, + size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusparseDpruneDense2csr_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, threshold, descrC, csrSortedValC, + csrSortedRowPtrC, csrSortedColIndC, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseSpruneDense2csrNnz( + cusparseHandle_t handle, int m, int n, const float *A, int lda, + const float *threshold, const cusparseMatDescr_t descrC, int *csrRowPtrC, + int *nnzTotalDevHostPtr, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const float *, int, const float *, + const cusparseMatDescr_t, int *, int *, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSpruneDense2csrNnz"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, threshold, descrC, csrRowPtrC, + nnzTotalDevHostPtr, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDpruneDense2csrNnz( + cusparseHandle_t handle, int m, int n, const double *A, int lda, + const double *threshold, const cusparseMatDescr_t descrC, + int *csrSortedRowPtrC, int *nnzTotalDevHostPtr, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const double *, int, const double *, + const cusparseMatDescr_t, int *, int *, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDpruneDense2csrNnz"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, threshold, descrC, csrSortedRowPtrC, + nnzTotalDevHostPtr, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseSpruneDense2csr( + cusparseHandle_t handle, int m, int n, const float *A, int lda, + const float *threshold, const cusparseMatDescr_t descrC, + float *csrSortedValC, const int *csrSortedRowPtrC, int *csrSortedColIndC, + void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const float *, int, const float *, + const cusparseMatDescr_t, float *, const int *, int *, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSpruneDense2csr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, threshold, descrC, csrSortedValC, + csrSortedRowPtrC, csrSortedColIndC, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDpruneDense2csr( + cusparseHandle_t handle, int m, int n, const double *A, int lda, + const double *threshold, const cusparseMatDescr_t descrC, + double *csrSortedValC, const int *csrSortedRowPtrC, int *csrSortedColIndC, + void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const double *, int, const double *, + const cusparseMatDescr_t, double *, const int *, int *, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDpruneDense2csr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, threshold, descrC, csrSortedValC, + csrSortedRowPtrC, csrSortedColIndC, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseSpruneCsr2csr_bufferSizeExt( + cusparseHandle_t handle, int m, int n, int nnzA, + const cusparseMatDescr_t descrA, const float *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + const float *threshold, const cusparseMatDescr_t descrC, + const float *csrSortedValC, const int *csrSortedRowPtrC, + const int *csrSortedColIndC, size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cusparseMatDescr_t, const float *, + const int *, const int *, const float *, const cusparseMatDescr_t, + const float *, const int *, const int *, size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusparseSpruneCsr2csr_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnzA, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, threshold, descrC, csrSortedValC, + csrSortedRowPtrC, csrSortedColIndC, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseDpruneCsr2csr_bufferSizeExt( + cusparseHandle_t handle, int m, int n, int nnzA, + const cusparseMatDescr_t descrA, const double *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + const double *threshold, const cusparseMatDescr_t descrC, + const double *csrSortedValC, const int *csrSortedRowPtrC, + const int *csrSortedColIndC, size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cusparseMatDescr_t, const double *, + const int *, const int *, const double *, const cusparseMatDescr_t, + const double *, const int *, const int *, size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusparseDpruneCsr2csr_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnzA, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, threshold, descrC, csrSortedValC, + csrSortedRowPtrC, csrSortedColIndC, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseSpruneCsr2csrNnz( + cusparseHandle_t handle, int m, int n, int nnzA, + const cusparseMatDescr_t descrA, const float *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + const float *threshold, const cusparseMatDescr_t descrC, + int *csrSortedRowPtrC, int *nnzTotalDevHostPtr, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cusparseMatDescr_t, const float *, + const int *, const int *, const float *, const cusparseMatDescr_t, int *, + int *, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSpruneCsr2csrNnz"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnzA, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, threshold, descrC, csrSortedRowPtrC, + nnzTotalDevHostPtr, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDpruneCsr2csrNnz( + cusparseHandle_t handle, int m, int n, int nnzA, + const cusparseMatDescr_t descrA, const double *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + const double *threshold, const cusparseMatDescr_t descrC, + int *csrSortedRowPtrC, int *nnzTotalDevHostPtr, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cusparseMatDescr_t, const double *, + const int *, const int *, const double *, const cusparseMatDescr_t, int *, + int *, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDpruneCsr2csrNnz"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnzA, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, threshold, descrC, csrSortedRowPtrC, + nnzTotalDevHostPtr, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseSpruneCsr2csr( + cusparseHandle_t handle, int m, int n, int nnzA, + const cusparseMatDescr_t descrA, const float *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + const float *threshold, const cusparseMatDescr_t descrC, + float *csrSortedValC, const int *csrSortedRowPtrC, int *csrSortedColIndC, + void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cusparseMatDescr_t, const float *, + const int *, const int *, const float *, const cusparseMatDescr_t, + float *, const int *, int *, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSpruneCsr2csr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnzA, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, threshold, descrC, csrSortedValC, + csrSortedRowPtrC, csrSortedColIndC, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDpruneCsr2csr( + cusparseHandle_t handle, int m, int n, int nnzA, + const cusparseMatDescr_t descrA, const double *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, + const double *threshold, const cusparseMatDescr_t descrC, + double *csrSortedValC, const int *csrSortedRowPtrC, int *csrSortedColIndC, + void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cusparseMatDescr_t, const double *, + const int *, const int *, const double *, const cusparseMatDescr_t, + double *, const int *, int *, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDpruneCsr2csr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnzA, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, threshold, descrC, csrSortedValC, + csrSortedRowPtrC, csrSortedColIndC, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseSpruneDense2csrByPercentage_bufferSizeExt( + cusparseHandle_t handle, int m, int n, const float *A, int lda, + float percentage, const cusparseMatDescr_t descrC, + const float *csrSortedValC, const int *csrSortedRowPtrC, + const int *csrSortedColIndC, pruneInfo_t info, size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const float *, int, float, + const cusparseMatDescr_t, const float *, const int *, const int *, + pruneInfo_t, size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusparseSpruneDense2csrByPercentage_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, percentage, descrC, csrSortedValC, + csrSortedRowPtrC, csrSortedColIndC, info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseDpruneDense2csrByPercentage_bufferSizeExt( + cusparseHandle_t handle, int m, int n, const double *A, int lda, + float percentage, const cusparseMatDescr_t descrC, + const double *csrSortedValC, const int *csrSortedRowPtrC, + const int *csrSortedColIndC, pruneInfo_t info, size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const double *, int, float, + const cusparseMatDescr_t, const double *, const int *, const int *, + pruneInfo_t, size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusparseDpruneDense2csrByPercentage_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, percentage, descrC, csrSortedValC, + csrSortedRowPtrC, csrSortedColIndC, info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseSpruneDense2csrNnzByPercentage( + cusparseHandle_t handle, int m, int n, const float *A, int lda, + float percentage, const cusparseMatDescr_t descrC, int *csrRowPtrC, + int *nnzTotalDevHostPtr, pruneInfo_t info, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const float *, int, float, + const cusparseMatDescr_t, int *, int *, pruneInfo_t, void *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusparseSpruneDense2csrNnzByPercentage"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, percentage, descrC, csrRowPtrC, + nnzTotalDevHostPtr, info, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDpruneDense2csrNnzByPercentage( + cusparseHandle_t handle, int m, int n, const double *A, int lda, + float percentage, const cusparseMatDescr_t descrC, int *csrRowPtrC, + int *nnzTotalDevHostPtr, pruneInfo_t info, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const double *, int, float, + const cusparseMatDescr_t, int *, int *, pruneInfo_t, void *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusparseDpruneDense2csrNnzByPercentage"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, percentage, descrC, csrRowPtrC, + nnzTotalDevHostPtr, info, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseSpruneDense2csrByPercentage( + cusparseHandle_t handle, int m, int n, const float *A, int lda, + float percentage, const cusparseMatDescr_t descrC, float *csrSortedValC, + const int *csrSortedRowPtrC, int *csrSortedColIndC, pruneInfo_t info, + void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const float *, int, float, + const cusparseMatDescr_t, float *, const int *, int *, pruneInfo_t, + void *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusparseSpruneDense2csrByPercentage"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, percentage, descrC, csrSortedValC, + csrSortedRowPtrC, csrSortedColIndC, info, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDpruneDense2csrByPercentage( + cusparseHandle_t handle, int m, int n, const double *A, int lda, + float percentage, const cusparseMatDescr_t descrC, double *csrSortedValC, + const int *csrSortedRowPtrC, int *csrSortedColIndC, pruneInfo_t info, + void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, const double *, int, float, + const cusparseMatDescr_t, double *, const int *, int *, pruneInfo_t, + void *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusparseDpruneDense2csrByPercentage"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, A, lda, percentage, descrC, csrSortedValC, + csrSortedRowPtrC, csrSortedColIndC, info, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseSpruneCsr2csrByPercentage_bufferSizeExt( + cusparseHandle_t handle, int m, int n, int nnzA, + const cusparseMatDescr_t descrA, const float *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, float percentage, + const cusparseMatDescr_t descrC, const float *csrSortedValC, + const int *csrSortedRowPtrC, const int *csrSortedColIndC, pruneInfo_t info, + size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cusparseMatDescr_t, const float *, + const int *, const int *, float, const cusparseMatDescr_t, const float *, + const int *, const int *, pruneInfo_t, size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusparseSpruneCsr2csrByPercentage_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnzA, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, percentage, descrC, csrSortedValC, + csrSortedRowPtrC, csrSortedColIndC, info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseDpruneCsr2csrByPercentage_bufferSizeExt( + cusparseHandle_t handle, int m, int n, int nnzA, + const cusparseMatDescr_t descrA, const double *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, float percentage, + const cusparseMatDescr_t descrC, const double *csrSortedValC, + const int *csrSortedRowPtrC, const int *csrSortedColIndC, pruneInfo_t info, + size_t *pBufferSizeInBytes) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cusparseMatDescr_t, const double *, + const int *, const int *, float, const cusparseMatDescr_t, const double *, + const int *, const int *, pruneInfo_t, size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusparseDpruneCsr2csrByPercentage_bufferSizeExt"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnzA, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, percentage, descrC, csrSortedValC, + csrSortedRowPtrC, csrSortedColIndC, info, pBufferSizeInBytes); +} + +cusparseStatus_t CUSPARSEAPI cusparseSpruneCsr2csrNnzByPercentage( + cusparseHandle_t handle, int m, int n, int nnzA, + const cusparseMatDescr_t descrA, const float *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, float percentage, + const cusparseMatDescr_t descrC, int *csrSortedRowPtrC, + int *nnzTotalDevHostPtr, pruneInfo_t info, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cusparseMatDescr_t, const float *, + const int *, const int *, float, const cusparseMatDescr_t, int *, int *, + pruneInfo_t, void *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusparseSpruneCsr2csrNnzByPercentage"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnzA, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, percentage, descrC, csrSortedRowPtrC, + nnzTotalDevHostPtr, info, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDpruneCsr2csrNnzByPercentage( + cusparseHandle_t handle, int m, int n, int nnzA, + const cusparseMatDescr_t descrA, const double *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, float percentage, + const cusparseMatDescr_t descrC, int *csrSortedRowPtrC, + int *nnzTotalDevHostPtr, pruneInfo_t info, void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cusparseMatDescr_t, const double *, + const int *, const int *, float, const cusparseMatDescr_t, int *, int *, + pruneInfo_t, void *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusparseDpruneCsr2csrNnzByPercentage"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnzA, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, percentage, descrC, csrSortedRowPtrC, + nnzTotalDevHostPtr, info, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseSpruneCsr2csrByPercentage( + cusparseHandle_t handle, int m, int n, int nnzA, + const cusparseMatDescr_t descrA, const float *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, float percentage, + const cusparseMatDescr_t descrC, float *csrSortedValC, + const int *csrSortedRowPtrC, int *csrSortedColIndC, pruneInfo_t info, + void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cusparseMatDescr_t, const float *, + const int *, const int *, float, const cusparseMatDescr_t, float *, + const int *, int *, pruneInfo_t, void *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusparseSpruneCsr2csrByPercentage"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnzA, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, percentage, descrC, csrSortedValC, + csrSortedRowPtrC, csrSortedColIndC, info, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseDpruneCsr2csrByPercentage( + cusparseHandle_t handle, int m, int n, int nnzA, + const cusparseMatDescr_t descrA, const double *csrSortedValA, + const int *csrSortedRowPtrA, const int *csrSortedColIndA, float percentage, + const cusparseMatDescr_t descrC, double *csrSortedValC, + const int *csrSortedRowPtrC, int *csrSortedColIndC, pruneInfo_t info, + void *pBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const cusparseMatDescr_t, const double *, + const int *, const int *, float, const cusparseMatDescr_t, double *, + const int *, int *, pruneInfo_t, void *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusparseDpruneCsr2csrByPercentage"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnzA, descrA, csrSortedValA, csrSortedRowPtrA, + csrSortedColIndA, percentage, descrC, csrSortedValC, + csrSortedRowPtrC, csrSortedColIndC, info, pBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCsr2cscEx2( + cusparseHandle_t handle, int m, int n, int nnz, const void *csrVal, + const int *csrRowPtr, const int *csrColInd, void *cscVal, int *cscColPtr, + int *cscRowInd, cudaDataType valType, cusparseAction_t copyValues, + cusparseIndexBase_t idxBase, cusparseCsr2CscAlg_t alg, void *buffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const void *, const int *, const int *, + void *, int *, int *, cudaDataType, cusparseAction_t, cusparseIndexBase_t, + cusparseCsr2CscAlg_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCsr2cscEx2"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnz, csrVal, csrRowPtr, csrColInd, cscVal, + cscColPtr, cscRowInd, valType, copyValues, idxBase, alg, + buffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseCsr2cscEx2_bufferSize( + cusparseHandle_t handle, int m, int n, int nnz, const void *csrVal, + const int *csrRowPtr, const int *csrColInd, void *cscVal, int *cscColPtr, + int *cscRowInd, cudaDataType valType, cusparseAction_t copyValues, + cusparseIndexBase_t idxBase, cusparseCsr2CscAlg_t alg, size_t *bufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, int, int, int, const void *, const int *, const int *, + void *, int *, int *, cudaDataType, cusparseAction_t, cusparseIndexBase_t, + cusparseCsr2CscAlg_t, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCsr2cscEx2_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, m, n, nnz, csrVal, csrRowPtr, csrColInd, cscVal, + cscColPtr, cscRowInd, valType, copyValues, idxBase, alg, + bufferSize); +} + +cusparseStatus_t CUSPARSEAPI +cusparseCreateSpVec(cusparseSpVecDescr_t *spVecDescr, int64_t size, int64_t nnz, + void *indices, void *values, cusparseIndexType_t idxType, + cusparseIndexBase_t idxBase, cudaDataType valueType) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseSpVecDescr_t *, int64_t, int64_t, void *, void *, + cusparseIndexType_t, cusparseIndexBase_t, cudaDataType); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCreateSpVec"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(spVecDescr, size, nnz, indices, values, idxType, idxBase, + valueType); +} + +cusparseStatus_t CUSPARSEAPI +cusparseDestroySpVec(cusparseSpVecDescr_t spVecDescr) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(cusparseSpVecDescr_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDestroySpVec"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(spVecDescr); +} + +cusparseStatus_t CUSPARSEAPI cusparseSpVecGet(cusparseSpVecDescr_t spVecDescr, + int64_t *size, int64_t *nnz, + void **indices, void **values, + cusparseIndexType_t *idxType, + cusparseIndexBase_t *idxBase, + cudaDataType *valueType) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseSpVecDescr_t, int64_t *, int64_t *, void **, void **, + cusparseIndexType_t *, cusparseIndexBase_t *, cudaDataType *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSpVecGet"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(spVecDescr, size, nnz, indices, values, idxType, idxBase, + valueType); +} + +cusparseStatus_t CUSPARSEAPI cusparseSpVecGetIndexBase( + cusparseSpVecDescr_t spVecDescr, cusparseIndexBase_t *idxBase) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(cusparseSpVecDescr_t, + cusparseIndexBase_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSpVecGetIndexBase"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(spVecDescr, idxBase); +} + +cusparseStatus_t CUSPARSEAPI +cusparseSpVecGetValues(cusparseSpVecDescr_t spVecDescr, void **values) { + using FuncPtr = + cusparseStatus_t(CUSPARSEAPI *)(cusparseSpVecDescr_t, void **); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSpVecGetValues"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(spVecDescr, values); +} + +cusparseStatus_t CUSPARSEAPI +cusparseSpVecSetValues(cusparseSpVecDescr_t spVecDescr, void *values) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(cusparseSpVecDescr_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSpVecSetValues"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(spVecDescr, values); +} + +cusparseStatus_t CUSPARSEAPI +cusparseCreateDnVec(cusparseDnVecDescr_t *dnVecDescr, int64_t size, + void *values, cudaDataType valueType) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseDnVecDescr_t *, int64_t, void *, cudaDataType); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCreateDnVec"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dnVecDescr, size, values, valueType); +} + +cusparseStatus_t CUSPARSEAPI +cusparseDestroyDnVec(cusparseDnVecDescr_t dnVecDescr) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(cusparseDnVecDescr_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDestroyDnVec"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dnVecDescr); +} + +cusparseStatus_t CUSPARSEAPI cusparseDnVecGet(cusparseDnVecDescr_t dnVecDescr, + int64_t *size, void **values, + cudaDataType *valueType) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseDnVecDescr_t, int64_t *, void **, cudaDataType *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDnVecGet"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dnVecDescr, size, values, valueType); +} + +cusparseStatus_t CUSPARSEAPI +cusparseDnVecGetValues(cusparseDnVecDescr_t dnVecDescr, void **values) { + using FuncPtr = + cusparseStatus_t(CUSPARSEAPI *)(cusparseDnVecDescr_t, void **); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDnVecGetValues"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dnVecDescr, values); +} + +cusparseStatus_t CUSPARSEAPI +cusparseDnVecSetValues(cusparseDnVecDescr_t dnVecDescr, void *values) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(cusparseDnVecDescr_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDnVecSetValues"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dnVecDescr, values); +} + +cusparseStatus_t CUSPARSEAPI +cusparseDestroySpMat(cusparseSpMatDescr_t spMatDescr) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(cusparseSpMatDescr_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDestroySpMat"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(spMatDescr); +} + +cusparseStatus_t CUSPARSEAPI cusparseSpMatGetFormat( + cusparseSpMatDescr_t spMatDescr, cusparseFormat_t *format) { + using FuncPtr = + cusparseStatus_t(CUSPARSEAPI *)(cusparseSpMatDescr_t, cusparseFormat_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSpMatGetFormat"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(spMatDescr, format); +} + +cusparseStatus_t CUSPARSEAPI cusparseSpMatGetIndexBase( + cusparseSpMatDescr_t spMatDescr, cusparseIndexBase_t *idxBase) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(cusparseSpMatDescr_t, + cusparseIndexBase_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSpMatGetIndexBase"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(spMatDescr, idxBase); +} + +cusparseStatus_t CUSPARSEAPI +cusparseSpMatGetValues(cusparseSpMatDescr_t spMatDescr, void **values) { + using FuncPtr = + cusparseStatus_t(CUSPARSEAPI *)(cusparseSpMatDescr_t, void **); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSpMatGetValues"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(spMatDescr, values); +} + +cusparseStatus_t CUSPARSEAPI +cusparseSpMatSetValues(cusparseSpMatDescr_t spMatDescr, void *values) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(cusparseSpMatDescr_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSpMatSetValues"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(spMatDescr, values); +} + +cusparseStatus_t CUSPARSEAPI +cusparseSpMatGetSize(cusparseSpMatDescr_t spMatDescr, int64_t *rows, + int64_t *cols, int64_t *nnz) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseSpMatDescr_t, int64_t *, int64_t *, int64_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSpMatGetSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(spMatDescr, rows, cols, nnz); +} + +cusparseStatus_t CUSPARSEAPI +cusparseSpMatSetStridedBatch(cusparseSpMatDescr_t spMatDescr, int batchCount) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(cusparseSpMatDescr_t, int); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSpMatSetStridedBatch"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(spMatDescr, batchCount); +} + +cusparseStatus_t CUSPARSEAPI +cusparseSpMatGetStridedBatch(cusparseSpMatDescr_t spMatDescr, int *batchCount) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(cusparseSpMatDescr_t, int *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSpMatGetStridedBatch"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(spMatDescr, batchCount); +} + +cusparseStatus_t CUSPARSEAPI cusparseCreateCsr( + cusparseSpMatDescr_t *spMatDescr, int64_t rows, int64_t cols, int64_t nnz, + void *csrRowOffsets, void *csrColInd, void *csrValues, + cusparseIndexType_t csrRowOffsetsType, cusparseIndexType_t csrColIndType, + cusparseIndexBase_t idxBase, cudaDataType valueType) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseSpMatDescr_t *, int64_t, int64_t, int64_t, void *, void *, void *, + cusparseIndexType_t, cusparseIndexType_t, cusparseIndexBase_t, + cudaDataType); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCreateCsr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(spMatDescr, rows, cols, nnz, csrRowOffsets, csrColInd, + csrValues, csrRowOffsetsType, csrColIndType, idxBase, + valueType); +} + +cusparseStatus_t CUSPARSEAPI cusparseCsrGet( + cusparseSpMatDescr_t spMatDescr, int64_t *rows, int64_t *cols, int64_t *nnz, + void **csrRowOffsets, void **csrColInd, void **csrValues, + cusparseIndexType_t *csrRowOffsetsType, cusparseIndexType_t *csrColIndType, + cusparseIndexBase_t *idxBase, cudaDataType *valueType) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseSpMatDescr_t, int64_t *, int64_t *, int64_t *, void **, void **, + void **, cusparseIndexType_t *, cusparseIndexType_t *, + cusparseIndexBase_t *, cudaDataType *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCsrGet"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(spMatDescr, rows, cols, nnz, csrRowOffsets, csrColInd, + csrValues, csrRowOffsetsType, csrColIndType, idxBase, + valueType); +} + +cusparseStatus_t CUSPARSEAPI +cusparseCsrSetPointers(cusparseSpMatDescr_t spMatDescr, void *csrRowOffsets, + void *csrColInd, void *csrValues) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(cusparseSpMatDescr_t, void *, + void *, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCsrSetPointers"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(spMatDescr, csrRowOffsets, csrColInd, csrValues); +} + +cusparseStatus_t CUSPARSEAPI cusparseCreateCoo(cusparseSpMatDescr_t *spMatDescr, + int64_t rows, int64_t cols, + int64_t nnz, void *cooRowInd, + void *cooColInd, void *cooValues, + cusparseIndexType_t cooIdxType, + cusparseIndexBase_t idxBase, + cudaDataType valueType) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseSpMatDescr_t *, int64_t, int64_t, int64_t, void *, void *, void *, + cusparseIndexType_t, cusparseIndexBase_t, cudaDataType); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCreateCoo"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(spMatDescr, rows, cols, nnz, cooRowInd, cooColInd, cooValues, + cooIdxType, idxBase, valueType); +} + +cusparseStatus_t CUSPARSEAPI cusparseCreateCooAoS( + cusparseSpMatDescr_t *spMatDescr, int64_t rows, int64_t cols, int64_t nnz, + void *cooInd, void *cooValues, cusparseIndexType_t cooIdxType, + cusparseIndexBase_t idxBase, cudaDataType valueType) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseSpMatDescr_t *, int64_t, int64_t, int64_t, void *, void *, + cusparseIndexType_t, cusparseIndexBase_t, cudaDataType); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCreateCooAoS"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(spMatDescr, rows, cols, nnz, cooInd, cooValues, cooIdxType, + idxBase, valueType); +} + +cusparseStatus_t CUSPARSEAPI cusparseCooGet( + cusparseSpMatDescr_t spMatDescr, int64_t *rows, int64_t *cols, int64_t *nnz, + void **cooRowInd, // COO row indices + void **cooColInd, // COO column indices + void **cooValues, // COO values + cusparseIndexType_t *idxType, cusparseIndexBase_t *idxBase, + cudaDataType *valueType) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseSpMatDescr_t, int64_t *, int64_t *, int64_t *, void **, void **, + void **, cusparseIndexType_t *, cusparseIndexBase_t *, cudaDataType *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCooGet"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(spMatDescr, rows, cols, nnz, cooRowInd, cooColInd, cooValues, + idxType, idxBase, valueType); +} + +cusparseStatus_t CUSPARSEAPI cusparseCooAoSGet(cusparseSpMatDescr_t spMatDescr, + int64_t *rows, int64_t *cols, + int64_t *nnz, + void **cooInd, // COO indices + void **cooValues, // COO values + cusparseIndexType_t *idxType, + cusparseIndexBase_t *idxBase, + cudaDataType *valueType) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseSpMatDescr_t, int64_t *, int64_t *, int64_t *, void **, void **, + cusparseIndexType_t *, cusparseIndexBase_t *, cudaDataType *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCooAoSGet"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(spMatDescr, rows, cols, nnz, cooInd, cooValues, idxType, + idxBase, valueType); +} + +cusparseStatus_t CUSPARSEAPI cusparseCreateDnMat( + cusparseDnMatDescr_t *dnMatDescr, int64_t rows, int64_t cols, int64_t ld, + void *values, cudaDataType valueType, cusparseOrder_t order) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseDnMatDescr_t *, int64_t, int64_t, int64_t, void *, cudaDataType, + cusparseOrder_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseCreateDnMat"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dnMatDescr, rows, cols, ld, values, valueType, order); +} + +cusparseStatus_t CUSPARSEAPI +cusparseDestroyDnMat(cusparseDnMatDescr_t dnMatDescr) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(cusparseDnMatDescr_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDestroyDnMat"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dnMatDescr); +} + +cusparseStatus_t CUSPARSEAPI cusparseDnMatGet(cusparseDnMatDescr_t dnMatDescr, + int64_t *rows, int64_t *cols, + int64_t *ld, void **values, + cudaDataType *type, + cusparseOrder_t *order) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseDnMatDescr_t, int64_t *, int64_t *, int64_t *, void **, + cudaDataType *, cusparseOrder_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDnMatGet"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dnMatDescr, rows, cols, ld, values, type, order); +} + +cusparseStatus_t CUSPARSEAPI +cusparseDnMatGetValues(cusparseDnMatDescr_t dnMatDescr, void **values) { + using FuncPtr = + cusparseStatus_t(CUSPARSEAPI *)(cusparseDnMatDescr_t, void **); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDnMatGetValues"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dnMatDescr, values); +} + +cusparseStatus_t CUSPARSEAPI +cusparseDnMatSetValues(cusparseDnMatDescr_t dnMatDescr, void *values) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(cusparseDnMatDescr_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDnMatSetValues"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dnMatDescr, values); +} + +cusparseStatus_t CUSPARSEAPI cusparseDnMatSetStridedBatch( + cusparseDnMatDescr_t dnMatDescr, int batchCount, int64_t batchStride) { + using FuncPtr = + cusparseStatus_t(CUSPARSEAPI *)(cusparseDnMatDescr_t, int, int64_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDnMatSetStridedBatch"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dnMatDescr, batchCount, batchStride); +} + +cusparseStatus_t CUSPARSEAPI cusparseDnMatGetStridedBatch( + cusparseDnMatDescr_t dnMatDescr, int *batchCount, int64_t *batchStride) { + using FuncPtr = + cusparseStatus_t(CUSPARSEAPI *)(cusparseDnMatDescr_t, int *, int64_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseDnMatGetStridedBatch"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(dnMatDescr, batchCount, batchStride); +} + +cusparseStatus_t CUSPARSEAPI cusparseSpVV_bufferSize( + cusparseHandle_t handle, cusparseOperation_t opX, cusparseSpVecDescr_t vecX, + cusparseDnVecDescr_t vecY, const void *result, cudaDataType computeType, + size_t *bufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, cusparseSpVecDescr_t, + cusparseDnVecDescr_t, const void *, cudaDataType, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSpVV_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, opX, vecX, vecY, result, computeType, bufferSize); +} + +cusparseStatus_t CUSPARSEAPI +cusparseSpVV(cusparseHandle_t handle, cusparseOperation_t opX, + cusparseSpVecDescr_t vecX, cusparseDnVecDescr_t vecY, void *result, + cudaDataType computeType, void *externalBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, cusparseSpVecDescr_t, + cusparseDnVecDescr_t, void *, cudaDataType, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSpVV"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, opX, vecX, vecY, result, computeType, externalBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseSpMV( + cusparseHandle_t handle, cusparseOperation_t opA, const void *alpha, + cusparseSpMatDescr_t matA, cusparseDnVecDescr_t vecX, const void *beta, + cusparseDnVecDescr_t vecY, cudaDataType computeType, cusparseSpMVAlg_t alg, + void *externalBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, const void *, cusparseSpMatDescr_t, + cusparseDnVecDescr_t, const void *, cusparseDnVecDescr_t, cudaDataType, + cusparseSpMVAlg_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSpMV"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, opA, alpha, matA, vecX, beta, vecY, computeType, alg, + externalBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseSpMV_bufferSize( + cusparseHandle_t handle, cusparseOperation_t opA, const void *alpha, + cusparseSpMatDescr_t matA, cusparseDnVecDescr_t vecX, const void *beta, + cusparseDnVecDescr_t vecY, cudaDataType computeType, cusparseSpMVAlg_t alg, + size_t *bufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, const void *, cusparseSpMatDescr_t, + cusparseDnVecDescr_t, const void *, cusparseDnVecDescr_t, cudaDataType, + cusparseSpMVAlg_t, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSpMV_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, opA, alpha, matA, vecX, beta, vecY, computeType, alg, + bufferSize); +} + +cusparseStatus_t CUSPARSEAPI cusparseSpMM( + cusparseHandle_t handle, cusparseOperation_t opA, cusparseOperation_t opB, + const void *alpha, cusparseSpMatDescr_t matA, cusparseDnMatDescr_t matB, + const void *beta, cusparseDnMatDescr_t matC, cudaDataType computeType, + cusparseSpMMAlg_t alg, void *externalBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, cusparseOperation_t, const void *, + cusparseSpMatDescr_t, cusparseDnMatDescr_t, const void *, + cusparseDnMatDescr_t, cudaDataType, cusparseSpMMAlg_t, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSpMM"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, opA, opB, alpha, matA, matB, beta, matC, computeType, + alg, externalBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseSpMM_bufferSize( + cusparseHandle_t handle, cusparseOperation_t opA, cusparseOperation_t opB, + const void *alpha, cusparseSpMatDescr_t matA, cusparseDnMatDescr_t matB, + const void *beta, cusparseDnMatDescr_t matC, cudaDataType computeType, + cusparseSpMMAlg_t alg, size_t *bufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, cusparseOperation_t, const void *, + cusparseSpMatDescr_t, cusparseDnMatDescr_t, const void *, + cusparseDnMatDescr_t, cudaDataType, cusparseSpMMAlg_t, size_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSpMM_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, opA, opB, alpha, matA, matB, beta, matC, computeType, + alg, bufferSize); +} + +cusparseStatus_t CUSPARSEAPI +cusparseSpGEMM_createDescr(cusparseSpGEMMDescr_t *descr) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(cusparseSpGEMMDescr_t *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSpGEMM_createDescr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(descr); +} + +cusparseStatus_t CUSPARSEAPI +cusparseSpGEMM_destroyDescr(cusparseSpGEMMDescr_t descr) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)(cusparseSpGEMMDescr_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSpGEMM_destroyDescr"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(descr); +} + +cusparseStatus_t CUSPARSEAPI cusparseSpGEMM_workEstimation( + cusparseHandle_t handle, cusparseOperation_t opA, cusparseOperation_t opB, + const void *alpha, cusparseSpMatDescr_t matA, cusparseSpMatDescr_t matB, + const void *beta, cusparseSpMatDescr_t matC, cudaDataType computeType, + cusparseSpGEMMAlg_t alg, cusparseSpGEMMDescr_t spgemmDescr, + size_t *bufferSize1, void *externalBuffer1) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, cusparseOperation_t, const void *, + cusparseSpMatDescr_t, cusparseSpMatDescr_t, const void *, + cusparseSpMatDescr_t, cudaDataType, cusparseSpGEMMAlg_t, + cusparseSpGEMMDescr_t, size_t *, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSpGEMM_workEstimation"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, opA, opB, alpha, matA, matB, beta, matC, computeType, + alg, spgemmDescr, bufferSize1, externalBuffer1); +} + +cusparseStatus_t CUSPARSEAPI cusparseSpGEMM_compute( + cusparseHandle_t handle, cusparseOperation_t opA, cusparseOperation_t opB, + const void *alpha, cusparseSpMatDescr_t matA, cusparseSpMatDescr_t matB, + const void *beta, cusparseSpMatDescr_t matC, cudaDataType computeType, + cusparseSpGEMMAlg_t alg, cusparseSpGEMMDescr_t spgemmDescr, + size_t *bufferSize2, void *externalBuffer2) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, cusparseOperation_t, const void *, + cusparseSpMatDescr_t, cusparseSpMatDescr_t, const void *, + cusparseSpMatDescr_t, cudaDataType, cusparseSpGEMMAlg_t, + cusparseSpGEMMDescr_t, size_t *, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSpGEMM_compute"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, opA, opB, alpha, matA, matB, beta, matC, computeType, + alg, spgemmDescr, bufferSize2, externalBuffer2); +} + +cusparseStatus_t CUSPARSEAPI cusparseSpGEMM_copy( + cusparseHandle_t handle, cusparseOperation_t opA, cusparseOperation_t opB, + const void *alpha, cusparseSpMatDescr_t matA, cusparseSpMatDescr_t matB, + const void *beta, cusparseSpMatDescr_t matC, cudaDataType computeType, + cusparseSpGEMMAlg_t alg, cusparseSpGEMMDescr_t spgemmDescr) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, cusparseOperation_t, const void *, + cusparseSpMatDescr_t, cusparseSpMatDescr_t, const void *, + cusparseSpMatDescr_t, cudaDataType, cusparseSpGEMMAlg_t, + cusparseSpGEMMDescr_t); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseSpGEMM_copy"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, opA, opB, alpha, matA, matB, beta, matC, computeType, + alg, spgemmDescr); +} + +cusparseStatus_t CUSPARSEAPI cusparseConstrainedGeMM( + cusparseHandle_t handle, cusparseOperation_t opA, cusparseOperation_t opB, + const void *alpha, cusparseDnMatDescr_t matA, cusparseDnMatDescr_t matB, + const void *beta, cusparseSpMatDescr_t matC, cudaDataType computeType, + void *externalBuffer) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, cusparseOperation_t, const void *, + cusparseDnMatDescr_t, cusparseDnMatDescr_t, const void *, + cusparseSpMatDescr_t, cudaDataType, void *); + static auto func_ptr = LoadSymbol<FuncPtr>("cusparseConstrainedGeMM"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, opA, opB, alpha, matA, matB, beta, matC, computeType, + externalBuffer); +} + +cusparseStatus_t CUSPARSEAPI cusparseConstrainedGeMM_bufferSize( + cusparseHandle_t handle, cusparseOperation_t opA, cusparseOperation_t opB, + const void *alpha, cusparseDnMatDescr_t matA, cusparseDnMatDescr_t matB, + const void *beta, cusparseSpMatDescr_t matC, cudaDataType computeType, + size_t *bufferSize) { + using FuncPtr = cusparseStatus_t(CUSPARSEAPI *)( + cusparseHandle_t, cusparseOperation_t, cusparseOperation_t, const void *, + cusparseDnMatDescr_t, cusparseDnMatDescr_t, const void *, + cusparseSpMatDescr_t, cudaDataType, size_t *); + static auto func_ptr = + LoadSymbol<FuncPtr>("cusparseConstrainedGeMM_bufferSize"); + if (!func_ptr) return GetSymbolNotFoundError(); + return func_ptr(handle, opA, opB, alpha, matA, matB, beta, matC, computeType, + bufferSize); +} + +} // extern "C" diff --git a/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc b/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc index e22a243a70b..fd3b5f19913 100644 --- a/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc +++ b/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc @@ -132,6 +132,11 @@ bool GpuExecutor::UnloadGpuBinary(const void* gpu_binary) { VLOG(3) << "Unloading HSACO module " << module; GpuDriver::UnloadModule(context_, module); gpu_binary_to_module_.erase(module_it); + const char* mem_it = nullptr; + for (auto x : in_memory_modules_) { + if (x.second == module) mem_it = x.first; + } + if (mem_it != nullptr) in_memory_modules_.erase(mem_it); } return true; } diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index f56330b428a..9a780839be3 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -48,6 +48,7 @@ load( "//third_party/mkl_dnn:build_defs.bzl", "if_mkl_open_source_only", "if_mkl_v1_open_source_only", + "if_mkldnn_threadpool", ) load( "//third_party/ngraph:build_defs.bzl", @@ -327,6 +328,11 @@ def tf_copts( if_mkl(["-DINTEL_MKL=1", "-DEIGEN_USE_VML"]) + if_mkl_open_source_only(["-DINTEL_MKL_DNN_ONLY"]) + if_mkl_v1_open_source_only(["-DENABLE_MKLDNN_V1"]) + + if_mkldnn_threadpool([ + "-DENABLE_MKLDNN_THREADPOOL", + "-DENABLE_MKLDNN_V1", + "-DINTEL_MKL_DNN_ONLY", + ]) + if_enable_mkl(["-DENABLE_MKL"]) + if_ngraph(["-DINTEL_NGRAPH=1"]) + if_android_arm(["-mfpu=neon"]) + @@ -348,7 +354,9 @@ def tf_copts( ) def tf_openmp_copts(): - return if_mkl_lnx_x64(["-fopenmp"]) + # TODO(intel-mkl): Remove -fopenmp for threadpool after removing all + # omp pragmas in tensorflow/core. + return if_mkl_lnx_x64(["-fopenmp"]) + if_mkldnn_threadpool(["-fopenmp"]) def tfe_xla_copts(): return select({ @@ -615,6 +623,9 @@ def tf_cc_shared_object( linkshared = 1, data = data + data_extra, linkopts = linkopts + _rpath_linkopts(name_os_full) + select({ + clean_dep("//tensorflow:ios"): [ + "-Wl,-install_name,@rpath/" + soname, + ], clean_dep("//tensorflow:macos"): [ "-Wl,-install_name,@rpath/" + soname, ], @@ -863,7 +874,7 @@ def tf_gen_op_wrappers_cc( clean_dep("//tensorflow/core:ops"), clean_dep("//tensorflow/core:protos_all_cc"), ]) + if_android([ - clean_dep("//tensorflow/core:android_tensorflow_lib"), + clean_dep("//tensorflow/core:portable_tensorflow_lib"), ]), copts = tf_copts(), alwayslink = 1, @@ -880,7 +891,7 @@ def tf_gen_op_wrappers_cc( clean_dep("//tensorflow/core:ops"), clean_dep("//tensorflow/core:protos_all_cc"), ]) + if_android([ - clean_dep("//tensorflow/core:android_tensorflow_lib"), + clean_dep("//tensorflow/core:portable_tensorflow_lib"), ]), copts = tf_copts(), alwayslink = 1, @@ -2207,6 +2218,15 @@ def tf_py_test( xla_enabled = False, grpc_enabled = False, tfrt_enabled = False, + # `tfrt_enabled` is set for some test targets, and if we enable + # TFRT tests just by that, this will enable TFRT builds for open source. + # TFRT open source is not fully integrated yet so we need a temporary + # workaround to enable TFRT only for internal builds. `tfrt_enabled_internal` + # will be set by `tensorflow.google.bzl`'s `tf_py_test` target, which is + # only applied for internal builds. + # TODO(b/156911178): Revert this temporary workaround once TFRT open source + # is fully integrated with TF. + tfrt_enabled_internal = False, **kwargs): """Create one or more python tests with extra tensorflow dependencies.""" xla_test_true_list = [] @@ -2250,7 +2270,7 @@ def tf_py_test( deps = depset(deps + xla_test_true_list), **kwargs ) - if tfrt_enabled: + if tfrt_enabled_internal: py_test( name = name + "_tfrt", size = size, @@ -2846,7 +2866,7 @@ def if_mlir(if_true, if_false = []): "//conditions:default": if_false, }) -def tfcompile_extra_flags(): +def tfcompile_target_cpu(): return "" def tf_external_workspace_visible(visibility): diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.tpu.experimental.-embedding-config-spec.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.tpu.experimental.-embedding-config-spec.pbtxt index 46d0362a705..355c57269fd 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.estimator.tpu.experimental.-embedding-config-spec.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.tpu.experimental.-embedding-config-spec.pbtxt @@ -35,6 +35,10 @@ tf_class { name: "table_to_config_dict" mtype: "<type \'property\'>" } + member { + name: "tensor_core_feature_columns" + mtype: "<type \'property\'>" + } member_method { name: "__init__" } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-sequence-features.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-sequence-features.pbtxt index 41483f2b83d..e2bef6beaaa 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-sequence-features.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-sequence-features.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.experimental.SequenceFeatures" tf_class { is_instance: "<class \'tensorflow.python.keras.feature_column.sequence_feature_column.SequenceFeatures\'>" - is_instance: "<class \'tensorflow.python.feature_column.feature_column_v2._BaseFeaturesLayer\'>" + is_instance: "<class \'tensorflow.python.keras.feature_column.base_feature_layer._BaseFeaturesLayer\'>" is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>" is_instance: "<class \'tensorflow.python.module.module.Module\'>" is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dense-features.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dense-features.pbtxt index ecda1603325..7ed6c7747a7 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dense-features.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dense-features.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.DenseFeatures" tf_class { - is_instance: "<class \'tensorflow.python.feature_column.dense_features.DenseFeatures\'>" - is_instance: "<class \'tensorflow.python.feature_column.feature_column_v2._BaseFeaturesLayer\'>" + is_instance: "<class \'tensorflow.python.keras.feature_column.dense_features.DenseFeatures\'>" + is_instance: "<class \'tensorflow.python.keras.feature_column.base_feature_layer._BaseFeaturesLayer\'>" is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>" is_instance: "<class \'tensorflow.python.module.module.Module\'>" is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-category-crossing.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-category-crossing.pbtxt index 0407188ab6b..6cfcbf73e5d 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-category-crossing.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-category-crossing.pbtxt @@ -1,6 +1,6 @@ path: "tensorflow.keras.layers.experimental.preprocessing.CategoryCrossing" tf_class { - is_instance: "<class \'tensorflow.python.keras.layers.preprocessing.categorical_crossing.CategoryCrossing\'>" + is_instance: "<class \'tensorflow.python.keras.layers.preprocessing.category_crossing.CategoryCrossing\'>" is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>" is_instance: "<class \'tensorflow.python.module.module.Module\'>" is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>" @@ -113,7 +113,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'depth\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'depth\', \'name\', \'separator\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\'], " } member_method { name: "add_loss" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-category-encoding.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-category-encoding.__metaclass__.pbtxt new file mode 100644 index 00000000000..e907d9a293b --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-category-encoding.__metaclass__.pbtxt @@ -0,0 +1,14 @@ +path: "tensorflow.keras.layers.experimental.preprocessing.CategoryEncoding.__metaclass__" +tf_class { + is_instance: "<type \'type\'>" + member_method { + name: "__init__" + } + member_method { + name: "mro" + } + member_method { + name: "register" + argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-category-encoding.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-category-encoding.pbtxt new file mode 100644 index 00000000000..165a6de49a8 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-category-encoding.pbtxt @@ -0,0 +1,234 @@ +path: "tensorflow.keras.layers.experimental.preprocessing.CategoryEncoding" +tf_class { + is_instance: "<class \'tensorflow.python.keras.layers.preprocessing.category_encoding_v1.CategoryEncoding\'>" + is_instance: "<class \'tensorflow.python.keras.layers.preprocessing.category_encoding.CategoryEncoding\'>" + is_instance: "<class \'tensorflow.python.keras.engine.base_preprocessing_layer_v1.CombinerPreprocessingLayer\'>" + is_instance: "<class \'tensorflow.python.keras.engine.base_preprocessing_layer.CombinerPreprocessingLayer\'>" + is_instance: "<class \'tensorflow.python.keras.engine.base_preprocessing_layer.PreprocessingLayer\'>" + is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>" + is_instance: "<class \'tensorflow.python.module.module.Module\'>" + is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>" + is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>" + is_instance: "<class \'tensorflow.python.keras.utils.version_utils.LayerVersionSelector\'>" + is_instance: "<type \'object\'>" + member { + name: "activity_regularizer" + mtype: "<type \'property\'>" + } + member { + name: "dtype" + mtype: "<type \'property\'>" + } + member { + name: "dynamic" + mtype: "<type \'property\'>" + } + member { + name: "inbound_nodes" + mtype: "<type \'property\'>" + } + member { + name: "input" + mtype: "<type \'property\'>" + } + member { + name: "input_mask" + mtype: "<type \'property\'>" + } + member { + name: "input_shape" + mtype: "<type \'property\'>" + } + member { + name: "input_spec" + mtype: "<type \'property\'>" + } + member { + name: "losses" + mtype: "<type \'property\'>" + } + member { + name: "metrics" + mtype: "<type \'property\'>" + } + member { + name: "name" + mtype: "<type \'property\'>" + } + member { + name: "name_scope" + mtype: "<type \'property\'>" + } + member { + name: "non_trainable_variables" + mtype: "<type \'property\'>" + } + member { + name: "non_trainable_weights" + mtype: "<type \'property\'>" + } + member { + name: "outbound_nodes" + mtype: "<type \'property\'>" + } + member { + name: "output" + mtype: "<type \'property\'>" + } + member { + name: "output_mask" + mtype: "<type \'property\'>" + } + member { + name: "output_shape" + mtype: "<type \'property\'>" + } + member { + name: "stateful" + mtype: "<type \'property\'>" + } + member { + name: "submodules" + mtype: "<type \'property\'>" + } + member { + name: "trainable" + mtype: "<type \'property\'>" + } + member { + name: "trainable_variables" + mtype: "<type \'property\'>" + } + member { + name: "trainable_weights" + mtype: "<type \'property\'>" + } + member { + name: "updates" + mtype: "<type \'property\'>" + } + member { + name: "variables" + mtype: "<type \'property\'>" + } + member { + name: "weights" + mtype: "<type \'property\'>" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'max_tokens\', \'output_mode\', \'sparse\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'count\', \'False\'], " + } + member_method { + name: "adapt" + argspec: "args=[\'self\', \'data\', \'reset_state\'], varargs=None, keywords=None, defaults=[\'True\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "add_metric" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_output_signature" + argspec: "args=[\'self\', \'input_spec\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_num_elements" + argspec: "args=[\'self\', \'num_elements\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_tfidf_data" + argspec: "args=[\'self\', \'tfidf_data\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "with_name_scope" + argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-hashing.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-hashing.pbtxt new file mode 100644 index 00000000000..e4a5619058d --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-hashing.pbtxt @@ -0,0 +1,218 @@ +path: "tensorflow.keras.layers.experimental.preprocessing.Hashing" +tf_class { + is_instance: "<class \'tensorflow.python.keras.layers.preprocessing.hashing.Hashing\'>" + is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>" + is_instance: "<class \'tensorflow.python.module.module.Module\'>" + is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>" + is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>" + is_instance: "<class \'tensorflow.python.keras.utils.version_utils.LayerVersionSelector\'>" + is_instance: "<type \'object\'>" + member { + name: "activity_regularizer" + mtype: "<type \'property\'>" + } + member { + name: "dtype" + mtype: "<type \'property\'>" + } + member { + name: "dynamic" + mtype: "<type \'property\'>" + } + member { + name: "inbound_nodes" + mtype: "<type \'property\'>" + } + member { + name: "input" + mtype: "<type \'property\'>" + } + member { + name: "input_mask" + mtype: "<type \'property\'>" + } + member { + name: "input_shape" + mtype: "<type \'property\'>" + } + member { + name: "input_spec" + mtype: "<type \'property\'>" + } + member { + name: "losses" + mtype: "<type \'property\'>" + } + member { + name: "metrics" + mtype: "<type \'property\'>" + } + member { + name: "name" + mtype: "<type \'property\'>" + } + member { + name: "name_scope" + mtype: "<type \'property\'>" + } + member { + name: "non_trainable_variables" + mtype: "<type \'property\'>" + } + member { + name: "non_trainable_weights" + mtype: "<type \'property\'>" + } + member { + name: "outbound_nodes" + mtype: "<type \'property\'>" + } + member { + name: "output" + mtype: "<type \'property\'>" + } + member { + name: "output_mask" + mtype: "<type \'property\'>" + } + member { + name: "output_shape" + mtype: "<type \'property\'>" + } + member { + name: "stateful" + mtype: "<type \'property\'>" + } + member { + name: "submodules" + mtype: "<type \'property\'>" + } + member { + name: "trainable" + mtype: "<type \'property\'>" + } + member { + name: "trainable_variables" + mtype: "<type \'property\'>" + } + member { + name: "trainable_weights" + mtype: "<type \'property\'>" + } + member { + name: "updates" + mtype: "<type \'property\'>" + } + member { + name: "variables" + mtype: "<type \'property\'>" + } + member { + name: "weights" + mtype: "<type \'property\'>" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'num_bins\', \'salt\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "add_metric" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_output_signature" + argspec: "args=[\'self\', \'input_spec\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "with_name_scope" + argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-rescaling.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-rescaling.pbtxt index 7036fb926a8..60c0bc92f81 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-rescaling.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-rescaling.pbtxt @@ -113,7 +113,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'scale\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " + argspec: "args=[\'self\', \'scale\', \'offset\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'0.0\', \'None\'], " } member_method { name: "add_loss" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.pbtxt index 0964922ea26..a922b143910 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.pbtxt @@ -4,10 +4,18 @@ tf_module { name: "CategoryCrossing" mtype: "<type \'type\'>" } + member { + name: "CategoryEncoding" + mtype: "<type \'type\'>" + } member { name: "CenterCrop" mtype: "<type \'type\'>" } + member { + name: "Hashing" + mtype: "<type \'type\'>" + } member { name: "Normalization" mtype: "<type \'type\'>" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index 44fb74ac63a..37a95cc88d1 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -736,6 +736,10 @@ tf_module { name: "ComplexAbs" argspec: "args=[\'x\', \'Tout\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], " } + member_method { + name: "CompressElement" + argspec: "args=[\'components\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "ComputeAccidentalHits" argspec: "args=[\'true_classes\', \'sampled_candidates\', \'num_true\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'None\'], " @@ -1472,6 +1476,10 @@ tf_module { name: "ExtractGlimpse" argspec: "args=[\'input\', \'size\', \'offsets\', \'centered\', \'normalized\', \'uniform_noise\', \'noise\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'True\', \'uniform\', \'None\'], " } + member_method { + name: "ExtractGlimpseV2" + argspec: "args=[\'input\', \'size\', \'offsets\', \'centered\', \'normalized\', \'uniform_noise\', \'noise\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'True\', \'uniform\', \'None\'], " + } member_method { name: "ExtractImagePatches" argspec: "args=[\'images\', \'ksizes\', \'strides\', \'rates\', \'padding\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " @@ -4100,6 +4108,14 @@ tf_module { name: "SparseCross" argspec: "args=[\'indices\', \'values\', \'shapes\', \'dense_inputs\', \'hashed_output\', \'num_buckets\', \'hash_key\', \'out_type\', \'internal_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "SparseCrossHashed" + argspec: "args=[\'indices\', \'values\', \'shapes\', \'dense_inputs\', \'num_buckets\', \'strong_hash\', \'salt\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "SparseCrossV2" + argspec: "args=[\'indices\', \'values\', \'shapes\', \'dense_inputs\', \'sep\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "SparseDenseCwiseAdd" argspec: "args=[\'sp_indices\', \'sp_values\', \'sp_shape\', \'dense\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " @@ -4590,7 +4606,7 @@ tf_module { } member_method { name: "TPUReplicatedInput" - argspec: "args=[\'inputs\', \'is_mirrored_variable\', \'index\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'-1\', \'None\'], " + argspec: "args=[\'inputs\', \'is_mirrored_variable\', \'index\', \'is_packed\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'-1\', \'False\', \'None\'], " } member_method { name: "TPUReplicatedOutput" @@ -4948,6 +4964,10 @@ tf_module { name: "UnbatchGrad" argspec: "args=[\'original_input\', \'batch_index\', \'grad\', \'id\', \'container\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'None\'], " } + member_method { + name: "UncompressElement" + argspec: "args=[\'compressed\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "UnicodeDecode" argspec: "args=[\'input\', \'input_encoding\', \'errors\', \'replacement_char\', \'replace_control_characters\', \'Tsplits\', \'name\'], varargs=None, keywords=None, defaults=[\'replace\', \'65533\', \'False\', \"<dtype: \'int64\'>\", \'None\'], " diff --git a/tensorflow/tools/api/golden/v1/tensorflow.sparse.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.sparse.pbtxt index f8f8edb26a8..9550418c2a6 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.sparse.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.sparse.pbtxt @@ -22,7 +22,7 @@ tf_module { } member_method { name: "cross" - argspec: "args=[\'inputs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'inputs\', \'name\', \'separator\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } member_method { name: "cross_hashed" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-adagrad.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-adagrad.pbtxt new file mode 100644 index 00000000000..e2c6bbd43d9 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-adagrad.pbtxt @@ -0,0 +1,10 @@ +path: "tensorflow.tpu.experimental.embedding.Adagrad" +tf_class { + is_instance: "<class \'tensorflow.python.tpu.tpu_embedding_v2_utils.Adagrad\'>" + is_instance: "<class \'tensorflow.python.tpu.tpu_embedding_v2_utils._Optimizer\'>" + is_instance: "<type \'object\'>" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'learning_rate\', \'initial_accumulator_value\', \'use_gradient_accumulation\', \'clip_weight_min\', \'clip_weight_max\', \'weight_decay_factor\', \'multiply_weight_decay_factor_by_learning_rate\', \'slot_variable_creation_fn\'], varargs=None, keywords=None, defaults=[\'0.001\', \'0.1\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-adam.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-adam.pbtxt new file mode 100644 index 00000000000..941e81acbbb --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-adam.pbtxt @@ -0,0 +1,10 @@ +path: "tensorflow.tpu.experimental.embedding.Adam" +tf_class { + is_instance: "<class \'tensorflow.python.tpu.tpu_embedding_v2_utils.Adam\'>" + is_instance: "<class \'tensorflow.python.tpu.tpu_embedding_v2_utils._Optimizer\'>" + is_instance: "<type \'object\'>" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'learning_rate\', \'beta_1\', \'beta_2\', \'epsilon\', \'lazy_adam\', \'sum_inside_sqrt\', \'use_gradient_accumulation\', \'clip_weight_min\', \'clip_weight_max\', \'weight_decay_factor\', \'multiply_weight_decay_factor_by_learning_rate\', \'slot_variable_creation_fn\'], varargs=None, keywords=None, defaults=[\'0.001\', \'0.9\', \'0.999\', \'1e-07\', \'True\', \'True\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-feature-config.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-feature-config.pbtxt new file mode 100644 index 00000000000..b2c31d00ad8 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-feature-config.pbtxt @@ -0,0 +1,9 @@ +path: "tensorflow.tpu.experimental.embedding.FeatureConfig" +tf_class { + is_instance: "<class \'tensorflow.python.tpu.tpu_embedding_v2_utils.FeatureConfig\'>" + is_instance: "<type \'object\'>" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'table\', \'max_sequence_length\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-s-g-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-s-g-d.pbtxt new file mode 100644 index 00000000000..9a3f47406b8 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-s-g-d.pbtxt @@ -0,0 +1,10 @@ +path: "tensorflow.tpu.experimental.embedding.SGD" +tf_class { + is_instance: "<class \'tensorflow.python.tpu.tpu_embedding_v2_utils.SGD\'>" + is_instance: "<class \'tensorflow.python.tpu.tpu_embedding_v2_utils._Optimizer\'>" + is_instance: "<type \'object\'>" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'learning_rate\', \'clip_weight_min\', \'clip_weight_max\', \'weight_decay_factor\', \'multiply_weight_decay_factor_by_learning_rate\'], varargs=None, keywords=None, defaults=[\'0.01\', \'None\', \'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-t-p-u-embedding.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-t-p-u-embedding.pbtxt new file mode 100644 index 00000000000..9cc8354b4bf --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-t-p-u-embedding.pbtxt @@ -0,0 +1,27 @@ +path: "tensorflow.tpu.experimental.embedding.TPUEmbedding" +tf_class { + is_instance: "<class \'tensorflow.python.tpu.tpu_embedding_v2.TPUEmbedding\'>" + is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>" + is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>" + is_instance: "<type \'object\'>" + member { + name: "embedding_tables" + mtype: "<type \'property\'>" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'feature_config\', \'batch_size\', \'optimizer\', \'pipeline_execution_with_tensor_core\', \'initialize_tpu_embedding\'], varargs=None, keywords=None, defaults=[\'False\', \'True\'], " + } + member_method { + name: "apply_gradients" + argspec: "args=[\'self\', \'gradients\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "dequeue" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "enqueue" + argspec: "args=[\'self\', \'features\', \'weights\', \'training\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-table-config.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-table-config.pbtxt new file mode 100644 index 00000000000..6be35ed6fb6 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-table-config.pbtxt @@ -0,0 +1,9 @@ +path: "tensorflow.tpu.experimental.embedding.TableConfig" +tf_class { + is_instance: "<class \'tensorflow.python.tpu.tpu_embedding_v2_utils.TableConfig\'>" + is_instance: "<type \'object\'>" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'vocabulary_size\', \'dim\', \'initializer\', \'optimizer\', \'combiner\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'mean\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.pbtxt new file mode 100644 index 00000000000..9d4f24f4edd --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.pbtxt @@ -0,0 +1,27 @@ +path: "tensorflow.tpu.experimental.embedding" +tf_module { + member { + name: "Adagrad" + mtype: "<type \'type\'>" + } + member { + name: "Adam" + mtype: "<type \'type\'>" + } + member { + name: "FeatureConfig" + mtype: "<type \'type\'>" + } + member { + name: "SGD" + mtype: "<type \'type\'>" + } + member { + name: "TPUEmbedding" + mtype: "<type \'type\'>" + } + member { + name: "TableConfig" + mtype: "<type \'type\'>" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.pbtxt index ef1c8078cca..f9925518a1a 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.pbtxt @@ -28,6 +28,10 @@ tf_module { name: "Topology" mtype: "<type \'type\'>" } + member { + name: "embedding" + mtype: "<type \'module\'>" + } member_method { name: "embedding_column" argspec: "args=[\'categorical_column\', \'dimension\', \'combiner\', \'initializer\', \'max_sequence_length\', \'learning_rate_fn\', \'embedding_lookup_device\', \'tensor_core_shape\', \'use_safe_embedding_lookup\'], varargs=None, keywords=None, defaults=[\'mean\', \'None\', \'0\', \'None\', \'None\', \'None\', \'True\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-sequence-features.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-sequence-features.pbtxt index 41483f2b83d..e2bef6beaaa 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-sequence-features.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-sequence-features.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.experimental.SequenceFeatures" tf_class { is_instance: "<class \'tensorflow.python.keras.feature_column.sequence_feature_column.SequenceFeatures\'>" - is_instance: "<class \'tensorflow.python.feature_column.feature_column_v2._BaseFeaturesLayer\'>" + is_instance: "<class \'tensorflow.python.keras.feature_column.base_feature_layer._BaseFeaturesLayer\'>" is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>" is_instance: "<class \'tensorflow.python.module.module.Module\'>" is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dense-features.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dense-features.pbtxt index f7137f0d09b..3b4eb863387 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dense-features.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dense-features.pbtxt @@ -1,8 +1,8 @@ path: "tensorflow.keras.layers.DenseFeatures" tf_class { - is_instance: "<class \'tensorflow.python.feature_column.dense_features_v2.DenseFeatures\'>" - is_instance: "<class \'tensorflow.python.feature_column.dense_features.DenseFeatures\'>" - is_instance: "<class \'tensorflow.python.feature_column.feature_column_v2._BaseFeaturesLayer\'>" + is_instance: "<class \'tensorflow.python.keras.feature_column.dense_features_v2.DenseFeatures\'>" + is_instance: "<class \'tensorflow.python.keras.feature_column.dense_features.DenseFeatures\'>" + is_instance: "<class \'tensorflow.python.keras.feature_column.base_feature_layer._BaseFeaturesLayer\'>" is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>" is_instance: "<class \'tensorflow.python.module.module.Module\'>" is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-category-crossing.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-category-crossing.pbtxt index 0407188ab6b..6cfcbf73e5d 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-category-crossing.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-category-crossing.pbtxt @@ -1,6 +1,6 @@ path: "tensorflow.keras.layers.experimental.preprocessing.CategoryCrossing" tf_class { - is_instance: "<class \'tensorflow.python.keras.layers.preprocessing.categorical_crossing.CategoryCrossing\'>" + is_instance: "<class \'tensorflow.python.keras.layers.preprocessing.category_crossing.CategoryCrossing\'>" is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>" is_instance: "<class \'tensorflow.python.module.module.Module\'>" is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>" @@ -113,7 +113,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'depth\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'depth\', \'name\', \'separator\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\'], " } member_method { name: "add_loss" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-category-encoding.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-category-encoding.__metaclass__.pbtxt new file mode 100644 index 00000000000..e907d9a293b --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-category-encoding.__metaclass__.pbtxt @@ -0,0 +1,14 @@ +path: "tensorflow.keras.layers.experimental.preprocessing.CategoryEncoding.__metaclass__" +tf_class { + is_instance: "<type \'type\'>" + member_method { + name: "__init__" + } + member_method { + name: "mro" + } + member_method { + name: "register" + argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-category-encoding.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-category-encoding.pbtxt new file mode 100644 index 00000000000..2edcfbb6487 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-category-encoding.pbtxt @@ -0,0 +1,232 @@ +path: "tensorflow.keras.layers.experimental.preprocessing.CategoryEncoding" +tf_class { + is_instance: "<class \'tensorflow.python.keras.layers.preprocessing.category_encoding.CategoryEncoding\'>" + is_instance: "<class \'tensorflow.python.keras.engine.base_preprocessing_layer.CombinerPreprocessingLayer\'>" + is_instance: "<class \'tensorflow.python.keras.engine.base_preprocessing_layer.PreprocessingLayer\'>" + is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>" + is_instance: "<class \'tensorflow.python.module.module.Module\'>" + is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>" + is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>" + is_instance: "<class \'tensorflow.python.keras.utils.version_utils.LayerVersionSelector\'>" + is_instance: "<type \'object\'>" + member { + name: "activity_regularizer" + mtype: "<type \'property\'>" + } + member { + name: "dtype" + mtype: "<type \'property\'>" + } + member { + name: "dynamic" + mtype: "<type \'property\'>" + } + member { + name: "inbound_nodes" + mtype: "<type \'property\'>" + } + member { + name: "input" + mtype: "<type \'property\'>" + } + member { + name: "input_mask" + mtype: "<type \'property\'>" + } + member { + name: "input_shape" + mtype: "<type \'property\'>" + } + member { + name: "input_spec" + mtype: "<type \'property\'>" + } + member { + name: "losses" + mtype: "<type \'property\'>" + } + member { + name: "metrics" + mtype: "<type \'property\'>" + } + member { + name: "name" + mtype: "<type \'property\'>" + } + member { + name: "name_scope" + mtype: "<type \'property\'>" + } + member { + name: "non_trainable_variables" + mtype: "<type \'property\'>" + } + member { + name: "non_trainable_weights" + mtype: "<type \'property\'>" + } + member { + name: "outbound_nodes" + mtype: "<type \'property\'>" + } + member { + name: "output" + mtype: "<type \'property\'>" + } + member { + name: "output_mask" + mtype: "<type \'property\'>" + } + member { + name: "output_shape" + mtype: "<type \'property\'>" + } + member { + name: "stateful" + mtype: "<type \'property\'>" + } + member { + name: "submodules" + mtype: "<type \'property\'>" + } + member { + name: "trainable" + mtype: "<type \'property\'>" + } + member { + name: "trainable_variables" + mtype: "<type \'property\'>" + } + member { + name: "trainable_weights" + mtype: "<type \'property\'>" + } + member { + name: "updates" + mtype: "<type \'property\'>" + } + member { + name: "variables" + mtype: "<type \'property\'>" + } + member { + name: "weights" + mtype: "<type \'property\'>" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'max_tokens\', \'output_mode\', \'sparse\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'count\', \'False\'], " + } + member_method { + name: "adapt" + argspec: "args=[\'self\', \'data\', \'reset_state\'], varargs=None, keywords=None, defaults=[\'True\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "add_metric" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_output_signature" + argspec: "args=[\'self\', \'input_spec\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_num_elements" + argspec: "args=[\'self\', \'num_elements\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_tfidf_data" + argspec: "args=[\'self\', \'tfidf_data\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "with_name_scope" + argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-hashing.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-hashing.pbtxt new file mode 100644 index 00000000000..e4a5619058d --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-hashing.pbtxt @@ -0,0 +1,218 @@ +path: "tensorflow.keras.layers.experimental.preprocessing.Hashing" +tf_class { + is_instance: "<class \'tensorflow.python.keras.layers.preprocessing.hashing.Hashing\'>" + is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>" + is_instance: "<class \'tensorflow.python.module.module.Module\'>" + is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>" + is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>" + is_instance: "<class \'tensorflow.python.keras.utils.version_utils.LayerVersionSelector\'>" + is_instance: "<type \'object\'>" + member { + name: "activity_regularizer" + mtype: "<type \'property\'>" + } + member { + name: "dtype" + mtype: "<type \'property\'>" + } + member { + name: "dynamic" + mtype: "<type \'property\'>" + } + member { + name: "inbound_nodes" + mtype: "<type \'property\'>" + } + member { + name: "input" + mtype: "<type \'property\'>" + } + member { + name: "input_mask" + mtype: "<type \'property\'>" + } + member { + name: "input_shape" + mtype: "<type \'property\'>" + } + member { + name: "input_spec" + mtype: "<type \'property\'>" + } + member { + name: "losses" + mtype: "<type \'property\'>" + } + member { + name: "metrics" + mtype: "<type \'property\'>" + } + member { + name: "name" + mtype: "<type \'property\'>" + } + member { + name: "name_scope" + mtype: "<type \'property\'>" + } + member { + name: "non_trainable_variables" + mtype: "<type \'property\'>" + } + member { + name: "non_trainable_weights" + mtype: "<type \'property\'>" + } + member { + name: "outbound_nodes" + mtype: "<type \'property\'>" + } + member { + name: "output" + mtype: "<type \'property\'>" + } + member { + name: "output_mask" + mtype: "<type \'property\'>" + } + member { + name: "output_shape" + mtype: "<type \'property\'>" + } + member { + name: "stateful" + mtype: "<type \'property\'>" + } + member { + name: "submodules" + mtype: "<type \'property\'>" + } + member { + name: "trainable" + mtype: "<type \'property\'>" + } + member { + name: "trainable_variables" + mtype: "<type \'property\'>" + } + member { + name: "trainable_weights" + mtype: "<type \'property\'>" + } + member { + name: "updates" + mtype: "<type \'property\'>" + } + member { + name: "variables" + mtype: "<type \'property\'>" + } + member { + name: "weights" + mtype: "<type \'property\'>" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'num_bins\', \'salt\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "add_metric" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_output_signature" + argspec: "args=[\'self\', \'input_spec\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "with_name_scope" + argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-rescaling.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-rescaling.pbtxt index 7036fb926a8..60c0bc92f81 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-rescaling.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-rescaling.pbtxt @@ -113,7 +113,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'scale\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " + argspec: "args=[\'self\', \'scale\', \'offset\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'0.0\', \'None\'], " } member_method { name: "add_loss" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.pbtxt index 0964922ea26..a922b143910 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.pbtxt @@ -4,10 +4,18 @@ tf_module { name: "CategoryCrossing" mtype: "<type \'type\'>" } + member { + name: "CategoryEncoding" + mtype: "<type \'type\'>" + } member { name: "CenterCrop" mtype: "<type \'type\'>" } + member { + name: "Hashing" + mtype: "<type \'type\'>" + } member { name: "Normalization" mtype: "<type \'type\'>" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.math.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.math.pbtxt index 227366f5f98..2ea4e8f84a6 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.math.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.math.pbtxt @@ -82,7 +82,7 @@ tf_module { } member_method { name: "bincount" - argspec: "args=[\'arr\', \'weights\', \'minlength\', \'maxlength\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \"<dtype: \'int32\'>\", \'None\'], " + argspec: "args=[\'arr\', \'weights\', \'minlength\', \'maxlength\', \'dtype\', \'name\', \'axis\', \'binary_output\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \"<dtype: \'int32\'>\", \'None\', \'None\', \'False\'], " } member_method { name: "ceil" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index 44fb74ac63a..37a95cc88d1 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -736,6 +736,10 @@ tf_module { name: "ComplexAbs" argspec: "args=[\'x\', \'Tout\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], " } + member_method { + name: "CompressElement" + argspec: "args=[\'components\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "ComputeAccidentalHits" argspec: "args=[\'true_classes\', \'sampled_candidates\', \'num_true\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'None\'], " @@ -1472,6 +1476,10 @@ tf_module { name: "ExtractGlimpse" argspec: "args=[\'input\', \'size\', \'offsets\', \'centered\', \'normalized\', \'uniform_noise\', \'noise\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'True\', \'uniform\', \'None\'], " } + member_method { + name: "ExtractGlimpseV2" + argspec: "args=[\'input\', \'size\', \'offsets\', \'centered\', \'normalized\', \'uniform_noise\', \'noise\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'True\', \'uniform\', \'None\'], " + } member_method { name: "ExtractImagePatches" argspec: "args=[\'images\', \'ksizes\', \'strides\', \'rates\', \'padding\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " @@ -4100,6 +4108,14 @@ tf_module { name: "SparseCross" argspec: "args=[\'indices\', \'values\', \'shapes\', \'dense_inputs\', \'hashed_output\', \'num_buckets\', \'hash_key\', \'out_type\', \'internal_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "SparseCrossHashed" + argspec: "args=[\'indices\', \'values\', \'shapes\', \'dense_inputs\', \'num_buckets\', \'strong_hash\', \'salt\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "SparseCrossV2" + argspec: "args=[\'indices\', \'values\', \'shapes\', \'dense_inputs\', \'sep\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "SparseDenseCwiseAdd" argspec: "args=[\'sp_indices\', \'sp_values\', \'sp_shape\', \'dense\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " @@ -4590,7 +4606,7 @@ tf_module { } member_method { name: "TPUReplicatedInput" - argspec: "args=[\'inputs\', \'is_mirrored_variable\', \'index\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'-1\', \'None\'], " + argspec: "args=[\'inputs\', \'is_mirrored_variable\', \'index\', \'is_packed\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'-1\', \'False\', \'None\'], " } member_method { name: "TPUReplicatedOutput" @@ -4948,6 +4964,10 @@ tf_module { name: "UnbatchGrad" argspec: "args=[\'original_input\', \'batch_index\', \'grad\', \'id\', \'container\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'None\'], " } + member_method { + name: "UncompressElement" + argspec: "args=[\'compressed\', \'output_types\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "UnicodeDecode" argspec: "args=[\'input\', \'input_encoding\', \'errors\', \'replacement_char\', \'replace_control_characters\', \'Tsplits\', \'name\'], varargs=None, keywords=None, defaults=[\'replace\', \'65533\', \'False\', \"<dtype: \'int64\'>\", \'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.sparse.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.sparse.pbtxt index 67235bb2cf2..0028b7d8953 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.sparse.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.sparse.pbtxt @@ -18,7 +18,7 @@ tf_module { } member_method { name: "cross" - argspec: "args=[\'inputs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'inputs\', \'name\', \'separator\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } member_method { name: "cross_hashed" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-adagrad.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-adagrad.pbtxt new file mode 100644 index 00000000000..e2c6bbd43d9 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-adagrad.pbtxt @@ -0,0 +1,10 @@ +path: "tensorflow.tpu.experimental.embedding.Adagrad" +tf_class { + is_instance: "<class \'tensorflow.python.tpu.tpu_embedding_v2_utils.Adagrad\'>" + is_instance: "<class \'tensorflow.python.tpu.tpu_embedding_v2_utils._Optimizer\'>" + is_instance: "<type \'object\'>" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'learning_rate\', \'initial_accumulator_value\', \'use_gradient_accumulation\', \'clip_weight_min\', \'clip_weight_max\', \'weight_decay_factor\', \'multiply_weight_decay_factor_by_learning_rate\', \'slot_variable_creation_fn\'], varargs=None, keywords=None, defaults=[\'0.001\', \'0.1\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-adam.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-adam.pbtxt new file mode 100644 index 00000000000..941e81acbbb --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-adam.pbtxt @@ -0,0 +1,10 @@ +path: "tensorflow.tpu.experimental.embedding.Adam" +tf_class { + is_instance: "<class \'tensorflow.python.tpu.tpu_embedding_v2_utils.Adam\'>" + is_instance: "<class \'tensorflow.python.tpu.tpu_embedding_v2_utils._Optimizer\'>" + is_instance: "<type \'object\'>" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'learning_rate\', \'beta_1\', \'beta_2\', \'epsilon\', \'lazy_adam\', \'sum_inside_sqrt\', \'use_gradient_accumulation\', \'clip_weight_min\', \'clip_weight_max\', \'weight_decay_factor\', \'multiply_weight_decay_factor_by_learning_rate\', \'slot_variable_creation_fn\'], varargs=None, keywords=None, defaults=[\'0.001\', \'0.9\', \'0.999\', \'1e-07\', \'True\', \'True\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-feature-config.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-feature-config.pbtxt new file mode 100644 index 00000000000..b2c31d00ad8 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-feature-config.pbtxt @@ -0,0 +1,9 @@ +path: "tensorflow.tpu.experimental.embedding.FeatureConfig" +tf_class { + is_instance: "<class \'tensorflow.python.tpu.tpu_embedding_v2_utils.FeatureConfig\'>" + is_instance: "<type \'object\'>" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'table\', \'max_sequence_length\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-s-g-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-s-g-d.pbtxt new file mode 100644 index 00000000000..9a3f47406b8 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-s-g-d.pbtxt @@ -0,0 +1,10 @@ +path: "tensorflow.tpu.experimental.embedding.SGD" +tf_class { + is_instance: "<class \'tensorflow.python.tpu.tpu_embedding_v2_utils.SGD\'>" + is_instance: "<class \'tensorflow.python.tpu.tpu_embedding_v2_utils._Optimizer\'>" + is_instance: "<type \'object\'>" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'learning_rate\', \'clip_weight_min\', \'clip_weight_max\', \'weight_decay_factor\', \'multiply_weight_decay_factor_by_learning_rate\'], varargs=None, keywords=None, defaults=[\'0.01\', \'None\', \'None\', \'None\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-t-p-u-embedding.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-t-p-u-embedding.pbtxt new file mode 100644 index 00000000000..9cc8354b4bf --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-t-p-u-embedding.pbtxt @@ -0,0 +1,27 @@ +path: "tensorflow.tpu.experimental.embedding.TPUEmbedding" +tf_class { + is_instance: "<class \'tensorflow.python.tpu.tpu_embedding_v2.TPUEmbedding\'>" + is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>" + is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>" + is_instance: "<type \'object\'>" + member { + name: "embedding_tables" + mtype: "<type \'property\'>" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'feature_config\', \'batch_size\', \'optimizer\', \'pipeline_execution_with_tensor_core\', \'initialize_tpu_embedding\'], varargs=None, keywords=None, defaults=[\'False\', \'True\'], " + } + member_method { + name: "apply_gradients" + argspec: "args=[\'self\', \'gradients\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "dequeue" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "enqueue" + argspec: "args=[\'self\', \'features\', \'weights\', \'training\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-table-config.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-table-config.pbtxt new file mode 100644 index 00000000000..6be35ed6fb6 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-table-config.pbtxt @@ -0,0 +1,9 @@ +path: "tensorflow.tpu.experimental.embedding.TableConfig" +tf_class { + is_instance: "<class \'tensorflow.python.tpu.tpu_embedding_v2_utils.TableConfig\'>" + is_instance: "<type \'object\'>" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'vocabulary_size\', \'dim\', \'initializer\', \'optimizer\', \'combiner\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'mean\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.pbtxt new file mode 100644 index 00000000000..9d4f24f4edd --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.pbtxt @@ -0,0 +1,27 @@ +path: "tensorflow.tpu.experimental.embedding" +tf_module { + member { + name: "Adagrad" + mtype: "<type \'type\'>" + } + member { + name: "Adam" + mtype: "<type \'type\'>" + } + member { + name: "FeatureConfig" + mtype: "<type \'type\'>" + } + member { + name: "SGD" + mtype: "<type \'type\'>" + } + member { + name: "TPUEmbedding" + mtype: "<type \'type\'>" + } + member { + name: "TableConfig" + mtype: "<type \'type\'>" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.pbtxt index df31799828c..5c547f4f49b 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.pbtxt @@ -12,6 +12,10 @@ tf_module { name: "Topology" mtype: "<type \'type\'>" } + member { + name: "embedding" + mtype: "<type \'module\'>" + } member_method { name: "initialize_tpu_system" argspec: "args=[\'cluster_resolver\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/ci_build/builds/docker_cpu_pip.sh b/tensorflow/tools/ci_build/builds/docker_cpu_pip.sh index 3bb8d8b7afa..cf0036fb98f 100755 --- a/tensorflow/tools/ci_build/builds/docker_cpu_pip.sh +++ b/tensorflow/tools/ci_build/builds/docker_cpu_pip.sh @@ -40,7 +40,7 @@ yes "" | python configure.py PIP_TEST_ROOT=pip_test_root mkdir -p ${PIP_TEST_ROOT} ln -s $(pwd)/tensorflow ${PIP_TEST_ROOT}/tensorflow -bazel test --define=no_tensorflow_py_deps=true \ +bazel --output_base=/tmp test --define=no_tensorflow_py_deps=true \ --test_lang_filters=py \ --build_tests_only \ -k \ diff --git a/tensorflow/tools/ci_build/builds/libtensorflow.sh b/tensorflow/tools/ci_build/builds/libtensorflow.sh index 44180b8bf84..a281afe7442 100755 --- a/tensorflow/tools/ci_build/builds/libtensorflow.sh +++ b/tensorflow/tools/ci_build/builds/libtensorflow.sh @@ -54,7 +54,7 @@ function build_libtensorflow_tarball() { BAZEL_OPTS="--config=opt --cxxopt=-D_GLIBCXX_USE_CXX11_ABI=0" export CC_OPT_FLAGS="-mavx -msse4.2" if [ "${TF_NEED_CUDA}" == "1" ]; then - BAZEL_OPTS="${BAZEL_OPTS} --config=cuda" + BAZEL_OPTS="${BAZEL_OPTS} --config=cuda --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain" export TF_NEED_ROCM=0 fi bazel clean --expunge diff --git a/tensorflow/tools/ci_build/ci_sanity.sh b/tensorflow/tools/ci_build/ci_sanity.sh index cc1156f8cc5..6db88755ac8 100755 --- a/tensorflow/tools/ci_build/ci_sanity.sh +++ b/tensorflow/tools/ci_build/ci_sanity.sh @@ -702,23 +702,37 @@ done # Print summary of build results COUNTER=0 echo "==== Summary of sanity check results ====" +TESTCASE_XML='' while [[ ${COUNTER} -lt "${#SANITY_STEPS[@]}" ]]; do INDEX=COUNTER ((INDEX++)) echo "${INDEX}. ${SANITY_STEPS[COUNTER]}: ${SANITY_STEPS_DESC[COUNTER]}" + TESTCASE_XML="${TESTCASE_XML} <testcase name=\"${SANITY_STEPS_DESC[COUNTER]}\" status=\"run\" classname=\"\" time=\"0\">" + if [[ ${STEP_EXIT_CODES[COUNTER]} == "0" ]]; then printf " ${COLOR_GREEN}PASS${COLOR_NC}\n" else printf " ${COLOR_RED}FAIL${COLOR_NC}\n" + TESTCASE_XML="${TESTCASE_XML} <failure message=\"\" type=\"\"/>" fi + TESTCASE_XML="${TESTCASE_XML} </testcase>" + ((COUNTER++)) done echo echo "${FAIL_COUNTER} failed; ${PASS_COUNTER} passed." +mkdir -p "${KOKORO_ARTIFACTS_DIR}/${KOKORO_JOB_NAME}/summary" +echo '<?xml version="1.0" encoding="UTF-8"?>'\ + '<testsuites name="1" tests="1" failures="0" errors="0" time="0">'\ + '<testsuite name="Kokoro Summary" tests="'"$((FAIL_COUNTER + PASS_COUNTER))"\ + '" failures="'"${FAIL_COUNTER}"'" errors="0" time="0">'\ + "${TESTCASE_XML}"'</testsuite></testsuites>'\ + > "${KOKORO_ARTIFACTS_DIR}/${KOKORO_JOB_NAME}/summary/sponge_log.xml" + echo if [[ ${FAIL_COUNTER} == "0" ]]; then printf "Sanity checks ${COLOR_GREEN}PASSED${COLOR_NC}\n" diff --git a/tensorflow/tools/ci_build/install/install_pi_python37_toolchain.sh b/tensorflow/tools/ci_build/install/install_pi_python37_toolchain.sh index 7688a081d6f..3bda56af648 100755 --- a/tensorflow/tools/ci_build/install/install_pi_python37_toolchain.sh +++ b/tensorflow/tools/ci_build/install/install_pi_python37_toolchain.sh @@ -15,12 +15,14 @@ # ============================================================================== dpkg --add-architecture armhf -echo 'deb [arch=armhf] http://ports.ubuntu.com/ xenial main restricted universe multiverse' >> /etc/apt/sources.list.d/armhf.list -echo 'deb [arch=armhf] http://ports.ubuntu.com/ xenial-updates main restricted universe multiverse' >> /etc/apt/sources.list.d/armhf.list -echo 'deb [arch=armhf] http://ports.ubuntu.com/ xenial-security main restricted universe multiverse' >> /etc/apt/sources.list.d/armhf.list -echo 'deb [arch=armhf] http://ports.ubuntu.com/ xenial-backports main restricted universe multiverse' >> /etc/apt/sources.list.d/armhf.list +dpkg --add-architecture arm64 +echo 'deb [arch=arm64,armhf] http://ports.ubuntu.com/ xenial main restricted universe multiverse' >> /etc/apt/sources.list.d/armhf.list +echo 'deb [arch=arm64,armhf] http://ports.ubuntu.com/ xenial-updates main restricted universe multiverse' >> /etc/apt/sources.list.d/armhf.list +echo 'deb [arch=arm64,armhf] http://ports.ubuntu.com/ xenial-security main restricted universe multiverse' >> /etc/apt/sources.list.d/armhf.list +echo 'deb [arch=arm64,armhf] http://ports.ubuntu.com/ xenial-backports main restricted universe multiverse' >> /etc/apt/sources.list.d/armhf.list sed -i 's#deb http://archive.ubuntu.com/ubuntu/#deb [arch=amd64] http://archive.ubuntu.com/ubuntu/#g' /etc/apt/sources.list yes | add-apt-repository ppa:deadsnakes/ppa apt-get update apt-get install -y python3.7 python3-numpy python3.7-dev python3-pip apt-get install -y libpython3.7-dev:armhf +apt-get install -y libpython3.7-dev:arm64 diff --git a/tensorflow/tools/ci_build/install/install_pi_python3_toolchain.sh b/tensorflow/tools/ci_build/install/install_pi_python3_toolchain.sh index 7c87a3fc7c5..b02c35c612d 100755 --- a/tensorflow/tools/ci_build/install/install_pi_python3_toolchain.sh +++ b/tensorflow/tools/ci_build/install/install_pi_python3_toolchain.sh @@ -15,11 +15,13 @@ # ============================================================================== dpkg --add-architecture armhf -echo 'deb [arch=armhf] http://ports.ubuntu.com/ xenial main restricted universe multiverse' >> /etc/apt/sources.list.d/armhf.list -echo 'deb [arch=armhf] http://ports.ubuntu.com/ xenial-updates main restricted universe multiverse' >> /etc/apt/sources.list.d/armhf.list -echo 'deb [arch=armhf] http://ports.ubuntu.com/ xenial-security main restricted universe multiverse' >> /etc/apt/sources.list.d/armhf.list -echo 'deb [arch=armhf] http://ports.ubuntu.com/ xenial-backports main restricted universe multiverse' >> /etc/apt/sources.list.d/armhf.list +dpkg --add-architecture arm64 +echo 'deb [arch=arm64,armhf] http://ports.ubuntu.com/ xenial main restricted universe multiverse' >> /etc/apt/sources.list.d/armhf.list +echo 'deb [arch=arm64,armhf] http://ports.ubuntu.com/ xenial-updates main restricted universe multiverse' >> /etc/apt/sources.list.d/armhf.list +echo 'deb [arch=arm64,armhf] http://ports.ubuntu.com/ xenial-security main restricted universe multiverse' >> /etc/apt/sources.list.d/armhf.list +echo 'deb [arch=arm64,armhf] http://ports.ubuntu.com/ xenial-backports main restricted universe multiverse' >> /etc/apt/sources.list.d/armhf.list sed -i 's#deb http://archive.ubuntu.com/ubuntu/#deb [arch=amd64] http://archive.ubuntu.com/ubuntu/#g' /etc/apt/sources.list apt-get update apt-get install -y libpython3-all-dev:armhf +apt-get install -y libpython3-all-dev:arm64 apt-get install -y python3 python3-numpy python3-dev python3-pip diff --git a/tensorflow/tools/ci_build/linux/libtensorflow_docker.sh b/tensorflow/tools/ci_build/linux/libtensorflow_docker.sh index 467b8dc8083..1b255682671 100755 --- a/tensorflow/tools/ci_build/linux/libtensorflow_docker.sh +++ b/tensorflow/tools/ci_build/linux/libtensorflow_docker.sh @@ -36,7 +36,7 @@ DOCKER_BINARY="docker" if [ "${TF_NEED_CUDA}" == "1" ]; then DOCKER_IMAGE="tf-tensorflow-gpu" DOCKER_BINARY="nvidia-docker" - DOCKER_FILE="Dockerfile.gpu" + DOCKER_FILE="Dockerfile.rbe.cuda10.1-cudnn7-ubuntu16.04-manylinux2010" fi if [ "${TF_NEED_ROCM}" == "1" ]; then DOCKER_IMAGE="tf-tensorflow-rocm" diff --git a/tensorflow/tools/ci_build/release/common_win.bat b/tensorflow/tools/ci_build/release/common_win.bat index d34c92736c0..464782dcefd 100644 --- a/tensorflow/tools/ci_build/release/common_win.bat +++ b/tensorflow/tools/ci_build/release/common_win.bat @@ -28,7 +28,7 @@ SET PATH=%PATH%;C:\%PYTHON_DIRECTORY% %PIP_EXE% install setuptools --upgrade %PIP_EXE% install future>=0.17.1 --no-deps -%PIP_EXE% install --force-reinstall tf-estimator-nightly --no-deps +%PIP_EXE% install --ignore-installed tf-estimator-nightly --no-deps %PIP_EXE% install tb-nightly --no-deps %PIP_EXE% install numpy --upgrade --no-deps %PIP_EXE% install opt_einsum --upgrade diff --git a/tensorflow/tools/ci_build/release/macos/cpu_py2_full/pip.sh b/tensorflow/tools/ci_build/release/macos/cpu_py2_full/pip.sh index f6de18d81ac..0630c117036 100644 --- a/tensorflow/tools/ci_build/release/macos/cpu_py2_full/pip.sh +++ b/tensorflow/tools/ci_build/release/macos/cpu_py2_full/pip.sh @@ -44,7 +44,7 @@ export TF_TEST_FLAGS="--define=no_tensorflow_py_deps=true --test_lang_filters=py export TF_TEST_TARGETS="//tensorflow/python/..." export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" export TF_TEST_FILTER_TAGS='-nomac,-no_mac,-no_oss,-oss_serial,-no_oss_py2,-v1only,-gpu,-tpu,-benchmark-test' -export IS_NIGHTLY=0 # Not nightly +#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. export TF_PROJECT_NAME="tensorflow" export TF_PIP_TEST_ROOT="pip_test" diff --git a/tensorflow/tools/ci_build/release/macos/cpu_py2_full/pip_v1.sh b/tensorflow/tools/ci_build/release/macos/cpu_py2_full/pip_v1.sh index c64d9c00787..188e47fa74b 100644 --- a/tensorflow/tools/ci_build/release/macos/cpu_py2_full/pip_v1.sh +++ b/tensorflow/tools/ci_build/release/macos/cpu_py2_full/pip_v1.sh @@ -39,7 +39,7 @@ export TF_TEST_FLAGS="--define=no_tensorflow_py_deps=true --test_lang_filters=py export TF_TEST_TARGETS="//tensorflow/python/..." export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" export TF_TEST_FILTER_TAGS='-nomac,-no_mac,-no_oss,-oss_serial,-no_oss_py2' -export IS_NIGHTLY=0 # Not nightly +#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. export TF_PROJECT_NAME="tensorflow" export TF_PIP_TEST_ROOT="pip_test" diff --git a/tensorflow/tools/ci_build/release/macos/cpu_py35_full/pip.sh b/tensorflow/tools/ci_build/release/macos/cpu_py35_full/pip.sh index 8c9b91dd55e..3f31033b2ac 100644 --- a/tensorflow/tools/ci_build/release/macos/cpu_py35_full/pip.sh +++ b/tensorflow/tools/ci_build/release/macos/cpu_py35_full/pip.sh @@ -44,7 +44,7 @@ export TF_TEST_FLAGS="--define=no_tensorflow_py_deps=true --test_lang_filters=py export TF_TEST_TARGETS="//tensorflow/python/..." export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" export TF_TEST_FILTER_TAGS='-nomac,-no_mac,-no_oss,-oss_serial,-no_oss_py35,-gpu,-tpu,-benchmark-test' -export IS_NIGHTLY=0 # Not nightly +#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. export TF_PROJECT_NAME="tensorflow" export TF_PIP_TEST_ROOT="pip_test" diff --git a/tensorflow/tools/ci_build/release/macos/cpu_py35_full/pip_v1.sh b/tensorflow/tools/ci_build/release/macos/cpu_py35_full/pip_v1.sh index e03f4c4ce2f..dcbd5b504c8 100644 --- a/tensorflow/tools/ci_build/release/macos/cpu_py35_full/pip_v1.sh +++ b/tensorflow/tools/ci_build/release/macos/cpu_py35_full/pip_v1.sh @@ -43,7 +43,7 @@ export TF_TEST_FLAGS="--define=no_tensorflow_py_deps=true --test_lang_filters=py export TF_TEST_TARGETS="//tensorflow/python/..." export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" export TF_TEST_FILTER_TAGS='-nomac,-no_mac,-no_oss,-oss_serial,-no_oss_py35' -export IS_NIGHTLY=0 # Not nightly +#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. export TF_PROJECT_NAME="tensorflow" export TF_PIP_TEST_ROOT="pip_test" diff --git a/tensorflow/tools/ci_build/release/macos/cpu_py36_full/pip.sh b/tensorflow/tools/ci_build/release/macos/cpu_py36_full/pip.sh index a66dca3885e..26ee4ea8edb 100644 --- a/tensorflow/tools/ci_build/release/macos/cpu_py36_full/pip.sh +++ b/tensorflow/tools/ci_build/release/macos/cpu_py36_full/pip.sh @@ -44,7 +44,7 @@ export TF_TEST_FLAGS="--define=no_tensorflow_py_deps=true --test_lang_filters=py export TF_TEST_TARGETS="//tensorflow/python/..." export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" export TF_TEST_FILTER_TAGS='-nomac,-no_mac,-no_oss,-oss_serial,-no_oss_py35,-v1only,-gpu,-tpu,-benchmark-test' -export IS_NIGHTLY=0 # Not nightly +#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. export TF_PROJECT_NAME="tensorflow" export TF_PIP_TEST_ROOT="pip_test" diff --git a/tensorflow/tools/ci_build/release/macos/cpu_py36_full/pip_v1.sh b/tensorflow/tools/ci_build/release/macos/cpu_py36_full/pip_v1.sh index dc153b16a43..3d04cf1d9ba 100644 --- a/tensorflow/tools/ci_build/release/macos/cpu_py36_full/pip_v1.sh +++ b/tensorflow/tools/ci_build/release/macos/cpu_py36_full/pip_v1.sh @@ -42,7 +42,7 @@ export TF_TEST_FLAGS="--define=no_tensorflow_py_deps=true --test_lang_filters=py export TF_TEST_TARGETS="//tensorflow/python/..." export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" export TF_TEST_FILTER_TAGS='-nomac,-no_mac,-no_oss,-oss_serial' -export IS_NIGHTLY=0 # Not nightly +#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. export TF_PROJECT_NAME="tensorflow" export TF_PIP_TEST_ROOT="pip_test" diff --git a/tensorflow/tools/ci_build/release/macos/cpu_py37_full/pip.sh b/tensorflow/tools/ci_build/release/macos/cpu_py37_full/pip.sh index 5d75224a45c..ed577db961a 100644 --- a/tensorflow/tools/ci_build/release/macos/cpu_py37_full/pip.sh +++ b/tensorflow/tools/ci_build/release/macos/cpu_py37_full/pip.sh @@ -44,7 +44,7 @@ export TF_TEST_FLAGS="--define=no_tensorflow_py_deps=true --test_lang_filters=py export TF_TEST_TARGETS="//tensorflow/python/..." export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" export TF_TEST_FILTER_TAGS='-nomac,-no_mac,-no_oss,-oss_serial,-no_oss_py37,-v1only,-gpu,-tpu,-benchmark-test' -export IS_NIGHTLY=0 # Not nightly +#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. export TF_PROJECT_NAME="tensorflow" export TF_PIP_TEST_ROOT="pip_test" diff --git a/tensorflow/tools/ci_build/release/macos/cpu_py37_full/pip_v1.sh b/tensorflow/tools/ci_build/release/macos/cpu_py37_full/pip_v1.sh index afe933a1912..c3840aa2dc8 100644 --- a/tensorflow/tools/ci_build/release/macos/cpu_py37_full/pip_v1.sh +++ b/tensorflow/tools/ci_build/release/macos/cpu_py37_full/pip_v1.sh @@ -42,7 +42,7 @@ export TF_TEST_FLAGS="--define=no_tensorflow_py_deps=true --test_lang_filters=py export TF_TEST_TARGETS="//tensorflow/python/..." export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" export TF_TEST_FILTER_TAGS='-nomac,-no_mac,-no_oss,-oss_serial' -export IS_NIGHTLY=0 # Not nightly +#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. export TF_PROJECT_NAME="tensorflow" export TF_PIP_TEST_ROOT="pip_test" diff --git a/tensorflow/tools/ci_build/release/macos/cpu_py38_full/pip.sh b/tensorflow/tools/ci_build/release/macos/cpu_py38_full/pip.sh index a5a5b6a34c4..f8eda5a7520 100644 --- a/tensorflow/tools/ci_build/release/macos/cpu_py38_full/pip.sh +++ b/tensorflow/tools/ci_build/release/macos/cpu_py38_full/pip.sh @@ -44,7 +44,7 @@ export TF_TEST_FLAGS="--define=no_tensorflow_py_deps=true --test_lang_filters=py export TF_TEST_TARGETS="//tensorflow/python/..." export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" export TF_TEST_FILTER_TAGS='-nomac,-no_mac,-no_oss,-oss_serial,-no_oss_py38,-v1only,-gpu,-tpu,-benchmark-test' -export IS_NIGHTLY=0 # Not nightly +#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. export TF_PROJECT_NAME="tensorflow" export TF_PIP_TEST_ROOT="pip_test" diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py2_full/pip.sh b/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py2_full/pip.sh index ad14d8724b8..8524bbbad03 100644 --- a/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py2_full/pip.sh +++ b/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py2_full/pip.sh @@ -46,7 +46,7 @@ export TF_TEST_FLAGS="--define=no_tensorflow_py_deps=true --test_lang_filters=py export TF_TEST_TARGETS="${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... " export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" export TF_TEST_FILTER_TAGS='-no_oss,-oss_serial,-gpu,-tpu,-benchmark-test,-no_oss_py2,-v1only' -export IS_NIGHTLY=0 # Not nightly +#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. export TF_PROJECT_NAME="tensorflow_cpu" export TF_PIP_TEST_ROOT="pip_test" diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py2_full/pip_v1.sh b/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py2_full/pip_v1.sh index a4d9bb1de03..bd2e27e8781 100644 --- a/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py2_full/pip_v1.sh +++ b/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py2_full/pip_v1.sh @@ -43,7 +43,7 @@ export TF_TEST_FLAGS="--define=no_tensorflow_py_deps=true --test_lang_filters=py export TF_TEST_TARGETS="//tensorflow/python/... " export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" export TF_TEST_FILTER_TAGS='-no_oss,-oss_serial,-gpu,-tpu,-benchmark-test,-no_oss_py2' -export IS_NIGHTLY=0 # Not nightly +#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. export TF_PROJECT_NAME="tensorflow_cpu" export TF_PIP_TEST_ROOT="pip_test" diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py35_full/pip.sh b/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py35_full/pip.sh index 3842410edb2..5d0cbacb0b7 100644 --- a/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py35_full/pip.sh +++ b/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py35_full/pip.sh @@ -45,7 +45,7 @@ export TF_TEST_FLAGS="--define=no_tensorflow_py_deps=true --test_lang_filters=py export TF_TEST_TARGETS="${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... " export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" export TF_TEST_FILTER_TAGS='-no_oss,-oss_serial,-no_oss_py35,-v1only' -export IS_NIGHTLY=0 # Not nightly +#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. export TF_PROJECT_NAME="tensorflow_cpu" export TF_PIP_TEST_ROOT="pip_test" diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py35_full/pip_v1.sh b/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py35_full/pip_v1.sh index cd8cdd98014..1e2665f4120 100644 --- a/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py35_full/pip_v1.sh +++ b/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py35_full/pip_v1.sh @@ -42,7 +42,7 @@ export TF_TEST_FLAGS="--define=no_tensorflow_py_deps=true --test_lang_filters=py export TF_TEST_TARGETS="//tensorflow/python/... " export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" export TF_TEST_FILTER_TAGS='-no_oss,-oss_serial,-no_oss_py35' -export IS_NIGHTLY=0 # Not nightly +#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. export TF_PROJECT_NAME="tensorflow_cpu" export TF_PIP_TEST_ROOT="pip_test" diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py36_full/pip.sh b/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py36_full/pip.sh index d23ce016080..25c4de88cdd 100644 --- a/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py36_full/pip.sh +++ b/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py36_full/pip.sh @@ -45,7 +45,7 @@ export TF_TEST_FLAGS="--define=no_tensorflow_py_deps=true --test_lang_filters=py export TF_TEST_TARGETS="${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... " export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" export TF_TEST_FILTER_TAGS='-no_oss,-oss_serial,-no_oss_py36,-v1only' -export IS_NIGHTLY=0 # Not nightly +#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. export TF_PROJECT_NAME="tensorflow_cpu" export TF_PIP_TEST_ROOT="pip_test" diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py36_full/pip_v1.sh b/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py36_full/pip_v1.sh index 084bfeb3a22..c4d78dc3fe5 100644 --- a/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py36_full/pip_v1.sh +++ b/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py36_full/pip_v1.sh @@ -42,7 +42,7 @@ export TF_TEST_FLAGS="--define=no_tensorflow_py_deps=true --test_lang_filters=py export TF_TEST_TARGETS="//tensorflow/python/... " export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" export TF_TEST_FILTER_TAGS='-no_oss,-oss_serial,-no_oss_py36' -export IS_NIGHTLY=0 # Not nightly +#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. export TF_PROJECT_NAME="tensorflow_cpu" export TF_PIP_TEST_ROOT="pip_test" diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py37_full/pip.sh b/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py37_full/pip.sh index 9cded426bde..940cef32ef8 100644 --- a/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py37_full/pip.sh +++ b/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py37_full/pip.sh @@ -45,7 +45,7 @@ export TF_TEST_FLAGS="--define=no_tensorflow_py_deps=true --test_lang_filters=py export TF_TEST_TARGETS="${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... " export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" export TF_TEST_FILTER_TAGS='-no_oss,-oss_serial,-no_oss_py37,-v1only' -export IS_NIGHTLY=0 # Not nightly +#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. export TF_PROJECT_NAME="tensorflow_cpu" export TF_PIP_TEST_ROOT="pip_test" diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py37_full/pip_v1.sh b/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py37_full/pip_v1.sh index 2df3c0e61e7..2208327388f 100644 --- a/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py37_full/pip_v1.sh +++ b/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py37_full/pip_v1.sh @@ -42,7 +42,7 @@ export TF_TEST_FLAGS="--define=no_tensorflow_py_deps=true --test_lang_filters=py export TF_TEST_TARGETS="//tensorflow/python/... " export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" export TF_TEST_FILTER_TAGS='-no_oss,-oss_serial,-no_oss_py37' -export IS_NIGHTLY=0 # Not nightly +#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. export TF_PROJECT_NAME="tensorflow_cpu" export TF_PIP_TEST_ROOT="pip_test" diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py38_full/pip.sh b/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py38_full/pip.sh index 366f2464612..a27d1f863d6 100644 --- a/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py38_full/pip.sh +++ b/tensorflow/tools/ci_build/release/ubuntu_16/cpu_py38_full/pip.sh @@ -45,7 +45,7 @@ export TF_TEST_FLAGS="--define=no_tensorflow_py_deps=true --test_lang_filters=py export TF_TEST_TARGETS="${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... " export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" export TF_TEST_FILTER_TAGS='-no_oss,-oss_serial,-no_oss_py38,-v1only' -export IS_NIGHTLY=0 # Not nightly +#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. export TF_PROJECT_NAME="tensorflow_cpu" export TF_PIP_TEST_ROOT="pip_test" diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py2_full/pip.sh b/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py2_full/pip.sh index 12290d1b0b5..dd618031c0d 100644 --- a/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py2_full/pip.sh +++ b/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py2_full/pip.sh @@ -58,7 +58,7 @@ export TF_TEST_FLAGS="--test_tag_filters=${TF_TEST_FILTER_TAGS} --build_tag_filt --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute " export TF_TEST_TARGETS="${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/..." export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" -export IS_NIGHTLY=0 # Not nightly +#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. export TF_PROJECT_NAME=${PROJECT_NAME} export TF_PIP_TEST_ROOT="pip_test" diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py2_full/pip_v1.sh b/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py2_full/pip_v1.sh index d5e5c76ce82..db0c3a22c06 100644 --- a/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py2_full/pip_v1.sh +++ b/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py2_full/pip_v1.sh @@ -56,7 +56,7 @@ export TF_TEST_FLAGS="--test_tag_filters=${TF_TEST_FILTER_TAGS} --build_tag_filt --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute " export TF_TEST_TARGETS="//tensorflow/python/... " export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" -export IS_NIGHTLY=0 # Not nightly +#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. export TF_PROJECT_NAME=${PROJECT_NAME} export TF_PIP_TEST_ROOT="pip_test" diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py35_full/pip.sh b/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py35_full/pip.sh index be97cc4bfa8..0e8cd8cd784 100644 --- a/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py35_full/pip.sh +++ b/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py35_full/pip.sh @@ -59,7 +59,7 @@ export TF_TEST_FLAGS="--test_tag_filters=${TF_TEST_FILTER_TAGS} --build_tag_filt --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute " export TF_TEST_TARGETS="${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... " export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" -export IS_NIGHTLY=0 # Not nightly +#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. export TF_PROJECT_NAME=${PROJECT_NAME} export TF_PIP_TEST_ROOT="pip_test" diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py35_full/pip_v1.sh b/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py35_full/pip_v1.sh index a3104e88395..4bbbd50724b 100644 --- a/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py35_full/pip_v1.sh +++ b/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py35_full/pip_v1.sh @@ -56,7 +56,7 @@ export TF_TEST_FLAGS="--test_tag_filters=${TF_TEST_FILTER_TAGS} --build_tag_filt --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute " export TF_TEST_TARGETS="//tensorflow/python/... " export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" -export IS_NIGHTLY=0 # Not nightly +#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. export TF_PROJECT_NAME=${PROJECT_NAME} export TF_PIP_TEST_ROOT="pip_test" diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py36_full/pip.sh b/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py36_full/pip.sh index 15f7db11a87..0b26173ca5f 100644 --- a/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py36_full/pip.sh +++ b/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py36_full/pip.sh @@ -59,7 +59,7 @@ export TF_TEST_FLAGS="--test_tag_filters=${TF_TEST_FILTER_TAGS} --build_tag_filt --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute " export TF_TEST_TARGETS="${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... " export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" -export IS_NIGHTLY=0 # Not nightly +#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. export TF_PROJECT_NAME=${PROJECT_NAME} export TF_PIP_TEST_ROOT="pip_test" diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py36_full/pip_v1.sh b/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py36_full/pip_v1.sh index c1fc598eed6..484daa63cb8 100644 --- a/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py36_full/pip_v1.sh +++ b/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py36_full/pip_v1.sh @@ -56,7 +56,7 @@ export TF_TEST_FLAGS="--test_tag_filters=${TF_TEST_FILTER_TAGS} --build_tag_filt --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute " export TF_TEST_TARGETS="//tensorflow/python/... " export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" -export IS_NIGHTLY=0 # Not nightly +#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. export TF_PROJECT_NAME=${PROJECT_NAME} export TF_PIP_TEST_ROOT="pip_test" diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py37_full/pip.sh b/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py37_full/pip.sh index 56f2a7f66e9..00047b775b1 100644 --- a/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py37_full/pip.sh +++ b/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py37_full/pip.sh @@ -59,7 +59,7 @@ export TF_TEST_FLAGS="--test_tag_filters=${TF_TEST_FILTER_TAGS} --build_tag_filt --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute " export TF_TEST_TARGETS="${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... " export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" -export IS_NIGHTLY=0 # Not nightly +#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. export TF_PROJECT_NAME=${PROJECT_NAME} export TF_PIP_TEST_ROOT="pip_test" diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py37_full/pip_v1.sh b/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py37_full/pip_v1.sh index e5d3fda2b73..50cf3d61e4a 100644 --- a/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py37_full/pip_v1.sh +++ b/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py37_full/pip_v1.sh @@ -56,7 +56,7 @@ export TF_TEST_FLAGS="--test_tag_filters=${TF_TEST_FILTER_TAGS} --build_tag_filt --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute " export TF_TEST_TARGETS="//tensorflow/python/... " export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" -export IS_NIGHTLY=0 # Not nightly +#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. export TF_PROJECT_NAME=${PROJECT_NAME} export TF_PIP_TEST_ROOT="pip_test" diff --git a/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py38_full/pip.sh b/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py38_full/pip.sh index 28b633c390e..9aa5fdf68c8 100644 --- a/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py38_full/pip.sh +++ b/tensorflow/tools/ci_build/release/ubuntu_16/gpu_py38_full/pip.sh @@ -59,7 +59,7 @@ export TF_TEST_FLAGS="--test_tag_filters=${TF_TEST_FILTER_TAGS} --build_tag_filt --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute " export TF_TEST_TARGETS="${DEFAULT_BAZEL_TARGETS} -//tensorflow/lite/... " export TF_PIP_TESTS="test_pip_virtualenv_non_clean test_pip_virtualenv_clean" -export IS_NIGHTLY=0 # Not nightly +#export IS_NIGHTLY=0 # Not nightly; uncomment if building from tf repo. export TF_PROJECT_NAME=${PROJECT_NAME} export TF_PIP_TEST_ROOT="pip_test" diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py index 4b8289a6202..01a3696823d 100644 --- a/tensorflow/tools/pip_package/setup.py +++ b/tensorflow/tools/pip_package/setup.py @@ -43,8 +43,6 @@ from setuptools import setup from setuptools.command.install import install as InstallCommandBase from setuptools.dist import Distribution -DOCLINES = __doc__.split('\n') - # This version string is semver compatible, but incompatible with pip. # For pip, we will remove all '-' characters from this string, and use the # result for pip. @@ -93,6 +91,16 @@ if 'tf_nightly' in project_name: elif 'tensorflow_estimator' in pkg: REQUIRED_PACKAGES[i] = 'tf-estimator-nightly' +DOCLINES = __doc__.split('\n') +if project_name.endswith('-gpu'): + project_name_no_gpu = project_name[:-len('-gpu')] + _GPU_PACKAGE_NOTE = 'Note that %s package by default supports both CPU and '\ + 'GPU. %s has the same content and exists solely for backward '\ + 'compatiblity. Please migrate to %s for GPU support.'\ + % (project_name_no_gpu, project_name, project_name_no_gpu) + DOCLINES.append(_GPU_PACKAGE_NOTE) + + # pylint: disable=line-too-long CONSOLE_SCRIPTS = [ 'toco_from_protos = tensorflow.lite.toco.python.toco_from_protos:main', @@ -232,6 +240,7 @@ headers = ( list(find_files('*.proto', 'tensorflow/compiler')) + list(find_files('*.proto', 'tensorflow/core')) + list(find_files('*.proto', 'tensorflow/python')) + + list(find_files('*.def', 'tensorflow/compiler')) + list(find_files('*.h', 'tensorflow/c')) + list(find_files('*.h', 'tensorflow/cc')) + list(find_files('*.h', 'tensorflow/compiler')) + diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index c3d097a8362..217edee0f86 100755 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -164,11 +164,11 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): tf_http_archive( name = "XNNPACK", - sha256 = "0440d9ad632945f10992664be84eb0c0c76581f8474df3c124aa30350981126c", - strip_prefix = "XNNPACK-d9a7e85c30a2bea7b6b263f21f066a93cb2b4dee", + sha256 = "05904bb15b7a5abadc261c16e6be3ac2314d6d4384aa16349b7354d9fa8bbb4f", + strip_prefix = "XNNPACK-1e5f80293b3c0197aaf44f3adb9329401fd36ed4", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/XNNPACK/archive/d9a7e85c30a2bea7b6b263f21f066a93cb2b4dee.zip", - "https://github.com/google/XNNPACK/archive/d9a7e85c30a2bea7b6b263f21f066a93cb2b4dee.zip", + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/XNNPACK/archive/1e5f80293b3c0197aaf44f3adb9329401fd36ed4.zip", + "https://github.com/google/XNNPACK/archive/1e5f80293b3c0197aaf44f3adb9329401fd36ed4.zip", ], ) @@ -184,11 +184,11 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): tf_http_archive( name = "pthreadpool", - sha256 = "c4b148fba41fc937fdf96bc195caadf0cf0be83f1c3e335ef5355934d4501f83", - strip_prefix = "pthreadpool-e918b206d26b1f3b2100b0edabf445c18708d2b7", + sha256 = "9f5fb7f87dc778d9c1d638826344b762afa23884d0252526337ae710264faef3", + strip_prefix = "pthreadpool-18a7156cb9be8e534acefade42e46d4209600c35", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/github.com/Maratyszcza/pthreadpool/archive/e918b206d26b1f3b2100b0edabf445c18708d2b7.zip", - "https://github.com/Maratyszcza/pthreadpool/archive/e918b206d26b1f3b2100b0edabf445c18708d2b7.zip", + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/Maratyszcza/pthreadpool/archive/18a7156cb9be8e534acefade42e46d4209600c35.zip", + "https://github.com/Maratyszcza/pthreadpool/archive/18a7156cb9be8e534acefade42e46d4209600c35.zip", ], ) @@ -200,11 +200,11 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): tf_http_archive( name = "mkl_dnn", build_file = clean_dep("//third_party/mkl_dnn:mkldnn.BUILD"), - sha256 = "31e78581e59d7e60d4becaba3834fc6a5bf2dccdae3e16b7f70d89ceab38423f", - strip_prefix = "mkl-dnn-0.21.3", + sha256 = "a0211aeb5e7dad50b97fa5dffc1a2fe2fe732572d4164e1ee8750a2ede43fbec", + strip_prefix = "oneDNN-0.21.3", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/github.com/intel/mkl-dnn/archive/v0.21.3.tar.gz", - "https://github.com/intel/mkl-dnn/archive/v0.21.3.tar.gz", + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/oneapi-src/oneDNN/archive/v0.21.3.tar.gz", + "https://github.com/oneapi-src/oneDNN/archive/v0.21.3.tar.gz", ], ) @@ -237,11 +237,11 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): name = "eigen_archive", build_file = clean_dep("//third_party:eigen.BUILD"), patch_file = clean_dep("//third_party/eigen3:gpu_packet_math.patch"), - sha256 = "2c7c0aec4271dfca6b8a7707e2112f67c4cb3bdf7c89c0e98d3fcd39707c4468", # SHARED_EIGEN_SHA - strip_prefix = "eigen-49f1aeb60d9f759859fce0d16aa5d1ecc7168d51", + sha256 = "854eabe6817e38d7738fde6ec39c3dfc55fd5e68b2523de8cae936f391a38a69", # SHARED_EIGEN_SHA + strip_prefix = "eigen-cc86a31e20b48b0f03d714b4d1b1f50d52848d36", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/gitlab.com/libeigen/eigen/-/archive/49f1aeb60d9f759859fce0d16aa5d1ecc7168d51/eigen-49f1aeb60d9f759859fce0d16aa5d1ecc7168d51.tar.gz", - "https://gitlab.com/libeigen/eigen/-/archive/49f1aeb60d9f759859fce0d16aa5d1ecc7168d51/eigen-49f1aeb60d9f759859fce0d16aa5d1ecc7168d51.tar.gz", + "https://storage.googleapis.com/mirror.tensorflow.org/gitlab.com/libeigen/eigen/-/archive/cc86a31e20b48b0f03d714b4d1b1f50d52848d36/eigen-cc86a31e20b48b0f03d714b4d1b1f50d52848d36.tar.gz", + "https://gitlab.com/libeigen/eigen/-/archive/cc86a31e20b48b0f03d714b4d1b1f50d52848d36/eigen-cc86a31e20b48b0f03d714b4d1b1f50d52848d36.tar.gz", ], ) @@ -655,8 +655,8 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): ) # Check out LLVM and MLIR from llvm-project. - LLVM_COMMIT = "bfa200ebcf3706fde0dde335a3c1fa3fe1b3ba3f" - LLVM_SHA256 = "72deefcfe20434cb27a31ff9503c348dcf21065dbd27e9fa54c1fb3f5089b8e1" + LLVM_COMMIT = "1108f5c737dbdab0277874a7e5b237491839c43a" + LLVM_SHA256 = "bbdaaa145a5a8eed8e6a0f06a3b9965f32b03286eddea5f50c5af2d1f3d008df" LLVM_URLS = [ "https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT), "https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT), diff --git a/third_party/FP16/workspace.bzl b/third_party/FP16/workspace.bzl index 441ef6b15e1..31746d6c371 100644 --- a/third_party/FP16/workspace.bzl +++ b/third_party/FP16/workspace.bzl @@ -5,11 +5,11 @@ load("//third_party:repo.bzl", "third_party_http_archive") def repo(): third_party_http_archive( name = "FP16", - strip_prefix = "FP16-3c54eacb74f6f5e39077300c5564156c424d77ba", - sha256 = "0d56bb92f649ec294dbccb13e04865e3c82933b6f6735d1d7145de45da700156", + strip_prefix = "FP16-4dfe081cf6bcd15db339cf2680b9281b8451eeb3", + sha256 = "d973501a40c55126b31accc2d9f08d931ec3cc190c0430309a5e341d3c0ce32a", urls = [ - "https://mirror.bazel.build/github.com/Maratyszcza/FP16/archive/3c54eacb74f6f5e39077300c5564156c424d77ba.zip", - "https://github.com/Maratyszcza/FP16/archive/3c54eacb74f6f5e39077300c5564156c424d77ba.zip", + "https://mirror.bazel.build/github.com/Maratyszcza/FP16/archive/4dfe081cf6bcd15db339cf2680b9281b8451eeb3.zip", + "https://github.com/Maratyszcza/FP16/archive/4dfe081cf6bcd15db339cf2680b9281b8451eeb3.zip", ], build_file = "//third_party/FP16:BUILD.bazel", ) diff --git a/third_party/aws/aws-c-common.bazel b/third_party/aws/aws-c-common.bazel index a66fbcb1164..ab9406805c2 100644 --- a/third_party/aws/aws-c-common.bazel +++ b/third_party/aws/aws-c-common.bazel @@ -14,7 +14,6 @@ cc_library( srcs = select({ "@org_tensorflow//tensorflow:linux_aarch64": glob([ "source/posix/*.c", - "source/arch/*.c" ]), "@org_tensorflow//tensorflow:linux_x86_64": glob([ "source/posix/*.c", diff --git a/third_party/cpuinfo/workspace.bzl b/third_party/cpuinfo/workspace.bzl index 922ab022486..e7aff433892 100644 --- a/third_party/cpuinfo/workspace.bzl +++ b/third_party/cpuinfo/workspace.bzl @@ -5,11 +5,11 @@ load("//third_party:repo.bzl", "third_party_http_archive") def repo(): third_party_http_archive( name = "cpuinfo", - strip_prefix = "cpuinfo-0cc563acb9baac39f2c1349bc42098c4a1da59e3", - sha256 = "80625d0b69a3d69b70c2236f30db2c542d0922ccf9bb51a61bc39c49fac91a35", + strip_prefix = "cpuinfo-19b9316c71e4e45b170a664bf62ddefd7ac9feb5", + sha256 = "e0a485c072de957668eb324c49d726dc0fd736cfb9436b334325f20d93085003", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/github.com/pytorch/cpuinfo/archive/0cc563acb9baac39f2c1349bc42098c4a1da59e3.tar.gz", - "https://github.com/pytorch/cpuinfo/archive/0cc563acb9baac39f2c1349bc42098c4a1da59e3.tar.gz", + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/pytorch/cpuinfo/archive/19b9316c71e4e45b170a664bf62ddefd7ac9feb5.zip", + "https://github.com/pytorch/cpuinfo/archive/19b9316c71e4e45b170a664bf62ddefd7ac9feb5.zip", ], build_file = "//third_party/cpuinfo:BUILD.bazel", ) diff --git a/third_party/flatbuffers/build_defs.bzl b/third_party/flatbuffers/build_defs.bzl index d07ad18630f..9be627119cf 100644 --- a/third_party/flatbuffers/build_defs.bzl +++ b/third_party/flatbuffers/build_defs.bzl @@ -472,6 +472,7 @@ def flatbuffer_java_library( native.java_library( name = name, srcs = [out_srcjar], + javacopts = ["-source 7 -target 7"], deps = [ "@flatbuffers//:runtime_java", ], @@ -562,7 +563,6 @@ def flatbuffer_android_library( srcs, custom_package = "", package_prefix = "", - javacopts = None, include_paths = DEFAULT_INCLUDE_PATHS, flatc_args = DEFAULT_FLATC_ARGS, visibility = None): @@ -575,7 +575,6 @@ def flatbuffer_android_library( namespace in the schema files will be used. (optional) package_prefix: like custom_package, but prefixes to the existing namespace. (optional) - javacopts: List of options to pass to javac. include_paths: List of paths that includes files can be found in. (optional) flatc_args: List of additional arguments to pass to flatc. (optional) visibility: Visibility setting for the android_library rule. (optional) @@ -604,6 +603,7 @@ def flatbuffer_android_library( android_library( name = name, srcs = [out_srcjar], + javacopts = ["-source 7 -target 7"], visibility = visibility, deps = [ "@flatbuffers//:runtime_android", diff --git a/third_party/gpus/check_cuda_libs.py b/third_party/gpus/check_cuda_libs.py index b7b36e6466e..479380da975 100644 --- a/third_party/gpus/check_cuda_libs.py +++ b/third_party/gpus/check_cuda_libs.py @@ -59,7 +59,7 @@ def check_cuda_lib(path, check_soname=True): objdump = which("objdump") if check_soname and objdump is not None and not _is_windows(): # Decode is necessary as in py3 the return type changed from str to bytes - output = subprocess.check_output([objdump, "-p", path]).decode("ascii") + output = subprocess.check_output([objdump, "-p", path]).decode("utf-8") output = [line for line in output.splitlines() if "SONAME" in line] sonames = [line.strip().split(" ")[-1] for line in output] if not any([soname == os.path.basename(path) for soname in sonames]): @@ -86,4 +86,3 @@ def main(): if __name__ == "__main__": main() - diff --git a/third_party/gpus/crosstool/cc_toolchain_config.bzl.tpl b/third_party/gpus/crosstool/cc_toolchain_config.bzl.tpl index e50592fd857..4acc05ff88c 100644 --- a/third_party/gpus/crosstool/cc_toolchain_config.bzl.tpl +++ b/third_party/gpus/crosstool/cc_toolchain_config.bzl.tpl @@ -12,1426 +12,237 @@ load( "tool", "tool_path", "variable_with_value", + "with_feature_set", ) -load( - "@bazel_tools//tools/build_defs/cc:action_names.bzl", - "ASSEMBLE_ACTION_NAME", - "CC_FLAGS_MAKE_VARIABLE_ACTION_NAME", - "CLIF_MATCH_ACTION_NAME", - "CPP_COMPILE_ACTION_NAME", - "CPP_HEADER_PARSING_ACTION_NAME", - "CPP_LINK_DYNAMIC_LIBRARY_ACTION_NAME", - "CPP_LINK_EXECUTABLE_ACTION_NAME", - "CPP_LINK_NODEPS_DYNAMIC_LIBRARY_ACTION_NAME", - "CPP_LINK_STATIC_LIBRARY_ACTION_NAME", - "CPP_MODULE_CODEGEN_ACTION_NAME", - "CPP_MODULE_COMPILE_ACTION_NAME", - "C_COMPILE_ACTION_NAME", - "LINKSTAMP_COMPILE_ACTION_NAME", - "LTO_BACKEND_ACTION_NAME", - "LTO_INDEXING_ACTION_NAME", - "OBJCPP_COMPILE_ACTION_NAME", - "OBJCPP_EXECUTABLE_ACTION_NAME", - "OBJC_ARCHIVE_ACTION_NAME", - "OBJC_COMPILE_ACTION_NAME", - "OBJC_EXECUTABLE_ACTION_NAME", - "OBJC_FULLY_LINK_ACTION_NAME", - "PREPROCESS_ASSEMBLE_ACTION_NAME", - "STRIP_ACTION_NAME", -) +load("@bazel_tools//tools/build_defs/cc:action_names.bzl", "ACTION_NAMES") -ACTION_NAMES = struct( - c_compile = C_COMPILE_ACTION_NAME, - cpp_compile = CPP_COMPILE_ACTION_NAME, - linkstamp_compile = LINKSTAMP_COMPILE_ACTION_NAME, - cc_flags_make_variable = CC_FLAGS_MAKE_VARIABLE_ACTION_NAME, - cpp_module_codegen = CPP_MODULE_CODEGEN_ACTION_NAME, - cpp_header_parsing = CPP_HEADER_PARSING_ACTION_NAME, - cpp_module_compile = CPP_MODULE_COMPILE_ACTION_NAME, - assemble = ASSEMBLE_ACTION_NAME, - preprocess_assemble = PREPROCESS_ASSEMBLE_ACTION_NAME, - lto_indexing = LTO_INDEXING_ACTION_NAME, - lto_backend = LTO_BACKEND_ACTION_NAME, - cpp_link_executable = CPP_LINK_EXECUTABLE_ACTION_NAME, - cpp_link_dynamic_library = CPP_LINK_DYNAMIC_LIBRARY_ACTION_NAME, - cpp_link_nodeps_dynamic_library = CPP_LINK_NODEPS_DYNAMIC_LIBRARY_ACTION_NAME, - cpp_link_static_library = CPP_LINK_STATIC_LIBRARY_ACTION_NAME, - strip = STRIP_ACTION_NAME, - objc_archive = OBJC_ARCHIVE_ACTION_NAME, - objc_compile = OBJC_COMPILE_ACTION_NAME, - objc_executable = OBJC_EXECUTABLE_ACTION_NAME, - objc_fully_link = OBJC_FULLY_LINK_ACTION_NAME, - objcpp_compile = OBJCPP_COMPILE_ACTION_NAME, - objcpp_executable = OBJCPP_EXECUTABLE_ACTION_NAME, - clif_match = CLIF_MATCH_ACTION_NAME, - objcopy_embed_data = "objcopy_embed_data", - ld_embed_data = "ld_embed_data", -) +def all_assembly_actions(): + return [ + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ] -def _impl(ctx): - if (ctx.attr.cpu == "darwin"): - toolchain_identifier = "local_darwin" - elif (ctx.attr.cpu == "local"): - toolchain_identifier = "local_linux" - elif (ctx.attr.cpu == "x64_windows"): - toolchain_identifier = "local_windows" - else: - fail("Unreachable") +def all_compile_actions(): + return [ + ACTION_NAMES.assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.preprocess_assemble, + ] - host_system_name = "local" +def all_c_compile_actions(): + return [ + ACTION_NAMES.c_compile, + ] - target_system_name = "local" +def all_cpp_compile_actions(): + return [ + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.linkstamp_compile, + ] - if (ctx.attr.cpu == "darwin"): - target_cpu = "darwin" - elif (ctx.attr.cpu == "local"): - target_cpu = "local" - elif (ctx.attr.cpu == "x64_windows"): - target_cpu = "x64_windows" - else: - fail("Unreachable") +def all_preprocessed_actions(): + return [ + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.preprocess_assemble, + ] - if (ctx.attr.cpu == "local"): - target_libc = "local" - elif (ctx.attr.cpu == "darwin"): - target_libc = "macosx" - elif (ctx.attr.cpu == "x64_windows"): - target_libc = "msvcrt" - else: - fail("Unreachable") - - if (ctx.attr.cpu == "darwin" or - ctx.attr.cpu == "local"): - compiler = "compiler" - elif (ctx.attr.cpu == "x64_windows"): - compiler = "msvc-cl" - else: - fail("Unreachable") - - abi_version = "local" - - abi_libc_version = "local" - - cc_target_os = None - - builtin_sysroot = ctx.attr.builtin_sysroot - - all_link_actions = [ +def all_link_actions(): + return [ ACTION_NAMES.cpp_link_executable, ACTION_NAMES.cpp_link_dynamic_library, ACTION_NAMES.cpp_link_nodeps_dynamic_library, ] - cpp_link_dynamic_library_action = action_config( - action_name = ACTION_NAMES.cpp_link_dynamic_library, - implies = [ - "nologo", - "shared_flag", - "linkstamps", - "output_execpath_flags", - "input_param_flags", - "user_link_flags", - "linker_subsystem_flag", - "linker_param_file", - "msvc_env", - "no_stripping", - "has_configured_linker_path", - "def_file", +def all_executable_link_actions(): + return [ + ACTION_NAMES.cpp_link_executable, + ] + +def all_shared_library_link_actions(): + return [ + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ] + +def all_archive_actions(): + return [ACTION_NAMES.cpp_link_static_library] + +def all_strip_actions(): + return [ACTION_NAMES.strip] + +def _library_to_link(flag_prefix, value, iterate = None): + return flag_group( + flags = [ + "{}%{{libraries_to_link.{}}}".format( + flag_prefix, + iterate if iterate else "name", + ), ], - tools = [tool(path = ctx.attr.msvc_link_path)], + iterate_over = ("libraries_to_link." + iterate if iterate else None), + expand_if_equal = variable_with_value( + name = "libraries_to_link.type", + value = value, + ), ) - cpp_link_nodeps_dynamic_library_action = action_config( - action_name = ACTION_NAMES.cpp_link_nodeps_dynamic_library, - implies = [ - "nologo", - "shared_flag", - "linkstamps", - "output_execpath_flags", - "input_param_flags", - "user_link_flags", - "linker_subsystem_flag", - "linker_param_file", - "msvc_env", - "no_stripping", - "has_configured_linker_path", - "def_file", - ], - tools = [tool(path = ctx.attr.msvc_link_path)], +def _surround_static_library(prefix, suffix): + return [ + flag_group( + flags = [prefix, "%{libraries_to_link.name}", suffix], + expand_if_true = "libraries_to_link.is_whole_archive", + ), + flag_group( + flags = ["%{libraries_to_link.name}"], + expand_if_false = "libraries_to_link.is_whole_archive", + ), + ] + +def _prefix_static_library(prefix): + return [ + flag_group( + flags = ["%{libraries_to_link.name}"], + expand_if_false = "libraries_to_link.is_whole_archive", + ), + flag_group( + flags = [prefix + "%{libraries_to_link.name}"], + expand_if_true = "libraries_to_link.is_whole_archive", + ), + ] + +def _static_library_to_link(alwayslink_prefix, alwayslink_suffix = None): + if alwayslink_suffix: + flag_groups = _surround_static_library(alwayslink_prefix, alwayslink_suffix) + else: + flag_groups = _prefix_static_library(alwayslink_prefix) + return flag_group( + flag_groups = flag_groups, + expand_if_equal = variable_with_value( + name = "libraries_to_link.type", + value = "static_library", + ), ) - cpp_link_static_library_action = action_config( - action_name = ACTION_NAMES.cpp_link_static_library, - implies = [ - "nologo", - "archiver_flags", - "input_param_flags", - "linker_param_file", - "msvc_env", - ], - tools = [tool(path = ctx.attr.msvc_lib_path)], +def _iterate_flag_group(iterate_over, flags = [], flag_groups = []): + return flag_group( + iterate_over = iterate_over, + expand_if_available = iterate_over, + flag_groups = flag_groups, + flags = flags, ) - assemble_action = action_config( - action_name = ACTION_NAMES.assemble, - implies = [ - "compiler_input_flags", - "compiler_output_flags", - "nologo", - "msvc_env", - "sysroot", - ], - tools = [tool(path = ctx.attr.msvc_ml_path)], +def _libraries_to_link_group(flavour): + if flavour == "linux": + return _iterate_flag_group( + iterate_over = "libraries_to_link", + flag_groups = [ + flag_group( + flags = ["-Wl,--start-lib"], + expand_if_equal = variable_with_value( + name = "libraries_to_link.type", + value = "object_file_group", + ), + ), + _library_to_link("", "object_file_group", "object_files"), + flag_group( + flags = ["-Wl,--end-lib"], + expand_if_equal = variable_with_value( + name = "libraries_to_link.type", + value = "object_file_group", + ), + ), + _library_to_link("", "object_file"), + _library_to_link("", "interface_library"), + _static_library_to_link("-Wl,-whole-archive", "-Wl,-no-whole-archive"), + _library_to_link("-l", "dynamic_library"), + _library_to_link("-l:", "versioned_dynamic_library"), + ], + ) + elif flavour == "darwin": + return _iterate_flag_group( + iterate_over = "libraries_to_link", + flag_groups = [ + _library_to_link("", "object_file_group", "object_files"), + _library_to_link("", "object_file"), + _library_to_link("", "interface_library"), + _static_library_to_link("-Wl,-force_load,"), + _library_to_link("-l", "dynamic_library"), + _library_to_link("-l:", "versioned_dynamic_library"), + ], + ) + elif flavour == "msvc": + return _iterate_flag_group( + iterate_over = "libraries_to_link", + flag_groups = [ + _library_to_link("", "object_file_group", "object_files"), + _library_to_link("", "object_file"), + _library_to_link("", "interface_library"), + _static_library_to_link("/WHOLEARCHIVE:"), + ], + ) + +def _action_configs_with_tool(path, actions): + return [ + action_config( + action_name = name, + enabled = True, + tools = [tool(path = path)], + ) + for name in actions + ] + +def _action_configs(assembly_path, c_compiler_path, cc_compiler_path, archiver_path, linker_path, strip_path): + return _action_configs_with_tool( + assembly_path, + all_assembly_actions(), + ) + _action_configs_with_tool( + c_compiler_path, + all_c_compile_actions(), + ) + _action_configs_with_tool( + cc_compiler_path, + all_cpp_compile_actions(), + ) + _action_configs_with_tool( + archiver_path, + all_archive_actions(), + ) + _action_configs_with_tool( + linker_path, + all_link_actions(), + ) + _action_configs_with_tool( + strip_path, + all_strip_actions(), ) - preprocess_assemble_action = action_config( - action_name = ACTION_NAMES.preprocess_assemble, - implies = [ - "compiler_input_flags", - "compiler_output_flags", - "nologo", - "msvc_env", - "sysroot", - ], - tools = [tool(path = ctx.attr.msvc_ml_path)], - ) - - c_compile_action = action_config( - action_name = ACTION_NAMES.c_compile, - implies = [ - "compiler_input_flags", - "compiler_output_flags", - "nologo", - "msvc_env", - "parse_showincludes", - "user_compile_flags", - "sysroot", - "unfiltered_compile_flags", - ], - tools = [tool(path = ctx.attr.msvc_cl_path)], - ) - - cpp_compile_action = action_config( - action_name = ACTION_NAMES.cpp_compile, - implies = [ - "compiler_input_flags", - "compiler_output_flags", - "nologo", - "msvc_env", - "parse_showincludes", - "user_compile_flags", - "sysroot", - "unfiltered_compile_flags", - ], - tools = [tool(path = ctx.attr.msvc_cl_path)], - ) - - cpp_link_executable_action = action_config( - action_name = ACTION_NAMES.cpp_link_executable, - implies = [ - "nologo", - "linkstamps", - "output_execpath_flags", - "input_param_flags", - "user_link_flags", - "linker_subsystem_flag", - "linker_param_file", - "msvc_env", - "no_stripping", - ], - tools = [tool(path = ctx.attr.msvc_link_path)], - ) - - if (ctx.attr.cpu == "darwin" or - ctx.attr.cpu == "local"): - action_configs = [] - elif (ctx.attr.cpu == "x64_windows"): - action_configs = [ - assemble_action, - preprocess_assemble_action, - c_compile_action, - cpp_compile_action, - cpp_link_executable_action, - cpp_link_dynamic_library_action, - cpp_link_nodeps_dynamic_library_action, - cpp_link_static_library_action, +def _tool_paths(cpu, ctx): + if cpu in ["local", "darwin"]: + return [ + tool_path(name = "gcc", path = ctx.attr.host_compiler_path), + tool_path(name = "ar", path = ctx.attr.host_compiler_prefix + ( + "/ar" if cpu == "local" else "/libtool" + )), + tool_path(name = "compat-ld", path = ctx.attr.host_compiler_prefix + "/ld"), + tool_path(name = "cpp", path = ctx.attr.host_compiler_prefix + "/cpp"), + tool_path(name = "dwp", path = ctx.attr.host_compiler_prefix + "/dwp"), + tool_path(name = "gcov", path = ctx.attr.host_compiler_prefix + "/gcov"), + tool_path(name = "ld", path = ctx.attr.host_compiler_prefix + "/ld"), + tool_path(name = "nm", path = ctx.attr.host_compiler_prefix + "/nm"), + tool_path(name = "objcopy", path = ctx.attr.host_compiler_prefix + "/objcopy"), + tool_path(name = "objdump", path = ctx.attr.host_compiler_prefix + "/objdump"), + tool_path(name = "strip", path = ctx.attr.host_compiler_prefix + "/strip"), ] - else: - fail("Unreachable") - - no_windows_export_all_symbols_feature = feature(name = "no_windows_export_all_symbols") - - pic_feature = feature( - name = "pic", - enabled = True, - flag_sets = [ - flag_set( - actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], - flag_groups = [ - flag_group(flags = ["-fPIC"], expand_if_available = "pic"), - flag_group( - flags = ["-fPIE"], - expand_if_not_available = "pic", - ), - ], - ), - ], - ) - - preprocessor_defines_feature = feature( - name = "preprocessor_defines", - enabled = True, - flag_sets = [ - flag_set( - actions = [ - ACTION_NAMES.assemble, - ACTION_NAMES.preprocess_assemble, - ACTION_NAMES.c_compile, - ACTION_NAMES.cpp_compile, - ACTION_NAMES.cpp_header_parsing, - ACTION_NAMES.cpp_module_compile, - ], - flag_groups = [ - flag_group( - flags = ["/D%{preprocessor_defines}"], - iterate_over = "preprocessor_defines", - ), - ], - ), - ], - ) - - generate_pdb_file_feature = feature( - name = "generate_pdb_file", - requires = [ - feature_set(features = ["dbg"]), - feature_set(features = ["fastbuild"]), - ], - ) - - linkstamps_feature = feature( - name = "linkstamps", - flag_sets = [ - flag_set( - actions = all_link_actions, - flag_groups = [ - flag_group( - flags = ["%{linkstamp_paths}"], - iterate_over = "linkstamp_paths", - expand_if_available = "linkstamp_paths", - ), - ], - ), - ], - ) - - unfiltered_compile_flags_feature = feature( - name = "unfiltered_compile_flags", - flag_sets = ([ - flag_set( - actions = [ - ACTION_NAMES.preprocess_assemble, - ACTION_NAMES.c_compile, - ACTION_NAMES.cpp_compile, - ACTION_NAMES.cpp_header_parsing, - ACTION_NAMES.cpp_module_compile, - ACTION_NAMES.cpp_module_codegen, - ], - flag_groups = [ - flag_group( - flags = ctx.attr.host_unfiltered_compile_flags, - ), - ], - ), - ] if ctx.attr.host_unfiltered_compile_flags else []), - ) - - determinism_feature = feature( - name = "determinism", - flag_sets = [ - flag_set( - actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], - flag_groups = [ - flag_group( - flags = [ - "-Wno-builtin-macro-redefined", - "-D__DATE__=\"redacted\"", - "-D__TIMESTAMP__=\"redacted\"", - "-D__TIME__=\"redacted\"", - ], - ), - ], - ), - ], - ) - - nologo_feature = feature( - name = "nologo", - flag_sets = [ - flag_set( - actions = [ - ACTION_NAMES.c_compile, - ACTION_NAMES.cpp_compile, - ACTION_NAMES.cpp_module_compile, - ACTION_NAMES.cpp_module_codegen, - ACTION_NAMES.cpp_header_parsing, - ACTION_NAMES.assemble, - ACTION_NAMES.preprocess_assemble, - ACTION_NAMES.cpp_link_executable, - ACTION_NAMES.cpp_link_dynamic_library, - ACTION_NAMES.cpp_link_nodeps_dynamic_library, - ACTION_NAMES.cpp_link_static_library, - ], - flag_groups = [flag_group(flags = ["/nologo"])], - ), - ], - ) - - supports_pic_feature = feature(name = "supports_pic", enabled = True) - - output_execpath_flags_feature = feature( - name = "output_execpath_flags", - flag_sets = [ - flag_set( - actions = all_link_actions, - flag_groups = [ - flag_group( - flags = ["/OUT:%{output_execpath}"], - expand_if_available = "output_execpath", - ), - ], - ), - ], - ) - - default_link_flags_feature = feature( - name = "default_link_flags", - enabled = True, - flag_sets = [ - flag_set( - actions = all_link_actions, - flag_groups = [flag_group(flags = ["/MACHINE:X64"])], - ), - ], - ) - - if (ctx.attr.cpu == "local"): - hardening_feature = feature( - name = "hardening", - flag_sets = [ - flag_set( - actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], - flag_groups = [ - flag_group( - flags = [ - "-U_FORTIFY_SOURCE", - "-D_FORTIFY_SOURCE=1", - "-fstack-protector", - ], - ), - ], - ), - flag_set( - actions = [ - ACTION_NAMES.cpp_link_dynamic_library, - ACTION_NAMES.cpp_link_nodeps_dynamic_library, - ], - flag_groups = [flag_group(flags = ["-Wl,-z,relro,-z,now"])], - ), - flag_set( - actions = [ACTION_NAMES.cpp_link_executable], - flag_groups = [flag_group(flags = ["-pie", "-Wl,-z,relro,-z,now"])], - ), - ], - ) - elif (ctx.attr.cpu == "darwin"): - hardening_feature = feature( - name = "hardening", - flag_sets = [ - flag_set( - actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], - flag_groups = [ - flag_group( - flags = [ - "-U_FORTIFY_SOURCE", - "-D_FORTIFY_SOURCE=1", - "-fstack-protector", - ], - ), - ], - ), - flag_set( - actions = [ACTION_NAMES.cpp_link_executable], - flag_groups = [flag_group(flags = ["-pie"])], - ), - ], - ) - else: - hardening_feature = None - - supports_dynamic_linker_feature = feature(name = "supports_dynamic_linker", enabled = True) - - targets_windows_feature = feature( - name = "targets_windows", - enabled = True, - implies = ["copy_dynamic_libraries_to_binary"], - ) - - msvc_env_feature = feature( - name = "msvc_env", - env_sets = [ - env_set( - actions = [ - ACTION_NAMES.c_compile, - ACTION_NAMES.cpp_compile, - ACTION_NAMES.cpp_module_compile, - ACTION_NAMES.cpp_module_codegen, - ACTION_NAMES.cpp_header_parsing, - ACTION_NAMES.assemble, - ACTION_NAMES.preprocess_assemble, - ACTION_NAMES.cpp_link_executable, - ACTION_NAMES.cpp_link_dynamic_library, - ACTION_NAMES.cpp_link_nodeps_dynamic_library, - ACTION_NAMES.cpp_link_static_library, - ], - env_entries = [ - env_entry(key = "PATH", value = ctx.attr.msvc_env_path), - env_entry( - key = "INCLUDE", - value = ctx.attr.msvc_env_include, - ), - env_entry(key = "LIB", value = ctx.attr.msvc_env_lib), - env_entry(key = "TMP", value = ctx.attr.msvc_env_tmp), - env_entry(key = "TEMP", value = ctx.attr.msvc_env_tmp), - ], - ), - ], - ) - - linker_subsystem_flag_feature = feature( - name = "linker_subsystem_flag", - flag_sets = [ - flag_set( - actions = all_link_actions, - flag_groups = [flag_group(flags = ["/SUBSYSTEM:CONSOLE"])], - ), - ], - ) - - dynamic_link_msvcrt_no_debug_feature = feature( - name = "dynamic_link_msvcrt_no_debug", - flag_sets = [ - flag_set( - actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], - flag_groups = [flag_group(flags = ["/MD"])], - ), - flag_set( - actions = all_link_actions, - flag_groups = [flag_group(flags = ["/DEFAULTLIB:msvcrt.lib"])], - ), - ], - requires = [ - feature_set(features = ["fastbuild"]), - feature_set(features = ["opt"]), - ], - ) - - warnings_feature = feature( - name = "warnings", - flag_sets = [ - flag_set( - actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], - flag_groups = [ - flag_group( - flags = ["-Wall"] + ctx.attr.host_compiler_warnings, - ), - ], - ), - ], - ) - - dynamic_link_msvcrt_debug_feature = feature( - name = "dynamic_link_msvcrt_debug", - flag_sets = [ - flag_set( - actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], - flag_groups = [flag_group(flags = ["/MDd"])], - ), - flag_set( - actions = all_link_actions, - flag_groups = [flag_group(flags = ["/DEFAULTLIB:msvcrtd.lib"])], - ), - ], - requires = [feature_set(features = ["dbg"])], - ) - - compiler_output_flags_feature = feature( - name = "compiler_output_flags", - flag_sets = [ - flag_set( - actions = [ACTION_NAMES.assemble], - flag_groups = [ - flag_group( - flag_groups = [ - flag_group( - flags = ["/Fo%{output_file}", "/Zi"], - expand_if_not_available = "output_preprocess_file", - ), - ], - expand_if_available = "output_file", - expand_if_not_available = "output_assembly_file", - ), - ], - ), - flag_set( - actions = [ - ACTION_NAMES.preprocess_assemble, - ACTION_NAMES.c_compile, - ACTION_NAMES.cpp_compile, - ACTION_NAMES.cpp_header_parsing, - ACTION_NAMES.cpp_module_compile, - ACTION_NAMES.cpp_module_codegen, - ], - flag_groups = [ - flag_group( - flag_groups = [ - flag_group( - flags = ["/Fo%{output_file}"], - expand_if_not_available = "output_preprocess_file", - ), - ], - expand_if_available = "output_file", - expand_if_not_available = "output_assembly_file", - ), - flag_group( - flag_groups = [ - flag_group( - flags = ["/Fa%{output_file}"], - expand_if_available = "output_assembly_file", - ), - ], - expand_if_available = "output_file", - ), - flag_group( - flag_groups = [ - flag_group( - flags = ["/P", "/Fi%{output_file}"], - expand_if_available = "output_preprocess_file", - ), - ], - expand_if_available = "output_file", - ), - ], - ), - ], - ) - - if ctx.attr.compiler == "clang": - default_compile_flags_feature = feature( - name = "default_compile_flags", - enabled = True, - flag_sets = [ - flag_set( - actions = [ - ACTION_NAMES.assemble, - ACTION_NAMES.preprocess_assemble, - ACTION_NAMES.linkstamp_compile, - ACTION_NAMES.c_compile, - ACTION_NAMES.cpp_compile, - ACTION_NAMES.cpp_header_parsing, - ACTION_NAMES.cpp_module_compile, - ACTION_NAMES.cpp_module_codegen, - ACTION_NAMES.lto_backend, - ACTION_NAMES.clif_match, - ], - flag_groups = [ - flag_group( - flags = [ - "-fexperimental-new-pass-manager", - ], - ), - ], - ), - ], - ) - - elif ctx.attr.compiler == "msvc": - default_compile_flags_feature = feature( - name = "default_compile_flags", - enabled = True, - flag_sets = [ - flag_set( - actions = [ - ACTION_NAMES.assemble, - ACTION_NAMES.preprocess_assemble, - ACTION_NAMES.linkstamp_compile, - ACTION_NAMES.c_compile, - ACTION_NAMES.cpp_compile, - ACTION_NAMES.cpp_header_parsing, - ACTION_NAMES.cpp_module_compile, - ACTION_NAMES.cpp_module_codegen, - ACTION_NAMES.lto_backend, - ACTION_NAMES.clif_match, - ], - flag_groups = [ - flag_group( - flags = [ - "/DCOMPILER_MSVC", - "/DNOMINMAX", - "/D_WIN32_WINNT=0x0600", - "/D_CRT_SECURE_NO_DEPRECATE", - "/D_CRT_SECURE_NO_WARNINGS", - "/D_SILENCE_STDEXT_HASH_DEPRECATION_WARNINGS", - "/bigobj", - "/Zm500", - "/J", - "/Gy", - "/GF", - "/EHsc", - "/wd4351", - "/wd4291", - "/wd4250", - "/wd4996", - ], - ), - ], - ), - ], - ) - - else: - default_compile_flags_feature = feature( - name = "default_compile_flags") - - static_link_msvcrt_debug_feature = feature( - name = "static_link_msvcrt_debug", - flag_sets = [ - flag_set( - actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], - flag_groups = [flag_group(flags = ["/MTd"])], - ), - flag_set( - actions = all_link_actions, - flag_groups = [flag_group(flags = ["/DEFAULTLIB:libcmtd.lib"])], - ), - ], - requires = [feature_set(features = ["dbg"])], - ) - - static_link_msvcrt_feature = feature(name = "static_link_msvcrt") - - if (ctx.attr.cpu == "darwin" or - ctx.attr.cpu == "local"): - dbg_feature = feature( - name = "dbg", - flag_sets = [ - flag_set( - actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], - flag_groups = [flag_group(flags = ["-g"])], - ), - ], - implies = ["common"], - ) - elif (ctx.attr.cpu == "x64_windows"): - dbg_feature = feature( - name = "dbg", - flag_sets = [ - flag_set( - actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], - flag_groups = [flag_group(flags = ["/Od", "/Z7", "/DDEBUG"])], - ), - flag_set( - actions = all_link_actions, - flag_groups = [flag_group(flags = ["/DEBUG:FULL", "/INCREMENTAL:NO"])], - ), - ], - implies = ["generate_pdb_file"], - ) - else: - dbg_feature = None - - undefined_dynamic_feature = feature( - name = "undefined-dynamic", - flag_sets = [ - flag_set( - actions = [ - ACTION_NAMES.cpp_link_dynamic_library, - ACTION_NAMES.cpp_link_nodeps_dynamic_library, - ACTION_NAMES.cpp_link_executable, - ], - flag_groups = [flag_group(flags = ["-undefined", "dynamic_lookup"])], - ), - ], - ) - - parse_showincludes_feature = feature( - name = "parse_showincludes", - flag_sets = [ - flag_set( - actions = [ - ACTION_NAMES.preprocess_assemble, - ACTION_NAMES.c_compile, - ACTION_NAMES.cpp_compile, - ACTION_NAMES.cpp_module_compile, - ACTION_NAMES.cpp_header_parsing, - ], - flag_groups = [flag_group(flags = ["/showIncludes"])], - ), - ], - ) - - linker_param_file_feature = feature( - name = "linker_param_file", - flag_sets = [ - flag_set( - actions = all_link_actions + - [ACTION_NAMES.cpp_link_static_library], - flag_groups = [ - flag_group( - flags = ["@%{linker_param_file}"], - expand_if_available = "linker_param_file", - ), - ], - ), - ], - ) - - static_link_msvcrt_no_debug_feature = feature( - name = "static_link_msvcrt_no_debug", - flag_sets = [ - flag_set( - actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], - flag_groups = [flag_group(flags = ["/MT"])], - ), - flag_set( - actions = all_link_actions, - flag_groups = [flag_group(flags = ["/DEFAULTLIB:libcmt.lib"])], - ), - ], - requires = [ - feature_set(features = ["fastbuild"]), - feature_set(features = ["opt"]), - ], - ) - - supports_interface_shared_libraries_feature = feature( - name = "supports_interface_shared_libraries", - enabled = True, - ) - - disable_assertions_feature = feature( - name = "disable-assertions", - flag_sets = [ - flag_set( - actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], - flag_groups = [flag_group(flags = ["-DNDEBUG"])], - ), - ], - ) - - if (ctx.attr.cpu == "x64_windows"): - fastbuild_feature = feature( - name = "fastbuild", - flag_sets = [ - flag_set( - actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], - flag_groups = [flag_group(flags = ["/Od", "/Z7", "/DDEBUG"])], - ), - flag_set( - actions = all_link_actions, - flag_groups = [ - flag_group(flags = ["/DEBUG:FASTLINK", "/INCREMENTAL:NO"]), - ], - ), - ], - implies = ["generate_pdb_file"], - ) - elif (ctx.attr.cpu == "darwin" or - ctx.attr.cpu == "local"): - fastbuild_feature = feature(name = "fastbuild", implies = ["common"]) - else: - fastbuild_feature = None - - user_compile_flags_feature = feature( - name = "user_compile_flags", - flag_sets = [ - flag_set( - actions = [ - ACTION_NAMES.preprocess_assemble, - ACTION_NAMES.c_compile, - ACTION_NAMES.cpp_compile, - ACTION_NAMES.cpp_header_parsing, - ACTION_NAMES.cpp_module_compile, - ACTION_NAMES.cpp_module_codegen, - ], - flag_groups = [ - flag_group( - flags = ["%{user_compile_flags}"], - iterate_over = "user_compile_flags", - expand_if_available = "user_compile_flags", - ), - ], - ), - ], - ) - - compiler_input_flags_feature = feature( - name = "compiler_input_flags", - flag_sets = [ - flag_set( - actions = [ - ACTION_NAMES.assemble, - ACTION_NAMES.preprocess_assemble, - ACTION_NAMES.c_compile, - ACTION_NAMES.cpp_compile, - ACTION_NAMES.cpp_header_parsing, - ACTION_NAMES.cpp_module_compile, - ACTION_NAMES.cpp_module_codegen, - ], - flag_groups = [ - flag_group( - flags = ["/c", "%{source_file}"], - expand_if_available = "source_file", - ), - ], - ), - ], - ) - - no_legacy_features_feature = feature(name = "no_legacy_features") - - archiver_flags_feature = feature( - name = "archiver_flags", - flag_sets = [ - flag_set( - actions = [ACTION_NAMES.cpp_link_static_library], - flag_groups = [ - flag_group( - flags = ["/OUT:%{output_execpath}"], - expand_if_available = "output_execpath", - ), - ], - ), - ], - ) - - redirector_feature = feature( - name = "redirector", - enabled = True, - flag_sets = [ - flag_set( - actions = [ - ACTION_NAMES.c_compile, - ACTION_NAMES.cpp_compile, - ACTION_NAMES.cpp_module_compile, - ACTION_NAMES.cpp_module_codegen, - ACTION_NAMES.cpp_header_parsing, - ACTION_NAMES.assemble, - ACTION_NAMES.preprocess_assemble, - ], - flag_groups = [ - flag_group( - flags = [ - "-B", - "external/local_config_cuda/crosstool/windows/msvc_wrapper_for_nvcc.py", - ], - ), - ], - ), - ], - ) - - linker_bin_path_feature = feature( - name = "linker-bin-path", - flag_sets = [ - flag_set( - actions = all_link_actions, - flag_groups = [flag_group(flags = ["-B" + ctx.attr.linker_bin_path])], - ), - ], - ) - - if (ctx.attr.cpu == "local"): - opt_feature = feature( - name = "opt", - flag_sets = [ - flag_set( - actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], - flag_groups = [ - flag_group( - flags = ["-g0", "-O2", "-ffunction-sections", "-fdata-sections"], - ), - ], - ), - flag_set( - actions = [ - ACTION_NAMES.cpp_link_dynamic_library, - ACTION_NAMES.cpp_link_nodeps_dynamic_library, - ACTION_NAMES.cpp_link_executable, - ], - flag_groups = [flag_group(flags = ["-Wl,--gc-sections"])], - ), - ], - implies = ["common", "disable-assertions"], - ) - elif (ctx.attr.cpu == "darwin"): - opt_feature = feature( - name = "opt", - flag_sets = [ - flag_set( - actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], - flag_groups = [ - flag_group( - flags = ["-g0", "-O2", "-ffunction-sections", "-fdata-sections"], - ), - ], - ), - ], - implies = ["common", "disable-assertions"], - ) - elif (ctx.attr.cpu == "x64_windows"): - opt_feature = feature( - name = "opt", - flag_sets = [ - flag_set( - actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], - flag_groups = [flag_group(flags = ["/O2", "/DNDEBUG"])], - ), - ], - ) - else: - opt_feature = None - - include_paths_feature = feature( - name = "include_paths", - enabled = True, - flag_sets = [ - flag_set( - actions = [ - ACTION_NAMES.assemble, - ACTION_NAMES.preprocess_assemble, - ACTION_NAMES.c_compile, - ACTION_NAMES.cpp_compile, - ACTION_NAMES.cpp_header_parsing, - ACTION_NAMES.cpp_module_compile, - ], - flag_groups = [ - flag_group( - flags = ["/I%{quote_include_paths}"], - iterate_over = "quote_include_paths", - ), - flag_group( - flags = ["/I%{include_paths}"], - iterate_over = "include_paths", - ), - flag_group( - flags = ["/I%{system_include_paths}"], - iterate_over = "system_include_paths", - ), - ], - ), - ], - ) - - shared_flag_feature = feature( - name = "shared_flag", - flag_sets = [ - flag_set( - actions = [ - ACTION_NAMES.cpp_link_dynamic_library, - ACTION_NAMES.cpp_link_nodeps_dynamic_library, - ], - flag_groups = [flag_group(flags = ["/DLL"])], - ), - ], - ) - - windows_export_all_symbols_feature = feature(name = "windows_export_all_symbols") - - frame_pointer_feature = feature( - name = "frame-pointer", - flag_sets = [ - flag_set( - actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], - flag_groups = [flag_group(flags = ["-fno-omit-frame-pointer"])], - ), - ], - ) - - build_id_feature = feature( - name = "build-id", - flag_sets = [ - flag_set( - actions = all_link_actions, - flag_groups = [ - flag_group( - flags = ["-Wl,--build-id=md5", "-Wl,--hash-style=gnu"], - ), - ], - ), - ], - ) - - sysroot_feature = feature( - name = "sysroot", - flag_sets = [ - flag_set( - actions = [ - ACTION_NAMES.assemble, - ACTION_NAMES.preprocess_assemble, - ACTION_NAMES.c_compile, - ACTION_NAMES.cpp_compile, - ACTION_NAMES.cpp_header_parsing, - ACTION_NAMES.cpp_module_compile, - ACTION_NAMES.cpp_module_codegen, - ACTION_NAMES.cpp_link_executable, - ACTION_NAMES.cpp_link_dynamic_library, - ACTION_NAMES.cpp_link_nodeps_dynamic_library, - ], - flag_groups = [ - flag_group( - flags = ["--sysroot=%{sysroot}"], - iterate_over = "sysroot", - expand_if_available = "sysroot", - ), - ], - ), - ], - ) - - cuda_path_feature = feature( - name = "cuda_path", - enabled = True, - flag_sets = [ - flag_set( - actions = [ - ACTION_NAMES.assemble, - ACTION_NAMES.preprocess_assemble, - ACTION_NAMES.c_compile, - ACTION_NAMES.cpp_compile, - ACTION_NAMES.cpp_header_parsing, - ACTION_NAMES.cpp_module_compile, - ACTION_NAMES.cpp_module_codegen, - ACTION_NAMES.cpp_link_executable, - ACTION_NAMES.cpp_link_dynamic_library, - ACTION_NAMES.cpp_link_nodeps_dynamic_library, - ], - flag_groups = [ - flag_group( - flags = ["--cuda-path=" + ctx.attr.cuda_path], - ), - ], - ), - ], - ) - - def_file_feature = feature( - name = "def_file", - flag_sets = [ - flag_set( - actions = all_link_actions, - flag_groups = [ - flag_group( - flags = ["/DEF:%{def_file_path}", "/ignore:4070"], - expand_if_available = "def_file_path", - ), - ], - ), - ], - ) - - if (ctx.attr.cpu == "darwin"): - stdlib_feature = feature( - name = "stdlib", - flag_sets = [ - flag_set( - actions = all_link_actions, - flag_groups = [flag_group(flags = ["-lc++"])], - ), - ], - ) - elif (ctx.attr.cpu == "local"): - stdlib_feature = feature( - name = "stdlib", - flag_sets = [ - flag_set( - actions = all_link_actions, - flag_groups = [flag_group(flags = ["-lstdc++"])], - ), - ], - ) - else: - stdlib_feature = None - - no_stripping_feature = feature(name = "no_stripping") - - alwayslink_feature = feature( - name = "alwayslink", - flag_sets = [ - flag_set( - actions = [ - ACTION_NAMES.cpp_link_dynamic_library, - ACTION_NAMES.cpp_link_nodeps_dynamic_library, - ACTION_NAMES.cpp_link_executable, - ], - flag_groups = [flag_group(flags = ["-Wl,-no-as-needed"])], - ), - ], - ) - - input_param_flags_feature = feature( - name = "input_param_flags", - flag_sets = [ - flag_set( - actions = [ - ACTION_NAMES.cpp_link_dynamic_library, - ACTION_NAMES.cpp_link_nodeps_dynamic_library, - ], - flag_groups = [ - flag_group( - flags = ["/IMPLIB:%{interface_library_output_path}"], - expand_if_available = "interface_library_output_path", - ), - ], - ), - flag_set( - actions = all_link_actions + - [ACTION_NAMES.cpp_link_static_library], - flag_groups = [ - flag_group( - iterate_over = "libraries_to_link", - flag_groups = [ - flag_group( - iterate_over = "libraries_to_link.object_files", - flag_groups = [flag_group(flags = ["%{libraries_to_link.object_files}"])], - expand_if_equal = variable_with_value( - name = "libraries_to_link.type", - value = "object_file_group", - ), - ), - flag_group( - flag_groups = [flag_group(flags = ["%{libraries_to_link.name}"])], - expand_if_equal = variable_with_value( - name = "libraries_to_link.type", - value = "object_file", - ), - ), - flag_group( - flag_groups = [flag_group(flags = ["%{libraries_to_link.name}"])], - expand_if_equal = variable_with_value( - name = "libraries_to_link.type", - value = "interface_library", - ), - ), - flag_group( - flag_groups = [ - flag_group( - flags = ["%{libraries_to_link.name}"], - expand_if_false = "libraries_to_link.is_whole_archive", - ), - flag_group( - flags = ["/WHOLEARCHIVE:%{libraries_to_link.name}"], - expand_if_true = "libraries_to_link.is_whole_archive", - ), - ], - expand_if_equal = variable_with_value( - name = "libraries_to_link.type", - value = "static_library", - ), - ), - ], - expand_if_available = "libraries_to_link", - ), - ], - ), - ], - ) - - if (ctx.attr.cpu == "local"): - no_canonical_prefixes_feature = feature( - name = "no-canonical-prefixes", - flag_sets = [ - flag_set( - actions = [ - ACTION_NAMES.c_compile, - ACTION_NAMES.cpp_compile, - ACTION_NAMES.cpp_link_executable, - ACTION_NAMES.cpp_link_dynamic_library, - ACTION_NAMES.cpp_link_nodeps_dynamic_library, - ], - flag_groups = [ - flag_group( - flags = [ - "-no-canonical-prefixes", - ] + ctx.attr.extra_no_canonical_prefixes_flags, - ), - ], - ), - ], - ) - elif (ctx.attr.cpu == "darwin"): - no_canonical_prefixes_feature = feature( - name = "no-canonical-prefixes", - flag_sets = [ - flag_set( - actions = [ - ACTION_NAMES.c_compile, - ACTION_NAMES.cpp_compile, - ACTION_NAMES.cpp_link_executable, - ACTION_NAMES.cpp_link_dynamic_library, - ACTION_NAMES.cpp_link_nodeps_dynamic_library, - ], - flag_groups = [flag_group(flags = ["-no-canonical-prefixes"])], - ), - ], - ) - else: - no_canonical_prefixes_feature = None - - has_configured_linker_path_feature = feature(name = "has_configured_linker_path") - - copy_dynamic_libraries_to_binary_feature = feature(name = "copy_dynamic_libraries_to_binary") - - user_link_flags_feature = feature( - name = "user_link_flags", - flag_sets = [ - flag_set( - actions = all_link_actions, - flag_groups = [ - flag_group( - flags = ["%{user_link_flags}"], - iterate_over = "user_link_flags", - expand_if_available = "user_link_flags", - ), - ], - ), - ], - ) - - if (ctx.attr.cpu == "local"): - common_feature = feature( - name = "common", - implies = [ - "stdlib", - "determinism", - "alwayslink", - "hardening", - "warnings", - "frame-pointer", - "build-id", - "no-canonical-prefixes", - "linker-bin-path", - ], - ) - elif (ctx.attr.cpu == "darwin"): - common_feature = feature( - name = "common", - implies = [ - "stdlib", - "determinism", - "hardening", - "warnings", - "frame-pointer", - "no-canonical-prefixes", - "linker-bin-path", - "undefined-dynamic", - ], - ) - else: - common_feature = None - - if (ctx.attr.cpu == "local"): - features = [ - default_compile_flags_feature, - stdlib_feature, - determinism_feature, - alwayslink_feature, - pic_feature, - hardening_feature, - warnings_feature, - frame_pointer_feature, - build_id_feature, - no_canonical_prefixes_feature, - disable_assertions_feature, - linker_bin_path_feature, - common_feature, - opt_feature, - fastbuild_feature, - dbg_feature, - supports_dynamic_linker_feature, - supports_pic_feature, - ] - if ctx.attr.cuda_path: - features += [cuda_path_feature] - elif (ctx.attr.cpu == "darwin"): - features = [ - stdlib_feature, - determinism_feature, - pic_feature, - hardening_feature, - warnings_feature, - frame_pointer_feature, - no_canonical_prefixes_feature, - disable_assertions_feature, - linker_bin_path_feature, - undefined_dynamic_feature, - common_feature, - opt_feature, - fastbuild_feature, - dbg_feature, - supports_dynamic_linker_feature, - supports_pic_feature, - ] - elif (ctx.attr.cpu == "x64_windows"): - features = [ - no_legacy_features_feature, - redirector_feature, - nologo_feature, - has_configured_linker_path_feature, - no_stripping_feature, - targets_windows_feature, - copy_dynamic_libraries_to_binary_feature, - default_compile_flags_feature, - msvc_env_feature, - include_paths_feature, - preprocessor_defines_feature, - parse_showincludes_feature, - generate_pdb_file_feature, - shared_flag_feature, - linkstamps_feature, - output_execpath_flags_feature, - archiver_flags_feature, - input_param_flags_feature, - linker_subsystem_flag_feature, - user_link_flags_feature, - default_link_flags_feature, - linker_param_file_feature, - static_link_msvcrt_feature, - static_link_msvcrt_no_debug_feature, - dynamic_link_msvcrt_no_debug_feature, - static_link_msvcrt_debug_feature, - dynamic_link_msvcrt_debug_feature, - dbg_feature, - fastbuild_feature, - opt_feature, - user_compile_flags_feature, - sysroot_feature, - unfiltered_compile_flags_feature, - compiler_output_flags_feature, - compiler_input_flags_feature, - def_file_feature, - windows_export_all_symbols_feature, - no_windows_export_all_symbols_feature, - supports_dynamic_linker_feature, - supports_interface_shared_libraries_feature, - ] - else: - fail("Unreachable") - - cxx_builtin_include_directories = ctx.attr.builtin_include_directories - - if (ctx.attr.cpu == "x64_windows"): - tool_paths = [ + elif cpu == "x64_windows": + return [ tool_path(name = "ar", path = ctx.attr.msvc_lib_path), tool_path(name = "ml", path = ctx.attr.msvc_ml_path), tool_path(name = "cpp", path = ctx.attr.msvc_cl_path), @@ -1452,58 +263,766 @@ def _impl(ctx): path = "wrapper/bin/msvc_nop.bat", ), ] - elif (ctx.attr.cpu == "local"): - tool_paths = [ - tool_path(name = "gcc", path = ctx.attr.host_compiler_path), - tool_path(name = "ar", path = ctx.attr.host_compiler_prefix + "/ar"), - tool_path(name = "compat-ld", path = ctx.attr.host_compiler_prefix + "/ld"), - tool_path(name = "cpp", path = ctx.attr.host_compiler_prefix + "/cpp"), - tool_path(name = "dwp", path = ctx.attr.host_compiler_prefix + "/dwp"), - tool_path(name = "gcov", path = ctx.attr.host_compiler_prefix + "/gcov"), - tool_path(name = "ld", path = ctx.attr.host_compiler_prefix + "/ld"), - tool_path(name = "nm", path = ctx.attr.host_compiler_prefix + "/nm"), - tool_path(name = "objcopy", path = ctx.attr.host_compiler_prefix + "/objcopy"), - tool_path(name = "objdump", path = ctx.attr.host_compiler_prefix + "/objdump"), - tool_path(name = "strip", path = ctx.attr.host_compiler_prefix + "/strip"), + else: + fail("Unreachable") + +def _sysroot_group(): + return flag_group( + flags = ["--sysroot=%{sysroot}"], + expand_if_available = "sysroot", + ) + +def _no_canonical_prefixes_group(extra_flags): + return flag_group( + flags = [ + "-no-canonical-prefixes", + ] + extra_flags, + ) + +def _cuda_set(cuda_path, actions): + if cuda_path: + return flag_set( + actions = actions, + flag_groups = [ + flag_group( + flags = ["--cuda-path=" + cuda_path], + ), + ], + ) + else: + return [] + +def _nologo(): + return flag_group(flags = ["/nologo"]) + +def _features(cpu, compiler, ctx): + if cpu in ["local", "darwin"]: + return [ + feature(name = "no_legacy_features"), + feature( + name = "all_compile_flags", + enabled = True, + flag_sets = [ + flag_set( + actions = all_compile_actions(), + flag_groups = [ + flag_group( + flags = ["-MD", "-MF", "%{dependency_file}"], + expand_if_available = "dependency_file", + ), + flag_group( + flags = ["-gsplit-dwarf"], + expand_if_available = "per_object_debug_info_file", + ), + ], + ), + flag_set( + actions = all_preprocessed_actions(), + flag_groups = [ + flag_group( + flags = ["-frandom-seed=%{output_file}"], + expand_if_available = "output_file", + ), + _iterate_flag_group( + flags = ["-D%{preprocessor_defines}"], + iterate_over = "preprocessor_defines", + ), + _iterate_flag_group( + flags = ["-include", "%{includes}"], + iterate_over = "includes", + ), + _iterate_flag_group( + flags = ["-iquote", "%{quote_include_paths}"], + iterate_over = "quote_include_paths", + ), + _iterate_flag_group( + flags = ["-I%{include_paths}"], + iterate_over = "include_paths", + ), + _iterate_flag_group( + flags = ["-isystem", "%{system_include_paths}"], + iterate_over = "system_include_paths", + ), + _iterate_flag_group( + flags = ["-F", "%{framework_include_paths}"], + iterate_over = "framework_include_paths", + ), + ], + ), + flag_set( + actions = all_cpp_compile_actions(), + flag_groups = [ + flag_group(flags = ["-fexperimental-new-pass-manager"]), + ] if compiler == "clang" else [], + ), + flag_set( + actions = all_compile_actions(), + flag_groups = [ + flag_group( + flags = [ + "-Wno-builtin-macro-redefined", + "-D__DATE__=\"redacted\"", + "-D__TIMESTAMP__=\"redacted\"", + "-D__TIME__=\"redacted\"", + ], + ), + flag_group( + flags = ["-fPIC"], + expand_if_available = "pic", + ), + flag_group( + flags = ["-fPIE"], + expand_if_not_available = "pic", + ), + flag_group( + flags = [ + "-U_FORTIFY_SOURCE", + "-D_FORTIFY_SOURCE=1", + "-fstack-protector", + "-Wall", + ] + ctx.attr.host_compiler_warnings + [ + "-fno-omit-frame-pointer", + ], + ), + _no_canonical_prefixes_group( + ctx.attr.extra_no_canonical_prefixes_flags, + ), + ], + ), + flag_set( + actions = all_compile_actions(), + flag_groups = [flag_group(flags = ["-DNDEBUG"])], + with_features = [with_feature_set(features = ["disable-assertions"])], + ), + flag_set( + actions = all_compile_actions(), + flag_groups = [ + flag_group( + flags = [ + "-g0", + "-O2", + "-ffunction-sections", + "-fdata-sections", + ], + ), + ], + with_features = [with_feature_set(features = ["opt"])], + ), + flag_set( + actions = all_compile_actions(), + flag_groups = [flag_group(flags = ["-g"])], + with_features = [with_feature_set(features = ["dbg"])], + ), + ] + _cuda_set( + ctx.attr.cuda_path, + all_compile_actions, + ) + [ + flag_set( + actions = all_compile_actions(), + flag_groups = [ + _iterate_flag_group( + flags = ["%{user_compile_flags}"], + iterate_over = "user_compile_flags", + ), + _sysroot_group(), + flag_group( + expand_if_available = "source_file", + flags = ["-c", "%{source_file}"], + ), + flag_group( + expand_if_available = "output_assembly_file", + flags = ["-S"], + ), + flag_group( + expand_if_available = "output_preprocess_file", + flags = ["-E"], + ), + flag_group( + expand_if_available = "output_file", + flags = ["-o", "%{output_file}"], + ), + ], + ), + ], + ), + feature( + name = "all_archive_flags", + enabled = True, + flag_sets = [ + flag_set( + actions = all_archive_actions(), + flag_groups = [ + flag_group( + expand_if_available = "linker_param_file", + flags = ["@%{linker_param_file}"], + ), + flag_group(flags = ["rcsD"]), + flag_group( + flags = ["%{output_execpath}"], + expand_if_available = "output_execpath", + ), + flag_group( + iterate_over = "libraries_to_link", + flag_groups = [ + flag_group( + flags = ["%{libraries_to_link.name}"], + expand_if_equal = variable_with_value( + name = "libraries_to_link.type", + value = "object_file", + ), + ), + flag_group( + flags = ["%{libraries_to_link.object_files}"], + iterate_over = "libraries_to_link.object_files", + expand_if_equal = variable_with_value( + name = "libraries_to_link.type", + value = "object_file_group", + ), + ), + ], + expand_if_available = "libraries_to_link", + ), + ], + ), + ], + ), + feature( + name = "all_link_flags", + enabled = True, + flag_sets = [ + flag_set( + actions = all_shared_library_link_actions(), + flag_groups = [flag_group(flags = ["-shared"])], + ), + flag_set( + actions = all_link_actions(), + flag_groups = [ + flag_group( + flags = ["@%{linker_param_file}"], + expand_if_available = "linker_param_file", + ), + _iterate_flag_group( + flags = ["%{linkstamp_paths}"], + iterate_over = "linkstamp_paths", + ), + flag_group( + flags = ["-o", "%{output_execpath}"], + expand_if_available = "output_execpath", + ), + _iterate_flag_group( + flags = ["-L%{library_search_directories}"], + iterate_over = "library_search_directories", + ), + _iterate_flag_group( + iterate_over = "runtime_library_search_directories", + flags = [ + "-Wl,-rpath,$ORIGIN/%{runtime_library_search_directories}", + ] if cpu == "local" else [ + "-Wl,-rpath,@loader_path/%{runtime_library_search_directories}", + ], + ), + _libraries_to_link_group("darwin" if cpu == "darwin" else "linux"), + _iterate_flag_group( + flags = ["%{user_link_flags}"], + iterate_over = "user_link_flags", + ), + flag_group( + flags = ["-Wl,--gdb-index"], + expand_if_available = "is_using_fission", + ), + flag_group( + flags = ["-Wl,-S"], + expand_if_available = "strip_debug_symbols", + ), + flag_group(flags = ["-lc++" if cpu == "darwin" else "-lstdc++"]), + _no_canonical_prefixes_group( + ctx.attr.extra_no_canonical_prefixes_flags, + ), + ], + ), + flag_set( + actions = all_executable_link_actions(), + flag_groups = [flag_group(flags = ["-pie"])], + ), + ] + ([ + flag_set( + actions = all_link_actions(), + flag_groups = [flag_group(flags = [ + "-Wl,-z,relro,-z,now", + ])], + ), + ] if cpu == "local" else []) + [ + flag_set( + actions = all_link_actions(), + flag_groups = [flag_group(flags = ["-Wl,-no-as-needed"])], + with_features = [with_feature_set(features = ["alwayslink"])], + ), + flag_set( + actions = all_link_actions(), + flag_groups = [ + flag_group(flags = ["-B" + ctx.attr.linker_bin_path]), + ], + ), + ] + ([flag_set( + actions = all_link_actions(), + flag_groups = [ + flag_group(flags = ["-Wl,--gc-sections"]), + flag_group( + flags = ["-Wl,--build-id=md5", "-Wl,--hash-style=gnu"], + ), + ], + )] if cpu == "local" else []) + ([ + flag_set( + actions = all_link_actions(), + flag_groups = [flag_group(flags = ["-undefined", "dynamic_lookup"])], + ), + ] if cpu == "darwin" else []) + _cuda_set( + ctx.attr.cuda_path, + all_link_actions(), + ) + [ + flag_set( + actions = all_link_actions(), + flag_groups = [ + _sysroot_group(), + ], + ), + ], + ), + feature(name = "alwayslink", enabled = cpu == "local"), + feature(name = "opt"), + feature(name = "fastbuild"), + feature(name = "dbg"), + feature(name = "supports_dynamic_linker", enabled = True), + feature(name = "pic", enabled = True), + feature(name = "supports_pic", enabled = True), + feature(name = "has_configured_linker_path", enabled = True), ] - elif (ctx.attr.cpu == "darwin"): - tool_paths = [ - tool_path(name = "gcc", path = ctx.attr.host_compiler_path), - tool_path(name = "ar", path = ctx.attr.host_compiler_prefix + "/libtool"), - tool_path(name = "compat-ld", path = ctx.attr.host_compiler_prefix + "/ld"), - tool_path(name = "cpp", path = ctx.attr.host_compiler_prefix + "/cpp"), - tool_path(name = "dwp", path = ctx.attr.host_compiler_prefix + "/dwp"), - tool_path(name = "gcov", path = ctx.attr.host_compiler_prefix + "/gcov"), - tool_path(name = "ld", path = ctx.attr.host_compiler_prefix + "/ld"), - tool_path(name = "nm", path = ctx.attr.host_compiler_prefix + "/nm"), - tool_path(name = "objcopy", path = ctx.attr.host_compiler_prefix + "/objcopy"), - tool_path(name = "objdump", path = ctx.attr.host_compiler_prefix + "/objdump"), - tool_path(name = "strip", path = ctx.attr.host_compiler_prefix + "/strip"), + elif cpu == "x64_windows": + return [ + feature(name = "no_legacy_features"), + feature( + name = "common_flags", + enabled = True, + env_sets = [ + env_set( + actions = all_compile_actions() + all_link_actions() + all_archive_actions(), + env_entries = [ + env_entry(key = "PATH", value = ctx.attr.msvc_env_path), + env_entry(key = "INCLUDE", value = ctx.attr.msvc_env_include), + env_entry(key = "LIB", value = ctx.attr.msvc_env_lib), + env_entry(key = "TMP", value = ctx.attr.msvc_env_tmp), + env_entry(key = "TEMP", value = ctx.attr.msvc_env_tmp), + ], + ), + ], + ), + feature( + name = "all_compile_flags", + enabled = True, + flag_sets = [ + flag_set( + actions = all_compile_actions(), + flag_groups = [ + flag_group( + flags = [ + "-B", + "external/local_config_cuda/crosstool/windows/msvc_wrapper_for_nvcc.py", + ], + ), + _nologo(), + flag_group( + flags = [ + "/DCOMPILER_MSVC", + "/DNOMINMAX", + "/D_WIN32_WINNT=0x0600", + "/D_CRT_SECURE_NO_DEPRECATE", + "/D_CRT_SECURE_NO_WARNINGS", + "/D_SILENCE_STDEXT_HASH_DEPRECATION_WARNINGS", + "/bigobj", + "/Zm500", + "/J", + "/Gy", + "/GF", + "/EHsc", + "/wd4351", + "/wd4291", + "/wd4250", + "/wd4996", + ], + ), + _iterate_flag_group( + flags = ["/I%{quote_include_paths}"], + iterate_over = "quote_include_paths", + ), + _iterate_flag_group( + flags = ["/I%{include_paths}"], + iterate_over = "include_paths", + ), + _iterate_flag_group( + flags = ["/I%{system_include_paths}"], + iterate_over = "system_include_paths", + ), + _iterate_flag_group( + flags = ["/D%{preprocessor_defines}"], + iterate_over = "preprocessor_defines", + ), + ], + ), + flag_set( + actions = all_preprocessed_actions(), + flag_groups = [flag_group(flags = ["/showIncludes"])], + ), + flag_set( + actions = all_compile_actions(), + flag_groups = [flag_group(flags = ["/MT"])], + with_features = [with_feature_set(features = ["static_link_msvcrt_no_debug"])], + ), + flag_set( + actions = all_compile_actions(), + flag_groups = [flag_group(flags = ["/MD"])], + with_features = [with_feature_set(features = ["dynamic_link_msvcrt_no_debug"])], + ), + flag_set( + actions = all_compile_actions(), + flag_groups = [flag_group(flags = ["/MTd"])], + with_features = [with_feature_set(features = ["static_link_msvcrt_debug"])], + ), + flag_set( + actions = all_compile_actions(), + flag_groups = [flag_group(flags = ["/MDd"])], + with_features = [with_feature_set(features = ["dynamic_link_msvcrt_debug"])], + ), + flag_set( + actions = all_compile_actions(), + flag_groups = [flag_group(flags = ["/Od", "/Z7", "/DDEBUG"])], + with_features = [with_feature_set(features = ["dbg"])], + ), + flag_set( + actions = all_compile_actions(), + flag_groups = [flag_group(flags = ["/Od", "/Z7", "/DDEBUG"])], + with_features = [with_feature_set(features = ["fastbuild"])], + ), + flag_set( + actions = all_compile_actions(), + flag_groups = [flag_group(flags = ["/O2", "/DNDEBUG"])], + with_features = [with_feature_set(features = ["opt"])], + ), + flag_set( + actions = all_preprocessed_actions(), + flag_groups = [ + _iterate_flag_group( + flags = ["%{user_compile_flags}"], + iterate_over = "user_compile_flags", + ), + ] + ([ + flag_group(flags = ctx.attr.host_unfiltered_compile_flags), + ] if ctx.attr.host_unfiltered_compile_flags else []), + ), + flag_set( + actions = [ACTION_NAMES.assemble], + flag_groups = [ + flag_group( + flag_groups = [ + flag_group( + flags = ["/Fo%{output_file}", "/Zi"], + expand_if_not_available = "output_preprocess_file", + ), + ], + expand_if_available = "output_file", + expand_if_not_available = "output_assembly_file", + ), + ], + ), + flag_set( + actions = all_preprocessed_actions(), + flag_groups = [ + flag_group( + flag_groups = [ + flag_group( + flags = ["/Fo%{output_file}"], + expand_if_not_available = "output_preprocess_file", + ), + ], + expand_if_available = "output_file", + expand_if_not_available = "output_assembly_file", + ), + flag_group( + flag_groups = [ + flag_group( + flags = ["/Fa%{output_file}"], + expand_if_available = "output_assembly_file", + ), + ], + expand_if_available = "output_file", + ), + flag_group( + flag_groups = [ + flag_group( + flags = ["/P", "/Fi%{output_file}"], + expand_if_available = "output_preprocess_file", + ), + ], + expand_if_available = "output_file", + ), + ], + ), + flag_set( + actions = all_compile_actions(), + flag_groups = [ + flag_group( + flags = ["/c", "%{source_file}"], + expand_if_available = "source_file", + ), + ], + ), + ], + ), + feature( + name = "all_archive_flags", + enabled = True, + flag_sets = [ + flag_set( + actions = all_archive_actions(), + flag_groups = [ + _nologo(), + flag_group( + flags = ["/OUT:%{output_execpath}"], + expand_if_available = "output_execpath", + ), + ], + ), + ], + ), + feature( + name = "all_link_flags", + enabled = True, + flag_sets = [ + flag_set( + actions = all_shared_library_link_actions(), + flag_groups = [flag_group(flags = ["/DLL"])], + ), + flag_set( + actions = all_link_actions(), + flag_groups = [ + _nologo(), + _iterate_flag_group( + flags = ["%{linkstamp_paths}"], + iterate_over = "linkstamp_paths", + ), + flag_group( + flags = ["/OUT:%{output_execpath}"], + expand_if_available = "output_execpath", + ), + ], + ), + flag_set( + actions = all_shared_library_link_actions(), + flag_groups = [ + flag_group( + flags = ["/IMPLIB:%{interface_library_output_path}"], + expand_if_available = "interface_library_output_path", + ), + ], + ), + flag_set( + actions = all_link_actions() + + all_archive_actions(), + flag_groups = [ + _libraries_to_link_group("msvc"), + ], + ), + flag_set( + actions = all_link_actions(), + flag_groups = [ + flag_group(flags = ["/SUBSYSTEM:CONSOLE"]), + _iterate_flag_group( + flags = ["%{user_link_flags}"], + iterate_over = "user_link_flags", + ), + flag_group(flags = ["/MACHINE:X64"]), + ], + ), + flag_set( + actions = all_link_actions() + + all_archive_actions(), + flag_groups = [ + flag_group( + flags = ["@%{linker_param_file}"], + expand_if_available = "linker_param_file", + ), + ], + ), + flag_set( + actions = all_link_actions(), + flag_groups = [flag_group(flags = ["/DEFAULTLIB:libcmt.lib"])], + with_features = [with_feature_set(features = ["static_link_msvcrt_no_debug"])], + ), + flag_set( + actions = all_link_actions(), + flag_groups = [flag_group(flags = ["/DEFAULTLIB:msvcrt.lib"])], + with_features = [with_feature_set(features = ["dynamic_link_msvcrt_no_debug"])], + ), + flag_set( + actions = all_link_actions(), + flag_groups = [flag_group(flags = ["/DEFAULTLIB:libcmtd.lib"])], + with_features = [with_feature_set(features = ["static_link_msvcrt_debug"])], + ), + flag_set( + actions = all_link_actions(), + flag_groups = [flag_group(flags = ["/DEFAULTLIB:msvcrtd.lib"])], + with_features = [with_feature_set(features = ["dynamic_link_msvcrt_debug"])], + ), + flag_set( + actions = all_link_actions(), + flag_groups = [flag_group(flags = ["/DEBUG:FULL", "/INCREMENTAL:NO"])], + with_features = [with_feature_set(features = ["dbg"])], + ), + flag_set( + actions = all_link_actions(), + flag_groups = [ + flag_group(flags = ["/DEBUG:FASTLINK", "/INCREMENTAL:NO"]), + ], + with_features = [with_feature_set(features = ["fastbuild"])], + ), + flag_set( + actions = all_link_actions(), + flag_groups = [ + flag_group( + flags = ["/DEF:%{def_file_path}", "/ignore:4070"], + expand_if_available = "def_file_path", + ), + ], + ), + ], + ), + feature(name = "parse_showincludes", enabled = True), + feature(name = "no_stripping", enabled = True), + feature( + name = "targets_windows", + enabled = True, + implies = ["copy_dynamic_libraries_to_binary"], + ), + feature(name = "copy_dynamic_libraries_to_binary"), + feature( + name = "generate_pdb_file", + requires = [ + feature_set(features = ["dbg"]), + feature_set(features = ["fastbuild"]), + ], + ), + feature(name = "static_link_msvcrt"), + feature( + name = "static_link_msvcrt_no_debug", + requires = [ + feature_set(features = ["fastbuild"]), + feature_set(features = ["opt"]), + ], + ), + feature( + name = "dynamic_link_msvcrt_no_debug", + requires = [ + feature_set(features = ["fastbuild"]), + feature_set(features = ["opt"]), + ], + ), + feature( + name = "static_link_msvcrt_debug", + requires = [feature_set(features = ["dbg"])], + ), + feature( + name = "dynamic_link_msvcrt_debug", + requires = [feature_set(features = ["dbg"])], + ), + feature( + name = "dbg", + implies = ["generate_pdb_file"], + ), + feature( + name = "fastbuild", + implies = ["generate_pdb_file"], + ), + feature( + name = "opt", + ), + feature(name = "windows_export_all_symbols"), + feature(name = "no_windows_export_all_symbols"), + feature(name = "supports_dynamic_linker", enabled = True), + feature( + name = "supports_interface_shared_libraries", + enabled = True, + ), + feature(name = "has_configured_linker_path", enabled = True), ] else: fail("Unreachable") +def _impl(ctx): + cpu = ctx.attr.cpu + compiler = ctx.attr.compiler + + if (cpu == "darwin"): + toolchain_identifier = "local_darwin" + target_cpu = "darwin" + target_libc = "macosx" + compiler = "compiler" + action_configs = _action_configs( + assembly_path = ctx.attr.host_compiler_path, + c_compiler_path = ctx.attr.host_compiler_path, + cc_compiler_path = ctx.attr.host_compiler_path, + archiver_path = ctx.attr.host_compiler_prefix + "/libtool", + linker_path = ctx.attr.host_compiler_path, + strip_path = ctx.attr.host_compiler_prefix + "/strip", + ) + elif (cpu == "local"): + toolchain_identifier = "local_linux" + target_cpu = "local" + target_libc = "local" + compiler = "compiler" + action_configs = _action_configs( + assembly_path = ctx.attr.host_compiler_path, + c_compiler_path = ctx.attr.host_compiler_path, + cc_compiler_path = ctx.attr.host_compiler_path, + archiver_path = ctx.attr.host_compiler_prefix + "/ar", + linker_path = ctx.attr.host_compiler_path, + strip_path = ctx.attr.host_compiler_prefix + "/strip", + ) + elif (cpu == "x64_windows"): + toolchain_identifier = "local_windows" + target_cpu = "x64_windows" + target_libc = "msvcrt" + compiler = "msvc-cl" + action_configs = _action_configs( + assembly_path = ctx.attr.msvc_ml_path, + c_compiler_path = ctx.attr.msvc_cl_path, + cc_compiler_path = ctx.attr.msvc_cl_path, + archiver_path = ctx.attr.msvc_lib_path, + linker_path = ctx.attr.msvc_link_path, + strip_path = "fake_tool_strip_not_supported", + ) + else: + fail("Unreachable") + out = ctx.actions.declare_file(ctx.label.name) ctx.actions.write(out, "Fake executable") return [ cc_common.create_cc_toolchain_config_info( ctx = ctx, - features = features, + features = _features(cpu, compiler, ctx), action_configs = action_configs, artifact_name_patterns = [], - cxx_builtin_include_directories = cxx_builtin_include_directories, + cxx_builtin_include_directories = ctx.attr.builtin_include_directories, toolchain_identifier = toolchain_identifier, - host_system_name = host_system_name, - target_system_name = target_system_name, + host_system_name = "local", + target_system_name = "local", target_cpu = target_cpu, target_libc = target_libc, compiler = compiler, - abi_version = abi_version, - abi_libc_version = abi_libc_version, - tool_paths = tool_paths, + abi_version = "local", + abi_libc_version = "local", + tool_paths = _tool_paths(cpu, ctx), make_variables = [], - builtin_sysroot = builtin_sysroot, - cc_target_os = cc_target_os, + builtin_sysroot = ctx.attr.builtin_sysroot, + cc_target_os = None, ), DefaultInfo( executable = out, @@ -1514,6 +1033,7 @@ cc_toolchain_config = rule( implementation = _impl, attrs = { "cpu": attr.string(mandatory = True, values = ["darwin", "local", "x64_windows"]), + "compiler": attr.string(values = ["clang", "msvc", "unknown"], default = "unknown"), "builtin_include_directories": attr.string_list(), "extra_no_canonical_prefixes_flags": attr.string_list(), "host_compiler_path": attr.string(), @@ -1531,7 +1051,6 @@ cc_toolchain_config = rule( "msvc_lib_path": attr.string(default = "msvc_not_used"), "msvc_link_path": attr.string(default = "msvc_not_used"), "msvc_ml_path": attr.string(default = "msvc_not_used"), - "compiler": attr.string(values = ["clang", "msvc", "unknown"], default="unknown"), }, provides = [CcToolchainConfigInfo], executable = True, diff --git a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl index 303339e77f7..9cc06ef99f5 100755 --- a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl +++ b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl @@ -221,8 +221,12 @@ def InvokeNvcc(argv, log=False): nvccopts = '-D_FORCE_INLINES ' for capability in GetOptionValue(argv, "--cuda-gpu-arch"): capability = capability[len('sm_'):] - nvccopts += r'-gencode=arch=compute_%s,\"code=sm_%s,compute_%s\" ' % ( - capability, capability, capability) + nvccopts += r'-gencode=arch=compute_%s,\"code=sm_%s\" ' % (capability, + capability) + for capability in GetOptionValue(argv, '--cuda-include-ptx'): + capability = capability[len('sm_'):] + nvccopts += r'-gencode=arch=compute_%s,\"code=compute_%s\" ' % (capability, + capability) nvccopts += nvcc_compiler_options nvccopts += undefines nvccopts += defines diff --git a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl index f5ac7b39dfd..89275128a9c 100755 --- a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl +++ b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl @@ -179,7 +179,7 @@ def InvokeHipcc(argv, log=False): # Also we need to retain warning about uninitialised shared variable as # warning only, even when -Werror option is specified. if HIPCC_IS_HIPCLANG: - hipccopts += ' --include=hip/hip_runtime.h -Wno-error=cuda-shared-init ' + hipccopts += ' --include=hip/hip_runtime.h ' hipccopts += ' ' + hipcc_compiler_options # Use -fno-gpu-rdc by default for early GPU kernel finalization # This flag would trigger GPU kernels be generated at compile time, instead @@ -258,6 +258,8 @@ def main(): gpu_linker_flags.append('-L' + HIP_RUNTIME_PATH) gpu_linker_flags.append('-Wl,-rpath=' + HIP_RUNTIME_PATH) gpu_linker_flags.append('-l' + HIP_RUNTIME_LIBRARY) + if HIPCC_IS_HIPCLANG: + gpu_linker_flags.append("-lrt") if VERBOSE: print(' '.join([CPU_COMPILER] + gpu_linker_flags)) return subprocess.call([CPU_COMPILER] + gpu_linker_flags) diff --git a/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl b/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl index de6512e3088..c00e7077b59 100644 --- a/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl +++ b/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl @@ -138,10 +138,18 @@ def InvokeNvcc(argv, log=False): nvccopts = ['-D_FORCE_INLINES'] compute_capabilities, argv = GetOptionValue(argv, "--cuda-gpu-arch") for capability in compute_capabilities: - print(capability) capability = capability[len('sm_'):] - nvccopts += [r'-gencode=arch=compute_%s,"code=sm_%s,compute_%s"' % ( - capability, capability, capability)] + nvccopts += [ + r'-gencode=arch=compute_%s,"code=sm_%s"' % (capability, capability) + ] + compute_capabilities, argv = GetOptionValue(argv, '--cuda-include-ptx') + for capability in compute_capabilities: + capability = capability[len('sm_'):] + nvccopts += [ + r'-gencode=arch=compute_%s,"code=compute_%s"' % (capability, capability) + ] + _, argv = GetOptionValue(argv, '--no-cuda-include-ptx') + nvccopts += nvcc_compiler_options nvccopts += undefines nvccopts += defines diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl index aa8a2f0226d..35e86d8d77b 100644 --- a/third_party/gpus/cuda_configure.bzl +++ b/third_party/gpus/cuda_configure.bzl @@ -66,8 +66,6 @@ _TF_CUDA_CONFIG_REPO = "TF_CUDA_CONFIG_REPO" _TF_DOWNLOAD_CLANG = "TF_DOWNLOAD_CLANG" _PYTHON_BIN_PATH = "PYTHON_BIN_PATH" -_DEFAULT_CUDA_COMPUTE_CAPABILITIES = ["3.5", "5.2"] - def to_list_of_strings(elements): """Convert the list of ["a", "b", "c"] into '"a", "b", "c"'. @@ -410,18 +408,40 @@ _NVCC_VERSION_PREFIX = "Cuda compilation tools, release " _DEFINE_CUDNN_MAJOR = "#define CUDNN_MAJOR" def compute_capabilities(repository_ctx): - """Returns a list of strings representing cuda compute capabilities.""" - capabilities_str = get_host_environ(repository_ctx, _TF_CUDA_COMPUTE_CAPABILITIES) - if capabilities_str == None: - return _DEFAULT_CUDA_COMPUTE_CAPABILITIES - capabilities = capabilities_str.split(",") - for capability in capabilities: - # Workaround for Skylark's lack of support for regex. This check should - # be equivalent to checking: - # if re.match("[0-9]+.[0-9]+", capability) == None: + """Returns a list of strings representing cuda compute capabilities. + + Args: + repository_ctx: the repo rule's context. + Returns: list of cuda architectures to compile for. 'compute_xy' refers to + both PTX and SASS, 'sm_xy' refers to SASS only. + """ + capabilities = get_host_environ( + repository_ctx, + _TF_CUDA_COMPUTE_CAPABILITIES, + "compute_35,compute_52", + ).split(",") + + # Map old 'x.y' capabilities to 'compute_xy'. + for i, capability in enumerate(capabilities): parts = capability.split(".") - if len(parts) != 2 or not parts[0].isdigit() or not parts[1].isdigit(): + if len(parts) != 2: + continue + capabilities[i] = "compute_%s%s" % (parts[0], parts[1]) + + # Make list unique + capabilities = dict(zip(capabilities, capabilities)).keys() + + # Validate capabilities. + for capability in capabilities: + if not capability.startswith(("compute_", "sm_")): auto_configure_fail("Invalid compute capability: %s" % capability) + for prefix in ["compute_", "sm_"]: + if not capability.startswith(prefix): + continue + if len(capability) == len(prefix) + 2 and capability[-2:].isdigit(): + continue + auto_configure_fail("Invalid compute capability: %s" % capability) + return capabilities def lib_name(base_name, cpu_value, version = None, static = False): @@ -809,23 +829,35 @@ def make_copy_files_rule(repository_ctx, name, srcs, outs): cmd = \"""%s \""", )""" % (name, "\n".join(outs), " && \\\n".join(cmds)) -def make_copy_dir_rule(repository_ctx, name, src_dir, out_dir): - """Returns a rule to recursively copy a directory.""" +def make_copy_dir_rule(repository_ctx, name, src_dir, out_dir, exceptions = None): + """Returns a rule to recursively copy a directory. + If exceptions is not None, it must be a list of files or directories in + 'src_dir'; these will be excluded from copying. + """ src_dir = _norm_path(src_dir) out_dir = _norm_path(out_dir) outs = read_dir(repository_ctx, src_dir) + post_cmd = "" + if exceptions != None: + outs = [x for x in outs if not any([ + x.startswith(src_dir + "/" + y) + for y in exceptions + ])] outs = [(' "%s",' % out.replace(src_dir, out_dir)) for out in outs] # '@D' already contains the relative path for a single file, see # http://docs.bazel.build/versions/master/be/make-variables.html#predefined_genrule_variables out_dir = "$(@D)/%s" % out_dir if len(outs) > 1 else "$(@D)" + if exceptions != None: + for x in exceptions: + post_cmd += " ; rm -fR " + out_dir + "/" + x return """genrule( name = "%s", outs = [ %s ], - cmd = \"""cp -rLf "%s/." "%s/" \""", -)""" % (name, "\n".join(outs), src_dir, out_dir) + cmd = \"""cp -rLf "%s/." "%s/" %s\""", +)""" % (name, "\n".join(outs), src_dir, out_dir, post_cmd) def _flag_enabled(repository_ctx, flag_name): return get_host_environ(repository_ctx, flag_name) == "1" @@ -837,22 +869,15 @@ def _tf_sysroot(repository_ctx): return get_host_environ(repository_ctx, _TF_SYSROOT, "") def _compute_cuda_extra_copts(repository_ctx, compute_capabilities): - capability_flags = [ - "--cuda-gpu-arch=sm_" + cap.replace(".", "") - for cap in compute_capabilities - ] + capability_flags = ["--no-cuda-include-ptx=all"] + for capability in compute_capabilities: + if capability.startswith("compute_"): + capability = capability.replace("compute_", "sm_") + capability_flags.append("--cuda-include-ptx=%s" % capability) + capability_flags.append("--cuda-gpu-arch=%s" % capability) + return str(capability_flags) -def _compute_cuda_gpu_architectures(repository_ctx, compute_capabilities): - gpu_architectures = [ - "sm_" + capability.replace(".", "") - for capability in compute_capabilities - ] - - # Make the list unique. - gpu_architectures = dict(zip(gpu_architectures, gpu_architectures)).keys() - return str(gpu_architectures) - def _tpl_path(repository_ctx, filename): return repository_ctx.path(Label("//third_party/gpus/%s.tpl" % filename)) @@ -984,10 +1009,7 @@ def _create_local_cuda_repository(repository_ctx): repository_ctx, cuda_config.compute_capabilities, ), - "%{cuda_gpu_architectures}": _compute_cuda_gpu_architectures( - repository_ctx, - cuda_config.compute_capabilities, - ), + "%{cuda_gpu_architectures}": str(cuda_config.compute_capabilities), }, ) diff --git a/third_party/gpus/rocm_configure.bzl b/third_party/gpus/rocm_configure.bzl index 3c345e6724b..4cfec2459e4 100644 --- a/third_party/gpus/rocm_configure.bzl +++ b/third_party/gpus/rocm_configure.bzl @@ -615,6 +615,7 @@ def _create_local_rocm_repository(repository_ctx): name = "rocm-include", src_dir = rocm_toolkit_path + "/include", out_dir = "rocm/include", + exceptions = ["gtest", "gmock"], ), make_copy_dir_rule( repository_ctx, diff --git a/third_party/mkl/build_defs.bzl b/third_party/mkl/build_defs.bzl index 4b8fb83eb09..bd0686523bc 100644 --- a/third_party/mkl/build_defs.bzl +++ b/third_party/mkl/build_defs.bzl @@ -107,6 +107,7 @@ def mkl_deps(): return select({ "@org_tensorflow//third_party/mkl_dnn:build_with_mkl_dnn_only": ["@mkl_dnn"], "@org_tensorflow//third_party/mkl_dnn:build_with_mkl_dnn_v1_only": ["@mkl_dnn_v1//:mkl_dnn"], + "@org_tensorflow//third_party/mkl_dnn:build_with_mkldnn_threadpool": ["@mkl_dnn_v1//:mkl_dnn"], "@org_tensorflow//third_party/mkl:build_with_mkl_ml_only": ["@org_tensorflow//third_party/mkl:intel_binary_blob"], "@org_tensorflow//third_party/mkl:build_with_mkl": [ "@org_tensorflow//third_party/mkl:intel_binary_blob", diff --git a/third_party/mkl_dnn/BUILD b/third_party/mkl_dnn/BUILD index 774e5b0e2c0..fe558322916 100644 --- a/third_party/mkl_dnn/BUILD +++ b/third_party/mkl_dnn/BUILD @@ -27,6 +27,15 @@ config_setting( visibility = ["//visibility:public"], ) +config_setting( + name = "build_with_mkldnn_threadpool", + define_values = { + "build_with_mkl": "true", + "build_with_mkldnn_threadpool": "true", + }, + visibility = ["//visibility:public"], +) + bzl_library( name = "build_defs_bzl", srcs = ["build_defs.bzl"], diff --git a/third_party/mkl_dnn/build_defs.bzl b/third_party/mkl_dnn/build_defs.bzl index af05333c947..bd3b4b94f29 100644 --- a/third_party/mkl_dnn/build_defs.bzl +++ b/third_party/mkl_dnn/build_defs.bzl @@ -29,3 +29,19 @@ def if_mkl_v1_open_source_only(if_true, if_false = []): "@org_tensorflow//third_party/mkl_dnn:build_with_mkl_dnn_v1_only": if_true, "//conditions:default": if_false, }) + +def if_mkldnn_threadpool(if_true, if_false = []): + """Returns `if_true` if MKL-DNN v1.x is used. + + Shorthand for select()'ing on whether we're building with + MKL-DNN v1.x open source library only with user specified threadpool, without depending on MKL binary form. + + Returns a select statement which evaluates to if_true if we're building + with MKL-DNN v1.x open source library only with user specified threadpool. Otherwise, the + select statement evaluates to if_false. + + """ + return select({ + "@org_tensorflow//third_party/mkl_dnn:build_with_mkldnn_threadpool": if_true, + "//conditions:default": if_false, + }) diff --git a/third_party/mkl_dnn/mkldnn_v1.BUILD b/third_party/mkl_dnn/mkldnn_v1.BUILD index 243ec00a60f..313f81c8108 100644 --- a/third_party/mkl_dnn/mkldnn_v1.BUILD +++ b/third_party/mkl_dnn/mkldnn_v1.BUILD @@ -4,6 +4,7 @@ load( "@org_tensorflow//third_party/mkl_dnn:build_defs.bzl", "if_mkl_open_source_only", "if_mkl_v1_open_source_only", + "if_mkldnn_threadpool", ) load( "@org_tensorflow//third_party:common.bzl", @@ -18,15 +19,26 @@ config_setting( }, ) +_DNNL_RUNTIME_OMP = { + "#cmakedefine DNNL_CPU_THREADING_RUNTIME DNNL_RUNTIME_${DNNL_CPU_THREADING_RUNTIME}": "#define DNNL_CPU_THREADING_RUNTIME DNNL_RUNTIME_OMP", + "#cmakedefine DNNL_CPU_RUNTIME DNNL_RUNTIME_${DNNL_CPU_RUNTIME}": "#define DNNL_CPU_RUNTIME DNNL_RUNTIME_OMP", + "#cmakedefine DNNL_GPU_RUNTIME DNNL_RUNTIME_${DNNL_GPU_RUNTIME}": "#define DNNL_GPU_RUNTIME DNNL_RUNTIME_NONE", +} + +_DNNL_RUNTIME_THREADPOOL = { + "#cmakedefine DNNL_CPU_THREADING_RUNTIME DNNL_RUNTIME_${DNNL_CPU_THREADING_RUNTIME}": "#define DNNL_CPU_THREADING_RUNTIME DNNL_RUNTIME_THREADPOOL", + "#cmakedefine DNNL_CPU_RUNTIME DNNL_RUNTIME_${DNNL_CPU_RUNTIME}": "#define DNNL_CPU_RUNTIME DNNL_RUNTIME_THREADPOOL", + "#cmakedefine DNNL_GPU_RUNTIME DNNL_RUNTIME_${DNNL_GPU_RUNTIME}": "#define DNNL_GPU_RUNTIME DNNL_RUNTIME_NONE", +} + template_rule( name = "dnnl_config_h", src = "include/dnnl_config.h.in", out = "include/dnnl_config.h", - substitutions = { - "#cmakedefine DNNL_CPU_THREADING_RUNTIME DNNL_RUNTIME_${DNNL_CPU_THREADING_RUNTIME}": "#define DNNL_CPU_THREADING_RUNTIME DNNL_RUNTIME_OMP", - "#cmakedefine DNNL_CPU_RUNTIME DNNL_RUNTIME_${DNNL_CPU_RUNTIME}": "#define DNNL_CPU_RUNTIME DNNL_RUNTIME_OMP", - "#cmakedefine DNNL_GPU_RUNTIME DNNL_RUNTIME_${DNNL_GPU_RUNTIME}": "#define DNNL_GPU_RUNTIME DNNL_RUNTIME_NONE", - }, + substitutions = if_mkldnn_threadpool( + _DNNL_RUNTIME_THREADPOOL, + if_false = _DNNL_RUNTIME_OMP, + ), ) # Create the file mkldnn_version.h with MKL-DNN version numbers. @@ -59,9 +71,10 @@ cc_library( "src/cpu/**/*.cpp", "src/cpu/**/*.hpp", "src/cpu/xbyak/*.h", - ]) + if_mkl_v1_open_source_only([ + ]) + [ ":dnnl_config_h", - ]) + [":dnnl_version_h"], + ":dnnl_version_h", + ], hdrs = glob(["include/*"]), copts = [ "-fexceptions", diff --git a/third_party/mlir/BUILD b/third_party/mlir/BUILD index 1bddf2180bc..a57088432e2 100644 --- a/third_party/mlir/BUILD +++ b/third_party/mlir/BUILD @@ -175,6 +175,7 @@ filegroup( filegroup( name = "AffineOpsTdFiles", srcs = [ + "include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td", "include/mlir/Dialect/Affine/IR/AffineOps.td", "include/mlir/Dialect/Affine/IR/AffineOpsBase.td", "include/mlir/Interfaces/LoopLikeInterface.td", @@ -207,6 +208,26 @@ gentbl( ], ) +gentbl( + name = "AffineMemoryOpInterfacesIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + "-gen-op-interface-decls", + "include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h.inc", + ), + ( + "-gen-op-interface-defs", + "include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.cpp.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td", + td_srcs = [ + ":AffineOpsTdFiles", + ], +) + ##---------------------------------------------------------------------------## # AVX512 dialect. ##---------------------------------------------------------------------------## @@ -462,6 +483,7 @@ cc_library( ]), includes = ["include"], deps = [ + ":AffineMemoryOpInterfacesIncGen", ":AffineOpsIncGen", ":EDSC", ":IR", @@ -677,6 +699,7 @@ cc_library( deps = [ ":CallOpInterfaces", ":CommonFolders", + ":ControlFlowInterfaces", ":Dialect", ":IR", ":InferTypeOpInterface", @@ -788,9 +811,6 @@ cc_library( "lib/Support/*.h", ], exclude = [ - # TODO(herhut): Move JitRunner out of Support so that Support does not - # depend on dialect. - "lib/Support/JitRunner.cpp", # TODO(jpienaar): Move this out, else Support depends on Analysis/ "lib/Support/MlirOptMain.cpp", ], @@ -1156,6 +1176,28 @@ cc_library( ], ) +cc_library( + name = "GPURuntimeTransforms", + srcs = [ + "lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp", + "lib/Conversion/PassDetail.h", + ], + hdrs = [ + "include/mlir/Conversion/GPUCommon/GPUCommonPass.h", + ], + includes = ["include"], + deps = [ + ":ConversionPassIncGen", + ":GPUDialect", + ":IR", + ":LLVMDialect", + ":Pass", + ":Support", + "@llvm-project//llvm:core", + "@llvm-project//llvm:support", + ], +) + gentbl( name = "GPUToNVVMGen", strip_include_prefix = "lib/Conversion/GPUToNVVM", @@ -1268,7 +1310,6 @@ cc_library( name = "GPUToCUDATransforms", srcs = [ "lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp", - "lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp", "lib/Conversion/PassDetail.h", ], hdrs = ["include/mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h"], @@ -2005,9 +2046,9 @@ cc_library( ) cc_library( - name = "LoopsToGPU", - srcs = ["lib/Conversion/LoopsToGPU/LoopsToGPU.cpp"], - hdrs = ["include/mlir/Conversion/LoopsToGPU/LoopsToGPU.h"], + name = "SCFToGPU", + srcs = ["lib/Conversion/SCFToGPU/SCFToGPU.cpp"], + hdrs = ["include/mlir/Conversion/SCFToGPU/SCFToGPU.h"], includes = ["include"], deps = [ ":Affine", @@ -2027,22 +2068,22 @@ cc_library( ) cc_library( - name = "LoopsToGPUPass", + name = "SCFToGPUPass", srcs = [ - "lib/Conversion/LoopsToGPU/LoopsToGPUPass.cpp", "lib/Conversion/PassDetail.h", + "lib/Conversion/SCFToGPU/SCFToGPUPass.cpp", ], hdrs = [ - "include/mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h", + "include/mlir/Conversion/SCFToGPU/SCFToGPUPass.h", ], includes = ["include"], deps = [ ":Affine", ":ConversionPassIncGen", ":GPUDialect", - ":LoopsToGPU", ":Pass", ":SCFDialect", + ":SCFToGPU", ":StandardOps", ":Support", ":Transforms", @@ -2053,11 +2094,11 @@ cc_library( cc_library( name = "CFGTransforms", srcs = [ - "lib/Conversion/LoopToStandard/LoopToStandard.cpp", "lib/Conversion/PassDetail.h", + "lib/Conversion/SCFToStandard/SCFToStandard.cpp", ], hdrs = [ - "include/mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h", + "include/mlir/Conversion/SCFToStandard/SCFToStandard.h", ], includes = ["include"], deps = [ @@ -2232,10 +2273,10 @@ gentbl( cc_library( name = "SideEffects", srcs = [ - "lib/Interfaces/SideEffects.cpp", + "lib/Interfaces/SideEffectInterfaces.cpp", ], hdrs = [ - "include/mlir/Interfaces/SideEffects.h", + "include/mlir/Interfaces/SideEffectInterfaces.h", ], includes = ["include"], deps = [ @@ -2449,6 +2490,7 @@ cc_library( includes = ["include"], deps = [ ":Analysis", + ":GPURuntimeTransforms", ":GPUToNVVMTransforms", ":GPUToROCDLTransforms", ":GPUToSPIRVTransforms", @@ -2458,6 +2500,7 @@ cc_library( ":LLVMTransforms", ":LinalgToLLVM", ":LinalgToSPIRV", + ":LinalgToStandard", ":NVVMDialect", ":Parser", ":Pass", @@ -2467,7 +2510,7 @@ cc_library( ":Support", ":Transforms", ":VectorToLLVM", - ":VectorToLoops", + ":VectorToSCF", "@llvm-project//llvm:support", "@llvm-project//mlir/test:TestAffine", "@llvm-project//mlir/test:TestDialect", @@ -2527,6 +2570,7 @@ cc_library( ":ConversionPassIncGen", ":GPUDialect", ":GPUPassIncGen", + ":GPURuntimeTransforms", ":GPUToCUDATransforms", ":GPUToNVVMTransforms", ":GPUToROCDLTransforms", @@ -2543,15 +2587,16 @@ cc_library( ":LinalgPassIncGen", ":LinalgToLLVM", ":LinalgToSPIRV", + ":LinalgToStandard", ":LinalgTransforms", ":LoopPassIncGen", - ":LoopsToGPUPass", ":NVVMDialect", ":OpenMPDialect", ":QuantOps", ":QuantPassIncGen", ":ROCDLDialect", ":SCFDialect", + ":SCFToGPUPass", ":SCFTransforms", ":SDBM", ":SPIRVDialect", @@ -2584,6 +2629,7 @@ cc_library( srcs = [ "tools/mlir-opt/mlir-opt.cpp", ], + copts = ["-DMLIR_INCLUDE_TESTS"], deps = [ ":AllPassesAndDialectsNoRegistration", ":Analysis", @@ -2600,11 +2646,11 @@ cc_binary( deps = [ ":Analysis", ":IR", - ":LoopsToGPUPass", ":MlirOptLib", ":MlirOptMain", ":OpenMPDialect", ":QuantOps", + ":SCFToGPUPass", ":Transforms", "@llvm-project//llvm:all_targets", "@llvm-project//llvm:support", @@ -2619,8 +2665,8 @@ cc_binary( cc_library( name = "MlirJitRunner", - srcs = ["lib/Support/JitRunner.cpp"], - hdrs = ["include/mlir/Support/JitRunner.h"], + srcs = ["lib/ExecutionEngine/JitRunner.cpp"], + hdrs = ["include/mlir/ExecutionEngine/JitRunner.h"], includes = ["include"], deps = [ ":AllPassesAndDialectsNoRegistration", @@ -2680,6 +2726,7 @@ cc_binary( srcs = ["tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp"], linkshared = True, deps = [ + ":mlir_c_runner_utils", "//third_party/gpus/cuda:cuda_headers", "//third_party/gpus/cuda:cuda_runtime", "//third_party/gpus/cuda:libcuda", @@ -2729,6 +2776,7 @@ cc_binary( ":AllPassesAndDialectsNoRegistration", ":ExecutionEngineUtils", ":GPUDialect", + ":GPURuntimeTransforms", ":GPUToNVVMTransforms", ":GPUToROCDLTransforms", ":GPUTransforms", @@ -3115,7 +3163,32 @@ cc_library( ":Support", ":Transforms", ":VectorToLLVM", - ":VectorToLoops", + ":VectorToSCF", + "@llvm-project//llvm:core", + "@llvm-project//llvm:support", + ], +) + +cc_library( + name = "LinalgToStandard", + srcs = glob([ + "lib/Conversion/LinalgToStandard/*.cpp", + "lib/Conversion/LinalgToStandard/*.h", + ]) + ["lib/Conversion/PassDetail.h"], + hdrs = glob([ + "include/mlir/Conversion/LinalgToStandard/*.h", + ]), + includes = ["include"], + deps = [ + ":Affine", + ":ConversionPassIncGen", + ":IR", + ":LinalgOps", + ":Pass", + ":SCFDialect", + ":StandardOps", + ":Support", + ":Transforms", "@llvm-project//llvm:core", "@llvm-project//llvm:support", ], @@ -3328,13 +3401,13 @@ cc_library( ) cc_library( - name = "VectorToLoops", + name = "VectorToSCF", srcs = glob([ - "lib/Conversion/VectorToLoops/*.cpp", - "lib/Conversion/VectorToLoops/*.h", + "lib/Conversion/VectorToSCF/*.cpp", + "lib/Conversion/VectorToSCF/*.h", ]), hdrs = glob([ - "include/mlir/Conversion/VectorToLoops/*.h", + "include/mlir/Conversion/VectorToSCF/*.h", ]), includes = ["include"], deps = [ diff --git a/third_party/mlir/test.BUILD b/third_party/mlir/test.BUILD index eb5d8a650eb..24b310f076e 100644 --- a/third_party/mlir/test.BUILD +++ b/third_party/mlir/test.BUILD @@ -171,7 +171,7 @@ cc_library( "@llvm-project//mlir:Transforms", "@llvm-project//mlir:VectorOps", "@llvm-project//mlir:VectorToLLVM", - "@llvm-project//mlir:VectorToLoops", + "@llvm-project//mlir:VectorToSCF", ], ) diff --git a/third_party/nccl/build_defs.bzl.tpl b/third_party/nccl/build_defs.bzl.tpl index e734e49f9dc..9268af7c890 100644 --- a/third_party/nccl/build_defs.bzl.tpl +++ b/third_party/nccl/build_defs.bzl.tpl @@ -1,6 +1,6 @@ """Repository rule for NCCL.""" -load("@local_config_cuda//cuda:build_defs.bzl", "cuda_default_copts") +load("@local_config_cuda//cuda:build_defs.bzl", "cuda_default_copts", "cuda_gpu_architectures") load("@bazel_tools//tools/cpp:toolchain_utils.bzl", "find_cpp_toolchain") def _gen_device_srcs_impl(ctx): @@ -84,6 +84,7 @@ def _device_link_impl(ctx): cubins = [] images = [] for arch in ctx.attr.gpu_archs: + arch = arch.replace("compute_", "sm_") # PTX is JIT-linked at runtime. cubin = ctx.actions.declare_file("%s_%s.cubin" % (name, arch)) register_h = ctx.actions.declare_file("%s_register_%s.h" % (name, arch)) ctx.actions.run( @@ -285,7 +286,7 @@ def cuda_rdc_library(name, hdrs = None, copts = None, linkstatic = True, **kwarg name = dlink_hdrs, deps = [lib], out = dlink_cc, - gpu_archs = %{gpu_architectures}, + gpu_archs = cuda_gpu_architectures(), nvlink_args = select({ "@org_tensorflow//tensorflow:linux_x86_64": ["--cpu-arch=X86_64"], "@org_tensorflow//tensorflow:linux_ppc64le": ["--cpu-arch=PPC64LE"], diff --git a/third_party/nccl/nccl_configure.bzl b/third_party/nccl/nccl_configure.bzl index 92acb204097..d59e861d70b 100644 --- a/third_party/nccl/nccl_configure.bzl +++ b/third_party/nccl/nccl_configure.bzl @@ -13,7 +13,6 @@ load( "//third_party/gpus:cuda_configure.bzl", - "compute_capabilities", "enable_cuda", "find_cuda_config", ) @@ -84,16 +83,7 @@ def _create_local_nccl_repository(repository_ctx): # Alias to open source build from @nccl_archive. repository_ctx.file("BUILD", _NCCL_ARCHIVE_BUILD_CONTENT) - # TODO(csigg): implement and reuse in cuda_configure.bzl. - gpu_architectures = [ - "sm_" + capability.replace(".", "") - for capability in compute_capabilities(repository_ctx) - ] - - # Round-about way to make the list unique. - gpu_architectures = dict(zip(gpu_architectures, gpu_architectures)).keys() config_wrap = { - "%{gpu_architectures}": str(gpu_architectures), "%{use_bin2c_path}": "False", } if (int(cuda_major), int(cuda_minor)) <= (10, 1): diff --git a/third_party/psimd/workspace.bzl b/third_party/psimd/workspace.bzl index 03d010c3db8..768fd6da839 100644 --- a/third_party/psimd/workspace.bzl +++ b/third_party/psimd/workspace.bzl @@ -5,11 +5,11 @@ load("//third_party:repo.bzl", "third_party_http_archive") def repo(): third_party_http_archive( name = "psimd", - strip_prefix = "psimd-85427dd4c8521cc037a1ffa6fcd25c55fafc8a00", - sha256 = "db23c2bc4a58d6f40c181797e43103300edac7cf9d286ca81590543f66ab95d2", + strip_prefix = "psimd-072586a71b55b7f8c584153d223e95687148a900", + sha256 = "dc615342bcbe51ca885323e51b68b90ed9bb9fa7df0f4419dbfa0297d5e837b7", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/github.com/Maratyszcza/psimd/archive/85427dd4c8521cc037a1ffa6fcd25c55fafc8a00.zip", - "https://github.com/Maratyszcza/psimd/archive/85427dd4c8521cc037a1ffa6fcd25c55fafc8a00.zip", + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/Maratyszcza/psimd/archive/072586a71b55b7f8c584153d223e95687148a900.zip", + "https://github.com/Maratyszcza/psimd/archive/072586a71b55b7f8c584153d223e95687148a900.zip", ], build_file = "//third_party/psimd:BUILD.bazel", ) diff --git a/third_party/toolchains/embedded/arm-linux/arm_linux_toolchain_configure.bzl b/third_party/toolchains/embedded/arm-linux/arm_linux_toolchain_configure.bzl index da4282d0215..af34133f27c 100644 --- a/third_party/toolchains/embedded/arm-linux/arm_linux_toolchain_configure.bzl +++ b/third_party/toolchains/embedded/arm-linux/arm_linux_toolchain_configure.bzl @@ -10,6 +10,16 @@ def _tpl(repository_ctx, tpl, substitutions = {}, out = None): ) def _arm_linux_toolchain_configure_impl(repository_ctx): + # We need to find a cross-compilation include directory for Python, so look + # for an environment variable. Be warned, this crosstool template is only + # regenerated on the first run of Bazel, so if you change the variable after + # it may not be reflected in later builds. Doing a shutdown and clean of Bazel + # doesn't fix this, you'll need to delete the generated file at something like: + # external/local_config_arm_compiler/CROSSTOOL in your Bazel install. + if "CROSSTOOL_PYTHON_INCLUDE_PATH" in repository_ctx.os.environ: + python_include_path = repository_ctx.os.environ["CROSSTOOL_PYTHON_INCLUDE_PATH"] + else: + python_include_path = "/usr/include/python3.5" _tpl(repository_ctx, "cc_config.bzl", { "%{AARCH64_COMPILER_PATH}%": str(repository_ctx.path( repository_ctx.attr.aarch64_repo, @@ -17,6 +27,7 @@ def _arm_linux_toolchain_configure_impl(repository_ctx): "%{ARMHF_COMPILER_PATH}%": str(repository_ctx.path( repository_ctx.attr.armhf_repo, )), + "%{PYTHON_INCLUDE_PATH}%": python_include_path, }) repository_ctx.symlink(repository_ctx.attr.build_file, "BUILD") diff --git a/third_party/toolchains/embedded/arm-linux/cc_config.bzl.tpl b/third_party/toolchains/embedded/arm-linux/cc_config.bzl.tpl index 06aaaecfa74..afbea6a3e34 100644 --- a/third_party/toolchains/embedded/arm-linux/cc_config.bzl.tpl +++ b/third_party/toolchains/embedded/arm-linux/cc_config.bzl.tpl @@ -252,6 +252,10 @@ def _impl(ctx): "%{AARCH64_COMPILER_PATH}%/aarch64-linux-gnu/include/c++/8.3.0/", "-isystem", "%{AARCH64_COMPILER_PATH}%/aarch64-linux-gnu/libc/usr/include/", + "-isystem", + "%{PYTHON_INCLUDE_PATH}%", + "-isystem", + "/usr/include/", ], ), ], @@ -347,6 +351,10 @@ def _impl(ctx): "%{ARMHF_COMPILER_PATH}%/arm-linux-gnueabihf/include/c++/8.3.0/", "-isystem", "%{ARMHF_COMPILER_PATH}%/arm-linux-gnueabihf/libc/usr/include/", + "-isystem", + "%{PYTHON_INCLUDE_PATH}%", + "-isystem", + "/usr/include/", ], ), ], @@ -466,6 +474,7 @@ def _impl(ctx): "%{AARCH64_COMPILER_PATH}%/lib/gcc/aarch64-linux-gnu/8.3.0/include-fixed", "%{AARCH64_COMPILER_PATH}%/aarch64-linux-gnu/include/c++/8.3.0/", "%{AARCH64_COMPILER_PATH}%/aarch64-linux-gnu/libc/usr/include/", + "/usr/include", ] elif (ctx.attr.cpu == "armhf"): cxx_builtin_include_directories = [ @@ -473,6 +482,7 @@ def _impl(ctx): "%{ARMHF_COMPILER_PATH}%/lib/gcc/arm-linux-gnueabihf/8.3.0/include-fixed", "%{ARMHF_COMPILER_PATH}%/arm-linux-gnueabihf/include/c++/8.3.0/", "%{ARMHF_COMPILER_PATH}%/arm-linux-gnueabihf/libc/usr/include/", + "/usr/include", ] else: fail("Unreachable")