TFLite GPU: Replace tflite::gpu::Status with absl::Status.
PiperOrigin-RevId: 302720429 Change-Id: I5b7987e677dad4a335ab4dae9480cba8779706ca
This commit is contained in:
parent
cd661aab96
commit
f19161ecb7
@ -12,12 +12,6 @@ exports_files([
|
|||||||
"metal_delegate.h",
|
"metal_delegate.h",
|
||||||
])
|
])
|
||||||
|
|
||||||
# Primary purpose of this config is to replace ::util::Status with our custom
|
|
||||||
# light implementation ::tflite::gpu::StatusLite to reduce binary size. Besides
|
|
||||||
# that, certain features that were hard to communicate without full open source
|
|
||||||
# were hidden away too such as compiled models, serialization, and metadata.
|
|
||||||
# While the latter will be fully available with the open source release, the
|
|
||||||
# former will have to stay until absl::Status is released.
|
|
||||||
config_setting(
|
config_setting(
|
||||||
name = "tflite_gpu_binary_release",
|
name = "tflite_gpu_binary_release",
|
||||||
values = {"copt": "-DTFLITE_GPU_BINARY_RELEASE"},
|
values = {"copt": "-DTFLITE_GPU_BINARY_RELEASE"},
|
||||||
|
@ -220,7 +220,8 @@ class InferenceBuilder {
|
|||||||
|
|
||||||
// Sets new shape for the input if underlying implementation and graph
|
// Sets new shape for the input if underlying implementation and graph
|
||||||
// structure allows dynamic tensors.
|
// structure allows dynamic tensors.
|
||||||
virtual Status SetInputShape(int index, const Dimensions& dimensions) = 0;
|
virtual absl::Status SetInputShape(int index,
|
||||||
|
const Dimensions& dimensions) = 0;
|
||||||
|
|
||||||
// Updates object definitions for the given index. Implementation may allow
|
// Updates object definitions for the given index. Implementation may allow
|
||||||
// to use different layouts and/or data type conversions between objects
|
// to use different layouts and/or data type conversions between objects
|
||||||
@ -229,21 +230,21 @@ class InferenceBuilder {
|
|||||||
// A user, however, has an input in DataType::FLOAT16, DataLayout::PHWC4.
|
// A user, however, has an input in DataType::FLOAT16, DataLayout::PHWC4.
|
||||||
// An implementation may allow this transformation to happen automatically
|
// An implementation may allow this transformation to happen automatically
|
||||||
// under the hood.
|
// under the hood.
|
||||||
virtual Status SetInputObjectDef(int index, ObjectDef def) = 0;
|
virtual absl::Status SetInputObjectDef(int index, ObjectDef def) = 0;
|
||||||
virtual Status SetOutputObjectDef(int index, ObjectDef def) = 0;
|
virtual absl::Status SetOutputObjectDef(int index, ObjectDef def) = 0;
|
||||||
virtual Status SetAllInputObjectDefsTo(ObjectDef def) {
|
virtual absl::Status SetAllInputObjectDefsTo(ObjectDef def) {
|
||||||
auto input_defs = inputs();
|
auto input_defs = inputs();
|
||||||
for (int i = 0; i < input_defs.size(); ++i) {
|
for (int i = 0; i < input_defs.size(); ++i) {
|
||||||
RETURN_IF_ERROR(SetInputObjectDef(i, def));
|
RETURN_IF_ERROR(SetInputObjectDef(i, def));
|
||||||
}
|
}
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
virtual Status SetAllOutputObjectDefsTo(ObjectDef def) {
|
virtual absl::Status SetAllOutputObjectDefsTo(ObjectDef def) {
|
||||||
auto output_defs = outputs();
|
auto output_defs = outputs();
|
||||||
for (int i = 0; i < output_defs.size(); ++i) {
|
for (int i = 0; i < output_defs.size(); ++i) {
|
||||||
RETURN_IF_ERROR(SetOutputObjectDef(i, def));
|
RETURN_IF_ERROR(SetOutputObjectDef(i, def));
|
||||||
}
|
}
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Creates new instance of the inference runner. InferenceBuilder stays valid
|
// Creates new instance of the inference runner. InferenceBuilder stays valid
|
||||||
@ -251,7 +252,7 @@ class InferenceBuilder {
|
|||||||
//
|
//
|
||||||
// This method may take significant time to prepare new inference runner. For
|
// This method may take significant time to prepare new inference runner. For
|
||||||
// example, it may require to compile OpenGL shaders.
|
// example, it may require to compile OpenGL shaders.
|
||||||
virtual Status Build(std::unique_ptr<InferenceRunner>* runner) = 0;
|
virtual absl::Status Build(std::unique_ptr<InferenceRunner>* runner) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Runs prepared inference. Every object marked as external needs to be set
|
// Runs prepared inference. Every object marked as external needs to be set
|
||||||
@ -268,12 +269,12 @@ class InferenceRunner {
|
|||||||
// Setters allow to set or change external object for the given index. Note,
|
// Setters allow to set or change external object for the given index. Note,
|
||||||
// object need to match object definition set before in InferenceBuilder.
|
// object need to match object definition set before in InferenceBuilder.
|
||||||
|
|
||||||
virtual Status GetInputObject(int index, TensorObject* object) = 0;
|
virtual absl::Status GetInputObject(int index, TensorObject* object) = 0;
|
||||||
virtual Status GetOutputObject(int index, TensorObject* object) = 0;
|
virtual absl::Status GetOutputObject(int index, TensorObject* object) = 0;
|
||||||
virtual Status SetInputObject(int index, TensorObject object) = 0;
|
virtual absl::Status SetInputObject(int index, TensorObject object) = 0;
|
||||||
virtual Status SetOutputObject(int index, TensorObject object) = 0;
|
virtual absl::Status SetOutputObject(int index, TensorObject object) = 0;
|
||||||
|
|
||||||
virtual Status Run() = 0;
|
virtual absl::Status Run() = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Encapsulated compilation/runtime tradeoffs.
|
// Encapsulated compilation/runtime tradeoffs.
|
||||||
|
@ -54,22 +54,22 @@ class NoopTensorTie : public TensorTie {
|
|||||||
return def.external_def == def.internal_def;
|
return def.external_def == def.internal_def;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status SetExternalObject(TensorObject obj) final {
|
absl::Status SetExternalObject(TensorObject obj) final {
|
||||||
if (!def().external_def.object_def.user_provided) {
|
if (!def().external_def.object_def.user_provided) {
|
||||||
return InvalidArgumentError("Tensor object is readonly.");
|
return absl::InvalidArgumentError("Tensor object is readonly.");
|
||||||
}
|
}
|
||||||
if (!IsValid(def().external_def, obj)) {
|
if (!IsValid(def().external_def, obj)) {
|
||||||
return InvalidArgumentError("Given object is not valid");
|
return absl::InvalidArgumentError("Given object is not valid");
|
||||||
}
|
}
|
||||||
obj_ = obj;
|
obj_ = obj;
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
TensorObject GetExternalObject() final { return obj_; }
|
TensorObject GetExternalObject() final { return obj_; }
|
||||||
|
|
||||||
Status CopyToExternalObject() final { return OkStatus(); }
|
absl::Status CopyToExternalObject() final { return absl::OkStatus(); }
|
||||||
|
|
||||||
Status CopyFromExternalObject() final { return OkStatus(); }
|
absl::Status CopyFromExternalObject() final { return absl::OkStatus(); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
TensorObject obj_;
|
TensorObject obj_;
|
||||||
@ -93,45 +93,45 @@ class DefaultTensorTie : public TensorTie {
|
|||||||
converter_builder.IsSupported(def.external_def, def.internal_def);
|
converter_builder.IsSupported(def.external_def, def.internal_def);
|
||||||
}
|
}
|
||||||
|
|
||||||
static Status New(const TensorTieDef& def, TensorObject internal_object,
|
static absl::Status New(const TensorTieDef& def, TensorObject internal_object,
|
||||||
TensorObjectConverterBuilder* converter_builder,
|
TensorObjectConverterBuilder* converter_builder,
|
||||||
Environment* env, std::unique_ptr<TensorTie>* tie) {
|
Environment* env, std::unique_ptr<TensorTie>* tie) {
|
||||||
auto tie_impl = absl::make_unique<DefaultTensorTie>(def, internal_object);
|
auto tie_impl = absl::make_unique<DefaultTensorTie>(def, internal_object);
|
||||||
RETURN_IF_ERROR(tie_impl->Init(converter_builder, env));
|
RETURN_IF_ERROR(tie_impl->Init(converter_builder, env));
|
||||||
*tie = std::move(tie_impl);
|
*tie = std::move(tie_impl);
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CopyToExternalObject() final {
|
absl::Status CopyToExternalObject() final {
|
||||||
if (!converter_to_) {
|
if (!converter_to_) {
|
||||||
return UnavailableError("Conversion is not available");
|
return absl::UnavailableError("Conversion is not available");
|
||||||
}
|
}
|
||||||
return converter_to_->Convert(internal_obj_, GetExternalObject());
|
return converter_to_->Convert(internal_obj_, GetExternalObject());
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CopyFromExternalObject() final {
|
absl::Status CopyFromExternalObject() final {
|
||||||
if (!converter_from_) {
|
if (!converter_from_) {
|
||||||
return UnavailableError("Conversion is not available");
|
return absl::UnavailableError("Conversion is not available");
|
||||||
}
|
}
|
||||||
return converter_from_->Convert(GetExternalObject(), internal_obj_);
|
return converter_from_->Convert(GetExternalObject(), internal_obj_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status SetExternalObject(TensorObject obj) final {
|
absl::Status SetExternalObject(TensorObject obj) final {
|
||||||
if (!def().external_def.object_def.user_provided) {
|
if (!def().external_def.object_def.user_provided) {
|
||||||
return InvalidArgumentError("External object is read-only");
|
return absl::InvalidArgumentError("External object is read-only");
|
||||||
}
|
}
|
||||||
if (!IsValid(def().external_def, obj)) {
|
if (!IsValid(def().external_def, obj)) {
|
||||||
return InvalidArgumentError("Given object is not valid");
|
return absl::InvalidArgumentError("Given object is not valid");
|
||||||
}
|
}
|
||||||
external_obj_ = obj;
|
external_obj_ = obj;
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
TensorObject GetExternalObject() final { return external_obj_; }
|
TensorObject GetExternalObject() final { return external_obj_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Status Init(TensorObjectConverterBuilder* converter_builder,
|
absl::Status Init(TensorObjectConverterBuilder* converter_builder,
|
||||||
Environment* env) {
|
Environment* env) {
|
||||||
RETURN_IF_ERROR(converter_builder->MakeConverter(
|
RETURN_IF_ERROR(converter_builder->MakeConverter(
|
||||||
def().internal_def, def().external_def, &converter_to_));
|
def().internal_def, def().external_def, &converter_to_));
|
||||||
RETURN_IF_ERROR(converter_builder->MakeConverter(
|
RETURN_IF_ERROR(converter_builder->MakeConverter(
|
||||||
@ -139,10 +139,10 @@ class DefaultTensorTie : public TensorTie {
|
|||||||
return MaybeAllocateExternalObject(env);
|
return MaybeAllocateExternalObject(env);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status MaybeAllocateExternalObject(Environment* env) {
|
absl::Status MaybeAllocateExternalObject(Environment* env) {
|
||||||
const TensorObjectDef& d = def().external_def;
|
const TensorObjectDef& d = def().external_def;
|
||||||
if (d.object_def.user_provided) {
|
if (d.object_def.user_provided) {
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
switch (d.object_def.object_type) {
|
switch (d.object_def.object_type) {
|
||||||
case ObjectType::CPU_MEMORY: {
|
case ObjectType::CPU_MEMORY: {
|
||||||
@ -170,9 +170,9 @@ class DefaultTensorTie : public TensorTie {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return InternalError("Unexpected object type");
|
return absl::InternalError("Unexpected object type");
|
||||||
}
|
}
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
const TensorObject internal_obj_;
|
const TensorObject internal_obj_;
|
||||||
@ -198,26 +198,26 @@ class TwoStepTensorTie : public TensorTie {
|
|||||||
DefaultTensorTie::IsSupported(defs.second, converter_builder);
|
DefaultTensorTie::IsSupported(defs.second, converter_builder);
|
||||||
}
|
}
|
||||||
|
|
||||||
static Status New(const TensorTieDef& def, TensorObject internal_object,
|
static absl::Status New(const TensorTieDef& def, TensorObject internal_object,
|
||||||
TensorObjectConverterBuilder* converter_builder,
|
TensorObjectConverterBuilder* converter_builder,
|
||||||
Environment* env, std::unique_ptr<TensorTie>* tie) {
|
Environment* env, std::unique_ptr<TensorTie>* tie) {
|
||||||
auto tie_impl = absl::make_unique<TwoStepTensorTie>(def);
|
auto tie_impl = absl::make_unique<TwoStepTensorTie>(def);
|
||||||
RETURN_IF_ERROR(tie_impl->Init(internal_object, converter_builder, env));
|
RETURN_IF_ERROR(tie_impl->Init(internal_object, converter_builder, env));
|
||||||
*tie = std::move(tie_impl);
|
*tie = std::move(tie_impl);
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CopyToExternalObject() final {
|
absl::Status CopyToExternalObject() final {
|
||||||
RETURN_IF_ERROR(inner_tie_->CopyToExternalObject());
|
RETURN_IF_ERROR(inner_tie_->CopyToExternalObject());
|
||||||
return outer_tie_->CopyToExternalObject();
|
return outer_tie_->CopyToExternalObject();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CopyFromExternalObject() final {
|
absl::Status CopyFromExternalObject() final {
|
||||||
RETURN_IF_ERROR(outer_tie_->CopyFromExternalObject());
|
RETURN_IF_ERROR(outer_tie_->CopyFromExternalObject());
|
||||||
return inner_tie_->CopyFromExternalObject();
|
return inner_tie_->CopyFromExternalObject();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status SetExternalObject(TensorObject obj) final {
|
absl::Status SetExternalObject(TensorObject obj) final {
|
||||||
return outer_tie_->SetExternalObject(obj);
|
return outer_tie_->SetExternalObject(obj);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -241,9 +241,9 @@ class TwoStepTensorTie : public TensorTie {
|
|||||||
return std::make_pair(outer_def, inner_def);
|
return std::make_pair(outer_def, inner_def);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Init(TensorObject internal_object,
|
absl::Status Init(TensorObject internal_object,
|
||||||
TensorObjectConverterBuilder* converter_builder,
|
TensorObjectConverterBuilder* converter_builder,
|
||||||
Environment* env) {
|
Environment* env) {
|
||||||
auto defs = MakeOuterInnerDefs(def());
|
auto defs = MakeOuterInnerDefs(def());
|
||||||
RETURN_IF_ERROR(DefaultTensorTie::New(defs.second, internal_object,
|
RETURN_IF_ERROR(DefaultTensorTie::New(defs.second, internal_object,
|
||||||
converter_builder, env, &inner_tie_));
|
converter_builder, env, &inner_tie_));
|
||||||
@ -274,27 +274,27 @@ class GlBufferHolder : public TensorTie {
|
|||||||
return DefaultTensorTie::IsSupported(MakeClDef(def), converter_builder);
|
return DefaultTensorTie::IsSupported(MakeClDef(def), converter_builder);
|
||||||
}
|
}
|
||||||
|
|
||||||
static Status New(const TensorTieDef& def, TensorObject internal_object,
|
static absl::Status New(const TensorTieDef& def, TensorObject internal_object,
|
||||||
TensorObjectConverterBuilder* converter_builder,
|
TensorObjectConverterBuilder* converter_builder,
|
||||||
GlInteropFabric* gl_interop_fabric, Environment* env,
|
GlInteropFabric* gl_interop_fabric, Environment* env,
|
||||||
std::unique_ptr<TensorTie>* tie) {
|
std::unique_ptr<TensorTie>* tie) {
|
||||||
auto tie_impl =
|
auto tie_impl =
|
||||||
absl::make_unique<GlBufferHolder>(def, gl_interop_fabric, env);
|
absl::make_unique<GlBufferHolder>(def, gl_interop_fabric, env);
|
||||||
RETURN_IF_ERROR(DefaultTensorTie::New(MakeClDef(def), internal_object,
|
RETURN_IF_ERROR(DefaultTensorTie::New(MakeClDef(def), internal_object,
|
||||||
converter_builder, env,
|
converter_builder, env,
|
||||||
&tie_impl->tie_));
|
&tie_impl->tie_));
|
||||||
*tie = std::move(tie_impl);
|
*tie = std::move(tie_impl);
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status SetExternalObject(TensorObject obj) final {
|
absl::Status SetExternalObject(TensorObject obj) final {
|
||||||
auto ssbo = absl::get_if<OpenGlBuffer>(&obj);
|
auto ssbo = absl::get_if<OpenGlBuffer>(&obj);
|
||||||
if (!ssbo) {
|
if (!ssbo) {
|
||||||
return InvalidArgumentError("Missing OpenGL SSBO");
|
return absl::InvalidArgumentError("Missing OpenGL SSBO");
|
||||||
}
|
}
|
||||||
auto old_ssbo = absl::get_if<OpenGlBuffer>(&external_obj_);
|
auto old_ssbo = absl::get_if<OpenGlBuffer>(&external_obj_);
|
||||||
if (old_ssbo && ssbo->id == old_ssbo->id) {
|
if (old_ssbo && ssbo->id == old_ssbo->id) {
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
if (cl_object_.memory()) {
|
if (cl_object_.memory()) {
|
||||||
gl_interop_fabric_->UnregisterMemory(cl_object_.memory());
|
gl_interop_fabric_->UnregisterMemory(cl_object_.memory());
|
||||||
@ -304,16 +304,18 @@ class GlBufferHolder : public TensorTie {
|
|||||||
external_obj_ = obj;
|
external_obj_ = obj;
|
||||||
RETURN_IF_ERROR(tie_->SetExternalObject(OpenClBuffer{cl_object_.memory()}));
|
RETURN_IF_ERROR(tie_->SetExternalObject(OpenClBuffer{cl_object_.memory()}));
|
||||||
gl_interop_fabric_->RegisterMemory(cl_object_.memory());
|
gl_interop_fabric_->RegisterMemory(cl_object_.memory());
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
TensorObject GetExternalObject() final { return external_obj_; }
|
TensorObject GetExternalObject() final { return external_obj_; }
|
||||||
|
|
||||||
Status CopyFromExternalObject() final {
|
absl::Status CopyFromExternalObject() final {
|
||||||
return tie_->CopyFromExternalObject();
|
return tie_->CopyFromExternalObject();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CopyToExternalObject() final { return tie_->CopyToExternalObject(); }
|
absl::Status CopyToExternalObject() final {
|
||||||
|
return tie_->CopyToExternalObject();
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static TensorTieDef MakeClDef(const TensorTieDef& def) {
|
static TensorTieDef MakeClDef(const TensorTieDef& def) {
|
||||||
@ -358,20 +360,20 @@ class TensorTieFactory {
|
|||||||
TwoStepTensorTie::IsSupported(def, *converter_builder_));
|
TwoStepTensorTie::IsSupported(def, *converter_builder_));
|
||||||
}
|
}
|
||||||
|
|
||||||
Status NewTensorTie(const TensorTieDef& def,
|
absl::Status NewTensorTie(const TensorTieDef& def,
|
||||||
std::unique_ptr<TensorTie>* tie) {
|
std::unique_ptr<TensorTie>* tie) {
|
||||||
TensorObject internal_object = TensorToObj(*context_.GetTensor(def.id));
|
TensorObject internal_object = TensorToObj(*context_.GetTensor(def.id));
|
||||||
auto converter = converter_builder_.get();
|
auto converter = converter_builder_.get();
|
||||||
if (NoopTensorTie::IsSupported(def)) {
|
if (NoopTensorTie::IsSupported(def)) {
|
||||||
*tie = absl::make_unique<NoopTensorTie>(def, internal_object);
|
*tie = absl::make_unique<NoopTensorTie>(def, internal_object);
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
if (DefaultTensorTie::IsSupported(def, *converter)) {
|
if (DefaultTensorTie::IsSupported(def, *converter)) {
|
||||||
return DefaultTensorTie::New(def, internal_object, converter, &env_, tie);
|
return DefaultTensorTie::New(def, internal_object, converter, &env_, tie);
|
||||||
}
|
}
|
||||||
if (GlBufferHolder::IsSupported(def, *converter)) {
|
if (GlBufferHolder::IsSupported(def, *converter)) {
|
||||||
if (!gl_interop_fabric_) {
|
if (!gl_interop_fabric_) {
|
||||||
return InvalidArgumentError(
|
return absl::InvalidArgumentError(
|
||||||
"GL object is used but InferenceEnvironmentOptions does not have "
|
"GL object is used but InferenceEnvironmentOptions does not have "
|
||||||
"EGL display and context set.");
|
"EGL display and context set.");
|
||||||
}
|
}
|
||||||
@ -381,7 +383,7 @@ class TensorTieFactory {
|
|||||||
if (TwoStepTensorTie::IsSupported(def, *converter)) {
|
if (TwoStepTensorTie::IsSupported(def, *converter)) {
|
||||||
return TwoStepTensorTie::New(def, internal_object, converter, &env_, tie);
|
return TwoStepTensorTie::New(def, internal_object, converter, &env_, tie);
|
||||||
}
|
}
|
||||||
return UnimplementedError("Unsupported tensor tie definition.");
|
return absl::UnimplementedError("Unsupported tensor tie definition.");
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -400,9 +402,9 @@ class InferenceRunnerImpl : public InferenceRunner {
|
|||||||
context_(std::move(context)),
|
context_(std::move(context)),
|
||||||
gl_interop_fabric_(std::move(gl_interop_fabric)) {}
|
gl_interop_fabric_(std::move(gl_interop_fabric)) {}
|
||||||
|
|
||||||
Status Initialize(const std::vector<TensorTieDef>& inputs,
|
absl::Status Initialize(const std::vector<TensorTieDef>& inputs,
|
||||||
const std::vector<TensorTieDef>& outputs,
|
const std::vector<TensorTieDef>& outputs,
|
||||||
TensorTieFactory* factory) {
|
TensorTieFactory* factory) {
|
||||||
RETURN_IF_ERROR(LinkTensors(inputs, factory, &inputs_));
|
RETURN_IF_ERROR(LinkTensors(inputs, factory, &inputs_));
|
||||||
return LinkTensors(outputs, factory, &outputs_);
|
return LinkTensors(outputs, factory, &outputs_);
|
||||||
}
|
}
|
||||||
@ -415,37 +417,37 @@ class InferenceRunnerImpl : public InferenceRunner {
|
|||||||
return GetExternalDefinitions(outputs_);
|
return GetExternalDefinitions(outputs_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GetInputObject(int index, TensorObject* object) override {
|
absl::Status GetInputObject(int index, TensorObject* object) override {
|
||||||
if (index < 0 || index >= inputs_.size()) {
|
if (index < 0 || index >= inputs_.size()) {
|
||||||
return OutOfRangeError("Index is out of range");
|
return absl::OutOfRangeError("Index is out of range");
|
||||||
}
|
}
|
||||||
*object = inputs_[index]->GetExternalObject();
|
*object = inputs_[index]->GetExternalObject();
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GetOutputObject(int index, TensorObject* object) override {
|
absl::Status GetOutputObject(int index, TensorObject* object) override {
|
||||||
if (index < 0 || index >= outputs_.size()) {
|
if (index < 0 || index >= outputs_.size()) {
|
||||||
return OutOfRangeError("Index is out of range");
|
return absl::OutOfRangeError("Index is out of range");
|
||||||
}
|
}
|
||||||
*object = outputs_[index]->GetExternalObject();
|
*object = outputs_[index]->GetExternalObject();
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status SetInputObject(int index, TensorObject object) override {
|
absl::Status SetInputObject(int index, TensorObject object) override {
|
||||||
if (index < 0 || index >= inputs_.size()) {
|
if (index < 0 || index >= inputs_.size()) {
|
||||||
return OutOfRangeError("Index is out of range");
|
return absl::OutOfRangeError("Index is out of range");
|
||||||
}
|
}
|
||||||
return inputs_[index]->SetExternalObject(object);
|
return inputs_[index]->SetExternalObject(object);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status SetOutputObject(int index, TensorObject object) override {
|
absl::Status SetOutputObject(int index, TensorObject object) override {
|
||||||
if (index < 0 || index >= outputs_.size()) {
|
if (index < 0 || index >= outputs_.size()) {
|
||||||
return OutOfRangeError("Index is out of range");
|
return absl::OutOfRangeError("Index is out of range");
|
||||||
}
|
}
|
||||||
return outputs_[index]->SetExternalObject(object);
|
return outputs_[index]->SetExternalObject(object);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Run() override {
|
absl::Status Run() override {
|
||||||
if (gl_interop_fabric_) {
|
if (gl_interop_fabric_) {
|
||||||
RETURN_IF_ERROR(gl_interop_fabric_->Start());
|
RETURN_IF_ERROR(gl_interop_fabric_->Start());
|
||||||
}
|
}
|
||||||
@ -460,20 +462,20 @@ class InferenceRunnerImpl : public InferenceRunner {
|
|||||||
if (gl_interop_fabric_) {
|
if (gl_interop_fabric_) {
|
||||||
RETURN_IF_ERROR(gl_interop_fabric_->Finish());
|
RETURN_IF_ERROR(gl_interop_fabric_->Finish());
|
||||||
}
|
}
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static Status LinkTensors(const std::vector<TensorTieDef>& defs,
|
static absl::Status LinkTensors(
|
||||||
TensorTieFactory* factory,
|
const std::vector<TensorTieDef>& defs, TensorTieFactory* factory,
|
||||||
std::vector<std::unique_ptr<TensorTie>>* objects) {
|
std::vector<std::unique_ptr<TensorTie>>* objects) {
|
||||||
objects->reserve(defs.size());
|
objects->reserve(defs.size());
|
||||||
for (auto& def : defs) {
|
for (auto& def : defs) {
|
||||||
std::unique_ptr<TensorTie> object;
|
std::unique_ptr<TensorTie> object;
|
||||||
RETURN_IF_ERROR(factory->NewTensorTie(def, &object));
|
RETURN_IF_ERROR(factory->NewTensorTie(def, &object));
|
||||||
objects->push_back(std::move(object));
|
objects->push_back(std::move(object));
|
||||||
}
|
}
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::vector<TensorObjectDef> GetExternalDefinitions(
|
static std::vector<TensorObjectDef> GetExternalDefinitions(
|
||||||
@ -511,9 +513,9 @@ class InferenceBuilderImpl : public InferenceBuilder {
|
|||||||
explicit InferenceBuilderImpl(Environment* environment)
|
explicit InferenceBuilderImpl(Environment* environment)
|
||||||
: environment_(environment) {}
|
: environment_(environment) {}
|
||||||
|
|
||||||
Status Initialize(const InferenceOptions& options,
|
absl::Status Initialize(const InferenceOptions& options,
|
||||||
const InferenceEnvironmentOptions& env_options,
|
const InferenceEnvironmentOptions& env_options,
|
||||||
const GraphFloat32& graph) {
|
const GraphFloat32& graph) {
|
||||||
context_ = absl::make_unique<InferenceContext>();
|
context_ = absl::make_unique<InferenceContext>();
|
||||||
InferenceContext::CreateInferenceInfo create_info;
|
InferenceContext::CreateInferenceInfo create_info;
|
||||||
create_info.precision = GetPrecision(options);
|
create_info.precision = GetPrecision(options);
|
||||||
@ -533,7 +535,7 @@ class InferenceBuilderImpl : public InferenceBuilder {
|
|||||||
|
|
||||||
inputs_ = LinkTensors(graph, graph.inputs());
|
inputs_ = LinkTensors(graph, graph.inputs());
|
||||||
outputs_ = LinkTensors(graph, graph.outputs());
|
outputs_ = LinkTensors(graph, graph.outputs());
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<TensorObjectDef> inputs() const override {
|
std::vector<TensorObjectDef> inputs() const override {
|
||||||
@ -544,40 +546,42 @@ class InferenceBuilderImpl : public InferenceBuilder {
|
|||||||
return GetExternalDefinitions(outputs_);
|
return GetExternalDefinitions(outputs_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status SetInputShape(int index, const Dimensions& dimensions) override {
|
absl::Status SetInputShape(int index, const Dimensions& dimensions) override {
|
||||||
if (index < 0 || index >= inputs_.size()) {
|
if (index < 0 || index >= inputs_.size()) {
|
||||||
return OutOfRangeError("Index is out of range");
|
return absl::OutOfRangeError("Index is out of range");
|
||||||
}
|
}
|
||||||
return UnimplementedError("Changing input shapes is not supported");
|
return absl::UnimplementedError("Changing input shapes is not supported");
|
||||||
}
|
}
|
||||||
|
|
||||||
Status SetInputObjectDef(int index, ObjectDef new_def) override {
|
absl::Status SetInputObjectDef(int index, ObjectDef new_def) override {
|
||||||
if (index < 0 || index >= inputs_.size()) {
|
if (index < 0 || index >= inputs_.size()) {
|
||||||
return OutOfRangeError("Index is out of range");
|
return absl::OutOfRangeError("Index is out of range");
|
||||||
}
|
}
|
||||||
auto def = inputs_[index];
|
auto def = inputs_[index];
|
||||||
def.external_def.object_def = new_def;
|
def.external_def.object_def = new_def;
|
||||||
if (!tie_factory_->IsSupported(def)) {
|
if (!tie_factory_->IsSupported(def)) {
|
||||||
return InvalidArgumentError("New object definition is not supported.");
|
return absl::InvalidArgumentError(
|
||||||
|
"New object definition is not supported.");
|
||||||
}
|
}
|
||||||
inputs_[index] = def;
|
inputs_[index] = def;
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status SetOutputObjectDef(int index, ObjectDef new_def) override {
|
absl::Status SetOutputObjectDef(int index, ObjectDef new_def) override {
|
||||||
if (index < 0 || index >= outputs_.size()) {
|
if (index < 0 || index >= outputs_.size()) {
|
||||||
return OutOfRangeError("Index is out of range");
|
return absl::OutOfRangeError("Index is out of range");
|
||||||
}
|
}
|
||||||
auto def = outputs_[index];
|
auto def = outputs_[index];
|
||||||
def.external_def.object_def = new_def;
|
def.external_def.object_def = new_def;
|
||||||
if (!tie_factory_->IsSupported(def)) {
|
if (!tie_factory_->IsSupported(def)) {
|
||||||
return InvalidArgumentError("New object definition is not supported.");
|
return absl::InvalidArgumentError(
|
||||||
|
"New object definition is not supported.");
|
||||||
}
|
}
|
||||||
outputs_[index] = def;
|
outputs_[index] = def;
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Build(std::unique_ptr<InferenceRunner>* runner) override {
|
absl::Status Build(std::unique_ptr<InferenceRunner>* runner) override {
|
||||||
if (gl_interop_fabric_ && !HasGlObjects()) {
|
if (gl_interop_fabric_ && !HasGlObjects()) {
|
||||||
// destroy interop layer when there are no GL objects to avoid
|
// destroy interop layer when there are no GL objects to avoid
|
||||||
// extra synchronization cost.
|
// extra synchronization cost.
|
||||||
@ -588,7 +592,7 @@ class InferenceBuilderImpl : public InferenceBuilder {
|
|||||||
RETURN_IF_ERROR(
|
RETURN_IF_ERROR(
|
||||||
runner_impl->Initialize(inputs_, outputs_, tie_factory_.get()));
|
runner_impl->Initialize(inputs_, outputs_, tie_factory_.get()));
|
||||||
*runner = std::move(runner_impl);
|
*runner = std::move(runner_impl);
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -696,7 +700,7 @@ class InferenceEnvironmentImpl : public InferenceEnvironment {
|
|||||||
explicit InferenceEnvironmentImpl(const InferenceEnvironmentOptions& options)
|
explicit InferenceEnvironmentImpl(const InferenceEnvironmentOptions& options)
|
||||||
: options_(options) {}
|
: options_(options) {}
|
||||||
|
|
||||||
Status Init() {
|
absl::Status Init() {
|
||||||
RETURN_IF_ERROR(LoadOpenCL());
|
RETURN_IF_ERROR(LoadOpenCL());
|
||||||
properties_.is_opencl_available = true;
|
properties_.is_opencl_available = true;
|
||||||
|
|
||||||
@ -716,13 +720,13 @@ class InferenceEnvironmentImpl : public InferenceEnvironment {
|
|||||||
properties_.is_cl_to_gl_fast_sync_supported =
|
properties_.is_cl_to_gl_fast_sync_supported =
|
||||||
IsEglSyncFromClEventSupported();
|
IsEglSyncFromClEventSupported();
|
||||||
if (options_.IsGlAware() && !properties_.is_gl_sharing_supported) {
|
if (options_.IsGlAware() && !properties_.is_gl_sharing_supported) {
|
||||||
return UnavailableError("GL sharing is not supported");
|
return absl::UnavailableError("GL sharing is not supported");
|
||||||
}
|
}
|
||||||
|
|
||||||
CLContext context;
|
CLContext context;
|
||||||
if (options_.context) {
|
if (options_.context) {
|
||||||
if (options_.IsGlAware()) {
|
if (options_.IsGlAware()) {
|
||||||
return InvalidArgumentError(
|
return absl::InvalidArgumentError(
|
||||||
"OpenCL context and EGL parameters are set in the same time.");
|
"OpenCL context and EGL parameters are set in the same time.");
|
||||||
}
|
}
|
||||||
context = CLContext(options_.context, /* has_ownership = */ false);
|
context = CLContext(options_.context, /* has_ownership = */ false);
|
||||||
@ -754,11 +758,11 @@ class InferenceEnvironmentImpl : public InferenceEnvironment {
|
|||||||
return environment_.Init();
|
return environment_.Init();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status NewInferenceBuilder(const InferenceOptions& options,
|
absl::Status NewInferenceBuilder(
|
||||||
GraphFloat32 model,
|
const InferenceOptions& options, GraphFloat32 model,
|
||||||
std::unique_ptr<InferenceBuilder>* builder) final {
|
std::unique_ptr<InferenceBuilder>* builder) final {
|
||||||
if (!IsValid(options)) {
|
if (!IsValid(options)) {
|
||||||
return InvalidArgumentError("InferenceOptions are invalid.");
|
return absl::InvalidArgumentError("InferenceOptions are invalid.");
|
||||||
}
|
}
|
||||||
InferenceOptions resolved_options = options;
|
InferenceOptions resolved_options = options;
|
||||||
ResolveAutoPriority(&resolved_options);
|
ResolveAutoPriority(&resolved_options);
|
||||||
@ -776,7 +780,7 @@ class InferenceEnvironmentImpl : public InferenceEnvironment {
|
|||||||
RETURN_IF_ERROR(
|
RETURN_IF_ERROR(
|
||||||
builder_impl->Initialize(resolved_options, options_, model));
|
builder_impl->Initialize(resolved_options, options_, model));
|
||||||
*builder = std::move(builder_impl);
|
*builder = std::move(builder_impl);
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<uint8_t> GetSerializedBinaryCache() const final {
|
std::vector<uint8_t> GetSerializedBinaryCache() const final {
|
||||||
@ -800,18 +804,18 @@ class InferenceEnvironmentImpl : public InferenceEnvironment {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Status NewInferenceEnvironment(
|
absl::Status NewInferenceEnvironment(
|
||||||
const InferenceEnvironmentOptions& options,
|
const InferenceEnvironmentOptions& options,
|
||||||
std::unique_ptr<InferenceEnvironment>* environment,
|
std::unique_ptr<InferenceEnvironment>* environment,
|
||||||
InferenceEnvironmentProperties* properties) {
|
InferenceEnvironmentProperties* properties) {
|
||||||
auto env_impl = absl::make_unique<InferenceEnvironmentImpl>(options);
|
auto env_impl = absl::make_unique<InferenceEnvironmentImpl>(options);
|
||||||
Status status = env_impl->Init();
|
absl::Status status = env_impl->Init();
|
||||||
if (properties) {
|
if (properties) {
|
||||||
*properties = env_impl->properties();
|
*properties = env_impl->properties();
|
||||||
}
|
}
|
||||||
RETURN_IF_ERROR(status);
|
RETURN_IF_ERROR(status);
|
||||||
*environment = std::move(env_impl);
|
*environment = std::move(env_impl);
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cl
|
} // namespace cl
|
||||||
|
@ -70,7 +70,7 @@ class InferenceEnvironment {
|
|||||||
public:
|
public:
|
||||||
virtual ~InferenceEnvironment() {}
|
virtual ~InferenceEnvironment() {}
|
||||||
|
|
||||||
virtual Status NewInferenceBuilder(
|
virtual absl::Status NewInferenceBuilder(
|
||||||
const InferenceOptions& options, GraphFloat32 model,
|
const InferenceOptions& options, GraphFloat32 model,
|
||||||
std::unique_ptr<InferenceBuilder>* builder) = 0;
|
std::unique_ptr<InferenceBuilder>* builder) = 0;
|
||||||
|
|
||||||
@ -112,7 +112,7 @@ struct InferenceEnvironmentOptions {
|
|||||||
|
|
||||||
// Creates new OpenCL environment that needs to stay around until all inference
|
// Creates new OpenCL environment that needs to stay around until all inference
|
||||||
// runners are destroyed.
|
// runners are destroyed.
|
||||||
Status NewInferenceEnvironment(
|
absl::Status NewInferenceEnvironment(
|
||||||
const InferenceEnvironmentOptions& options,
|
const InferenceEnvironmentOptions& options,
|
||||||
std::unique_ptr<InferenceEnvironment>* environment,
|
std::unique_ptr<InferenceEnvironment>* environment,
|
||||||
InferenceEnvironmentProperties* properties /* optional */);
|
InferenceEnvironmentProperties* properties /* optional */);
|
||||||
|
@ -21,8 +21,10 @@ namespace tflite {
|
|||||||
namespace gpu {
|
namespace gpu {
|
||||||
namespace cl {
|
namespace cl {
|
||||||
namespace {
|
namespace {
|
||||||
Status CreateBuffer(size_t size_in_bytes, bool gpu_read_only, const void* data,
|
|
||||||
CLContext* context, Buffer* result) {
|
absl::Status CreateBuffer(size_t size_in_bytes, bool gpu_read_only,
|
||||||
|
const void* data, CLContext* context,
|
||||||
|
Buffer* result) {
|
||||||
cl_mem_flags flags = gpu_read_only ? CL_MEM_READ_ONLY : CL_MEM_READ_WRITE;
|
cl_mem_flags flags = gpu_read_only ? CL_MEM_READ_ONLY : CL_MEM_READ_WRITE;
|
||||||
if (data != nullptr) {
|
if (data != nullptr) {
|
||||||
flags |= CL_MEM_COPY_HOST_PTR;
|
flags |= CL_MEM_COPY_HOST_PTR;
|
||||||
@ -31,14 +33,14 @@ Status CreateBuffer(size_t size_in_bytes, bool gpu_read_only, const void* data,
|
|||||||
cl_mem buffer = clCreateBuffer(context->context(), flags, size_in_bytes,
|
cl_mem buffer = clCreateBuffer(context->context(), flags, size_in_bytes,
|
||||||
const_cast<void*>(data), &error_code);
|
const_cast<void*>(data), &error_code);
|
||||||
if (!buffer) {
|
if (!buffer) {
|
||||||
return UnknownError(
|
return absl::UnknownError(
|
||||||
absl::StrCat("Failed to allocate device memory with clCreateBuffer",
|
absl::StrCat("Failed to allocate device memory with clCreateBuffer",
|
||||||
CLErrorCodeToString(error_code)));
|
CLErrorCodeToString(error_code)));
|
||||||
}
|
}
|
||||||
|
|
||||||
*result = Buffer(buffer, size_in_bytes);
|
*result = Buffer(buffer, size_in_bytes);
|
||||||
|
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
@ -69,18 +71,18 @@ void Buffer::Release() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateReadOnlyBuffer(size_t size_in_bytes, CLContext* context,
|
absl::Status CreateReadOnlyBuffer(size_t size_in_bytes, CLContext* context,
|
||||||
Buffer* result) {
|
Buffer* result) {
|
||||||
return CreateBuffer(size_in_bytes, true, nullptr, context, result);
|
return CreateBuffer(size_in_bytes, true, nullptr, context, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateReadOnlyBuffer(size_t size_in_bytes, const void* data,
|
absl::Status CreateReadOnlyBuffer(size_t size_in_bytes, const void* data,
|
||||||
CLContext* context, Buffer* result) {
|
CLContext* context, Buffer* result) {
|
||||||
return CreateBuffer(size_in_bytes, true, data, context, result);
|
return CreateBuffer(size_in_bytes, true, data, context, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateReadWriteBuffer(size_t size_in_bytes, CLContext* context,
|
absl::Status CreateReadWriteBuffer(size_t size_in_bytes, CLContext* context,
|
||||||
Buffer* result) {
|
Buffer* result) {
|
||||||
return CreateBuffer(size_in_bytes, false, nullptr, context, result);
|
return CreateBuffer(size_in_bytes, false, nullptr, context, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -51,11 +51,11 @@ class Buffer {
|
|||||||
// Writes data to a buffer. Data should point to a region that
|
// Writes data to a buffer. Data should point to a region that
|
||||||
// has exact size in bytes as size_in_bytes(constructor parameter).
|
// has exact size in bytes as size_in_bytes(constructor parameter).
|
||||||
template <typename T>
|
template <typename T>
|
||||||
Status WriteData(CLCommandQueue* queue, const absl::Span<T> data);
|
absl::Status WriteData(CLCommandQueue* queue, const absl::Span<T> data);
|
||||||
|
|
||||||
// Reads data from Buffer into CPU memory.
|
// Reads data from Buffer into CPU memory.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
Status ReadData(CLCommandQueue* queue, std::vector<T>* result) const;
|
absl::Status ReadData(CLCommandQueue* queue, std::vector<T>* result) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void Release();
|
void Release();
|
||||||
@ -64,29 +64,31 @@ class Buffer {
|
|||||||
size_t size_;
|
size_t size_;
|
||||||
};
|
};
|
||||||
|
|
||||||
Status CreateReadOnlyBuffer(size_t size_in_bytes, CLContext* context,
|
absl::Status CreateReadOnlyBuffer(size_t size_in_bytes, CLContext* context,
|
||||||
Buffer* result);
|
Buffer* result);
|
||||||
|
|
||||||
Status CreateReadOnlyBuffer(size_t size_in_bytes, const void* data,
|
absl::Status CreateReadOnlyBuffer(size_t size_in_bytes, const void* data,
|
||||||
CLContext* context, Buffer* result);
|
CLContext* context, Buffer* result);
|
||||||
|
|
||||||
Status CreateReadWriteBuffer(size_t size_in_bytes, CLContext* context,
|
absl::Status CreateReadWriteBuffer(size_t size_in_bytes, CLContext* context,
|
||||||
Buffer* result);
|
Buffer* result);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
Status Buffer::WriteData(CLCommandQueue* queue, const absl::Span<T> data) {
|
absl::Status Buffer::WriteData(CLCommandQueue* queue,
|
||||||
|
const absl::Span<T> data) {
|
||||||
if (size_ != sizeof(T) * data.size()) {
|
if (size_ != sizeof(T) * data.size()) {
|
||||||
return InvalidArgumentError(
|
return absl::InvalidArgumentError(
|
||||||
"absl::Span<T> data size is different from buffer allocated size.");
|
"absl::Span<T> data size is different from buffer allocated size.");
|
||||||
}
|
}
|
||||||
RETURN_IF_ERROR(queue->EnqueueWriteBuffer(buffer_, size_, data.data()));
|
RETURN_IF_ERROR(queue->EnqueueWriteBuffer(buffer_, size_, data.data()));
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
Status Buffer::ReadData(CLCommandQueue* queue, std::vector<T>* result) const {
|
absl::Status Buffer::ReadData(CLCommandQueue* queue,
|
||||||
|
std::vector<T>* result) const {
|
||||||
if (size_ % sizeof(T) != 0) {
|
if (size_ % sizeof(T) != 0) {
|
||||||
return UnknownError("Wrong element size(typename T is not correct?");
|
return absl::UnknownError("Wrong element size(typename T is not correct?");
|
||||||
}
|
}
|
||||||
|
|
||||||
const int elements_count = size_ / sizeof(T);
|
const int elements_count = size_ / sizeof(T);
|
||||||
|
@ -56,8 +56,9 @@ void CLCommandQueue::Release() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CLCommandQueue::DispatchImplicit(const CLKernel& kernel, int3 grid,
|
absl::Status CLCommandQueue::DispatchImplicit(const CLKernel& kernel, int3 grid,
|
||||||
int3 work_group_size, CLEvent* event) {
|
int3 work_group_size,
|
||||||
|
CLEvent* event) {
|
||||||
std::vector<size_t> local(3);
|
std::vector<size_t> local(3);
|
||||||
std::vector<size_t> global(3);
|
std::vector<size_t> global(3);
|
||||||
for (int i = 0; i < 3; ++i) {
|
for (int i = 0; i < 3; ++i) {
|
||||||
@ -72,30 +73,31 @@ Status CLCommandQueue::DispatchImplicit(const CLKernel& kernel, int3 grid,
|
|||||||
*event = CLEvent(resulting_event);
|
*event = CLEvent(resulting_event);
|
||||||
}
|
}
|
||||||
if (error_code != CL_SUCCESS) {
|
if (error_code != CL_SUCCESS) {
|
||||||
return UnknownError(absl::StrCat("Failed to clEnqueueNDRangeKernel - ",
|
return absl::UnknownError(
|
||||||
CLErrorCodeToString(error_code)));
|
absl::StrCat("Failed to clEnqueueNDRangeKernel - ",
|
||||||
|
CLErrorCodeToString(error_code)));
|
||||||
}
|
}
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CLCommandQueue::DispatchImplicit(const CLKernel& kernel, int3 grid,
|
absl::Status CLCommandQueue::DispatchImplicit(const CLKernel& kernel, int3 grid,
|
||||||
int3 work_group_size) {
|
int3 work_group_size) {
|
||||||
return DispatchImplicit(kernel, grid, work_group_size, nullptr);
|
return DispatchImplicit(kernel, grid, work_group_size, nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CLCommandQueue::EnqueueEvent(CLEvent* event) {
|
absl::Status CLCommandQueue::EnqueueEvent(CLEvent* event) {
|
||||||
cl_event resulting_event;
|
cl_event resulting_event;
|
||||||
const int error_code = clEnqueueMarker(queue_, &resulting_event);
|
const int error_code = clEnqueueMarker(queue_, &resulting_event);
|
||||||
*event = CLEvent(resulting_event);
|
*event = CLEvent(resulting_event);
|
||||||
if (error_code != CL_SUCCESS) {
|
if (error_code != CL_SUCCESS) {
|
||||||
return UnknownError(absl::StrCat("Failed to clEnqueueMarker - ",
|
return absl::UnknownError(absl::StrCat("Failed to clEnqueueMarker - ",
|
||||||
CLErrorCodeToString(error_code)));
|
CLErrorCodeToString(error_code)));
|
||||||
}
|
}
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CLCommandQueue::EnqueueWriteImage(cl_mem memory, int3 region,
|
absl::Status CLCommandQueue::EnqueueWriteImage(cl_mem memory, int3 region,
|
||||||
const void* data) {
|
const void* data) {
|
||||||
const size_t origin[] = {0, 0, 0};
|
const size_t origin[] = {0, 0, 0};
|
||||||
const size_t r[] = {static_cast<size_t>(region.x),
|
const size_t r[] = {static_cast<size_t>(region.x),
|
||||||
static_cast<size_t>(region.y),
|
static_cast<size_t>(region.y),
|
||||||
@ -103,16 +105,16 @@ Status CLCommandQueue::EnqueueWriteImage(cl_mem memory, int3 region,
|
|||||||
auto error_code = clEnqueueWriteImage(queue_, memory, CL_TRUE, origin, r, 0,
|
auto error_code = clEnqueueWriteImage(queue_, memory, CL_TRUE, origin, r, 0,
|
||||||
0, data, 0, nullptr, nullptr);
|
0, data, 0, nullptr, nullptr);
|
||||||
if (error_code != CL_SUCCESS) {
|
if (error_code != CL_SUCCESS) {
|
||||||
return UnknownError(
|
return absl::UnknownError(
|
||||||
absl::StrCat("Failed to upload data to GPU (clEnqueueWriteImage) - ",
|
absl::StrCat("Failed to upload data to GPU (clEnqueueWriteImage) - ",
|
||||||
CLErrorCodeToString(error_code)));
|
CLErrorCodeToString(error_code)));
|
||||||
}
|
}
|
||||||
|
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CLCommandQueue::EnqueueReadImage(cl_mem memory, int3 region,
|
absl::Status CLCommandQueue::EnqueueReadImage(cl_mem memory, int3 region,
|
||||||
void* data) {
|
void* data) {
|
||||||
const size_t origin[] = {0, 0, 0};
|
const size_t origin[] = {0, 0, 0};
|
||||||
const size_t r[] = {static_cast<size_t>(region.x),
|
const size_t r[] = {static_cast<size_t>(region.x),
|
||||||
static_cast<size_t>(region.y),
|
static_cast<size_t>(region.y),
|
||||||
@ -120,45 +122,47 @@ Status CLCommandQueue::EnqueueReadImage(cl_mem memory, int3 region,
|
|||||||
auto error_code = clEnqueueReadImage(queue_, memory, CL_TRUE, origin, r, 0, 0,
|
auto error_code = clEnqueueReadImage(queue_, memory, CL_TRUE, origin, r, 0, 0,
|
||||||
data, 0, nullptr, nullptr);
|
data, 0, nullptr, nullptr);
|
||||||
if (error_code != CL_SUCCESS) {
|
if (error_code != CL_SUCCESS) {
|
||||||
return UnknownError(
|
return absl::UnknownError(
|
||||||
absl::StrCat("Failed to read data from GPU (clEnqueueReadImage) - ",
|
absl::StrCat("Failed to read data from GPU (clEnqueueReadImage) - ",
|
||||||
CLErrorCodeToString(error_code)));
|
CLErrorCodeToString(error_code)));
|
||||||
}
|
}
|
||||||
|
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CLCommandQueue::EnqueueWriteBuffer(cl_mem memory, size_t size_in_bytes,
|
absl::Status CLCommandQueue::EnqueueWriteBuffer(cl_mem memory,
|
||||||
const void* data) {
|
size_t size_in_bytes,
|
||||||
|
const void* data) {
|
||||||
auto error_code = clEnqueueWriteBuffer(
|
auto error_code = clEnqueueWriteBuffer(
|
||||||
queue_, memory, CL_TRUE, 0, size_in_bytes, data, 0, nullptr, nullptr);
|
queue_, memory, CL_TRUE, 0, size_in_bytes, data, 0, nullptr, nullptr);
|
||||||
if (error_code != CL_SUCCESS) {
|
if (error_code != CL_SUCCESS) {
|
||||||
return UnknownError(
|
return absl::UnknownError(
|
||||||
absl::StrCat("Failed to upload data to GPU (clEnqueueWriteBuffer) - ",
|
absl::StrCat("Failed to upload data to GPU (clEnqueueWriteBuffer) - ",
|
||||||
CLErrorCodeToString(error_code)));
|
CLErrorCodeToString(error_code)));
|
||||||
}
|
}
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CLCommandQueue::EnqueueReadBuffer(cl_mem memory, size_t size_in_bytes,
|
absl::Status CLCommandQueue::EnqueueReadBuffer(cl_mem memory,
|
||||||
void* data) {
|
size_t size_in_bytes,
|
||||||
|
void* data) {
|
||||||
auto error_code = clEnqueueReadBuffer(
|
auto error_code = clEnqueueReadBuffer(
|
||||||
queue_, memory, CL_TRUE, 0, size_in_bytes, data, 0, nullptr, nullptr);
|
queue_, memory, CL_TRUE, 0, size_in_bytes, data, 0, nullptr, nullptr);
|
||||||
if (error_code != CL_SUCCESS) {
|
if (error_code != CL_SUCCESS) {
|
||||||
return UnknownError(
|
return absl::UnknownError(
|
||||||
absl::StrCat("Failed to read data from GPU (clEnqueueReadBuffer) - ",
|
absl::StrCat("Failed to read data from GPU (clEnqueueReadBuffer) - ",
|
||||||
CLErrorCodeToString(error_code)));
|
CLErrorCodeToString(error_code)));
|
||||||
}
|
}
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CLCommandQueue::WaitForCompletion() {
|
absl::Status CLCommandQueue::WaitForCompletion() {
|
||||||
auto error_code = clFinish(queue_);
|
auto error_code = clFinish(queue_);
|
||||||
if (error_code != CL_SUCCESS) {
|
if (error_code != CL_SUCCESS) {
|
||||||
return UnknownError(
|
return absl::UnknownError(
|
||||||
absl::StrCat("Failed to clFinish - ", CLErrorCodeToString(error_code)));
|
absl::StrCat("Failed to clFinish - ", CLErrorCodeToString(error_code)));
|
||||||
}
|
}
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
ProfilingCommandQueue::ProfilingCommandQueue(cl_command_queue queue)
|
ProfilingCommandQueue::ProfilingCommandQueue(cl_command_queue queue)
|
||||||
@ -187,14 +191,14 @@ void ProfilingCommandQueue::SetEventsLabel(const std::string& name) {
|
|||||||
|
|
||||||
void ProfilingCommandQueue::ResetMeasurements() { events_.clear(); }
|
void ProfilingCommandQueue::ResetMeasurements() { events_.clear(); }
|
||||||
|
|
||||||
Status ProfilingCommandQueue::DispatchImplicit(const CLKernel& kernel,
|
absl::Status ProfilingCommandQueue::DispatchImplicit(const CLKernel& kernel,
|
||||||
int3 grid,
|
int3 grid,
|
||||||
int3 work_group_size) {
|
int3 work_group_size) {
|
||||||
events_.push_back(CLEvent());
|
events_.push_back(CLEvent());
|
||||||
RETURN_IF_ERROR(CLCommandQueue::DispatchImplicit(
|
RETURN_IF_ERROR(CLCommandQueue::DispatchImplicit(
|
||||||
kernel, grid, work_group_size, &events_[events_.size() - 1]));
|
kernel, grid, work_group_size, &events_[events_.size() - 1]));
|
||||||
events_.back().SetName(current_label_);
|
events_.back().SetName(current_label_);
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
ProfilingInfo ProfilingCommandQueue::GetProfilingInfo() const {
|
ProfilingInfo ProfilingCommandQueue::GetProfilingInfo() const {
|
||||||
@ -208,7 +212,7 @@ ProfilingInfo ProfilingCommandQueue::GetProfilingInfo() const {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ProfilingCommandQueue::GetBestWorkGroupIndex(
|
absl::Status ProfilingCommandQueue::GetBestWorkGroupIndex(
|
||||||
const CLKernel& kernel, const DeviceInfo& device_info, const int3& grid,
|
const CLKernel& kernel, const DeviceInfo& device_info, const int3& grid,
|
||||||
const std::vector<int3>& work_group_sizes, int* index) {
|
const std::vector<int3>& work_group_sizes, int* index) {
|
||||||
// Some Adreno 3xx can have wrong numbers for some events
|
// Some Adreno 3xx can have wrong numbers for some events
|
||||||
@ -268,20 +272,22 @@ Status ProfilingCommandQueue::GetBestWorkGroupIndex(
|
|||||||
|
|
||||||
*index = minimum_index;
|
*index = minimum_index;
|
||||||
|
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateCLCommandQueue(const CLDevice& device, const CLContext& context,
|
absl::Status CreateCLCommandQueue(const CLDevice& device,
|
||||||
CLCommandQueue* result) {
|
const CLContext& context,
|
||||||
|
CLCommandQueue* result) {
|
||||||
int error_code;
|
int error_code;
|
||||||
cl_command_queue queue =
|
cl_command_queue queue =
|
||||||
clCreateCommandQueue(context.context(), device.id(), 0, &error_code);
|
clCreateCommandQueue(context.context(), device.id(), 0, &error_code);
|
||||||
if (!queue) {
|
if (!queue) {
|
||||||
return UnknownError(absl::StrCat("Failed to create a command queue - ",
|
return absl::UnknownError(
|
||||||
CLErrorCodeToString(error_code)));
|
absl::StrCat("Failed to create a command queue - ",
|
||||||
|
CLErrorCodeToString(error_code)));
|
||||||
}
|
}
|
||||||
*result = CLCommandQueue(queue, true);
|
*result = CLCommandQueue(queue, true);
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
double ProfilingCommandQueue::GetQueueExecutionTimeMs() const {
|
double ProfilingCommandQueue::GetQueueExecutionTimeMs() const {
|
||||||
@ -300,19 +306,20 @@ double ProfilingCommandQueue::GetSumOfEventsTimeMs() const {
|
|||||||
return sum;
|
return sum;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateProfilingCommandQueue(const CLDevice& device,
|
absl::Status CreateProfilingCommandQueue(const CLDevice& device,
|
||||||
const CLContext& context,
|
const CLContext& context,
|
||||||
ProfilingCommandQueue* result) {
|
ProfilingCommandQueue* result) {
|
||||||
int error_code;
|
int error_code;
|
||||||
cl_command_queue queue = clCreateCommandQueue(
|
cl_command_queue queue = clCreateCommandQueue(
|
||||||
context.context(), device.id(), CL_QUEUE_PROFILING_ENABLE, &error_code);
|
context.context(), device.id(), CL_QUEUE_PROFILING_ENABLE, &error_code);
|
||||||
if (!queue) {
|
if (!queue) {
|
||||||
return UnknownError(absl::StrCat("Failed to create a command queue - ",
|
return absl::UnknownError(
|
||||||
CLErrorCodeToString(error_code)));
|
absl::StrCat("Failed to create a command queue - ",
|
||||||
|
CLErrorCodeToString(error_code)));
|
||||||
}
|
}
|
||||||
|
|
||||||
*result = ProfilingCommandQueue(queue);
|
*result = ProfilingCommandQueue(queue);
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Duration ProfilingInfo::GetTotalTime() const {
|
absl::Duration ProfilingInfo::GetTotalTime() const {
|
||||||
|
@ -74,22 +74,23 @@ class CLCommandQueue {
|
|||||||
|
|
||||||
cl_command_queue queue() const { return queue_; }
|
cl_command_queue queue() const { return queue_; }
|
||||||
|
|
||||||
virtual Status DispatchImplicit(const CLKernel& kernel, int3 grid,
|
virtual absl::Status DispatchImplicit(const CLKernel& kernel, int3 grid,
|
||||||
int3 work_group_size);
|
int3 work_group_size);
|
||||||
|
|
||||||
Status EnqueueEvent(CLEvent* event);
|
absl::Status EnqueueEvent(CLEvent* event);
|
||||||
|
|
||||||
Status DispatchImplicit(const CLKernel& kernel, int3 grid,
|
absl::Status DispatchImplicit(const CLKernel& kernel, int3 grid,
|
||||||
int3 work_group_size, CLEvent* event);
|
int3 work_group_size, CLEvent* event);
|
||||||
|
|
||||||
Status EnqueueWriteImage(cl_mem memory, int3 region, const void* data);
|
absl::Status EnqueueWriteImage(cl_mem memory, int3 region, const void* data);
|
||||||
Status EnqueueReadImage(cl_mem memory, int3 region, void* data);
|
absl::Status EnqueueReadImage(cl_mem memory, int3 region, void* data);
|
||||||
|
|
||||||
Status EnqueueWriteBuffer(cl_mem memory, size_t size_in_bytes,
|
absl::Status EnqueueWriteBuffer(cl_mem memory, size_t size_in_bytes,
|
||||||
const void* data);
|
const void* data);
|
||||||
Status EnqueueReadBuffer(cl_mem memory, size_t size_in_bytes, void* data);
|
absl::Status EnqueueReadBuffer(cl_mem memory, size_t size_in_bytes,
|
||||||
|
void* data);
|
||||||
|
|
||||||
Status WaitForCompletion();
|
absl::Status WaitForCompletion();
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
void Release();
|
void Release();
|
||||||
@ -109,14 +110,15 @@ class ProfilingCommandQueue : public CLCommandQueue {
|
|||||||
ProfilingCommandQueue(const ProfilingCommandQueue&) = delete;
|
ProfilingCommandQueue(const ProfilingCommandQueue&) = delete;
|
||||||
ProfilingCommandQueue& operator=(const ProfilingCommandQueue&) = delete;
|
ProfilingCommandQueue& operator=(const ProfilingCommandQueue&) = delete;
|
||||||
|
|
||||||
Status DispatchImplicit(const CLKernel& kernel, int3 grid,
|
absl::Status DispatchImplicit(const CLKernel& kernel, int3 grid,
|
||||||
int3 work_group_size) override;
|
int3 work_group_size) override;
|
||||||
|
|
||||||
// will write index for fastest work_group among work_group_sizes
|
// will write index for fastest work_group among work_group_sizes
|
||||||
Status GetBestWorkGroupIndex(const CLKernel& kernel,
|
absl::Status GetBestWorkGroupIndex(const CLKernel& kernel,
|
||||||
const DeviceInfo& device_info, const int3& grid,
|
const DeviceInfo& device_info,
|
||||||
const std::vector<int3>& work_group_sizes,
|
const int3& grid,
|
||||||
int* index);
|
const std::vector<int3>& work_group_sizes,
|
||||||
|
int* index);
|
||||||
|
|
||||||
// call ResetMeasurements() to start new seriese of measurements
|
// call ResetMeasurements() to start new seriese of measurements
|
||||||
void ResetMeasurements();
|
void ResetMeasurements();
|
||||||
@ -139,12 +141,13 @@ class ProfilingCommandQueue : public CLCommandQueue {
|
|||||||
std::string current_label_;
|
std::string current_label_;
|
||||||
};
|
};
|
||||||
|
|
||||||
Status CreateCLCommandQueue(const CLDevice& device, const CLContext& context,
|
absl::Status CreateCLCommandQueue(const CLDevice& device,
|
||||||
CLCommandQueue* result);
|
const CLContext& context,
|
||||||
|
CLCommandQueue* result);
|
||||||
|
|
||||||
Status CreateProfilingCommandQueue(const CLDevice& device,
|
absl::Status CreateProfilingCommandQueue(const CLDevice& device,
|
||||||
const CLContext& context,
|
const CLContext& context,
|
||||||
ProfilingCommandQueue* result);
|
ProfilingCommandQueue* result);
|
||||||
|
|
||||||
} // namespace cl
|
} // namespace cl
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
|
@ -43,19 +43,21 @@ std::vector<cl_image_format> GetSupportedImage2DFormats(cl_context context,
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateCLContext(const CLDevice& device,
|
absl::Status CreateCLContext(const CLDevice& device,
|
||||||
cl_context_properties* properties, CLContext* result) {
|
cl_context_properties* properties,
|
||||||
|
CLContext* result) {
|
||||||
int error_code;
|
int error_code;
|
||||||
cl_device_id device_id = device.id();
|
cl_device_id device_id = device.id();
|
||||||
cl_context context =
|
cl_context context =
|
||||||
clCreateContext(properties, 1, &device_id, nullptr, nullptr, &error_code);
|
clCreateContext(properties, 1, &device_id, nullptr, nullptr, &error_code);
|
||||||
if (!context) {
|
if (!context) {
|
||||||
return UnknownError(absl::StrCat("Failed to create a compute context - ",
|
return absl::UnknownError(
|
||||||
CLErrorCodeToString(error_code)));
|
absl::StrCat("Failed to create a compute context - ",
|
||||||
|
CLErrorCodeToString(error_code)));
|
||||||
}
|
}
|
||||||
|
|
||||||
*result = CLContext(context, true);
|
*result = CLContext(context, true);
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
@ -99,15 +101,16 @@ bool CLContext::IsFloatTexture2DSupported(int num_channels, DataType data_type,
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateCLContext(const CLDevice& device, CLContext* result) {
|
absl::Status CreateCLContext(const CLDevice& device, CLContext* result) {
|
||||||
return CreateCLContext(device, nullptr, result);
|
return CreateCLContext(device, nullptr, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateCLGLContext(const CLDevice& device,
|
absl::Status CreateCLGLContext(const CLDevice& device,
|
||||||
cl_context_properties egl_context,
|
cl_context_properties egl_context,
|
||||||
cl_context_properties egl_display, CLContext* result) {
|
cl_context_properties egl_display,
|
||||||
|
CLContext* result) {
|
||||||
if (!device.SupportsExtension("cl_khr_gl_sharing")) {
|
if (!device.SupportsExtension("cl_khr_gl_sharing")) {
|
||||||
return UnavailableError("Device doesn't support CL-GL sharing.");
|
return absl::UnavailableError("Device doesn't support CL-GL sharing.");
|
||||||
}
|
}
|
||||||
cl_context_properties platform =
|
cl_context_properties platform =
|
||||||
reinterpret_cast<cl_context_properties>(device.platform());
|
reinterpret_cast<cl_context_properties>(device.platform());
|
||||||
|
@ -51,10 +51,11 @@ class CLContext {
|
|||||||
bool has_ownership_ = false;
|
bool has_ownership_ = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
Status CreateCLContext(const CLDevice& device, CLContext* result);
|
absl::Status CreateCLContext(const CLDevice& device, CLContext* result);
|
||||||
Status CreateCLGLContext(const CLDevice& device,
|
absl::Status CreateCLGLContext(const CLDevice& device,
|
||||||
cl_context_properties egl_context,
|
cl_context_properties egl_context,
|
||||||
cl_context_properties egl_display, CLContext* result);
|
cl_context_properties egl_display,
|
||||||
|
CLContext* result);
|
||||||
|
|
||||||
} // namespace cl
|
} // namespace cl
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
|
@ -516,11 +516,11 @@ void CLDevice::DisableOneLayerTextureArray() {
|
|||||||
info_.adreno_info.support_one_layer_texture_array = false;
|
info_.adreno_info.support_one_layer_texture_array = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateDefaultGPUDevice(CLDevice* result) {
|
absl::Status CreateDefaultGPUDevice(CLDevice* result) {
|
||||||
cl_uint num_platforms;
|
cl_uint num_platforms;
|
||||||
clGetPlatformIDs(0, nullptr, &num_platforms);
|
clGetPlatformIDs(0, nullptr, &num_platforms);
|
||||||
if (num_platforms == 0) {
|
if (num_platforms == 0) {
|
||||||
return UnknownError("No supported OpenCL platform.");
|
return absl::UnknownError("No supported OpenCL platform.");
|
||||||
}
|
}
|
||||||
std::vector<cl_platform_id> platforms(num_platforms);
|
std::vector<cl_platform_id> platforms(num_platforms);
|
||||||
clGetPlatformIDs(num_platforms, platforms.data(), nullptr);
|
clGetPlatformIDs(num_platforms, platforms.data(), nullptr);
|
||||||
@ -529,7 +529,7 @@ Status CreateDefaultGPUDevice(CLDevice* result) {
|
|||||||
cl_uint num_devices;
|
cl_uint num_devices;
|
||||||
clGetDeviceIDs(platform_id, CL_DEVICE_TYPE_GPU, 0, nullptr, &num_devices);
|
clGetDeviceIDs(platform_id, CL_DEVICE_TYPE_GPU, 0, nullptr, &num_devices);
|
||||||
if (num_devices == 0) {
|
if (num_devices == 0) {
|
||||||
return UnknownError("No GPU on current platform.");
|
return absl::UnknownError("No GPU on current platform.");
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<cl_device_id> devices(num_devices);
|
std::vector<cl_device_id> devices(num_devices);
|
||||||
@ -537,7 +537,7 @@ Status CreateDefaultGPUDevice(CLDevice* result) {
|
|||||||
nullptr);
|
nullptr);
|
||||||
|
|
||||||
*result = CLDevice(devices[0], platform_id);
|
*result = CLDevice(devices[0], platform_id);
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cl
|
} // namespace cl
|
||||||
|
@ -191,7 +191,7 @@ class CLDevice {
|
|||||||
DeviceInfo info_;
|
DeviceInfo info_;
|
||||||
};
|
};
|
||||||
|
|
||||||
Status CreateDefaultGPUDevice(CLDevice* result);
|
absl::Status CreateDefaultGPUDevice(CLDevice* result);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T GetDeviceInfo(cl_device_id id, cl_device_info info) {
|
T GetDeviceInfo(cl_device_id id, cl_device_info info) {
|
||||||
@ -204,12 +204,12 @@ T GetDeviceInfo(cl_device_id id, cl_device_info info) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
Status GetDeviceInfo(cl_device_id id, cl_device_info info, T* result) {
|
absl::Status GetDeviceInfo(cl_device_id id, cl_device_info info, T* result) {
|
||||||
cl_int error = clGetDeviceInfo(id, info, sizeof(T), result, nullptr);
|
cl_int error = clGetDeviceInfo(id, info, sizeof(T), result, nullptr);
|
||||||
if (error != CL_SUCCESS) {
|
if (error != CL_SUCCESS) {
|
||||||
return InvalidArgumentError(CLErrorCodeToString(error));
|
return absl::InvalidArgumentError(CLErrorCodeToString(error));
|
||||||
}
|
}
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cl
|
} // namespace cl
|
||||||
|
@ -27,11 +27,12 @@ namespace cl {
|
|||||||
|
|
||||||
// @return if error_code is success, then return OK status. Otherwise translates
|
// @return if error_code is success, then return OK status. Otherwise translates
|
||||||
// error code into a message.
|
// error code into a message.
|
||||||
inline Status GetOpenCLError(cl_int error_code) {
|
inline absl::Status GetOpenCLError(cl_int error_code) {
|
||||||
if (error_code == CL_SUCCESS) {
|
if (error_code == CL_SUCCESS) {
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
return InternalError("OpenCL error: " + CLErrorCodeToString(error_code));
|
return absl::InternalError("OpenCL error: " +
|
||||||
|
CLErrorCodeToString(error_code));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cl
|
} // namespace cl
|
||||||
|
@ -25,34 +25,34 @@ namespace gpu {
|
|||||||
namespace cl {
|
namespace cl {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
Status GetKernelMaxWorkGroupSize(cl_kernel kernel, cl_device_id device_id,
|
absl::Status GetKernelMaxWorkGroupSize(cl_kernel kernel, cl_device_id device_id,
|
||||||
int* result) {
|
int* result) {
|
||||||
size_t max_work_group_size;
|
size_t max_work_group_size;
|
||||||
cl_int error_code =
|
cl_int error_code =
|
||||||
clGetKernelWorkGroupInfo(kernel, device_id, CL_KERNEL_WORK_GROUP_SIZE,
|
clGetKernelWorkGroupInfo(kernel, device_id, CL_KERNEL_WORK_GROUP_SIZE,
|
||||||
sizeof(size_t), &max_work_group_size, nullptr);
|
sizeof(size_t), &max_work_group_size, nullptr);
|
||||||
if (error_code != CL_SUCCESS) {
|
if (error_code != CL_SUCCESS) {
|
||||||
return UnknownError(
|
return absl::UnknownError(
|
||||||
absl::StrCat("Failed to get info CL_KERNEL_WORK_GROUP_SIZE ",
|
absl::StrCat("Failed to get info CL_KERNEL_WORK_GROUP_SIZE ",
|
||||||
CLErrorCodeToString(error_code)));
|
CLErrorCodeToString(error_code)));
|
||||||
}
|
}
|
||||||
*result = static_cast<int>(max_work_group_size);
|
*result = static_cast<int>(max_work_group_size);
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GetKernelPrivateMemorySize(cl_kernel kernel, cl_device_id device_id,
|
absl::Status GetKernelPrivateMemorySize(cl_kernel kernel,
|
||||||
int* result) {
|
cl_device_id device_id, int* result) {
|
||||||
cl_ulong private_mem_size;
|
cl_ulong private_mem_size;
|
||||||
cl_int error_code =
|
cl_int error_code =
|
||||||
clGetKernelWorkGroupInfo(kernel, device_id, CL_KERNEL_PRIVATE_MEM_SIZE,
|
clGetKernelWorkGroupInfo(kernel, device_id, CL_KERNEL_PRIVATE_MEM_SIZE,
|
||||||
sizeof(cl_ulong), &private_mem_size, nullptr);
|
sizeof(cl_ulong), &private_mem_size, nullptr);
|
||||||
if (error_code != CL_SUCCESS) {
|
if (error_code != CL_SUCCESS) {
|
||||||
return UnknownError(
|
return absl::UnknownError(
|
||||||
absl::StrCat("Failed to get info CL_KERNEL_PRIVATE_MEM_SIZE ",
|
absl::StrCat("Failed to get info CL_KERNEL_PRIVATE_MEM_SIZE ",
|
||||||
CLErrorCodeToString(error_code)));
|
CLErrorCodeToString(error_code)));
|
||||||
}
|
}
|
||||||
*result = static_cast<int>(private_mem_size);
|
*result = static_cast<int>(private_mem_size);
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
@ -82,17 +82,17 @@ CLKernel& CLKernel::operator=(CLKernel&& kernel) {
|
|||||||
|
|
||||||
CLKernel::~CLKernel() { Release(); }
|
CLKernel::~CLKernel() { Release(); }
|
||||||
|
|
||||||
Status CLKernel::ReInit() const {
|
absl::Status CLKernel::ReInit() const {
|
||||||
clReleaseKernel(kernel_);
|
clReleaseKernel(kernel_);
|
||||||
cl_kernel* kern_ptr = const_cast<cl_kernel*>(&kernel_);
|
cl_kernel* kern_ptr = const_cast<cl_kernel*>(&kernel_);
|
||||||
int error_code;
|
int error_code;
|
||||||
*kern_ptr = clCreateKernel(program_, function_name_.c_str(), &error_code);
|
*kern_ptr = clCreateKernel(program_, function_name_.c_str(), &error_code);
|
||||||
if (!kernel_ || error_code != CL_SUCCESS) {
|
if (!kernel_ || error_code != CL_SUCCESS) {
|
||||||
*kern_ptr = nullptr;
|
*kern_ptr = nullptr;
|
||||||
return UnknownError(absl::StrCat("Failed to create ", function_name_,
|
return absl::UnknownError(absl::StrCat("Failed to create ", function_name_,
|
||||||
CLErrorCodeToString(error_code)));
|
CLErrorCodeToString(error_code)));
|
||||||
}
|
}
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
void CLKernel::Release() {
|
void CLKernel::Release() {
|
||||||
@ -103,16 +103,16 @@ void CLKernel::Release() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CLKernel::CreateFromProgram(const CLProgram& program,
|
absl::Status CLKernel::CreateFromProgram(const CLProgram& program,
|
||||||
const std::string& function_name) {
|
const std::string& function_name) {
|
||||||
int error_code;
|
int error_code;
|
||||||
function_name_ = function_name;
|
function_name_ = function_name;
|
||||||
kernel_ =
|
kernel_ =
|
||||||
clCreateKernel(program.program(), function_name.c_str(), &error_code);
|
clCreateKernel(program.program(), function_name.c_str(), &error_code);
|
||||||
if (!kernel_ || error_code != CL_SUCCESS) {
|
if (!kernel_ || error_code != CL_SUCCESS) {
|
||||||
kernel_ = nullptr;
|
kernel_ = nullptr;
|
||||||
return UnknownError(absl::StrCat("Failed to create ", function_name,
|
return absl::UnknownError(absl::StrCat("Failed to create ", function_name,
|
||||||
CLErrorCodeToString(error_code)));
|
CLErrorCodeToString(error_code)));
|
||||||
}
|
}
|
||||||
|
|
||||||
program_ = program.program();
|
program_ = program.program();
|
||||||
@ -122,64 +122,64 @@ Status CLKernel::CreateFromProgram(const CLProgram& program,
|
|||||||
&private_memory_size_));
|
&private_memory_size_));
|
||||||
RETURN_IF_ERROR(GetKernelMaxWorkGroupSize(kernel_, program.GetDeviceId(),
|
RETURN_IF_ERROR(GetKernelMaxWorkGroupSize(kernel_, program.GetDeviceId(),
|
||||||
&max_work_group_size_));
|
&max_work_group_size_));
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CLKernel::SetMemory(int index, cl_mem memory) {
|
absl::Status CLKernel::SetMemory(int index, cl_mem memory) {
|
||||||
return SetBytes(index, &memory, sizeof(cl_mem));
|
return SetBytes(index, &memory, sizeof(cl_mem));
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CLKernel::SetMemoryAuto(cl_mem memory) {
|
absl::Status CLKernel::SetMemoryAuto(cl_mem memory) {
|
||||||
return SetBytesAuto(&memory, sizeof(cl_mem));
|
return SetBytesAuto(&memory, sizeof(cl_mem));
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CLKernel::SetBytes(int index, const void* ptr, int length) const {
|
absl::Status CLKernel::SetBytes(int index, const void* ptr, int length) const {
|
||||||
const int error_code = clSetKernelArg(kernel_, index, length, ptr);
|
const int error_code = clSetKernelArg(kernel_, index, length, ptr);
|
||||||
if (error_code != CL_SUCCESS) {
|
if (error_code != CL_SUCCESS) {
|
||||||
return UnknownError(absl::StrCat("Failed to set kernel arguments - ",
|
return absl::UnknownError(absl::StrCat("Failed to set kernel arguments - ",
|
||||||
CLErrorCodeToString(error_code)));
|
CLErrorCodeToString(error_code)));
|
||||||
}
|
}
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CLKernel::SetBytesAuto(const void* ptr, int length) {
|
absl::Status CLKernel::SetBytesAuto(const void* ptr, int length) {
|
||||||
const int error_code = clSetKernelArg(kernel_, binding_counter_, length, ptr);
|
const int error_code = clSetKernelArg(kernel_, binding_counter_, length, ptr);
|
||||||
if (error_code != CL_SUCCESS) {
|
if (error_code != CL_SUCCESS) {
|
||||||
return UnknownError(absl::StrCat("Failed to set kernel arguments - ",
|
return absl::UnknownError(absl::StrCat(
|
||||||
CLErrorCodeToString(error_code),
|
"Failed to set kernel arguments - ", CLErrorCodeToString(error_code),
|
||||||
"(at index - ", binding_counter_, ")"));
|
"(at index - ", binding_counter_, ")"));
|
||||||
}
|
}
|
||||||
binding_counter_++;
|
binding_counter_++;
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
Status CLKernel::SetBytes<FLT>(int index, const FLT& value) const {
|
absl::Status CLKernel::SetBytes<FLT>(int index, const FLT& value) const {
|
||||||
return SetBytes(index, value.GetData(), value.GetSize());
|
return SetBytes(index, value.GetData(), value.GetSize());
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
Status CLKernel::SetBytes<FLT2>(int index, const FLT2& value) const {
|
absl::Status CLKernel::SetBytes<FLT2>(int index, const FLT2& value) const {
|
||||||
return SetBytes(index, value.GetData(), value.GetSize());
|
return SetBytes(index, value.GetData(), value.GetSize());
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
Status CLKernel::SetBytes<FLT4>(int index, const FLT4& value) const {
|
absl::Status CLKernel::SetBytes<FLT4>(int index, const FLT4& value) const {
|
||||||
return SetBytes(index, value.GetData(), value.GetSize());
|
return SetBytes(index, value.GetData(), value.GetSize());
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
Status CLKernel::SetBytesAuto<FLT>(const FLT& value) {
|
absl::Status CLKernel::SetBytesAuto<FLT>(const FLT& value) {
|
||||||
return SetBytesAuto(value.GetData(), value.GetSize());
|
return SetBytesAuto(value.GetData(), value.GetSize());
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
Status CLKernel::SetBytesAuto<FLT2>(const FLT2& value) {
|
absl::Status CLKernel::SetBytesAuto<FLT2>(const FLT2& value) {
|
||||||
return SetBytesAuto(value.GetData(), value.GetSize());
|
return SetBytesAuto(value.GetData(), value.GetSize());
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
Status CLKernel::SetBytesAuto<FLT4>(const FLT4& value) {
|
absl::Status CLKernel::SetBytesAuto<FLT4>(const FLT4& value) {
|
||||||
return SetBytesAuto(value.GetData(), value.GetSize());
|
return SetBytesAuto(value.GetData(), value.GetSize());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -48,17 +48,17 @@ class CLKernel {
|
|||||||
|
|
||||||
cl_kernel kernel() const { return kernel_; }
|
cl_kernel kernel() const { return kernel_; }
|
||||||
|
|
||||||
Status CreateFromProgram(const CLProgram& program,
|
absl::Status CreateFromProgram(const CLProgram& program,
|
||||||
const std::string& function_name);
|
const std::string& function_name);
|
||||||
|
|
||||||
Status SetMemory(int index, cl_mem memory);
|
absl::Status SetMemory(int index, cl_mem memory);
|
||||||
Status SetMemoryAuto(cl_mem memory);
|
absl::Status SetMemoryAuto(cl_mem memory);
|
||||||
template <typename T>
|
template <typename T>
|
||||||
Status SetBytes(int index, const T& value) const {
|
absl::Status SetBytes(int index, const T& value) const {
|
||||||
return SetBytes(index, static_cast<const void*>(&value), sizeof(T));
|
return SetBytes(index, static_cast<const void*>(&value), sizeof(T));
|
||||||
}
|
}
|
||||||
template <typename T>
|
template <typename T>
|
||||||
Status SetBytesAuto(const T& value) {
|
absl::Status SetBytesAuto(const T& value) {
|
||||||
return SetBytesAuto(static_cast<const void*>(&value), sizeof(T));
|
return SetBytesAuto(static_cast<const void*>(&value), sizeof(T));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -69,12 +69,12 @@ class CLKernel {
|
|||||||
|
|
||||||
// Do not use this function
|
// Do not use this function
|
||||||
// workaround for Mali memory leak
|
// workaround for Mali memory leak
|
||||||
Status ReInit() const;
|
absl::Status ReInit() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void Release();
|
void Release();
|
||||||
Status SetBytes(int index, const void* ptr, int length) const;
|
absl::Status SetBytes(int index, const void* ptr, int length) const;
|
||||||
Status SetBytesAuto(const void* ptr, int length);
|
absl::Status SetBytesAuto(const void* ptr, int length);
|
||||||
|
|
||||||
int private_memory_size_;
|
int private_memory_size_;
|
||||||
int max_work_group_size_;
|
int max_work_group_size_;
|
||||||
@ -87,22 +87,22 @@ class CLKernel {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
Status CLKernel::SetBytes<FLT>(int index, const FLT& value) const;
|
absl::Status CLKernel::SetBytes<FLT>(int index, const FLT& value) const;
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
Status CLKernel::SetBytes<FLT2>(int index, const FLT2& value) const;
|
absl::Status CLKernel::SetBytes<FLT2>(int index, const FLT2& value) const;
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
Status CLKernel::SetBytes<FLT4>(int index, const FLT4& value) const;
|
absl::Status CLKernel::SetBytes<FLT4>(int index, const FLT4& value) const;
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
Status CLKernel::SetBytesAuto<FLT>(const FLT& value);
|
absl::Status CLKernel::SetBytesAuto<FLT>(const FLT& value);
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
Status CLKernel::SetBytesAuto<FLT2>(const FLT2& value);
|
absl::Status CLKernel::SetBytesAuto<FLT2>(const FLT2& value);
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
Status CLKernel::SetBytesAuto<FLT4>(const FLT4& value);
|
absl::Status CLKernel::SetBytesAuto<FLT4>(const FLT4& value);
|
||||||
|
|
||||||
} // namespace cl
|
} // namespace cl
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
|
@ -49,28 +49,29 @@ std::string GetProgramBuildInfo(cl_program program, cl_device_id id,
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GetBinarySize(cl_program program, size_t* binary_size) {
|
absl::Status GetBinarySize(cl_program program, size_t* binary_size) {
|
||||||
cl_int error_code = clGetProgramInfo(program, CL_PROGRAM_BINARY_SIZES,
|
cl_int error_code = clGetProgramInfo(program, CL_PROGRAM_BINARY_SIZES,
|
||||||
sizeof(size_t), binary_size, nullptr);
|
sizeof(size_t), binary_size, nullptr);
|
||||||
if (error_code != CL_SUCCESS) {
|
if (error_code != CL_SUCCESS) {
|
||||||
return UnknownError(absl::StrCat("Failed to get program binary size - ",
|
return absl::UnknownError(
|
||||||
CLErrorCodeToString(error_code)));
|
absl::StrCat("Failed to get program binary size - ",
|
||||||
|
CLErrorCodeToString(error_code)));
|
||||||
}
|
}
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status BuildProgram(cl_program program, const CLDevice& device,
|
absl::Status BuildProgram(cl_program program, const CLDevice& device,
|
||||||
const std::string& compiler_options) {
|
const std::string& compiler_options) {
|
||||||
const int error_code = clBuildProgram(
|
const int error_code = clBuildProgram(
|
||||||
program, 0, nullptr, compiler_options.c_str(), nullptr, nullptr);
|
program, 0, nullptr, compiler_options.c_str(), nullptr, nullptr);
|
||||||
if (error_code != CL_SUCCESS) {
|
if (error_code != CL_SUCCESS) {
|
||||||
return UnknownError(absl::StrCat(
|
return absl::UnknownError(absl::StrCat(
|
||||||
"Failed to build program executable - ",
|
"Failed to build program executable - ",
|
||||||
CLErrorCodeToString(error_code),
|
CLErrorCodeToString(error_code),
|
||||||
GetProgramBuildInfo(program, device.id(), CL_PROGRAM_BUILD_LOG)));
|
GetProgramBuildInfo(program, device.id(), CL_PROGRAM_BUILD_LOG)));
|
||||||
}
|
}
|
||||||
|
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string CompilerOptionToString(const CLDevice& device,
|
std::string CompilerOptionToString(const CLDevice& device,
|
||||||
@ -133,7 +134,7 @@ void CLProgram::Release() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CLProgram::GetBinary(std::vector<uint8_t>* result) const {
|
absl::Status CLProgram::GetBinary(std::vector<uint8_t>* result) const {
|
||||||
size_t binary_size;
|
size_t binary_size;
|
||||||
RETURN_IF_ERROR(GetBinarySize(program_, &binary_size));
|
RETURN_IF_ERROR(GetBinarySize(program_, &binary_size));
|
||||||
result->resize(result->size() + binary_size);
|
result->resize(result->size() + binary_size);
|
||||||
@ -141,35 +142,36 @@ Status CLProgram::GetBinary(std::vector<uint8_t>* result) const {
|
|||||||
cl_int error_code = clGetProgramInfo(program_, CL_PROGRAM_BINARIES,
|
cl_int error_code = clGetProgramInfo(program_, CL_PROGRAM_BINARIES,
|
||||||
binary_size, &binary_ptr, nullptr);
|
binary_size, &binary_ptr, nullptr);
|
||||||
if (error_code != CL_SUCCESS) {
|
if (error_code != CL_SUCCESS) {
|
||||||
return UnknownError(absl::StrCat("Failed to get program binary - ",
|
return absl::UnknownError(absl::StrCat("Failed to get program binary - ",
|
||||||
CLErrorCodeToString(error_code)));
|
CLErrorCodeToString(error_code)));
|
||||||
}
|
}
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateCLProgram(const std::string& code,
|
absl::Status CreateCLProgram(const std::string& code,
|
||||||
const std::string& compiler_options,
|
const std::string& compiler_options,
|
||||||
const CLContext& context, const CLDevice& device,
|
const CLContext& context, const CLDevice& device,
|
||||||
CLProgram* result) {
|
CLProgram* result) {
|
||||||
int error_code;
|
int error_code;
|
||||||
const char* source = code.c_str();
|
const char* source = code.c_str();
|
||||||
|
|
||||||
cl_program program = clCreateProgramWithSource(context.context(), 1, &source,
|
cl_program program = clCreateProgramWithSource(context.context(), 1, &source,
|
||||||
nullptr, &error_code);
|
nullptr, &error_code);
|
||||||
if (!program || error_code != CL_SUCCESS) {
|
if (!program || error_code != CL_SUCCESS) {
|
||||||
return UnknownError(absl::StrCat("Failed to create compute program - ",
|
return absl::UnknownError(
|
||||||
CLErrorCodeToString(error_code)));
|
absl::StrCat("Failed to create compute program - ",
|
||||||
|
CLErrorCodeToString(error_code)));
|
||||||
}
|
}
|
||||||
|
|
||||||
*result = CLProgram(program, device.id());
|
*result = CLProgram(program, device.id());
|
||||||
RETURN_IF_ERROR(BuildProgram(program, device, compiler_options));
|
RETURN_IF_ERROR(BuildProgram(program, device, compiler_options));
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateCLProgramFromBinary(const CLContext& context,
|
absl::Status CreateCLProgramFromBinary(const CLContext& context,
|
||||||
const CLDevice& device,
|
const CLDevice& device,
|
||||||
absl::Span<const uint8_t> binary,
|
absl::Span<const uint8_t> binary,
|
||||||
CLProgram* result) {
|
CLProgram* result) {
|
||||||
cl_int binary_status;
|
cl_int binary_status;
|
||||||
cl_int error_code;
|
cl_int error_code;
|
||||||
cl_device_id devices_list[] = {device.id()};
|
cl_device_id devices_list[] = {device.id()};
|
||||||
@ -179,13 +181,13 @@ Status CreateCLProgramFromBinary(const CLContext& context,
|
|||||||
context.context(), 1, devices_list, &binary_size, &binary_pointer,
|
context.context(), 1, devices_list, &binary_size, &binary_pointer,
|
||||||
&binary_status, &error_code);
|
&binary_status, &error_code);
|
||||||
if (binary_status != CL_SUCCESS) {
|
if (binary_status != CL_SUCCESS) {
|
||||||
return UnknownError(absl::StrCat(
|
return absl::UnknownError(absl::StrCat(
|
||||||
"Something wrong with binary after clCreateProgramWithBinary - ",
|
"Something wrong with binary after clCreateProgramWithBinary - ",
|
||||||
binary_status));
|
binary_status));
|
||||||
}
|
}
|
||||||
if (error_code != CL_SUCCESS) {
|
if (error_code != CL_SUCCESS) {
|
||||||
return UnknownError(absl::StrCat("Failed to create program - ",
|
return absl::UnknownError(absl::StrCat("Failed to create program - ",
|
||||||
CLErrorCodeToString(error_code)));
|
CLErrorCodeToString(error_code)));
|
||||||
}
|
}
|
||||||
*result = CLProgram(program, device.id());
|
*result = CLProgram(program, device.id());
|
||||||
return BuildProgram(program, device, "");
|
return BuildProgram(program, device, "");
|
||||||
|
@ -68,7 +68,7 @@ class CLProgram {
|
|||||||
// was created using clCreateProgramWithBinary.
|
// was created using clCreateProgramWithBinary.
|
||||||
cl_device_id GetDeviceId() const { return device_id_; }
|
cl_device_id GetDeviceId() const { return device_id_; }
|
||||||
|
|
||||||
Status GetBinary(std::vector<uint8_t>* result) const;
|
absl::Status GetBinary(std::vector<uint8_t>* result) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void Release();
|
void Release();
|
||||||
@ -79,15 +79,15 @@ class CLProgram {
|
|||||||
cl_device_id device_id_ = nullptr;
|
cl_device_id device_id_ = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
Status CreateCLProgram(const std::string& code,
|
absl::Status CreateCLProgram(const std::string& code,
|
||||||
const std::string& compiler_options,
|
const std::string& compiler_options,
|
||||||
const CLContext& context, const CLDevice& device,
|
const CLContext& context, const CLDevice& device,
|
||||||
CLProgram* result);
|
CLProgram* result);
|
||||||
|
|
||||||
Status CreateCLProgramFromBinary(const CLContext& context,
|
absl::Status CreateCLProgramFromBinary(const CLContext& context,
|
||||||
const CLDevice& device,
|
const CLDevice& device,
|
||||||
absl::Span<const uint8_t> binary,
|
absl::Span<const uint8_t> binary,
|
||||||
CLProgram* result);
|
CLProgram* result);
|
||||||
|
|
||||||
} // namespace cl
|
} // namespace cl
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
|
@ -21,15 +21,15 @@ namespace tflite {
|
|||||||
namespace gpu {
|
namespace gpu {
|
||||||
namespace cl {
|
namespace cl {
|
||||||
|
|
||||||
Status EglSync::NewFence(EGLDisplay display, EglSync* sync) {
|
absl::Status EglSync::NewFence(EGLDisplay display, EglSync* sync) {
|
||||||
EGLSyncKHR egl_sync;
|
EGLSyncKHR egl_sync;
|
||||||
RETURN_IF_ERROR(TFLITE_GPU_CALL_EGL(eglCreateSyncKHR, &egl_sync, display,
|
RETURN_IF_ERROR(TFLITE_GPU_CALL_EGL(eglCreateSyncKHR, &egl_sync, display,
|
||||||
EGL_SYNC_FENCE_KHR, nullptr));
|
EGL_SYNC_FENCE_KHR, nullptr));
|
||||||
if (egl_sync == EGL_NO_SYNC_KHR) {
|
if (egl_sync == EGL_NO_SYNC_KHR) {
|
||||||
return InternalError("Returned empty KHR EGL sync");
|
return absl::InternalError("Returned empty KHR EGL sync");
|
||||||
}
|
}
|
||||||
*sync = EglSync(display, egl_sync);
|
*sync = EglSync(display, egl_sync);
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
EglSync& EglSync::operator=(EglSync&& sync) {
|
EglSync& EglSync::operator=(EglSync&& sync) {
|
||||||
@ -48,22 +48,23 @@ void EglSync::Invalidate() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status EglSync::ServerWait() {
|
absl::Status EglSync::ServerWait() {
|
||||||
EGLint result;
|
EGLint result;
|
||||||
RETURN_IF_ERROR(
|
RETURN_IF_ERROR(
|
||||||
TFLITE_GPU_CALL_EGL(eglWaitSyncKHR, &result, display_, sync_, 0));
|
TFLITE_GPU_CALL_EGL(eglWaitSyncKHR, &result, display_, sync_, 0));
|
||||||
return result == EGL_TRUE ? OkStatus() : InternalError("eglWaitSync failed");
|
return result == EGL_TRUE ? absl::OkStatus()
|
||||||
|
: absl::InternalError("eglWaitSync failed");
|
||||||
}
|
}
|
||||||
|
|
||||||
Status EglSync::ClientWait() {
|
absl::Status EglSync::ClientWait() {
|
||||||
EGLint result;
|
EGLint result;
|
||||||
// TODO(akulik): make it active wait for better performance
|
// TODO(akulik): make it active wait for better performance
|
||||||
RETURN_IF_ERROR(TFLITE_GPU_CALL_EGL(eglClientWaitSyncKHR, &result, display_,
|
RETURN_IF_ERROR(TFLITE_GPU_CALL_EGL(eglClientWaitSyncKHR, &result, display_,
|
||||||
sync_, EGL_SYNC_FLUSH_COMMANDS_BIT_KHR,
|
sync_, EGL_SYNC_FLUSH_COMMANDS_BIT_KHR,
|
||||||
EGL_FOREVER_KHR));
|
EGL_FOREVER_KHR));
|
||||||
return result == EGL_CONDITION_SATISFIED_KHR
|
return result == EGL_CONDITION_SATISFIED_KHR
|
||||||
? OkStatus()
|
? absl::OkStatus()
|
||||||
: InternalError("eglClientWaitSync failed");
|
: absl::InternalError("eglClientWaitSync failed");
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cl
|
} // namespace cl
|
||||||
|
@ -32,7 +32,7 @@ class EglSync {
|
|||||||
// flushed.
|
// flushed.
|
||||||
//
|
//
|
||||||
// Depends on EGL_KHR_fence_sync extension.
|
// Depends on EGL_KHR_fence_sync extension.
|
||||||
static Status NewFence(EGLDisplay display, EglSync* sync);
|
static absl::Status NewFence(EGLDisplay display, EglSync* sync);
|
||||||
|
|
||||||
// Creates invalid object.
|
// Creates invalid object.
|
||||||
EglSync() : EglSync(EGL_NO_DISPLAY, EGL_NO_SYNC_KHR) {}
|
EglSync() : EglSync(EGL_NO_DISPLAY, EGL_NO_SYNC_KHR) {}
|
||||||
@ -50,10 +50,10 @@ class EglSync {
|
|||||||
|
|
||||||
// Causes GPU to block and wait until this sync has been signaled.
|
// Causes GPU to block and wait until this sync has been signaled.
|
||||||
// This call does not block and returns immediately.
|
// This call does not block and returns immediately.
|
||||||
Status ServerWait();
|
absl::Status ServerWait();
|
||||||
|
|
||||||
// Causes CPU to block and wait until this sync has been signaled.
|
// Causes CPU to block and wait until this sync has been signaled.
|
||||||
Status ClientWait();
|
absl::Status ClientWait();
|
||||||
|
|
||||||
// Returns the EGLDisplay on which this instance was created.
|
// Returns the EGLDisplay on which this instance was created.
|
||||||
EGLDisplay display() const { return display_; }
|
EGLDisplay display() const { return display_; }
|
||||||
|
@ -26,6 +26,7 @@ namespace tflite {
|
|||||||
namespace gpu {
|
namespace gpu {
|
||||||
namespace cl {
|
namespace cl {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
std::string GetKernelOneLayerTextureArray() {
|
std::string GetKernelOneLayerTextureArray() {
|
||||||
return R"(
|
return R"(
|
||||||
|
|
||||||
@ -43,12 +44,12 @@ __kernel void main_function(__write_only image2d_array_t dst) {
|
|||||||
// texture, we will get zeroes instead of actual values.
|
// texture, we will get zeroes instead of actual values.
|
||||||
// The same kernel will work, if we use texture array with more than one layer.
|
// The same kernel will work, if we use texture array with more than one layer.
|
||||||
// With help of this code we can detect this bug.
|
// With help of this code we can detect this bug.
|
||||||
Status CheckKernelSupportOfOneLayerTextureArray(Environment* env,
|
absl::Status CheckKernelSupportOfOneLayerTextureArray(Environment* env,
|
||||||
bool* result) {
|
bool* result) {
|
||||||
// No bug on Adreno 6xx
|
// No bug on Adreno 6xx
|
||||||
if (env->device().GetInfo().adreno_info.gpu_version >= 600) {
|
if (env->device().GetInfo().adreno_info.gpu_version >= 600) {
|
||||||
*result = true;
|
*result = true;
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
CLKernel kernel;
|
CLKernel kernel;
|
||||||
RETURN_IF_ERROR(env->program_cache()->GetOrCreateCLKernel(
|
RETURN_IF_ERROR(env->program_cache()->GetOrCreateCLKernel(
|
||||||
@ -75,12 +76,12 @@ Status CheckKernelSupportOfOneLayerTextureArray(Environment* env,
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateEnvironment(Environment* result, bool shared,
|
absl::Status CreateEnvironment(Environment* result, bool shared,
|
||||||
cl_context_properties egl_context,
|
cl_context_properties egl_context,
|
||||||
cl_context_properties egl_display) {
|
cl_context_properties egl_display) {
|
||||||
CLDevice gpu;
|
CLDevice gpu;
|
||||||
RETURN_IF_ERROR(CreateDefaultGPUDevice(&gpu));
|
RETURN_IF_ERROR(CreateDefaultGPUDevice(&gpu));
|
||||||
|
|
||||||
@ -107,8 +108,9 @@ Status CreateEnvironment(Environment* result, bool shared,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Environment::Environment(CLDevice&& device, CLContext&& context,
|
Environment::Environment(CLDevice&& device, CLContext&& context,
|
||||||
@ -137,7 +139,7 @@ Environment& Environment::operator=(Environment&& environment) {
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Environment::Init() {
|
absl::Status Environment::Init() {
|
||||||
if (device().IsAdreno() && device().SupportsTextureArray()) {
|
if (device().IsAdreno() && device().SupportsTextureArray()) {
|
||||||
bool supports_one_layer;
|
bool supports_one_layer;
|
||||||
RETURN_IF_ERROR(
|
RETURN_IF_ERROR(
|
||||||
@ -146,7 +148,7 @@ Status Environment::Init() {
|
|||||||
GetDevicePtr()->DisableOneLayerTextureArray();
|
GetDevicePtr()->DisableOneLayerTextureArray();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
void Environment::SetHighPerformance() const {
|
void Environment::SetHighPerformance() const {
|
||||||
@ -266,7 +268,7 @@ TensorStorageType GetStorageTypeWithMinimalMemoryConsumption(
|
|||||||
return TensorStorageType::BUFFER;
|
return TensorStorageType::BUFFER;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateEnvironment(Environment* result) {
|
absl::Status CreateEnvironment(Environment* result) {
|
||||||
CLDevice gpu;
|
CLDevice gpu;
|
||||||
RETURN_IF_ERROR(CreateDefaultGPUDevice(&gpu));
|
RETURN_IF_ERROR(CreateDefaultGPUDevice(&gpu));
|
||||||
|
|
||||||
|
@ -57,7 +57,7 @@ class Environment {
|
|||||||
std::vector<TensorStorageType> GetSupportedStorages() const;
|
std::vector<TensorStorageType> GetSupportedStorages() const;
|
||||||
bool IsSupported(TensorStorageType storage_type) const;
|
bool IsSupported(TensorStorageType storage_type) const;
|
||||||
|
|
||||||
Status Init();
|
absl::Status Init();
|
||||||
|
|
||||||
void SetHighPerformance() const;
|
void SetHighPerformance() const;
|
||||||
void SetDefaultPerformance() const;
|
void SetDefaultPerformance() const;
|
||||||
@ -75,7 +75,7 @@ TensorStorageType GetFastestStorageType(const CLDevice& gpu);
|
|||||||
TensorStorageType GetStorageTypeWithMinimalMemoryConsumption(
|
TensorStorageType GetStorageTypeWithMinimalMemoryConsumption(
|
||||||
const CLDevice& gpu);
|
const CLDevice& gpu);
|
||||||
|
|
||||||
Status CreateEnvironment(Environment* result);
|
absl::Status CreateEnvironment(Environment* result);
|
||||||
|
|
||||||
} // namespace cl
|
} // namespace cl
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
|
@ -41,10 +41,11 @@ PFNEGLCREATESYNCPROC g_eglCreateSync = nullptr;
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Status CreateEglSyncFromClEvent(cl_event event, EGLDisplay display,
|
absl::Status CreateEglSyncFromClEvent(cl_event event, EGLDisplay display,
|
||||||
EglSync* sync) {
|
EglSync* sync) {
|
||||||
if (!IsEglSyncFromClEventSupported()) {
|
if (!IsEglSyncFromClEventSupported()) {
|
||||||
return UnimplementedError("CreateEglSyncFromClEvent is not supported");
|
return absl::UnimplementedError(
|
||||||
|
"CreateEglSyncFromClEvent is not supported");
|
||||||
}
|
}
|
||||||
EGLSync egl_sync;
|
EGLSync egl_sync;
|
||||||
const EGLAttrib attributes[] = {EGL_CL_EVENT_HANDLE,
|
const EGLAttrib attributes[] = {EGL_CL_EVENT_HANDLE,
|
||||||
@ -52,10 +53,10 @@ Status CreateEglSyncFromClEvent(cl_event event, EGLDisplay display,
|
|||||||
RETURN_IF_ERROR(TFLITE_GPU_CALL_EGL(g_eglCreateSync, &egl_sync, display,
|
RETURN_IF_ERROR(TFLITE_GPU_CALL_EGL(g_eglCreateSync, &egl_sync, display,
|
||||||
EGL_SYNC_CL_EVENT, attributes));
|
EGL_SYNC_CL_EVENT, attributes));
|
||||||
if (egl_sync == EGL_NO_SYNC) {
|
if (egl_sync == EGL_NO_SYNC) {
|
||||||
return InternalError("Returned empty EGL sync");
|
return absl::InternalError("Returned empty EGL sync");
|
||||||
}
|
}
|
||||||
*sync = EglSync(display, egl_sync);
|
*sync = EglSync(display, egl_sync);
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsEglSyncFromClEventSupported() {
|
bool IsEglSyncFromClEventSupported() {
|
||||||
@ -73,52 +74,54 @@ bool IsEglSyncFromClEventSupported() {
|
|||||||
return supported;
|
return supported;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateClEventFromEglSync(cl_context context, const EglSync& egl_sync,
|
absl::Status CreateClEventFromEglSync(cl_context context,
|
||||||
CLEvent* event) {
|
const EglSync& egl_sync, CLEvent* event) {
|
||||||
cl_int error_code;
|
cl_int error_code;
|
||||||
cl_event new_event = clCreateEventFromEGLSyncKHR(
|
cl_event new_event = clCreateEventFromEGLSyncKHR(
|
||||||
context, egl_sync.sync(), egl_sync.display(), &error_code);
|
context, egl_sync.sync(), egl_sync.display(), &error_code);
|
||||||
if (error_code != CL_SUCCESS) {
|
if (error_code != CL_SUCCESS) {
|
||||||
return InternalError(
|
return absl::InternalError(
|
||||||
absl::StrCat("Unable to create CL sync from EGL sync. ",
|
absl::StrCat("Unable to create CL sync from EGL sync. ",
|
||||||
CLErrorCodeToString(error_code)));
|
CLErrorCodeToString(error_code)));
|
||||||
}
|
}
|
||||||
*event = CLEvent(new_event);
|
*event = CLEvent(new_event);
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsClEventFromEglSyncSupported(const CLDevice& device) {
|
bool IsClEventFromEglSyncSupported(const CLDevice& device) {
|
||||||
return device.SupportsExtension("cl_khr_egl_event");
|
return device.SupportsExtension("cl_khr_egl_event");
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateClMemoryFromGlBuffer(GLuint gl_ssbo_id, AccessType access_type,
|
absl::Status CreateClMemoryFromGlBuffer(GLuint gl_ssbo_id,
|
||||||
CLContext* context, CLMemory* memory) {
|
AccessType access_type,
|
||||||
|
CLContext* context, CLMemory* memory) {
|
||||||
cl_int error_code;
|
cl_int error_code;
|
||||||
auto mem = clCreateFromGLBuffer(context->context(), ToClMemFlags(access_type),
|
auto mem = clCreateFromGLBuffer(context->context(), ToClMemFlags(access_type),
|
||||||
gl_ssbo_id, &error_code);
|
gl_ssbo_id, &error_code);
|
||||||
if (error_code != CL_SUCCESS) {
|
if (error_code != CL_SUCCESS) {
|
||||||
return InternalError(
|
return absl::InternalError(
|
||||||
absl::StrCat("Unable to acquire CL buffer from GL buffer. ",
|
absl::StrCat("Unable to acquire CL buffer from GL buffer. ",
|
||||||
CLErrorCodeToString(error_code)));
|
CLErrorCodeToString(error_code)));
|
||||||
}
|
}
|
||||||
*memory = CLMemory(mem, true);
|
*memory = CLMemory(mem, true);
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateClMemoryFromGlTexture(GLenum texture_target, GLuint texture_id,
|
absl::Status CreateClMemoryFromGlTexture(GLenum texture_target,
|
||||||
AccessType access_type, CLContext* context,
|
GLuint texture_id,
|
||||||
CLMemory* memory) {
|
AccessType access_type,
|
||||||
|
CLContext* context, CLMemory* memory) {
|
||||||
cl_int error_code;
|
cl_int error_code;
|
||||||
auto mem =
|
auto mem =
|
||||||
clCreateFromGLTexture(context->context(), ToClMemFlags(access_type),
|
clCreateFromGLTexture(context->context(), ToClMemFlags(access_type),
|
||||||
texture_target, 0, texture_id, &error_code);
|
texture_target, 0, texture_id, &error_code);
|
||||||
if (error_code != CL_SUCCESS) {
|
if (error_code != CL_SUCCESS) {
|
||||||
return InternalError(
|
return absl::InternalError(
|
||||||
absl::StrCat("Unable to create CL buffer from GL texture. ",
|
absl::StrCat("Unable to create CL buffer from GL texture. ",
|
||||||
CLErrorCodeToString(error_code)));
|
CLErrorCodeToString(error_code)));
|
||||||
}
|
}
|
||||||
*memory = CLMemory(mem, true);
|
*memory = CLMemory(mem, true);
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsGlSharingSupported(const CLDevice& device) {
|
bool IsGlSharingSupported(const CLDevice& device) {
|
||||||
@ -128,19 +131,18 @@ bool IsGlSharingSupported(const CLDevice& device) {
|
|||||||
|
|
||||||
AcquiredGlObjects::~AcquiredGlObjects() { Release({}, nullptr).IgnoreError(); }
|
AcquiredGlObjects::~AcquiredGlObjects() { Release({}, nullptr).IgnoreError(); }
|
||||||
|
|
||||||
Status AcquiredGlObjects::Acquire(const std::vector<cl_mem>& memory,
|
absl::Status AcquiredGlObjects::Acquire(
|
||||||
cl_command_queue queue,
|
const std::vector<cl_mem>& memory, cl_command_queue queue,
|
||||||
const std::vector<cl_event>& wait_events,
|
const std::vector<cl_event>& wait_events, CLEvent* acquire_event,
|
||||||
CLEvent* acquire_event,
|
AcquiredGlObjects* objects) {
|
||||||
AcquiredGlObjects* objects) {
|
|
||||||
if (!memory.empty()) {
|
if (!memory.empty()) {
|
||||||
cl_event new_event;
|
cl_event new_event;
|
||||||
cl_int error_code = clEnqueueAcquireGLObjects(
|
cl_int error_code = clEnqueueAcquireGLObjects(
|
||||||
queue, memory.size(), memory.data(), wait_events.size(),
|
queue, memory.size(), memory.data(), wait_events.size(),
|
||||||
wait_events.data(), acquire_event ? &new_event : nullptr);
|
wait_events.data(), acquire_event ? &new_event : nullptr);
|
||||||
if (error_code != CL_SUCCESS) {
|
if (error_code != CL_SUCCESS) {
|
||||||
return InternalError(absl::StrCat("Unable to acquire GL object. ",
|
return absl::InternalError(absl::StrCat("Unable to acquire GL object. ",
|
||||||
CLErrorCodeToString(error_code)));
|
CLErrorCodeToString(error_code)));
|
||||||
}
|
}
|
||||||
if (acquire_event) {
|
if (acquire_event) {
|
||||||
*acquire_event = CLEvent(new_event);
|
*acquire_event = CLEvent(new_event);
|
||||||
@ -148,19 +150,19 @@ Status AcquiredGlObjects::Acquire(const std::vector<cl_mem>& memory,
|
|||||||
clFlush(queue);
|
clFlush(queue);
|
||||||
}
|
}
|
||||||
*objects = AcquiredGlObjects(memory, queue);
|
*objects = AcquiredGlObjects(memory, queue);
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status AcquiredGlObjects::Release(const std::vector<cl_event>& wait_events,
|
absl::Status AcquiredGlObjects::Release(
|
||||||
CLEvent* release_event) {
|
const std::vector<cl_event>& wait_events, CLEvent* release_event) {
|
||||||
if (queue_ && !memory_.empty()) {
|
if (queue_ && !memory_.empty()) {
|
||||||
cl_event new_event;
|
cl_event new_event;
|
||||||
cl_int error_code = clEnqueueReleaseGLObjects(
|
cl_int error_code = clEnqueueReleaseGLObjects(
|
||||||
queue_, memory_.size(), memory_.data(), wait_events.size(),
|
queue_, memory_.size(), memory_.data(), wait_events.size(),
|
||||||
wait_events.data(), release_event ? &new_event : nullptr);
|
wait_events.data(), release_event ? &new_event : nullptr);
|
||||||
if (error_code != CL_SUCCESS) {
|
if (error_code != CL_SUCCESS) {
|
||||||
return InternalError(absl::StrCat("Unable to release GL object. ",
|
return absl::InternalError(absl::StrCat("Unable to release GL object. ",
|
||||||
CLErrorCodeToString(error_code)));
|
CLErrorCodeToString(error_code)));
|
||||||
}
|
}
|
||||||
if (release_event) {
|
if (release_event) {
|
||||||
*release_event = CLEvent(new_event);
|
*release_event = CLEvent(new_event);
|
||||||
@ -168,7 +170,7 @@ Status AcquiredGlObjects::Release(const std::vector<cl_event>& wait_events,
|
|||||||
clFlush(queue_);
|
clFlush(queue_);
|
||||||
queue_ = nullptr;
|
queue_ = nullptr;
|
||||||
}
|
}
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
GlInteropFabric::GlInteropFabric(EGLDisplay egl_display,
|
GlInteropFabric::GlInteropFabric(EGLDisplay egl_display,
|
||||||
@ -192,9 +194,9 @@ void GlInteropFabric::UnregisterMemory(cl_mem memory) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GlInteropFabric::Start() {
|
absl::Status GlInteropFabric::Start() {
|
||||||
if (!is_enabled()) {
|
if (!is_enabled()) {
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
// In GL-CL interoperability, we need to make sure GL finished processing of
|
// In GL-CL interoperability, we need to make sure GL finished processing of
|
||||||
@ -235,9 +237,9 @@ Status GlInteropFabric::Start() {
|
|||||||
nullptr, &gl_objects_);
|
nullptr, &gl_objects_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GlInteropFabric::Finish() {
|
absl::Status GlInteropFabric::Finish() {
|
||||||
if (!is_enabled()) {
|
if (!is_enabled()) {
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
RETURN_IF_ERROR(gl_objects_.Release({}, &outbound_event_));
|
RETURN_IF_ERROR(gl_objects_.Release({}, &outbound_event_));
|
||||||
|
|
||||||
@ -258,7 +260,7 @@ Status GlInteropFabric::Finish() {
|
|||||||
// This slow sync is the only working solution right now. We have to debug why
|
// This slow sync is the only working solution right now. We have to debug why
|
||||||
// above version is not working fast and reliable.
|
// above version is not working fast and reliable.
|
||||||
outbound_event_.Wait();
|
outbound_event_.Wait();
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cl
|
} // namespace cl
|
||||||
|
@ -39,8 +39,8 @@ namespace cl {
|
|||||||
// returned sync and could be safely destroyed.
|
// returned sync and could be safely destroyed.
|
||||||
//
|
//
|
||||||
// Depends on EGL 1.5.
|
// Depends on EGL 1.5.
|
||||||
Status CreateEglSyncFromClEvent(cl_event event, EGLDisplay display,
|
absl::Status CreateEglSyncFromClEvent(cl_event event, EGLDisplay display,
|
||||||
EglSync* sync);
|
EglSync* sync);
|
||||||
|
|
||||||
// Returns true if 'CreateEglSyncFromClEvent' is supported.
|
// Returns true if 'CreateEglSyncFromClEvent' is supported.
|
||||||
bool IsEglSyncFromClEventSupported();
|
bool IsEglSyncFromClEventSupported();
|
||||||
@ -48,20 +48,22 @@ bool IsEglSyncFromClEventSupported();
|
|||||||
// Creates CL event from EGL sync.
|
// Creates CL event from EGL sync.
|
||||||
// Created event could only be consumed by AcquiredGlObject::Acquire call as
|
// Created event could only be consumed by AcquiredGlObject::Acquire call as
|
||||||
// a 'wait_event'.
|
// a 'wait_event'.
|
||||||
Status CreateClEventFromEglSync(cl_context context, const EglSync& egl_sync,
|
absl::Status CreateClEventFromEglSync(cl_context context,
|
||||||
CLEvent* event);
|
const EglSync& egl_sync, CLEvent* event);
|
||||||
|
|
||||||
// Returns true if 'CreateClEventFromEglSync' is supported.
|
// Returns true if 'CreateClEventFromEglSync' is supported.
|
||||||
bool IsClEventFromEglSyncSupported(const CLDevice& device);
|
bool IsClEventFromEglSyncSupported(const CLDevice& device);
|
||||||
|
|
||||||
// Creates new CL memory object from OpenGL buffer.
|
// Creates new CL memory object from OpenGL buffer.
|
||||||
Status CreateClMemoryFromGlBuffer(GLuint gl_ssbo_id, AccessType access_type,
|
absl::Status CreateClMemoryFromGlBuffer(GLuint gl_ssbo_id,
|
||||||
CLContext* context, CLMemory* memory);
|
AccessType access_type,
|
||||||
|
CLContext* context, CLMemory* memory);
|
||||||
|
|
||||||
// Creates new CL memory object from OpenGL texture.
|
// Creates new CL memory object from OpenGL texture.
|
||||||
Status CreateClMemoryFromGlTexture(GLenum texture_target, GLuint texture_id,
|
absl::Status CreateClMemoryFromGlTexture(GLenum texture_target,
|
||||||
AccessType access_type, CLContext* context,
|
GLuint texture_id,
|
||||||
CLMemory* memory);
|
AccessType access_type,
|
||||||
|
CLContext* context, CLMemory* memory);
|
||||||
|
|
||||||
// Returns true if GL objects could be shared with OpenCL context.
|
// Returns true if GL objects could be shared with OpenCL context.
|
||||||
bool IsGlSharingSupported(const CLDevice& device);
|
bool IsGlSharingSupported(const CLDevice& device);
|
||||||
@ -81,16 +83,16 @@ class AcquiredGlObjects {
|
|||||||
// CreateClMemoryFromGlBuffer or CreateClMemoryFromGlTexture calls.
|
// CreateClMemoryFromGlBuffer or CreateClMemoryFromGlTexture calls.
|
||||||
// If 'acquire_event' is not nullptr, it will be signared once acquisition is
|
// If 'acquire_event' is not nullptr, it will be signared once acquisition is
|
||||||
// complete.
|
// complete.
|
||||||
static Status Acquire(const std::vector<cl_mem>& memory,
|
static absl::Status Acquire(const std::vector<cl_mem>& memory,
|
||||||
cl_command_queue queue,
|
cl_command_queue queue,
|
||||||
const std::vector<cl_event>& wait_events,
|
const std::vector<cl_event>& wait_events,
|
||||||
CLEvent* acquire_event /* optional */,
|
CLEvent* acquire_event /* optional */,
|
||||||
AcquiredGlObjects* objects);
|
AcquiredGlObjects* objects);
|
||||||
|
|
||||||
// Releases OpenCL memory back to OpenGL context. If 'release_event' is not
|
// Releases OpenCL memory back to OpenGL context. If 'release_event' is not
|
||||||
// nullptr, it will be signalled once release is complete.
|
// nullptr, it will be signalled once release is complete.
|
||||||
Status Release(const std::vector<cl_event>& wait_events,
|
absl::Status Release(const std::vector<cl_event>& wait_events,
|
||||||
CLEvent* release_event /* optional */);
|
CLEvent* release_event /* optional */);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
AcquiredGlObjects(const std::vector<cl_mem>& memory, cl_command_queue queue)
|
AcquiredGlObjects(const std::vector<cl_mem>& memory, cl_command_queue queue)
|
||||||
@ -108,10 +110,10 @@ class GlInteropFabric {
|
|||||||
|
|
||||||
// Ensures proper GL->CL synchronization is in place before
|
// Ensures proper GL->CL synchronization is in place before
|
||||||
// GL objects that are mapped to CL objects are used.
|
// GL objects that are mapped to CL objects are used.
|
||||||
Status Start();
|
absl::Status Start();
|
||||||
|
|
||||||
// Puts appropriate CL->GL synchronization after all work is complete.
|
// Puts appropriate CL->GL synchronization after all work is complete.
|
||||||
Status Finish();
|
absl::Status Finish();
|
||||||
|
|
||||||
// Registers memory to be used from GL context. Such CL memory object must
|
// Registers memory to be used from GL context. Such CL memory object must
|
||||||
// be created with CreateClMemoryFromGlBuffer or CreateClMemoryFromGlTexture
|
// be created with CreateClMemoryFromGlBuffer or CreateClMemoryFromGlTexture
|
||||||
|
@ -87,8 +87,8 @@ class Delegate {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Prepare(TfLiteContext* context,
|
absl::Status Prepare(TfLiteContext* context,
|
||||||
const TfLiteDelegateParams* delegate_params) {
|
const TfLiteDelegateParams* delegate_params) {
|
||||||
// Extract TFLite delegate execution plan from the context and convert it
|
// Extract TFLite delegate execution plan from the context and convert it
|
||||||
// into FlowGraph32.
|
// into FlowGraph32.
|
||||||
GraphFloat32 graph;
|
GraphFloat32 graph;
|
||||||
@ -98,7 +98,7 @@ class Delegate {
|
|||||||
NullTransformationReporter reporter;
|
NullTransformationReporter reporter;
|
||||||
ModelTransformer transformer(&graph, &reporter);
|
ModelTransformer transformer(&graph, &reporter);
|
||||||
if (!ApplyGeneralTransformations(&transformer)) {
|
if (!ApplyGeneralTransformations(&transformer)) {
|
||||||
return InternalError("Graph general transformations failed");
|
return absl::InternalError("Graph general transformations failed");
|
||||||
}
|
}
|
||||||
|
|
||||||
InferenceEnvironmentOptions env_options;
|
InferenceEnvironmentOptions env_options;
|
||||||
@ -108,7 +108,7 @@ class Delegate {
|
|||||||
options_.serialized_binary_cache_data,
|
options_.serialized_binary_cache_data,
|
||||||
options_.serialized_binary_cache_size};
|
options_.serialized_binary_cache_size};
|
||||||
InferenceEnvironmentProperties properties;
|
InferenceEnvironmentProperties properties;
|
||||||
Status status =
|
absl::Status status =
|
||||||
NewInferenceEnvironment(env_options, &environment_, &properties);
|
NewInferenceEnvironment(env_options, &environment_, &properties);
|
||||||
if (!properties.is_opencl_available) {
|
if (!properties.is_opencl_available) {
|
||||||
context->ReportError(context,
|
context->ReportError(context,
|
||||||
@ -200,7 +200,7 @@ class Delegate {
|
|||||||
return builder->Build(&runner_);
|
return builder->Build(&runner_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status SetInputsAndOutputs(TfLiteContext* context) {
|
absl::Status SetInputsAndOutputs(TfLiteContext* context) {
|
||||||
int i = 0;
|
int i = 0;
|
||||||
for (auto index : input_indices_) {
|
for (auto index : input_indices_) {
|
||||||
RETURN_IF_ERROR(
|
RETURN_IF_ERROR(
|
||||||
@ -211,10 +211,10 @@ class Delegate {
|
|||||||
RETURN_IF_ERROR(
|
RETURN_IF_ERROR(
|
||||||
runner_->SetOutputObject(i++, GetTensorObject(index, context)));
|
runner_->SetOutputObject(i++, GetTensorObject(index, context)));
|
||||||
}
|
}
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Invoke(TfLiteContext* context) {
|
absl::Status Invoke(TfLiteContext* context) {
|
||||||
RETURN_IF_ERROR(SetInputsAndOutputs(context));
|
RETURN_IF_ERROR(SetInputsAndOutputs(context));
|
||||||
return runner_->Run();
|
return runner_->Run();
|
||||||
}
|
}
|
||||||
@ -310,7 +310,7 @@ TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) {
|
|||||||
const auto status = gpu_delegate->Prepare(context, params);
|
const auto status = gpu_delegate->Prepare(context, params);
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
context->ReportError(context, "TfLiteGpuDelegate Init: %s",
|
context->ReportError(context, "TfLiteGpuDelegate Init: %s",
|
||||||
status.error_message().c_str());
|
std::string(status.message()).c_str());
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
return gpu_delegate;
|
return gpu_delegate;
|
||||||
@ -335,7 +335,7 @@ TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) {
|
|||||||
const auto status = GetDelegate(node)->Invoke(context);
|
const auto status = GetDelegate(node)->Invoke(context);
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
context->ReportError(context, "TfLiteGpuDelegate Invoke: %s",
|
context->ReportError(context, "TfLiteGpuDelegate Invoke: %s",
|
||||||
status.error_message().c_str());
|
std::string(status.message()).c_str());
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
|
@ -169,9 +169,9 @@ CLNode& CLNode::operator=(CLNode&& node) {
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status InferenceContext::InitFromGraph(const CreateInferenceInfo& create_info,
|
absl::Status InferenceContext::InitFromGraph(
|
||||||
const GraphFloat32& graph,
|
const CreateInferenceInfo& create_info, const GraphFloat32& graph,
|
||||||
Environment* env) {
|
Environment* env) {
|
||||||
CreationContext creation_context;
|
CreationContext creation_context;
|
||||||
creation_context.device = env->GetDevicePtr();
|
creation_context.device = env->GetDevicePtr();
|
||||||
creation_context.context = &env->context();
|
creation_context.context = &env->context();
|
||||||
@ -206,15 +206,15 @@ Status InferenceContext::InitFromGraph(const CreateInferenceInfo& create_info,
|
|||||||
tuning_parameters.tuning_type = TuningType::FAST;
|
tuning_parameters.tuning_type = TuningType::FAST;
|
||||||
}
|
}
|
||||||
RETURN_IF_ERROR(Tune(tuning_parameters));
|
RETURN_IF_ERROR(Tune(tuning_parameters));
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status InferenceContext::InitFromGraphWithTransforms(
|
absl::Status InferenceContext::InitFromGraphWithTransforms(
|
||||||
const CreateInferenceInfo& create_info, GraphFloat32* graph,
|
const CreateInferenceInfo& create_info, GraphFloat32* graph,
|
||||||
Environment* env) {
|
Environment* env) {
|
||||||
RETURN_IF_ERROR(RunGraphTransforms(graph));
|
RETURN_IF_ERROR(RunGraphTransforms(graph));
|
||||||
RETURN_IF_ERROR(InitFromGraph(create_info, *graph, env));
|
RETURN_IF_ERROR(InitFromGraph(create_info, *graph, env));
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
void InferenceContext::CopyInAndOutIds(const GraphFloat32& graph) {
|
void InferenceContext::CopyInAndOutIds(const GraphFloat32& graph) {
|
||||||
@ -258,7 +258,7 @@ void InferenceContext::ReserveGraphTensors(
|
|||||||
tensor_reserver_.SetNext(max_id + 1);
|
tensor_reserver_.SetNext(max_id + 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status InferenceContext::ConvertOperations(
|
absl::Status InferenceContext::ConvertOperations(
|
||||||
const CreationContext& creation_context, const GraphFloat32& graph,
|
const CreationContext& creation_context, const GraphFloat32& graph,
|
||||||
ModelHints hints) {
|
ModelHints hints) {
|
||||||
std::vector<Node*> graph_nodes = graph.nodes();
|
std::vector<Node*> graph_nodes = graph.nodes();
|
||||||
@ -343,7 +343,7 @@ Status InferenceContext::ConvertOperations(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
void InferenceContext::Merge() {
|
void InferenceContext::Merge() {
|
||||||
@ -424,15 +424,15 @@ void InferenceContext::GetUsages(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status InferenceContext::AllocateMemory(const CLDevice& device,
|
absl::Status InferenceContext::AllocateMemory(const CLDevice& device,
|
||||||
CLContext* context) {
|
CLContext* context) {
|
||||||
RETURN_IF_ERROR(AllocateMemoryForBuffers(device, context));
|
RETURN_IF_ERROR(AllocateMemoryForBuffers(device, context));
|
||||||
RETURN_IF_ERROR(AllocateMemoryForStrongShapes(device, context));
|
RETURN_IF_ERROR(AllocateMemoryForStrongShapes(device, context));
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status InferenceContext::AllocateMemoryForBuffers(const CLDevice& device,
|
absl::Status InferenceContext::AllocateMemoryForBuffers(const CLDevice& device,
|
||||||
CLContext* context) {
|
CLContext* context) {
|
||||||
std::map<ValueId, int2> buffer_usages;
|
std::map<ValueId, int2> buffer_usages;
|
||||||
GetUsages(
|
GetUsages(
|
||||||
[](const TensorDescriptor& t) { return IsBufferBased(t.storage_type); },
|
[](const TensorDescriptor& t) { return IsBufferBased(t.storage_type); },
|
||||||
@ -480,11 +480,11 @@ Status InferenceContext::AllocateMemoryForBuffers(const CLDevice& device,
|
|||||||
created_tensors[tensor_index] = true;
|
created_tensors[tensor_index] = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status InferenceContext::AllocateMemoryForStrongShapes(const CLDevice& device,
|
absl::Status InferenceContext::AllocateMemoryForStrongShapes(
|
||||||
CLContext* context) {
|
const CLDevice& device, CLContext* context) {
|
||||||
std::map<ValueId, int2> usages;
|
std::map<ValueId, int2> usages;
|
||||||
GetUsages(
|
GetUsages(
|
||||||
[](const TensorDescriptor& t) { return !IsBufferBased(t.storage_type); },
|
[](const TensorDescriptor& t) { return !IsBufferBased(t.storage_type); },
|
||||||
@ -517,7 +517,7 @@ Status InferenceContext::AllocateMemoryForStrongShapes(const CLDevice& device,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
void InferenceContext::BindMemoryToOperations() {
|
void InferenceContext::BindMemoryToOperations() {
|
||||||
@ -539,21 +539,22 @@ void InferenceContext::BindMemoryToOperations() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status InferenceContext::Compile(const CreationContext& creation_context) {
|
absl::Status InferenceContext::Compile(
|
||||||
|
const CreationContext& creation_context) {
|
||||||
for (auto& node : nodes_) {
|
for (auto& node : nodes_) {
|
||||||
RETURN_IF_ERROR(node.operations[0]->Compile(creation_context));
|
RETURN_IF_ERROR(node.operations[0]->Compile(creation_context));
|
||||||
}
|
}
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status InferenceContext::Tune(const TuningParameters& tuning_parameters) {
|
absl::Status InferenceContext::Tune(const TuningParameters& tuning_parameters) {
|
||||||
for (auto& node : nodes_) {
|
for (auto& node : nodes_) {
|
||||||
RETURN_IF_ERROR(node.operations[0]->Tune(tuning_parameters));
|
RETURN_IF_ERROR(node.operations[0]->Tune(tuning_parameters));
|
||||||
}
|
}
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status InferenceContext::AddToQueue(CLCommandQueue* queue) {
|
absl::Status InferenceContext::AddToQueue(CLCommandQueue* queue) {
|
||||||
if (need_manual_release_) {
|
if (need_manual_release_) {
|
||||||
if (prev_enqueue_start_point_.is_valid()) {
|
if (prev_enqueue_start_point_.is_valid()) {
|
||||||
prev_enqueue_start_point_.Wait();
|
prev_enqueue_start_point_.Wait();
|
||||||
@ -571,11 +572,11 @@ Status InferenceContext::AddToQueue(CLCommandQueue* queue) {
|
|||||||
if (need_flush_) {
|
if (need_flush_) {
|
||||||
clFlush(queue->queue());
|
clFlush(queue->queue());
|
||||||
}
|
}
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status InferenceContext::Profile(ProfilingCommandQueue* queue,
|
absl::Status InferenceContext::Profile(ProfilingCommandQueue* queue,
|
||||||
ProfilingInfo* result) {
|
ProfilingInfo* result) {
|
||||||
queue->ResetMeasurements();
|
queue->ResetMeasurements();
|
||||||
for (auto& node : nodes_) {
|
for (auto& node : nodes_) {
|
||||||
queue->SetEventsLabel(node.name);
|
queue->SetEventsLabel(node.name);
|
||||||
@ -583,7 +584,7 @@ Status InferenceContext::Profile(ProfilingCommandQueue* queue,
|
|||||||
}
|
}
|
||||||
RETURN_IF_ERROR(queue->WaitForCompletion());
|
RETURN_IF_ERROR(queue->WaitForCompletion());
|
||||||
*result = queue->GetProfilingInfo();
|
*result = queue->GetProfilingInfo();
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
uint64_t InferenceContext::GetSizeOfMemoryAllocatedForIntermediateTensors()
|
uint64_t InferenceContext::GetSizeOfMemoryAllocatedForIntermediateTensors()
|
||||||
@ -608,13 +609,15 @@ Tensor* InferenceContext::GetTensor(ValueId id) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status InferenceContext::SetInputTensor(ValueId id, const TensorFloat32& tensor,
|
absl::Status InferenceContext::SetInputTensor(ValueId id,
|
||||||
CLCommandQueue* queue) {
|
const TensorFloat32& tensor,
|
||||||
|
CLCommandQueue* queue) {
|
||||||
return GetTensor(id)->WriteData(queue, tensor);
|
return GetTensor(id)->WriteData(queue, tensor);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status InferenceContext::GetOutputTensor(ValueId id, CLCommandQueue* queue,
|
absl::Status InferenceContext::GetOutputTensor(ValueId id,
|
||||||
TensorFloat32* result) {
|
CLCommandQueue* queue,
|
||||||
|
TensorFloat32* result) {
|
||||||
const auto& gpu_tensor = *GetTensor(id);
|
const auto& gpu_tensor = *GetTensor(id);
|
||||||
const auto dst_shape = BHWC(gpu_tensor.Batch(), gpu_tensor.Height(),
|
const auto dst_shape = BHWC(gpu_tensor.Batch(), gpu_tensor.Height(),
|
||||||
gpu_tensor.Width(), gpu_tensor.Channels());
|
gpu_tensor.Width(), gpu_tensor.Channels());
|
||||||
@ -624,17 +627,17 @@ Status InferenceContext::GetOutputTensor(ValueId id, CLCommandQueue* queue,
|
|||||||
return gpu_tensor.ReadData(queue, result);
|
return gpu_tensor.ReadData(queue, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status RunGraphTransforms(GraphFloat32* graph) {
|
absl::Status RunGraphTransforms(GraphFloat32* graph) {
|
||||||
auto merge_padding_transform = NewMergePaddingWithAdd();
|
auto merge_padding_transform = NewMergePaddingWithAdd();
|
||||||
auto add_bias_transform = NewAddBias();
|
auto add_bias_transform = NewAddBias();
|
||||||
ModelTransformer transformer(graph, /*reporter=*/nullptr);
|
ModelTransformer transformer(graph, /*reporter=*/nullptr);
|
||||||
if (!transformer.Apply("add_bias", add_bias_transform.get())) {
|
if (!transformer.Apply("add_bias", add_bias_transform.get())) {
|
||||||
return InternalError("Invalid add_bias transform");
|
return absl::InternalError("Invalid add_bias transform");
|
||||||
}
|
}
|
||||||
if (!transformer.Apply("merge_padding", merge_padding_transform.get())) {
|
if (!transformer.Apply("merge_padding", merge_padding_transform.get())) {
|
||||||
return InternalError("Invalid merge_padding transform");
|
return absl::InternalError("Invalid merge_padding transform");
|
||||||
}
|
}
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cl
|
} // namespace cl
|
||||||
|
@ -65,53 +65,55 @@ class InferenceContext {
|
|||||||
TensorStorageType storage_type;
|
TensorStorageType storage_type;
|
||||||
ModelHints hints;
|
ModelHints hints;
|
||||||
};
|
};
|
||||||
Status InitFromGraph(const CreateInferenceInfo& create_info,
|
absl::Status InitFromGraph(const CreateInferenceInfo& create_info,
|
||||||
const GraphFloat32& graph, Environment* env);
|
const GraphFloat32& graph, Environment* env);
|
||||||
|
|
||||||
// Applies OpenCL-specific transformations to the graph before the
|
// Applies OpenCL-specific transformations to the graph before the
|
||||||
// initialization. These transformations are either impossible or useless in
|
// initialization. These transformations are either impossible or useless in
|
||||||
// other backends.
|
// other backends.
|
||||||
Status InitFromGraphWithTransforms(const CreateInferenceInfo& create_info,
|
absl::Status InitFromGraphWithTransforms(
|
||||||
GraphFloat32* graph, Environment* env);
|
const CreateInferenceInfo& create_info, GraphFloat32* graph,
|
||||||
|
Environment* env);
|
||||||
|
|
||||||
Status AddToQueue(CLCommandQueue* queue);
|
absl::Status AddToQueue(CLCommandQueue* queue);
|
||||||
Status Profile(ProfilingCommandQueue* queue, ProfilingInfo* result);
|
absl::Status Profile(ProfilingCommandQueue* queue, ProfilingInfo* result);
|
||||||
// for profiling and memory statistics
|
// for profiling and memory statistics
|
||||||
uint64_t GetSizeOfMemoryAllocatedForIntermediateTensors() const;
|
uint64_t GetSizeOfMemoryAllocatedForIntermediateTensors() const;
|
||||||
|
|
||||||
Status SetInputTensor(ValueId id, const TensorFloat32& tensor,
|
absl::Status SetInputTensor(ValueId id, const TensorFloat32& tensor,
|
||||||
CLCommandQueue* queue);
|
CLCommandQueue* queue);
|
||||||
|
|
||||||
// It will work only with input/output tensor ids. For all other ids we don't
|
// It will work only with input/output tensor ids. For all other ids we don't
|
||||||
// have any guarantees.
|
// have any guarantees.
|
||||||
Tensor* GetTensor(ValueId id);
|
Tensor* GetTensor(ValueId id);
|
||||||
|
|
||||||
Status GetOutputTensor(ValueId id, CLCommandQueue* queue,
|
absl::Status GetOutputTensor(ValueId id, CLCommandQueue* queue,
|
||||||
TensorFloat32* result);
|
TensorFloat32* result);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void CopyInAndOutIds(const GraphFloat32& graph);
|
void CopyInAndOutIds(const GraphFloat32& graph);
|
||||||
Status ConvertOperations(const CreationContext& creation_context,
|
absl::Status ConvertOperations(const CreationContext& creation_context,
|
||||||
const GraphFloat32& graph, ModelHints hints);
|
const GraphFloat32& graph, ModelHints hints);
|
||||||
void CreateLinks();
|
void CreateLinks();
|
||||||
void ReserveGraphTensors(const CreateInferenceInfo& create_info,
|
void ReserveGraphTensors(const CreateInferenceInfo& create_info,
|
||||||
const CreationContext& creation_context,
|
const CreationContext& creation_context,
|
||||||
const GraphFloat32& graph);
|
const GraphFloat32& graph);
|
||||||
void Merge();
|
void Merge();
|
||||||
Status AllocateMemory(const CLDevice& device, CLContext* context);
|
absl::Status AllocateMemory(const CLDevice& device, CLContext* context);
|
||||||
|
|
||||||
Status AllocateMemoryForBuffers(const CLDevice& device, CLContext* context);
|
absl::Status AllocateMemoryForBuffers(const CLDevice& device,
|
||||||
|
CLContext* context);
|
||||||
|
|
||||||
Status AllocateMemoryForStrongShapes(const CLDevice& device,
|
absl::Status AllocateMemoryForStrongShapes(const CLDevice& device,
|
||||||
CLContext* context);
|
CLContext* context);
|
||||||
|
|
||||||
// utility function
|
// utility function
|
||||||
void GetUsages(const std::function<bool(const TensorDescriptor&)>& functor,
|
void GetUsages(const std::function<bool(const TensorDescriptor&)>& functor,
|
||||||
std::map<ValueId, int2>* usages);
|
std::map<ValueId, int2>* usages);
|
||||||
|
|
||||||
void BindMemoryToOperations();
|
void BindMemoryToOperations();
|
||||||
Status Compile(const CreationContext& creation_context);
|
absl::Status Compile(const CreationContext& creation_context);
|
||||||
Status Tune(const TuningParameters& tuning_parameters);
|
absl::Status Tune(const TuningParameters& tuning_parameters);
|
||||||
|
|
||||||
// performance hacks
|
// performance hacks
|
||||||
bool need_flush_ = false;
|
bool need_flush_ = false;
|
||||||
@ -175,7 +177,7 @@ class InferenceContext {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Runs OpenCL specific transforms for the graph.
|
// Runs OpenCL specific transforms for the graph.
|
||||||
Status RunGraphTransforms(GraphFloat32* graph);
|
absl::Status RunGraphTransforms(GraphFloat32* graph);
|
||||||
|
|
||||||
} // namespace cl
|
} // namespace cl
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
|
@ -143,17 +143,17 @@ std::string Add::GetArgsDeclaration() const {
|
|||||||
return args;
|
return args;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Add::BindArguments(CLKernel* kernel) {
|
absl::Status Add::BindArguments(CLKernel* kernel) {
|
||||||
for (int i = 1; i < src_depthes_.size(); ++i) {
|
for (int i = 1; i < src_depthes_.size(); ++i) {
|
||||||
RETURN_IF_ERROR(kernel->SetMemoryAuto(src_[i]->GetMemoryPtr()));
|
RETURN_IF_ERROR(kernel->SetMemoryAuto(src_[i]->GetMemoryPtr()));
|
||||||
}
|
}
|
||||||
for (int i = 1; i < src_depthes_.size(); ++i) {
|
for (int i = 1; i < src_depthes_.size(); ++i) {
|
||||||
RETURN_IF_ERROR(kernel->SetBytesAuto(src_[i]->GetWBatchedHSB()));
|
RETURN_IF_ERROR(kernel->SetBytesAuto(src_[i]->GetWBatchedHSB()));
|
||||||
}
|
}
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Add::Compile(const CreationContext& creation_context) {
|
absl::Status Add::Compile(const CreationContext& creation_context) {
|
||||||
const auto code = GetElementWiseCode(definition_, linked_operations_);
|
const auto code = GetElementWiseCode(definition_, linked_operations_);
|
||||||
return creation_context.cache->GetOrCreateCLKernel(
|
return creation_context.cache->GetOrCreateCLKernel(
|
||||||
code, "main_function", *creation_context.context,
|
code, "main_function", *creation_context.context,
|
||||||
|
@ -36,7 +36,7 @@ class Add : public ElementwiseOperation {
|
|||||||
Add(const OperationDef& definition, const std::vector<int>& channels,
|
Add(const OperationDef& definition, const std::vector<int>& channels,
|
||||||
int dst_channels);
|
int dst_channels);
|
||||||
|
|
||||||
Status Compile(const CreationContext& creation_context) override;
|
absl::Status Compile(const CreationContext& creation_context) override;
|
||||||
|
|
||||||
// Move only
|
// Move only
|
||||||
Add(Add&& operation);
|
Add(Add&& operation);
|
||||||
@ -47,7 +47,7 @@ class Add : public ElementwiseOperation {
|
|||||||
void SetLinkIndex(int index) override;
|
void SetLinkIndex(int index) override;
|
||||||
std::string GetCoreCode(const LinkingContext& context) const override;
|
std::string GetCoreCode(const LinkingContext& context) const override;
|
||||||
std::string GetArgsDeclaration() const override;
|
std::string GetArgsDeclaration() const override;
|
||||||
Status BindArguments(CLKernel* kernel) override;
|
absl::Status BindArguments(CLKernel* kernel) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::string GetElementWiseCode(
|
std::string GetElementWiseCode(
|
||||||
|
@ -21,17 +21,17 @@ namespace tflite {
|
|||||||
namespace gpu {
|
namespace gpu {
|
||||||
namespace cl {
|
namespace cl {
|
||||||
|
|
||||||
Status ExecuteGPUOperation(const std::vector<TensorFloat32>& src_cpu,
|
absl::Status ExecuteGPUOperation(const std::vector<TensorFloat32>& src_cpu,
|
||||||
const CreationContext& creation_context,
|
const CreationContext& creation_context,
|
||||||
GPUOperation* operation,
|
GPUOperation* operation,
|
||||||
const std::vector<BHWC>& dst_sizes,
|
const std::vector<BHWC>& dst_sizes,
|
||||||
const std::vector<TensorFloat32*>& dst_cpu) {
|
const std::vector<TensorFloat32*>& dst_cpu) {
|
||||||
const OperationDef& op_def = operation->GetDefinition();
|
const OperationDef& op_def = operation->GetDefinition();
|
||||||
std::vector<Tensor> src(src_cpu.size());
|
std::vector<Tensor> src(src_cpu.size());
|
||||||
for (int i = 0; i < src_cpu.size(); ++i) {
|
for (int i = 0; i < src_cpu.size(); ++i) {
|
||||||
auto src_shape = src_cpu[i].shape;
|
auto src_shape = src_cpu[i].shape;
|
||||||
if (src_shape.b != 1 && !op_def.IsBatchSupported()) {
|
if (src_shape.b != 1 && !op_def.IsBatchSupported()) {
|
||||||
return InvalidArgumentError(
|
return absl::InvalidArgumentError(
|
||||||
"Layout doesn't have Batch dimension, but shape.b != 1");
|
"Layout doesn't have Batch dimension, but shape.b != 1");
|
||||||
}
|
}
|
||||||
RETURN_IF_ERROR(CreateTensor(*creation_context.context,
|
RETURN_IF_ERROR(CreateTensor(*creation_context.context,
|
||||||
@ -45,7 +45,7 @@ Status ExecuteGPUOperation(const std::vector<TensorFloat32>& src_cpu,
|
|||||||
for (int i = 0; i < dst_cpu.size(); ++i) {
|
for (int i = 0; i < dst_cpu.size(); ++i) {
|
||||||
auto dst_shape = dst_sizes[i];
|
auto dst_shape = dst_sizes[i];
|
||||||
if (dst_shape.b != 1 && !op_def.IsBatchSupported()) {
|
if (dst_shape.b != 1 && !op_def.IsBatchSupported()) {
|
||||||
return InvalidArgumentError(
|
return absl::InvalidArgumentError(
|
||||||
"Layout doesn't have Batch dimension, but shape.b != 1");
|
"Layout doesn't have Batch dimension, but shape.b != 1");
|
||||||
}
|
}
|
||||||
RETURN_IF_ERROR(CreateTensor(*creation_context.context,
|
RETURN_IF_ERROR(CreateTensor(*creation_context.context,
|
||||||
@ -64,22 +64,22 @@ Status ExecuteGPUOperation(const std::vector<TensorFloat32>& src_cpu,
|
|||||||
dst_cpu[i]->data = std::vector<float>(dst_sizes[i].DimensionsProduct(), 0);
|
dst_cpu[i]->data = std::vector<float>(dst_sizes[i].DimensionsProduct(), 0);
|
||||||
RETURN_IF_ERROR(dst[i].ReadData(creation_context.queue, dst_cpu[i]));
|
RETURN_IF_ERROR(dst[i].ReadData(creation_context.queue, dst_cpu[i]));
|
||||||
}
|
}
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ExecuteGPUOperation(const std::vector<TensorFloat32>& src_cpu,
|
absl::Status ExecuteGPUOperation(const std::vector<TensorFloat32>& src_cpu,
|
||||||
const CreationContext& creation_context,
|
const CreationContext& creation_context,
|
||||||
GPUOperation* operation, const BHWC& dst_size,
|
GPUOperation* operation, const BHWC& dst_size,
|
||||||
TensorFloat32* result) {
|
TensorFloat32* result) {
|
||||||
return ExecuteGPUOperation(
|
return ExecuteGPUOperation(
|
||||||
std::vector<TensorFloat32>{src_cpu}, creation_context, operation,
|
std::vector<TensorFloat32>{src_cpu}, creation_context, operation,
|
||||||
std::vector<BHWC>{dst_size}, std::vector<TensorFloat32*>{result});
|
std::vector<BHWC>{dst_size}, std::vector<TensorFloat32*>{result});
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ExecuteGPUOperation(const TensorFloat32& src_cpu,
|
absl::Status ExecuteGPUOperation(const TensorFloat32& src_cpu,
|
||||||
const CreationContext& creation_context,
|
const CreationContext& creation_context,
|
||||||
GPUOperation* operation, const BHWC& dst_size,
|
GPUOperation* operation, const BHWC& dst_size,
|
||||||
TensorFloat32* result) {
|
TensorFloat32* result) {
|
||||||
return ExecuteGPUOperation(std::vector<TensorFloat32>{src_cpu},
|
return ExecuteGPUOperation(std::vector<TensorFloat32>{src_cpu},
|
||||||
creation_context, operation, dst_size, result);
|
creation_context, operation, dst_size, result);
|
||||||
}
|
}
|
||||||
|
@ -51,21 +51,21 @@ class OpenCLOperationTest : public ::testing::Test {
|
|||||||
CreationContext creation_context_;
|
CreationContext creation_context_;
|
||||||
};
|
};
|
||||||
|
|
||||||
Status ExecuteGPUOperation(const TensorFloat32& src_cpu,
|
absl::Status ExecuteGPUOperation(const TensorFloat32& src_cpu,
|
||||||
const CreationContext& creation_context,
|
const CreationContext& creation_context,
|
||||||
GPUOperation* operation, const BHWC& dst_size,
|
GPUOperation* operation, const BHWC& dst_size,
|
||||||
TensorFloat32* result);
|
TensorFloat32* result);
|
||||||
|
|
||||||
Status ExecuteGPUOperation(const std::vector<TensorFloat32>& src_cpu,
|
absl::Status ExecuteGPUOperation(const std::vector<TensorFloat32>& src_cpu,
|
||||||
const CreationContext& creation_context,
|
const CreationContext& creation_context,
|
||||||
GPUOperation* operation, const BHWC& dst_size,
|
GPUOperation* operation, const BHWC& dst_size,
|
||||||
TensorFloat32* result);
|
TensorFloat32* result);
|
||||||
|
|
||||||
Status ExecuteGPUOperation(const std::vector<TensorFloat32>& src_cpu,
|
absl::Status ExecuteGPUOperation(const std::vector<TensorFloat32>& src_cpu,
|
||||||
const CreationContext& creation_context,
|
const CreationContext& creation_context,
|
||||||
GPUOperation* operation,
|
GPUOperation* operation,
|
||||||
const std::vector<BHWC>& dst_sizes,
|
const std::vector<BHWC>& dst_sizes,
|
||||||
const std::vector<TensorFloat32*>& dst_cpu);
|
const std::vector<TensorFloat32*>& dst_cpu);
|
||||||
} // namespace cl
|
} // namespace cl
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -96,7 +96,7 @@ ConcatXY& ConcatXY::operator=(ConcatXY&& operation) {
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConcatXY::Compile(const CreationContext& creation_context) {
|
absl::Status ConcatXY::Compile(const CreationContext& creation_context) {
|
||||||
const auto code =
|
const auto code =
|
||||||
GetConcatKernelCode(definition_, tensors_count_, linked_operations_);
|
GetConcatKernelCode(definition_, tensors_count_, linked_operations_);
|
||||||
return creation_context.cache->GetOrCreateCLKernel(
|
return creation_context.cache->GetOrCreateCLKernel(
|
||||||
@ -104,7 +104,7 @@ Status ConcatXY::Compile(const CreationContext& creation_context) {
|
|||||||
*creation_context.device, &kernel_);
|
*creation_context.device, &kernel_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConcatXY::BindArguments() {
|
absl::Status ConcatXY::BindArguments() {
|
||||||
kernel_.ResetBindingCounter();
|
kernel_.ResetBindingCounter();
|
||||||
for (int i = 0; i < tensors_count_; ++i) {
|
for (int i = 0; i < tensors_count_; ++i) {
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[i]->GetMemoryPtr()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[i]->GetMemoryPtr()));
|
||||||
@ -122,7 +122,7 @@ Status ConcatXY::BindArguments() {
|
|||||||
y_offset += attr_.axis == Axis::HEIGHT ? height : 0;
|
y_offset += attr_.axis == Axis::HEIGHT ? height : 0;
|
||||||
}
|
}
|
||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHSB()));
|
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHSB()));
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
int3 ConcatXY::GetGridSize() const {
|
int3 ConcatXY::GetGridSize() const {
|
||||||
@ -140,12 +140,12 @@ int3 ConcatXY::GetGridSize() const {
|
|||||||
return int3(grid_x, grid_y, grid_z);
|
return int3(grid_x, grid_y, grid_z);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConcatXY::Tune(const TuningParameters& params) {
|
absl::Status ConcatXY::Tune(const TuningParameters& params) {
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_);
|
return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConcatXY::AddToQueue(CLCommandQueue* queue) {
|
absl::Status ConcatXY::AddToQueue(CLCommandQueue* queue) {
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
||||||
}
|
}
|
||||||
|
@ -31,10 +31,10 @@ class ConcatXY : public GPUOperation {
|
|||||||
ConcatXY(const OperationDef& definition, const ConcatAttributes& attr,
|
ConcatXY(const OperationDef& definition, const ConcatAttributes& attr,
|
||||||
int tensors_count)
|
int tensors_count)
|
||||||
: GPUOperation(definition), attr_(attr), tensors_count_(tensors_count) {}
|
: GPUOperation(definition), attr_(attr), tensors_count_(tensors_count) {}
|
||||||
Status AddToQueue(CLCommandQueue* queue) override;
|
absl::Status AddToQueue(CLCommandQueue* queue) override;
|
||||||
Status Tune(const TuningParameters& params) override;
|
absl::Status Tune(const TuningParameters& params) override;
|
||||||
|
|
||||||
Status Compile(const CreationContext& creation_context) override;
|
absl::Status Compile(const CreationContext& creation_context) override;
|
||||||
|
|
||||||
// Move only
|
// Move only
|
||||||
ConcatXY(ConcatXY&& operation);
|
ConcatXY(ConcatXY&& operation);
|
||||||
@ -43,7 +43,7 @@ class ConcatXY : public GPUOperation {
|
|||||||
ConcatXY& operator=(const ConcatXY&) = delete;
|
ConcatXY& operator=(const ConcatXY&) = delete;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Status BindArguments();
|
absl::Status BindArguments();
|
||||||
int3 GetGridSize() const;
|
int3 GetGridSize() const;
|
||||||
|
|
||||||
ConcatAttributes attr_;
|
ConcatAttributes attr_;
|
||||||
|
@ -25,8 +25,8 @@ limitations under the License.
|
|||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace gpu {
|
namespace gpu {
|
||||||
namespace cl {
|
namespace cl {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
bool IsAllChannelsX4(const std::vector<int>& channels) {
|
bool IsAllChannelsX4(const std::vector<int>& channels) {
|
||||||
for (int channel : channels) {
|
for (int channel : channels) {
|
||||||
if (channel % 4 != 0) {
|
if (channel % 4 != 0) {
|
||||||
@ -146,6 +146,7 @@ std::string GetConcatKernelCode(
|
|||||||
c += "}\n";
|
c += "}\n";
|
||||||
return c;
|
return c;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
ConcatZ::ConcatZ(ConcatZ&& kernel)
|
ConcatZ::ConcatZ(ConcatZ&& kernel)
|
||||||
@ -164,7 +165,7 @@ ConcatZ& ConcatZ::operator=(ConcatZ&& kernel) {
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConcatZ::Compile(const CreationContext& creation_context) {
|
absl::Status ConcatZ::Compile(const CreationContext& creation_context) {
|
||||||
const auto code =
|
const auto code =
|
||||||
GetConcatKernelCode(definition_, channels_, linked_operations_);
|
GetConcatKernelCode(definition_, channels_, linked_operations_);
|
||||||
std::vector<CompilerOptions> options;
|
std::vector<CompilerOptions> options;
|
||||||
@ -186,7 +187,7 @@ Status ConcatZ::Compile(const CreationContext& creation_context) {
|
|||||||
*creation_context.device, &kernel_);
|
*creation_context.device, &kernel_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConcatZ::BindArguments() {
|
absl::Status ConcatZ::BindArguments() {
|
||||||
kernel_.ResetBindingCounter();
|
kernel_.ResetBindingCounter();
|
||||||
for (int i = 0; i < channels_.size(); ++i) {
|
for (int i = 0; i < channels_.size(); ++i) {
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[i]->GetMemoryPtr()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[i]->GetMemoryPtr()));
|
||||||
@ -197,7 +198,7 @@ Status ConcatZ::BindArguments() {
|
|||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[i]->Slices()));
|
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[i]->Slices()));
|
||||||
}
|
}
|
||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHSB()));
|
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHSB()));
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
int3 ConcatZ::GetGridSize() const {
|
int3 ConcatZ::GetGridSize() const {
|
||||||
@ -207,12 +208,12 @@ int3 ConcatZ::GetGridSize() const {
|
|||||||
return int3(grid_x, grid_y, grid_z);
|
return int3(grid_x, grid_y, grid_z);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConcatZ::Tune(const TuningParameters& params) {
|
absl::Status ConcatZ::Tune(const TuningParameters& params) {
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_);
|
return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConcatZ::AddToQueue(CLCommandQueue* queue) {
|
absl::Status ConcatZ::AddToQueue(CLCommandQueue* queue) {
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
||||||
}
|
}
|
||||||
|
@ -32,10 +32,10 @@ class ConcatZ : public GPUOperation {
|
|||||||
public:
|
public:
|
||||||
ConcatZ(const OperationDef& definition, const std::vector<int>& channels)
|
ConcatZ(const OperationDef& definition, const std::vector<int>& channels)
|
||||||
: GPUOperation(definition), channels_(channels) {}
|
: GPUOperation(definition), channels_(channels) {}
|
||||||
Status AddToQueue(CLCommandQueue* queue) override;
|
absl::Status AddToQueue(CLCommandQueue* queue) override;
|
||||||
Status Tune(const TuningParameters& params) override;
|
absl::Status Tune(const TuningParameters& params) override;
|
||||||
|
|
||||||
Status Compile(const CreationContext& creation_context) override;
|
absl::Status Compile(const CreationContext& creation_context) override;
|
||||||
|
|
||||||
// Move only
|
// Move only
|
||||||
ConcatZ(ConcatZ&& kernel);
|
ConcatZ(ConcatZ&& kernel);
|
||||||
@ -44,7 +44,7 @@ class ConcatZ : public GPUOperation {
|
|||||||
ConcatZ& operator=(const ConcatZ&) = delete;
|
ConcatZ& operator=(const ConcatZ&) = delete;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Status BindArguments();
|
absl::Status BindArguments();
|
||||||
int3 GetGridSize() const;
|
int3 GetGridSize() const;
|
||||||
|
|
||||||
std::vector<int> channels_;
|
std::vector<int> channels_;
|
||||||
|
@ -76,7 +76,7 @@ Conv3D& Conv3D::operator=(Conv3D&& operation) {
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Conv3D::Compile(const CreationContext& creation_context) {
|
absl::Status Conv3D::Compile(const CreationContext& creation_context) {
|
||||||
const bool stride_correction =
|
const bool stride_correction =
|
||||||
definition_.IsBatchSupported() && stride_.x != 1;
|
definition_.IsBatchSupported() && stride_.x != 1;
|
||||||
const std::string code =
|
const std::string code =
|
||||||
@ -92,7 +92,7 @@ Status Conv3D::Compile(const CreationContext& creation_context) {
|
|||||||
*creation_context.device, &kernel_);
|
*creation_context.device, &kernel_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Conv3D::BindArguments() {
|
absl::Status Conv3D::BindArguments() {
|
||||||
kernel_.ResetBindingCounter();
|
kernel_.ResetBindingCounter();
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||||
if (conv_params_.AreWeightsBuffer()) {
|
if (conv_params_.AreWeightsBuffer()) {
|
||||||
@ -131,7 +131,7 @@ Status Conv3D::BindArguments() {
|
|||||||
IntegralDivideRoundUp(dst_[0]->Slices(), conv_params_.block_size.w)));
|
IntegralDivideRoundUp(dst_[0]->Slices(), conv_params_.block_size.w)));
|
||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWBatchedHDS()));
|
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWBatchedHDS()));
|
||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHDS()));
|
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHDS()));
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
int3 Conv3D::GetGridSize() const {
|
int3 Conv3D::GetGridSize() const {
|
||||||
@ -154,12 +154,12 @@ int3 Conv3D::GetGridSize() const {
|
|||||||
conv_params_.work_group_size.z);
|
conv_params_.work_group_size.z);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Conv3D::Tune(const TuningParameters& params) {
|
absl::Status Conv3D::Tune(const TuningParameters& params) {
|
||||||
if (conv_params_.weights_upload_type ==
|
if (conv_params_.weights_upload_type ==
|
||||||
WeightsUploadType::LOCAL_MEM_ASYNC_SUBGROUP ||
|
WeightsUploadType::LOCAL_MEM_ASYNC_SUBGROUP ||
|
||||||
conv_params_.weights_upload_type ==
|
conv_params_.weights_upload_type ==
|
||||||
WeightsUploadType::LOCAL_MEM_BY_THREADS) {
|
WeightsUploadType::LOCAL_MEM_BY_THREADS) {
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
if (conv_params_.work_group_launch_order[0] == 0 &&
|
if (conv_params_.work_group_launch_order[0] == 0 &&
|
||||||
conv_params_.work_group_launch_order[1] == 1 &&
|
conv_params_.work_group_launch_order[1] == 1 &&
|
||||||
@ -168,10 +168,10 @@ Status Conv3D::Tune(const TuningParameters& params) {
|
|||||||
return GetBestWorkGroupConv(params, kernel_, GetGridSize(),
|
return GetBestWorkGroupConv(params, kernel_, GetGridSize(),
|
||||||
&conv_params_.work_group_size);
|
&conv_params_.work_group_size);
|
||||||
}
|
}
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Conv3D::AddToQueue(CLCommandQueue* queue) {
|
absl::Status Conv3D::AddToQueue(CLCommandQueue* queue) {
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return queue->DispatchImplicit(kernel_, GetGridSize(),
|
return queue->DispatchImplicit(kernel_, GetGridSize(),
|
||||||
conv_params_.work_group_size);
|
conv_params_.work_group_size);
|
||||||
@ -903,9 +903,9 @@ Conv3D::ConvParams Conv3D::GuessBestParams(
|
|||||||
x_kernel_is_1, y_kernel_is_1, z_kernel_is_1);
|
x_kernel_is_1, y_kernel_is_1, z_kernel_is_1);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateConv3D(const CreationContext& creation_context,
|
absl::Status CreateConv3D(const CreationContext& creation_context,
|
||||||
const OperationDef& definition,
|
const OperationDef& definition,
|
||||||
const Convolution3DAttributes& attr, Conv3D* result) {
|
const Convolution3DAttributes& attr, Conv3D* result) {
|
||||||
*result = Conv3D(definition, attr, *creation_context.device);
|
*result = Conv3D(definition, attr, *creation_context.device);
|
||||||
return result->UploadData(attr.weights, attr.bias, creation_context.context);
|
return result->UploadData(attr.weights, attr.bias, creation_context.context);
|
||||||
}
|
}
|
||||||
|
@ -39,9 +39,9 @@ namespace cl {
|
|||||||
class Conv3D : public GPUOperation {
|
class Conv3D : public GPUOperation {
|
||||||
public:
|
public:
|
||||||
Conv3D() = default;
|
Conv3D() = default;
|
||||||
Status AddToQueue(CLCommandQueue* queue) override;
|
absl::Status AddToQueue(CLCommandQueue* queue) override;
|
||||||
Status Tune(const TuningParameters& params) override;
|
absl::Status Tune(const TuningParameters& params) override;
|
||||||
Status Compile(const CreationContext& creation_context) override;
|
absl::Status Compile(const CreationContext& creation_context) override;
|
||||||
|
|
||||||
// Move only
|
// Move only
|
||||||
Conv3D(Conv3D&& operation);
|
Conv3D(Conv3D&& operation);
|
||||||
@ -75,21 +75,21 @@ class Conv3D : public GPUOperation {
|
|||||||
const CLDevice& device);
|
const CLDevice& device);
|
||||||
|
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
Status UploadData(const ::tflite::gpu::Tensor<OHWDI, T>& weights,
|
absl::Status UploadData(const ::tflite::gpu::Tensor<OHWDI, T>& weights,
|
||||||
const ::tflite::gpu::Tensor<Linear, T>& biases,
|
const ::tflite::gpu::Tensor<Linear, T>& biases,
|
||||||
CLContext* context);
|
CLContext* context);
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
Status UploadWeights(const ::tflite::gpu::Tensor<OHWDI, T>& weights,
|
absl::Status UploadWeights(const ::tflite::gpu::Tensor<OHWDI, T>& weights,
|
||||||
CLContext* context);
|
CLContext* context);
|
||||||
|
|
||||||
template <DataType S, typename T>
|
template <DataType S, typename T>
|
||||||
void RearrangeWeightsData(const ::tflite::gpu::Tensor<OHWDI, S>& weights,
|
void RearrangeWeightsData(const ::tflite::gpu::Tensor<OHWDI, S>& weights,
|
||||||
absl::Span<T> dst);
|
absl::Span<T> dst);
|
||||||
|
|
||||||
friend Status CreateConv3D(const CreationContext& creation_context,
|
friend absl::Status CreateConv3D(const CreationContext& creation_context,
|
||||||
const OperationDef& definition,
|
const OperationDef& definition,
|
||||||
const Convolution3DAttributes& attr,
|
const Convolution3DAttributes& attr,
|
||||||
Conv3D* result);
|
Conv3D* result);
|
||||||
|
|
||||||
friend std::string GenerateConv3D(
|
friend std::string GenerateConv3D(
|
||||||
const OperationDef& op_def, const LinearStorage& biases,
|
const OperationDef& op_def, const LinearStorage& biases,
|
||||||
@ -105,7 +105,7 @@ class Conv3D : public GPUOperation {
|
|||||||
int dst_slices, bool x_kernel_is_1,
|
int dst_slices, bool x_kernel_is_1,
|
||||||
bool y_kernel_is_1, bool z_kernel_is_1) const;
|
bool y_kernel_is_1, bool z_kernel_is_1) const;
|
||||||
|
|
||||||
Status BindArguments();
|
absl::Status BindArguments();
|
||||||
int3 GetGridSize() const;
|
int3 GetGridSize() const;
|
||||||
|
|
||||||
Texture2D weights_0_;
|
Texture2D weights_0_;
|
||||||
@ -125,9 +125,9 @@ class Conv3D : public GPUOperation {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
Status Conv3D::UploadData(const ::tflite::gpu::Tensor<OHWDI, T>& weights,
|
absl::Status Conv3D::UploadData(const ::tflite::gpu::Tensor<OHWDI, T>& weights,
|
||||||
const ::tflite::gpu::Tensor<Linear, T>& biases,
|
const ::tflite::gpu::Tensor<Linear, T>& biases,
|
||||||
CLContext* context) {
|
CLContext* context) {
|
||||||
RETURN_IF_ERROR(UploadWeights(weights, context));
|
RETURN_IF_ERROR(UploadWeights(weights, context));
|
||||||
LinearStorageCreateInfo create_info;
|
LinearStorageCreateInfo create_info;
|
||||||
create_info.storage_type = conv_params_.AreWeightsBuffer()
|
create_info.storage_type = conv_params_.AreWeightsBuffer()
|
||||||
@ -139,12 +139,12 @@ Status Conv3D::UploadData(const ::tflite::gpu::Tensor<OHWDI, T>& weights,
|
|||||||
create_info.name = "biases";
|
create_info.name = "biases";
|
||||||
create_info.aligned_size = weights.shape.o;
|
create_info.aligned_size = weights.shape.o;
|
||||||
RETURN_IF_ERROR(CreateLinearStorage(create_info, biases, context, &biases_));
|
RETURN_IF_ERROR(CreateLinearStorage(create_info, biases, context, &biases_));
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
Status Conv3D::UploadWeights(const ::tflite::gpu::Tensor<OHWDI, T>& weights,
|
absl::Status Conv3D::UploadWeights(
|
||||||
CLContext* context) {
|
const ::tflite::gpu::Tensor<OHWDI, T>& weights, CLContext* context) {
|
||||||
const int block_size = conv_params_.block_size.w;
|
const int block_size = conv_params_.block_size.w;
|
||||||
const int dst_slices =
|
const int dst_slices =
|
||||||
AlignByN(IntegralDivideRoundUp(weights.shape.o, 4), block_size);
|
AlignByN(IntegralDivideRoundUp(weights.shape.o, 4), block_size);
|
||||||
@ -211,7 +211,7 @@ Status Conv3D::UploadWeights(const ::tflite::gpu::Tensor<OHWDI, T>& weights,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
template <DataType S, typename T>
|
template <DataType S, typename T>
|
||||||
@ -271,9 +271,9 @@ void Conv3D::RearrangeWeightsData(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateConv3D(const CreationContext& creation_context,
|
absl::Status CreateConv3D(const CreationContext& creation_context,
|
||||||
const OperationDef& definition,
|
const OperationDef& definition,
|
||||||
const Convolution3DAttributes& attr, Conv3D* result);
|
const Convolution3DAttributes& attr, Conv3D* result);
|
||||||
|
|
||||||
} // namespace cl
|
} // namespace cl
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
|
@ -291,16 +291,16 @@ ConvBuffer1x1& ConvBuffer1x1::operator=(ConvBuffer1x1&& operation) {
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConvBuffer1x1::Compile(const CreationContext& creation_context) {
|
absl::Status ConvBuffer1x1::Compile(const CreationContext& creation_context) {
|
||||||
std::string code =
|
std::string code =
|
||||||
GenerateConvBuffer1x1(definition_, conv_params_, linked_operations_);
|
GenerateConvBuffer1x1(definition_, conv_params_, linked_operations_);
|
||||||
RETURN_IF_ERROR(creation_context.cache->GetOrCreateCLKernel(
|
RETURN_IF_ERROR(creation_context.cache->GetOrCreateCLKernel(
|
||||||
code, "main_function", *creation_context.context,
|
code, "main_function", *creation_context.context,
|
||||||
*creation_context.device, &kernel_));
|
*creation_context.device, &kernel_));
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConvBuffer1x1::BindArguments() {
|
absl::Status ConvBuffer1x1::BindArguments() {
|
||||||
kernel_.ResetBindingCounter();
|
kernel_.ResetBindingCounter();
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_.GetMemoryPtr()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_.GetMemoryPtr()));
|
||||||
@ -313,7 +313,7 @@ Status ConvBuffer1x1::BindArguments() {
|
|||||||
src_width_elements * src_[0]->Height());
|
src_width_elements * src_[0]->Height());
|
||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_size));
|
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_size));
|
||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHSB()));
|
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHSB()));
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
int3 ConvBuffer1x1::GetGridSize() const {
|
int3 ConvBuffer1x1::GetGridSize() const {
|
||||||
@ -328,13 +328,13 @@ int3 ConvBuffer1x1::GetGridSize() const {
|
|||||||
return int3(grid_x, grid_y, grid_z);
|
return int3(grid_x, grid_y, grid_z);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConvBuffer1x1::Tune(const TuningParameters& params) {
|
absl::Status ConvBuffer1x1::Tune(const TuningParameters& params) {
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return GetBestWorkGroupConv(params, kernel_, GetGridSize(),
|
return GetBestWorkGroupConv(params, kernel_, GetGridSize(),
|
||||||
&conv_params_.work_group_size);
|
&conv_params_.work_group_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConvBuffer1x1::AddToQueue(CLCommandQueue* queue) {
|
absl::Status ConvBuffer1x1::AddToQueue(CLCommandQueue* queue) {
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return queue->DispatchImplicit(kernel_, GetGridSize(),
|
return queue->DispatchImplicit(kernel_, GetGridSize(),
|
||||||
conv_params_.work_group_size);
|
conv_params_.work_group_size);
|
||||||
@ -351,12 +351,12 @@ bool IsConvBuffer1x1Supported(const OperationDef& definition,
|
|||||||
attr.padding.appended.w == 0 && attr.padding.appended.h == 0;
|
attr.padding.appended.w == 0 && attr.padding.appended.h == 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateConvBuffer1x1(const CreationContext& creation_context,
|
absl::Status CreateConvBuffer1x1(const CreationContext& creation_context,
|
||||||
const OperationDef& definition,
|
const OperationDef& definition,
|
||||||
const Convolution2DAttributes& attr,
|
const Convolution2DAttributes& attr,
|
||||||
ConvBuffer1x1* result, const BHWC* shape) {
|
ConvBuffer1x1* result, const BHWC* shape) {
|
||||||
if (!IsConvBuffer1x1Supported(definition, attr)) {
|
if (!IsConvBuffer1x1Supported(definition, attr)) {
|
||||||
return InvalidArgumentError("ConvBuffer1x1 doesn't supported");
|
return absl::InvalidArgumentError("ConvBuffer1x1 doesn't supported");
|
||||||
}
|
}
|
||||||
const int dst_depth = IntegralDivideRoundUp(attr.weights.shape.o, 4);
|
const int dst_depth = IntegralDivideRoundUp(attr.weights.shape.o, 4);
|
||||||
const int src_depth = IntegralDivideRoundUp(attr.weights.shape.i, 4);
|
const int src_depth = IntegralDivideRoundUp(attr.weights.shape.i, 4);
|
||||||
@ -372,10 +372,10 @@ Status CreateConvBuffer1x1(const CreationContext& creation_context,
|
|||||||
return result->UploadData(attr.weights, attr.bias, creation_context.context);
|
return result->UploadData(attr.weights, attr.bias, creation_context.context);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateConvBuffer1x1(const CreationContext& creation_context,
|
absl::Status CreateConvBuffer1x1(const CreationContext& creation_context,
|
||||||
const OperationDef& definition,
|
const OperationDef& definition,
|
||||||
const FullyConnectedAttributes& attr,
|
const FullyConnectedAttributes& attr,
|
||||||
ConvBuffer1x1* result, const BHWC* shape) {
|
ConvBuffer1x1* result, const BHWC* shape) {
|
||||||
const int dst_depth = IntegralDivideRoundUp(attr.weights.shape.o, 4);
|
const int dst_depth = IntegralDivideRoundUp(attr.weights.shape.o, 4);
|
||||||
const int src_depth = IntegralDivideRoundUp(attr.weights.shape.i, 4);
|
const int src_depth = IntegralDivideRoundUp(attr.weights.shape.i, 4);
|
||||||
ConvBuffer1x1::ConvParams conv_params;
|
ConvBuffer1x1::ConvParams conv_params;
|
||||||
@ -392,11 +392,10 @@ Status CreateConvBuffer1x1(const CreationContext& creation_context,
|
|||||||
return result->UploadData(attr.weights, attr.bias, creation_context.context);
|
return result->UploadData(attr.weights, attr.bias, creation_context.context);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateConvBuffer1x1Wino4x4To6x6(const CreationContext& creation_context,
|
absl::Status CreateConvBuffer1x1Wino4x4To6x6(
|
||||||
const OperationDef& definition,
|
const CreationContext& creation_context, const OperationDef& definition,
|
||||||
const Convolution2DAttributes& attr,
|
const Convolution2DAttributes& attr, ConvBuffer1x1* result,
|
||||||
ConvBuffer1x1* result,
|
const BHWC* shape) {
|
||||||
const BHWC* shape) {
|
|
||||||
const int dst_depth = IntegralDivideRoundUp(attr.weights.shape.o, 4);
|
const int dst_depth = IntegralDivideRoundUp(attr.weights.shape.o, 4);
|
||||||
const int src_depth = IntegralDivideRoundUp(attr.weights.shape.i, 4);
|
const int src_depth = IntegralDivideRoundUp(attr.weights.shape.i, 4);
|
||||||
ConvBuffer1x1::ConvParams conv_params;
|
ConvBuffer1x1::ConvParams conv_params;
|
||||||
|
@ -45,10 +45,10 @@ class ConvBuffer1x1 : public GPUOperation {
|
|||||||
ConvBuffer1x1(const ConvBuffer1x1&) = delete;
|
ConvBuffer1x1(const ConvBuffer1x1&) = delete;
|
||||||
ConvBuffer1x1& operator=(const ConvBuffer1x1&) = delete;
|
ConvBuffer1x1& operator=(const ConvBuffer1x1&) = delete;
|
||||||
|
|
||||||
Status AddToQueue(CLCommandQueue* queue) override;
|
absl::Status AddToQueue(CLCommandQueue* queue) override;
|
||||||
Status Tune(const TuningParameters& params) override;
|
absl::Status Tune(const TuningParameters& params) override;
|
||||||
|
|
||||||
Status Compile(const CreationContext& creation_context) override;
|
absl::Status Compile(const CreationContext& creation_context) override;
|
||||||
|
|
||||||
struct ConvParams {
|
struct ConvParams {
|
||||||
int3 block_size = int3(1, 1, 1);
|
int3 block_size = int3(1, 1, 1);
|
||||||
@ -64,33 +64,33 @@ class ConvBuffer1x1 : public GPUOperation {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
ConvBuffer1x1(const OperationDef& definition, const ConvParams& conv_params);
|
ConvBuffer1x1(const OperationDef& definition, const ConvParams& conv_params);
|
||||||
friend Status CreateConvBuffer1x1(const CreationContext& creation_context,
|
friend absl::Status CreateConvBuffer1x1(
|
||||||
const OperationDef& definition,
|
const CreationContext& creation_context, const OperationDef& definition,
|
||||||
const Convolution2DAttributes& attr,
|
const Convolution2DAttributes& attr, ConvBuffer1x1* result,
|
||||||
ConvBuffer1x1* result, const BHWC* shape);
|
const BHWC* shape);
|
||||||
friend Status CreateConvBuffer1x1(const CreationContext& creation_context,
|
friend absl::Status CreateConvBuffer1x1(
|
||||||
const OperationDef& definition,
|
const CreationContext& creation_context, const OperationDef& definition,
|
||||||
const FullyConnectedAttributes& attr,
|
const FullyConnectedAttributes& attr, ConvBuffer1x1* result,
|
||||||
ConvBuffer1x1* result, const BHWC* shape);
|
const BHWC* shape);
|
||||||
friend Status CreateConvBuffer1x1Wino4x4To6x6(
|
friend absl::Status CreateConvBuffer1x1Wino4x4To6x6(
|
||||||
const CreationContext& creation_context, const OperationDef& definition,
|
const CreationContext& creation_context, const OperationDef& definition,
|
||||||
const Convolution2DAttributes& attr, ConvBuffer1x1* result,
|
const Convolution2DAttributes& attr, ConvBuffer1x1* result,
|
||||||
const BHWC* shape);
|
const BHWC* shape);
|
||||||
|
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
Status UploadData(const ::tflite::gpu::Tensor<OHWI, T>& weights,
|
absl::Status UploadData(const ::tflite::gpu::Tensor<OHWI, T>& weights,
|
||||||
const ::tflite::gpu::Tensor<Linear, T>& biases,
|
const ::tflite::gpu::Tensor<Linear, T>& biases,
|
||||||
CLContext* context);
|
CLContext* context);
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
Status UploadDataForWinograd4x4To6x6(
|
absl::Status UploadDataForWinograd4x4To6x6(
|
||||||
const ::tflite::gpu::Tensor<OHWI, T>& weights, const CLDevice& device,
|
const ::tflite::gpu::Tensor<OHWI, T>& weights, const CLDevice& device,
|
||||||
CLContext* context);
|
CLContext* context);
|
||||||
|
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
Status UploadWeights(const ::tflite::gpu::Tensor<OHWI, T>& weights,
|
absl::Status UploadWeights(const ::tflite::gpu::Tensor<OHWI, T>& weights,
|
||||||
CLContext* context);
|
CLContext* context);
|
||||||
|
|
||||||
Status BindArguments();
|
absl::Status BindArguments();
|
||||||
int3 GetGridSize() const;
|
int3 GetGridSize() const;
|
||||||
|
|
||||||
Buffer weights_;
|
Buffer weights_;
|
||||||
@ -101,20 +101,20 @@ class ConvBuffer1x1 : public GPUOperation {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
Status ConvBuffer1x1::UploadData(const ::tflite::gpu::Tensor<OHWI, T>& weights,
|
absl::Status ConvBuffer1x1::UploadData(
|
||||||
const ::tflite::gpu::Tensor<Linear, T>& biases,
|
const ::tflite::gpu::Tensor<OHWI, T>& weights,
|
||||||
CLContext* context) {
|
const ::tflite::gpu::Tensor<Linear, T>& biases, CLContext* context) {
|
||||||
RETURN_IF_ERROR(UploadWeights(weights, context));
|
RETURN_IF_ERROR(UploadWeights(weights, context));
|
||||||
LinearStorageCreateInfo create_info;
|
LinearStorageCreateInfo create_info;
|
||||||
create_info.storage_type = LinearStorageType::BUFFER;
|
create_info.storage_type = LinearStorageType::BUFFER;
|
||||||
create_info.data_type = definition_.GetDataType();
|
create_info.data_type = definition_.GetDataType();
|
||||||
create_info.aligned_size = weights.shape.o;
|
create_info.aligned_size = weights.shape.o;
|
||||||
RETURN_IF_ERROR(CreateLinearStorage(create_info, biases, context, &biases_));
|
RETURN_IF_ERROR(CreateLinearStorage(create_info, biases, context, &biases_));
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
Status ConvBuffer1x1::UploadDataForWinograd4x4To6x6(
|
absl::Status ConvBuffer1x1::UploadDataForWinograd4x4To6x6(
|
||||||
const ::tflite::gpu::Tensor<OHWI, T>& weights, const CLDevice& device,
|
const ::tflite::gpu::Tensor<OHWI, T>& weights, const CLDevice& device,
|
||||||
CLContext* context) {
|
CLContext* context) {
|
||||||
::tflite::gpu::Tensor<OHWI, T> wino_weights;
|
::tflite::gpu::Tensor<OHWI, T> wino_weights;
|
||||||
@ -132,7 +132,7 @@ Status ConvBuffer1x1::UploadDataForWinograd4x4To6x6(
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
Status ConvBuffer1x1::UploadWeights(
|
absl::Status ConvBuffer1x1::UploadWeights(
|
||||||
const ::tflite::gpu::Tensor<OHWI, T>& weights, CLContext* context) {
|
const ::tflite::gpu::Tensor<OHWI, T>& weights, CLContext* context) {
|
||||||
const int dst_depth = IntegralDivideRoundUp(weights.shape.o, 4);
|
const int dst_depth = IntegralDivideRoundUp(weights.shape.o, 4);
|
||||||
const int src_depth = IntegralDivideRoundUp(weights.shape.i, 4);
|
const int src_depth = IntegralDivideRoundUp(weights.shape.i, 4);
|
||||||
@ -162,21 +162,22 @@ Status ConvBuffer1x1::UploadWeights(
|
|||||||
bool IsConvBuffer1x1Supported(const OperationDef& definition,
|
bool IsConvBuffer1x1Supported(const OperationDef& definition,
|
||||||
const Convolution2DAttributes& attr);
|
const Convolution2DAttributes& attr);
|
||||||
|
|
||||||
Status CreateConvBuffer1x1(const CreationContext& creation_context,
|
absl::Status CreateConvBuffer1x1(const CreationContext& creation_context,
|
||||||
const OperationDef& definition,
|
const OperationDef& definition,
|
||||||
const Convolution2DAttributes& attr,
|
const Convolution2DAttributes& attr,
|
||||||
ConvBuffer1x1* result, const BHWC* shape = nullptr);
|
ConvBuffer1x1* result,
|
||||||
|
const BHWC* shape = nullptr);
|
||||||
|
|
||||||
Status CreateConvBuffer1x1(const CreationContext& creation_context,
|
absl::Status CreateConvBuffer1x1(const CreationContext& creation_context,
|
||||||
const OperationDef& definition,
|
const OperationDef& definition,
|
||||||
const FullyConnectedAttributes& attr,
|
const FullyConnectedAttributes& attr,
|
||||||
ConvBuffer1x1* result, const BHWC* shape = nullptr);
|
ConvBuffer1x1* result,
|
||||||
|
const BHWC* shape = nullptr);
|
||||||
|
|
||||||
Status CreateConvBuffer1x1Wino4x4To6x6(const CreationContext& creation_context,
|
absl::Status CreateConvBuffer1x1Wino4x4To6x6(
|
||||||
const OperationDef& definition,
|
const CreationContext& creation_context, const OperationDef& definition,
|
||||||
const Convolution2DAttributes& attr,
|
const Convolution2DAttributes& attr, ConvBuffer1x1* result,
|
||||||
ConvBuffer1x1* result,
|
const BHWC* shape = nullptr);
|
||||||
const BHWC* shape = nullptr);
|
|
||||||
|
|
||||||
} // namespace cl
|
} // namespace cl
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
|
@ -219,7 +219,7 @@ ConvConstants& ConvConstants::operator=(ConvConstants&& kernel) {
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConvConstants::Compile(const CreationContext& creation_context) {
|
absl::Status ConvConstants::Compile(const CreationContext& creation_context) {
|
||||||
const bool stride_correction =
|
const bool stride_correction =
|
||||||
definition_.IsBatchSupported() && stride_.x != 1;
|
definition_.IsBatchSupported() && stride_.x != 1;
|
||||||
const auto code = GenerateConvolutionConstantCode(
|
const auto code = GenerateConvolutionConstantCode(
|
||||||
@ -240,7 +240,7 @@ Status ConvConstants::Compile(const CreationContext& creation_context) {
|
|||||||
*creation_context.device, &kernel_);
|
*creation_context.device, &kernel_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConvConstants::BindArguments() {
|
absl::Status ConvConstants::BindArguments() {
|
||||||
kernel_.ResetBindingCounter();
|
kernel_.ResetBindingCounter();
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_.GetMemoryPtr()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_.GetMemoryPtr()));
|
||||||
@ -254,7 +254,7 @@ Status ConvConstants::BindArguments() {
|
|||||||
kernel_.SetBytesAuto(int2(dilation_.x * src_[0]->Batch(), dilation_.y)));
|
kernel_.SetBytesAuto(int2(dilation_.x * src_[0]->Batch(), dilation_.y)));
|
||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWBatchedHSB()));
|
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWBatchedHSB()));
|
||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHSB()));
|
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHSB()));
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
int3 ConvConstants::GetGridSize() const {
|
int3 ConvConstants::GetGridSize() const {
|
||||||
@ -263,12 +263,12 @@ int3 ConvConstants::GetGridSize() const {
|
|||||||
return int3(grid_x, grid_y, 1);
|
return int3(grid_x, grid_y, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConvConstants::Tune(const TuningParameters& params) {
|
absl::Status ConvConstants::Tune(const TuningParameters& params) {
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_);
|
return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConvConstants::AddToQueue(CLCommandQueue* queue) {
|
absl::Status ConvConstants::AddToQueue(CLCommandQueue* queue) {
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
||||||
}
|
}
|
||||||
@ -294,12 +294,12 @@ bool IsConvConstantsSupported(const CLDevice& device,
|
|||||||
return filters_buffer_size <= kConstantMaxSize && flt4_registers <= 8;
|
return filters_buffer_size <= kConstantMaxSize && flt4_registers <= 8;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateConvConstants(const CreationContext& creation_context,
|
absl::Status CreateConvConstants(const CreationContext& creation_context,
|
||||||
const OperationDef& definition,
|
const OperationDef& definition,
|
||||||
const Convolution2DAttributes& attr,
|
const Convolution2DAttributes& attr,
|
||||||
ConvConstants* result) {
|
ConvConstants* result) {
|
||||||
if (!IsConvConstantsSupported(*creation_context.device, definition, attr)) {
|
if (!IsConvConstantsSupported(*creation_context.device, definition, attr)) {
|
||||||
return InvalidArgumentError("ConvConstants doesn't supported");
|
return absl::InvalidArgumentError("ConvConstants doesn't supported");
|
||||||
}
|
}
|
||||||
*result = ConvConstants(definition, attr);
|
*result = ConvConstants(definition, attr);
|
||||||
RETURN_IF_ERROR(
|
RETURN_IF_ERROR(
|
||||||
@ -310,8 +310,7 @@ Status CreateConvConstants(const CreationContext& creation_context,
|
|||||||
create_info.aligned_size = attr.weights.shape.o;
|
create_info.aligned_size = attr.weights.shape.o;
|
||||||
RETURN_IF_ERROR(CreateLinearStorage(
|
RETURN_IF_ERROR(CreateLinearStorage(
|
||||||
create_info, attr.bias, creation_context.context, &result->biases_));
|
create_info, attr.bias, creation_context.context, &result->biases_));
|
||||||
|
return absl::OkStatus();
|
||||||
return OkStatus();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cl
|
} // namespace cl
|
||||||
|
@ -35,10 +35,10 @@ namespace cl {
|
|||||||
class ConvConstants : public GPUOperation {
|
class ConvConstants : public GPUOperation {
|
||||||
public:
|
public:
|
||||||
ConvConstants() = default;
|
ConvConstants() = default;
|
||||||
Status AddToQueue(CLCommandQueue* queue) override;
|
absl::Status AddToQueue(CLCommandQueue* queue) override;
|
||||||
Status Tune(const TuningParameters& params) override;
|
absl::Status Tune(const TuningParameters& params) override;
|
||||||
|
|
||||||
Status Compile(const CreationContext& creation_context) override;
|
absl::Status Compile(const CreationContext& creation_context) override;
|
||||||
|
|
||||||
// Move only
|
// Move only
|
||||||
ConvConstants(ConvConstants&& kernel);
|
ConvConstants(ConvConstants&& kernel);
|
||||||
@ -47,10 +47,9 @@ class ConvConstants : public GPUOperation {
|
|||||||
ConvConstants& operator=(const ConvConstants&) = delete;
|
ConvConstants& operator=(const ConvConstants&) = delete;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
friend Status CreateConvConstants(const CreationContext& creation_context,
|
friend absl::Status CreateConvConstants(
|
||||||
const OperationDef& definition,
|
const CreationContext& creation_context, const OperationDef& definition,
|
||||||
const Convolution2DAttributes& attr,
|
const Convolution2DAttributes& attr, ConvConstants* result);
|
||||||
ConvConstants* result);
|
|
||||||
explicit ConvConstants(const OperationDef& definition,
|
explicit ConvConstants(const OperationDef& definition,
|
||||||
const Convolution2DAttributes& attr)
|
const Convolution2DAttributes& attr)
|
||||||
: GPUOperation(definition),
|
: GPUOperation(definition),
|
||||||
@ -62,14 +61,14 @@ class ConvConstants : public GPUOperation {
|
|||||||
dst_channels_(attr.weights.shape.o) {}
|
dst_channels_(attr.weights.shape.o) {}
|
||||||
|
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
Status UploadWeights(const ::tflite::gpu::Tensor<OHWI, T>& weights,
|
absl::Status UploadWeights(const ::tflite::gpu::Tensor<OHWI, T>& weights,
|
||||||
CLContext* context);
|
CLContext* context);
|
||||||
|
|
||||||
template <DataType S, typename T>
|
template <DataType S, typename T>
|
||||||
void RearrangeWeightsData(const ::tflite::gpu::Tensor<OHWI, S>& weights,
|
void RearrangeWeightsData(const ::tflite::gpu::Tensor<OHWI, S>& weights,
|
||||||
absl::Span<T> dst);
|
absl::Span<T> dst);
|
||||||
|
|
||||||
Status BindArguments();
|
absl::Status BindArguments();
|
||||||
int3 GetGridSize() const;
|
int3 GetGridSize() const;
|
||||||
|
|
||||||
Buffer weights_;
|
Buffer weights_;
|
||||||
@ -87,7 +86,7 @@ class ConvConstants : public GPUOperation {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
Status ConvConstants::UploadWeights(
|
absl::Status ConvConstants::UploadWeights(
|
||||||
const ::tflite::gpu::Tensor<OHWI, T>& weights, CLContext* context) {
|
const ::tflite::gpu::Tensor<OHWI, T>& weights, CLContext* context) {
|
||||||
const int dst_depth = IntegralDivideRoundUp(weights.shape.o, 4);
|
const int dst_depth = IntegralDivideRoundUp(weights.shape.o, 4);
|
||||||
const int kernel_x = weights.shape.w;
|
const int kernel_x = weights.shape.w;
|
||||||
@ -157,10 +156,10 @@ bool IsConvConstantsSupported(const CLDevice& device,
|
|||||||
const OperationDef& definition,
|
const OperationDef& definition,
|
||||||
const Convolution2DAttributes& attr);
|
const Convolution2DAttributes& attr);
|
||||||
|
|
||||||
Status CreateConvConstants(const CreationContext& creation_context,
|
absl::Status CreateConvConstants(const CreationContext& creation_context,
|
||||||
const OperationDef& definition,
|
const OperationDef& definition,
|
||||||
const Convolution2DAttributes& attr,
|
const Convolution2DAttributes& attr,
|
||||||
ConvConstants* result);
|
ConvConstants* result);
|
||||||
|
|
||||||
} // namespace cl
|
} // namespace cl
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
|
@ -173,7 +173,7 @@ ConvPowerVR& ConvPowerVR::operator=(ConvPowerVR&& operation) {
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConvPowerVR::Compile(const CreationContext& creation_context) {
|
absl::Status ConvPowerVR::Compile(const CreationContext& creation_context) {
|
||||||
const bool stride_correction =
|
const bool stride_correction =
|
||||||
definition_.IsBatchSupported() && stride_padding_.x != 1;
|
definition_.IsBatchSupported() && stride_padding_.x != 1;
|
||||||
const std::string code =
|
const std::string code =
|
||||||
@ -189,7 +189,7 @@ Status ConvPowerVR::Compile(const CreationContext& creation_context) {
|
|||||||
*creation_context.device, &kernel_);
|
*creation_context.device, &kernel_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConvPowerVR::BindArguments() {
|
absl::Status ConvPowerVR::BindArguments() {
|
||||||
kernel_.ResetBindingCounter();
|
kernel_.ResetBindingCounter();
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_.GetMemoryPtr()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_.GetMemoryPtr()));
|
||||||
@ -211,7 +211,7 @@ Status ConvPowerVR::BindArguments() {
|
|||||||
}
|
}
|
||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWBatchedHSB()));
|
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWBatchedHSB()));
|
||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHSB()));
|
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHSB()));
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
int3 ConvPowerVR::GetGridSize() const {
|
int3 ConvPowerVR::GetGridSize() const {
|
||||||
@ -245,13 +245,13 @@ int3 ConvPowerVR::GetGridSize() const {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConvPowerVR::Tune(const TuningParameters& params) {
|
absl::Status ConvPowerVR::Tune(const TuningParameters& params) {
|
||||||
if (conv_params_.weights_upload_type ==
|
if (conv_params_.weights_upload_type ==
|
||||||
WeightsUploadType::LOCAL_MEM_ASYNC_SUBGROUP ||
|
WeightsUploadType::LOCAL_MEM_ASYNC_SUBGROUP ||
|
||||||
conv_params_.weights_upload_type ==
|
conv_params_.weights_upload_type ==
|
||||||
WeightsUploadType::LOCAL_MEM_BY_THREADS ||
|
WeightsUploadType::LOCAL_MEM_BY_THREADS ||
|
||||||
conv_params_.fixed_work_group_size) {
|
conv_params_.fixed_work_group_size) {
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
if (conv_params_.work_group_launch_order[0] == 0 &&
|
if (conv_params_.work_group_launch_order[0] == 0 &&
|
||||||
conv_params_.work_group_launch_order[1] == 1 &&
|
conv_params_.work_group_launch_order[1] == 1 &&
|
||||||
@ -260,10 +260,10 @@ Status ConvPowerVR::Tune(const TuningParameters& params) {
|
|||||||
return GetBestWorkGroupConv(params, kernel_, GetGridSize(),
|
return GetBestWorkGroupConv(params, kernel_, GetGridSize(),
|
||||||
&conv_params_.work_group_size);
|
&conv_params_.work_group_size);
|
||||||
}
|
}
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConvPowerVR::AddToQueue(CLCommandQueue* queue) {
|
absl::Status ConvPowerVR::AddToQueue(CLCommandQueue* queue) {
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return queue->DispatchImplicit(kernel_, GetGridSize(),
|
return queue->DispatchImplicit(kernel_, GetGridSize(),
|
||||||
conv_params_.work_group_size);
|
conv_params_.work_group_size);
|
||||||
@ -848,27 +848,26 @@ ConvPowerVR::ConvParams ConvPowerVR::GuessBestParamsWinograd(
|
|||||||
return params;
|
return params;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateConvPowerVR(const CreationContext& creation_context,
|
absl::Status CreateConvPowerVR(const CreationContext& creation_context,
|
||||||
const OperationDef& definition,
|
const OperationDef& definition,
|
||||||
const Convolution2DAttributes& attr,
|
const Convolution2DAttributes& attr,
|
||||||
ConvPowerVR* result, const BHWC* dst_shape) {
|
ConvPowerVR* result, const BHWC* dst_shape) {
|
||||||
*result = ConvPowerVR(definition, attr, *creation_context.device, dst_shape);
|
*result = ConvPowerVR(definition, attr, *creation_context.device, dst_shape);
|
||||||
return result->UploadData(attr.weights, attr.bias, creation_context.context);
|
return result->UploadData(attr.weights, attr.bias, creation_context.context);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateConvPowerVR(const CreationContext& creation_context,
|
absl::Status CreateConvPowerVR(const CreationContext& creation_context,
|
||||||
const OperationDef& definition,
|
const OperationDef& definition,
|
||||||
const FullyConnectedAttributes& attr,
|
const FullyConnectedAttributes& attr,
|
||||||
ConvPowerVR* result, const BHWC* dst_shape) {
|
ConvPowerVR* result, const BHWC* dst_shape) {
|
||||||
*result = ConvPowerVR(definition, attr, *creation_context.device, dst_shape);
|
*result = ConvPowerVR(definition, attr, *creation_context.device, dst_shape);
|
||||||
return result->UploadData(attr.weights, attr.bias, creation_context.context);
|
return result->UploadData(attr.weights, attr.bias, creation_context.context);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateConvPowerVRWino4x4To6x6(const CreationContext& creation_context,
|
absl::Status CreateConvPowerVRWino4x4To6x6(
|
||||||
const OperationDef& definition,
|
const CreationContext& creation_context, const OperationDef& definition,
|
||||||
const Convolution2DAttributes& attr,
|
const Convolution2DAttributes& attr, ConvPowerVR* result,
|
||||||
ConvPowerVR* result,
|
const BHWC* dst_shape) {
|
||||||
const BHWC* dst_shape) {
|
|
||||||
*result = ConvPowerVR(definition);
|
*result = ConvPowerVR(definition);
|
||||||
result->conv_params_ = result->GuessBestParamsWinograd(
|
result->conv_params_ = result->GuessBestParamsWinograd(
|
||||||
*creation_context.device, definition, attr, dst_shape);
|
*creation_context.device, definition, attr, dst_shape);
|
||||||
|
@ -39,9 +39,9 @@ namespace cl {
|
|||||||
class ConvPowerVR : public GPUOperation {
|
class ConvPowerVR : public GPUOperation {
|
||||||
public:
|
public:
|
||||||
ConvPowerVR() = default;
|
ConvPowerVR() = default;
|
||||||
Status AddToQueue(CLCommandQueue* queue) override;
|
absl::Status AddToQueue(CLCommandQueue* queue) override;
|
||||||
Status Tune(const TuningParameters& params) override;
|
absl::Status Tune(const TuningParameters& params) override;
|
||||||
Status Compile(const CreationContext& creation_context) override;
|
absl::Status Compile(const CreationContext& creation_context) override;
|
||||||
|
|
||||||
// Move only
|
// Move only
|
||||||
ConvPowerVR(ConvPowerVR&& operation);
|
ConvPowerVR(ConvPowerVR&& operation);
|
||||||
@ -87,29 +87,31 @@ class ConvPowerVR : public GPUOperation {
|
|||||||
explicit ConvPowerVR(const OperationDef& definition);
|
explicit ConvPowerVR(const OperationDef& definition);
|
||||||
|
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
Status UploadData(const ::tflite::gpu::Tensor<OHWI, T>& weights,
|
absl::Status UploadData(const ::tflite::gpu::Tensor<OHWI, T>& weights,
|
||||||
const ::tflite::gpu::Tensor<Linear, T>& biases,
|
const ::tflite::gpu::Tensor<Linear, T>& biases,
|
||||||
CLContext* context);
|
CLContext* context);
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
Status UploadDataForWinograd4x4To6x6(
|
absl::Status UploadDataForWinograd4x4To6x6(
|
||||||
const ::tflite::gpu::Tensor<OHWI, T>& weights, const CLDevice& device,
|
const ::tflite::gpu::Tensor<OHWI, T>& weights, const CLDevice& device,
|
||||||
CLContext* context);
|
CLContext* context);
|
||||||
|
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
Status UploadWeights(const ::tflite::gpu::Tensor<OHWI, T>& weights,
|
absl::Status UploadWeights(const ::tflite::gpu::Tensor<OHWI, T>& weights,
|
||||||
CLContext* context);
|
CLContext* context);
|
||||||
|
|
||||||
friend Status CreateConvPowerVR(const CreationContext& creation_context,
|
friend absl::Status CreateConvPowerVR(const CreationContext& creation_context,
|
||||||
const OperationDef& definition,
|
const OperationDef& definition,
|
||||||
const Convolution2DAttributes& attr,
|
const Convolution2DAttributes& attr,
|
||||||
ConvPowerVR* result, const BHWC* dst_shape);
|
ConvPowerVR* result,
|
||||||
|
const BHWC* dst_shape);
|
||||||
|
|
||||||
friend Status CreateConvPowerVR(const CreationContext& creation_context,
|
friend absl::Status CreateConvPowerVR(const CreationContext& creation_context,
|
||||||
const OperationDef& definition,
|
const OperationDef& definition,
|
||||||
const FullyConnectedAttributes& attr,
|
const FullyConnectedAttributes& attr,
|
||||||
ConvPowerVR* result, const BHWC* dst_shape);
|
ConvPowerVR* result,
|
||||||
|
const BHWC* dst_shape);
|
||||||
|
|
||||||
friend Status CreateConvPowerVRWino4x4To6x6(
|
friend absl::Status CreateConvPowerVRWino4x4To6x6(
|
||||||
const CreationContext& creation_context, const OperationDef& definition,
|
const CreationContext& creation_context, const OperationDef& definition,
|
||||||
const Convolution2DAttributes& attr, ConvPowerVR* result,
|
const Convolution2DAttributes& attr, ConvPowerVR* result,
|
||||||
const BHWC* dst_shape);
|
const BHWC* dst_shape);
|
||||||
@ -138,7 +140,7 @@ class ConvPowerVR : public GPUOperation {
|
|||||||
bool different_weights_for_height,
|
bool different_weights_for_height,
|
||||||
const BHWC* dst_shape = nullptr) const;
|
const BHWC* dst_shape = nullptr) const;
|
||||||
|
|
||||||
Status BindArguments();
|
absl::Status BindArguments();
|
||||||
int3 GetGridSize() const;
|
int3 GetGridSize() const;
|
||||||
|
|
||||||
Buffer weights_;
|
Buffer weights_;
|
||||||
@ -152,20 +154,20 @@ class ConvPowerVR : public GPUOperation {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
Status ConvPowerVR::UploadData(const ::tflite::gpu::Tensor<OHWI, T>& weights,
|
absl::Status ConvPowerVR::UploadData(
|
||||||
const ::tflite::gpu::Tensor<Linear, T>& biases,
|
const ::tflite::gpu::Tensor<OHWI, T>& weights,
|
||||||
CLContext* context) {
|
const ::tflite::gpu::Tensor<Linear, T>& biases, CLContext* context) {
|
||||||
RETURN_IF_ERROR(UploadWeights(weights, context));
|
RETURN_IF_ERROR(UploadWeights(weights, context));
|
||||||
LinearStorageCreateInfo create_info;
|
LinearStorageCreateInfo create_info;
|
||||||
create_info.storage_type = LinearStorageType::BUFFER;
|
create_info.storage_type = LinearStorageType::BUFFER;
|
||||||
create_info.data_type = conv_params_.weights_data_type;
|
create_info.data_type = conv_params_.weights_data_type;
|
||||||
create_info.aligned_size = weights.shape.o;
|
create_info.aligned_size = weights.shape.o;
|
||||||
RETURN_IF_ERROR(CreateLinearStorage(create_info, biases, context, &biases_));
|
RETURN_IF_ERROR(CreateLinearStorage(create_info, biases, context, &biases_));
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
Status ConvPowerVR::UploadDataForWinograd4x4To6x6(
|
absl::Status ConvPowerVR::UploadDataForWinograd4x4To6x6(
|
||||||
const ::tflite::gpu::Tensor<OHWI, T>& weights, const CLDevice& device,
|
const ::tflite::gpu::Tensor<OHWI, T>& weights, const CLDevice& device,
|
||||||
CLContext* context) {
|
CLContext* context) {
|
||||||
::tflite::gpu::Tensor<OHWI, T> wino_weights;
|
::tflite::gpu::Tensor<OHWI, T> wino_weights;
|
||||||
@ -179,12 +181,12 @@ Status ConvPowerVR::UploadDataForWinograd4x4To6x6(
|
|||||||
bias.shape = Linear(weights.shape.o);
|
bias.shape = Linear(weights.shape.o);
|
||||||
bias.data.resize(weights.shape.o, 0.0f);
|
bias.data.resize(weights.shape.o, 0.0f);
|
||||||
RETURN_IF_ERROR(CreateLinearStorage(create_info, bias, context, &biases_));
|
RETURN_IF_ERROR(CreateLinearStorage(create_info, bias, context, &biases_));
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
Status ConvPowerVR::UploadWeights(const ::tflite::gpu::Tensor<OHWI, T>& weights,
|
absl::Status ConvPowerVR::UploadWeights(
|
||||||
CLContext* context) {
|
const ::tflite::gpu::Tensor<OHWI, T>& weights, CLContext* context) {
|
||||||
const int dst_depth = IntegralDivideRoundUp(weights.shape.o, 4);
|
const int dst_depth = IntegralDivideRoundUp(weights.shape.o, 4);
|
||||||
const int src_depth = IntegralDivideRoundUp(weights.shape.i, 4);
|
const int src_depth = IntegralDivideRoundUp(weights.shape.i, 4);
|
||||||
|
|
||||||
@ -210,21 +212,22 @@ Status ConvPowerVR::UploadWeights(const ::tflite::gpu::Tensor<OHWI, T>& weights,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateConvPowerVR(const CreationContext& creation_context,
|
absl::Status CreateConvPowerVR(const CreationContext& creation_context,
|
||||||
const OperationDef& definition,
|
const OperationDef& definition,
|
||||||
const Convolution2DAttributes& attr,
|
const Convolution2DAttributes& attr,
|
||||||
ConvPowerVR* result, const BHWC* dst_shape = nullptr);
|
ConvPowerVR* result,
|
||||||
|
const BHWC* dst_shape = nullptr);
|
||||||
|
|
||||||
Status CreateConvPowerVR(const CreationContext& creation_context,
|
absl::Status CreateConvPowerVR(const CreationContext& creation_context,
|
||||||
const OperationDef& definition,
|
const OperationDef& definition,
|
||||||
const FullyConnectedAttributes& attr,
|
const FullyConnectedAttributes& attr,
|
||||||
ConvPowerVR* result, const BHWC* dst_shape = nullptr);
|
ConvPowerVR* result,
|
||||||
|
const BHWC* dst_shape = nullptr);
|
||||||
|
|
||||||
Status CreateConvPowerVRWino4x4To6x6(const CreationContext& creation_context,
|
absl::Status CreateConvPowerVRWino4x4To6x6(
|
||||||
const OperationDef& definition,
|
const CreationContext& creation_context, const OperationDef& definition,
|
||||||
const Convolution2DAttributes& attr,
|
const Convolution2DAttributes& attr, ConvPowerVR* result,
|
||||||
ConvPowerVR* result,
|
const BHWC* dst_shape = nullptr);
|
||||||
const BHWC* dst_shape = nullptr);
|
|
||||||
|
|
||||||
} // namespace cl
|
} // namespace cl
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
|
@ -30,6 +30,7 @@ namespace tflite {
|
|||||||
namespace gpu {
|
namespace gpu {
|
||||||
namespace cl {
|
namespace cl {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
std::string GenerateConvCode(
|
std::string GenerateConvCode(
|
||||||
const OperationDef& op_def, const int3& block_size, bool is1x1,
|
const OperationDef& op_def, const int3& block_size, bool is1x1,
|
||||||
bool adreno4xx_optimization, bool stride_correction,
|
bool adreno4xx_optimization, bool stride_correction,
|
||||||
@ -384,7 +385,7 @@ ConvTexture& ConvTexture::operator=(ConvTexture&& operation) {
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConvTexture::Compile(const CreationContext& creation_context) {
|
absl::Status ConvTexture::Compile(const CreationContext& creation_context) {
|
||||||
auto storage_type = definition_.GetPrimaryStorageType();
|
auto storage_type = definition_.GetPrimaryStorageType();
|
||||||
bool is1x1 = kernel_size_.x == 1 && kernel_size_.y == 1;
|
bool is1x1 = kernel_size_.x == 1 && kernel_size_.y == 1;
|
||||||
bool adreno4xx_optimization =
|
bool adreno4xx_optimization =
|
||||||
@ -407,7 +408,7 @@ Status ConvTexture::Compile(const CreationContext& creation_context) {
|
|||||||
*creation_context.device, &kernel_);
|
*creation_context.device, &kernel_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConvTexture::BindArguments() {
|
absl::Status ConvTexture::BindArguments() {
|
||||||
kernel_.ResetBindingCounter();
|
kernel_.ResetBindingCounter();
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_0_.GetMemoryPtr()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_0_.GetMemoryPtr()));
|
||||||
@ -427,7 +428,7 @@ Status ConvTexture::BindArguments() {
|
|||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(stride_));
|
RETURN_IF_ERROR(kernel_.SetBytesAuto(stride_));
|
||||||
RETURN_IF_ERROR(
|
RETURN_IF_ERROR(
|
||||||
kernel_.SetBytesAuto(int2(padding_.x * src_[0]->Batch(), padding_.y)));
|
kernel_.SetBytesAuto(int2(padding_.x * src_[0]->Batch(), padding_.y)));
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
int3 ConvTexture::GetGridSize() const {
|
int3 ConvTexture::GetGridSize() const {
|
||||||
@ -438,37 +439,36 @@ int3 ConvTexture::GetGridSize() const {
|
|||||||
return int3(grid_x, grid_y, grid_z);
|
return int3(grid_x, grid_y, grid_z);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConvTexture::Tune(const TuningParameters& params) {
|
absl::Status ConvTexture::Tune(const TuningParameters& params) {
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return GetBestWorkGroupConv(params, kernel_, GetGridSize(),
|
return GetBestWorkGroupConv(params, kernel_, GetGridSize(),
|
||||||
&work_group_size_);
|
&work_group_size_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConvTexture::AddToQueue(CLCommandQueue* queue) {
|
absl::Status ConvTexture::AddToQueue(CLCommandQueue* queue) {
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateConvTexture(const CreationContext& creation_context,
|
absl::Status CreateConvTexture(const CreationContext& creation_context,
|
||||||
const OperationDef& definition,
|
const OperationDef& definition,
|
||||||
const Convolution2DAttributes& attr,
|
const Convolution2DAttributes& attr,
|
||||||
ConvTexture* result) {
|
ConvTexture* result) {
|
||||||
*result = ConvTexture(definition, attr);
|
*result = ConvTexture(definition, attr);
|
||||||
return result->UploadData(attr.weights, attr.bias, creation_context.context);
|
return result->UploadData(attr.weights, attr.bias, creation_context.context);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateConvTexture(const CreationContext& creation_context,
|
absl::Status CreateConvTexture(const CreationContext& creation_context,
|
||||||
const OperationDef& definition,
|
const OperationDef& definition,
|
||||||
const FullyConnectedAttributes& attr,
|
const FullyConnectedAttributes& attr,
|
||||||
ConvTexture* result) {
|
ConvTexture* result) {
|
||||||
*result = ConvTexture(definition);
|
*result = ConvTexture(definition);
|
||||||
return result->UploadData(attr.weights, attr.bias, creation_context.context);
|
return result->UploadData(attr.weights, attr.bias, creation_context.context);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateConvTextureWino4x4To6x6(const CreationContext& creation_context,
|
absl::Status CreateConvTextureWino4x4To6x6(
|
||||||
const OperationDef& definition,
|
const CreationContext& creation_context, const OperationDef& definition,
|
||||||
const Convolution2DAttributes& attr,
|
const Convolution2DAttributes& attr, ConvTexture* result) {
|
||||||
ConvTexture* result) {
|
|
||||||
*result = ConvTexture(definition);
|
*result = ConvTexture(definition);
|
||||||
result->different_weights_for_height_ = true;
|
result->different_weights_for_height_ = true;
|
||||||
result->block_size_ = {4, 1, 2};
|
result->block_size_ = {4, 1, 2};
|
||||||
|
@ -41,10 +41,10 @@ namespace cl {
|
|||||||
class ConvTexture : public GPUOperation {
|
class ConvTexture : public GPUOperation {
|
||||||
public:
|
public:
|
||||||
ConvTexture() = default;
|
ConvTexture() = default;
|
||||||
Status AddToQueue(CLCommandQueue* queue) override;
|
absl::Status AddToQueue(CLCommandQueue* queue) override;
|
||||||
Status Tune(const TuningParameters& params) override;
|
absl::Status Tune(const TuningParameters& params) override;
|
||||||
|
|
||||||
Status Compile(const CreationContext& creation_context) override;
|
absl::Status Compile(const CreationContext& creation_context) override;
|
||||||
|
|
||||||
// Move only
|
// Move only
|
||||||
ConvTexture(ConvTexture&& operation);
|
ConvTexture(ConvTexture&& operation);
|
||||||
@ -53,16 +53,16 @@ class ConvTexture : public GPUOperation {
|
|||||||
ConvTexture& operator=(const ConvTexture&) = delete;
|
ConvTexture& operator=(const ConvTexture&) = delete;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
friend Status CreateConvTexture(const CreationContext& creation_context,
|
friend absl::Status CreateConvTexture(const CreationContext& creation_context,
|
||||||
const OperationDef& definition,
|
const OperationDef& definition,
|
||||||
const Convolution2DAttributes& attr,
|
const Convolution2DAttributes& attr,
|
||||||
ConvTexture* result);
|
ConvTexture* result);
|
||||||
friend Status CreateConvTexture(const CreationContext& creation_context,
|
friend absl::Status CreateConvTexture(const CreationContext& creation_context,
|
||||||
const OperationDef& definition,
|
const OperationDef& definition,
|
||||||
const FullyConnectedAttributes& attr,
|
const FullyConnectedAttributes& attr,
|
||||||
ConvTexture* result);
|
ConvTexture* result);
|
||||||
|
|
||||||
friend Status CreateConvTextureWino4x4To6x6(
|
friend absl::Status CreateConvTextureWino4x4To6x6(
|
||||||
const CreationContext& creation_context, const OperationDef& definition,
|
const CreationContext& creation_context, const OperationDef& definition,
|
||||||
const Convolution2DAttributes& attr, ConvTexture* result);
|
const Convolution2DAttributes& attr, ConvTexture* result);
|
||||||
|
|
||||||
@ -70,25 +70,25 @@ class ConvTexture : public GPUOperation {
|
|||||||
const Convolution2DAttributes& attr);
|
const Convolution2DAttributes& attr);
|
||||||
explicit ConvTexture(const OperationDef& definition);
|
explicit ConvTexture(const OperationDef& definition);
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
Status UploadData(const ::tflite::gpu::Tensor<OHWI, T>& weights,
|
absl::Status UploadData(const ::tflite::gpu::Tensor<OHWI, T>& weights,
|
||||||
const ::tflite::gpu::Tensor<Linear, T>& biases,
|
const ::tflite::gpu::Tensor<Linear, T>& biases,
|
||||||
CLContext* context);
|
CLContext* context);
|
||||||
|
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
Status UploadDataForWinograd4x4To6x6(
|
absl::Status UploadDataForWinograd4x4To6x6(
|
||||||
const ::tflite::gpu::Tensor<OHWI, T>& weights, const CLDevice& device,
|
const ::tflite::gpu::Tensor<OHWI, T>& weights, const CLDevice& device,
|
||||||
CLContext* context);
|
CLContext* context);
|
||||||
|
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
Status UploadWeights(const ::tflite::gpu::Tensor<OHWI, T>& weights,
|
absl::Status UploadWeights(const ::tflite::gpu::Tensor<OHWI, T>& weights,
|
||||||
CLContext* context);
|
CLContext* context);
|
||||||
|
|
||||||
template <DataType S, typename T>
|
template <DataType S, typename T>
|
||||||
void RearrangeWeightsData(const ::tflite::gpu::Tensor<OHWI, S>& weights,
|
void RearrangeWeightsData(const ::tflite::gpu::Tensor<OHWI, S>& weights,
|
||||||
absl::Span<T> dst_0, absl::Span<T> dst_1,
|
absl::Span<T> dst_0, absl::Span<T> dst_1,
|
||||||
absl::Span<T> dst_2, absl::Span<T> dst_3);
|
absl::Span<T> dst_2, absl::Span<T> dst_3);
|
||||||
|
|
||||||
Status BindArguments();
|
absl::Status BindArguments();
|
||||||
int3 GetGridSize() const;
|
int3 GetGridSize() const;
|
||||||
|
|
||||||
Texture2D weights_0_;
|
Texture2D weights_0_;
|
||||||
@ -114,20 +114,20 @@ class ConvTexture : public GPUOperation {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
Status ConvTexture::UploadData(const ::tflite::gpu::Tensor<OHWI, T>& weights,
|
absl::Status ConvTexture::UploadData(
|
||||||
const ::tflite::gpu::Tensor<Linear, T>& biases,
|
const ::tflite::gpu::Tensor<OHWI, T>& weights,
|
||||||
CLContext* context) {
|
const ::tflite::gpu::Tensor<Linear, T>& biases, CLContext* context) {
|
||||||
RETURN_IF_ERROR(UploadWeights(weights, context));
|
RETURN_IF_ERROR(UploadWeights(weights, context));
|
||||||
LinearStorageCreateInfo create_info;
|
LinearStorageCreateInfo create_info;
|
||||||
create_info.storage_type = LinearStorageType::TEXTURE_2D;
|
create_info.storage_type = LinearStorageType::TEXTURE_2D;
|
||||||
create_info.data_type = definition_.GetDataType();
|
create_info.data_type = definition_.GetDataType();
|
||||||
create_info.aligned_size = weights.shape.o;
|
create_info.aligned_size = weights.shape.o;
|
||||||
RETURN_IF_ERROR(CreateLinearStorage(create_info, biases, context, &biases_));
|
RETURN_IF_ERROR(CreateLinearStorage(create_info, biases, context, &biases_));
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
Status ConvTexture::UploadDataForWinograd4x4To6x6(
|
absl::Status ConvTexture::UploadDataForWinograd4x4To6x6(
|
||||||
const ::tflite::gpu::Tensor<OHWI, T>& weights, const CLDevice& device,
|
const ::tflite::gpu::Tensor<OHWI, T>& weights, const CLDevice& device,
|
||||||
CLContext* context) {
|
CLContext* context) {
|
||||||
::tflite::gpu::Tensor<OHWI, T> wino_weights;
|
::tflite::gpu::Tensor<OHWI, T> wino_weights;
|
||||||
@ -145,8 +145,8 @@ Status ConvTexture::UploadDataForWinograd4x4To6x6(
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
Status ConvTexture::UploadWeights(const ::tflite::gpu::Tensor<OHWI, T>& weights,
|
absl::Status ConvTexture::UploadWeights(
|
||||||
CLContext* context) {
|
const ::tflite::gpu::Tensor<OHWI, T>& weights, CLContext* context) {
|
||||||
int dst_depth = IntegralDivideRoundUp(weights.shape.o, 4);
|
int dst_depth = IntegralDivideRoundUp(weights.shape.o, 4);
|
||||||
dst_depth = AlignByN(dst_depth, block_size_.z);
|
dst_depth = AlignByN(dst_depth, block_size_.z);
|
||||||
const int src_depth = IntegralDivideRoundUp(weights.shape.i, 4);
|
const int src_depth = IntegralDivideRoundUp(weights.shape.i, 4);
|
||||||
@ -246,20 +246,19 @@ void ConvTexture::RearrangeWeightsData(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateConvTexture(const CreationContext& creation_context,
|
absl::Status CreateConvTexture(const CreationContext& creation_context,
|
||||||
const OperationDef& definition,
|
const OperationDef& definition,
|
||||||
const Convolution2DAttributes& attr,
|
const Convolution2DAttributes& attr,
|
||||||
ConvTexture* result);
|
ConvTexture* result);
|
||||||
|
|
||||||
Status CreateConvTexture(const CreationContext& creation_context,
|
absl::Status CreateConvTexture(const CreationContext& creation_context,
|
||||||
const OperationDef& definition,
|
const OperationDef& definition,
|
||||||
const FullyConnectedAttributes& attr,
|
const FullyConnectedAttributes& attr,
|
||||||
ConvTexture* result);
|
ConvTexture* result);
|
||||||
|
|
||||||
Status CreateConvTextureWino4x4To6x6(const CreationContext& creation_context,
|
absl::Status CreateConvTextureWino4x4To6x6(
|
||||||
const OperationDef& definition,
|
const CreationContext& creation_context, const OperationDef& definition,
|
||||||
const Convolution2DAttributes& attr,
|
const Convolution2DAttributes& attr, ConvTexture* result);
|
||||||
ConvTexture* result);
|
|
||||||
|
|
||||||
} // namespace cl
|
} // namespace cl
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
|
@ -35,12 +35,12 @@ namespace {
|
|||||||
|
|
||||||
class OpenClConverterImpl : public TensorObjectConverter {
|
class OpenClConverterImpl : public TensorObjectConverter {
|
||||||
public:
|
public:
|
||||||
virtual Status Init(const TensorObjectDef& input_def,
|
virtual absl::Status Init(const TensorObjectDef& input_def,
|
||||||
const TensorObjectDef& output_def,
|
const TensorObjectDef& output_def,
|
||||||
Environment* environment) = 0;
|
Environment* environment) = 0;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
Status DispatchKernel(cl_mem input, cl_mem output) {
|
absl::Status DispatchKernel(cl_mem input, cl_mem output) {
|
||||||
kernel_.ResetBindingCounter();
|
kernel_.ResetBindingCounter();
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(input));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(input));
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(output));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(output));
|
||||||
@ -119,9 +119,9 @@ class FromTensorConverter : public OpenClConverterImpl {
|
|||||||
})");
|
})");
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Init(const TensorObjectDef& input_def,
|
absl::Status Init(const TensorObjectDef& input_def,
|
||||||
const TensorObjectDef& output_def,
|
const TensorObjectDef& output_def,
|
||||||
Environment* environment) final {
|
Environment* environment) final {
|
||||||
auto params_kernel = output_def.object_def.data_layout == DataLayout::BHWC
|
auto params_kernel = output_def.object_def.data_layout == DataLayout::BHWC
|
||||||
? GetToBhwcKernel(input_def, output_def)
|
? GetToBhwcKernel(input_def, output_def)
|
||||||
: GetToDhwc4Kernel(input_def, output_def);
|
: GetToDhwc4Kernel(input_def, output_def);
|
||||||
@ -157,11 +157,12 @@ __kernel void from_tensor()" +
|
|||||||
environment->device(), &kernel_);
|
environment->device(), &kernel_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Convert(const TensorObject& input_obj,
|
absl::Status Convert(const TensorObject& input_obj,
|
||||||
const TensorObject& output_obj) override {
|
const TensorObject& output_obj) override {
|
||||||
auto output = absl::get_if<OpenClBuffer>(&output_obj);
|
auto output = absl::get_if<OpenClBuffer>(&output_obj);
|
||||||
if (!output || !output->memobj) {
|
if (!output || !output->memobj) {
|
||||||
return InvalidArgumentError("Missing output in from_tensor converter");
|
return absl::InvalidArgumentError(
|
||||||
|
"Missing output in from_tensor converter");
|
||||||
}
|
}
|
||||||
auto input_texture = absl::get_if<OpenClTexture>(&input_obj);
|
auto input_texture = absl::get_if<OpenClTexture>(&input_obj);
|
||||||
if (input_texture && input_texture->memobj) {
|
if (input_texture && input_texture->memobj) {
|
||||||
@ -171,7 +172,7 @@ __kernel void from_tensor()" +
|
|||||||
if (input_buffer && input_buffer->memobj) {
|
if (input_buffer && input_buffer->memobj) {
|
||||||
return DispatchKernel(input_buffer->memobj, output->memobj);
|
return DispatchKernel(input_buffer->memobj, output->memobj);
|
||||||
}
|
}
|
||||||
return InvalidArgumentError("Missing input in from_tensor converter");
|
return absl::InvalidArgumentError("Missing input in from_tensor converter");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -225,9 +226,9 @@ class ToTensorConverter : public OpenClConverterImpl {
|
|||||||
)");
|
)");
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Init(const TensorObjectDef& input_def,
|
absl::Status Init(const TensorObjectDef& input_def,
|
||||||
const TensorObjectDef& output_def,
|
const TensorObjectDef& output_def,
|
||||||
Environment* environment) final {
|
Environment* environment) final {
|
||||||
auto params_kernel = input_def.object_def.data_layout == DataLayout::BHWC
|
auto params_kernel = input_def.object_def.data_layout == DataLayout::BHWC
|
||||||
? GetFromBhwcKernel(input_def, output_def)
|
? GetFromBhwcKernel(input_def, output_def)
|
||||||
: GetFromDhwc4Kernel(input_def, output_def);
|
: GetFromDhwc4Kernel(input_def, output_def);
|
||||||
@ -261,11 +262,11 @@ __kernel void to_tensor()" +
|
|||||||
&kernel_);
|
&kernel_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Convert(const TensorObject& input_obj,
|
absl::Status Convert(const TensorObject& input_obj,
|
||||||
const TensorObject& output_obj) override {
|
const TensorObject& output_obj) override {
|
||||||
auto input = absl::get_if<OpenClBuffer>(&input_obj);
|
auto input = absl::get_if<OpenClBuffer>(&input_obj);
|
||||||
if (!input || !input->memobj) {
|
if (!input || !input->memobj) {
|
||||||
return InvalidArgumentError("Missing input in to_tensor converter");
|
return absl::InvalidArgumentError("Missing input in to_tensor converter");
|
||||||
}
|
}
|
||||||
auto output_texture = absl::get_if<OpenClTexture>(&output_obj);
|
auto output_texture = absl::get_if<OpenClTexture>(&output_obj);
|
||||||
if (output_texture && output_texture->memobj) {
|
if (output_texture && output_texture->memobj) {
|
||||||
@ -275,7 +276,7 @@ __kernel void to_tensor()" +
|
|||||||
if (output_buffer && output_buffer->memobj) {
|
if (output_buffer && output_buffer->memobj) {
|
||||||
return DispatchKernel(input->memobj, output_buffer->memobj);
|
return DispatchKernel(input->memobj, output_buffer->memobj);
|
||||||
}
|
}
|
||||||
return InvalidArgumentError("Missing input in to_tensor converter");
|
return absl::InvalidArgumentError("Missing input in to_tensor converter");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -318,18 +319,18 @@ class TrivialCopier : public OpenClConverterImpl {
|
|||||||
input.data_layout == output.data_layout;
|
input.data_layout == output.data_layout;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Init(const TensorObjectDef& input_def,
|
absl::Status Init(const TensorObjectDef& input_def,
|
||||||
const TensorObjectDef& output_def,
|
const TensorObjectDef& output_def,
|
||||||
Environment* environment) final {
|
Environment* environment) final {
|
||||||
dims_ = input_def.dimensions;
|
dims_ = input_def.dimensions;
|
||||||
data_type_ = input_def.object_def.data_type;
|
data_type_ = input_def.object_def.data_type;
|
||||||
queue_ = environment->queue();
|
queue_ = environment->queue();
|
||||||
region_ = CalculateTextureRegion(output_def);
|
region_ = CalculateTextureRegion(output_def);
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Convert(const TensorObject& input_obj,
|
absl::Status Convert(const TensorObject& input_obj,
|
||||||
const TensorObject& output_obj) override {
|
const TensorObject& output_obj) override {
|
||||||
auto texture_input = absl::get_if<OpenClTexture>(&input_obj);
|
auto texture_input = absl::get_if<OpenClTexture>(&input_obj);
|
||||||
auto texture_output = absl::get_if<OpenClTexture>(&output_obj);
|
auto texture_output = absl::get_if<OpenClTexture>(&output_obj);
|
||||||
if (texture_input && texture_output) {
|
if (texture_input && texture_output) {
|
||||||
@ -340,12 +341,12 @@ class TrivialCopier : public OpenClConverterImpl {
|
|||||||
if (buffer_input && buffer_output) {
|
if (buffer_input && buffer_output) {
|
||||||
return Copy(*buffer_input, *buffer_output);
|
return Copy(*buffer_input, *buffer_output);
|
||||||
}
|
}
|
||||||
return InternalError("Unexpected object");
|
return absl::InternalError("Unexpected object");
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Copy(const OpenClBuffer& input, const OpenClBuffer& output) {
|
absl::Status Copy(const OpenClBuffer& input, const OpenClBuffer& output) {
|
||||||
if (input.memobj == output.memobj) {
|
if (input.memobj == output.memobj) {
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
return GetOpenCLError(clEnqueueCopyBuffer(
|
return GetOpenCLError(clEnqueueCopyBuffer(
|
||||||
queue_->queue(), input.memobj, output.memobj, 0, 0,
|
queue_->queue(), input.memobj, output.memobj, 0, 0,
|
||||||
@ -353,9 +354,9 @@ class TrivialCopier : public OpenClConverterImpl {
|
|||||||
nullptr));
|
nullptr));
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Copy(const OpenClTexture& input, const OpenClTexture& output) {
|
absl::Status Copy(const OpenClTexture& input, const OpenClTexture& output) {
|
||||||
if (input.memobj == output.memobj) {
|
if (input.memobj == output.memobj) {
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
size_t origin[3] = {0, 0, 0};
|
size_t origin[3] = {0, 0, 0};
|
||||||
return GetOpenCLError(
|
return GetOpenCLError(
|
||||||
@ -380,18 +381,18 @@ class CpuCopier : public OpenClConverterImpl {
|
|||||||
IsOpenClTextureOrBuffer(input.object_type)));
|
IsOpenClTextureOrBuffer(input.object_type)));
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Init(const TensorObjectDef& input_def,
|
absl::Status Init(const TensorObjectDef& input_def,
|
||||||
const TensorObjectDef& output_def,
|
const TensorObjectDef& output_def,
|
||||||
Environment* environment) final {
|
Environment* environment) final {
|
||||||
region_ = CalculateTextureRegion(
|
region_ = CalculateTextureRegion(
|
||||||
input_def.object_def.object_type == ObjectType::CPU_MEMORY ? output_def
|
input_def.object_def.object_type == ObjectType::CPU_MEMORY ? output_def
|
||||||
: input_def);
|
: input_def);
|
||||||
queue_ = environment->queue();
|
queue_ = environment->queue();
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Convert(const TensorObject& input_obj,
|
absl::Status Convert(const TensorObject& input_obj,
|
||||||
const TensorObject& output_obj) override {
|
const TensorObject& output_obj) override {
|
||||||
auto cpu_input = absl::get_if<CpuMemory>(&input_obj);
|
auto cpu_input = absl::get_if<CpuMemory>(&input_obj);
|
||||||
auto cpu_output = absl::get_if<CpuMemory>(&output_obj);
|
auto cpu_output = absl::get_if<CpuMemory>(&output_obj);
|
||||||
if (cpu_input) {
|
if (cpu_input) {
|
||||||
@ -419,7 +420,7 @@ class CpuCopier : public OpenClConverterImpl {
|
|||||||
buffer_input->memobj, cpu_output->size_bytes, cpu_output->data);
|
buffer_input->memobj, cpu_output->size_bytes, cpu_output->data);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return InternalError("Unexpected object");
|
return absl::InternalError("Unexpected object");
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -442,7 +443,7 @@ class OpenClTensorConverterBuilder : public TensorObjectConverterBuilder {
|
|||||||
ToTensorConverter::IsSupported(input_def, output_def));
|
ToTensorConverter::IsSupported(input_def, output_def));
|
||||||
}
|
}
|
||||||
|
|
||||||
Status MakeConverter(
|
absl::Status MakeConverter(
|
||||||
const TensorObjectDef& input, const TensorObjectDef& output,
|
const TensorObjectDef& input, const TensorObjectDef& output,
|
||||||
std::unique_ptr<TensorObjectConverter>* converter) final {
|
std::unique_ptr<TensorObjectConverter>* converter) final {
|
||||||
std::unique_ptr<OpenClConverterImpl> impl;
|
std::unique_ptr<OpenClConverterImpl> impl;
|
||||||
@ -457,11 +458,11 @@ class OpenClTensorConverterBuilder : public TensorObjectConverterBuilder {
|
|||||||
} else if (ToTensorConverter::IsSupported(input_def, output_def)) {
|
} else if (ToTensorConverter::IsSupported(input_def, output_def)) {
|
||||||
impl = absl::make_unique<ToTensorConverter>();
|
impl = absl::make_unique<ToTensorConverter>();
|
||||||
} else {
|
} else {
|
||||||
return UnimplementedError("Unsupported conversion");
|
return absl::UnimplementedError("Unsupported conversion");
|
||||||
}
|
}
|
||||||
RETURN_IF_ERROR(impl->Init(input, output, environment_));
|
RETURN_IF_ERROR(impl->Init(input, output, environment_));
|
||||||
*converter = std::move(impl);
|
*converter = std::move(impl);
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
Environment* environment_;
|
Environment* environment_;
|
||||||
|
@ -368,7 +368,8 @@ ConvolutionTransposed& ConvolutionTransposed::operator=(
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConvolutionTransposed::Compile(const CreationContext& creation_context) {
|
absl::Status ConvolutionTransposed::Compile(
|
||||||
|
const CreationContext& creation_context) {
|
||||||
const auto code = GenerateConvolutionTransposedCode(
|
const auto code = GenerateConvolutionTransposedCode(
|
||||||
definition_, biases_, *creation_context.device, weights_are_buffer_,
|
definition_, biases_, *creation_context.device, weights_are_buffer_,
|
||||||
block_size_, linked_operations_);
|
block_size_, linked_operations_);
|
||||||
@ -380,7 +381,7 @@ Status ConvolutionTransposed::Compile(const CreationContext& creation_context) {
|
|||||||
*creation_context.device, &kernel_);
|
*creation_context.device, &kernel_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConvolutionTransposed::BindArguments() {
|
absl::Status ConvolutionTransposed::BindArguments() {
|
||||||
kernel_.ResetBindingCounter();
|
kernel_.ResetBindingCounter();
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||||
if (weights_are_buffer_) {
|
if (weights_are_buffer_) {
|
||||||
@ -399,7 +400,7 @@ Status ConvolutionTransposed::BindArguments() {
|
|||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(padding_));
|
RETURN_IF_ERROR(kernel_.SetBytesAuto(padding_));
|
||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHSB()));
|
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHSB()));
|
||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB()));
|
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB()));
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
int3 ConvolutionTransposed::GetGridSize() const {
|
int3 ConvolutionTransposed::GetGridSize() const {
|
||||||
@ -412,21 +413,21 @@ int3 ConvolutionTransposed::GetGridSize() const {
|
|||||||
return int3(grid_x, grid_y, grid_z);
|
return int3(grid_x, grid_y, grid_z);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConvolutionTransposed::Tune(const TuningParameters& params) {
|
absl::Status ConvolutionTransposed::Tune(const TuningParameters& params) {
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return GetBestWorkGroupConv(params, kernel_, GetGridSize(),
|
return GetBestWorkGroupConv(params, kernel_, GetGridSize(),
|
||||||
&work_group_size_);
|
&work_group_size_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConvolutionTransposed::AddToQueue(CLCommandQueue* queue) {
|
absl::Status ConvolutionTransposed::AddToQueue(CLCommandQueue* queue) {
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateConvolutionTransposed(const CreationContext& creation_context,
|
absl::Status CreateConvolutionTransposed(
|
||||||
const OperationDef& definition,
|
const CreationContext& creation_context, const OperationDef& definition,
|
||||||
const ConvolutionTransposedAttributes& attr,
|
const ConvolutionTransposedAttributes& attr,
|
||||||
ConvolutionTransposed* result) {
|
ConvolutionTransposed* result) {
|
||||||
*result = ConvolutionTransposed(definition, attr, *creation_context.device);
|
*result = ConvolutionTransposed(definition, attr, *creation_context.device);
|
||||||
RETURN_IF_ERROR(
|
RETURN_IF_ERROR(
|
||||||
result->UploadWeights(attr.weights, creation_context.context));
|
result->UploadWeights(attr.weights, creation_context.context));
|
||||||
@ -438,8 +439,7 @@ Status CreateConvolutionTransposed(const CreationContext& creation_context,
|
|||||||
create_info.aligned_size = attr.weights.shape.o;
|
create_info.aligned_size = attr.weights.shape.o;
|
||||||
RETURN_IF_ERROR(CreateLinearStorage(
|
RETURN_IF_ERROR(CreateLinearStorage(
|
||||||
create_info, attr.bias, creation_context.context, &result->biases_));
|
create_info, attr.bias, creation_context.context, &result->biases_));
|
||||||
|
return absl::OkStatus();
|
||||||
return OkStatus();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cl
|
} // namespace cl
|
||||||
|
@ -38,10 +38,10 @@ namespace cl {
|
|||||||
class ConvolutionTransposed : public GPUOperation {
|
class ConvolutionTransposed : public GPUOperation {
|
||||||
public:
|
public:
|
||||||
ConvolutionTransposed() = default;
|
ConvolutionTransposed() = default;
|
||||||
Status AddToQueue(CLCommandQueue* queue) override;
|
absl::Status AddToQueue(CLCommandQueue* queue) override;
|
||||||
Status Tune(const TuningParameters& params) override;
|
absl::Status Tune(const TuningParameters& params) override;
|
||||||
|
|
||||||
Status Compile(const CreationContext& creation_context) override;
|
absl::Status Compile(const CreationContext& creation_context) override;
|
||||||
|
|
||||||
// Move only
|
// Move only
|
||||||
ConvolutionTransposed(ConvolutionTransposed&& operation);
|
ConvolutionTransposed(ConvolutionTransposed&& operation);
|
||||||
@ -50,7 +50,7 @@ class ConvolutionTransposed : public GPUOperation {
|
|||||||
ConvolutionTransposed& operator=(const ConvolutionTransposed&) = delete;
|
ConvolutionTransposed& operator=(const ConvolutionTransposed&) = delete;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
friend Status CreateConvolutionTransposed(
|
friend absl::Status CreateConvolutionTransposed(
|
||||||
const CreationContext& creation_context, const OperationDef& definition,
|
const CreationContext& creation_context, const OperationDef& definition,
|
||||||
const ConvolutionTransposedAttributes& attr,
|
const ConvolutionTransposedAttributes& attr,
|
||||||
ConvolutionTransposed* result);
|
ConvolutionTransposed* result);
|
||||||
@ -58,14 +58,14 @@ class ConvolutionTransposed : public GPUOperation {
|
|||||||
const ConvolutionTransposedAttributes& attr,
|
const ConvolutionTransposedAttributes& attr,
|
||||||
const CLDevice& device);
|
const CLDevice& device);
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
Status UploadWeights(const ::tflite::gpu::Tensor<OHWI, T>& weights,
|
absl::Status UploadWeights(const ::tflite::gpu::Tensor<OHWI, T>& weights,
|
||||||
CLContext* context);
|
CLContext* context);
|
||||||
|
|
||||||
template <DataType S, typename T>
|
template <DataType S, typename T>
|
||||||
void RearrangeWeightsData(const ::tflite::gpu::Tensor<OHWI, S>& weights,
|
void RearrangeWeightsData(const ::tflite::gpu::Tensor<OHWI, S>& weights,
|
||||||
absl::Span<T> dst);
|
absl::Span<T> dst);
|
||||||
|
|
||||||
Status BindArguments();
|
absl::Status BindArguments();
|
||||||
int3 GetGridSize() const;
|
int3 GetGridSize() const;
|
||||||
|
|
||||||
LinearStorage biases_;
|
LinearStorage biases_;
|
||||||
@ -88,7 +88,7 @@ class ConvolutionTransposed : public GPUOperation {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
Status ConvolutionTransposed::UploadWeights(
|
absl::Status ConvolutionTransposed::UploadWeights(
|
||||||
const ::tflite::gpu::Tensor<OHWI, T>& weights, CLContext* context) {
|
const ::tflite::gpu::Tensor<OHWI, T>& weights, CLContext* context) {
|
||||||
const int dst_depth =
|
const int dst_depth =
|
||||||
AlignByN(IntegralDivideRoundUp(weights.shape.o, 4), block_size_.z);
|
AlignByN(IntegralDivideRoundUp(weights.shape.o, 4), block_size_.z);
|
||||||
@ -153,7 +153,7 @@ Status ConvolutionTransposed::UploadWeights(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
template <DataType S, typename T>
|
template <DataType S, typename T>
|
||||||
@ -208,10 +208,9 @@ void ConvolutionTransposed::RearrangeWeightsData(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateConvolutionTransposed(const CreationContext& creation_context,
|
absl::Status CreateConvolutionTransposed(
|
||||||
const OperationDef& definition,
|
const CreationContext& creation_context, const OperationDef& definition,
|
||||||
const ConvolutionTransposedAttributes& attr,
|
const ConvolutionTransposedAttributes& attr, ConvolutionTransposed* result);
|
||||||
ConvolutionTransposed* result);
|
|
||||||
|
|
||||||
} // namespace cl
|
} // namespace cl
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
|
@ -396,7 +396,7 @@ ConvolutionTransposed3D& ConvolutionTransposed3D::operator=(
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConvolutionTransposed3D::Compile(
|
absl::Status ConvolutionTransposed3D::Compile(
|
||||||
const CreationContext& creation_context) {
|
const CreationContext& creation_context) {
|
||||||
const auto code = GenerateConvolutionTransposed3DCode(
|
const auto code = GenerateConvolutionTransposed3DCode(
|
||||||
definition_, biases_, *creation_context.device, weights_are_buffer_,
|
definition_, biases_, *creation_context.device, weights_are_buffer_,
|
||||||
@ -417,7 +417,7 @@ Status ConvolutionTransposed3D::Compile(
|
|||||||
*creation_context.device, &kernel_);
|
*creation_context.device, &kernel_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConvolutionTransposed3D::BindArguments() {
|
absl::Status ConvolutionTransposed3D::BindArguments() {
|
||||||
kernel_.ResetBindingCounter();
|
kernel_.ResetBindingCounter();
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||||
if (weights_are_buffer_) {
|
if (weights_are_buffer_) {
|
||||||
@ -444,7 +444,7 @@ Status ConvolutionTransposed3D::BindArguments() {
|
|||||||
IntegralDivideRoundUp(dst_[0]->Slices(), block_size_.w)));
|
IntegralDivideRoundUp(dst_[0]->Slices(), block_size_.w)));
|
||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHDS()));
|
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHDS()));
|
||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHDS()));
|
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHDS()));
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
int3 ConvolutionTransposed3D::GetGridSize() const {
|
int3 ConvolutionTransposed3D::GetGridSize() const {
|
||||||
@ -459,18 +459,18 @@ int3 ConvolutionTransposed3D::GetGridSize() const {
|
|||||||
return int3(grid_x, grid_y, grid_z);
|
return int3(grid_x, grid_y, grid_z);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConvolutionTransposed3D::Tune(const TuningParameters& params) {
|
absl::Status ConvolutionTransposed3D::Tune(const TuningParameters& params) {
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return GetBestWorkGroupConv(params, kernel_, GetGridSize(),
|
return GetBestWorkGroupConv(params, kernel_, GetGridSize(),
|
||||||
&work_group_size_);
|
&work_group_size_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConvolutionTransposed3D::AddToQueue(CLCommandQueue* queue) {
|
absl::Status ConvolutionTransposed3D::AddToQueue(CLCommandQueue* queue) {
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateConvolutionTransposed3D(
|
absl::Status CreateConvolutionTransposed3D(
|
||||||
const CreationContext& creation_context, const OperationDef& definition,
|
const CreationContext& creation_context, const OperationDef& definition,
|
||||||
const ConvolutionTransposed3DAttributes& attr,
|
const ConvolutionTransposed3DAttributes& attr,
|
||||||
ConvolutionTransposed3D* result) {
|
ConvolutionTransposed3D* result) {
|
||||||
@ -485,8 +485,7 @@ Status CreateConvolutionTransposed3D(
|
|||||||
create_info.aligned_size = attr.weights.shape.o;
|
create_info.aligned_size = attr.weights.shape.o;
|
||||||
RETURN_IF_ERROR(CreateLinearStorage(
|
RETURN_IF_ERROR(CreateLinearStorage(
|
||||||
create_info, attr.bias, creation_context.context, &result->biases_));
|
create_info, attr.bias, creation_context.context, &result->biases_));
|
||||||
|
return absl::OkStatus();
|
||||||
return OkStatus();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cl
|
} // namespace cl
|
||||||
|
@ -38,10 +38,10 @@ namespace cl {
|
|||||||
class ConvolutionTransposed3D : public GPUOperation {
|
class ConvolutionTransposed3D : public GPUOperation {
|
||||||
public:
|
public:
|
||||||
ConvolutionTransposed3D() = default;
|
ConvolutionTransposed3D() = default;
|
||||||
Status AddToQueue(CLCommandQueue* queue) override;
|
absl::Status AddToQueue(CLCommandQueue* queue) override;
|
||||||
Status Tune(const TuningParameters& params) override;
|
absl::Status Tune(const TuningParameters& params) override;
|
||||||
|
|
||||||
Status Compile(const CreationContext& creation_context) override;
|
absl::Status Compile(const CreationContext& creation_context) override;
|
||||||
|
|
||||||
// Move only
|
// Move only
|
||||||
ConvolutionTransposed3D(ConvolutionTransposed3D&& operation);
|
ConvolutionTransposed3D(ConvolutionTransposed3D&& operation);
|
||||||
@ -50,7 +50,7 @@ class ConvolutionTransposed3D : public GPUOperation {
|
|||||||
ConvolutionTransposed3D& operator=(const ConvolutionTransposed3D&) = delete;
|
ConvolutionTransposed3D& operator=(const ConvolutionTransposed3D&) = delete;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
friend Status CreateConvolutionTransposed3D(
|
friend absl::Status CreateConvolutionTransposed3D(
|
||||||
const CreationContext& creation_context, const OperationDef& definition,
|
const CreationContext& creation_context, const OperationDef& definition,
|
||||||
const ConvolutionTransposed3DAttributes& attr,
|
const ConvolutionTransposed3DAttributes& attr,
|
||||||
ConvolutionTransposed3D* result);
|
ConvolutionTransposed3D* result);
|
||||||
@ -58,14 +58,14 @@ class ConvolutionTransposed3D : public GPUOperation {
|
|||||||
const ConvolutionTransposed3DAttributes& attr,
|
const ConvolutionTransposed3DAttributes& attr,
|
||||||
const CLDevice& device);
|
const CLDevice& device);
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
Status UploadWeights(const ::tflite::gpu::Tensor<OHWDI, T>& weights,
|
absl::Status UploadWeights(const ::tflite::gpu::Tensor<OHWDI, T>& weights,
|
||||||
CLContext* context);
|
CLContext* context);
|
||||||
|
|
||||||
template <DataType S, typename T>
|
template <DataType S, typename T>
|
||||||
void RearrangeWeightsData(const ::tflite::gpu::Tensor<OHWDI, S>& weights,
|
void RearrangeWeightsData(const ::tflite::gpu::Tensor<OHWDI, S>& weights,
|
||||||
absl::Span<T> dst);
|
absl::Span<T> dst);
|
||||||
|
|
||||||
Status BindArguments();
|
absl::Status BindArguments();
|
||||||
int3 GetGridSize() const;
|
int3 GetGridSize() const;
|
||||||
|
|
||||||
LinearStorage biases_;
|
LinearStorage biases_;
|
||||||
@ -88,7 +88,7 @@ class ConvolutionTransposed3D : public GPUOperation {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
Status ConvolutionTransposed3D::UploadWeights(
|
absl::Status ConvolutionTransposed3D::UploadWeights(
|
||||||
const ::tflite::gpu::Tensor<OHWDI, T>& weights, CLContext* context) {
|
const ::tflite::gpu::Tensor<OHWDI, T>& weights, CLContext* context) {
|
||||||
const int dst_depth =
|
const int dst_depth =
|
||||||
AlignByN(IntegralDivideRoundUp(weights.shape.o, 4), block_size_.z);
|
AlignByN(IntegralDivideRoundUp(weights.shape.o, 4), block_size_.z);
|
||||||
@ -155,7 +155,7 @@ Status ConvolutionTransposed3D::UploadWeights(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
template <DataType S, typename T>
|
template <DataType S, typename T>
|
||||||
@ -214,7 +214,7 @@ void ConvolutionTransposed3D::RearrangeWeightsData(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateConvolutionTransposed3D(
|
absl::Status CreateConvolutionTransposed3D(
|
||||||
const CreationContext& creation_context, const OperationDef& definition,
|
const CreationContext& creation_context, const OperationDef& definition,
|
||||||
const ConvolutionTransposed3DAttributes& attr,
|
const ConvolutionTransposed3DAttributes& attr,
|
||||||
ConvolutionTransposed3D* result);
|
ConvolutionTransposed3D* result);
|
||||||
|
@ -304,12 +304,11 @@ ConvolutionTransposed3x3& ConvolutionTransposed3x3::operator=(
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConvolutionTransposed3x3::Compile(
|
absl::Status ConvolutionTransposed3x3::Compile(
|
||||||
const CreationContext& creation_context) {
|
const CreationContext& creation_context) {
|
||||||
const auto code = GenerateConvolutionTransposedCode(
|
const auto code = GenerateConvolutionTransposedCode(
|
||||||
definition_, biases_, linked_operations_, weights_upload_type_, padding_,
|
definition_, biases_, linked_operations_, weights_upload_type_, padding_,
|
||||||
work_group_launch_order_);
|
work_group_launch_order_);
|
||||||
|
|
||||||
std::vector<CompilerOptions> options;
|
std::vector<CompilerOptions> options;
|
||||||
if (definition_.precision == CalculationsPrecision::F16 &&
|
if (definition_.precision == CalculationsPrecision::F16 &&
|
||||||
creation_context.device->IsPowerVR()) {
|
creation_context.device->IsPowerVR()) {
|
||||||
@ -318,11 +317,10 @@ Status ConvolutionTransposed3x3::Compile(
|
|||||||
RETURN_IF_ERROR(creation_context.cache->GetOrCreateCLKernel(
|
RETURN_IF_ERROR(creation_context.cache->GetOrCreateCLKernel(
|
||||||
code, "main_function", options, *creation_context.context,
|
code, "main_function", options, *creation_context.context,
|
||||||
*creation_context.device, &kernel_));
|
*creation_context.device, &kernel_));
|
||||||
|
return absl::OkStatus();
|
||||||
return OkStatus();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConvolutionTransposed3x3::BindArguments() {
|
absl::Status ConvolutionTransposed3x3::BindArguments() {
|
||||||
kernel_.ResetBindingCounter();
|
kernel_.ResetBindingCounter();
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_.GetMemoryPtr()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_.GetMemoryPtr()));
|
||||||
@ -337,10 +335,7 @@ Status ConvolutionTransposed3x3::BindArguments() {
|
|||||||
padding_.x >= 1 ? (padding_.x - 1) / 2 : (padding_.x - 2) / 2;
|
padding_.x >= 1 ? (padding_.x - 1) / 2 : (padding_.x - 2) / 2;
|
||||||
const int padding_y =
|
const int padding_y =
|
||||||
padding_.y >= 1 ? (padding_.y - 1) / 2 : (padding_.y - 2) / 2;
|
padding_.y >= 1 ? (padding_.y - 1) / 2 : (padding_.y - 2) / 2;
|
||||||
RETURN_IF_ERROR(
|
return kernel_.SetBytesAuto(int2(padding_x * src_[0]->Batch(), padding_y));
|
||||||
kernel_.SetBytesAuto(int2(padding_x * src_[0]->Batch(), padding_y)));
|
|
||||||
|
|
||||||
return OkStatus();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int3 ConvolutionTransposed3x3::GetGridSize() const {
|
int3 ConvolutionTransposed3x3::GetGridSize() const {
|
||||||
@ -358,7 +353,7 @@ int3 ConvolutionTransposed3x3::GetGridSize() const {
|
|||||||
return int3(grid_x, grid_y, grid_z);
|
return int3(grid_x, grid_y, grid_z);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConvolutionTransposed3x3::AddToQueue(CLCommandQueue* queue) {
|
absl::Status ConvolutionTransposed3x3::AddToQueue(CLCommandQueue* queue) {
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
||||||
}
|
}
|
||||||
@ -370,13 +365,13 @@ bool IsConvolutionTransposed3x3Supported(
|
|||||||
attr.stride.w == 2 && attr.stride.h == 2;
|
attr.stride.w == 2 && attr.stride.h == 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateConvolutionTransposed3x3(
|
absl::Status CreateConvolutionTransposed3x3(
|
||||||
const CreationContext& creation_context, const OperationDef& definition,
|
const CreationContext& creation_context, const OperationDef& definition,
|
||||||
const ConvolutionTransposedAttributes& attr,
|
const ConvolutionTransposedAttributes& attr,
|
||||||
ConvolutionTransposed3x3* result) {
|
ConvolutionTransposed3x3* result) {
|
||||||
if (!IsConvolutionTransposed3x3Supported(*creation_context.device, definition,
|
if (!IsConvolutionTransposed3x3Supported(*creation_context.device, definition,
|
||||||
attr)) {
|
attr)) {
|
||||||
return InvalidArgumentError(
|
return absl::InvalidArgumentError(
|
||||||
"ConvolutionTransposed3x3 doesn't support this attributes");
|
"ConvolutionTransposed3x3 doesn't support this attributes");
|
||||||
}
|
}
|
||||||
const int2 padding = int2(attr.padding.prepended.w, attr.padding.prepended.h);
|
const int2 padding = int2(attr.padding.prepended.w, attr.padding.prepended.h);
|
||||||
@ -391,7 +386,7 @@ Status CreateConvolutionTransposed3x3(
|
|||||||
create_info.aligned_size = attr.weights.shape.o;
|
create_info.aligned_size = attr.weights.shape.o;
|
||||||
RETURN_IF_ERROR(CreateLinearStorage(
|
RETURN_IF_ERROR(CreateLinearStorage(
|
||||||
create_info, attr.bias, creation_context.context, &result->biases_));
|
create_info, attr.bias, creation_context.context, &result->biases_));
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cl
|
} // namespace cl
|
||||||
|
@ -37,8 +37,8 @@ namespace cl {
|
|||||||
class ConvolutionTransposed3x3 : public GPUOperation {
|
class ConvolutionTransposed3x3 : public GPUOperation {
|
||||||
public:
|
public:
|
||||||
ConvolutionTransposed3x3() = default;
|
ConvolutionTransposed3x3() = default;
|
||||||
Status AddToQueue(CLCommandQueue* queue) override;
|
absl::Status AddToQueue(CLCommandQueue* queue) override;
|
||||||
Status Compile(const CreationContext& creation_context) override;
|
absl::Status Compile(const CreationContext& creation_context) override;
|
||||||
|
|
||||||
// Move only
|
// Move only
|
||||||
ConvolutionTransposed3x3(ConvolutionTransposed3x3&& operation);
|
ConvolutionTransposed3x3(ConvolutionTransposed3x3&& operation);
|
||||||
@ -56,19 +56,19 @@ class ConvolutionTransposed3x3 : public GPUOperation {
|
|||||||
private:
|
private:
|
||||||
ConvolutionTransposed3x3(const OperationDef& definition,
|
ConvolutionTransposed3x3(const OperationDef& definition,
|
||||||
const CLDevice& device, int2 padding);
|
const CLDevice& device, int2 padding);
|
||||||
friend Status CreateConvolutionTransposed3x3(
|
friend absl::Status CreateConvolutionTransposed3x3(
|
||||||
const CreationContext& creation_context, const OperationDef& definition,
|
const CreationContext& creation_context, const OperationDef& definition,
|
||||||
const ConvolutionTransposedAttributes& attr,
|
const ConvolutionTransposedAttributes& attr,
|
||||||
ConvolutionTransposed3x3* result);
|
ConvolutionTransposed3x3* result);
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
Status UploadWeights(const ::tflite::gpu::Tensor<OHWI, T>& weights,
|
absl::Status UploadWeights(const ::tflite::gpu::Tensor<OHWI, T>& weights,
|
||||||
CLContext* context);
|
CLContext* context);
|
||||||
|
|
||||||
template <DataType S, typename T>
|
template <DataType S, typename T>
|
||||||
void RearrangeWeightsData(const ::tflite::gpu::Tensor<OHWI, S>& weights,
|
void RearrangeWeightsData(const ::tflite::gpu::Tensor<OHWI, S>& weights,
|
||||||
absl::Span<T> dst);
|
absl::Span<T> dst);
|
||||||
|
|
||||||
Status BindArguments();
|
absl::Status BindArguments();
|
||||||
int3 GetGridSize() const;
|
int3 GetGridSize() const;
|
||||||
|
|
||||||
int2 padding_;
|
int2 padding_;
|
||||||
@ -82,7 +82,7 @@ class ConvolutionTransposed3x3 : public GPUOperation {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
Status ConvolutionTransposed3x3::UploadWeights(
|
absl::Status ConvolutionTransposed3x3::UploadWeights(
|
||||||
const ::tflite::gpu::Tensor<OHWI, T>& weights, CLContext* context) {
|
const ::tflite::gpu::Tensor<OHWI, T>& weights, CLContext* context) {
|
||||||
const int src_depth = IntegralDivideRoundUp(weights.shape.i, 4);
|
const int src_depth = IntegralDivideRoundUp(weights.shape.i, 4);
|
||||||
const int dst_depth = IntegralDivideRoundUp(weights.shape.o, 4);
|
const int dst_depth = IntegralDivideRoundUp(weights.shape.o, 4);
|
||||||
@ -165,7 +165,7 @@ bool IsConvolutionTransposed3x3Supported(
|
|||||||
const CLDevice& device, const OperationDef& definition,
|
const CLDevice& device, const OperationDef& definition,
|
||||||
const ConvolutionTransposedAttributes& attr);
|
const ConvolutionTransposedAttributes& attr);
|
||||||
|
|
||||||
Status CreateConvolutionTransposed3x3(
|
absl::Status CreateConvolutionTransposed3x3(
|
||||||
const CreationContext& creation_context, const OperationDef& definition,
|
const CreationContext& creation_context, const OperationDef& definition,
|
||||||
const ConvolutionTransposedAttributes& attr,
|
const ConvolutionTransposedAttributes& attr,
|
||||||
ConvolutionTransposed3x3* result);
|
ConvolutionTransposed3x3* result);
|
||||||
|
@ -221,19 +221,18 @@ ConvolutionTransposed3x3Thin& ConvolutionTransposed3x3Thin::operator=(
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConvolutionTransposed3x3Thin::Compile(
|
absl::Status ConvolutionTransposed3x3Thin::Compile(
|
||||||
const CreationContext& creation_context) {
|
const CreationContext& creation_context) {
|
||||||
const auto code = GenerateConvolutionTransposedCode(
|
const auto code = GenerateConvolutionTransposedCode(
|
||||||
definition_, biases_, IntegralDivideRoundUp(src_channels_, 4),
|
definition_, biases_, IntegralDivideRoundUp(src_channels_, 4),
|
||||||
IntegralDivideRoundUp(dst_channels_, 4), *creation_context.device,
|
IntegralDivideRoundUp(dst_channels_, 4), *creation_context.device,
|
||||||
linked_operations_);
|
linked_operations_);
|
||||||
|
|
||||||
return creation_context.cache->GetOrCreateCLKernel(
|
return creation_context.cache->GetOrCreateCLKernel(
|
||||||
code, "main_function", *creation_context.context,
|
code, "main_function", *creation_context.context,
|
||||||
*creation_context.device, &kernel_);
|
*creation_context.device, &kernel_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConvolutionTransposed3x3Thin::BindArguments() {
|
absl::Status ConvolutionTransposed3x3Thin::BindArguments() {
|
||||||
kernel_.ResetBindingCounter();
|
kernel_.ResetBindingCounter();
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_.GetMemoryPtr()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_.GetMemoryPtr()));
|
||||||
@ -242,7 +241,7 @@ Status ConvolutionTransposed3x3Thin::BindArguments() {
|
|||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
|
||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHSB()));
|
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHSB()));
|
||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB()));
|
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB()));
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
int3 ConvolutionTransposed3x3Thin::GetGridSize() const {
|
int3 ConvolutionTransposed3x3Thin::GetGridSize() const {
|
||||||
@ -252,12 +251,13 @@ int3 ConvolutionTransposed3x3Thin::GetGridSize() const {
|
|||||||
return int3(grid_x, grid_y, grid_z);
|
return int3(grid_x, grid_y, grid_z);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConvolutionTransposed3x3Thin::Tune(const TuningParameters& params) {
|
absl::Status ConvolutionTransposed3x3Thin::Tune(
|
||||||
|
const TuningParameters& params) {
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_);
|
return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConvolutionTransposed3x3Thin::AddToQueue(CLCommandQueue* queue) {
|
absl::Status ConvolutionTransposed3x3Thin::AddToQueue(CLCommandQueue* queue) {
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
||||||
}
|
}
|
||||||
@ -271,13 +271,13 @@ bool IsConvolutionTransposed3x3ThinSupported(
|
|||||||
attr.padding.appended.h == 1;
|
attr.padding.appended.h == 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateConvolutionTransposed3x3Thin(
|
absl::Status CreateConvolutionTransposed3x3Thin(
|
||||||
const CreationContext& creation_context, const OperationDef& definition,
|
const CreationContext& creation_context, const OperationDef& definition,
|
||||||
const ConvolutionTransposedAttributes& attr,
|
const ConvolutionTransposedAttributes& attr,
|
||||||
ConvolutionTransposed3x3Thin* result) {
|
ConvolutionTransposed3x3Thin* result) {
|
||||||
if (!IsConvolutionTransposed3x3ThinSupported(*creation_context.device,
|
if (!IsConvolutionTransposed3x3ThinSupported(*creation_context.device,
|
||||||
attr)) {
|
attr)) {
|
||||||
return InvalidArgumentError(
|
return absl::InvalidArgumentError(
|
||||||
"ConvolutionTransposed3x3Thin doesn't support this attributes");
|
"ConvolutionTransposed3x3Thin doesn't support this attributes");
|
||||||
}
|
}
|
||||||
*result = ConvolutionTransposed3x3Thin(definition, attr);
|
*result = ConvolutionTransposed3x3Thin(definition, attr);
|
||||||
@ -291,8 +291,7 @@ Status CreateConvolutionTransposed3x3Thin(
|
|||||||
create_info.aligned_size = attr.weights.shape.o;
|
create_info.aligned_size = attr.weights.shape.o;
|
||||||
RETURN_IF_ERROR(CreateLinearStorage(
|
RETURN_IF_ERROR(CreateLinearStorage(
|
||||||
create_info, attr.bias, creation_context.context, &result->biases_));
|
create_info, attr.bias, creation_context.context, &result->biases_));
|
||||||
|
return absl::OkStatus();
|
||||||
return OkStatus();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cl
|
} // namespace cl
|
||||||
|
@ -37,10 +37,10 @@ namespace cl {
|
|||||||
class ConvolutionTransposed3x3Thin : public GPUOperation {
|
class ConvolutionTransposed3x3Thin : public GPUOperation {
|
||||||
public:
|
public:
|
||||||
ConvolutionTransposed3x3Thin() = default;
|
ConvolutionTransposed3x3Thin() = default;
|
||||||
Status AddToQueue(CLCommandQueue* queue) override;
|
absl::Status AddToQueue(CLCommandQueue* queue) override;
|
||||||
Status Tune(const TuningParameters& params) override;
|
absl::Status Tune(const TuningParameters& params) override;
|
||||||
|
|
||||||
Status Compile(const CreationContext& creation_context) override;
|
absl::Status Compile(const CreationContext& creation_context) override;
|
||||||
|
|
||||||
// Move only
|
// Move only
|
||||||
ConvolutionTransposed3x3Thin(ConvolutionTransposed3x3Thin&& operation);
|
ConvolutionTransposed3x3Thin(ConvolutionTransposed3x3Thin&& operation);
|
||||||
@ -51,7 +51,7 @@ class ConvolutionTransposed3x3Thin : public GPUOperation {
|
|||||||
delete;
|
delete;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
friend Status CreateConvolutionTransposed3x3Thin(
|
friend absl::Status CreateConvolutionTransposed3x3Thin(
|
||||||
const CreationContext& creation_context, const OperationDef& definition,
|
const CreationContext& creation_context, const OperationDef& definition,
|
||||||
const ConvolutionTransposedAttributes& attr,
|
const ConvolutionTransposedAttributes& attr,
|
||||||
ConvolutionTransposed3x3Thin* result);
|
ConvolutionTransposed3x3Thin* result);
|
||||||
@ -59,14 +59,14 @@ class ConvolutionTransposed3x3Thin : public GPUOperation {
|
|||||||
const OperationDef& definition,
|
const OperationDef& definition,
|
||||||
const ConvolutionTransposedAttributes& attr);
|
const ConvolutionTransposedAttributes& attr);
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
Status UploadWeights(const ::tflite::gpu::Tensor<OHWI, T>& weights,
|
absl::Status UploadWeights(const ::tflite::gpu::Tensor<OHWI, T>& weights,
|
||||||
CLContext* context);
|
CLContext* context);
|
||||||
|
|
||||||
template <DataType S, typename T>
|
template <DataType S, typename T>
|
||||||
void RearrangeWeightsData(const ::tflite::gpu::Tensor<OHWI, S>& weights,
|
void RearrangeWeightsData(const ::tflite::gpu::Tensor<OHWI, S>& weights,
|
||||||
absl::Span<T> dst);
|
absl::Span<T> dst);
|
||||||
|
|
||||||
Status BindArguments();
|
absl::Status BindArguments();
|
||||||
int3 GetGridSize() const;
|
int3 GetGridSize() const;
|
||||||
|
|
||||||
Buffer weights_;
|
Buffer weights_;
|
||||||
@ -80,7 +80,7 @@ class ConvolutionTransposed3x3Thin : public GPUOperation {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
Status ConvolutionTransposed3x3Thin::UploadWeights(
|
absl::Status ConvolutionTransposed3x3Thin::UploadWeights(
|
||||||
const ::tflite::gpu::Tensor<OHWI, T>& weights, CLContext* context) {
|
const ::tflite::gpu::Tensor<OHWI, T>& weights, CLContext* context) {
|
||||||
const int src_depth = IntegralDivideRoundUp(src_channels_, 4);
|
const int src_depth = IntegralDivideRoundUp(src_channels_, 4);
|
||||||
const int dst_depth = IntegralDivideRoundUp(dst_channels_, 4);
|
const int dst_depth = IntegralDivideRoundUp(dst_channels_, 4);
|
||||||
@ -150,7 +150,7 @@ void ConvolutionTransposed3x3Thin::RearrangeWeightsData(
|
|||||||
bool IsConvolutionTransposed3x3ThinSupported(
|
bool IsConvolutionTransposed3x3ThinSupported(
|
||||||
const CLDevice& device, const ConvolutionTransposedAttributes& attr);
|
const CLDevice& device, const ConvolutionTransposedAttributes& attr);
|
||||||
|
|
||||||
Status CreateConvolutionTransposed3x3Thin(
|
absl::Status CreateConvolutionTransposed3x3Thin(
|
||||||
const CreationContext& creation_context, const OperationDef& definition,
|
const CreationContext& creation_context, const OperationDef& definition,
|
||||||
const ConvolutionTransposedAttributes& attr,
|
const ConvolutionTransposedAttributes& attr,
|
||||||
ConvolutionTransposed3x3Thin* result);
|
ConvolutionTransposed3x3Thin* result);
|
||||||
|
@ -301,7 +301,7 @@ ConvolutionTransposed4x4& ConvolutionTransposed4x4::operator=(
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConvolutionTransposed4x4::Compile(
|
absl::Status ConvolutionTransposed4x4::Compile(
|
||||||
const CreationContext& creation_context) {
|
const CreationContext& creation_context) {
|
||||||
const auto code = GenerateConvolutionTransposedCode(
|
const auto code = GenerateConvolutionTransposedCode(
|
||||||
definition_, biases_, linked_operations_, weights_upload_type_);
|
definition_, biases_, linked_operations_, weights_upload_type_);
|
||||||
@ -314,11 +314,10 @@ Status ConvolutionTransposed4x4::Compile(
|
|||||||
RETURN_IF_ERROR(creation_context.cache->GetOrCreateCLKernel(
|
RETURN_IF_ERROR(creation_context.cache->GetOrCreateCLKernel(
|
||||||
code, "main_function", options, *creation_context.context,
|
code, "main_function", options, *creation_context.context,
|
||||||
*creation_context.device, &kernel_));
|
*creation_context.device, &kernel_));
|
||||||
|
return absl::OkStatus();
|
||||||
return OkStatus();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConvolutionTransposed4x4::BindArguments() {
|
absl::Status ConvolutionTransposed4x4::BindArguments() {
|
||||||
kernel_.ResetBindingCounter();
|
kernel_.ResetBindingCounter();
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_.GetMemoryPtr()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_.GetMemoryPtr()));
|
||||||
@ -329,8 +328,7 @@ Status ConvolutionTransposed4x4::BindArguments() {
|
|||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHSB()));
|
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHSB()));
|
||||||
const int32_t filters_offset = 4 * 16 * src_[0]->Slices();
|
const int32_t filters_offset = 4 * 16 * src_[0]->Slices();
|
||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(filters_offset));
|
RETURN_IF_ERROR(kernel_.SetBytesAuto(filters_offset));
|
||||||
|
return absl::OkStatus();
|
||||||
return OkStatus();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int3 ConvolutionTransposed4x4::GetGridSize() const {
|
int3 ConvolutionTransposed4x4::GetGridSize() const {
|
||||||
@ -341,7 +339,7 @@ int3 ConvolutionTransposed4x4::GetGridSize() const {
|
|||||||
return int3(grid_x, grid_y, grid_z);
|
return int3(grid_x, grid_y, grid_z);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConvolutionTransposed4x4::AddToQueue(CLCommandQueue* queue) {
|
absl::Status ConvolutionTransposed4x4::AddToQueue(CLCommandQueue* queue) {
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
||||||
}
|
}
|
||||||
@ -354,13 +352,13 @@ bool IsConvolutionTransposed4x4Supported(
|
|||||||
attr.padding.prepended.w == 1 && attr.padding.prepended.h == 1;
|
attr.padding.prepended.w == 1 && attr.padding.prepended.h == 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateConvolutionTransposed4x4(
|
absl::Status CreateConvolutionTransposed4x4(
|
||||||
const CreationContext& creation_context, const OperationDef& definition,
|
const CreationContext& creation_context, const OperationDef& definition,
|
||||||
const ConvolutionTransposedAttributes& attr,
|
const ConvolutionTransposedAttributes& attr,
|
||||||
ConvolutionTransposed4x4* result) {
|
ConvolutionTransposed4x4* result) {
|
||||||
if (!IsConvolutionTransposed4x4Supported(*creation_context.device, definition,
|
if (!IsConvolutionTransposed4x4Supported(*creation_context.device, definition,
|
||||||
attr)) {
|
attr)) {
|
||||||
return InvalidArgumentError(
|
return absl::InvalidArgumentError(
|
||||||
"ConvolutionTransposed4x4 doesn't support this attributes");
|
"ConvolutionTransposed4x4 doesn't support this attributes");
|
||||||
}
|
}
|
||||||
*result = ConvolutionTransposed4x4(definition, *creation_context.device);
|
*result = ConvolutionTransposed4x4(definition, *creation_context.device);
|
||||||
@ -373,7 +371,7 @@ Status CreateConvolutionTransposed4x4(
|
|||||||
create_info.aligned_size = attr.weights.shape.o;
|
create_info.aligned_size = attr.weights.shape.o;
|
||||||
RETURN_IF_ERROR(CreateLinearStorage(
|
RETURN_IF_ERROR(CreateLinearStorage(
|
||||||
create_info, attr.bias, creation_context.context, &result->biases_));
|
create_info, attr.bias, creation_context.context, &result->biases_));
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cl
|
} // namespace cl
|
||||||
|
@ -37,8 +37,8 @@ namespace cl {
|
|||||||
class ConvolutionTransposed4x4 : public GPUOperation {
|
class ConvolutionTransposed4x4 : public GPUOperation {
|
||||||
public:
|
public:
|
||||||
ConvolutionTransposed4x4() = default;
|
ConvolutionTransposed4x4() = default;
|
||||||
Status AddToQueue(CLCommandQueue* queue) override;
|
absl::Status AddToQueue(CLCommandQueue* queue) override;
|
||||||
Status Compile(const CreationContext& creation_context) override;
|
absl::Status Compile(const CreationContext& creation_context) override;
|
||||||
|
|
||||||
// Move only
|
// Move only
|
||||||
ConvolutionTransposed4x4(ConvolutionTransposed4x4&& operation);
|
ConvolutionTransposed4x4(ConvolutionTransposed4x4&& operation);
|
||||||
@ -56,19 +56,19 @@ class ConvolutionTransposed4x4 : public GPUOperation {
|
|||||||
private:
|
private:
|
||||||
ConvolutionTransposed4x4(const OperationDef& definition,
|
ConvolutionTransposed4x4(const OperationDef& definition,
|
||||||
const CLDevice& device);
|
const CLDevice& device);
|
||||||
friend Status CreateConvolutionTransposed4x4(
|
friend absl::Status CreateConvolutionTransposed4x4(
|
||||||
const CreationContext& creation_context, const OperationDef& definition,
|
const CreationContext& creation_context, const OperationDef& definition,
|
||||||
const ConvolutionTransposedAttributes& attr,
|
const ConvolutionTransposedAttributes& attr,
|
||||||
ConvolutionTransposed4x4* result);
|
ConvolutionTransposed4x4* result);
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
Status UploadWeights(const ::tflite::gpu::Tensor<OHWI, T>& weights,
|
absl::Status UploadWeights(const ::tflite::gpu::Tensor<OHWI, T>& weights,
|
||||||
CLContext* context);
|
CLContext* context);
|
||||||
|
|
||||||
template <DataType S, typename T>
|
template <DataType S, typename T>
|
||||||
void RearrangeWeightsData(const ::tflite::gpu::Tensor<OHWI, S>& weights,
|
void RearrangeWeightsData(const ::tflite::gpu::Tensor<OHWI, S>& weights,
|
||||||
absl::Span<T> dst);
|
absl::Span<T> dst);
|
||||||
|
|
||||||
Status BindArguments();
|
absl::Status BindArguments();
|
||||||
int3 GetGridSize() const;
|
int3 GetGridSize() const;
|
||||||
|
|
||||||
Buffer weights_;
|
Buffer weights_;
|
||||||
@ -80,7 +80,7 @@ class ConvolutionTransposed4x4 : public GPUOperation {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
Status ConvolutionTransposed4x4::UploadWeights(
|
absl::Status ConvolutionTransposed4x4::UploadWeights(
|
||||||
const ::tflite::gpu::Tensor<OHWI, T>& weights, CLContext* context) {
|
const ::tflite::gpu::Tensor<OHWI, T>& weights, CLContext* context) {
|
||||||
const int src_depth = IntegralDivideRoundUp(weights.shape.i, 4);
|
const int src_depth = IntegralDivideRoundUp(weights.shape.i, 4);
|
||||||
const int dst_depth = IntegralDivideRoundUp(weights.shape.o, 4);
|
const int dst_depth = IntegralDivideRoundUp(weights.shape.o, 4);
|
||||||
@ -150,7 +150,7 @@ bool IsConvolutionTransposed4x4Supported(
|
|||||||
const CLDevice& device, const OperationDef& definition,
|
const CLDevice& device, const OperationDef& definition,
|
||||||
const ConvolutionTransposedAttributes& attr);
|
const ConvolutionTransposedAttributes& attr);
|
||||||
|
|
||||||
Status CreateConvolutionTransposed4x4(
|
absl::Status CreateConvolutionTransposed4x4(
|
||||||
const CreationContext& creation_context, const OperationDef& definition,
|
const CreationContext& creation_context, const OperationDef& definition,
|
||||||
const ConvolutionTransposedAttributes& attr,
|
const ConvolutionTransposedAttributes& attr,
|
||||||
ConvolutionTransposed4x4* result);
|
ConvolutionTransposed4x4* result);
|
||||||
|
@ -184,7 +184,7 @@ ConvolutionTransposedThin& ConvolutionTransposedThin::operator=(
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConvolutionTransposedThin::Compile(
|
absl::Status ConvolutionTransposedThin::Compile(
|
||||||
const CreationContext& creation_context) {
|
const CreationContext& creation_context) {
|
||||||
const auto code = GenerateConvolutionTransposedCode(
|
const auto code = GenerateConvolutionTransposedCode(
|
||||||
definition_, IntegralDivideRoundUp(src_channels_, 4), dst_channels_,
|
definition_, IntegralDivideRoundUp(src_channels_, 4), dst_channels_,
|
||||||
@ -201,7 +201,7 @@ Status ConvolutionTransposedThin::Compile(
|
|||||||
*creation_context.device, &kernel_);
|
*creation_context.device, &kernel_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConvolutionTransposedThin::BindArguments() {
|
absl::Status ConvolutionTransposedThin::BindArguments() {
|
||||||
kernel_.ResetBindingCounter();
|
kernel_.ResetBindingCounter();
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_buf_.GetMemoryPtr()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_buf_.GetMemoryPtr()));
|
||||||
@ -210,7 +210,7 @@ Status ConvolutionTransposedThin::BindArguments() {
|
|||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHSB()));
|
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHSB()));
|
||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB()));
|
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB()));
|
||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(bias_value_));
|
RETURN_IF_ERROR(kernel_.SetBytesAuto(bias_value_));
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
int3 ConvolutionTransposedThin::GetGridSize() const {
|
int3 ConvolutionTransposedThin::GetGridSize() const {
|
||||||
@ -220,12 +220,12 @@ int3 ConvolutionTransposedThin::GetGridSize() const {
|
|||||||
return int3(grid_x, grid_y, grid_z);
|
return int3(grid_x, grid_y, grid_z);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConvolutionTransposedThin::Tune(const TuningParameters& params) {
|
absl::Status ConvolutionTransposedThin::Tune(const TuningParameters& params) {
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_);
|
return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConvolutionTransposedThin::AddToQueue(CLCommandQueue* queue) {
|
absl::Status ConvolutionTransposedThin::AddToQueue(CLCommandQueue* queue) {
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
||||||
}
|
}
|
||||||
@ -238,18 +238,18 @@ bool IsConvolutionTransposedThinSupported(
|
|||||||
attr.padding.appended.w == 0 && attr.padding.appended.h == 0;
|
attr.padding.appended.w == 0 && attr.padding.appended.h == 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateConvolutionTransposedThin(
|
absl::Status CreateConvolutionTransposedThin(
|
||||||
const CreationContext& creation_context, const OperationDef& definition,
|
const CreationContext& creation_context, const OperationDef& definition,
|
||||||
const ConvolutionTransposedAttributes& attr,
|
const ConvolutionTransposedAttributes& attr,
|
||||||
ConvolutionTransposedThin* result) {
|
ConvolutionTransposedThin* result) {
|
||||||
if (!IsConvolutionTransposedThinSupported(*creation_context.device, attr)) {
|
if (!IsConvolutionTransposedThinSupported(*creation_context.device, attr)) {
|
||||||
return InvalidArgumentError(
|
return absl::InvalidArgumentError(
|
||||||
"ConvolutionTransposedThin doesn't support this attributes");
|
"ConvolutionTransposedThin doesn't support this attributes");
|
||||||
}
|
}
|
||||||
*result = ConvolutionTransposedThin(definition, attr);
|
*result = ConvolutionTransposedThin(definition, attr);
|
||||||
RETURN_IF_ERROR(
|
RETURN_IF_ERROR(
|
||||||
result->UploadWeights(attr.weights, creation_context.context));
|
result->UploadWeights(attr.weights, creation_context.context));
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cl
|
} // namespace cl
|
||||||
|
@ -38,10 +38,10 @@ namespace cl {
|
|||||||
class ConvolutionTransposedThin : public GPUOperation {
|
class ConvolutionTransposedThin : public GPUOperation {
|
||||||
public:
|
public:
|
||||||
ConvolutionTransposedThin() = default;
|
ConvolutionTransposedThin() = default;
|
||||||
Status AddToQueue(CLCommandQueue* queue) override;
|
absl::Status AddToQueue(CLCommandQueue* queue) override;
|
||||||
Status Tune(const TuningParameters& params) override;
|
absl::Status Tune(const TuningParameters& params) override;
|
||||||
|
|
||||||
Status Compile(const CreationContext& creation_context) override;
|
absl::Status Compile(const CreationContext& creation_context) override;
|
||||||
|
|
||||||
// Move only
|
// Move only
|
||||||
ConvolutionTransposedThin(ConvolutionTransposedThin&& operation);
|
ConvolutionTransposedThin(ConvolutionTransposedThin&& operation);
|
||||||
@ -51,21 +51,21 @@ class ConvolutionTransposedThin : public GPUOperation {
|
|||||||
delete;
|
delete;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
friend Status CreateConvolutionTransposedThin(
|
friend absl::Status CreateConvolutionTransposedThin(
|
||||||
const CreationContext& creation_context, const OperationDef& definition,
|
const CreationContext& creation_context, const OperationDef& definition,
|
||||||
const ConvolutionTransposedAttributes& attr,
|
const ConvolutionTransposedAttributes& attr,
|
||||||
ConvolutionTransposedThin* result);
|
ConvolutionTransposedThin* result);
|
||||||
ConvolutionTransposedThin(const OperationDef& definition,
|
ConvolutionTransposedThin(const OperationDef& definition,
|
||||||
const ConvolutionTransposedAttributes& attr);
|
const ConvolutionTransposedAttributes& attr);
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
Status UploadWeights(const ::tflite::gpu::Tensor<OHWI, T>& weights,
|
absl::Status UploadWeights(const ::tflite::gpu::Tensor<OHWI, T>& weights,
|
||||||
CLContext* context);
|
CLContext* context);
|
||||||
|
|
||||||
template <DataType S, typename T>
|
template <DataType S, typename T>
|
||||||
void RearrangeWeightsData(const ::tflite::gpu::Tensor<OHWI, S>& weights,
|
void RearrangeWeightsData(const ::tflite::gpu::Tensor<OHWI, S>& weights,
|
||||||
absl::Span<T> dst);
|
absl::Span<T> dst);
|
||||||
|
|
||||||
Status BindArguments();
|
absl::Status BindArguments();
|
||||||
int3 GetGridSize() const;
|
int3 GetGridSize() const;
|
||||||
|
|
||||||
Buffer weights_buf_;
|
Buffer weights_buf_;
|
||||||
@ -80,7 +80,7 @@ class ConvolutionTransposedThin : public GPUOperation {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
Status ConvolutionTransposedThin::UploadWeights(
|
absl::Status ConvolutionTransposedThin::UploadWeights(
|
||||||
const ::tflite::gpu::Tensor<OHWI, T>& weights, CLContext* context) {
|
const ::tflite::gpu::Tensor<OHWI, T>& weights, CLContext* context) {
|
||||||
const int src_depth = IntegralDivideRoundUp(src_channels_, 4);
|
const int src_depth = IntegralDivideRoundUp(src_channels_, 4);
|
||||||
const int elements_count =
|
const int elements_count =
|
||||||
@ -136,7 +136,7 @@ void ConvolutionTransposedThin::RearrangeWeightsData(
|
|||||||
bool IsConvolutionTransposedThinSupported(
|
bool IsConvolutionTransposedThinSupported(
|
||||||
const CLDevice& device, const ConvolutionTransposedAttributes& attr);
|
const CLDevice& device, const ConvolutionTransposedAttributes& attr);
|
||||||
|
|
||||||
Status CreateConvolutionTransposedThin(
|
absl::Status CreateConvolutionTransposedThin(
|
||||||
const CreationContext& creation_context, const OperationDef& definition,
|
const CreationContext& creation_context, const OperationDef& definition,
|
||||||
const ConvolutionTransposedAttributes& attr,
|
const ConvolutionTransposedAttributes& attr,
|
||||||
ConvolutionTransposedThin* result);
|
ConvolutionTransposedThin* result);
|
||||||
|
@ -226,7 +226,8 @@ DepthWiseConvolution& DepthWiseConvolution::operator=(
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DepthWiseConvolution::Compile(const CreationContext& creation_context) {
|
absl::Status DepthWiseConvolution::Compile(
|
||||||
|
const CreationContext& creation_context) {
|
||||||
const bool stride_correction =
|
const bool stride_correction =
|
||||||
definition_.IsBatchSupported() && stride_.x != 1;
|
definition_.IsBatchSupported() && stride_.x != 1;
|
||||||
const auto code = GenerateDepthWiseConvolutionCode(
|
const auto code = GenerateDepthWiseConvolutionCode(
|
||||||
@ -237,7 +238,7 @@ Status DepthWiseConvolution::Compile(const CreationContext& creation_context) {
|
|||||||
*creation_context.device, &kernel_);
|
*creation_context.device, &kernel_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DepthWiseConvolution::BindArguments() {
|
absl::Status DepthWiseConvolution::BindArguments() {
|
||||||
kernel_.ResetBindingCounter();
|
kernel_.ResetBindingCounter();
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_));
|
||||||
@ -255,7 +256,7 @@ Status DepthWiseConvolution::BindArguments() {
|
|||||||
}
|
}
|
||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWBatchedHSB()));
|
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWBatchedHSB()));
|
||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHSB()));
|
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHSB()));
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
int3 DepthWiseConvolution::GetGridSize() const {
|
int3 DepthWiseConvolution::GetGridSize() const {
|
||||||
@ -265,20 +266,20 @@ int3 DepthWiseConvolution::GetGridSize() const {
|
|||||||
return int3(grid_x, grid_y, grid_z);
|
return int3(grid_x, grid_y, grid_z);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DepthWiseConvolution::Tune(const TuningParameters& params) {
|
absl::Status DepthWiseConvolution::Tune(const TuningParameters& params) {
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_);
|
return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DepthWiseConvolution::AddToQueue(CLCommandQueue* queue) {
|
absl::Status DepthWiseConvolution::AddToQueue(CLCommandQueue* queue) {
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateDepthWiseConvolution(const CreationContext& creation_context,
|
absl::Status CreateDepthWiseConvolution(
|
||||||
const OperationDef& definition,
|
const CreationContext& creation_context, const OperationDef& definition,
|
||||||
const DepthwiseConvolution2DAttributes& attr,
|
const DepthwiseConvolution2DAttributes& attr,
|
||||||
DepthWiseConvolution* result) {
|
DepthWiseConvolution* result) {
|
||||||
bool weights_are_buffer = creation_context.device->IsMali();
|
bool weights_are_buffer = creation_context.device->IsMali();
|
||||||
*result = DepthWiseConvolution(definition, attr, weights_are_buffer);
|
*result = DepthWiseConvolution(definition, attr, weights_are_buffer);
|
||||||
RETURN_IF_ERROR(
|
RETURN_IF_ERROR(
|
||||||
@ -291,7 +292,7 @@ Status CreateDepthWiseConvolution(const CreationContext& creation_context,
|
|||||||
create_info.aligned_size = attr.weights.shape.o * attr.weights.shape.i;
|
create_info.aligned_size = attr.weights.shape.o * attr.weights.shape.i;
|
||||||
RETURN_IF_ERROR(CreateLinearStorage(
|
RETURN_IF_ERROR(CreateLinearStorage(
|
||||||
create_info, attr.bias, creation_context.context, &result->biases_));
|
create_info, attr.bias, creation_context.context, &result->biases_));
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cl
|
} // namespace cl
|
||||||
|
@ -38,10 +38,10 @@ namespace cl {
|
|||||||
class DepthWiseConvolution : public GPUOperation {
|
class DepthWiseConvolution : public GPUOperation {
|
||||||
public:
|
public:
|
||||||
DepthWiseConvolution() = default;
|
DepthWiseConvolution() = default;
|
||||||
Status AddToQueue(CLCommandQueue* queue) override;
|
absl::Status AddToQueue(CLCommandQueue* queue) override;
|
||||||
Status Tune(const TuningParameters& params) override;
|
absl::Status Tune(const TuningParameters& params) override;
|
||||||
|
|
||||||
Status Compile(const CreationContext& creation_context) override;
|
absl::Status Compile(const CreationContext& creation_context) override;
|
||||||
|
|
||||||
// Move only
|
// Move only
|
||||||
DepthWiseConvolution(DepthWiseConvolution&& operation);
|
DepthWiseConvolution(DepthWiseConvolution&& operation);
|
||||||
@ -50,7 +50,7 @@ class DepthWiseConvolution : public GPUOperation {
|
|||||||
DepthWiseConvolution& operator=(const DepthWiseConvolution&) = delete;
|
DepthWiseConvolution& operator=(const DepthWiseConvolution&) = delete;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
friend Status CreateDepthWiseConvolution(
|
friend absl::Status CreateDepthWiseConvolution(
|
||||||
const CreationContext& creation_context, const OperationDef& definition,
|
const CreationContext& creation_context, const OperationDef& definition,
|
||||||
const DepthwiseConvolution2DAttributes& attr,
|
const DepthwiseConvolution2DAttributes& attr,
|
||||||
DepthWiseConvolution* result);
|
DepthWiseConvolution* result);
|
||||||
@ -58,14 +58,14 @@ class DepthWiseConvolution : public GPUOperation {
|
|||||||
const DepthwiseConvolution2DAttributes& attr,
|
const DepthwiseConvolution2DAttributes& attr,
|
||||||
bool weights_are_buffer);
|
bool weights_are_buffer);
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
Status UploadWeights(const ::tflite::gpu::Tensor<OHWI, T>& weights,
|
absl::Status UploadWeights(const ::tflite::gpu::Tensor<OHWI, T>& weights,
|
||||||
CLContext* context);
|
CLContext* context);
|
||||||
|
|
||||||
template <DataType S, typename T>
|
template <DataType S, typename T>
|
||||||
void RearrangeWeightsData(const ::tflite::gpu::Tensor<OHWI, S>& weights,
|
void RearrangeWeightsData(const ::tflite::gpu::Tensor<OHWI, S>& weights,
|
||||||
absl::Span<T> dst);
|
absl::Span<T> dst);
|
||||||
|
|
||||||
Status BindArguments();
|
absl::Status BindArguments();
|
||||||
int3 GetGridSize() const;
|
int3 GetGridSize() const;
|
||||||
|
|
||||||
bool weights_are_buffer_;
|
bool weights_are_buffer_;
|
||||||
@ -86,7 +86,7 @@ class DepthWiseConvolution : public GPUOperation {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
Status DepthWiseConvolution::UploadWeights(
|
absl::Status DepthWiseConvolution::UploadWeights(
|
||||||
const ::tflite::gpu::Tensor<OHWI, T>& weights, CLContext* context) {
|
const ::tflite::gpu::Tensor<OHWI, T>& weights, CLContext* context) {
|
||||||
const int dst_channels = weights.shape.i * weights.shape.o;
|
const int dst_channels = weights.shape.i * weights.shape.o;
|
||||||
const int dst_depth = IntegralDivideRoundUp(dst_channels, 4);
|
const int dst_depth = IntegralDivideRoundUp(dst_channels, 4);
|
||||||
@ -130,7 +130,7 @@ Status DepthWiseConvolution::UploadWeights(
|
|||||||
weights_ = weights_tex2d_.GetMemoryPtr();
|
weights_ = weights_tex2d_.GetMemoryPtr();
|
||||||
}
|
}
|
||||||
|
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
template <DataType S, typename T>
|
template <DataType S, typename T>
|
||||||
@ -162,10 +162,9 @@ void DepthWiseConvolution::RearrangeWeightsData(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateDepthWiseConvolution(const CreationContext& creation_context,
|
absl::Status CreateDepthWiseConvolution(
|
||||||
const OperationDef& definition,
|
const CreationContext& creation_context, const OperationDef& definition,
|
||||||
const DepthwiseConvolution2DAttributes& attr,
|
const DepthwiseConvolution2DAttributes& attr, DepthWiseConvolution* result);
|
||||||
DepthWiseConvolution* result);
|
|
||||||
|
|
||||||
} // namespace cl
|
} // namespace cl
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
|
@ -256,7 +256,7 @@ DepthWiseConvolution3D& DepthWiseConvolution3D::operator=(
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DepthWiseConvolution3D::Compile(
|
absl::Status DepthWiseConvolution3D::Compile(
|
||||||
const CreationContext& creation_context) {
|
const CreationContext& creation_context) {
|
||||||
const bool stride_correction =
|
const bool stride_correction =
|
||||||
definition_.IsBatchSupported() && stride_.x != 1;
|
definition_.IsBatchSupported() && stride_.x != 1;
|
||||||
@ -268,7 +268,7 @@ Status DepthWiseConvolution3D::Compile(
|
|||||||
*creation_context.device, &kernel_);
|
*creation_context.device, &kernel_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DepthWiseConvolution3D::BindArguments() {
|
absl::Status DepthWiseConvolution3D::BindArguments() {
|
||||||
kernel_.ResetBindingCounter();
|
kernel_.ResetBindingCounter();
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||||
if (weights_are_buffer_) {
|
if (weights_are_buffer_) {
|
||||||
@ -295,7 +295,7 @@ Status DepthWiseConvolution3D::BindArguments() {
|
|||||||
}
|
}
|
||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWBatchedHDS()));
|
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWBatchedHDS()));
|
||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHDS()));
|
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHDS()));
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
int3 DepthWiseConvolution3D::GetGridSize() const {
|
int3 DepthWiseConvolution3D::GetGridSize() const {
|
||||||
@ -305,17 +305,17 @@ int3 DepthWiseConvolution3D::GetGridSize() const {
|
|||||||
return int3(grid_x, grid_y, grid_z);
|
return int3(grid_x, grid_y, grid_z);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DepthWiseConvolution3D::Tune(const TuningParameters& params) {
|
absl::Status DepthWiseConvolution3D::Tune(const TuningParameters& params) {
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_);
|
return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DepthWiseConvolution3D::AddToQueue(CLCommandQueue* queue) {
|
absl::Status DepthWiseConvolution3D::AddToQueue(CLCommandQueue* queue) {
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateDepthWiseConvolution3D(
|
absl::Status CreateDepthWiseConvolution3D(
|
||||||
const CreationContext& creation_context, const OperationDef& definition,
|
const CreationContext& creation_context, const OperationDef& definition,
|
||||||
const DepthwiseConvolution3DAttributes& attr,
|
const DepthwiseConvolution3DAttributes& attr,
|
||||||
DepthWiseConvolution3D* result) {
|
DepthWiseConvolution3D* result) {
|
||||||
@ -330,7 +330,7 @@ Status CreateDepthWiseConvolution3D(
|
|||||||
create_info.aligned_size = attr.weights.shape.o * attr.weights.shape.i;
|
create_info.aligned_size = attr.weights.shape.o * attr.weights.shape.i;
|
||||||
RETURN_IF_ERROR(CreateLinearStorage(
|
RETURN_IF_ERROR(CreateLinearStorage(
|
||||||
create_info, attr.bias, creation_context.context, &result->biases_));
|
create_info, attr.bias, creation_context.context, &result->biases_));
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cl
|
} // namespace cl
|
||||||
|
@ -38,10 +38,10 @@ namespace cl {
|
|||||||
class DepthWiseConvolution3D : public GPUOperation {
|
class DepthWiseConvolution3D : public GPUOperation {
|
||||||
public:
|
public:
|
||||||
DepthWiseConvolution3D() = default;
|
DepthWiseConvolution3D() = default;
|
||||||
Status AddToQueue(CLCommandQueue* queue) override;
|
absl::Status AddToQueue(CLCommandQueue* queue) override;
|
||||||
Status Tune(const TuningParameters& params) override;
|
absl::Status Tune(const TuningParameters& params) override;
|
||||||
|
|
||||||
Status Compile(const CreationContext& creation_context) override;
|
absl::Status Compile(const CreationContext& creation_context) override;
|
||||||
|
|
||||||
// Move only
|
// Move only
|
||||||
DepthWiseConvolution3D(DepthWiseConvolution3D&& operation);
|
DepthWiseConvolution3D(DepthWiseConvolution3D&& operation);
|
||||||
@ -50,7 +50,7 @@ class DepthWiseConvolution3D : public GPUOperation {
|
|||||||
DepthWiseConvolution3D& operator=(const DepthWiseConvolution3D&) = delete;
|
DepthWiseConvolution3D& operator=(const DepthWiseConvolution3D&) = delete;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
friend Status CreateDepthWiseConvolution3D(
|
friend absl::Status CreateDepthWiseConvolution3D(
|
||||||
const CreationContext& creation_context, const OperationDef& definition,
|
const CreationContext& creation_context, const OperationDef& definition,
|
||||||
const DepthwiseConvolution3DAttributes& attr,
|
const DepthwiseConvolution3DAttributes& attr,
|
||||||
DepthWiseConvolution3D* result);
|
DepthWiseConvolution3D* result);
|
||||||
@ -58,14 +58,14 @@ class DepthWiseConvolution3D : public GPUOperation {
|
|||||||
const DepthwiseConvolution3DAttributes& attr,
|
const DepthwiseConvolution3DAttributes& attr,
|
||||||
const CLDevice& device);
|
const CLDevice& device);
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
Status UploadWeights(const ::tflite::gpu::Tensor<OHWDI, T>& weights,
|
absl::Status UploadWeights(const ::tflite::gpu::Tensor<OHWDI, T>& weights,
|
||||||
CLContext* context);
|
CLContext* context);
|
||||||
|
|
||||||
template <DataType S, typename T>
|
template <DataType S, typename T>
|
||||||
void RearrangeWeightsData(const ::tflite::gpu::Tensor<OHWDI, S>& weights,
|
void RearrangeWeightsData(const ::tflite::gpu::Tensor<OHWDI, S>& weights,
|
||||||
absl::Span<T> dst);
|
absl::Span<T> dst);
|
||||||
|
|
||||||
Status BindArguments();
|
absl::Status BindArguments();
|
||||||
int3 GetGridSize() const;
|
int3 GetGridSize() const;
|
||||||
|
|
||||||
Texture2D weights_tex2d_;
|
Texture2D weights_tex2d_;
|
||||||
@ -85,7 +85,7 @@ class DepthWiseConvolution3D : public GPUOperation {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
Status DepthWiseConvolution3D::UploadWeights(
|
absl::Status DepthWiseConvolution3D::UploadWeights(
|
||||||
const ::tflite::gpu::Tensor<OHWDI, T>& weights, CLContext* context) {
|
const ::tflite::gpu::Tensor<OHWDI, T>& weights, CLContext* context) {
|
||||||
const int dst_channels = weights.shape.i * weights.shape.o;
|
const int dst_channels = weights.shape.i * weights.shape.o;
|
||||||
const int dst_slices = IntegralDivideRoundUp(dst_channels, 4);
|
const int dst_slices = IntegralDivideRoundUp(dst_channels, 4);
|
||||||
@ -123,7 +123,7 @@ Status DepthWiseConvolution3D::UploadWeights(
|
|||||||
gpu_data.data(), context, &weights_tex2d_));
|
gpu_data.data(), context, &weights_tex2d_));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
template <DataType S, typename T>
|
template <DataType S, typename T>
|
||||||
@ -158,7 +158,7 @@ void DepthWiseConvolution3D::RearrangeWeightsData(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateDepthWiseConvolution3D(
|
absl::Status CreateDepthWiseConvolution3D(
|
||||||
const CreationContext& creation_context, const OperationDef& definition,
|
const CreationContext& creation_context, const OperationDef& definition,
|
||||||
const DepthwiseConvolution3DAttributes& attr,
|
const DepthwiseConvolution3DAttributes& attr,
|
||||||
DepthWiseConvolution3D* result);
|
DepthWiseConvolution3D* result);
|
||||||
|
@ -297,7 +297,8 @@ DepthWiseConv3x3& DepthWiseConv3x3::operator=(DepthWiseConv3x3&& operation) {
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DepthWiseConv3x3::Compile(const CreationContext& creation_context) {
|
absl::Status DepthWiseConv3x3::Compile(
|
||||||
|
const CreationContext& creation_context) {
|
||||||
std::string code = GenerateDepthWiseConvCode(
|
std::string code = GenerateDepthWiseConvCode(
|
||||||
definition_, linked_operations_, *creation_context.device,
|
definition_, linked_operations_, *creation_context.device,
|
||||||
weights_are_buffer_, local_mem_uploads_);
|
weights_are_buffer_, local_mem_uploads_);
|
||||||
@ -311,15 +312,14 @@ Status DepthWiseConv3x3::Compile(const CreationContext& creation_context) {
|
|||||||
*creation_context.device, &kernel_);
|
*creation_context.device, &kernel_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DepthWiseConv3x3::BindArguments() {
|
absl::Status DepthWiseConv3x3::BindArguments() {
|
||||||
kernel_.ResetBindingCounter();
|
kernel_.ResetBindingCounter();
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_));
|
||||||
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
|
||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB()));
|
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB()));
|
||||||
|
return absl::OkStatus();
|
||||||
return OkStatus();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int3 DepthWiseConv3x3::GetGridSize() const {
|
int3 DepthWiseConv3x3::GetGridSize() const {
|
||||||
@ -329,15 +329,15 @@ int3 DepthWiseConv3x3::GetGridSize() const {
|
|||||||
return int3(grid_x, grid_y, grid_z);
|
return int3(grid_x, grid_y, grid_z);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DepthWiseConv3x3::Tune(const TuningParameters& params) {
|
absl::Status DepthWiseConv3x3::Tune(const TuningParameters& params) {
|
||||||
if (local_mem_uploads_) {
|
if (local_mem_uploads_) {
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_);
|
return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DepthWiseConv3x3::AddToQueue(CLCommandQueue* queue) {
|
absl::Status DepthWiseConv3x3::AddToQueue(CLCommandQueue* queue) {
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
||||||
}
|
}
|
||||||
@ -351,12 +351,11 @@ bool IsDepthWiseConv3x3Supported(const DepthwiseConvolution2DAttributes& attr) {
|
|||||||
attr.padding.appended.h == 1;
|
attr.padding.appended.h == 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateDepthWiseConv3x3(const CreationContext& creation_context,
|
absl::Status CreateDepthWiseConv3x3(
|
||||||
const OperationDef& definition,
|
const CreationContext& creation_context, const OperationDef& definition,
|
||||||
const DepthwiseConvolution2DAttributes& attr,
|
const DepthwiseConvolution2DAttributes& attr, DepthWiseConv3x3* result) {
|
||||||
DepthWiseConv3x3* result) {
|
|
||||||
if (!IsDepthWiseConv3x3Supported(attr)) {
|
if (!IsDepthWiseConv3x3Supported(attr)) {
|
||||||
return InvalidArgumentError(
|
return absl::InvalidArgumentError(
|
||||||
"DepthWiseConv3x3 doesn't support this attributes");
|
"DepthWiseConv3x3 doesn't support this attributes");
|
||||||
}
|
}
|
||||||
bool weights_are_buffer =
|
bool weights_are_buffer =
|
||||||
@ -364,9 +363,8 @@ Status CreateDepthWiseConv3x3(const CreationContext& creation_context,
|
|||||||
bool local_mem_uploads =
|
bool local_mem_uploads =
|
||||||
weights_are_buffer && creation_context.device->IsPowerVR();
|
weights_are_buffer && creation_context.device->IsPowerVR();
|
||||||
*result = DepthWiseConv3x3(definition, weights_are_buffer, local_mem_uploads);
|
*result = DepthWiseConv3x3(definition, weights_are_buffer, local_mem_uploads);
|
||||||
RETURN_IF_ERROR(result->UploadWeightsAndBiases(attr.weights, attr.bias,
|
return result->UploadWeightsAndBiases(attr.weights, attr.bias,
|
||||||
creation_context.context));
|
creation_context.context);
|
||||||
return OkStatus();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cl
|
} // namespace cl
|
||||||
|
@ -38,10 +38,10 @@ namespace cl {
|
|||||||
class DepthWiseConv3x3 : public GPUOperation {
|
class DepthWiseConv3x3 : public GPUOperation {
|
||||||
public:
|
public:
|
||||||
DepthWiseConv3x3() = default;
|
DepthWiseConv3x3() = default;
|
||||||
Status AddToQueue(CLCommandQueue* queue) override;
|
absl::Status AddToQueue(CLCommandQueue* queue) override;
|
||||||
Status Tune(const TuningParameters& params) override;
|
absl::Status Tune(const TuningParameters& params) override;
|
||||||
|
|
||||||
Status Compile(const CreationContext& creation_context) override;
|
absl::Status Compile(const CreationContext& creation_context) override;
|
||||||
|
|
||||||
// Move only
|
// Move only
|
||||||
DepthWiseConv3x3(DepthWiseConv3x3&& operation);
|
DepthWiseConv3x3(DepthWiseConv3x3&& operation);
|
||||||
@ -53,11 +53,11 @@ class DepthWiseConv3x3 : public GPUOperation {
|
|||||||
explicit DepthWiseConv3x3(const OperationDef& definition,
|
explicit DepthWiseConv3x3(const OperationDef& definition,
|
||||||
bool weights_are_buffer, bool local_mem_uploads);
|
bool weights_are_buffer, bool local_mem_uploads);
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
Status UploadWeightsAndBiases(const ::tflite::gpu::Tensor<OHWI, T>& weights,
|
absl::Status UploadWeightsAndBiases(
|
||||||
const ::tflite::gpu::Tensor<Linear, T>& biases,
|
const ::tflite::gpu::Tensor<OHWI, T>& weights,
|
||||||
CLContext* context);
|
const ::tflite::gpu::Tensor<Linear, T>& biases, CLContext* context);
|
||||||
|
|
||||||
friend Status CreateDepthWiseConv3x3(
|
friend absl::Status CreateDepthWiseConv3x3(
|
||||||
const CreationContext& creation_context, const OperationDef& definition,
|
const CreationContext& creation_context, const OperationDef& definition,
|
||||||
const DepthwiseConvolution2DAttributes& attr, DepthWiseConv3x3* result);
|
const DepthwiseConvolution2DAttributes& attr, DepthWiseConv3x3* result);
|
||||||
|
|
||||||
@ -66,7 +66,7 @@ class DepthWiseConv3x3 : public GPUOperation {
|
|||||||
const ::tflite::gpu::Tensor<OHWI, S>& weights,
|
const ::tflite::gpu::Tensor<OHWI, S>& weights,
|
||||||
const ::tflite::gpu::Tensor<Linear, S>& biases, absl::Span<T> dst);
|
const ::tflite::gpu::Tensor<Linear, S>& biases, absl::Span<T> dst);
|
||||||
|
|
||||||
Status BindArguments();
|
absl::Status BindArguments();
|
||||||
int3 GetGridSize() const;
|
int3 GetGridSize() const;
|
||||||
|
|
||||||
bool weights_are_buffer_;
|
bool weights_are_buffer_;
|
||||||
@ -80,7 +80,7 @@ class DepthWiseConv3x3 : public GPUOperation {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
Status DepthWiseConv3x3::UploadWeightsAndBiases(
|
absl::Status DepthWiseConv3x3::UploadWeightsAndBiases(
|
||||||
const ::tflite::gpu::Tensor<OHWI, T>& weights,
|
const ::tflite::gpu::Tensor<OHWI, T>& weights,
|
||||||
const ::tflite::gpu::Tensor<Linear, T>& biases, CLContext* context) {
|
const ::tflite::gpu::Tensor<Linear, T>& biases, CLContext* context) {
|
||||||
const int src_depth = IntegralDivideRoundUp(weights.shape.i, 4);
|
const int src_depth = IntegralDivideRoundUp(weights.shape.i, 4);
|
||||||
@ -122,7 +122,7 @@ Status DepthWiseConv3x3::UploadWeightsAndBiases(
|
|||||||
weights_ = weights_tex2d_.GetMemoryPtr();
|
weights_ = weights_tex2d_.GetMemoryPtr();
|
||||||
}
|
}
|
||||||
|
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
template <DataType S, typename T>
|
template <DataType S, typename T>
|
||||||
@ -160,10 +160,9 @@ void DepthWiseConv3x3::RearrangeWeightsAndBiasesData(
|
|||||||
|
|
||||||
bool IsDepthWiseConv3x3Supported(const DepthwiseConvolution2DAttributes& attr);
|
bool IsDepthWiseConv3x3Supported(const DepthwiseConvolution2DAttributes& attr);
|
||||||
|
|
||||||
Status CreateDepthWiseConv3x3(const CreationContext& creation_context,
|
absl::Status CreateDepthWiseConv3x3(
|
||||||
const OperationDef& definition,
|
const CreationContext& creation_context, const OperationDef& definition,
|
||||||
const DepthwiseConvolution2DAttributes& attr,
|
const DepthwiseConvolution2DAttributes& attr, DepthWiseConv3x3* result);
|
||||||
DepthWiseConv3x3* result);
|
|
||||||
|
|
||||||
} // namespace cl
|
} // namespace cl
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
|
@ -203,14 +203,14 @@ std::string ElementwiseTwoInput::GetArgsDeclaration() const {
|
|||||||
return args;
|
return args;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ElementwiseTwoInput::BindArguments(CLKernel* kernel) {
|
absl::Status ElementwiseTwoInput::BindArguments(CLKernel* kernel) {
|
||||||
if (use_scalar_para_) {
|
if (use_scalar_para_) {
|
||||||
RETURN_IF_ERROR(kernel->SetBytesAuto(scalar_para_));
|
RETURN_IF_ERROR(kernel->SetBytesAuto(scalar_para_));
|
||||||
} else {
|
} else {
|
||||||
RETURN_IF_ERROR(kernel->SetMemoryAuto(src_[1]->GetMemoryPtr()));
|
RETURN_IF_ERROR(kernel->SetMemoryAuto(src_[1]->GetMemoryPtr()));
|
||||||
RETURN_IF_ERROR(kernel->SetBytesAuto(src_[1]->GetWBatchedHSB()));
|
RETURN_IF_ERROR(kernel->SetBytesAuto(src_[1]->GetWBatchedHSB()));
|
||||||
}
|
}
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
ElementwiseTwoInput CreateElementwiseTwoInput(
|
ElementwiseTwoInput CreateElementwiseTwoInput(
|
||||||
|
@ -75,7 +75,7 @@ class ElementwiseTwoInput : public ElementwiseOperation {
|
|||||||
void SetLinkIndex(int index) override;
|
void SetLinkIndex(int index) override;
|
||||||
std::string GetCoreCode(const LinkingContext& context) const override;
|
std::string GetCoreCode(const LinkingContext& context) const override;
|
||||||
std::string GetArgsDeclaration() const override;
|
std::string GetArgsDeclaration() const override;
|
||||||
Status BindArguments(CLKernel* kernel) override;
|
absl::Status BindArguments(CLKernel* kernel) override;
|
||||||
inline void SetScalarPara(FLT scalar) {
|
inline void SetScalarPara(FLT scalar) {
|
||||||
scalar_para_ = scalar;
|
scalar_para_ = scalar;
|
||||||
use_scalar_para_ = true;
|
use_scalar_para_ = true;
|
||||||
|
@ -113,7 +113,7 @@ FullyConnected& FullyConnected::operator=(FullyConnected&& kernel) {
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status FullyConnected::Compile(const CreationContext& creation_context) {
|
absl::Status FullyConnected::Compile(const CreationContext& creation_context) {
|
||||||
int wg_width = 32;
|
int wg_width = 32;
|
||||||
int wg_height = 4;
|
int wg_height = 4;
|
||||||
int work_items;
|
int work_items;
|
||||||
@ -134,10 +134,10 @@ Status FullyConnected::Compile(const CreationContext& creation_context) {
|
|||||||
}
|
}
|
||||||
work_items = work_group_size_.x * work_group_size_.y * work_group_size_.z;
|
work_items = work_group_size_.x * work_group_size_.y * work_group_size_.z;
|
||||||
} while (work_items > kernel_.GetMaxWorkGroupSize());
|
} while (work_items > kernel_.GetMaxWorkGroupSize());
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status FullyConnected::AddToQueue(CLCommandQueue* queue) {
|
absl::Status FullyConnected::AddToQueue(CLCommandQueue* queue) {
|
||||||
kernel_.ResetBindingCounter();
|
kernel_.ResetBindingCounter();
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_.GetMemoryPtr()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_.GetMemoryPtr()));
|
||||||
@ -146,15 +146,14 @@ Status FullyConnected::AddToQueue(CLCommandQueue* queue) {
|
|||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
|
||||||
RETURN_IF_ERROR(
|
RETURN_IF_ERROR(
|
||||||
kernel_.SetBytesAuto(int2(src_[0]->Slices(), dst_[0]->Slices())));
|
kernel_.SetBytesAuto(int2(src_[0]->Slices(), dst_[0]->Slices())));
|
||||||
|
|
||||||
return queue->DispatchImplicit(kernel_, {dst_[0]->Slices(), 1, 1},
|
return queue->DispatchImplicit(kernel_, {dst_[0]->Slices(), 1, 1},
|
||||||
work_group_size_);
|
work_group_size_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateFullyConnected(const CreationContext& creation_context,
|
absl::Status CreateFullyConnected(const CreationContext& creation_context,
|
||||||
const OperationDef& definition,
|
const OperationDef& definition,
|
||||||
const FullyConnectedAttributes& attr,
|
const FullyConnectedAttributes& attr,
|
||||||
FullyConnected* result) {
|
FullyConnected* result) {
|
||||||
*result = FullyConnected(definition);
|
*result = FullyConnected(definition);
|
||||||
RETURN_IF_ERROR(
|
RETURN_IF_ERROR(
|
||||||
result->UploadWeights(attr.weights, creation_context.context));
|
result->UploadWeights(attr.weights, creation_context.context));
|
||||||
@ -165,7 +164,7 @@ Status CreateFullyConnected(const CreationContext& creation_context,
|
|||||||
create_info.aligned_size = attr.weights.shape.o;
|
create_info.aligned_size = attr.weights.shape.o;
|
||||||
RETURN_IF_ERROR(CreateLinearStorage(
|
RETURN_IF_ERROR(CreateLinearStorage(
|
||||||
create_info, attr.bias, creation_context.context, &result->biases_));
|
create_info, attr.bias, creation_context.context, &result->biases_));
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cl
|
} // namespace cl
|
||||||
|
@ -37,9 +37,9 @@ namespace cl {
|
|||||||
class FullyConnected : public GPUOperation {
|
class FullyConnected : public GPUOperation {
|
||||||
public:
|
public:
|
||||||
FullyConnected() = default;
|
FullyConnected() = default;
|
||||||
Status AddToQueue(CLCommandQueue* queue) override;
|
absl::Status AddToQueue(CLCommandQueue* queue) override;
|
||||||
|
|
||||||
Status Compile(const CreationContext& creation_context) override;
|
absl::Status Compile(const CreationContext& creation_context) override;
|
||||||
|
|
||||||
// Move only
|
// Move only
|
||||||
FullyConnected(FullyConnected&& kernel);
|
FullyConnected(FullyConnected&& kernel);
|
||||||
@ -49,14 +49,13 @@ class FullyConnected : public GPUOperation {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
explicit FullyConnected(const OperationDef& definition);
|
explicit FullyConnected(const OperationDef& definition);
|
||||||
friend Status CreateFullyConnected(const CreationContext& creation_context,
|
friend absl::Status CreateFullyConnected(
|
||||||
const OperationDef& definition,
|
const CreationContext& creation_context, const OperationDef& definition,
|
||||||
const FullyConnectedAttributes& attr,
|
const FullyConnectedAttributes& attr, FullyConnected* result);
|
||||||
FullyConnected* result);
|
|
||||||
|
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
Status UploadWeights(const ::tflite::gpu::Tensor<OHWI, T>& weights,
|
absl::Status UploadWeights(const ::tflite::gpu::Tensor<OHWI, T>& weights,
|
||||||
CLContext* context);
|
CLContext* context);
|
||||||
|
|
||||||
template <DataType T, typename S>
|
template <DataType T, typename S>
|
||||||
void RearrangeWeights(const ::tflite::gpu::Tensor<OHWI, T>& weights,
|
void RearrangeWeights(const ::tflite::gpu::Tensor<OHWI, T>& weights,
|
||||||
@ -69,7 +68,7 @@ class FullyConnected : public GPUOperation {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
Status FullyConnected::UploadWeights(
|
absl::Status FullyConnected::UploadWeights(
|
||||||
const ::tflite::gpu::Tensor<OHWI, T>& weights, CLContext* context) {
|
const ::tflite::gpu::Tensor<OHWI, T>& weights, CLContext* context) {
|
||||||
const int src_depth = IntegralDivideRoundUp(weights.shape.i, 4);
|
const int src_depth = IntegralDivideRoundUp(weights.shape.i, 4);
|
||||||
const int dst_depth = IntegralDivideRoundUp(weights.shape.o, 4);
|
const int dst_depth = IntegralDivideRoundUp(weights.shape.o, 4);
|
||||||
@ -123,10 +122,10 @@ void FullyConnected::RearrangeWeights(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateFullyConnected(const CreationContext& creation_context,
|
absl::Status CreateFullyConnected(const CreationContext& creation_context,
|
||||||
const OperationDef& definition,
|
const OperationDef& definition,
|
||||||
const FullyConnectedAttributes& attr,
|
const FullyConnectedAttributes& attr,
|
||||||
FullyConnected* result);
|
FullyConnected* result);
|
||||||
|
|
||||||
} // namespace cl
|
} // namespace cl
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
|
@ -154,7 +154,7 @@ ElementwiseOperation& ElementwiseOperation::operator=(
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ElementwiseOperation::BindArguments() {
|
absl::Status ElementwiseOperation::BindArguments() {
|
||||||
kernel_.ResetBindingCounter();
|
kernel_.ResetBindingCounter();
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||||
RETURN_IF_ERROR(BindArguments(&kernel_));
|
RETURN_IF_ERROR(BindArguments(&kernel_));
|
||||||
@ -162,7 +162,7 @@ Status ElementwiseOperation::BindArguments() {
|
|||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
|
||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWBatchedHSB()));
|
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWBatchedHSB()));
|
||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHSB()));
|
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWBatchedHSB()));
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
int3 ElementwiseOperation::GetGridSize() const {
|
int3 ElementwiseOperation::GetGridSize() const {
|
||||||
@ -172,19 +172,20 @@ int3 ElementwiseOperation::GetGridSize() const {
|
|||||||
return int3(grid_x, grid_y, grid_z);
|
return int3(grid_x, grid_y, grid_z);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ElementwiseOperation::Compile(const CreationContext& creation_context) {
|
absl::Status ElementwiseOperation::Compile(
|
||||||
|
const CreationContext& creation_context) {
|
||||||
const auto code = GetElementWiseCode(definition_, *this, linked_operations_);
|
const auto code = GetElementWiseCode(definition_, *this, linked_operations_);
|
||||||
return creation_context.cache->GetOrCreateCLKernel(
|
return creation_context.cache->GetOrCreateCLKernel(
|
||||||
code, "main_function", *creation_context.context,
|
code, "main_function", *creation_context.context,
|
||||||
*creation_context.device, &kernel_);
|
*creation_context.device, &kernel_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ElementwiseOperation::AddToQueue(CLCommandQueue* queue) {
|
absl::Status ElementwiseOperation::AddToQueue(CLCommandQueue* queue) {
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ElementwiseOperation::Tune(const TuningParameters& params) {
|
absl::Status ElementwiseOperation::Tune(const TuningParameters& params) {
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_);
|
return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_);
|
||||||
}
|
}
|
||||||
@ -209,12 +210,12 @@ std::string PostProcess(const std::vector<ElementwiseOperation*>& linked_ops,
|
|||||||
return code;
|
return code;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status BindArgs(CLKernel* kernel,
|
absl::Status BindArgs(CLKernel* kernel,
|
||||||
const std::vector<ElementwiseOperation*>& linked_ops) {
|
const std::vector<ElementwiseOperation*>& linked_ops) {
|
||||||
for (auto linked_op : linked_ops) {
|
for (auto linked_op : linked_ops) {
|
||||||
RETURN_IF_ERROR(linked_op->BindArguments(kernel));
|
RETURN_IF_ERROR(linked_op->BindArguments(kernel));
|
||||||
}
|
}
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cl
|
} // namespace cl
|
||||||
|
@ -96,11 +96,15 @@ class GPUOperation {
|
|||||||
void SetSrc(Tensor* ptr, int index = 0);
|
void SetSrc(Tensor* ptr, int index = 0);
|
||||||
void SetDst(Tensor* ptr, int index = 0);
|
void SetDst(Tensor* ptr, int index = 0);
|
||||||
|
|
||||||
virtual Status AddToQueue(CLCommandQueue* queue) { return OkStatus(); }
|
virtual absl::Status AddToQueue(CLCommandQueue* queue) {
|
||||||
virtual Status Tune(const TuningParameters& params) { return OkStatus(); }
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
virtual absl::Status Tune(const TuningParameters& params) {
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
virtual Status Compile(const CreationContext& creation_context) {
|
virtual absl::Status Compile(const CreationContext& creation_context) {
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
const OperationDef& GetDefinition() const { return definition_; }
|
const OperationDef& GetDefinition() const { return definition_; }
|
||||||
@ -127,10 +131,10 @@ class ElementwiseOperation : public GPUOperation {
|
|||||||
: GPUOperation(definition) {}
|
: GPUOperation(definition) {}
|
||||||
|
|
||||||
virtual ~ElementwiseOperation() {}
|
virtual ~ElementwiseOperation() {}
|
||||||
Status AddToQueue(CLCommandQueue* queue) override;
|
absl::Status AddToQueue(CLCommandQueue* queue) override;
|
||||||
Status Tune(const TuningParameters& params) override;
|
absl::Status Tune(const TuningParameters& params) override;
|
||||||
|
|
||||||
Status Compile(const CreationContext& creation_context) override;
|
absl::Status Compile(const CreationContext& creation_context) override;
|
||||||
|
|
||||||
// Move only
|
// Move only
|
||||||
ElementwiseOperation(ElementwiseOperation&& operation);
|
ElementwiseOperation(ElementwiseOperation&& operation);
|
||||||
@ -150,10 +154,12 @@ class ElementwiseOperation : public GPUOperation {
|
|||||||
|
|
||||||
virtual std::string GetCoreCode(const LinkingContext& context) const = 0;
|
virtual std::string GetCoreCode(const LinkingContext& context) const = 0;
|
||||||
virtual std::string GetArgsDeclaration() const { return ""; }
|
virtual std::string GetArgsDeclaration() const { return ""; }
|
||||||
virtual Status BindArguments(CLKernel* kernel) { return OkStatus(); }
|
virtual absl::Status BindArguments(CLKernel* kernel) {
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
Status BindArguments();
|
absl::Status BindArguments();
|
||||||
int3 GetGridSize() const;
|
int3 GetGridSize() const;
|
||||||
CLKernel kernel_;
|
CLKernel kernel_;
|
||||||
int3 work_group_size_ = int3(8, 4, 1);
|
int3 work_group_size_ = int3(8, 4, 1);
|
||||||
@ -171,8 +177,8 @@ std::string PostProcess(const std::vector<ElementwiseOperation*>& linked_ops,
|
|||||||
// Binds arguments to given kernel for elementwise operations in
|
// Binds arguments to given kernel for elementwise operations in
|
||||||
// linked_ops.
|
// linked_ops.
|
||||||
// Every ElementwiseOperation can bind her arguments.
|
// Every ElementwiseOperation can bind her arguments.
|
||||||
Status BindArgs(CLKernel* kernel,
|
absl::Status BindArgs(CLKernel* kernel,
|
||||||
const std::vector<ElementwiseOperation*>& linked_ops);
|
const std::vector<ElementwiseOperation*>& linked_ops);
|
||||||
|
|
||||||
} // namespace cl
|
} // namespace cl
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
|
@ -121,14 +121,14 @@ LSTM& LSTM::operator=(LSTM&& kernel) {
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status LSTM::Compile(const CreationContext& creation_context) {
|
absl::Status LSTM::Compile(const CreationContext& creation_context) {
|
||||||
const auto code = GetLSTMCode(definition_, *creation_context.device);
|
const auto code = GetLSTMCode(definition_, *creation_context.device);
|
||||||
return creation_context.cache->GetOrCreateCLKernel(
|
return creation_context.cache->GetOrCreateCLKernel(
|
||||||
code, "main_function", *creation_context.context,
|
code, "main_function", *creation_context.context,
|
||||||
*creation_context.device, &kernel_);
|
*creation_context.device, &kernel_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status LSTM::BindArguments() {
|
absl::Status LSTM::BindArguments() {
|
||||||
kernel_.ResetBindingCounter();
|
kernel_.ResetBindingCounter();
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[1]->GetMemoryPtr()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[1]->GetMemoryPtr()));
|
||||||
@ -137,8 +137,7 @@ Status LSTM::BindArguments() {
|
|||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHSB()));
|
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHSB()));
|
||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB()));
|
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB()));
|
||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->Batch()));
|
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->Batch()));
|
||||||
|
return absl::OkStatus();
|
||||||
return OkStatus();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int3 LSTM::GetGridSize() const {
|
int3 LSTM::GetGridSize() const {
|
||||||
@ -148,12 +147,12 @@ int3 LSTM::GetGridSize() const {
|
|||||||
return int3(grid_x, grid_y, grid_z);
|
return int3(grid_x, grid_y, grid_z);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status LSTM::Tune(const TuningParameters& params) {
|
absl::Status LSTM::Tune(const TuningParameters& params) {
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_);
|
return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status LSTM::AddToQueue(CLCommandQueue* queue) {
|
absl::Status LSTM::AddToQueue(CLCommandQueue* queue) {
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
||||||
}
|
}
|
||||||
|
@ -28,9 +28,9 @@ namespace cl {
|
|||||||
class LSTM : public GPUOperation {
|
class LSTM : public GPUOperation {
|
||||||
public:
|
public:
|
||||||
explicit LSTM(const OperationDef& definition);
|
explicit LSTM(const OperationDef& definition);
|
||||||
Status AddToQueue(CLCommandQueue* queue) override;
|
absl::Status AddToQueue(CLCommandQueue* queue) override;
|
||||||
Status Tune(const TuningParameters& params) override;
|
absl::Status Tune(const TuningParameters& params) override;
|
||||||
Status Compile(const CreationContext& creation_context) override;
|
absl::Status Compile(const CreationContext& creation_context) override;
|
||||||
|
|
||||||
// Move only
|
// Move only
|
||||||
LSTM(LSTM&& kernel);
|
LSTM(LSTM&& kernel);
|
||||||
@ -39,7 +39,7 @@ class LSTM : public GPUOperation {
|
|||||||
LSTM& operator=(const LSTM&) = delete;
|
LSTM& operator=(const LSTM&) = delete;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Status BindArguments();
|
absl::Status BindArguments();
|
||||||
int3 GetGridSize() const;
|
int3 GetGridSize() const;
|
||||||
|
|
||||||
CLKernel kernel_;
|
CLKernel kernel_;
|
||||||
|
@ -218,7 +218,7 @@ MaxUnpooling& MaxUnpooling::operator=(MaxUnpooling&& kernel) {
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status MaxUnpooling::Compile(const CreationContext& creation_context) {
|
absl::Status MaxUnpooling::Compile(const CreationContext& creation_context) {
|
||||||
const auto code = GetMaxUnpoolingKernelCode(
|
const auto code = GetMaxUnpoolingKernelCode(
|
||||||
definition_, *creation_context.device, linked_operations_);
|
definition_, *creation_context.device, linked_operations_);
|
||||||
return creation_context.cache->GetOrCreateCLKernel(
|
return creation_context.cache->GetOrCreateCLKernel(
|
||||||
@ -226,7 +226,7 @@ Status MaxUnpooling::Compile(const CreationContext& creation_context) {
|
|||||||
*creation_context.device, &kernel_);
|
*creation_context.device, &kernel_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status MaxUnpooling::BindArguments() {
|
absl::Status MaxUnpooling::BindArguments() {
|
||||||
kernel_.ResetBindingCounter();
|
kernel_.ResetBindingCounter();
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[1]->GetMemoryPtr()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[1]->GetMemoryPtr()));
|
||||||
@ -237,8 +237,7 @@ Status MaxUnpooling::BindArguments() {
|
|||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(kernel_size_));
|
RETURN_IF_ERROR(kernel_.SetBytesAuto(kernel_size_));
|
||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(padding_));
|
RETURN_IF_ERROR(kernel_.SetBytesAuto(padding_));
|
||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(stride_));
|
RETURN_IF_ERROR(kernel_.SetBytesAuto(stride_));
|
||||||
|
return absl::OkStatus();
|
||||||
return OkStatus();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int3 MaxUnpooling::GetGridSize() const {
|
int3 MaxUnpooling::GetGridSize() const {
|
||||||
@ -248,12 +247,12 @@ int3 MaxUnpooling::GetGridSize() const {
|
|||||||
return int3(grid_x, grid_y, grid_z);
|
return int3(grid_x, grid_y, grid_z);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status MaxUnpooling::Tune(const TuningParameters& params) {
|
absl::Status MaxUnpooling::Tune(const TuningParameters& params) {
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_);
|
return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status MaxUnpooling::AddToQueue(CLCommandQueue* queue) {
|
absl::Status MaxUnpooling::AddToQueue(CLCommandQueue* queue) {
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
||||||
}
|
}
|
||||||
@ -291,7 +290,7 @@ MaxUnpooling3D& MaxUnpooling3D::operator=(MaxUnpooling3D&& kernel) {
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status MaxUnpooling3D::Compile(const CreationContext& creation_context) {
|
absl::Status MaxUnpooling3D::Compile(const CreationContext& creation_context) {
|
||||||
const auto code = GetMaxUnpooling3DKernelCode(
|
const auto code = GetMaxUnpooling3DKernelCode(
|
||||||
definition_, *creation_context.device, linked_operations_);
|
definition_, *creation_context.device, linked_operations_);
|
||||||
return creation_context.cache->GetOrCreateCLKernel(
|
return creation_context.cache->GetOrCreateCLKernel(
|
||||||
@ -299,7 +298,7 @@ Status MaxUnpooling3D::Compile(const CreationContext& creation_context) {
|
|||||||
*creation_context.device, &kernel_);
|
*creation_context.device, &kernel_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status MaxUnpooling3D::BindArguments() {
|
absl::Status MaxUnpooling3D::BindArguments() {
|
||||||
kernel_.ResetBindingCounter();
|
kernel_.ResetBindingCounter();
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[1]->GetMemoryPtr()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[1]->GetMemoryPtr()));
|
||||||
@ -316,8 +315,7 @@ Status MaxUnpooling3D::BindArguments() {
|
|||||||
kernel_.SetBytesAuto(int4(padding_.x, padding_.y, padding_.z, 1)));
|
kernel_.SetBytesAuto(int4(padding_.x, padding_.y, padding_.z, 1)));
|
||||||
RETURN_IF_ERROR(
|
RETURN_IF_ERROR(
|
||||||
kernel_.SetBytesAuto(int4(stride_.x, stride_.y, stride_.z, 1)));
|
kernel_.SetBytesAuto(int4(stride_.x, stride_.y, stride_.z, 1)));
|
||||||
|
return absl::OkStatus();
|
||||||
return OkStatus();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int3 MaxUnpooling3D::GetGridSize() const {
|
int3 MaxUnpooling3D::GetGridSize() const {
|
||||||
@ -327,12 +325,12 @@ int3 MaxUnpooling3D::GetGridSize() const {
|
|||||||
return int3(grid_x, grid_y, grid_z);
|
return int3(grid_x, grid_y, grid_z);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status MaxUnpooling3D::Tune(const TuningParameters& params) {
|
absl::Status MaxUnpooling3D::Tune(const TuningParameters& params) {
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_);
|
return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status MaxUnpooling3D::AddToQueue(CLCommandQueue* queue) {
|
absl::Status MaxUnpooling3D::AddToQueue(CLCommandQueue* queue) {
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
||||||
}
|
}
|
||||||
|
@ -29,10 +29,10 @@ class MaxUnpooling : public GPUOperation {
|
|||||||
public:
|
public:
|
||||||
MaxUnpooling(const OperationDef& definition,
|
MaxUnpooling(const OperationDef& definition,
|
||||||
const MaxUnpooling2DAttributes& attr);
|
const MaxUnpooling2DAttributes& attr);
|
||||||
Status AddToQueue(CLCommandQueue* queue) override;
|
absl::Status AddToQueue(CLCommandQueue* queue) override;
|
||||||
Status Tune(const TuningParameters& params) override;
|
absl::Status Tune(const TuningParameters& params) override;
|
||||||
|
|
||||||
Status Compile(const CreationContext& creation_context) override;
|
absl::Status Compile(const CreationContext& creation_context) override;
|
||||||
|
|
||||||
// Move only
|
// Move only
|
||||||
MaxUnpooling(MaxUnpooling&& kernel);
|
MaxUnpooling(MaxUnpooling&& kernel);
|
||||||
@ -41,7 +41,7 @@ class MaxUnpooling : public GPUOperation {
|
|||||||
MaxUnpooling& operator=(const MaxUnpooling&) = delete;
|
MaxUnpooling& operator=(const MaxUnpooling&) = delete;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Status BindArguments();
|
absl::Status BindArguments();
|
||||||
int3 GetGridSize() const;
|
int3 GetGridSize() const;
|
||||||
|
|
||||||
int2 stride_;
|
int2 stride_;
|
||||||
@ -59,10 +59,10 @@ class MaxUnpooling3D : public GPUOperation {
|
|||||||
public:
|
public:
|
||||||
MaxUnpooling3D(const OperationDef& definition,
|
MaxUnpooling3D(const OperationDef& definition,
|
||||||
const MaxUnpooling3DAttributes& attr);
|
const MaxUnpooling3DAttributes& attr);
|
||||||
Status AddToQueue(CLCommandQueue* queue) override;
|
absl::Status AddToQueue(CLCommandQueue* queue) override;
|
||||||
Status Tune(const TuningParameters& params) override;
|
absl::Status Tune(const TuningParameters& params) override;
|
||||||
|
|
||||||
Status Compile(const CreationContext& creation_context) override;
|
absl::Status Compile(const CreationContext& creation_context) override;
|
||||||
|
|
||||||
// Move only
|
// Move only
|
||||||
MaxUnpooling3D(MaxUnpooling3D&& kernel);
|
MaxUnpooling3D(MaxUnpooling3D&& kernel);
|
||||||
@ -71,7 +71,7 @@ class MaxUnpooling3D : public GPUOperation {
|
|||||||
MaxUnpooling3D& operator=(const MaxUnpooling3D&) = delete;
|
MaxUnpooling3D& operator=(const MaxUnpooling3D&) = delete;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Status BindArguments();
|
absl::Status BindArguments();
|
||||||
int3 GetGridSize() const;
|
int3 GetGridSize() const;
|
||||||
|
|
||||||
int3 stride_;
|
int3 stride_;
|
||||||
|
@ -103,7 +103,7 @@ Mean& Mean::operator=(Mean&& operation) {
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Mean::Compile(const CreationContext& creation_context) {
|
absl::Status Mean::Compile(const CreationContext& creation_context) {
|
||||||
if (creation_context.device->IsAdreno3xx()) {
|
if (creation_context.device->IsAdreno3xx()) {
|
||||||
work_group_size_ = int3(16, 8, 1);
|
work_group_size_ = int3(16, 8, 1);
|
||||||
}
|
}
|
||||||
@ -114,7 +114,7 @@ Status Mean::Compile(const CreationContext& creation_context) {
|
|||||||
*creation_context.device, &kernel_);
|
*creation_context.device, &kernel_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Mean::BindArguments() {
|
absl::Status Mean::BindArguments() {
|
||||||
kernel_.ResetBindingCounter();
|
kernel_.ResetBindingCounter();
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||||
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
||||||
@ -124,7 +124,7 @@ Status Mean::BindArguments() {
|
|||||||
const double size_0 = work_group_size_.x * work_group_size_.y;
|
const double size_0 = work_group_size_.x * work_group_size_.y;
|
||||||
const double size_1 = total_size / size_0;
|
const double size_1 = total_size / size_0;
|
||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(float2(1.0 / size_1, 1.0 / size_0)));
|
RETURN_IF_ERROR(kernel_.SetBytesAuto(float2(1.0 / size_1, 1.0 / size_0)));
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
int3 Mean::GetGridSize() const {
|
int3 Mean::GetGridSize() const {
|
||||||
@ -134,7 +134,7 @@ int3 Mean::GetGridSize() const {
|
|||||||
return int3(grid_x, grid_y, grid_z);
|
return int3(grid_x, grid_y, grid_z);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Mean::AddToQueue(CLCommandQueue* queue) {
|
absl::Status Mean::AddToQueue(CLCommandQueue* queue) {
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
||||||
}
|
}
|
||||||
|
@ -30,9 +30,9 @@ class Mean : public GPUOperation {
|
|||||||
public:
|
public:
|
||||||
Mean() = default;
|
Mean() = default;
|
||||||
explicit Mean(const OperationDef& definition) : GPUOperation(definition) {}
|
explicit Mean(const OperationDef& definition) : GPUOperation(definition) {}
|
||||||
Status AddToQueue(CLCommandQueue* queue) override;
|
absl::Status AddToQueue(CLCommandQueue* queue) override;
|
||||||
|
|
||||||
Status Compile(const CreationContext& creation_context) override;
|
absl::Status Compile(const CreationContext& creation_context) override;
|
||||||
|
|
||||||
// Move only
|
// Move only
|
||||||
Mean(Mean&& operation);
|
Mean(Mean&& operation);
|
||||||
@ -41,7 +41,7 @@ class Mean : public GPUOperation {
|
|||||||
Mean& operator=(const Mean&) = delete;
|
Mean& operator=(const Mean&) = delete;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Status BindArguments();
|
absl::Status BindArguments();
|
||||||
int3 GetGridSize() const;
|
int3 GetGridSize() const;
|
||||||
CLKernel kernel_;
|
CLKernel kernel_;
|
||||||
|
|
||||||
|
@ -89,7 +89,7 @@ std::string MultiplyAdd::GetArgsDeclaration() const {
|
|||||||
return args;
|
return args;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status MultiplyAdd::BindArguments(CLKernel* kernel) {
|
absl::Status MultiplyAdd::BindArguments(CLKernel* kernel) {
|
||||||
if (use_mul_vec_) {
|
if (use_mul_vec_) {
|
||||||
RETURN_IF_ERROR(kernel->SetMemoryAuto(mul_vec_.GetMemoryPtr()));
|
RETURN_IF_ERROR(kernel->SetMemoryAuto(mul_vec_.GetMemoryPtr()));
|
||||||
}
|
}
|
||||||
@ -102,12 +102,12 @@ Status MultiplyAdd::BindArguments(CLKernel* kernel) {
|
|||||||
if (scalar_add_.Active()) {
|
if (scalar_add_.Active()) {
|
||||||
RETURN_IF_ERROR(kernel->SetBytesAuto(scalar_add_));
|
RETURN_IF_ERROR(kernel->SetBytesAuto(scalar_add_));
|
||||||
}
|
}
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status MultiplyAdd::UploadMul(const MultiplyAttributes& attr,
|
absl::Status MultiplyAdd::UploadMul(const MultiplyAttributes& attr,
|
||||||
CalculationsPrecision scalar_precision,
|
CalculationsPrecision scalar_precision,
|
||||||
CLContext* context) {
|
CLContext* context) {
|
||||||
auto mul = absl::get_if<::tflite::gpu::Tensor<Linear, DataType::FLOAT32>>(
|
auto mul = absl::get_if<::tflite::gpu::Tensor<Linear, DataType::FLOAT32>>(
|
||||||
&attr.param);
|
&attr.param);
|
||||||
auto mul_scalar = absl::get_if<float>(&attr.param);
|
auto mul_scalar = absl::get_if<float>(&attr.param);
|
||||||
@ -116,12 +116,12 @@ Status MultiplyAdd::UploadMul(const MultiplyAttributes& attr,
|
|||||||
} else {
|
} else {
|
||||||
scalar_mul_ = FLT(scalar_precision, *mul_scalar);
|
scalar_mul_ = FLT(scalar_precision, *mul_scalar);
|
||||||
}
|
}
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status MultiplyAdd::UploadAdd(const AddAttributes& attr,
|
absl::Status MultiplyAdd::UploadAdd(const AddAttributes& attr,
|
||||||
CalculationsPrecision scalar_precision,
|
CalculationsPrecision scalar_precision,
|
||||||
CLContext* context) {
|
CLContext* context) {
|
||||||
auto add = absl::get_if<::tflite::gpu::Tensor<Linear, DataType::FLOAT32>>(
|
auto add = absl::get_if<::tflite::gpu::Tensor<Linear, DataType::FLOAT32>>(
|
||||||
&attr.param);
|
&attr.param);
|
||||||
auto add_scalar = absl::get_if<float>(&attr.param);
|
auto add_scalar = absl::get_if<float>(&attr.param);
|
||||||
@ -130,12 +130,13 @@ Status MultiplyAdd::UploadAdd(const AddAttributes& attr,
|
|||||||
} else {
|
} else {
|
||||||
scalar_add_ = FLT(scalar_precision, *add_scalar);
|
scalar_add_ = FLT(scalar_precision, *add_scalar);
|
||||||
}
|
}
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateMultiplyAdd(const CreationContext& creation_context,
|
absl::Status CreateMultiplyAdd(const CreationContext& creation_context,
|
||||||
const OperationDef& definition,
|
const OperationDef& definition,
|
||||||
const MultiplyAttributes& attr, MultiplyAdd* result) {
|
const MultiplyAttributes& attr,
|
||||||
|
MultiplyAdd* result) {
|
||||||
const auto scalar_precision = creation_context.device->IsPowerVR()
|
const auto scalar_precision = creation_context.device->IsPowerVR()
|
||||||
? CalculationsPrecision::F32
|
? CalculationsPrecision::F32
|
||||||
: definition.precision;
|
: definition.precision;
|
||||||
@ -143,12 +144,12 @@ Status CreateMultiplyAdd(const CreationContext& creation_context,
|
|||||||
RETURN_IF_ERROR(
|
RETURN_IF_ERROR(
|
||||||
result->UploadMul(attr, scalar_precision, creation_context.context));
|
result->UploadMul(attr, scalar_precision, creation_context.context));
|
||||||
result->SetLinkIndex(0);
|
result->SetLinkIndex(0);
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateMultiplyAdd(const CreationContext& creation_context,
|
absl::Status CreateMultiplyAdd(const CreationContext& creation_context,
|
||||||
const OperationDef& definition,
|
const OperationDef& definition,
|
||||||
const AddAttributes& attr, MultiplyAdd* result) {
|
const AddAttributes& attr, MultiplyAdd* result) {
|
||||||
const auto scalar_precision = creation_context.device->IsPowerVR()
|
const auto scalar_precision = creation_context.device->IsPowerVR()
|
||||||
? CalculationsPrecision::F32
|
? CalculationsPrecision::F32
|
||||||
: definition.precision;
|
: definition.precision;
|
||||||
@ -156,13 +157,14 @@ Status CreateMultiplyAdd(const CreationContext& creation_context,
|
|||||||
RETURN_IF_ERROR(
|
RETURN_IF_ERROR(
|
||||||
result->UploadAdd(attr, scalar_precision, creation_context.context));
|
result->UploadAdd(attr, scalar_precision, creation_context.context));
|
||||||
result->SetLinkIndex(0);
|
result->SetLinkIndex(0);
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateMultiplyAdd(const CreationContext& creation_context,
|
absl::Status CreateMultiplyAdd(const CreationContext& creation_context,
|
||||||
const OperationDef& definition,
|
const OperationDef& definition,
|
||||||
const MultiplyAttributes& mul_attr,
|
const MultiplyAttributes& mul_attr,
|
||||||
const AddAttributes& add_attr, MultiplyAdd* result) {
|
const AddAttributes& add_attr,
|
||||||
|
MultiplyAdd* result) {
|
||||||
const auto scalar_precision = creation_context.device->IsPowerVR()
|
const auto scalar_precision = creation_context.device->IsPowerVR()
|
||||||
? CalculationsPrecision::F32
|
? CalculationsPrecision::F32
|
||||||
: definition.precision;
|
: definition.precision;
|
||||||
@ -172,7 +174,7 @@ Status CreateMultiplyAdd(const CreationContext& creation_context,
|
|||||||
RETURN_IF_ERROR(
|
RETURN_IF_ERROR(
|
||||||
result->UploadAdd(add_attr, scalar_precision, creation_context.context));
|
result->UploadAdd(add_attr, scalar_precision, creation_context.context));
|
||||||
result->SetLinkIndex(0);
|
result->SetLinkIndex(0);
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cl
|
} // namespace cl
|
||||||
|
@ -40,40 +40,42 @@ class MultiplyAdd : public ElementwiseOperation {
|
|||||||
MultiplyAdd(const MultiplyAdd&) = delete;
|
MultiplyAdd(const MultiplyAdd&) = delete;
|
||||||
MultiplyAdd& operator=(const MultiplyAdd&) = delete;
|
MultiplyAdd& operator=(const MultiplyAdd&) = delete;
|
||||||
|
|
||||||
Status UploadMul(const MultiplyAttributes& attr,
|
absl::Status UploadMul(const MultiplyAttributes& attr,
|
||||||
CalculationsPrecision scalar_precision, CLContext* context);
|
CalculationsPrecision scalar_precision,
|
||||||
Status UploadAdd(const AddAttributes& attr,
|
CLContext* context);
|
||||||
CalculationsPrecision scalar_precision, CLContext* context);
|
absl::Status UploadAdd(const AddAttributes& attr,
|
||||||
|
CalculationsPrecision scalar_precision,
|
||||||
|
CLContext* context);
|
||||||
|
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
Status UploadMul(const ::tflite::gpu::Tensor<Linear, T>& mul,
|
absl::Status UploadMul(const ::tflite::gpu::Tensor<Linear, T>& mul,
|
||||||
CLContext* context);
|
CLContext* context);
|
||||||
|
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
Status UploadAdd(const ::tflite::gpu::Tensor<Linear, T>& add,
|
absl::Status UploadAdd(const ::tflite::gpu::Tensor<Linear, T>& add,
|
||||||
CLContext* context);
|
CLContext* context);
|
||||||
|
|
||||||
void SetLinkIndex(int index) override;
|
void SetLinkIndex(int index) override;
|
||||||
std::string GetCoreCode(const LinkingContext& context) const override;
|
std::string GetCoreCode(const LinkingContext& context) const override;
|
||||||
|
|
||||||
std::string GetArgsDeclaration() const override;
|
std::string GetArgsDeclaration() const override;
|
||||||
Status BindArguments(CLKernel* kernel) override;
|
absl::Status BindArguments(CLKernel* kernel) override;
|
||||||
|
|
||||||
friend Status CreateMultiplyAdd(const CreationContext& creation_context,
|
friend absl::Status CreateMultiplyAdd(const CreationContext& creation_context,
|
||||||
const OperationDef& definition,
|
const OperationDef& definition,
|
||||||
const MultiplyAttributes& attr,
|
const MultiplyAttributes& attr,
|
||||||
MultiplyAdd* result);
|
MultiplyAdd* result);
|
||||||
|
|
||||||
friend Status CreateMultiplyAdd(const CreationContext& creation_context,
|
friend absl::Status CreateMultiplyAdd(const CreationContext& creation_context,
|
||||||
const OperationDef& definition,
|
const OperationDef& definition,
|
||||||
const AddAttributes& attr,
|
const AddAttributes& attr,
|
||||||
MultiplyAdd* result);
|
MultiplyAdd* result);
|
||||||
|
|
||||||
friend Status CreateMultiplyAdd(const CreationContext& creation_context,
|
friend absl::Status CreateMultiplyAdd(const CreationContext& creation_context,
|
||||||
const OperationDef& definition,
|
const OperationDef& definition,
|
||||||
const MultiplyAttributes& mul_attr,
|
const MultiplyAttributes& mul_attr,
|
||||||
const AddAttributes& add_attr,
|
const AddAttributes& add_attr,
|
||||||
MultiplyAdd* result);
|
MultiplyAdd* result);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
explicit MultiplyAdd(const OperationDef& definition)
|
explicit MultiplyAdd(const OperationDef& definition)
|
||||||
@ -89,41 +91,43 @@ class MultiplyAdd : public ElementwiseOperation {
|
|||||||
FLT scalar_add_;
|
FLT scalar_add_;
|
||||||
};
|
};
|
||||||
|
|
||||||
Status CreateMultiplyAdd(const CreationContext& creation_context,
|
absl::Status CreateMultiplyAdd(const CreationContext& creation_context,
|
||||||
const OperationDef& definition,
|
const OperationDef& definition,
|
||||||
const MultiplyAttributes& attr, MultiplyAdd* result);
|
const MultiplyAttributes& attr,
|
||||||
|
MultiplyAdd* result);
|
||||||
|
|
||||||
Status CreateMultiplyAdd(const CreationContext& creation_context,
|
absl::Status CreateMultiplyAdd(const CreationContext& creation_context,
|
||||||
const OperationDef& definition,
|
const OperationDef& definition,
|
||||||
const AddAttributes& attr, MultiplyAdd* result);
|
const AddAttributes& attr, MultiplyAdd* result);
|
||||||
|
|
||||||
Status CreateMultiplyAdd(const CreationContext& creation_context,
|
absl::Status CreateMultiplyAdd(const CreationContext& creation_context,
|
||||||
const OperationDef& definition,
|
const OperationDef& definition,
|
||||||
const MultiplyAttributes& mul_attr,
|
const MultiplyAttributes& mul_attr,
|
||||||
const AddAttributes& add_attr, MultiplyAdd* result);
|
const AddAttributes& add_attr,
|
||||||
|
MultiplyAdd* result);
|
||||||
|
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
Status MultiplyAdd::UploadMul(const ::tflite::gpu::Tensor<Linear, T>& mul,
|
absl::Status MultiplyAdd::UploadMul(const ::tflite::gpu::Tensor<Linear, T>& mul,
|
||||||
CLContext* context) {
|
CLContext* context) {
|
||||||
LinearStorageCreateInfo create_info;
|
LinearStorageCreateInfo create_info;
|
||||||
create_info.storage_type =
|
create_info.storage_type =
|
||||||
DeduceLinearStorageType(definition_.GetPrimaryStorageType());
|
DeduceLinearStorageType(definition_.GetPrimaryStorageType());
|
||||||
create_info.data_type = definition_.GetDataType();
|
create_info.data_type = definition_.GetDataType();
|
||||||
RETURN_IF_ERROR(CreateLinearStorage(create_info, mul, context, &mul_vec_));
|
RETURN_IF_ERROR(CreateLinearStorage(create_info, mul, context, &mul_vec_));
|
||||||
use_mul_vec_ = true;
|
use_mul_vec_ = true;
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
Status MultiplyAdd::UploadAdd(const ::tflite::gpu::Tensor<Linear, T>& add,
|
absl::Status MultiplyAdd::UploadAdd(const ::tflite::gpu::Tensor<Linear, T>& add,
|
||||||
CLContext* context) {
|
CLContext* context) {
|
||||||
LinearStorageCreateInfo create_info;
|
LinearStorageCreateInfo create_info;
|
||||||
create_info.storage_type =
|
create_info.storage_type =
|
||||||
DeduceLinearStorageType(definition_.GetPrimaryStorageType());
|
DeduceLinearStorageType(definition_.GetPrimaryStorageType());
|
||||||
create_info.data_type = definition_.GetDataType();
|
create_info.data_type = definition_.GetDataType();
|
||||||
RETURN_IF_ERROR(CreateLinearStorage(create_info, add, context, &add_vec_));
|
RETURN_IF_ERROR(CreateLinearStorage(create_info, add, context, &add_vec_));
|
||||||
use_add_vec_ = true;
|
use_add_vec_ = true;
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cl
|
} // namespace cl
|
||||||
|
@ -169,7 +169,7 @@ Padding& Padding::operator=(Padding&& kernel) {
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Padding::Compile(const CreationContext& creation_context) {
|
absl::Status Padding::Compile(const CreationContext& creation_context) {
|
||||||
const auto code =
|
const auto code =
|
||||||
GetPaddingCode(definition_, linked_operations_, attributes_);
|
GetPaddingCode(definition_, linked_operations_, attributes_);
|
||||||
return creation_context.cache->GetOrCreateCLKernel(
|
return creation_context.cache->GetOrCreateCLKernel(
|
||||||
@ -177,7 +177,7 @@ Status Padding::Compile(const CreationContext& creation_context) {
|
|||||||
*creation_context.device, &kernel_);
|
*creation_context.device, &kernel_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Padding::BindArguments() {
|
absl::Status Padding::BindArguments() {
|
||||||
kernel_.ResetBindingCounter();
|
kernel_.ResetBindingCounter();
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||||
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
||||||
@ -187,7 +187,7 @@ Status Padding::BindArguments() {
|
|||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB()));
|
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB()));
|
||||||
const auto& prep = attributes_.prepended;
|
const auto& prep = attributes_.prepended;
|
||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(int4(prep.w, prep.h, prep.c, prep.b)));
|
RETURN_IF_ERROR(kernel_.SetBytesAuto(int4(prep.w, prep.h, prep.c, prep.b)));
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
int3 Padding::GetGridSize() const {
|
int3 Padding::GetGridSize() const {
|
||||||
@ -197,12 +197,12 @@ int3 Padding::GetGridSize() const {
|
|||||||
return int3(grid_x, grid_y, grid_z);
|
return int3(grid_x, grid_y, grid_z);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Padding::Tune(const TuningParameters& params) {
|
absl::Status Padding::Tune(const TuningParameters& params) {
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_);
|
return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Padding::AddToQueue(CLCommandQueue* queue) {
|
absl::Status Padding::AddToQueue(CLCommandQueue* queue) {
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
||||||
}
|
}
|
||||||
|
@ -28,10 +28,10 @@ namespace cl {
|
|||||||
class Padding : public GPUOperation {
|
class Padding : public GPUOperation {
|
||||||
public:
|
public:
|
||||||
Padding(const OperationDef& definition, const PadAttributes& attr);
|
Padding(const OperationDef& definition, const PadAttributes& attr);
|
||||||
Status AddToQueue(CLCommandQueue* queue) override;
|
absl::Status AddToQueue(CLCommandQueue* queue) override;
|
||||||
Status Tune(const TuningParameters& params) override;
|
absl::Status Tune(const TuningParameters& params) override;
|
||||||
|
|
||||||
Status Compile(const CreationContext& creation_context) override;
|
absl::Status Compile(const CreationContext& creation_context) override;
|
||||||
|
|
||||||
// Move only
|
// Move only
|
||||||
Padding(Padding&& kernel);
|
Padding(Padding&& kernel);
|
||||||
@ -40,7 +40,7 @@ class Padding : public GPUOperation {
|
|||||||
Padding& operator=(const Padding&) = delete;
|
Padding& operator=(const Padding&) = delete;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Status BindArguments();
|
absl::Status BindArguments();
|
||||||
int3 GetGridSize() const;
|
int3 GetGridSize() const;
|
||||||
|
|
||||||
PadAttributes attributes_;
|
PadAttributes attributes_;
|
||||||
|
@ -408,7 +408,7 @@ Pooling& Pooling::operator=(Pooling&& kernel) {
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Pooling::Compile(const CreationContext& creation_context) {
|
absl::Status Pooling::Compile(const CreationContext& creation_context) {
|
||||||
std::string code;
|
std::string code;
|
||||||
const bool stride_correction =
|
const bool stride_correction =
|
||||||
definition_.IsBatchSupported() && stride_.x != 1;
|
definition_.IsBatchSupported() && stride_.x != 1;
|
||||||
@ -423,7 +423,7 @@ Status Pooling::Compile(const CreationContext& creation_context) {
|
|||||||
linked_operations_, output_indices_);
|
linked_operations_, output_indices_);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
return InvalidArgumentError(
|
return absl::InvalidArgumentError(
|
||||||
"You should create another kernel with this params");
|
"You should create another kernel with this params");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -432,7 +432,7 @@ Status Pooling::Compile(const CreationContext& creation_context) {
|
|||||||
*creation_context.device, &kernel_);
|
*creation_context.device, &kernel_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Pooling::BindArguments() {
|
absl::Status Pooling::BindArguments() {
|
||||||
kernel_.ResetBindingCounter();
|
kernel_.ResetBindingCounter();
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||||
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
||||||
@ -447,7 +447,7 @@ Status Pooling::BindArguments() {
|
|||||||
kernel_.SetBytesAuto(int2(padding_.x * src_[0]->Batch(), padding_.y)));
|
kernel_.SetBytesAuto(int2(padding_.x * src_[0]->Batch(), padding_.y)));
|
||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(stride_));
|
RETURN_IF_ERROR(kernel_.SetBytesAuto(stride_));
|
||||||
|
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
int3 Pooling::GetGridSize() const {
|
int3 Pooling::GetGridSize() const {
|
||||||
@ -457,12 +457,12 @@ int3 Pooling::GetGridSize() const {
|
|||||||
return int3(grid_x, grid_y, grid_z);
|
return int3(grid_x, grid_y, grid_z);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Pooling::Tune(const TuningParameters& params) {
|
absl::Status Pooling::Tune(const TuningParameters& params) {
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_);
|
return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Pooling::AddToQueue(CLCommandQueue* queue) {
|
absl::Status Pooling::AddToQueue(CLCommandQueue* queue) {
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
||||||
}
|
}
|
||||||
@ -506,7 +506,7 @@ Pooling3D& Pooling3D::operator=(Pooling3D&& kernel) {
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Pooling3D::Compile(const CreationContext& creation_context) {
|
absl::Status Pooling3D::Compile(const CreationContext& creation_context) {
|
||||||
std::string code;
|
std::string code;
|
||||||
const bool stride_correction =
|
const bool stride_correction =
|
||||||
definition_.IsBatchSupported() && stride_.x != 1;
|
definition_.IsBatchSupported() && stride_.x != 1;
|
||||||
@ -521,7 +521,7 @@ Status Pooling3D::Compile(const CreationContext& creation_context) {
|
|||||||
linked_operations_, output_indices_);
|
linked_operations_, output_indices_);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
return InvalidArgumentError(
|
return absl::InvalidArgumentError(
|
||||||
"You should create another kernel with this params");
|
"You should create another kernel with this params");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -530,7 +530,7 @@ Status Pooling3D::Compile(const CreationContext& creation_context) {
|
|||||||
*creation_context.device, &kernel_);
|
*creation_context.device, &kernel_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Pooling3D::BindArguments() {
|
absl::Status Pooling3D::BindArguments() {
|
||||||
kernel_.ResetBindingCounter();
|
kernel_.ResetBindingCounter();
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||||
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
||||||
@ -550,7 +550,7 @@ Status Pooling3D::BindArguments() {
|
|||||||
RETURN_IF_ERROR(
|
RETURN_IF_ERROR(
|
||||||
kernel_.SetBytesAuto(int4(stride_.x, stride_.y, stride_.z, 1)));
|
kernel_.SetBytesAuto(int4(stride_.x, stride_.y, stride_.z, 1)));
|
||||||
|
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
int3 Pooling3D::GetGridSize() const {
|
int3 Pooling3D::GetGridSize() const {
|
||||||
@ -560,12 +560,12 @@ int3 Pooling3D::GetGridSize() const {
|
|||||||
return int3(grid_x, grid_y, grid_z);
|
return int3(grid_x, grid_y, grid_z);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Pooling3D::Tune(const TuningParameters& params) {
|
absl::Status Pooling3D::Tune(const TuningParameters& params) {
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_);
|
return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Pooling3D::AddToQueue(CLCommandQueue* queue) {
|
absl::Status Pooling3D::AddToQueue(CLCommandQueue* queue) {
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
||||||
}
|
}
|
||||||
|
@ -30,10 +30,10 @@ namespace cl {
|
|||||||
class Pooling : public GPUOperation {
|
class Pooling : public GPUOperation {
|
||||||
public:
|
public:
|
||||||
Pooling(const OperationDef& definition, const Pooling2DAttributes& attr);
|
Pooling(const OperationDef& definition, const Pooling2DAttributes& attr);
|
||||||
Status AddToQueue(CLCommandQueue* queue) override;
|
absl::Status AddToQueue(CLCommandQueue* queue) override;
|
||||||
Status Tune(const TuningParameters& params) override;
|
absl::Status Tune(const TuningParameters& params) override;
|
||||||
|
|
||||||
Status Compile(const CreationContext& creation_context) override;
|
absl::Status Compile(const CreationContext& creation_context) override;
|
||||||
|
|
||||||
// Move only
|
// Move only
|
||||||
Pooling(Pooling&& kernel);
|
Pooling(Pooling&& kernel);
|
||||||
@ -42,7 +42,7 @@ class Pooling : public GPUOperation {
|
|||||||
Pooling& operator=(const Pooling&) = delete;
|
Pooling& operator=(const Pooling&) = delete;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Status BindArguments();
|
absl::Status BindArguments();
|
||||||
int3 GetGridSize() const;
|
int3 GetGridSize() const;
|
||||||
|
|
||||||
int2 stride_;
|
int2 stride_;
|
||||||
@ -62,10 +62,10 @@ Pooling CreatePooling(const OperationDef& definition,
|
|||||||
class Pooling3D : public GPUOperation {
|
class Pooling3D : public GPUOperation {
|
||||||
public:
|
public:
|
||||||
Pooling3D(const OperationDef& definition, const Pooling3DAttributes& attr);
|
Pooling3D(const OperationDef& definition, const Pooling3DAttributes& attr);
|
||||||
Status AddToQueue(CLCommandQueue* queue) override;
|
absl::Status AddToQueue(CLCommandQueue* queue) override;
|
||||||
Status Tune(const TuningParameters& params) override;
|
absl::Status Tune(const TuningParameters& params) override;
|
||||||
|
|
||||||
Status Compile(const CreationContext& creation_context) override;
|
absl::Status Compile(const CreationContext& creation_context) override;
|
||||||
|
|
||||||
// Move only
|
// Move only
|
||||||
Pooling3D(Pooling3D&& kernel);
|
Pooling3D(Pooling3D&& kernel);
|
||||||
@ -74,7 +74,7 @@ class Pooling3D : public GPUOperation {
|
|||||||
Pooling3D& operator=(const Pooling3D&) = delete;
|
Pooling3D& operator=(const Pooling3D&) = delete;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Status BindArguments();
|
absl::Status BindArguments();
|
||||||
int3 GetGridSize() const;
|
int3 GetGridSize() const;
|
||||||
|
|
||||||
int3 stride_;
|
int3 stride_;
|
||||||
|
@ -73,21 +73,21 @@ std::string PReLU::GetArgsDeclaration() const {
|
|||||||
return args;
|
return args;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status PReLU::BindArguments(CLKernel* kernel) {
|
absl::Status PReLU::BindArguments(CLKernel* kernel) {
|
||||||
RETURN_IF_ERROR(kernel->SetMemoryAuto(alpha_.GetMemoryPtr()));
|
RETURN_IF_ERROR(kernel->SetMemoryAuto(alpha_.GetMemoryPtr()));
|
||||||
if (clip_.Active()) {
|
if (clip_.Active()) {
|
||||||
RETURN_IF_ERROR(kernel->SetBytesAuto(clip_));
|
RETURN_IF_ERROR(kernel->SetBytesAuto(clip_));
|
||||||
}
|
}
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreatePReLU(const CreationContext& creation_context,
|
absl::Status CreatePReLU(const CreationContext& creation_context,
|
||||||
const OperationDef& definition, const PReLUAttributes& attr,
|
const OperationDef& definition,
|
||||||
PReLU* result) {
|
const PReLUAttributes& attr, PReLU* result) {
|
||||||
auto alpha = absl::get_if<::tflite::gpu::Tensor<Linear, DataType::FLOAT32>>(
|
auto alpha = absl::get_if<::tflite::gpu::Tensor<Linear, DataType::FLOAT32>>(
|
||||||
&attr.alpha);
|
&attr.alpha);
|
||||||
if (!alpha) {
|
if (!alpha) {
|
||||||
return InvalidArgumentError("Alpha is missing");
|
return absl::InvalidArgumentError("Alpha is missing");
|
||||||
}
|
}
|
||||||
const auto scalar_precision = creation_context.device->IsPowerVR()
|
const auto scalar_precision = creation_context.device->IsPowerVR()
|
||||||
? CalculationsPrecision::F32
|
? CalculationsPrecision::F32
|
||||||
@ -95,7 +95,7 @@ Status CreatePReLU(const CreationContext& creation_context,
|
|||||||
*result = PReLU(definition, attr, scalar_precision);
|
*result = PReLU(definition, attr, scalar_precision);
|
||||||
RETURN_IF_ERROR(result->UploadParameters(*alpha, creation_context.context));
|
RETURN_IF_ERROR(result->UploadParameters(*alpha, creation_context.context));
|
||||||
result->SetLinkIndex(0);
|
result->SetLinkIndex(0);
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cl
|
} // namespace cl
|
||||||
|
@ -44,30 +44,30 @@ class PReLU : public ElementwiseOperation {
|
|||||||
void SetLinkIndex(int index) override;
|
void SetLinkIndex(int index) override;
|
||||||
std::string GetCoreCode(const LinkingContext& context) const override;
|
std::string GetCoreCode(const LinkingContext& context) const override;
|
||||||
std::string GetArgsDeclaration() const override;
|
std::string GetArgsDeclaration() const override;
|
||||||
Status BindArguments(CLKernel* kernel) override;
|
absl::Status BindArguments(CLKernel* kernel) override;
|
||||||
|
|
||||||
friend Status CreatePReLU(const CreationContext& creation_context,
|
friend absl::Status CreatePReLU(const CreationContext& creation_context,
|
||||||
const OperationDef& definition,
|
const OperationDef& definition,
|
||||||
const PReLUAttributes& attr, PReLU* result);
|
const PReLUAttributes& attr, PReLU* result);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
PReLU(const OperationDef& definition, const PReLUAttributes& attr,
|
PReLU(const OperationDef& definition, const PReLUAttributes& attr,
|
||||||
CalculationsPrecision scalar_precision);
|
CalculationsPrecision scalar_precision);
|
||||||
|
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
Status UploadParameters(const ::tflite::gpu::Tensor<Linear, T>& parameters,
|
absl::Status UploadParameters(
|
||||||
CLContext* context);
|
const ::tflite::gpu::Tensor<Linear, T>& parameters, CLContext* context);
|
||||||
|
|
||||||
FLT clip_;
|
FLT clip_;
|
||||||
LinearStorage alpha_;
|
LinearStorage alpha_;
|
||||||
};
|
};
|
||||||
|
|
||||||
Status CreatePReLU(const CreationContext& creation_context,
|
absl::Status CreatePReLU(const CreationContext& creation_context,
|
||||||
const OperationDef& definition, const PReLUAttributes& attr,
|
const OperationDef& definition,
|
||||||
PReLU* result);
|
const PReLUAttributes& attr, PReLU* result);
|
||||||
|
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
Status PReLU::UploadParameters(
|
absl::Status PReLU::UploadParameters(
|
||||||
const ::tflite::gpu::Tensor<Linear, T>& parameters, CLContext* context) {
|
const ::tflite::gpu::Tensor<Linear, T>& parameters, CLContext* context) {
|
||||||
LinearStorageCreateInfo create_info;
|
LinearStorageCreateInfo create_info;
|
||||||
create_info.storage_type =
|
create_info.storage_type =
|
||||||
@ -75,7 +75,7 @@ Status PReLU::UploadParameters(
|
|||||||
create_info.data_type = definition_.GetPrimaryDataType();
|
create_info.data_type = definition_.GetPrimaryDataType();
|
||||||
RETURN_IF_ERROR(
|
RETURN_IF_ERROR(
|
||||||
CreateLinearStorage(create_info, parameters, context, &alpha_));
|
CreateLinearStorage(create_info, parameters, context, &alpha_));
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cl
|
} // namespace cl
|
||||||
|
@ -92,17 +92,17 @@ std::string QuantizeAndDequantize::GetArgsDeclaration() const {
|
|||||||
scale_.GetDeclaration());
|
scale_.GetDeclaration());
|
||||||
}
|
}
|
||||||
|
|
||||||
Status QuantizeAndDequantize::BindArguments(CLKernel* kernel) {
|
absl::Status QuantizeAndDequantize::BindArguments(CLKernel* kernel) {
|
||||||
RETURN_IF_ERROR(kernel->SetBytesAuto(min_));
|
RETURN_IF_ERROR(kernel->SetBytesAuto(min_));
|
||||||
RETURN_IF_ERROR(kernel->SetBytesAuto(max_));
|
RETURN_IF_ERROR(kernel->SetBytesAuto(max_));
|
||||||
RETURN_IF_ERROR(kernel->SetBytesAuto(scale_));
|
RETURN_IF_ERROR(kernel->SetBytesAuto(scale_));
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateQuantizeAndDequantize(const CreationContext& creation_context,
|
absl::Status CreateQuantizeAndDequantize(
|
||||||
const OperationDef& definition,
|
const CreationContext& creation_context, const OperationDef& definition,
|
||||||
const QuantizeAndDequantizeAttributes& attr,
|
const QuantizeAndDequantizeAttributes& attr,
|
||||||
QuantizeAndDequantize* result) {
|
QuantizeAndDequantize* result) {
|
||||||
const auto scalar_precision = creation_context.device->IsPowerVR()
|
const auto scalar_precision = creation_context.device->IsPowerVR()
|
||||||
? CalculationsPrecision::F32
|
? CalculationsPrecision::F32
|
||||||
: definition.precision;
|
: definition.precision;
|
||||||
@ -120,7 +120,7 @@ Status CreateQuantizeAndDequantize(const CreationContext& creation_context,
|
|||||||
*result = QuantizeAndDequantize(definition, attr, scalar_precision);
|
*result = QuantizeAndDequantize(definition, attr, scalar_precision);
|
||||||
}
|
}
|
||||||
result->SetLinkIndex(0);
|
result->SetLinkIndex(0);
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cl
|
} // namespace cl
|
||||||
|
@ -57,9 +57,9 @@ class QuantizeAndDequantize : public ElementwiseOperation {
|
|||||||
void SetLinkIndex(int index) override;
|
void SetLinkIndex(int index) override;
|
||||||
std::string GetCoreCode(const LinkingContext& context) const override;
|
std::string GetCoreCode(const LinkingContext& context) const override;
|
||||||
std::string GetArgsDeclaration() const override;
|
std::string GetArgsDeclaration() const override;
|
||||||
Status BindArguments(CLKernel* kernel) override;
|
absl::Status BindArguments(CLKernel* kernel) override;
|
||||||
|
|
||||||
friend Status CreateQuantizeAndDequantize(
|
friend absl::Status CreateQuantizeAndDequantize(
|
||||||
const CreationContext& creation_context, const OperationDef& definition,
|
const CreationContext& creation_context, const OperationDef& definition,
|
||||||
const QuantizeAndDequantizeAttributes& attr,
|
const QuantizeAndDequantizeAttributes& attr,
|
||||||
QuantizeAndDequantize* result);
|
QuantizeAndDequantize* result);
|
||||||
@ -70,27 +70,26 @@ class QuantizeAndDequantize : public ElementwiseOperation {
|
|||||||
CalculationsPrecision scalar_precision);
|
CalculationsPrecision scalar_precision);
|
||||||
|
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
Status UploadParameters(const ::tflite::gpu::Tensor<Linear, T>& parameters,
|
absl::Status UploadParameters(
|
||||||
CLContext* context);
|
const ::tflite::gpu::Tensor<Linear, T>& parameters, CLContext* context);
|
||||||
|
|
||||||
FLT min_;
|
FLT min_;
|
||||||
FLT max_;
|
FLT max_;
|
||||||
FLT scale_;
|
FLT scale_;
|
||||||
};
|
};
|
||||||
|
|
||||||
Status CreateQuantizeAndDequantize(const CreationContext& creation_context,
|
absl::Status CreateQuantizeAndDequantize(
|
||||||
const OperationDef& definition,
|
const CreationContext& creation_context, const OperationDef& definition,
|
||||||
const QuantizeAndDequantizeAttributes& attr,
|
const QuantizeAndDequantizeAttributes& attr, QuantizeAndDequantize* result);
|
||||||
QuantizeAndDequantize* result);
|
|
||||||
|
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
Status QuantizeAndDequantize::UploadParameters(
|
absl::Status QuantizeAndDequantize::UploadParameters(
|
||||||
const ::tflite::gpu::Tensor<Linear, T>& parameters, CLContext* context) {
|
const ::tflite::gpu::Tensor<Linear, T>& parameters, CLContext* context) {
|
||||||
LinearStorageCreateInfo create_info;
|
LinearStorageCreateInfo create_info;
|
||||||
create_info.storage_type =
|
create_info.storage_type =
|
||||||
DeduceLinearStorageType(definition_.GetPrimaryStorageType());
|
DeduceLinearStorageType(definition_.GetPrimaryStorageType());
|
||||||
create_info.data_type = definition_.GetPrimaryDataType();
|
create_info.data_type = definition_.GetPrimaryDataType();
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cl
|
} // namespace cl
|
||||||
|
@ -80,14 +80,14 @@ std::string ReLU::GetArgsDeclaration() const {
|
|||||||
return args;
|
return args;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ReLU::BindArguments(CLKernel* kernel) {
|
absl::Status ReLU::BindArguments(CLKernel* kernel) {
|
||||||
if (alpha_.Active()) {
|
if (alpha_.Active()) {
|
||||||
RETURN_IF_ERROR(kernel->SetBytesAuto(alpha_));
|
RETURN_IF_ERROR(kernel->SetBytesAuto(alpha_));
|
||||||
}
|
}
|
||||||
if (clip_.Active()) {
|
if (clip_.Active()) {
|
||||||
RETURN_IF_ERROR(kernel->SetBytesAuto(clip_));
|
RETURN_IF_ERROR(kernel->SetBytesAuto(clip_));
|
||||||
}
|
}
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
ReLU CreateReLU(const CreationContext& creation_context,
|
ReLU CreateReLU(const CreationContext& creation_context,
|
||||||
|
@ -37,7 +37,7 @@ class ReLU : public ElementwiseOperation {
|
|||||||
void SetLinkIndex(int index) override;
|
void SetLinkIndex(int index) override;
|
||||||
std::string GetCoreCode(const LinkingContext& context) const override;
|
std::string GetCoreCode(const LinkingContext& context) const override;
|
||||||
std::string GetArgsDeclaration() const override;
|
std::string GetArgsDeclaration() const override;
|
||||||
Status BindArguments(CLKernel* kernel) override;
|
absl::Status BindArguments(CLKernel* kernel) override;
|
||||||
|
|
||||||
friend ReLU CreateReLU(const CreationContext& creation_context,
|
friend ReLU CreateReLU(const CreationContext& creation_context,
|
||||||
const OperationDef& definition,
|
const OperationDef& definition,
|
||||||
|
@ -156,7 +156,7 @@ Reshape& Reshape::operator=(Reshape&& operation) {
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Reshape::Compile(const CreationContext& creation_context) {
|
absl::Status Reshape::Compile(const CreationContext& creation_context) {
|
||||||
const auto code = definition_.IsBatchSupported()
|
const auto code = definition_.IsBatchSupported()
|
||||||
? GetReshapeBatchedCode(definition_, linked_operations_)
|
? GetReshapeBatchedCode(definition_, linked_operations_)
|
||||||
: GetReshapeCode(definition_, linked_operations_);
|
: GetReshapeCode(definition_, linked_operations_);
|
||||||
@ -165,7 +165,7 @@ Status Reshape::Compile(const CreationContext& creation_context) {
|
|||||||
*creation_context.device, &kernel_);
|
*creation_context.device, &kernel_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Reshape::BindArguments() {
|
absl::Status Reshape::BindArguments() {
|
||||||
kernel_.ResetBindingCounter();
|
kernel_.ResetBindingCounter();
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||||
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
||||||
@ -174,8 +174,7 @@ Status Reshape::BindArguments() {
|
|||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB()));
|
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB()));
|
||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->Channels()));
|
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->Channels()));
|
||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->Channels()));
|
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->Channels()));
|
||||||
|
return absl::OkStatus();
|
||||||
return OkStatus();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int3 Reshape::GetGridSize() const {
|
int3 Reshape::GetGridSize() const {
|
||||||
@ -185,12 +184,12 @@ int3 Reshape::GetGridSize() const {
|
|||||||
return int3(grid_x, grid_y, grid_z);
|
return int3(grid_x, grid_y, grid_z);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Reshape::Tune(const TuningParameters& params) {
|
absl::Status Reshape::Tune(const TuningParameters& params) {
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_);
|
return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Reshape::AddToQueue(CLCommandQueue* queue) {
|
absl::Status Reshape::AddToQueue(CLCommandQueue* queue) {
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
||||||
}
|
}
|
||||||
|
@ -29,10 +29,10 @@ class Reshape : public GPUOperation {
|
|||||||
public:
|
public:
|
||||||
explicit Reshape(const OperationDef& definition)
|
explicit Reshape(const OperationDef& definition)
|
||||||
: GPUOperation(definition), work_group_size_(8, 4, 1) {}
|
: GPUOperation(definition), work_group_size_(8, 4, 1) {}
|
||||||
Status AddToQueue(CLCommandQueue* queue) override;
|
absl::Status AddToQueue(CLCommandQueue* queue) override;
|
||||||
Status Tune(const TuningParameters& params) override;
|
absl::Status Tune(const TuningParameters& params) override;
|
||||||
|
|
||||||
Status Compile(const CreationContext& creation_context) override;
|
absl::Status Compile(const CreationContext& creation_context) override;
|
||||||
|
|
||||||
// Move only
|
// Move only
|
||||||
Reshape(Reshape&& operation);
|
Reshape(Reshape&& operation);
|
||||||
@ -41,7 +41,7 @@ class Reshape : public GPUOperation {
|
|||||||
Reshape& operator=(const Reshape&) = delete;
|
Reshape& operator=(const Reshape&) = delete;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Status BindArguments();
|
absl::Status BindArguments();
|
||||||
int3 GetGridSize() const;
|
int3 GetGridSize() const;
|
||||||
|
|
||||||
CLKernel kernel_;
|
CLKernel kernel_;
|
||||||
|
@ -120,7 +120,7 @@ Reshapex4& Reshapex4::operator=(Reshapex4&& operation) {
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Reshapex4::Compile(const CreationContext& creation_context) {
|
absl::Status Reshapex4::Compile(const CreationContext& creation_context) {
|
||||||
const auto code = definition_.IsBatchSupported()
|
const auto code = definition_.IsBatchSupported()
|
||||||
? GetReshapeBatchedCode(definition_, linked_operations_)
|
? GetReshapeBatchedCode(definition_, linked_operations_)
|
||||||
: GetReshapeCode(definition_, linked_operations_);
|
: GetReshapeCode(definition_, linked_operations_);
|
||||||
@ -129,15 +129,14 @@ Status Reshapex4::Compile(const CreationContext& creation_context) {
|
|||||||
*creation_context.device, &kernel_);
|
*creation_context.device, &kernel_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Reshapex4::BindArguments() {
|
absl::Status Reshapex4::BindArguments() {
|
||||||
kernel_.ResetBindingCounter();
|
kernel_.ResetBindingCounter();
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||||
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
|
||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHSB()));
|
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHSB()));
|
||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB()));
|
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB()));
|
||||||
|
return absl::OkStatus();
|
||||||
return OkStatus();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int3 Reshapex4::GetGridSize() const {
|
int3 Reshapex4::GetGridSize() const {
|
||||||
@ -147,12 +146,12 @@ int3 Reshapex4::GetGridSize() const {
|
|||||||
return int3(grid_x, grid_y, grid_z);
|
return int3(grid_x, grid_y, grid_z);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Reshapex4::Tune(const TuningParameters& params) {
|
absl::Status Reshapex4::Tune(const TuningParameters& params) {
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_);
|
return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Reshapex4::AddToQueue(CLCommandQueue* queue) {
|
absl::Status Reshapex4::AddToQueue(CLCommandQueue* queue) {
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
||||||
}
|
}
|
||||||
|
@ -30,10 +30,10 @@ class Reshapex4 : public GPUOperation {
|
|||||||
public:
|
public:
|
||||||
explicit Reshapex4(const OperationDef& definition)
|
explicit Reshapex4(const OperationDef& definition)
|
||||||
: GPUOperation(definition), work_group_size_(8, 4, 1) {}
|
: GPUOperation(definition), work_group_size_(8, 4, 1) {}
|
||||||
Status AddToQueue(CLCommandQueue* queue) override;
|
absl::Status AddToQueue(CLCommandQueue* queue) override;
|
||||||
Status Tune(const TuningParameters& params) override;
|
absl::Status Tune(const TuningParameters& params) override;
|
||||||
|
|
||||||
Status Compile(const CreationContext& creation_context) override;
|
absl::Status Compile(const CreationContext& creation_context) override;
|
||||||
|
|
||||||
// Move only
|
// Move only
|
||||||
Reshapex4(Reshapex4&& operation);
|
Reshapex4(Reshapex4&& operation);
|
||||||
@ -42,7 +42,7 @@ class Reshapex4 : public GPUOperation {
|
|||||||
Reshapex4& operator=(const Reshapex4&) = delete;
|
Reshapex4& operator=(const Reshapex4&) = delete;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Status BindArguments();
|
absl::Status BindArguments();
|
||||||
int3 GetGridSize() const;
|
int3 GetGridSize() const;
|
||||||
|
|
||||||
CLKernel kernel_;
|
CLKernel kernel_;
|
||||||
|
@ -209,7 +209,7 @@ Resize& Resize::operator=(Resize&& operation) {
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Resize::Compile(const CreationContext& creation_context) {
|
absl::Status Resize::Compile(const CreationContext& creation_context) {
|
||||||
const auto code = GetResizeCode(definition_, attr_.type,
|
const auto code = GetResizeCode(definition_, attr_.type,
|
||||||
attr_.half_pixel_centers, linked_operations_);
|
attr_.half_pixel_centers, linked_operations_);
|
||||||
return creation_context.cache->GetOrCreateCLKernel(
|
return creation_context.cache->GetOrCreateCLKernel(
|
||||||
@ -217,7 +217,7 @@ Status Resize::Compile(const CreationContext& creation_context) {
|
|||||||
*creation_context.device, &kernel_);
|
*creation_context.device, &kernel_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Resize::BindArguments() {
|
absl::Status Resize::BindArguments() {
|
||||||
kernel_.ResetBindingCounter();
|
kernel_.ResetBindingCounter();
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||||
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
||||||
@ -230,7 +230,7 @@ Status Resize::BindArguments() {
|
|||||||
float2(CalculateResizeScale(src_[0]->Width(), dst_[0]->Width(), attr_),
|
float2(CalculateResizeScale(src_[0]->Width(), dst_[0]->Width(), attr_),
|
||||||
CalculateResizeScale(src_[0]->Height(), dst_[0]->Height(), attr_));
|
CalculateResizeScale(src_[0]->Height(), dst_[0]->Height(), attr_));
|
||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(scale_factor));
|
RETURN_IF_ERROR(kernel_.SetBytesAuto(scale_factor));
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
int3 Resize::GetGridSize() const {
|
int3 Resize::GetGridSize() const {
|
||||||
@ -240,12 +240,12 @@ int3 Resize::GetGridSize() const {
|
|||||||
return int3(grid_x, grid_y, grid_z);
|
return int3(grid_x, grid_y, grid_z);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Resize::AddToQueue(CLCommandQueue* queue) {
|
absl::Status Resize::AddToQueue(CLCommandQueue* queue) {
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Resize::Tune(const TuningParameters& params) {
|
absl::Status Resize::Tune(const TuningParameters& params) {
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_);
|
return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_);
|
||||||
}
|
}
|
||||||
@ -271,7 +271,7 @@ Resize3D& Resize3D::operator=(Resize3D&& operation) {
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Resize3D::Compile(const CreationContext& creation_context) {
|
absl::Status Resize3D::Compile(const CreationContext& creation_context) {
|
||||||
const auto code =
|
const auto code =
|
||||||
GetResize3DCode(definition_, attr_.type, linked_operations_);
|
GetResize3DCode(definition_, attr_.type, linked_operations_);
|
||||||
return creation_context.cache->GetOrCreateCLKernel(
|
return creation_context.cache->GetOrCreateCLKernel(
|
||||||
@ -279,7 +279,7 @@ Status Resize3D::Compile(const CreationContext& creation_context) {
|
|||||||
*creation_context.device, &kernel_);
|
*creation_context.device, &kernel_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Resize3D::BindArguments() {
|
absl::Status Resize3D::BindArguments() {
|
||||||
kernel_.ResetBindingCounter();
|
kernel_.ResetBindingCounter();
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||||
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
||||||
@ -296,7 +296,7 @@ Status Resize3D::BindArguments() {
|
|||||||
CalculateResizeScale(src_[0]->Height(), dst_[0]->Height(), attr_),
|
CalculateResizeScale(src_[0]->Height(), dst_[0]->Height(), attr_),
|
||||||
CalculateResizeScale(src_[0]->Depth(), dst_[0]->Depth(), attr_), 1.0f);
|
CalculateResizeScale(src_[0]->Depth(), dst_[0]->Depth(), attr_), 1.0f);
|
||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(scale_factor));
|
RETURN_IF_ERROR(kernel_.SetBytesAuto(scale_factor));
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
int3 Resize3D::GetGridSize() const {
|
int3 Resize3D::GetGridSize() const {
|
||||||
@ -306,12 +306,12 @@ int3 Resize3D::GetGridSize() const {
|
|||||||
return int3(grid_x, grid_y, grid_z);
|
return int3(grid_x, grid_y, grid_z);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Resize3D::AddToQueue(CLCommandQueue* queue) {
|
absl::Status Resize3D::AddToQueue(CLCommandQueue* queue) {
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Resize3D::Tune(const TuningParameters& params) {
|
absl::Status Resize3D::Tune(const TuningParameters& params) {
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_);
|
return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_);
|
||||||
}
|
}
|
||||||
|
@ -27,10 +27,10 @@ namespace cl {
|
|||||||
|
|
||||||
class Resize : public GPUOperation {
|
class Resize : public GPUOperation {
|
||||||
public:
|
public:
|
||||||
Status AddToQueue(CLCommandQueue* queue) override;
|
absl::Status AddToQueue(CLCommandQueue* queue) override;
|
||||||
Status Tune(const TuningParameters& params) override;
|
absl::Status Tune(const TuningParameters& params) override;
|
||||||
|
|
||||||
Status Compile(const CreationContext& creation_context) override;
|
absl::Status Compile(const CreationContext& creation_context) override;
|
||||||
|
|
||||||
// Move only
|
// Move only
|
||||||
Resize(Resize&& operation);
|
Resize(Resize&& operation);
|
||||||
@ -45,7 +45,7 @@ class Resize : public GPUOperation {
|
|||||||
Resize(const OperationDef& definition, const Resize2DAttributes& attr)
|
Resize(const OperationDef& definition, const Resize2DAttributes& attr)
|
||||||
: GPUOperation(definition), attr_(attr) {}
|
: GPUOperation(definition), attr_(attr) {}
|
||||||
|
|
||||||
Status BindArguments();
|
absl::Status BindArguments();
|
||||||
int3 GetGridSize() const;
|
int3 GetGridSize() const;
|
||||||
|
|
||||||
Resize2DAttributes attr_;
|
Resize2DAttributes attr_;
|
||||||
@ -58,10 +58,10 @@ Resize CreateResize(const OperationDef& definition,
|
|||||||
|
|
||||||
class Resize3D : public GPUOperation {
|
class Resize3D : public GPUOperation {
|
||||||
public:
|
public:
|
||||||
Status AddToQueue(CLCommandQueue* queue) override;
|
absl::Status AddToQueue(CLCommandQueue* queue) override;
|
||||||
Status Tune(const TuningParameters& params) override;
|
absl::Status Tune(const TuningParameters& params) override;
|
||||||
|
|
||||||
Status Compile(const CreationContext& creation_context) override;
|
absl::Status Compile(const CreationContext& creation_context) override;
|
||||||
|
|
||||||
// Move only
|
// Move only
|
||||||
Resize3D(Resize3D&& operation);
|
Resize3D(Resize3D&& operation);
|
||||||
@ -76,7 +76,7 @@ class Resize3D : public GPUOperation {
|
|||||||
Resize3D(const OperationDef& definition, const Resize3DAttributes& attr)
|
Resize3D(const OperationDef& definition, const Resize3DAttributes& attr)
|
||||||
: GPUOperation(definition), attr_(attr) {}
|
: GPUOperation(definition), attr_(attr) {}
|
||||||
|
|
||||||
Status BindArguments();
|
absl::Status BindArguments();
|
||||||
int3 GetGridSize() const;
|
int3 GetGridSize() const;
|
||||||
|
|
||||||
Resize3DAttributes attr_;
|
Resize3DAttributes attr_;
|
||||||
|
@ -79,14 +79,14 @@ Softmax& Softmax::operator=(Softmax&& kernel) {
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Softmax::Compile(const CreationContext& creation_context) {
|
absl::Status Softmax::Compile(const CreationContext& creation_context) {
|
||||||
const auto code = GetSoftmaxKernelCode(definition_, linked_operations_);
|
const auto code = GetSoftmaxKernelCode(definition_, linked_operations_);
|
||||||
return creation_context.cache->GetOrCreateCLKernel(
|
return creation_context.cache->GetOrCreateCLKernel(
|
||||||
code, "main_function", *creation_context.context,
|
code, "main_function", *creation_context.context,
|
||||||
*creation_context.device, &kernel_);
|
*creation_context.device, &kernel_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Softmax::BindArguments() {
|
absl::Status Softmax::BindArguments() {
|
||||||
kernel_.ResetBindingCounter();
|
kernel_.ResetBindingCounter();
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||||
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
||||||
@ -94,7 +94,7 @@ Status Softmax::BindArguments() {
|
|||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWBatchedHSB()));
|
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWBatchedHSB()));
|
||||||
RETURN_IF_ERROR(
|
RETURN_IF_ERROR(
|
||||||
kernel_.SetBytesAuto(GetMaskForLastPlane(src_[0]->Channels())));
|
kernel_.SetBytesAuto(GetMaskForLastPlane(src_[0]->Channels())));
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
int3 Softmax::GetGridSize() const {
|
int3 Softmax::GetGridSize() const {
|
||||||
@ -104,12 +104,12 @@ int3 Softmax::GetGridSize() const {
|
|||||||
return int3(grid_x, grid_y, grid_z);
|
return int3(grid_x, grid_y, grid_z);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Softmax::Tune(const TuningParameters& params) {
|
absl::Status Softmax::Tune(const TuningParameters& params) {
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_);
|
return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Softmax::AddToQueue(CLCommandQueue* queue) {
|
absl::Status Softmax::AddToQueue(CLCommandQueue* queue) {
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
||||||
}
|
}
|
||||||
|
@ -30,10 +30,10 @@ class Softmax : public GPUOperation {
|
|||||||
public:
|
public:
|
||||||
Softmax() = default;
|
Softmax() = default;
|
||||||
explicit Softmax(const OperationDef& definition) : GPUOperation(definition) {}
|
explicit Softmax(const OperationDef& definition) : GPUOperation(definition) {}
|
||||||
Status AddToQueue(CLCommandQueue* queue) override;
|
absl::Status AddToQueue(CLCommandQueue* queue) override;
|
||||||
Status Tune(const TuningParameters& params) override;
|
absl::Status Tune(const TuningParameters& params) override;
|
||||||
|
|
||||||
Status Compile(const CreationContext& creation_context) override;
|
absl::Status Compile(const CreationContext& creation_context) override;
|
||||||
|
|
||||||
// Move only
|
// Move only
|
||||||
Softmax(Softmax&& kernel);
|
Softmax(Softmax&& kernel);
|
||||||
@ -44,7 +44,7 @@ class Softmax : public GPUOperation {
|
|||||||
friend Softmax CreateSoftmax();
|
friend Softmax CreateSoftmax();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Status BindArguments();
|
absl::Status BindArguments();
|
||||||
int3 GetGridSize() const;
|
int3 GetGridSize() const;
|
||||||
CLKernel kernel_;
|
CLKernel kernel_;
|
||||||
int3 work_group_size_ = int3(8, 4, 1);
|
int3 work_group_size_ = int3(8, 4, 1);
|
||||||
|
@ -115,14 +115,14 @@ Softmax1x1& Softmax1x1::operator=(Softmax1x1&& kernel) {
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Softmax1x1::Compile(const CreationContext& creation_context) {
|
absl::Status Softmax1x1::Compile(const CreationContext& creation_context) {
|
||||||
const auto code = GetSoftmaxKernelCode(definition_, linked_operations_);
|
const auto code = GetSoftmaxKernelCode(definition_, linked_operations_);
|
||||||
return creation_context.cache->GetOrCreateCLKernel(
|
return creation_context.cache->GetOrCreateCLKernel(
|
||||||
code, "main_function", *creation_context.context,
|
code, "main_function", *creation_context.context,
|
||||||
*creation_context.device, &kernel_);
|
*creation_context.device, &kernel_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Softmax1x1::AddToQueue(CLCommandQueue* queue) {
|
absl::Status Softmax1x1::AddToQueue(CLCommandQueue* queue) {
|
||||||
kernel_.ResetBindingCounter();
|
kernel_.ResetBindingCounter();
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||||
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
||||||
|
@ -30,9 +30,9 @@ class Softmax1x1 : public GPUOperation {
|
|||||||
Softmax1x1() = default;
|
Softmax1x1() = default;
|
||||||
explicit Softmax1x1(const OperationDef& definition)
|
explicit Softmax1x1(const OperationDef& definition)
|
||||||
: GPUOperation(definition) {}
|
: GPUOperation(definition) {}
|
||||||
Status AddToQueue(CLCommandQueue* queue) override;
|
absl::Status AddToQueue(CLCommandQueue* queue) override;
|
||||||
|
|
||||||
Status Compile(const CreationContext& creation_context) override;
|
absl::Status Compile(const CreationContext& creation_context) override;
|
||||||
|
|
||||||
// Move only
|
// Move only
|
||||||
Softmax1x1(Softmax1x1&& kernel);
|
Softmax1x1(Softmax1x1&& kernel);
|
||||||
|
@ -96,14 +96,14 @@ SpaceToDepth& SpaceToDepth::operator=(SpaceToDepth&& operation) {
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status SpaceToDepth::Compile(const CreationContext& creation_context) {
|
absl::Status SpaceToDepth::Compile(const CreationContext& creation_context) {
|
||||||
const auto code = GetSpaceToDepthCode(definition_, linked_operations_);
|
const auto code = GetSpaceToDepthCode(definition_, linked_operations_);
|
||||||
return creation_context.cache->GetOrCreateCLKernel(
|
return creation_context.cache->GetOrCreateCLKernel(
|
||||||
code, "main_function", *creation_context.context,
|
code, "main_function", *creation_context.context,
|
||||||
*creation_context.device, &kernel_);
|
*creation_context.device, &kernel_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status SpaceToDepth::BindArguments() {
|
absl::Status SpaceToDepth::BindArguments() {
|
||||||
kernel_.ResetBindingCounter();
|
kernel_.ResetBindingCounter();
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||||
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
||||||
@ -121,12 +121,12 @@ int3 SpaceToDepth::GetGridSize() const {
|
|||||||
return int3(grid_x, grid_y, grid_z);
|
return int3(grid_x, grid_y, grid_z);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status SpaceToDepth::Tune(const TuningParameters& params) {
|
absl::Status SpaceToDepth::Tune(const TuningParameters& params) {
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_);
|
return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status SpaceToDepth::AddToQueue(CLCommandQueue* queue) {
|
absl::Status SpaceToDepth::AddToQueue(CLCommandQueue* queue) {
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
||||||
}
|
}
|
||||||
|
@ -30,9 +30,9 @@ class SpaceToDepth : public GPUOperation {
|
|||||||
public:
|
public:
|
||||||
SpaceToDepth(const OperationDef& op_def, const SpaceToDepthAttributes& attr)
|
SpaceToDepth(const OperationDef& op_def, const SpaceToDepthAttributes& attr)
|
||||||
: GPUOperation(op_def), attr_(attr), work_group_size_(8, 4, 1) {}
|
: GPUOperation(op_def), attr_(attr), work_group_size_(8, 4, 1) {}
|
||||||
Status AddToQueue(CLCommandQueue* queue) override;
|
absl::Status AddToQueue(CLCommandQueue* queue) override;
|
||||||
Status Tune(const TuningParameters& params) override;
|
absl::Status Tune(const TuningParameters& params) override;
|
||||||
Status Compile(const CreationContext& creation_context) override;
|
absl::Status Compile(const CreationContext& creation_context) override;
|
||||||
|
|
||||||
SpaceToDepth(SpaceToDepth&& operation);
|
SpaceToDepth(SpaceToDepth&& operation);
|
||||||
SpaceToDepth& operator=(SpaceToDepth&& operation);
|
SpaceToDepth& operator=(SpaceToDepth&& operation);
|
||||||
@ -40,7 +40,7 @@ class SpaceToDepth : public GPUOperation {
|
|||||||
SpaceToDepth& operator=(const SpaceToDepth&) = delete;
|
SpaceToDepth& operator=(const SpaceToDepth&) = delete;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Status BindArguments();
|
absl::Status BindArguments();
|
||||||
int3 GetGridSize() const;
|
int3 GetGridSize() const;
|
||||||
|
|
||||||
SpaceToDepthAttributes attr_;
|
SpaceToDepthAttributes attr_;
|
||||||
|
@ -166,7 +166,7 @@ StridedSlice& StridedSlice::operator=(StridedSlice&& operation) {
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status StridedSlice::Compile(const CreationContext& creation_context) {
|
absl::Status StridedSlice::Compile(const CreationContext& creation_context) {
|
||||||
const auto code = GetStridedSliceCode(definition_, Is4Aligned(attributes_),
|
const auto code = GetStridedSliceCode(definition_, Is4Aligned(attributes_),
|
||||||
linked_operations_);
|
linked_operations_);
|
||||||
return creation_context.cache->GetOrCreateCLKernel(
|
return creation_context.cache->GetOrCreateCLKernel(
|
||||||
@ -174,7 +174,7 @@ Status StridedSlice::Compile(const CreationContext& creation_context) {
|
|||||||
*creation_context.device, &kernel_);
|
*creation_context.device, &kernel_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status StridedSlice::BindArguments() {
|
absl::Status StridedSlice::BindArguments() {
|
||||||
kernel_.ResetBindingCounter();
|
kernel_.ResetBindingCounter();
|
||||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||||
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
||||||
@ -187,7 +187,7 @@ Status StridedSlice::BindArguments() {
|
|||||||
attributes_.strides.c, attributes_.strides.b)));
|
attributes_.strides.c, attributes_.strides.b)));
|
||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHSB()));
|
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHSB()));
|
||||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB()));
|
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB()));
|
||||||
return OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
int3 StridedSlice::GetGridSize() const {
|
int3 StridedSlice::GetGridSize() const {
|
||||||
@ -197,12 +197,12 @@ int3 StridedSlice::GetGridSize() const {
|
|||||||
return int3(grid_x, grid_y, grid_z);
|
return int3(grid_x, grid_y, grid_z);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status StridedSlice::Tune(const TuningParameters& params) {
|
absl::Status StridedSlice::Tune(const TuningParameters& params) {
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_);
|
return GetBestWorkGroup(params, kernel_, GetGridSize(), &work_group_size_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status StridedSlice::AddToQueue(CLCommandQueue* queue) {
|
absl::Status StridedSlice::AddToQueue(CLCommandQueue* queue) {
|
||||||
RETURN_IF_ERROR(BindArguments());
|
RETURN_IF_ERROR(BindArguments());
|
||||||
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
return queue->DispatchImplicit(kernel_, GetGridSize(), work_group_size_);
|
||||||
}
|
}
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user