Added support of HWC alpha in PReLU.

PiperOrigin-RevId: 327541760
Change-Id: I5491196f0990580abe51129a71c799c3fa9bb561
This commit is contained in:
Raman Sarokin 2020-08-19 17:41:05 -07:00 committed by TensorFlower Gardener
parent 313edafd6f
commit efa873b2a1
7 changed files with 101 additions and 44 deletions

View File

@ -941,6 +941,7 @@ cc_library(
"//tensorflow/lite/delegates/gpu/cl:cl_context",
"//tensorflow/lite/delegates/gpu/cl:cl_kernel",
"//tensorflow/lite/delegates/gpu/cl:linear_storage",
"//tensorflow/lite/delegates/gpu/cl:storage_type_util",
"//tensorflow/lite/delegates/gpu/common:data_type",
"//tensorflow/lite/delegates/gpu/common:operations",
"//tensorflow/lite/delegates/gpu/common:status",

View File

@ -18,47 +18,75 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "absl/types/variant.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
#include "tensorflow/lite/delegates/gpu/cl/storage_type_util.h"
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
namespace tflite {
namespace gpu {
namespace cl {
absl::Status CreatePReLU(const CreationContext& creation_context,
GPUOperation CreatePReLU(const DeviceInfo& device_info,
const OperationDef& definition,
const PReLUAttributes& attr, GPUOperation* result) {
*result = GPUOperation(definition);
result->elementwise_ = true;
const PReLUAttributes& attr) {
GPUOperation result(definition);
result.elementwise_ = true;
std::string alpha_read;
auto alpha_linear =
absl::get_if<tflite::gpu::Tensor<Linear, DataType::FLOAT32>>(&attr.alpha);
if (alpha_linear) {
TensorLinearDescriptor desc;
desc.storage_type =
DeduceLinearStorageType(definition.GetPrimaryStorageType());
desc.element_type = definition.GetPrimaryDataType();
desc.UploadLinearData(*alpha_linear);
result.args_.AddObject(
"alpha", absl::make_unique<TensorLinearDescriptor>(std::move(desc)));
alpha_read = "FLT4 alpha_val = args.alpha.Read(S_COORD);\n";
}
auto alpha_hwc =
absl::get_if<tflite::gpu::Tensor<HWC, DataType::FLOAT32>>(&attr.alpha);
if (alpha_hwc) {
const BHWC shape =
BHWC(1, alpha_hwc->shape.h, alpha_hwc->shape.w, alpha_hwc->shape.c);
TensorStorageType storage_type = SelectBestStorageType(
device_info, shape, definition.GetPrimaryStorageType(),
definition.GetDataType(), Layout::HWC);
TensorDescriptor desc{definition.GetDataType(), storage_type, Layout::HWC};
desc.UploadData(*alpha_hwc);
result.args_.AddObject(
"alpha", absl::make_unique<TensorDescriptor>(std::move(desc)));
const std::string x_coord = shape.w == 1 ? "0" : "X_COORD";
const std::string y_coord = shape.h == 1 ? "0" : "Y_COORD";
const std::string s_coord = shape.c == 1 ? "0" : "S_COORD";
alpha_read = absl::StrCat("FLT4 alpha_val = args.alpha.Read(", x_coord,
", ", y_coord, ", ", s_coord, ");\n");
if (shape.c == 1) {
alpha_read += " alpha_val.y = alpha_val.x;\n";
alpha_read += " alpha_val.z = alpha_val.x;\n";
alpha_read += " alpha_val.w = alpha_val.x;\n";
}
}
if (attr.clip != 0) {
if (definition.precision == CalculationsPrecision::F32) {
result->args_.AddFloat("clip", attr.clip);
result.args_.AddFloat("clip", attr.clip);
} else {
result->args_.AddHalf("clip", half(attr.clip));
result.args_.AddHalf("clip", half(attr.clip));
}
result->code_ =
result.code_ =
alpha_read +
"in_out_value = clamp(in_out_value, (FLT4)(0.0f), (FLT4)(args.clip)) + "
"min((FLT4)(0.0f), in_out_value) * args.alpha.Read(S_COORD);";
"min((FLT4)(0.0f), in_out_value) * alpha_val;";
} else {
result->code_ =
result.code_ =
alpha_read +
"in_out_value = max((FLT4)(0.0f), in_out_value) + min((FLT4)(0.0f), "
"in_out_value) * args.alpha.Read(S_COORD);";
"in_out_value) * alpha_val;";
}
auto alpha =
absl::get_if<tflite::gpu::Tensor<Linear, DataType::FLOAT32>>(&attr.alpha);
if (!alpha) {
return absl::InvalidArgumentError("Alpha is missing");
}
TensorLinearDescriptor desc;
desc.storage_type =
DeduceLinearStorageType(definition.GetPrimaryStorageType());
desc.element_type = definition.GetPrimaryDataType();
desc.UploadLinearData(*alpha);
result->args_.AddObject(
"alpha", absl::make_unique<TensorLinearDescriptor>(std::move(desc)));
return absl::OkStatus();
return result;
}
} // namespace cl

View File

@ -31,9 +31,9 @@ namespace tflite {
namespace gpu {
namespace cl {
absl::Status CreatePReLU(const CreationContext& creation_context,
GPUOperation CreatePReLU(const DeviceInfo& device_info,
const OperationDef& definition,
const PReLUAttributes& attr, GPUOperation* result);
const PReLUAttributes& attr);
} // namespace cl
} // namespace gpu

View File

@ -52,8 +52,8 @@ TEST_F(OpenCLOperationTest, PReLUAlpha) {
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
GPUOperation operation;
ASSERT_OK(CreatePReLU(creation_context_, op_def, attr, &operation));
GPUOperation operation =
CreatePReLU(creation_context_.GetDeviceInfo(), op_def, attr);
ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation,
BHWC(1, 2, 1, 2), &dst_tensor));
EXPECT_THAT(dst_tensor.data,
@ -83,8 +83,8 @@ TEST_F(OpenCLOperationTest, PReLUAlphaClip) {
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
GPUOperation operation;
ASSERT_OK(CreatePReLU(creation_context_, op_def, attr, &operation));
GPUOperation operation =
CreatePReLU(creation_context_.GetDeviceInfo(), op_def, attr);
ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation,
BHWC(1, 2, 1, 2), &dst_tensor));
EXPECT_THAT(dst_tensor.data,
@ -93,6 +93,37 @@ TEST_F(OpenCLOperationTest, PReLUAlphaClip) {
}
}
TEST_F(OpenCLOperationTest, PReLUHWCAlpha) {
TensorFloat32 src_tensor;
src_tensor.shape = BHWC(1, 2, 1, 2);
src_tensor.data = {0.0f, -1.0f, -2.0f, 3.0f};
PReLUAttributes attr;
::tflite::gpu::Tensor<HWC, DataType::FLOAT32> hwc_tensor;
hwc_tensor.shape = HWC(2, 1, 2);
hwc_tensor.data = {0.5f, -2.0f, 0.7f, 4.7f};
attr.alpha = hwc_tensor;
attr.clip = 0.0;
for (auto storage : env_.GetSupportedStorages()) {
for (auto precision : env_.GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
OperationDef op_def;
op_def.precision = precision;
auto data_type = DeduceDataTypeFromPrecision(precision);
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
GPUOperation operation =
CreatePReLU(creation_context_.GetDeviceInfo(), op_def, attr);
ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation,
BHWC(1, 2, 1, 2), &dst_tensor));
EXPECT_THAT(dst_tensor.data,
Pointwise(FloatNear(eps), {0.0f, 2.0f, -1.4f, 3.0f}));
}
}
}
} // namespace
} // namespace cl
} // namespace gpu

