Added generic API neutral Arguments for better(generic) task description.
Added MetalArguments for compute task. Showed example of usage in FullyConnected. PiperOrigin-RevId: 336932028 Change-Id: I86e9c4fd967a69ad18cf08b4e0c6acef2f3a681e
This commit is contained in:
parent
21b7abe1b6
commit
d92c3720a0
@ -38,6 +38,15 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "arguments",
|
||||
srcs = ["arguments.cc"],
|
||||
hdrs = ["arguments.h"],
|
||||
deps = [
|
||||
"//tensorflow/lite/delegates/gpu/common:status",
|
||||
],
|
||||
)
|
||||
|
||||
objc_library(
|
||||
name = "buffer_convert",
|
||||
srcs = ["buffer_convert.mm"],
|
||||
@ -134,6 +143,7 @@ objc_library(
|
||||
deps = [
|
||||
":common",
|
||||
":compute_task_descriptor",
|
||||
":metal_arguments",
|
||||
":runtime_options",
|
||||
"//tensorflow/lite/delegates/gpu/common:model",
|
||||
"//tensorflow/lite/delegates/gpu/common:shape",
|
||||
@ -149,6 +159,7 @@ objc_library(
|
||||
hdrs = ["compute_task_descriptor.h"],
|
||||
copts = DEFAULT_COPTS,
|
||||
deps = [
|
||||
":arguments",
|
||||
"//tensorflow/lite/delegates/gpu/common:model",
|
||||
"//tensorflow/lite/delegates/gpu/common:shape",
|
||||
"//tensorflow/lite/delegates/gpu/common:types",
|
||||
@ -213,6 +224,20 @@ ios_unit_test(
|
||||
deps = [":inference_context_test_lib"],
|
||||
)
|
||||
|
||||
objc_library(
|
||||
name = "metal_arguments",
|
||||
srcs = ["metal_arguments.mm"],
|
||||
hdrs = ["metal_arguments.h"],
|
||||
copts = DEFAULT_COPTS,
|
||||
sdk_frameworks = ["Metal"],
|
||||
deps = [
|
||||
":arguments",
|
||||
"//tensorflow/lite/delegates/gpu/common:status",
|
||||
"//tensorflow/lite/delegates/gpu/common:util",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "runtime_options",
|
||||
hdrs = ["runtime_options.h"],
|
||||
|
64
tensorflow/lite/delegates/gpu/metal/arguments.cc
Normal file
64
tensorflow/lite/delegates/gpu/metal/arguments.cc
Normal file
@ -0,0 +1,64 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/metal/arguments.h"
|
||||
|
||||
#include "absl/strings/ascii.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace metal {
|
||||
namespace {
|
||||
bool IsWordSymbol(char symbol) {
|
||||
return absl::ascii_isalnum(symbol) || symbol == '_';
|
||||
}
|
||||
|
||||
bool HasWord(const std::string& word, const std::string& text) {
|
||||
size_t pos = text.find(word);
|
||||
while (pos != std::string::npos) {
|
||||
char prev = pos == 0 ? '.' : text[pos - 1];
|
||||
char next = pos + word.size() < text.size() ? text[pos + word.size()] : '.';
|
||||
if (!IsWordSymbol(prev) & !IsWordSymbol(next)) {
|
||||
return true;
|
||||
}
|
||||
pos = text.find(word, pos + 1);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// Static
|
||||
constexpr char Arguments::kArgsPrefix[];
|
||||
|
||||
void Arguments::AddFloat(const std::string& name, float value) {
|
||||
float_values_[name].value = value;
|
||||
}
|
||||
|
||||
void Arguments::AddInt(const std::string& name, int value) {
|
||||
int_values_[name].value = value;
|
||||
}
|
||||
|
||||
void Arguments::GetActiveArguments(const std::string& code) {
|
||||
for (auto& float_val : float_values_) {
|
||||
float_val.second.active = HasWord(kArgsPrefix + float_val.first, code);
|
||||
}
|
||||
for (auto& int_val : int_values_) {
|
||||
int_val.second.active = HasWord(kArgsPrefix + int_val.first, code);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace metal
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
77
tensorflow/lite/delegates/gpu/metal/arguments.h
Normal file
77
tensorflow/lite/delegates/gpu/metal/arguments.h
Normal file
@ -0,0 +1,77 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_ARGUMENTS_H_
|
||||
#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_ARGUMENTS_H_
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace metal {
|
||||
|
||||
class Arguments {
|
||||
public:
|
||||
Arguments() = default;
|
||||
|
||||
// Move only
|
||||
Arguments(Arguments&& args) = default;
|
||||
Arguments& operator=(Arguments&& args) = default;
|
||||
Arguments(const Arguments&) = delete;
|
||||
Arguments& operator=(const Arguments&) = delete;
|
||||
|
||||
void AddFloat(const std::string& name, float value = 0.0f);
|
||||
void AddInt(const std::string& name, int value = 0);
|
||||
|
||||
private:
|
||||
friend class MetalArguments;
|
||||
void GetActiveArguments(const std::string& code);
|
||||
|
||||
static constexpr char kArgsPrefix[] = "args.";
|
||||
struct IntValue {
|
||||
int value;
|
||||
|
||||
// many arguments generated automatically and not used
|
||||
// this flag active if argument was used in kernel_code
|
||||
// Will be filled after GetActiveArguments call
|
||||
bool active = false;
|
||||
};
|
||||
std::map<std::string, IntValue> int_values_;
|
||||
|
||||
struct FloatValue {
|
||||
float value;
|
||||
|
||||
// many arguments generated automatically and not used
|
||||
// this flag active if argument was used in kernel_code
|
||||
// Will be filled after GetActiveArguments call
|
||||
bool active = false;
|
||||
};
|
||||
std::map<std::string, FloatValue> float_values_;
|
||||
};
|
||||
|
||||
class ArgumentsSetter {
|
||||
public:
|
||||
virtual absl::Status SetInt(const std::string& name, int value) = 0;
|
||||
virtual absl::Status SetFloat(const std::string& name, float value) = 0;
|
||||
virtual ~ArgumentsSetter() = default;
|
||||
};
|
||||
|
||||
} // namespace metal
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_ARGUMENTS_H_
|
@ -535,6 +535,7 @@ ComputeTaskDescriptorPtr FuseChain(const FusionSequence& chain) {
|
||||
uniform_index++;
|
||||
fused_descriptor->uniform_buffers.push_back({"", buffer.data_function});
|
||||
}
|
||||
fused_descriptor->args = std::move(desc->args);
|
||||
|
||||
if (desc->is_linkable) {
|
||||
call_code +=
|
||||
@ -546,8 +547,8 @@ ComputeTaskDescriptorPtr FuseChain(const FusionSequence& chain) {
|
||||
|
||||
ComputeTaskDescriptorPtr non_linkable = sequence.front();
|
||||
fused_descriptor->shader_source =
|
||||
absl::Substitute(non_linkable->shader_source, function_code,
|
||||
buffer_declarations, call_code);
|
||||
absl::Substitute(non_linkable->shader_source, function_code + "$0",
|
||||
buffer_declarations + "$1", call_code);
|
||||
std::vector<ValueId> alias;
|
||||
alias.reserve(chain.size() - 1);
|
||||
for (int i = 0; i < chain.size() - 1; i++) {
|
||||
|
@ -158,6 +158,7 @@ static std::vector<ComputeTaskDescriptorPtr> Add2Linkable(int id, ValueId input_
|
||||
ValueId input_id2, ValueId output_id) {
|
||||
std::vector<ComputeTaskDescriptorPtr> descriptors;
|
||||
descriptors.push_back(ComputeTaskDescriptorPtr(new ComputeTaskDescriptor({
|
||||
{}, // args
|
||||
id,
|
||||
true, // linkable
|
||||
true, // associative_op
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/metal/metal_arguments.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
@ -70,11 +71,15 @@ struct UniformBuffer {
|
||||
uint3 _groupsCount;
|
||||
DispatchParamsFunction _resizeFunction;
|
||||
std::string _description;
|
||||
tflite::gpu::metal::MetalArguments _metal_args;
|
||||
}
|
||||
|
||||
- (absl::Status)compileWithDevice:(id<MTLDevice>)device
|
||||
taskDescriptor:(ComputeTaskDescriptorPtr)desc
|
||||
runtimeOptions:(const RuntimeOptions&)options {
|
||||
size_t offset = desc->input_buffers.size() + desc->uniform_buffers.size()
|
||||
+ desc->immutable_buffers.size() + 1;
|
||||
RETURN_IF_ERROR(_metal_args.Init(offset, &desc->args, &desc->shader_source));
|
||||
NSString* barrier;
|
||||
// simdgroup_barrier is supported on macOS 10.13+ and Metal shading language version 2.0
|
||||
if (@available(macOS 10.13, iOS 10.0, tvOS 10.0, *)) {
|
||||
@ -251,6 +256,7 @@ struct UniformBuffer {
|
||||
[encoder setBytes:uniform.data.data() length:uniform.data.size() atIndex:bindIndex];
|
||||
bindIndex++;
|
||||
}
|
||||
_metal_args.Encode(encoder, bindIndex);
|
||||
|
||||
MTLSize groupsCount = MTLSizeMake(_groupsCount.x, _groupsCount.y, _groupsCount.z);
|
||||
MTLSize groupsSize = MTLSizeMake(_groupsSize.x, _groupsSize.y, _groupsSize.z);
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/arguments.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
|
||||
|
||||
namespace tflite {
|
||||
@ -79,6 +80,7 @@ struct ComputeTaskDescriptor {
|
||||
UniformsFunction data_function;
|
||||
};
|
||||
|
||||
Arguments args;
|
||||
// Unique ID to match the graph compilation errors.
|
||||
int id;
|
||||
bool is_linkable;
|
||||
|
@ -51,13 +51,6 @@ std::string GetFullyConnectedCode(const DeviceInfo& device_info,
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
|
||||
struct uniforms {
|
||||
uint src_depth;
|
||||
uint dst_channels;
|
||||
uint out_channels;
|
||||
uint dummy;
|
||||
};
|
||||
|
||||
$$0
|
||||
kernel void ComputeFunction(
|
||||
$$1
|
||||
@ -71,11 +64,11 @@ std::string GetFullyConnectedCode(const DeviceInfo& device_info,
|
||||
float summa = 0.0f;
|
||||
threadgroup FLT4 local_vector[32];
|
||||
for (int j = 0; j < $0; ++j) {
|
||||
local_vector[tid_index] = j * 32 + tid_index >= params.src_depth ?
|
||||
local_vector[tid_index] = j * 32 + tid_index >= args.src_slices ?
|
||||
FLT4(0.0f) : vector[j * 32 + tid_index];
|
||||
$1(mem_flags::mem_threadgroup);
|
||||
for (uint i = 0, counter = j * 32 + tid.y * 8; i < 8; ++i, ++counter) {
|
||||
summa += dot(local_vector[tid.y * 8 + i], matrix[counter * params.dst_channels + ugid.x]);
|
||||
summa += dot(local_vector[tid.y * 8 + i], matrix[counter * args.dst_channels_alignedx8 + ugid.x]);
|
||||
}
|
||||
$1(mem_flags::mem_none);
|
||||
}
|
||||
@ -87,10 +80,10 @@ std::string GetFullyConnectedCode(const DeviceInfo& device_info,
|
||||
for (uint i = 0; i < $0; ++i, ++counter) {
|
||||
)";
|
||||
if (src_depth % 4 != 0) {
|
||||
code << " if (counter >= params.src_depth) continue;" << std::endl;
|
||||
code << " if (counter >= args.src_slices) continue;" << std::endl;
|
||||
}
|
||||
code << " summa += dot(vector[counter], matrix[counter * "
|
||||
"params.dst_channels + ugid.x]);"
|
||||
"args.dst_channels_alignedx8 + ugid.x]);"
|
||||
<< std::endl;
|
||||
code << " }" << std::endl;
|
||||
}
|
||||
@ -106,7 +99,7 @@ std::string GetFullyConnectedCode(const DeviceInfo& device_info,
|
||||
temp[tid.x][0] = summa;
|
||||
}
|
||||
$1(mem_flags::mem_threadgroup);
|
||||
if (tid.y == 0 && tid.x % 4 == 0 && ugid.x < params.out_channels) {
|
||||
if (tid.y == 0 && tid.x % 4 == 0 && ugid.x < args.dst_channels) {
|
||||
const int linear_index = ugid.x / 4;
|
||||
FLT4 value = FLT4(temp[tid.x][0], temp[tid.x + 1][0], temp[tid.x + 2][0], temp[tid.x + 3][0]) +
|
||||
biases[linear_index];
|
||||
@ -132,6 +125,11 @@ std::vector<ComputeTaskDescriptorPtr> FullyConnected(
|
||||
desc->shader_source = GetFullyConnectedCode(device_info, attr.weights.shape.i,
|
||||
attr.weights.shape.o);
|
||||
|
||||
desc->args.AddInt("dst_channels", attr.weights.shape.o);
|
||||
desc->args.AddInt("src_slices", DivideRoundUp(attr.weights.shape.i, 4));
|
||||
desc->args.AddInt("dst_channels_alignedx8",
|
||||
AlignByN(attr.weights.shape.o, 8));
|
||||
|
||||
desc->input_buffers = {
|
||||
{input_id, "device FLT4* const vector"},
|
||||
};
|
||||
@ -174,19 +172,6 @@ std::vector<ComputeTaskDescriptorPtr> FullyConnected(
|
||||
attr.weights.shape.o)},
|
||||
};
|
||||
|
||||
desc->uniform_buffers = {
|
||||
{"constant uniforms& params",
|
||||
[attr](const std::map<ValueId, BHWC>& buffers) {
|
||||
std::vector<uint32_t> uniform_params{
|
||||
static_cast<uint32_t>(DivideRoundUp(attr.weights.shape.i, 4)),
|
||||
static_cast<uint32_t>(AlignByN(attr.weights.shape.o, 8)),
|
||||
static_cast<uint32_t>(attr.weights.shape.o),
|
||||
static_cast<uint32_t>(0),
|
||||
};
|
||||
return GetByteBuffer(uniform_params);
|
||||
}},
|
||||
};
|
||||
|
||||
desc->resize_function = [attr](const std::map<ValueId, BHWC>& buffers) {
|
||||
const uint3 groups_size{8, 4, 1};
|
||||
const int dst_channels_aligned = AlignByN(attr.weights.shape.o, 8);
|
||||
|
80
tensorflow/lite/delegates/gpu/metal/metal_arguments.h
Normal file
80
tensorflow/lite/delegates/gpu/metal/metal_arguments.h
Normal file
@ -0,0 +1,80 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_METAL_ARGUMENTS_H_
|
||||
#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_METAL_ARGUMENTS_H_
|
||||
|
||||
#import <Metal/Metal.h>
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/arguments.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace metal {
|
||||
|
||||
class MetalArguments : public ArgumentsSetter {
|
||||
public:
|
||||
MetalArguments() = default;
|
||||
|
||||
absl::Status Init(int buffer_offset, Arguments* args, std::string* code);
|
||||
|
||||
// Move only
|
||||
MetalArguments(MetalArguments&& args) = default;
|
||||
MetalArguments& operator=(MetalArguments&& args) = default;
|
||||
MetalArguments(const MetalArguments&) = delete;
|
||||
MetalArguments& operator=(const MetalArguments&) = delete;
|
||||
|
||||
absl::Status SetInt(const std::string& name, int value) override;
|
||||
absl::Status SetFloat(const std::string& name, float value) override;
|
||||
|
||||
void Encode(id<MTLComputeCommandEncoder> encoder, int buffer_offset) const;
|
||||
|
||||
private:
|
||||
static constexpr char kArgsPrefix[] = "args.";
|
||||
struct IntValue {
|
||||
int value;
|
||||
|
||||
// many arguments generated automatically and not used
|
||||
// to reduce amount of data transferred we adding this optimization
|
||||
bool active = false;
|
||||
|
||||
// offset to shared storage.
|
||||
uint32_t bytes_offset = -1;
|
||||
};
|
||||
std::map<std::string, IntValue> int_values_;
|
||||
|
||||
struct FloatValue {
|
||||
float value;
|
||||
|
||||
// many arguments generated automatically and not used
|
||||
// to reduce amount of data transferred we adding this optimization
|
||||
bool active = false;
|
||||
|
||||
// offset to shared storage.
|
||||
uint32_t bytes_offset = -1;
|
||||
};
|
||||
std::map<std::string, FloatValue> float_values_;
|
||||
std::vector<uint8_t> const_data_;
|
||||
};
|
||||
|
||||
} // namespace metal
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_DELEGATES_GPU_METAL_METAL_ARGUMENTS_H_
|
137
tensorflow/lite/delegates/gpu/metal/metal_arguments.mm
Normal file
137
tensorflow/lite/delegates/gpu/metal/metal_arguments.mm
Normal file
@ -0,0 +1,137 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/delegates/gpu/metal/metal_arguments.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace metal {
|
||||
namespace {
|
||||
bool IsWordSymbol(char symbol) {
|
||||
return absl::ascii_isalnum(symbol) || symbol == '_';
|
||||
}
|
||||
|
||||
void ReplaceAllWords(const std::string& old_word, const std::string& new_word,
|
||||
std::string* str) {
|
||||
size_t position = str->find(old_word);
|
||||
while (position != std::string::npos) {
|
||||
char prev = position == 0 ? '.' : (*str)[position - 1];
|
||||
char next = position + old_word.size() < str->size()
|
||||
? (*str)[position + old_word.size()]
|
||||
: '.';
|
||||
if (IsWordSymbol(prev) || IsWordSymbol(next)) {
|
||||
position = str->find(old_word, position + 1);
|
||||
continue;
|
||||
}
|
||||
str->replace(position, old_word.size(), new_word);
|
||||
position = str->find(old_word, position + new_word.size());
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// Static
|
||||
constexpr char MetalArguments::kArgsPrefix[];
|
||||
|
||||
absl::Status MetalArguments::Init(int buffer_offset, Arguments* args, std::string* code) {
|
||||
args->GetActiveArguments(*code);
|
||||
std::string struct_desc = "struct uniforms_buffer {\n";
|
||||
std::string struct_decl;
|
||||
int pos = 0;
|
||||
for (auto& fvalue : args->float_values_) {
|
||||
auto& new_val = float_values_[fvalue.first];
|
||||
new_val.value = fvalue.second.value;
|
||||
new_val.active = fvalue.second.active;
|
||||
if (fvalue.second.active) {
|
||||
new_val.bytes_offset = pos * 4;
|
||||
pos++;
|
||||
struct_desc += " float " + fvalue.first + ";\n";
|
||||
ReplaceAllWords(kArgsPrefix + fvalue.first, "U." + fvalue.first, code);
|
||||
}
|
||||
}
|
||||
for (auto& ivalue : args->int_values_) {
|
||||
auto& new_val = int_values_[ivalue.first];
|
||||
new_val.value = ivalue.second.value;
|
||||
new_val.active = ivalue.second.active;
|
||||
if (ivalue.second.active) {
|
||||
new_val.bytes_offset = pos * 4;
|
||||
pos++;
|
||||
struct_desc += " int " + ivalue.first + ";\n";
|
||||
ReplaceAllWords(kArgsPrefix + ivalue.first, "U." + ivalue.first, code);
|
||||
}
|
||||
}
|
||||
if (pos != 0) {
|
||||
struct_decl = "constant uniforms_buffer& U[[buffer(" + std::to_string(buffer_offset) + ")]],\n";
|
||||
int aligned_pos = AlignByN(pos, 4);
|
||||
for (int i = pos; i < aligned_pos; i++) {
|
||||
struct_desc += " int dummy" + std::to_string(i - pos) + ";\n";
|
||||
}
|
||||
struct_desc += "};";
|
||||
const_data_.resize(aligned_pos * 4);
|
||||
for (auto& it : float_values_) {
|
||||
float* ptr = reinterpret_cast<float*>(&const_data_[it.second.bytes_offset]);
|
||||
*ptr = it.second.value;
|
||||
}
|
||||
for (auto& it : int_values_) {
|
||||
int32_t* ptr = reinterpret_cast<int32_t*>(&const_data_[it.second.bytes_offset]);
|
||||
*ptr = it.second.value;
|
||||
}
|
||||
} else {
|
||||
struct_desc = "";
|
||||
struct_decl = "";
|
||||
}
|
||||
*code = absl::Substitute(*code, struct_desc, struct_decl);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status MetalArguments::SetInt(const std::string& name, int value) {
|
||||
auto it = int_values_.find(name);
|
||||
if (it == int_values_.end()) {
|
||||
return absl::NotFoundError(
|
||||
absl::StrCat("No int argument with name - ", name));
|
||||
}
|
||||
it->second.value = value;
|
||||
if (it->second.active) {
|
||||
int32_t* ptr = reinterpret_cast<int32_t*>(&const_data_[it->second.bytes_offset]);
|
||||
*ptr = value;
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
absl::Status MetalArguments::SetFloat(const std::string& name, float value) {
|
||||
auto it = float_values_.find(name);
|
||||
if (it == float_values_.end()) {
|
||||
return absl::NotFoundError(
|
||||
absl::StrCat("No float argument with name - ", name));
|
||||
}
|
||||
it->second.value = value;
|
||||
if (it->second.active) {
|
||||
float* ptr = reinterpret_cast<float*>(&const_data_[it->second.bytes_offset]);
|
||||
*ptr = value;
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void MetalArguments::Encode(id<MTLComputeCommandEncoder> encoder, int buffer_offset) const {
|
||||
if (!const_data_.empty()) {
|
||||
[encoder setBytes:const_data_.data() length:const_data_.size() atIndex:buffer_offset];
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace metal
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
Loading…
Reference in New Issue
Block a user