Added support of HWC alpha in PReLU.
PiperOrigin-RevId: 327541760 Change-Id: I5491196f0990580abe51129a71c799c3fa9bb561
This commit is contained in:
parent
313edafd6f
commit
efa873b2a1
@ -941,6 +941,7 @@ cc_library(
|
||||
"//tensorflow/lite/delegates/gpu/cl:cl_context",
|
||||
"//tensorflow/lite/delegates/gpu/cl:cl_kernel",
|
||||
"//tensorflow/lite/delegates/gpu/cl:linear_storage",
|
||||
"//tensorflow/lite/delegates/gpu/cl:storage_type_util",
|
||||
"//tensorflow/lite/delegates/gpu/common:data_type",
|
||||
"//tensorflow/lite/delegates/gpu/common:operations",
|
||||
"//tensorflow/lite/delegates/gpu/common:status",
|
||||
|
@ -18,47 +18,75 @@ limitations under the License.
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/types/variant.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/storage_type_util.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace cl {
|
||||
|
||||
absl::Status CreatePReLU(const CreationContext& creation_context,
|
||||
GPUOperation CreatePReLU(const DeviceInfo& device_info,
|
||||
const OperationDef& definition,
|
||||
const PReLUAttributes& attr, GPUOperation* result) {
|
||||
*result = GPUOperation(definition);
|
||||
result->elementwise_ = true;
|
||||
const PReLUAttributes& attr) {
|
||||
GPUOperation result(definition);
|
||||
result.elementwise_ = true;
|
||||
|
||||
std::string alpha_read;
|
||||
auto alpha_linear =
|
||||
absl::get_if<tflite::gpu::Tensor<Linear, DataType::FLOAT32>>(&attr.alpha);
|
||||
if (alpha_linear) {
|
||||
TensorLinearDescriptor desc;
|
||||
desc.storage_type =
|
||||
DeduceLinearStorageType(definition.GetPrimaryStorageType());
|
||||
desc.element_type = definition.GetPrimaryDataType();
|
||||
desc.UploadLinearData(*alpha_linear);
|
||||
result.args_.AddObject(
|
||||
"alpha", absl::make_unique<TensorLinearDescriptor>(std::move(desc)));
|
||||
alpha_read = "FLT4 alpha_val = args.alpha.Read(S_COORD);\n";
|
||||
}
|
||||
|
||||
auto alpha_hwc =
|
||||
absl::get_if<tflite::gpu::Tensor<HWC, DataType::FLOAT32>>(&attr.alpha);
|
||||
if (alpha_hwc) {
|
||||
const BHWC shape =
|
||||
BHWC(1, alpha_hwc->shape.h, alpha_hwc->shape.w, alpha_hwc->shape.c);
|
||||
TensorStorageType storage_type = SelectBestStorageType(
|
||||
device_info, shape, definition.GetPrimaryStorageType(),
|
||||
definition.GetDataType(), Layout::HWC);
|
||||
TensorDescriptor desc{definition.GetDataType(), storage_type, Layout::HWC};
|
||||
desc.UploadData(*alpha_hwc);
|
||||
result.args_.AddObject(
|
||||
"alpha", absl::make_unique<TensorDescriptor>(std::move(desc)));
|
||||
const std::string x_coord = shape.w == 1 ? "0" : "X_COORD";
|
||||
const std::string y_coord = shape.h == 1 ? "0" : "Y_COORD";
|
||||
const std::string s_coord = shape.c == 1 ? "0" : "S_COORD";
|
||||
alpha_read = absl::StrCat("FLT4 alpha_val = args.alpha.Read(", x_coord,
|
||||
", ", y_coord, ", ", s_coord, ");\n");
|
||||
if (shape.c == 1) {
|
||||
alpha_read += " alpha_val.y = alpha_val.x;\n";
|
||||
alpha_read += " alpha_val.z = alpha_val.x;\n";
|
||||
alpha_read += " alpha_val.w = alpha_val.x;\n";
|
||||
}
|
||||
}
|
||||
|
||||
if (attr.clip != 0) {
|
||||
if (definition.precision == CalculationsPrecision::F32) {
|
||||
result->args_.AddFloat("clip", attr.clip);
|
||||
result.args_.AddFloat("clip", attr.clip);
|
||||
} else {
|
||||
result->args_.AddHalf("clip", half(attr.clip));
|
||||
result.args_.AddHalf("clip", half(attr.clip));
|
||||
}
|
||||
result->code_ =
|
||||
result.code_ =
|
||||
alpha_read +
|
||||
"in_out_value = clamp(in_out_value, (FLT4)(0.0f), (FLT4)(args.clip)) + "
|
||||
"min((FLT4)(0.0f), in_out_value) * args.alpha.Read(S_COORD);";
|
||||
"min((FLT4)(0.0f), in_out_value) * alpha_val;";
|
||||
} else {
|
||||
result->code_ =
|
||||
result.code_ =
|
||||
alpha_read +
|
||||
"in_out_value = max((FLT4)(0.0f), in_out_value) + min((FLT4)(0.0f), "
|
||||
"in_out_value) * args.alpha.Read(S_COORD);";
|
||||
"in_out_value) * alpha_val;";
|
||||
}
|
||||
|
||||
auto alpha =
|
||||
absl::get_if<tflite::gpu::Tensor<Linear, DataType::FLOAT32>>(&attr.alpha);
|
||||
if (!alpha) {
|
||||
return absl::InvalidArgumentError("Alpha is missing");
|
||||
}
|
||||
TensorLinearDescriptor desc;
|
||||
desc.storage_type =
|
||||
DeduceLinearStorageType(definition.GetPrimaryStorageType());
|
||||
desc.element_type = definition.GetPrimaryDataType();
|
||||
desc.UploadLinearData(*alpha);
|
||||
|
||||
result->args_.AddObject(
|
||||
"alpha", absl::make_unique<TensorLinearDescriptor>(std::move(desc)));
|
||||
|
||||
return absl::OkStatus();
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace cl
|
||||
|
@ -31,9 +31,9 @@ namespace tflite {
|
||||
namespace gpu {
|
||||
namespace cl {
|
||||
|
||||
absl::Status CreatePReLU(const CreationContext& creation_context,
|
||||
GPUOperation CreatePReLU(const DeviceInfo& device_info,
|
||||
const OperationDef& definition,
|
||||
const PReLUAttributes& attr, GPUOperation* result);
|
||||
const PReLUAttributes& attr);
|
||||
|
||||
} // namespace cl
|
||||
} // namespace gpu
|
||||
|
@ -52,8 +52,8 @@ TEST_F(OpenCLOperationTest, PReLUAlpha) {
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
TensorFloat32 dst_tensor;
|
||||
GPUOperation operation;
|
||||
ASSERT_OK(CreatePReLU(creation_context_, op_def, attr, &operation));
|
||||
GPUOperation operation =
|
||||
CreatePReLU(creation_context_.GetDeviceInfo(), op_def, attr);
|
||||
ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation,
|
||||
BHWC(1, 2, 1, 2), &dst_tensor));
|
||||
EXPECT_THAT(dst_tensor.data,
|
||||
@ -83,8 +83,8 @@ TEST_F(OpenCLOperationTest, PReLUAlphaClip) {
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
TensorFloat32 dst_tensor;
|
||||
GPUOperation operation;
|
||||
ASSERT_OK(CreatePReLU(creation_context_, op_def, attr, &operation));
|
||||
GPUOperation operation =
|
||||
CreatePReLU(creation_context_.GetDeviceInfo(), op_def, attr);
|
||||
ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation,
|
||||
BHWC(1, 2, 1, 2), &dst_tensor));
|
||||
EXPECT_THAT(dst_tensor.data,
|
||||
@ -93,6 +93,37 @@ TEST_F(OpenCLOperationTest, PReLUAlphaClip) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(OpenCLOperationTest, PReLUHWCAlpha) {
|
||||
TensorFloat32 src_tensor;
|
||||
src_tensor.shape = BHWC(1, 2, 1, 2);
|
||||
src_tensor.data = {0.0f, -1.0f, -2.0f, 3.0f};
|
||||
|
||||
PReLUAttributes attr;
|
||||
::tflite::gpu::Tensor<HWC, DataType::FLOAT32> hwc_tensor;
|
||||
hwc_tensor.shape = HWC(2, 1, 2);
|
||||
hwc_tensor.data = {0.5f, -2.0f, 0.7f, 4.7f};
|
||||
attr.alpha = hwc_tensor;
|
||||
attr.clip = 0.0;
|
||||
|
||||
for (auto storage : env_.GetSupportedStorages()) {
|
||||
for (auto precision : env_.GetSupportedPrecisions()) {
|
||||
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
|
||||
OperationDef op_def;
|
||||
op_def.precision = precision;
|
||||
auto data_type = DeduceDataTypeFromPrecision(precision);
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
TensorFloat32 dst_tensor;
|
||||
GPUOperation operation =
|
||||
CreatePReLU(creation_context_.GetDeviceInfo(), op_def, attr);
|
||||
ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation,
|
||||
BHWC(1, 2, 1, 2), &dst_tensor));
|
||||
EXPECT_THAT(dst_tensor.data,
|
||||
Pointwise(FloatNear(eps), {0.0f, 2.0f, -1.4f, 3.0f}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace cl
|
||||
} // namespace gpu
|
||||
|
@ -284,7 +284,8 @@ absl::Status GPUOperationFromNode(const CreationContext& creation_context,
|
||||
}
|
||||
case OperationType::PRELU: {
|
||||
auto attr = absl::any_cast<PReLUAttributes>(node.operation.attributes);
|
||||
return SelectPReLU(attr, creation_context, op_def, gpu_op);
|
||||
*gpu_op = SelectPReLU(attr, creation_context.GetDeviceInfo(), op_def);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
case OperationType::QUANTIZE_AND_DEQUANTIZE: {
|
||||
auto attr = absl::any_cast<QuantizeAndDequantizeAttributes>(
|
||||
|
@ -56,14 +56,11 @@ std::unique_ptr<GPUOperation> SelectReLU(const ReLUAttributes& attr,
|
||||
return absl::make_unique<GPUOperation>(CreateReLU(op_def, attr));
|
||||
}
|
||||
|
||||
absl::Status SelectPReLU(const PReLUAttributes& attr,
|
||||
const CreationContext& creation_context,
|
||||
const OperationDef& op_def,
|
||||
std::unique_ptr<GPUOperation>* ptr) {
|
||||
GPUOperation operation;
|
||||
RETURN_IF_ERROR(CreatePReLU(creation_context, op_def, attr, &operation));
|
||||
*ptr = absl::make_unique<GPUOperation>(std::move(operation));
|
||||
return absl::OkStatus();
|
||||
std::unique_ptr<GPUOperation> SelectPReLU(const PReLUAttributes& attr,
|
||||
const DeviceInfo& device_info,
|
||||
const OperationDef& op_def) {
|
||||
return absl::make_unique<GPUOperation>(
|
||||
CreatePReLU(device_info, op_def, attr));
|
||||
}
|
||||
|
||||
void SelectPooling(const Pooling2DAttributes& attr, const OperationDef& op_def,
|
||||
|
@ -34,10 +34,9 @@ void SelectLSTM(const OperationDef& op_def, const DeviceInfo& device_info,
|
||||
std::unique_ptr<GPUOperation> SelectReLU(const ReLUAttributes& attr,
|
||||
const OperationDef& op_def);
|
||||
|
||||
absl::Status SelectPReLU(const PReLUAttributes& attr,
|
||||
const CreationContext& creation_context,
|
||||
const OperationDef& op_def,
|
||||
std::unique_ptr<GPUOperation>* ptr);
|
||||
std::unique_ptr<GPUOperation> SelectPReLU(const PReLUAttributes& attr,
|
||||
const DeviceInfo& device_info,
|
||||
const OperationDef& op_def);
|
||||
|
||||
void SelectPooling(const Pooling2DAttributes& attr, const OperationDef& op_def,
|
||||
std::unique_ptr<GPUOperation>* ptr);
|
||||
|
Loading…
x
Reference in New Issue
Block a user