Extended support of SUB (and other elementwise ops).

OpenCL delegate supports SUB with runtime tensor as second argument.

PiperOrigin-RevId: 324726289
Change-Id: If26a72a5214bffc7b664f1902344ab04038ed3f5
This commit is contained in:
Raman Sarokin 2020-08-03 18:28:29 -07:00 committed by TensorFlower Gardener
parent e2865bb150
commit d353f49989
6 changed files with 143 additions and 151 deletions

View File

@ -98,53 +98,51 @@ std::string GetOneInputCode(const OperationType& op_type,
} }
std::string GetTwoInputCode(const OperationType& op_type, std::string GetTwoInputCode(const OperationType& op_type,
const std::string& result_var,
const std::string& input0, const std::string& input0,
const std::string& input1) { const std::string& input1,
bool swap_inputs = false) {
std::string result; std::string result;
switch (op_type) { switch (op_type) {
case OperationType::ADD: case OperationType::ADD:
result += "$0 += $1;\n"; result += "$0 = $1 + $2;\n";
break; break;
case OperationType::DIV: case OperationType::DIV:
result += "$0 /= $1;\n"; result += "$0 = $1 / $2;\n";
break; break;
case OperationType::MAXIMUM: case OperationType::MAXIMUM:
result += "$0 = max($0, $1);\n"; result += "$0 = max($1, $2);\n";
break; break;
case OperationType::MINIMUM: case OperationType::MINIMUM:
result += "$0 = min($0, $1);\n"; result += "$0 = min($1, $2);\n";
break; break;
case OperationType::MUL: case OperationType::MUL:
result += "$0 *= $1;\n"; result += "$0 = $1 * $2;\n";
break; break;
case OperationType::POW: case OperationType::POW:
result += "$0 = pow($0, $1);\n"; result += "$0 = pow($1, $2);\n";
break; break;
case OperationType::SQUARED_DIFF: case OperationType::SQUARED_DIFF:
result += "$0 -= $1;\n"; result += "$0 = ($1 - $2) * ($1 - $2);\n";
result += "$0 *= $0;\n";
break; break;
case OperationType::SUB: case OperationType::SUB:
result += "$0 -= $1;\n"; result += "$0 = $1 - $2;\n";
break; break;
default: default:
return "Unknown operation type;\n"; return "Unknown operation type;\n";
} }
return absl::Substitute(result, input0, input1); if (swap_inputs) {
return absl::Substitute(result, result_var, input1, input0);
} else {
return absl::Substitute(result, result_var, input0, input1);
} }
} // namespace
GPUOperation CreateElementwiseOneInput(const OperationDef& definition,
const OperationType& op_type) {
GPUOperation op(definition);
op.elementwise_ = true;
op.code_ = GetOneInputCode(op_type, definition.precision, "in_out_value");
return op;
} }
// Creates simple two input (first input is runtime tensor and second input is
// scalar argument) operation, for example sub, div, pow, etc.
GPUOperation CreateElementwiseOneRuntimeOneScalar( GPUOperation CreateElementwiseOneRuntimeOneScalar(
const CreationContext& creation_context, const OperationDef& definition, const OperationDef& definition, const OperationType& op_type,
const OperationType& op_type, float scalar_parameter) { float scalar_parameter, bool swap_inputs) {
GPUOperation op(definition); GPUOperation op(definition);
op.elementwise_ = true; op.elementwise_ = true;
if (definition.precision == CalculationsPrecision::F32) { if (definition.precision == CalculationsPrecision::F32) {
@ -152,15 +150,21 @@ GPUOperation CreateElementwiseOneRuntimeOneScalar(
} else { } else {
op.args_.AddHalf("scalar", half(scalar_parameter)); op.args_.AddHalf("scalar", half(scalar_parameter));
} }
op.code_ = GetTwoInputCode(op_type, "in_out_value", "args.scalar"); op.code_ =
"FLT4 second_val = (FLT4)(args.scalar, args.scalar, args.scalar, "
"args.scalar);\n";
op.code_ += GetTwoInputCode(op_type, "in_out_value", "in_out_value",
"second_val", swap_inputs);
return op; return op;
} }
// Creates simple two input(first input is runtime tensor and second input is
// constant linear tensor) operation, for example sub, div and etc.
absl::Status CreateElementwiseTwoInput( absl::Status CreateElementwiseTwoInput(
const CreationContext& creation_context, const OperationDef& definition, const CreationContext& creation_context, const OperationDef& definition,
const OperationType& op_type, const OperationType& op_type,
const tflite::gpu::Tensor<Linear, DataType::FLOAT32>& constant_tensor, const tflite::gpu::Tensor<Linear, DataType::FLOAT32>& constant_tensor,
GPUOperation* result) { bool swap_inputs, GPUOperation* result) {
const BHWC shape = BHWC(1, 1, 1, constant_tensor.shape.v); const BHWC shape = BHWC(1, 1, 1, constant_tensor.shape.v);
TensorStorageType storage_type = TensorStorageType storage_type =
SelectBestStorageType(*creation_context.context, *creation_context.device, SelectBestStorageType(*creation_context.context, *creation_context.device,
@ -187,15 +191,18 @@ absl::Status CreateElementwiseTwoInput(
result->code_ += " second_val.z = second_val.x;\n"; result->code_ += " second_val.z = second_val.x;\n";
result->code_ += " second_val.w = second_val.x;\n"; result->code_ += " second_val.w = second_val.x;\n";
} }
result->code_ += GetTwoInputCode(op_type, "in_out_value", "second_val"); result->code_ += GetTwoInputCode(op_type, "in_out_value", "in_out_value",
"second_val", swap_inputs);
return absl::OkStatus(); return absl::OkStatus();
} }
// Creates simple two input(first input is runtime tensor and second input is
// constant HWC tensor) operation, for example sub, div and etc.
absl::Status CreateElementwiseTwoInput( absl::Status CreateElementwiseTwoInput(
const CreationContext& creation_context, const OperationDef& definition, const CreationContext& creation_context, const OperationDef& definition,
const OperationType& op_type, const OperationType& op_type,
const tflite::gpu::Tensor<HWC, DataType::FLOAT32>& constant_tensor, const tflite::gpu::Tensor<HWC, DataType::FLOAT32>& constant_tensor,
GPUOperation* result) { bool swap_inputs, GPUOperation* result) {
const BHWC shape = BHWC(1, constant_tensor.shape.h, constant_tensor.shape.w, const BHWC shape = BHWC(1, constant_tensor.shape.h, constant_tensor.shape.w,
constant_tensor.shape.c); constant_tensor.shape.c);
TensorStorageType storage_type = TensorStorageType storage_type =
@ -225,11 +232,50 @@ absl::Status CreateElementwiseTwoInput(
result->code_ += " second_val.z = second_val.x;\n"; result->code_ += " second_val.z = second_val.x;\n";
result->code_ += " second_val.w = second_val.x;\n"; result->code_ += " second_val.w = second_val.x;\n";
} }
result->code_ += GetTwoInputCode(op_type, "in_out_value", "second_val"); result->code_ += GetTwoInputCode(op_type, "in_out_value", "in_out_value",
"second_val", swap_inputs);
return absl::OkStatus(); return absl::OkStatus();
} }
} // namespace
GPUOperation CreateElementwiseOneInput(const OperationDef& definition,
const OperationType& op_type) {
GPUOperation op(definition);
op.elementwise_ = true;
op.code_ = GetOneInputCode(op_type, definition.precision, "in_out_value");
return op;
}
absl::Status CreateElementwise(const CreationContext& creation_context,
const OperationDef& definition,
const OperationType& op_type,
const ElementwiseAttributes& attr,
GPUOperation* result) {
const float* scalar = absl::get_if<float>(&attr.param);
const auto* linear_tensor =
absl::get_if<tflite::gpu::Tensor<Linear, DataType::FLOAT32>>(&attr.param);
const auto* hwc_tensor =
absl::get_if<tflite::gpu::Tensor<HWC, DataType::FLOAT32>>(&attr.param);
if (scalar) {
*result = CreateElementwiseOneRuntimeOneScalar(
definition, op_type, *scalar, attr.runtime_tensor_is_second);
return absl::OkStatus();
} else if (linear_tensor) {
return CreateElementwiseTwoInput(creation_context, definition, op_type,
*linear_tensor,
attr.runtime_tensor_is_second, result);
} else if (hwc_tensor) {
return CreateElementwiseTwoInput(creation_context, definition, op_type,
*hwc_tensor, attr.runtime_tensor_is_second,
result);
}
return absl::UnimplementedError(
"No elementwise implementation for this case");
}
GPUOperation CreateElementwiseTwoInput(const OperationDef& definition, GPUOperation CreateElementwiseTwoInput(const OperationDef& definition,
const OperationType& op_type, const OperationType& op_type,
const BHWC& shape) { const BHWC& shape) {
@ -250,7 +296,8 @@ GPUOperation CreateElementwiseTwoInput(const OperationDef& definition,
op.code_ += " second_val.z = second_val.x;\n"; op.code_ += " second_val.z = second_val.x;\n";
op.code_ += " second_val.w = second_val.x;\n"; op.code_ += " second_val.w = second_val.x;\n";
} }
op.code_ += GetTwoInputCode(op_type, "in_out_value", "second_val"); op.code_ += GetTwoInputCode(op_type, "in_out_value", "in_out_value",
"second_val", false);
return op; return op;
} }