View File

@ -284,7 +284,8 @@ absl::Status GPUOperationFromNode(const CreationContext& creation_context,
}
case OperationType::PRELU: {
auto attr = absl::any_cast<PReLUAttributes>(node.operation.attributes);
return SelectPReLU(attr, creation_context, op_def, gpu_op);
*gpu_op = SelectPReLU(attr, creation_context.GetDeviceInfo(), op_def);
return absl::OkStatus();
}
case OperationType::QUANTIZE_AND_DEQUANTIZE: {
auto attr = absl::any_cast<QuantizeAndDequantizeAttributes>(

View File

@ -56,14 +56,11 @@ std::unique_ptr<GPUOperation> SelectReLU(const ReLUAttributes& attr,
return absl::make_unique<GPUOperation>(CreateReLU(op_def, attr));
}
absl::Status SelectPReLU(const PReLUAttributes& attr,
const CreationContext& creation_context,
const OperationDef& op_def,
std::unique_ptr<GPUOperation>* ptr) {
GPUOperation operation;
RETURN_IF_ERROR(CreatePReLU(creation_context, op_def, attr, &operation));
*ptr = absl::make_unique<GPUOperation>(std::move(operation));
return absl::OkStatus();
std::unique_ptr<GPUOperation> SelectPReLU(const PReLUAttributes& attr,
const DeviceInfo& device_info,
const OperationDef& op_def) {
return absl::make_unique<GPUOperation>(
CreatePReLU(device_info, op_def, attr));
}
void SelectPooling(const Pooling2DAttributes& attr, const OperationDef& op_def,

View File

@ -34,10 +34,9 @@ void SelectLSTM(const OperationDef& op_def, const DeviceInfo& device_info,
std::unique_ptr<GPUOperation> SelectReLU(const ReLUAttributes& attr,
const OperationDef& op_def);
absl::Status SelectPReLU(const PReLUAttributes& attr,
const CreationContext& creation_context,
const OperationDef& op_def,
std::unique_ptr<GPUOperation>* ptr);
std::unique_ptr<GPUOperation> SelectPReLU(const PReLUAttributes& attr,
const DeviceInfo& device_info,
const OperationDef& op_def);
void SelectPooling(const Pooling2DAttributes& attr, const OperationDef& op_def,
std::unique_ptr<GPUOperation>* ptr);