From 55a21bf47ee09fef538674788420cc6475d0d502 Mon Sep 17 00:00:00 2001 From: Raman Sarokin Date: Wed, 30 Sep 2020 13:30:55 -0700 Subject: [PATCH] Added example of using OpenCL internal API. PiperOrigin-RevId: 334667219 Change-Id: I4598ee0ca7e596615454514c3a281d225b15d4d7 --- .../lite/delegates/gpu/cl/testing/BUILD | 45 ++-- .../gpu/cl/testing/internal_api_samples.cc | 224 ++++++++++++++++++ .../cl/testing/run_internal_api_samples.sh | 101 ++++++++ 3 files changed, 356 insertions(+), 14 deletions(-) create mode 100644 tensorflow/lite/delegates/gpu/cl/testing/internal_api_samples.cc create mode 100755 tensorflow/lite/delegates/gpu/cl/testing/run_internal_api_samples.sh diff --git a/tensorflow/lite/delegates/gpu/cl/testing/BUILD b/tensorflow/lite/delegates/gpu/cl/testing/BUILD index c82190ca0e6..9eb5a9445a6 100644 --- a/tensorflow/lite/delegates/gpu/cl/testing/BUILD +++ b/tensorflow/lite/delegates/gpu/cl/testing/BUILD @@ -3,20 +3,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -cc_binary( - name = "performance_profiling", - srcs = ["performance_profiling.cc"], - deps = [ - "//tensorflow/lite/delegates/gpu/cl:environment", - "//tensorflow/lite/delegates/gpu/cl:inference_context", - "//tensorflow/lite/delegates/gpu/common:model", - "//tensorflow/lite/delegates/gpu/common:status", - "//tensorflow/lite/delegates/gpu/common/testing:tflite_model_reader", - "//tensorflow/lite/kernels:builtin_ops", - "@com_google_absl//absl/time", - ], -) - cc_binary( name = "delegate_testing", srcs = ["delegate_testing.cc"], @@ -34,3 +20,34 @@ cc_binary( "@com_google_absl//absl/time", ], ) + +cc_binary( + name = "internal_api_samples", + srcs = ["internal_api_samples.cc"], + deps = [ + "//tensorflow/lite/delegates/gpu:api", + "//tensorflow/lite/delegates/gpu/cl:api", + "//tensorflow/lite/delegates/gpu/cl:environment", + "//tensorflow/lite/delegates/gpu/cl:inference_context", + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common/testing:tflite_model_reader", + "//tensorflow/lite/kernels:builtin_ops", + "//tensorflow/lite/kernels:kernel_util", + "@com_google_absl//absl/time", + ], +) + +cc_binary( + name = "performance_profiling", + srcs = ["performance_profiling.cc"], + deps = [ + "//tensorflow/lite/delegates/gpu/cl:environment", + "//tensorflow/lite/delegates/gpu/cl:inference_context", + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:status", + "//tensorflow/lite/delegates/gpu/common/testing:tflite_model_reader", + "//tensorflow/lite/kernels:builtin_ops", + "@com_google_absl//absl/time", + ], +) diff --git a/tensorflow/lite/delegates/gpu/cl/testing/internal_api_samples.cc b/tensorflow/lite/delegates/gpu/cl/testing/internal_api_samples.cc new file mode 100644 index 00000000000..3e9b614c8c4 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/cl/testing/internal_api_samples.cc @@ -0,0 +1,224 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include // NOLINT(build/c++11) +#include +#include + +#include "absl/time/time.h" +#include "tensorflow/lite/delegates/gpu/api.h" +#include "tensorflow/lite/delegates/gpu/cl/api.h" +#include "tensorflow/lite/delegates/gpu/cl/environment.h" +#include "tensorflow/lite/delegates/gpu/cl/inference_context.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/testing/tflite_model_reader.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/register.h" + +namespace tflite { +namespace gpu { +namespace cl { +namespace { +void FillInputTensors(tflite::Interpreter* interpreter) { + for (int k = 0; k < interpreter->inputs().size(); ++k) { + TfLiteTensor* tensor_ptr = interpreter->tensor(interpreter->inputs()[k]); + const auto tensor_elements_count = tflite::NumElements(tensor_ptr); + if (tensor_ptr->type == kTfLiteFloat32) { + float* p = interpreter->typed_input_tensor(k); + for (int i = 0; i < tensor_elements_count; ++i) { + p[i] = std::sin(i); + } + } else { + std::cout << "No support of non Float32 input/output tensors" + << std::endl; + } + } +} + +void CompareCPUGPUResults(tflite::Interpreter* cpu, + const std::vector& outputs, + const std::vector>& gpu, + float eps) { + for (int i = 0; i < gpu.size(); ++i) { + TfLiteTensor* tensor_ptr = cpu->tensor(outputs[i]); + const float* cpu_out = tensor_ptr->data.f; + const float* gpu_out = gpu[i].data(); + const int kMaxPrint = 10; + int printed = 0; + int total_different = 0; + for (int k = 0; k < tensor_ptr->bytes / 4; ++k) { + const float abs_diff = fabs(cpu_out[k] - gpu_out[k]); + if (abs_diff > eps) { + total_different++; + if (printed < kMaxPrint) { + std::cout << "Output #" << i << ": element #" << k << ": CPU value - " + << cpu_out[k] << ", GPU value - " << gpu_out[k] + << ", abs diff - " << abs_diff << std::endl; + printed++; + } + if (printed == kMaxPrint) { + std::cout << "Printed " << kMaxPrint + << " different elements, threshhold - " << eps + << ", next different elements skipped" << std::endl; + printed++; + } + } + } + std::cout << "Total " << total_different + << " different elements, for output #" << i << ", threshhold - " + << eps << std::endl; + } +} +} // namespace + +// Run Jet with OpenCL internal API and compares correctness with TFLite CPU +absl::Status RunModelSampleWithInternalAPI(const std::string& model_name) { + auto flatbuffer = tflite::FlatBufferModel::BuildFromFile(model_name.c_str()); + + ops::builtin::BuiltinOpResolver op_resolver; + InterpreterBuilder tfl_builder(*flatbuffer, op_resolver); + + // CPU. + std::unique_ptr cpu_inference; + tfl_builder(&cpu_inference); + if (!cpu_inference) { + return absl::InternalError("Failed to build CPU inference."); + } + auto status = cpu_inference->AllocateTensors(); + if (status != kTfLiteOk) { + return absl::InternalError("Failed to AllocateTensors for CPU inference."); + } + for (int k = 0; k < cpu_inference->inputs().size(); ++k) { + TfLiteTensor* tensor_ptr = + cpu_inference->tensor(cpu_inference->inputs()[k]); + if (tensor_ptr->type != kTfLiteFloat32) { + return absl::InvalidArgumentError( + "Internal api supports only F32 input tensors"); + } + } + for (int k = 0; k < cpu_inference->outputs().size(); ++k) { + TfLiteTensor* tensor_ptr = + cpu_inference->tensor(cpu_inference->outputs()[k]); + if (tensor_ptr->type != kTfLiteFloat32) { + return absl::InvalidArgumentError( + "Internal api supports only F32 output tensors"); + } + } + FillInputTensors(cpu_inference.get()); + status = cpu_inference->Invoke(); + if (status != kTfLiteOk) { + return absl::InternalError("Failed to Invoke CPU inference."); + } + + GraphFloat32 graph_cl; + RETURN_IF_ERROR(BuildFromFlatBuffer(*flatbuffer, op_resolver, &graph_cl)); + + auto inputs = graph_cl.inputs(); + auto outputs = graph_cl.outputs(); + std::vector in_refs(inputs.size()); + std::vector out_refs(outputs.size()); + for (int i = 0; i < inputs.size(); ++i) { + in_refs[i] = inputs[i]->tensor.ref; + } + for (int i = 0; i < outputs.size(); ++i) { + out_refs[i] = outputs[i]->tensor.ref; + } + + Environment env; + RETURN_IF_ERROR(CreateEnvironment(&env)); + + std::unique_ptr inf_env; + // Initializes environment. + InferenceEnvironmentOptions env_options; + env_options.device = env.device().id(); + env_options.context = env.context().context(); + env_options.command_queue = env.queue()->queue(); + RETURN_IF_ERROR(NewInferenceEnvironment(env_options, &inf_env, nullptr)); + + std::unique_ptr builder; + // Initializes builder. + InferenceOptions options; + options.priority1 = InferencePriority::MIN_LATENCY; + options.priority2 = InferencePriority::MIN_MEMORY_USAGE; + options.priority3 = InferencePriority::MAX_PRECISION; + options.usage = InferenceUsage::SUSTAINED_SPEED; + RETURN_IF_ERROR( + inf_env->NewInferenceBuilder(options, std::move(graph_cl), &builder)); + + // Sets input/output object def for builder_. + ObjectDef obj_def; + obj_def.data_type = DataType::FLOAT32; + obj_def.data_layout = DataLayout::BHWC; + obj_def.object_type = ObjectType::CPU_MEMORY; + obj_def.user_provided = true; + for (int i = 0; i < in_refs.size(); ++i) { + RETURN_IF_ERROR(builder->SetInputObjectDef(i, obj_def)); + } + for (int i = 0; i < out_refs.size(); ++i) { + RETURN_IF_ERROR(builder->SetOutputObjectDef(i, obj_def)); + } + + std::unique_ptr<::tflite::gpu::InferenceRunner> runner; + // Builds runner. + RETURN_IF_ERROR(builder->Build(&runner)); + + // Sets the input/output object. + for (int i = 0; i < in_refs.size(); ++i) { + TfLiteTensor* tensor_ptr = cpu_inference->tensor(in_refs[i]); + RETURN_IF_ERROR(runner->SetInputObject( + i, CpuMemory{tensor_ptr->data.data, tensor_ptr->bytes})); + } + + std::vector> output_tensors(out_refs.size()); + for (int i = 0; i < out_refs.size(); ++i) { + TfLiteTensor* tensor_ptr = cpu_inference->tensor(out_refs[i]); + output_tensors[i].resize(tensor_ptr->bytes / 4); + RETURN_IF_ERROR(runner->SetOutputObject( + i, CpuMemory{output_tensors[i].data(), tensor_ptr->bytes})); + } + + RETURN_IF_ERROR(runner->Run()); + + CompareCPUGPUResults(cpu_inference.get(), out_refs, output_tensors, 1e-4f); + + return absl::OkStatus(); +} + +} // namespace cl +} // namespace gpu +} // namespace tflite + +int main(int argc, char** argv) { + if (argc <= 1) { + std::cerr << "Expected model path as second argument."; + return -1; + } + + auto load_status = tflite::gpu::cl::LoadOpenCL(); + if (!load_status.ok()) { + std::cerr << load_status.message(); + return -1; + } + + auto run_status = tflite::gpu::cl::RunModelSampleWithInternalAPI(argv[1]); + if (!run_status.ok()) { + std::cerr << run_status.message(); + return -1; + } + + return EXIT_SUCCESS; +} diff --git a/tensorflow/lite/delegates/gpu/cl/testing/run_internal_api_samples.sh b/tensorflow/lite/delegates/gpu/cl/testing/run_internal_api_samples.sh new file mode 100755 index 00000000000..21900c55875 --- /dev/null +++ b/tensorflow/lite/delegates/gpu/cl/testing/run_internal_api_samples.sh @@ -0,0 +1,101 @@ +#!/bin/bash +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +shopt -s expand_aliases # to work with commands aliases in .sh + +description="Example of intetrnal api usage: +How to use: +[-h or --help, print instructions] +[-m or --model_path, path to the model in .tflite format] +[-d or --device, select device](optional, if you have few connected devices)" + +model_path="" +alias ADB='adb' +host="" + +while [[ "$1" != "" ]]; do + case $1 in + -m | --model_path) + shift + model_path=$1 + ;; + -d | --device) + shift + if [[ "$1" == "HOST" ]] + then + host="HOST" + fi + alias ADB='adb -s '$1'' + ;; + -h | --help) + echo "$description" + exit + ;; + esac + shift +done + +if [ "$model_path" = "" ] +then +echo "No model provided." +echo "$description" +exit +fi + +SHELL_DIR=$(dirname "$0") +BINARY_NAME=internal_api_samples + +if [[ "$host" == "HOST" ]] +then +bazel build -c opt --copt -DCL_DELEGATE_NO_GL //"$SHELL_DIR":"$BINARY_NAME" +chmod +x bazel-bin/"$SHELL_DIR"/"$BINARY_NAME" +./bazel-bin/"$SHELL_DIR"/"$BINARY_NAME" "$model_path" +exit +fi + +model_name=${model_path##*/} # finds last token after '/' + +OPENCL_DIR=/data/local/tmp/internal_api_samples/ + +ADB shell mkdir -p $OPENCL_DIR + +ADB push "$model_path" "$OPENCL_DIR" + +declare -a BUILD_CONFIG +abi_version=$(ADB shell getprop ro.product.cpu.abi | tr -d '\r') +if [[ "$abi_version" == "armeabi-v7a" ]]; then +#"32 bit ARM" +BUILD_CONFIG=( --config=android_arm -c opt --copt=-fPIE --linkopt=-pie ) +elif [[ "$abi_version" == "arm64-v8a" ]]; then +#"64 bit ARM" +BUILD_CONFIG=( --config=android_arm64 -c opt ) +elif [[ "$abi_version" == "x86_64" ]]; then +# x86_64 +BUILD_CONFIG=( --config=android_x86_64 -c opt ) +else +echo "Error: Unknown processor ABI" +exit 1 +fi + +bazel build "${BUILD_CONFIG[@]}" --copt -DCL_DELEGATE_NO_GL //$SHELL_DIR:$BINARY_NAME + +ADB push bazel-bin/$SHELL_DIR/$BINARY_NAME $OPENCL_DIR + +ADB shell chmod +x $OPENCL_DIR/$BINARY_NAME +ADB shell "cd $OPENCL_DIR && ./$BINARY_NAME $model_name" + +# clean up files from device +ADB shell rm -rf $OPENCL_DIR