View File

@ -32,25 +32,11 @@ GPUOperation CreateElementwiseOneInput(const OperationDef& definition,
const OperationType& op_type); const OperationType& op_type);
// Creates simple two input(first input is runtime tensor and second input is // Creates simple two input(first input is runtime tensor and second input is
// scalar argument) operation, for example sub, div, pow, etc. // constant or linear/hwc tensor) operation, for example sub, div and etc.
GPUOperation CreateElementwiseOneRuntimeOneScalar( absl::Status CreateElementwise(const CreationContext& creation_context,
const CreationContext& creation_context, const OperationDef& definition, const OperationDef& definition,
const OperationType& op_type, float scalar_parameter);
// Creates simple two input(first input is runtime tensor and second input is
// constant linear tensor) operation, for example sub, div and etc.
absl::Status CreateElementwiseTwoInput(
const CreationContext& creation_context, const OperationDef& definition,
const OperationType& op_type, const OperationType& op_type,
const tflite::gpu::Tensor<Linear, DataType::FLOAT32>& constant_tensor, const ElementwiseAttributes& attr,
GPUOperation* result);
// Creates simple two input(first input is runtime tensor and second input is
// constant HWC tensor) operation, for example sub, div and etc.
absl::Status CreateElementwiseTwoInput(
const CreationContext& creation_context, const OperationDef& definition,
const OperationType& op_type,
const tflite::gpu::Tensor<HWC, DataType::FLOAT32>& constant_tensor,
GPUOperation* result); GPUOperation* result);
// Creates simple two input(2 runtime tensors) operation, for example // Creates simple two input(2 runtime tensors) operation, for example

