Extend MUL operation with constant HWC tensor.
PiperOrigin-RevId: 326276881 Change-Id: Ib68142bdb9f69667543f13039ca0d3a699466ed6
This commit is contained in:
parent
fcba9e4c22
commit
f11f25daaa
@ -1175,6 +1175,10 @@ class MulOperationParser : public TFLiteOperationParser {
|
||||
Tensor<Scalar, DataType::FLOAT32> tensor;
|
||||
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
|
||||
attr.param = tensor.data[0];
|
||||
} else if (constant_dims->size == 3) {
|
||||
Tensor<HWC, DataType::FLOAT32> tensor;
|
||||
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
|
||||
attr.param = std::move(tensor);
|
||||
} else {
|
||||
Tensor<Linear, DataType::FLOAT32> tensor;
|
||||
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
|
||||
|
@ -349,6 +349,7 @@ cc_library(
|
||||
srcs = ["mul.cc"],
|
||||
hdrs = ["mul.h"],
|
||||
deps = [
|
||||
"//tensorflow/lite/delegates/gpu/common:convert",
|
||||
"//tensorflow/lite/delegates/gpu/common:operations",
|
||||
"//tensorflow/lite/delegates/gpu/common:status",
|
||||
"//tensorflow/lite/delegates/gpu/common:types",
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/convert.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
||||
|
||||
@ -81,18 +82,10 @@ absl::Status GenerateApplyMaskCode(const NodeShader::GenerationContext& ctx,
|
||||
absl::Status GenerateMultiplyScalarCode(
|
||||
const NodeShader::GenerationContext& ctx, GeneratedCode* generated_code) {
|
||||
const auto& attr = absl::any_cast<const ElementwiseAttributes&>(ctx.op_attr);
|
||||
auto muls = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&attr.param);
|
||||
auto scalar = absl::get_if<float>(&attr.param);
|
||||
|
||||
const auto* hwc_tensor =
|
||||
absl::get_if<Tensor<HWC, DataType::FLOAT32>>(&attr.param);
|
||||
if (hwc_tensor) {
|
||||
return absl::UnimplementedError("Mul does not support HWC constant tensor");
|
||||
}
|
||||
|
||||
if (scalar) {
|
||||
if (absl::holds_alternative<float>(attr.param)) {
|
||||
*generated_code = {
|
||||
/*parameters=*/{{"scalar", *scalar}},
|
||||
/*parameters=*/{{"scalar", absl::get<float>(attr.param)}},
|
||||
/*objects=*/{},
|
||||
/*shared_variables=*/{},
|
||||
/*workload=*/uint3(),
|
||||
@ -101,13 +94,16 @@ absl::Status GenerateMultiplyScalarCode(
|
||||
/*input=*/IOStructure::AUTO,
|
||||
/*output=*/IOStructure::AUTO,
|
||||
};
|
||||
} else {
|
||||
if (!muls) {
|
||||
return absl::InvalidArgumentError("Empty parameters for Multiplication.");
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
if (absl::holds_alternative<Tensor<Linear, DataType::FLOAT32>>(attr.param)) {
|
||||
*generated_code = {
|
||||
/*parameters=*/{},
|
||||
/*objects=*/{{"mul_buffer", MakeReadonlyObject(muls->data)}},
|
||||
/*objects=*/
|
||||
{{"mul_buffer",
|
||||
MakeReadonlyObject(
|
||||
absl::get<Tensor<Linear, DataType::FLOAT32>>(attr.param).data)}},
|
||||
/*shared_variables=*/{},
|
||||
// Declare workload explicitly because shader depends on gid.z.
|
||||
/*workload=*/
|
||||
@ -119,9 +115,35 @@ absl::Status GenerateMultiplyScalarCode(
|
||||
/*input=*/IOStructure::AUTO,
|
||||
/*output=*/IOStructure::AUTO,
|
||||
};
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
if (absl::holds_alternative<Tensor<HWC, DataType::FLOAT32>>(attr.param)) {
|
||||
*generated_code = {
|
||||
/*parameters=*/{},
|
||||
/*objects=*/
|
||||
{{"hwc_buffer",
|
||||
MakeReadonlyObject(
|
||||
uint3(static_cast<int>(ctx.input_shapes[0][2]),
|
||||
static_cast<int>(ctx.input_shapes[0][1]),
|
||||
DivideRoundUp(static_cast<int>(ctx.input_shapes[0][3]), 4)),
|
||||
ConvertToPHWC4(
|
||||
absl::get<Tensor<HWC, DataType::FLOAT32>>(attr.param)))}},
|
||||
/*shared_variables=*/{},
|
||||
// Declare workload explicitly because shader depends on gid.z.
|
||||
/*workload=*/
|
||||
uint3(static_cast<int>(ctx.input_shapes[0][2]),
|
||||
static_cast<int>(ctx.input_shapes[0][1]),
|
||||
DivideRoundUp(static_cast<int>(ctx.input_shapes[0][3]), 4)),
|
||||
/*workgroup=*/uint3(),
|
||||
/*source_code=*/"value_0 *= $hwc_buffer[gid.x, gid.y, gid.z]$;",
|
||||
/*input=*/IOStructure::AUTO,
|
||||
/*output=*/IOStructure::AUTO,
|
||||
};
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
return absl::InvalidArgumentError("Unsupported Multiplication case.");
|
||||
}
|
||||
|
||||
class Multiply : public NodeShader {
|
||||
|
@ -74,6 +74,32 @@ TEST(MulTest, Linear) {
|
||||
EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {2, 6, 6, 12}));
|
||||
}
|
||||
|
||||
TEST(MulTest, ConstTensor3D) {
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
input.shape = BHWC(1, 1, 2, 2);
|
||||
|
||||
TensorRef<BHWC> output;
|
||||
output.type = DataType::FLOAT32;
|
||||
output.ref = 1;
|
||||
output.shape = BHWC(1, 1, 2, 2);
|
||||
|
||||
ElementwiseAttributes attr;
|
||||
Tensor<HWC, DataType::FLOAT32> tensor_3d;
|
||||
tensor_3d.shape.h = 1;
|
||||
tensor_3d.shape.w = 2;
|
||||
tensor_3d.shape.c = 2;
|
||||
tensor_3d.id = 2;
|
||||
tensor_3d.data = {-2, 2, -3, 3};
|
||||
attr.param = std::move(tensor_3d);
|
||||
|
||||
SingleOpModel model({ToString(OperationType::MUL), attr}, {input}, {output});
|
||||
ASSERT_TRUE(model.PopulateTensor(0, {1, 2, 3, 4}));
|
||||
ASSERT_OK(model.Invoke(*NewMultiplyNodeShader()));
|
||||
EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {-2, 4, -9, 12}));
|
||||
}
|
||||
|
||||
TEST(MulTest, MaskChannel1) {
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
|
Loading…
x
Reference in New Issue
Block a user