Added generic API neutral Arguments for better(generic) task description.
Added MetalArguments for compute task. Showed example of usage in FullyConnected. PiperOrigin-RevId: 336932028 Change-Id: I86e9c4fd967a69ad18cf08b4e0c6acef2f3a681e
This commit is contained in:
parent
21b7abe1b6
commit
d92c3720a0
@ -38,6 +38,15 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "arguments",
|
||||||
|
srcs = ["arguments.cc"],
|
||||||
|
hdrs = ["arguments.h"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:status",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
objc_library(
|
objc_library(
|
||||||
name = "buffer_convert",
|
name = "buffer_convert",
|
||||||
srcs = ["buffer_convert.mm"],
|
srcs = ["buffer_convert.mm"],
|
||||||
@ -134,6 +143,7 @@ objc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":common",
|
":common",
|
||||||
":compute_task_descriptor",
|
":compute_task_descriptor",
|
||||||
|
":metal_arguments",
|
||||||
":runtime_options",
|
":runtime_options",
|
||||||
"//tensorflow/lite/delegates/gpu/common:model",
|
"//tensorflow/lite/delegates/gpu/common:model",
|
||||||
"//tensorflow/lite/delegates/gpu/common:shape",
|
"//tensorflow/lite/delegates/gpu/common:shape",
|
||||||
@ -149,6 +159,7 @@ objc_library(
|
|||||||
hdrs = ["compute_task_descriptor.h"],
|
hdrs = ["compute_task_descriptor.h"],
|
||||||
copts = DEFAULT_COPTS,
|
copts = DEFAULT_COPTS,
|
||||||
deps = [
|
deps = [
|
||||||
|
":arguments",
|
||||||
"//tensorflow/lite/delegates/gpu/common:model",
|
"//tensorflow/lite/delegates/gpu/common:model",
|
||||||
"//tensorflow/lite/delegates/gpu/common:shape",
|
"//tensorflow/lite/delegates/gpu/common:shape",
|
||||||
"//tensorflow/lite/delegates/gpu/common:types",
|
"//tensorflow/lite/delegates/gpu/common:types",
|
||||||
@ -213,6 +224,20 @@ ios_unit_test(
|
|||||||
deps = [":inference_context_test_lib"],
|
deps = [":inference_context_test_lib"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
objc_library(
|
||||||
|
name = "metal_arguments",
|
||||||
|
srcs = ["metal_arguments.mm"],
|
||||||
|
hdrs = ["metal_arguments.h"],
|
||||||
|
copts = DEFAULT_COPTS,
|
||||||
|
sdk_frameworks = ["Metal"],
|
||||||
|
deps = [
|
||||||
|
":arguments",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:status",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:util",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "runtime_options",
|
name = "runtime_options",
|
||||||
hdrs = ["runtime_options.h"],
|
hdrs = ["runtime_options.h"],
|
||||||
|
64
tensorflow/lite/delegates/gpu/metal/arguments.cc
Normal file
64
tensorflow/lite/delegates/gpu/metal/arguments.cc
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/metal/arguments.h"
|
||||||
|
|
||||||
|
#include "absl/strings/ascii.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace metal {
|
||||||
|
namespace {
|
||||||
|
bool IsWordSymbol(char symbol) {
|
||||||
|
return absl::ascii_isalnum(symbol) || symbol == '_';
|
||||||
|
}
|
||||||
|
|
||||||
|
bool HasWord(const std::string& word, const std::string& text) {
|
||||||
|
size_t pos = text.find(word);
|
||||||
|
while (pos != std::string::npos) {
|
||||||
|
char prev = pos == 0 ? '.' : text[pos - 1];
|
||||||
|
char next = pos + word.size() < text.size() ? text[pos + word.size()] : '.';
|
||||||
|
if (!IsWordSymbol(prev) & !IsWordSymbol(next)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
pos = text.find(word, pos + 1);
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
// Static
|
||||||
|
constexpr char Arguments::kArgsPrefix[];
|
||||||
|
|
||||||
|
void Arguments::AddFloat(const std::string& name, float value) {
|
||||||
|
float_values_[name].value = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Arguments::AddInt(const std::string& name, int value) {
|
||||||
|
int_values_[name].value = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Arguments::GetActiveArguments(const std::string& code) {
|
||||||
|
for (auto& float_val : float_values_) {
|
||||||
|
float_val.second.active = HasWord(kArgsPrefix + float_val.first, code);
|
||||||
|
}
|
||||||
|
for (auto& int_val : int_values_) {
|
||||||
|
int_val.second.active = HasWord(kArgsPrefix + int_val.first, code);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace metal
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
77
tensorflow/lite/delegates/gpu/metal/arguments.h
Normal file
77
tensorflow/lite/delegates/gpu/metal/arguments.h
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_ARGUMENTS_H_
|
||||||
|
#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_ARGUMENTS_H_
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace metal {
|
||||||
|
|
||||||
|
class Arguments {
|
||||||
|
public:
|
||||||
|
Arguments() = default;
|
||||||
|
|
||||||
|
// Move only
|
||||||
|
Arguments(Arguments&& args) = default;
|
||||||
|
Arguments& operator=(Arguments&& args) = default;
|
||||||
|
Arguments(const Arguments&) = delete;
|
||||||
|
Arguments& operator=(const Arguments&) = delete;
|
||||||
|
|
||||||
|
void AddFloat(const std::string& name, float value = 0.0f);
|
||||||
|
void AddInt(const std::string& name, int value = 0);
|
||||||
|
|
||||||
|
private:
|
||||||
|
friend class MetalArguments;
|
||||||
|
void GetActiveArguments(const std::string& code);
|
||||||
|
|
||||||
|
static constexpr char kArgsPrefix[] = "args.";
|
||||||
|
struct IntValue {
|
||||||
|
int value;
|
||||||
|
|
||||||
|
// many arguments generated automatically and not used
|
||||||
|
// this flag active if argument was used in kernel_code
|
||||||
|
// Will be filled after GetActiveArguments call
|
||||||
|
bool active = false;
|
||||||
|
};
|
||||||
|
std::map<std::string, IntValue> int_values_;
|
||||||
|
|
||||||
|
struct FloatValue {
|
||||||
|
float value;
|
||||||
|
|
||||||
|
// many arguments generated automatically and not used
|
||||||
|
// this flag active if argument was used in kernel_code
|
||||||
|
// Will be filled after GetActiveArguments call
|
||||||
|
bool active = false;
|
||||||
|
};
|
||||||
|
std::map<std::string, FloatValue> float_values_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class ArgumentsSetter {
|
||||||
|
public:
|
||||||
|
virtual absl::Status SetInt(const std::string& name, int value) = 0;
|
||||||
|
virtual absl::Status SetFloat(const std::string& name, float value) = 0;
|
||||||
|
virtual ~ArgumentsSetter() = default;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace metal
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_ARGUMENTS_H_
|
@ -535,6 +535,7 @@ ComputeTaskDescriptorPtr FuseChain(const FusionSequence& chain) {
|
|||||||
uniform_index++;
|
uniform_index++;
|
||||||
fused_descriptor->uniform_buffers.push_back({"", buffer.data_function});
|
fused_descriptor->uniform_buffers.push_back({"", buffer.data_function});
|
||||||
}
|
}
|
||||||
|
fused_descriptor->args = std::move(desc->args);
|
||||||
|
|
||||||
if (desc->is_linkable) {
|
if (desc->is_linkable) {
|
||||||
call_code +=
|
call_code +=
|
||||||
@ -546,8 +547,8 @@ ComputeTaskDescriptorPtr FuseChain(const FusionSequence& chain) {
|
|||||||
|
|
||||||
ComputeTaskDescriptorPtr non_linkable = sequence.front();
|
ComputeTaskDescriptorPtr non_linkable = sequence.front();
|
||||||
fused_descriptor->shader_source =
|
fused_descriptor->shader_source =
|
||||||
absl::Substitute(non_linkable->shader_source, function_code,
|
absl::Substitute(non_linkable->shader_source, function_code + "$0",
|
||||||
buffer_declarations, call_code);
|
buffer_declarations + "$1", call_code);
|
||||||
std::vector<ValueId> alias;
|
std::vector<ValueId> alias;
|
||||||
alias.reserve(chain.size() - 1);
|
alias.reserve(chain.size() - 1);
|
||||||
for (int i = 0; i < chain.size() - 1; i++) {
|
for (int i = 0; i < chain.size() - 1; i++) {
|
||||||
|
@ -158,6 +158,7 @@ static std::vector<ComputeTaskDescriptorPtr> Add2Linkable(int id, ValueId input_
|
|||||||
ValueId input_id2, ValueId output_id) {
|
ValueId input_id2, ValueId output_id) {
|
||||||
std::vector<ComputeTaskDescriptorPtr> descriptors;
|
std::vector<ComputeTaskDescriptorPtr> descriptors;
|
||||||
descriptors.push_back(ComputeTaskDescriptorPtr(new ComputeTaskDescriptor({
|
descriptors.push_back(ComputeTaskDescriptorPtr(new ComputeTaskDescriptor({
|
||||||
|
{}, // args
|
||||||
id,
|
id,
|
||||||
true, // linkable
|
true, // linkable
|
||||||
true, // associative_op
|
true, // associative_op
|
||||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/metal/metal_arguments.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
@ -70,11 +71,15 @@ struct UniformBuffer {
|
|||||||
uint3 _groupsCount;
|
uint3 _groupsCount;
|
||||||
DispatchParamsFunction _resizeFunction;
|
DispatchParamsFunction _resizeFunction;
|
||||||
std::string _description;
|
std::string _description;
|
||||||
|
tflite::gpu::metal::MetalArguments _metal_args;
|
||||||
}
|
}
|
||||||
|
|
||||||
- (absl::Status)compileWithDevice:(id<MTLDevice>)device
|
- (absl::Status)compileWithDevice:(id<MTLDevice>)device
|
||||||
taskDescriptor:(ComputeTaskDescriptorPtr)desc
|
taskDescriptor:(ComputeTaskDescriptorPtr)desc
|
||||||
runtimeOptions:(const RuntimeOptions&)options {
|
runtimeOptions:(const RuntimeOptions&)options {
|
||||||
|
size_t offset = desc->input_buffers.size() + desc->uniform_buffers.size()
|
||||||
|
+ desc->immutable_buffers.size() + 1;
|
||||||
|
RETURN_IF_ERROR(_metal_args.Init(offset, &desc->args, &desc->shader_source));
|
||||||
NSString* barrier;
|
NSString* barrier;
|
||||||
// simdgroup_barrier is supported on macOS 10.13+ and Metal shading language version 2.0
|
// simdgroup_barrier is supported on macOS 10.13+ and Metal shading language version 2.0
|
||||||
if (@available(macOS 10.13, iOS 10.0, tvOS 10.0, *)) {
|
if (@available(macOS 10.13, iOS 10.0, tvOS 10.0, *)) {
|
||||||
@ -251,6 +256,7 @@ struct UniformBuffer {
|
|||||||
[encoder setBytes:uniform.data.data() length:uniform.data.size() atIndex:bindIndex];
|
[encoder setBytes:uniform.data.data() length:uniform.data.size() atIndex:bindIndex];
|
||||||
bindIndex++;
|
bindIndex++;
|
||||||
}
|
}
|
||||||
|
_metal_args.Encode(encoder, bindIndex);
|
||||||
|
|
||||||
MTLSize groupsCount = MTLSizeMake(_groupsCount.x, _groupsCount.y, _groupsCount.z);
|
MTLSize groupsCount = MTLSizeMake(_groupsCount.x, _groupsCount.y, _groupsCount.z);
|
||||||
MTLSize groupsSize = MTLSizeMake(_groupsSize.x, _groupsSize.y, _groupsSize.z);
|
MTLSize groupsSize = MTLSizeMake(_groupsSize.x, _groupsSize.y, _groupsSize.z);
|
||||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/metal/arguments.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
|
#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
@ -79,6 +80,7 @@ struct ComputeTaskDescriptor {
|
|||||||
UniformsFunction data_function;
|
UniformsFunction data_function;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Arguments args;
|
||||||
// Unique ID to match the graph compilation errors.
|
// Unique ID to match the graph compilation errors.
|
||||||
int id;
|
int id;
|
||||||
bool is_linkable;
|
bool is_linkable;
|
||||||
|
@ -51,13 +51,6 @@ std::string GetFullyConnectedCode(const DeviceInfo& device_info,
|
|||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
struct uniforms {
|
|
||||||
uint src_depth;
|
|
||||||
uint dst_channels;
|
|
||||||
uint out_channels;
|
|
||||||
uint dummy;
|
|
||||||
};
|
|
||||||
|
|
||||||
$$0
|
$$0
|
||||||
kernel void ComputeFunction(
|
kernel void ComputeFunction(
|
||||||
$$1
|
$$1
|
||||||
@ -71,11 +64,11 @@ std::string GetFullyConnectedCode(const DeviceInfo& device_info,
|
|||||||
float summa = 0.0f;
|
float summa = 0.0f;
|
||||||
threadgroup FLT4 local_vector[32];
|
threadgroup FLT4 local_vector[32];
|
||||||
for (int j = 0; j < $0; ++j) {
|
for (int j = 0; j < $0; ++j) {
|
||||||
local_vector[tid_index] = j * 32 + tid_index >= params.src_depth ?
|
local_vector[tid_index] = j * 32 + tid_index >= args.src_slices ?
|
||||||
FLT4(0.0f) : vector[j * 32 + tid_index];
|
FLT4(0.0f) : vector[j * 32 + tid_index];
|
||||||
$1(mem_flags::mem_threadgroup);
|
$1(mem_flags::mem_threadgroup);
|
||||||
for (uint i = 0, counter = j * 32 + tid.y * 8; i < 8; ++i, ++counter) {
|
for (uint i = 0, counter = j * 32 + tid.y * 8; i < 8; ++i, ++counter) {
|
||||||
summa += dot(local_vector[tid.y * 8 + i], matrix[counter * params.dst_channels + ugid.x]);
|
summa += dot(local_vector[tid.y * 8 + i], matrix[counter * args.dst_channels_alignedx8 + ugid.x]);
|
||||||
}
|
}
|
||||||
$1(mem_flags::mem_none);
|
$1(mem_flags::mem_none);
|
||||||
}
|
}
|
||||||
@ -87,10 +80,10 @@ std::string GetFullyConnectedCode(const DeviceInfo& device_info,
|
|||||||
for (uint i = 0; i < $0; ++i, ++counter) {
|
for (uint i = 0; i < $0; ++i, ++counter) {
|
||||||
)";
|
)";
|
||||||
if (src_depth % 4 != 0) {
|
if (src_depth % 4 != 0) {
|
||||||
code << " if (counter >= params.src_depth) continue;" << std::endl;
|
code << " if (counter >= args.src_slices) continue;" << std::endl;
|
||||||
}
|
}
|
||||||
code << " summa += dot(vector[counter], matrix[counter * "
|
code << " summa += dot(vector[counter], matrix[counter * "
|
||||||
"params.dst_channels + ugid.x]);"
|
"args.dst_channels_alignedx8 + ugid.x]);"
|
||||||
<< std::endl;
|
<< std::endl;
|
||||||
code << " }" << std::endl;
|
code << " }" << std::endl;
|
||||||
}
|
}
|
||||||
@ -106,7 +99,7 @@ std::string GetFullyConnectedCode(const DeviceInfo& device_info,
|
|||||||
temp[tid.x][0] = summa;
|
temp[tid.x][0] = summa;
|
||||||
}
|
}
|
||||||
$1(mem_flags::mem_threadgroup);
|
$1(mem_flags::mem_threadgroup);
|
||||||
if (tid.y == 0 && tid.x % 4 == 0 && ugid.x < params.out_channels) {
|
if (tid.y == 0 && tid.x % 4 == 0 && ugid.x < args.dst_channels) {
|
||||||
const int linear_index = ugid.x / 4;
|
const int linear_index = ugid.x / 4;
|
||||||
FLT4 value = FLT4(temp[tid.x][0], temp[tid.x + 1][0], temp[tid.x + 2][0], temp[tid.x + 3][0]) +
|
FLT4 value = FLT4(temp[tid.x][0], temp[tid.x + 1][0], temp[tid.x + 2][0], temp[tid.x + 3][0]) +
|
||||||
biases[linear_index];
|
biases[linear_index];
|
||||||
@ -132,6 +125,11 @@ std::vector<ComputeTaskDescriptorPtr> FullyConnected(
|
|||||||
desc->shader_source = GetFullyConnectedCode(device_info, attr.weights.shape.i,
|
desc->shader_source = GetFullyConnectedCode(device_info, attr.weights.shape.i,
|
||||||
attr.weights.shape.o);
|
attr.weights.shape.o);
|
||||||
|
|
||||||
|
desc->args.AddInt("dst_channels", attr.weights.shape.o);
|
||||||
|
desc->args.AddInt("src_slices", DivideRoundUp(attr.weights.shape.i, 4));
|
||||||
|
desc->args.AddInt("dst_channels_alignedx8",
|
||||||
|
AlignByN(attr.weights.shape.o, 8));
|
||||||
|
|
||||||
desc->input_buffers = {
|
desc->input_buffers = {
|
||||||
{input_id, "device FLT4* const vector"},
|
{input_id, "device FLT4* const vector"},
|
||||||
};
|
};
|
||||||
@ -174,19 +172,6 @@ std::vector<ComputeTaskDescriptorPtr> FullyConnected(
|
|||||||
attr.weights.shape.o)},
|
attr.weights.shape.o)},
|
||||||
};
|
};
|
||||||
|
|
||||||
desc->uniform_buffers = {
|
|
||||||
{"constant uniforms& params",
|
|
||||||
[attr](const std::map<ValueId, BHWC>& buffers) {
|
|
||||||
std::vector<uint32_t> uniform_params{
|
|
||||||
static_cast<uint32_t>(DivideRoundUp(attr.weights.shape.i, 4)),
|
|
||||||
static_cast<uint32_t>(AlignByN(attr.weights.shape.o, 8)),
|
|
||||||
static_cast<uint32_t>(attr.weights.shape.o),
|
|
||||||
static_cast<uint32_t>(0),
|
|
||||||
};
|
|
||||||
return GetByteBuffer(uniform_params);
|
|
||||||
}},
|
|
||||||
};
|
|
||||||
|
|
||||||
desc->resize_function = [attr](const std::map<ValueId, BHWC>& buffers) {
|
desc->resize_function = [attr](const std::map<ValueId, BHWC>& buffers) {
|
||||||
const uint3 groups_size{8, 4, 1};
|
const uint3 groups_size{8, 4, 1};
|
||||||
const int dst_channels_aligned = AlignByN(attr.weights.shape.o, 8);
|
const int dst_channels_aligned = AlignByN(attr.weights.shape.o, 8);
|
||||||
|
80
tensorflow/lite/delegates/gpu/metal/metal_arguments.h
Normal file
80
tensorflow/lite/delegates/gpu/metal/metal_arguments.h
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_METAL_ARGUMENTS_H_
|
||||||
|
#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_METAL_ARGUMENTS_H_
|
||||||
|
|
||||||
|
#import <Metal/Metal.h>
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/metal/arguments.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace metal {
|
||||||
|
|
||||||
|
class MetalArguments : public ArgumentsSetter {
|
||||||
|
public:
|
||||||
|
MetalArguments() = default;
|
||||||
|
|
||||||
|
absl::Status Init(int buffer_offset, Arguments* args, std::string* code);
|
||||||
|
|
||||||
|
// Move only
|
||||||
|
MetalArguments(MetalArguments&& args) = default;
|
||||||
|
MetalArguments& operator=(MetalArguments&& args) = default;
|
||||||
|
MetalArguments(const MetalArguments&) = delete;
|
||||||
|
MetalArguments& operator=(const MetalArguments&) = delete;
|
||||||
|
|
||||||
|
absl::Status SetInt(const std::string& name, int value) override;
|
||||||
|
absl::Status SetFloat(const std::string& name, float value) override;
|
||||||
|
|
||||||
|
void Encode(id<MTLComputeCommandEncoder> encoder, int buffer_offset) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
static constexpr char kArgsPrefix[] = "args.";
|
||||||
|
struct IntValue {
|
||||||
|
int value;
|
||||||
|
|
||||||
|
// many arguments generated automatically and not used
|
||||||
|
// to reduce amount of data transferred we adding this optimization
|
||||||
|
bool active = false;
|
||||||
|
|
||||||
|
// offset to shared storage.
|
||||||
|
uint32_t bytes_offset = -1;
|
||||||
|
};
|
||||||
|
std::map<std::string, IntValue> int_values_;
|
||||||
|
|
||||||
|
struct FloatValue {
|
||||||
|
float value;
|
||||||
|
|
||||||
|
// many arguments generated automatically and not used
|
||||||
|
// to reduce amount of data transferred we adding this optimization
|
||||||
|
bool active = false;
|
||||||
|
|
||||||
|
// offset to shared storage.
|
||||||
|
uint32_t bytes_offset = -1;
|
||||||
|
};
|
||||||
|
std::map<std::string, FloatValue> float_values_;
|
||||||
|
std::vector<uint8_t> const_data_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace metal
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_METAL_ARGUMENTS_H_
|
137
tensorflow/lite/delegates/gpu/metal/metal_arguments.mm
Normal file
137
tensorflow/lite/delegates/gpu/metal/metal_arguments.mm
Normal file
@ -0,0 +1,137 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/lite/delegates/gpu/metal/metal_arguments.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "absl/strings/substitute.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/util.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace metal {
|
||||||
|
namespace {
|
||||||
|
bool IsWordSymbol(char symbol) {
|
||||||
|
return absl::ascii_isalnum(symbol) || symbol == '_';
|
||||||
|
}
|
||||||
|
|
||||||
|
void ReplaceAllWords(const std::string& old_word, const std::string& new_word,
|
||||||
|
std::string* str) {
|
||||||
|
size_t position = str->find(old_word);
|
||||||
|
while (position != std::string::npos) {
|
||||||
|
char prev = position == 0 ? '.' : (*str)[position - 1];
|
||||||
|
char next = position + old_word.size() < str->size()
|
||||||
|
? (*str)[position + old_word.size()]
|
||||||
|
: '.';
|
||||||
|
if (IsWordSymbol(prev) || IsWordSymbol(next)) {
|
||||||
|
position = str->find(old_word, position + 1);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
str->replace(position, old_word.size(), new_word);
|
||||||
|
position = str->find(old_word, position + new_word.size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
// Static
|
||||||
|
constexpr char MetalArguments::kArgsPrefix[];
|
||||||
|
|
||||||
|
absl::Status MetalArguments::Init(int buffer_offset, Arguments* args, std::string* code) {
|
||||||
|
args->GetActiveArguments(*code);
|
||||||
|
std::string struct_desc = "struct uniforms_buffer {\n";
|
||||||
|
std::string struct_decl;
|
||||||
|
int pos = 0;
|
||||||
|
for (auto& fvalue : args->float_values_) {
|
||||||
|
auto& new_val = float_values_[fvalue.first];
|
||||||
|
new_val.value = fvalue.second.value;
|
||||||
|
new_val.active = fvalue.second.active;
|
||||||
|
if (fvalue.second.active) {
|
||||||
|
new_val.bytes_offset = pos * 4;
|
||||||
|
pos++;
|
||||||
|
struct_desc += " float " + fvalue.first + ";\n";
|
||||||
|
ReplaceAllWords(kArgsPrefix + fvalue.first, "U." + fvalue.first, code);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (auto& ivalue : args->int_values_) {
|
||||||
|
auto& new_val = int_values_[ivalue.first];
|
||||||
|
new_val.value = ivalue.second.value;
|
||||||
|
new_val.active = ivalue.second.active;
|
||||||
|
if (ivalue.second.active) {
|
||||||
|
new_val.bytes_offset = pos * 4;
|
||||||
|
pos++;
|
||||||
|
struct_desc += " int " + ivalue.first + ";\n";
|
||||||
|
ReplaceAllWords(kArgsPrefix + ivalue.first, "U." + ivalue.first, code);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (pos != 0) {
|
||||||
|
struct_decl = "constant uniforms_buffer& U[[buffer(" + std::to_string(buffer_offset) + ")]],\n";
|
||||||
|
int aligned_pos = AlignByN(pos, 4);
|
||||||
|
for (int i = pos; i < aligned_pos; i++) {
|
||||||
|
struct_desc += " int dummy" + std::to_string(i - pos) + ";\n";
|
||||||
|
}
|
||||||
|
struct_desc += "};";
|
||||||
|
const_data_.resize(aligned_pos * 4);
|
||||||
|
for (auto& it : float_values_) {
|
||||||
|
float* ptr = reinterpret_cast<float*>(&const_data_[it.second.bytes_offset]);
|
||||||
|
*ptr = it.second.value;
|
||||||
|
}
|
||||||
|
for (auto& it : int_values_) {
|
||||||
|
int32_t* ptr = reinterpret_cast<int32_t*>(&const_data_[it.second.bytes_offset]);
|
||||||
|
*ptr = it.second.value;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
struct_desc = "";
|
||||||
|
struct_decl = "";
|
||||||
|
}
|
||||||
|
*code = absl::Substitute(*code, struct_desc, struct_decl);
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status MetalArguments::SetInt(const std::string& name, int value) {
|
||||||
|
auto it = int_values_.find(name);
|
||||||
|
if (it == int_values_.end()) {
|
||||||
|
return absl::NotFoundError(
|
||||||
|
absl::StrCat("No int argument with name - ", name));
|
||||||
|
}
|
||||||
|
it->second.value = value;
|
||||||
|
if (it->second.active) {
|
||||||
|
int32_t* ptr = reinterpret_cast<int32_t*>(&const_data_[it->second.bytes_offset]);
|
||||||
|
*ptr = value;
|
||||||
|
}
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
absl::Status MetalArguments::SetFloat(const std::string& name, float value) {
|
||||||
|
auto it = float_values_.find(name);
|
||||||
|
if (it == float_values_.end()) {
|
||||||
|
return absl::NotFoundError(
|
||||||
|
absl::StrCat("No float argument with name - ", name));
|
||||||
|
}
|
||||||
|
it->second.value = value;
|
||||||
|
if (it->second.active) {
|
||||||
|
float* ptr = reinterpret_cast<float*>(&const_data_[it->second.bytes_offset]);
|
||||||
|
*ptr = value;
|
||||||
|
}
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
void MetalArguments::Encode(id<MTLComputeCommandEncoder> encoder, int buffer_offset) const {
|
||||||
|
if (!const_data_.empty()) {
|
||||||
|
[encoder setBytes:const_data_.data() length:const_data_.size() atIndex:buffer_offset];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace metal
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
Loading…
Reference in New Issue
Block a user