View File

@ -546,9 +546,9 @@ TEST_F(OpenCLOperationTest, MaximumWithScalar) {
op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor; TensorFloat32 dst_tensor;
const float* scalar = absl::get_if<float>(&attr.param); GPUOperation operation;
GPUOperation operation = CreateElementwiseOneRuntimeOneScalar( ASSERT_OK(CreateElementwise(creation_context_, op_def,
creation_context_, op_def, OperationType::MAXIMUM, *scalar); OperationType::MAXIMUM, attr, &operation));
ASSERT_OK(ExecuteGPUOperation(src_tensor_0, creation_context_, &operation, ASSERT_OK(ExecuteGPUOperation(src_tensor_0, creation_context_, &operation,
BHWC(1, 4, 1, 1), &dst_tensor)); BHWC(1, 4, 1, 1), &dst_tensor));
EXPECT_THAT(dst_tensor.data, EXPECT_THAT(dst_tensor.data,
@ -578,9 +578,8 @@ TEST_F(OpenCLOperationTest, MaximumWithConstantLinearTensor) {
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor; TensorFloat32 dst_tensor;
GPUOperation operation; GPUOperation operation;
ASSERT_OK(CreateElementwiseTwoInput(creation_context_, op_def, ASSERT_OK(CreateElementwise(creation_context_, op_def,
OperationType::MAXIMUM, linear_tensor, OperationType::MAXIMUM, attr, &operation));
&operation));
ASSERT_OK(ExecuteGPUOperation(src_tensor_0, creation_context_, &operation, ASSERT_OK(ExecuteGPUOperation(src_tensor_0, creation_context_, &operation,
BHWC(1, 2, 1, 2), &dst_tensor)); BHWC(1, 2, 1, 2), &dst_tensor));
EXPECT_THAT(dst_tensor.data, EXPECT_THAT(dst_tensor.data,
@ -597,6 +596,8 @@ TEST_F(OpenCLOperationTest, MaximumWithConstantHWCTensor) {
::tflite::gpu::Tensor<HWC, DataType::FLOAT32> hwc_tensor; ::tflite::gpu::Tensor<HWC, DataType::FLOAT32> hwc_tensor;
hwc_tensor.shape = HWC(2, 1, 2); hwc_tensor.shape = HWC(2, 1, 2);
hwc_tensor.data = {0.5f, 2.0f, 0.7f, 4.7f}; hwc_tensor.data = {0.5f, 2.0f, 0.7f, 4.7f};
ElementwiseAttributes attr;
attr.param = hwc_tensor;
for (auto storage : env_.GetSupportedStorages()) { for (auto storage : env_.GetSupportedStorages()) {
for (auto precision : env_.GetSupportedPrecisions()) { for (auto precision : env_.GetSupportedPrecisions()) {
@ -608,9 +609,8 @@ TEST_F(OpenCLOperationTest, MaximumWithConstantHWCTensor) {
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor; TensorFloat32 dst_tensor;
GPUOperation operation; GPUOperation operation;
ASSERT_OK(CreateElementwiseTwoInput(creation_context_, op_def, ASSERT_OK(CreateElementwise(creation_context_, op_def,
OperationType::MAXIMUM, hwc_tensor, OperationType::MAXIMUM, attr, &operation));
&operation));
ASSERT_OK(ExecuteGPUOperation(src_tensor_0, creation_context_, &operation, ASSERT_OK(ExecuteGPUOperation(src_tensor_0, creation_context_, &operation,
BHWC(1, 2, 1, 2), &dst_tensor)); BHWC(1, 2, 1, 2), &dst_tensor));
EXPECT_THAT(dst_tensor.data, EXPECT_THAT(dst_tensor.data,
@ -626,6 +626,8 @@ TEST_F(OpenCLOperationTest, MaximumWithConstantHWCTensorBroadcastChannels) {
::tflite::gpu::Tensor<HWC, DataType::FLOAT32> hwc_tensor; ::tflite::gpu::Tensor<HWC, DataType::FLOAT32> hwc_tensor;
hwc_tensor.shape = HWC(2, 1, 1); hwc_tensor.shape = HWC(2, 1, 1);
hwc_tensor.data = {0.5f, 2.0f}; hwc_tensor.data = {0.5f, 2.0f};
ElementwiseAttributes attr;
attr.param = hwc_tensor;
for (auto storage : env_.GetSupportedStorages()) { for (auto storage : env_.GetSupportedStorages()) {
for (auto precision : env_.GetSupportedPrecisions()) { for (auto precision : env_.GetSupportedPrecisions()) {
@ -637,9 +639,8 @@ TEST_F(OpenCLOperationTest, MaximumWithConstantHWCTensorBroadcastChannels) {
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor; TensorFloat32 dst_tensor;
GPUOperation operation; GPUOperation operation;
ASSERT_OK(CreateElementwiseTwoInput(creation_context_, op_def, ASSERT_OK(CreateElementwise(creation_context_, op_def,
OperationType::MAXIMUM, hwc_tensor, OperationType::MAXIMUM, attr, &operation));
&operation));
ASSERT_OK(ExecuteGPUOperation(src_tensor_0, creation_context_, &operation, ASSERT_OK(ExecuteGPUOperation(src_tensor_0, creation_context_, &operation,
BHWC(1, 2, 1, 2), &dst_tensor)); BHWC(1, 2, 1, 2), &dst_tensor));
EXPECT_THAT(dst_tensor.data, EXPECT_THAT(dst_tensor.data,
@ -693,9 +694,9 @@ TEST_F(OpenCLOperationTest, MinimumWithScalar) {
op_def.src_tensors.push_back({data_type, storage, Layout::HWC}); op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC}); op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor; TensorFloat32 dst_tensor;
const float* scalar = absl::get_if<float>(&attr.param); GPUOperation operation;
GPUOperation operation = CreateElementwiseOneRuntimeOneScalar( ASSERT_OK(CreateElementwise(creation_context_, op_def,
creation_context_, op_def, OperationType::MINIMUM, *scalar); OperationType::MINIMUM, attr, &operation));
ASSERT_OK(ExecuteGPUOperation(src_tensor_0, creation_context_, &operation, ASSERT_OK(ExecuteGPUOperation(src_tensor_0, creation_context_, &operation,
BHWC(1, 4, 1, 1), &dst_tensor)); BHWC(1, 4, 1, 1), &dst_tensor));
EXPECT_THAT(dst_tensor.data, EXPECT_THAT(dst_tensor.data,
@ -788,6 +789,35 @@ TEST_F(OpenCLOperationTest, MulBroadcastChannels) {
} }
} }
TEST_F(OpenCLOperationTest, SubWithScalarAtFirstPosition) {
TensorFloat32 src_tensor_0;
src_tensor_0.shape = BHWC(1, 4, 1, 1);
src_tensor_0.data = {0.0f, -6.2f, 2.0f, -3.0f};
ElementwiseAttributes attr;
attr.param = 4.0f;
attr.runtime_tensor_is_second = true;
for (auto storage : env_.GetSupportedStorages()) {
for (auto precision : env_.GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-2f;
OperationDef op_def;
op_def.precision = precision;
auto data_type = DeduceDataTypeFromPrecision(precision);
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
GPUOperation operation;
ASSERT_OK(CreateElementwise(creation_context_, op_def, OperationType::SUB,
attr, &operation));
ASSERT_OK(ExecuteGPUOperation(src_tensor_0, creation_context_, &operation,
BHWC(1, 4, 1, 1), &dst_tensor));
EXPECT_THAT(dst_tensor.data,
Pointwise(FloatNear(eps), {4.0f, 10.2f, 2.0f, 7.0f}));
}
}
}
} // namespace } // namespace
} // namespace cl } // namespace cl
} // namespace gpu } // namespace gpu

View File

@ -159,31 +159,11 @@ absl::Status GPUOperationFromNode(const CreationContext& creation_context,
} else if (inputs.size() == 1 && node.operation.attributes.has_value()) { } else if (inputs.size() == 1 && node.operation.attributes.has_value()) {
auto attr = auto attr =
absl::any_cast<ElementwiseAttributes>(node.operation.attributes); absl::any_cast<ElementwiseAttributes>(node.operation.attributes);
const float* scalar = absl::get_if<float>(&attr.param);
const auto* linear_tensor =
absl::get_if<tflite::gpu::Tensor<Linear, DataType::FLOAT32>>(
&attr.param);
const auto* hwc_tensor =
absl::get_if<tflite::gpu::Tensor<HWC, DataType::FLOAT32>>(
&attr.param);
if (scalar) {
GPUOperation operation = CreateElementwiseOneRuntimeOneScalar(
creation_context, op_def, op_type, *scalar);
*gpu_op = absl::make_unique<GPUOperation>(std::move(operation));
return absl::OkStatus();
} else if (linear_tensor) {
GPUOperation operation; GPUOperation operation;
RETURN_IF_ERROR(CreateElementwiseTwoInput( RETURN_IF_ERROR(CreateElementwise(creation_context, op_def, op_type,
creation_context, op_def, op_type, *linear_tensor, &operation)); attr, &operation));
*gpu_op = absl::make_unique<GPUOperation>(std::move(operation)); *gpu_op = absl::make_unique<GPUOperation>(std::move(operation));
return absl::OkStatus(); return absl::OkStatus();
} else if (hwc_tensor) {
GPUOperation operation;
RETURN_IF_ERROR(CreateElementwiseTwoInput(
creation_context, op_def, op_type, *hwc_tensor, &operation));
*gpu_op = absl::make_unique<GPUOperation>(std::move(operation));
return absl::OkStatus();
}
} }
return absl::UnimplementedError(absl::StrCat( return absl::UnimplementedError(absl::StrCat(
"No support of ", node.operation.type, " with this parameters")); "No support of ", node.operation.type, " with this parameters"));
@ -289,44 +269,6 @@ absl::Status GPUOperationFromNode(const CreationContext& creation_context,
absl::make_unique<MeanStdDevNormalization>(std::move(operation)); absl::make_unique<MeanStdDevNormalization>(std::move(operation));
return absl::OkStatus(); return absl::OkStatus();
} }
case OperationType::MUL: {
if (inputs.size() == 2) {
GPUOperation operation =
CreateElementwiseTwoInput(op_def, op_type, inputs[1]->tensor.shape);
*gpu_op = absl::make_unique<GPUOperation>(std::move(operation));
return absl::OkStatus();
} else if (inputs.size() == 1 && node.operation.attributes.has_value()) {
auto attr =
absl::any_cast<ElementwiseAttributes>(node.operation.attributes);
const float* scalar = absl::get_if<float>(&attr.param);
const auto* linear_tensor =
absl::get_if<tflite::gpu::Tensor<Linear, DataType::FLOAT32>>(
&attr.param);
const auto* hwc_tensor =
absl::get_if<tflite::gpu::Tensor<HWC, DataType::FLOAT32>>(
&attr.param);
if (scalar) {
GPUOperation operation = CreateElementwiseOneRuntimeOneScalar(
creation_context, op_def, op_type, *scalar);
*gpu_op = absl::make_unique<GPUOperation>(std::move(operation));
return absl::OkStatus();
} else if (linear_tensor) {
GPUOperation operation;
RETURN_IF_ERROR(CreateElementwiseTwoInput(
creation_context, op_def, op_type, *linear_tensor, &operation));
*gpu_op = absl::make_unique<GPUOperation>(std::move(operation));
return absl::OkStatus();
} else if (hwc_tensor) {
GPUOperation operation;
RETURN_IF_ERROR(CreateElementwiseTwoInput(
creation_context, op_def, op_type, *hwc_tensor, &operation));
*gpu_op = absl::make_unique<GPUOperation>(std::move(operation));
return absl::OkStatus();
}
}
return absl::UnimplementedError(absl::StrCat(
"No support of ", node.operation.type, " with this parameters"));
}
case OperationType::PAD: { case OperationType::PAD: {
auto attr = absl::any_cast<PadAttributes>(node.operation.attributes); auto attr = absl::any_cast<PadAttributes>(node.operation.attributes);
SelectPadding(attr, op_def, gpu_op); SelectPadding(attr, op_def, gpu_op);
@ -404,6 +346,7 @@ absl::Status GPUOperationFromNode(const CreationContext& creation_context,
case OperationType::DIV: case OperationType::DIV:
case OperationType::MAXIMUM: case OperationType::MAXIMUM:
case OperationType::MINIMUM: case OperationType::MINIMUM:
case OperationType::MUL:
case OperationType::POW: case OperationType::POW:
case OperationType::SQUARED_DIFF: case OperationType::SQUARED_DIFF:
case OperationType::SUB: { case OperationType::SUB: {
@ -415,31 +358,11 @@ absl::Status GPUOperationFromNode(const CreationContext& creation_context,
} else if (inputs.size() == 1 && node.operation.attributes.has_value()) { } else if (inputs.size() == 1 && node.operation.attributes.has_value()) {
auto attr = auto attr =
absl::any_cast<ElementwiseAttributes>(node.operation.attributes); absl::any_cast<ElementwiseAttributes>(node.operation.attributes);
const float* scalar = absl::get_if<float>(&attr.param);
const auto* linear_tensor =
absl::get_if<tflite::gpu::Tensor<Linear, DataType::FLOAT32>>(
&attr.param);
const auto* hwc_tensor =
absl::get_if<tflite::gpu::Tensor<HWC, DataType::FLOAT32>>(
&attr.param);
if (scalar) {
GPUOperation operation = CreateElementwiseOneRuntimeOneScalar(
creation_context, op_def, op_type, *scalar);
*gpu_op = absl::make_unique<GPUOperation>(std::move(operation));
return absl::OkStatus();
} else if (linear_tensor) {
GPUOperation operation; GPUOperation operation;
RETURN_IF_ERROR(CreateElementwiseTwoInput( RETURN_IF_ERROR(CreateElementwise(creation_context, op_def, op_type,
creation_context, op_def, op_type, *linear_tensor, &operation)); attr, &operation));
*gpu_op = absl::make_unique<GPUOperation>(std::move(operation)); *gpu_op = absl::make_unique<GPUOperation>(std::move(operation));
return absl::OkStatus(); return absl::OkStatus();
} else if (hwc_tensor) {
GPUOperation operation;
RETURN_IF_ERROR(CreateElementwiseTwoInput(
creation_context, op_def, op_type, *hwc_tensor, &operation));
*gpu_op = absl::make_unique<GPUOperation>(std::move(operation));
return absl::OkStatus();
}
} }
return absl::UnimplementedError(absl::StrCat( return absl::UnimplementedError(absl::StrCat(
"No support of ", node.operation.type, " with this parameters")); "No support of ", node.operation.type, " with this parameters"));

View File

@ -847,6 +847,8 @@ class ElementwiseOperationParser : public TFLiteOperationParser {
/*outputs=*/1)); /*outputs=*/1));
ElementwiseAttributes attr; ElementwiseAttributes attr;
RETURN_IF_ERROR(ParseInputsWithConstTensor(node, reader, &attr.param)); RETURN_IF_ERROR(ParseInputsWithConstTensor(node, reader, &attr.param));
attr.runtime_tensor_is_second =
IsConstantTensor(reader->GetInputTensor(0));
node->operation.attributes = std::move(attr); node->operation.attributes = std::move(attr);
} else { } else {
return absl::InvalidArgumentError("Incorrect operation type passed"); return absl::InvalidArgumentError("Incorrect operation type passed");

View File

@ -490,6 +490,10 @@ BHWC CalculateOutputShape(const BHWC& input, const MeanAttributes& attr);
struct ElementwiseAttributes { struct ElementwiseAttributes {
TensorOrScalar param; TensorOrScalar param;
// For elementwise operation with 2 inputs op(A, B), runtime_tensor_is_second
// true when runtime tensor is B(on second position). this is important for
// ops that non commutative, for example substract.
bool runtime_tensor_is_second = false;
}; };
struct ReshapeAttributes { struct ReshapeAttributes {