Added generic API neutral Arguments for better(generic) task description.

Added MetalArguments for compute task.
Showed example of usage in FullyConnected.

PiperOrigin-RevId: 336932028
Change-Id: I86e9c4fd967a69ad18cf08b4e0c6acef2f3a681e
This commit is contained in:
Raman Sarokin 2020-10-13 12:32:59 -07:00 committed by TensorFlower Gardener
parent 21b7abe1b6
commit d92c3720a0
10 changed files with 405 additions and 27 deletions

View File

@ -38,6 +38,15 @@ cc_library(
],
)
cc_library(
name = "arguments",
srcs = ["arguments.cc"],
hdrs = ["arguments.h"],
deps = [
"//tensorflow/lite/delegates/gpu/common:status",
],
)
objc_library(
name = "buffer_convert",
srcs = ["buffer_convert.mm"],
@ -134,6 +143,7 @@ objc_library(
deps = [
":common",
":compute_task_descriptor",
":metal_arguments",
":runtime_options",
"//tensorflow/lite/delegates/gpu/common:model",
"//tensorflow/lite/delegates/gpu/common:shape",
@ -149,6 +159,7 @@ objc_library(
hdrs = ["compute_task_descriptor.h"],
copts = DEFAULT_COPTS,
deps = [
":arguments",
"//tensorflow/lite/delegates/gpu/common:model",
"//tensorflow/lite/delegates/gpu/common:shape",
"//tensorflow/lite/delegates/gpu/common:types",
@ -213,6 +224,20 @@ ios_unit_test(
deps = [":inference_context_test_lib"],
)
objc_library(
name = "metal_arguments",
srcs = ["metal_arguments.mm"],
hdrs = ["metal_arguments.h"],
copts = DEFAULT_COPTS,
sdk_frameworks = ["Metal"],
deps = [
":arguments",
"//tensorflow/lite/delegates/gpu/common:status",
"//tensorflow/lite/delegates/gpu/common:util",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "runtime_options",
hdrs = ["runtime_options.h"],

View File

@ -0,0 +1,64 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/metal/arguments.h"
#include "absl/strings/ascii.h"
namespace tflite {
namespace gpu {
namespace metal {
namespace {
bool IsWordSymbol(char symbol) {
return absl::ascii_isalnum(symbol) || symbol == '_';
}
bool HasWord(const std::string& word, const std::string& text) {
size_t pos = text.find(word);
while (pos != std::string::npos) {
char prev = pos == 0 ? '.' : text[pos - 1];
char next = pos + word.size() < text.size() ? text[pos + word.size()] : '.';
if (!IsWordSymbol(prev) & !IsWordSymbol(next)) {
return true;
}
pos = text.find(word, pos + 1);
}
return false;
}
} // namespace
// Static
constexpr char Arguments::kArgsPrefix[];
void Arguments::AddFloat(const std::string& name, float value) {
float_values_[name].value = value;
}
void Arguments::AddInt(const std::string& name, int value) {
int_values_[name].value = value;
}
void Arguments::GetActiveArguments(const std::string& code) {
for (auto& float_val : float_values_) {
float_val.second.active = HasWord(kArgsPrefix + float_val.first, code);
}
for (auto& int_val : int_values_) {
int_val.second.active = HasWord(kArgsPrefix + int_val.first, code);
}
}
} // namespace metal
} // namespace gpu
} // namespace tflite

View File

@ -0,0 +1,77 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_ARGUMENTS_H_
#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_ARGUMENTS_H_
#include <map>
#include <string>
#include "tensorflow/lite/delegates/gpu/common/status.h"
namespace tflite {
namespace gpu {
namespace metal {
class Arguments {
public:
Arguments() = default;
// Move only
Arguments(Arguments&& args) = default;
Arguments& operator=(Arguments&& args) = default;
Arguments(const Arguments&) = delete;
Arguments& operator=(const Arguments&) = delete;
void AddFloat(const std::string& name, float value = 0.0f);
void AddInt(const std::string& name, int value = 0);
private:
friend class MetalArguments;
void GetActiveArguments(const std::string& code);
static constexpr char kArgsPrefix[] = "args.";
struct IntValue {
int value;
// many arguments generated automatically and not used
// this flag active if argument was used in kernel_code
// Will be filled after GetActiveArguments call
bool active = false;
};
std::map<std::string, IntValue> int_values_;
struct FloatValue {
float value;
// many arguments generated automatically and not used
// this flag active if argument was used in kernel_code
// Will be filled after GetActiveArguments call
bool active = false;
};
std::map<std::string, FloatValue> float_values_;
};
class ArgumentsSetter {
public:
virtual absl::Status SetInt(const std::string& name, int value) = 0;
virtual absl::Status SetFloat(const std::string& name, float value) = 0;
virtual ~ArgumentsSetter() = default;
};
} // namespace metal
} // namespace gpu
} // namespace tflite
#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_ARGUMENTS_H_

View File

@ -535,6 +535,7 @@ ComputeTaskDescriptorPtr FuseChain(const FusionSequence& chain) {
uniform_index++;
fused_descriptor->uniform_buffers.push_back({"", buffer.data_function});
}
fused_descriptor->args = std::move(desc->args);
if (desc->is_linkable) {
call_code +=
@ -546,8 +547,8 @@ ComputeTaskDescriptorPtr FuseChain(const FusionSequence& chain) {
ComputeTaskDescriptorPtr non_linkable = sequence.front();
fused_descriptor->shader_source =
absl::Substitute(non_linkable->shader_source, function_code,
buffer_declarations, call_code);
absl::Substitute(non_linkable->shader_source, function_code + "$0",
buffer_declarations + "$1", call_code);
std::vector<ValueId> alias;
alias.reserve(chain.size() - 1);
for (int i = 0; i < chain.size() - 1; i++) {

View File

@ -158,6 +158,7 @@ static std::vector<ComputeTaskDescriptorPtr> Add2Linkable(int id, ValueId input_
ValueId input_id2, ValueId output_id) {
std::vector<ComputeTaskDescriptorPtr> descriptors;
descriptors.push_back(ComputeTaskDescriptorPtr(new ComputeTaskDescriptor({
{}, // args
id,
true, // linkable
true, // associative_op

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <string>
#include <tuple>
#include "tensorflow/lite/delegates/gpu/metal/metal_arguments.h"
#include "tensorflow/lite/delegates/gpu/common/model.h"
#include "tensorflow/lite/delegates/gpu/common/shape.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
@ -70,11 +71,15 @@ struct UniformBuffer {
uint3 _groupsCount;
DispatchParamsFunction _resizeFunction;
std::string _description;
tflite::gpu::metal::MetalArguments _metal_args;
}
- (absl::Status)compileWithDevice:(id<MTLDevice>)device
taskDescriptor:(ComputeTaskDescriptorPtr)desc
runtimeOptions:(const RuntimeOptions&)options {
size_t offset = desc->input_buffers.size() + desc->uniform_buffers.size()
+ desc->immutable_buffers.size() + 1;
RETURN_IF_ERROR(_metal_args.Init(offset, &desc->args, &desc->shader_source));
NSString* barrier;
// simdgroup_barrier is supported on macOS 10.13+ and Metal shading language version 2.0
if (@available(macOS 10.13, iOS 10.0, tvOS 10.0, *)) {
@ -251,6 +256,7 @@ struct UniformBuffer {
[encoder setBytes:uniform.data.data() length:uniform.data.size() atIndex:bindIndex];
bindIndex++;
}
_metal_args.Encode(encoder, bindIndex);
MTLSize groupsCount = MTLSizeMake(_groupsCount.x, _groupsCount.y, _groupsCount.z);
MTLSize groupsSize = MTLSizeMake(_groupsSize.x, _groupsSize.y, _groupsSize.z);

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/common/model.h"
#include "tensorflow/lite/delegates/gpu/common/shape.h"
#include "tensorflow/lite/delegates/gpu/common/types.h"
#include "tensorflow/lite/delegates/gpu/metal/arguments.h"
#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
namespace tflite {
@ -79,6 +80,7 @@ struct ComputeTaskDescriptor {
UniformsFunction data_function;
};
Arguments args;
// Unique ID to match the graph compilation errors.
int id;
bool is_linkable;

View File

@ -51,13 +51,6 @@ std::string GetFullyConnectedCode(const DeviceInfo& device_info,
#include <metal_stdlib>
using namespace metal;
struct uniforms {
uint src_depth;
uint dst_channels;
uint out_channels;
uint dummy;
};
$$0
kernel void ComputeFunction(
$$1
@ -71,11 +64,11 @@ std::string GetFullyConnectedCode(const DeviceInfo& device_info,
float summa = 0.0f;
threadgroup FLT4 local_vector[32];
for (int j = 0; j < $0; ++j) {
local_vector[tid_index] = j * 32 + tid_index >= params.src_depth ?
local_vector[tid_index] = j * 32 + tid_index >= args.src_slices ?
FLT4(0.0f) : vector[j * 32 + tid_index];
$1(mem_flags::mem_threadgroup);
for (uint i = 0, counter = j * 32 + tid.y * 8; i < 8; ++i, ++counter) {
summa += dot(local_vector[tid.y * 8 + i], matrix[counter * params.dst_channels + ugid.x]);
summa += dot(local_vector[tid.y * 8 + i], matrix[counter * args.dst_channels_alignedx8 + ugid.x]);
}
$1(mem_flags::mem_none);
}
@ -87,10 +80,10 @@ std::string GetFullyConnectedCode(const DeviceInfo& device_info,
for (uint i = 0; i < $0; ++i, ++counter) {
)";
if (src_depth % 4 != 0) {
code << " if (counter >= params.src_depth) continue;" << std::endl;
code << " if (counter >= args.src_slices) continue;" << std::endl;
}
code << " summa += dot(vector[counter], matrix[counter * "
"params.dst_channels + ugid.x]);"
"args.dst_channels_alignedx8 + ugid.x]);"
<< std::endl;
code << " }" << std::endl;
}
@ -106,7 +99,7 @@ std::string GetFullyConnectedCode(const DeviceInfo& device_info,
temp[tid.x][0] = summa;
}
$1(mem_flags::mem_threadgroup);
if (tid.y == 0 && tid.x % 4 == 0 && ugid.x < params.out_channels) {
if (tid.y == 0 && tid.x % 4 == 0 && ugid.x < args.dst_channels) {
const int linear_index = ugid.x / 4;
FLT4 value = FLT4(temp[tid.x][0], temp[tid.x + 1][0], temp[tid.x + 2][0], temp[tid.x + 3][0]) +
biases[linear_index];
@ -132,6 +125,11 @@ std::vector<ComputeTaskDescriptorPtr> FullyConnected(
desc->shader_source = GetFullyConnectedCode(device_info, attr.weights.shape.i,
attr.weights.shape.o);
desc->args.AddInt("dst_channels", attr.weights.shape.o);
desc->args.AddInt("src_slices", DivideRoundUp(attr.weights.shape.i, 4));
desc->args.AddInt("dst_channels_alignedx8",
AlignByN(attr.weights.shape.o, 8));
desc->input_buffers = {
{input_id, "device FLT4* const vector"},
};
@ -174,19 +172,6 @@ std::vector<ComputeTaskDescriptorPtr> FullyConnected(
attr.weights.shape.o)},
};
desc->uniform_buffers = {
{"constant uniforms& params",
[attr](const std::map<ValueId, BHWC>& buffers) {
std::vector<uint32_t> uniform_params{
static_cast<uint32_t>(DivideRoundUp(attr.weights.shape.i, 4)),
static_cast<uint32_t>(AlignByN(attr.weights.shape.o, 8)),
static_cast<uint32_t>(attr.weights.shape.o),
static_cast<uint32_t>(0),
};
return GetByteBuffer(uniform_params);
}},
};
desc->resize_function = [attr](const std::map<ValueId, BHWC>& buffers) {
const uint3 groups_size{8, 4, 1};
const int dst_channels_aligned = AlignByN(attr.weights.shape.o, 8);

View File

@ -0,0 +1,80 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_METAL_ARGUMENTS_H_
#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_METAL_ARGUMENTS_H_
#import <Metal/Metal.h>
#include <map>
#include <string>
#include <vector>
#include "tensorflow/lite/delegates/gpu/common/status.h"
#include "tensorflow/lite/delegates/gpu/metal/arguments.h"
namespace tflite {
namespace gpu {
namespace metal {
class MetalArguments : public ArgumentsSetter {
public:
MetalArguments() = default;
absl::Status Init(int buffer_offset, Arguments* args, std::string* code);
// Move only
MetalArguments(MetalArguments&& args) = default;
MetalArguments& operator=(MetalArguments&& args) = default;
MetalArguments(const MetalArguments&) = delete;
MetalArguments& operator=(const MetalArguments&) = delete;
absl::Status SetInt(const std::string& name, int value) override;
absl::Status SetFloat(const std::string& name, float value) override;
void Encode(id<MTLComputeCommandEncoder> encoder, int buffer_offset) const;
private:
static constexpr char kArgsPrefix[] = "args.";
struct IntValue {
int value;
// many arguments generated automatically and not used
// to reduce amount of data transferred we adding this optimization
bool active = false;
// offset to shared storage.
uint32_t bytes_offset = -1;
};
std::map<std::string, IntValue> int_values_;
struct FloatValue {
float value;
// many arguments generated automatically and not used
// to reduce amount of data transferred we adding this optimization
bool active = false;
// offset to shared storage.
uint32_t bytes_offset = -1;
};
std::map<std::string, FloatValue> float_values_;
std::vector<uint8_t> const_data_;
};
} // namespace metal
} // namespace gpu
} // namespace tflite
#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_METAL_ARGUMENTS_H_

View File

@ -0,0 +1,137 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/metal/metal_arguments.h"
#include <string>
#include "absl/strings/substitute.h"
#include "tensorflow/lite/delegates/gpu/common/util.h"
namespace tflite {
namespace gpu {
namespace metal {
namespace {
bool IsWordSymbol(char symbol) {
return absl::ascii_isalnum(symbol) || symbol == '_';
}
void ReplaceAllWords(const std::string& old_word, const std::string& new_word,
std::string* str) {
size_t position = str->find(old_word);
while (position != std::string::npos) {
char prev = position == 0 ? '.' : (*str)[position - 1];
char next = position + old_word.size() < str->size()
? (*str)[position + old_word.size()]
: '.';
if (IsWordSymbol(prev) || IsWordSymbol(next)) {
position = str->find(old_word, position + 1);
continue;
}
str->replace(position, old_word.size(), new_word);
position = str->find(old_word, position + new_word.size());
}
}
} // namespace
// Static
constexpr char MetalArguments::kArgsPrefix[];
absl::Status MetalArguments::Init(int buffer_offset, Arguments* args, std::string* code) {
args->GetActiveArguments(*code);
std::string struct_desc = "struct uniforms_buffer {\n";
std::string struct_decl;
int pos = 0;
for (auto& fvalue : args->float_values_) {
auto& new_val = float_values_[fvalue.first];
new_val.value = fvalue.second.value;
new_val.active = fvalue.second.active;
if (fvalue.second.active) {
new_val.bytes_offset = pos * 4;
pos++;
struct_desc += " float " + fvalue.first + ";\n";
ReplaceAllWords(kArgsPrefix + fvalue.first, "U." + fvalue.first, code);
}
}
for (auto& ivalue : args->int_values_) {
auto& new_val = int_values_[ivalue.first];
new_val.value = ivalue.second.value;
new_val.active = ivalue.second.active;
if (ivalue.second.active) {
new_val.bytes_offset = pos * 4;
pos++;
struct_desc += " int " + ivalue.first + ";\n";
ReplaceAllWords(kArgsPrefix + ivalue.first, "U." + ivalue.first, code);
}
}
if (pos != 0) {
struct_decl = "constant uniforms_buffer& U[[buffer(" + std::to_string(buffer_offset) + ")]],\n";
int aligned_pos = AlignByN(pos, 4);
for (int i = pos; i < aligned_pos; i++) {
struct_desc += " int dummy" + std::to_string(i - pos) + ";\n";
}
struct_desc += "};";
const_data_.resize(aligned_pos * 4);
for (auto& it : float_values_) {
float* ptr = reinterpret_cast<float*>(&const_data_[it.second.bytes_offset]);
*ptr = it.second.value;
}
for (auto& it : int_values_) {
int32_t* ptr = reinterpret_cast<int32_t*>(&const_data_[it.second.bytes_offset]);
*ptr = it.second.value;
}
} else {
struct_desc = "";
struct_decl = "";
}
*code = absl::Substitute(*code, struct_desc, struct_decl);
return absl::OkStatus();
}
absl::Status MetalArguments::SetInt(const std::string& name, int value) {
auto it = int_values_.find(name);
if (it == int_values_.end()) {
return absl::NotFoundError(
absl::StrCat("No int argument with name - ", name));
}
it->second.value = value;
if (it->second.active) {
int32_t* ptr = reinterpret_cast<int32_t*>(&const_data_[it->second.bytes_offset]);
*ptr = value;
}
return absl::OkStatus();
}
absl::Status MetalArguments::SetFloat(const std::string& name, float value) {
auto it = float_values_.find(name);
if (it == float_values_.end()) {
return absl::NotFoundError(
absl::StrCat("No float argument with name - ", name));
}
it->second.value = value;
if (it->second.active) {
float* ptr = reinterpret_cast<float*>(&const_data_[it->second.bytes_offset]);
*ptr = value;
}
return absl::OkStatus();
}
void MetalArguments::Encode(id<MTLComputeCommandEncoder> encoder, int buffer_offset) const {
if (!const_data_.empty()) {
[encoder setBytes:const_data_.data() length:const_data_.size() atIndex:buffer_offset];
}
}
} // namespace metal
} // namespace gpu
} // namespace tflite