Extend MUL operation with constant HWC tensor.

PiperOrigin-RevId: 326276881
Change-Id: Ib68142bdb9f69667543f13039ca0d3a699466ed6
This commit is contained in:
A. Unique TensorFlower 2020-08-12 11:40:57 -07:00 committed by TensorFlower Gardener
parent fcba9e4c22
commit f11f25daaa
4 changed files with 69 additions and 16 deletions

View File

@ -1175,6 +1175,10 @@ class MulOperationParser : public TFLiteOperationParser {
Tensor<Scalar, DataType::FLOAT32> tensor;
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
attr.param = tensor.data[0];
} else if (constant_dims->size == 3) {
Tensor<HWC, DataType::FLOAT32> tensor;
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
attr.param = std::move(tensor);
} else {
Tensor<Linear, DataType::FLOAT32> tensor;
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));

View File

@ -349,6 +349,7 @@ cc_library(
srcs = ["mul.cc"],
hdrs = ["mul.h"],
deps = [
"//tensorflow/lite/delegates/gpu/common:convert",
"//tensorflow/lite/delegates/gpu/common:operations",
"//tensorflow/lite/delegates/gpu/common:status",
"//tensorflow/lite/delegates/gpu/common:types",

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/lite/delegates/gpu/common/convert.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
#include "tensorflow/lite/delegates/gpu/common/types.h"
@ -81,18 +82,10 @@ absl::Status GenerateApplyMaskCode(const NodeShader::GenerationContext& ctx,
absl::Status GenerateMultiplyScalarCode(
const NodeShader::GenerationContext& ctx, GeneratedCode* generated_code) {
const auto& attr = absl::any_cast<const ElementwiseAttributes&>(ctx.op_attr);
auto muls = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&attr.param);
auto scalar = absl::get_if<float>(&attr.param);
const auto* hwc_tensor =
absl::get_if<Tensor<HWC, DataType::FLOAT32>>(&attr.param);
if (hwc_tensor) {
return absl::UnimplementedError("Mul does not support HWC constant tensor");
}
if (scalar) {
if (absl::holds_alternative<float>(attr.param)) {
*generated_code = {
/*parameters=*/{{"scalar", *scalar}},
/*parameters=*/{{"scalar", absl::get<float>(attr.param)}},
/*objects=*/{},
/*shared_variables=*/{},
/*workload=*/uint3(),
@ -101,13 +94,16 @@ absl::Status GenerateMultiplyScalarCode(
/*input=*/IOStructure::AUTO,
/*output=*/IOStructure::AUTO,
};
} else {
if (!muls) {
return absl::InvalidArgumentError("Empty parameters for Multiplication.");
}
return absl::OkStatus();
}
if (absl::holds_alternative<Tensor<Linear, DataType::FLOAT32>>(attr.param)) {
*generated_code = {
/*parameters=*/{},
/*objects=*/{{"mul_buffer", MakeReadonlyObject(muls->data)}},
/*objects=*/
{{"mul_buffer",
MakeReadonlyObject(
absl::get<Tensor<Linear, DataType::FLOAT32>>(attr.param).data)}},
/*shared_variables=*/{},
// Declare workload explicitly because shader depends on gid.z.
/*workload=*/
@ -119,9 +115,35 @@ absl::Status GenerateMultiplyScalarCode(
/*input=*/IOStructure::AUTO,
/*output=*/IOStructure::AUTO,
};
return absl::OkStatus();
}
return absl::OkStatus();
if (absl::holds_alternative<Tensor<HWC, DataType::FLOAT32>>(attr.param)) {
*generated_code = {
/*parameters=*/{},
/*objects=*/
{{"hwc_buffer",
MakeReadonlyObject(
uint3(static_cast<int>(ctx.input_shapes[0][2]),
static_cast<int>(ctx.input_shapes[0][1]),
DivideRoundUp(static_cast<int>(ctx.input_shapes[0][3]), 4)),
ConvertToPHWC4(
absl::get<Tensor<HWC, DataType::FLOAT32>>(attr.param)))}},
/*shared_variables=*/{},
// Declare workload explicitly because shader depends on gid.z.
/*workload=*/
uint3(static_cast<int>(ctx.input_shapes[0][2]),
static_cast<int>(ctx.input_shapes[0][1]),
DivideRoundUp(static_cast<int>(ctx.input_shapes[0][3]), 4)),
/*workgroup=*/uint3(),
/*source_code=*/"value_0 *= $hwc_buffer[gid.x, gid.y, gid.z]$;",
/*input=*/IOStructure::AUTO,
/*output=*/IOStructure::AUTO,
};
return absl::OkStatus();
}
return absl::InvalidArgumentError("Unsupported Multiplication case.");
}
class Multiply : public NodeShader {

View File

@ -74,6 +74,32 @@ TEST(MulTest, Linear) {
EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {2, 6, 6, 12}));
}
TEST(MulTest, ConstTensor3D) {
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
input.shape = BHWC(1, 1, 2, 2);
TensorRef<BHWC> output;
output.type = DataType::FLOAT32;
output.ref = 1;
output.shape = BHWC(1, 1, 2, 2);
ElementwiseAttributes attr;
Tensor<HWC, DataType::FLOAT32> tensor_3d;
tensor_3d.shape.h = 1;
tensor_3d.shape.w = 2;
tensor_3d.shape.c = 2;
tensor_3d.id = 2;
tensor_3d.data = {-2, 2, -3, 3};
attr.param = std::move(tensor_3d);
SingleOpModel model({ToString(OperationType::MUL), attr}, {input}, {output});
ASSERT_TRUE(model.PopulateTensor(0, {1, 2, 3, 4}));
ASSERT_OK(model.Invoke(*NewMultiplyNodeShader()));
EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {-2, 4, -9, 12}));
}
TEST(MulTest, MaskChannel1) {
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;