Remove APPLY_MASK and MULTIPLY_SCALAR enum values, use MUL instead.

Rename MultiplyScalarAttributes to MultiplyAttributes.

PiperOrigin-RevId: 292228929
Change-Id: I459585d62e4161a93d6ec163bdbc0494c37eaf0e
This commit is contained in:
A. Unique TensorFlower 2020-01-29 15:17:40 -08:00 committed by TensorFlower Gardener
parent abb256c1f8
commit 71c6f97e2d
25 changed files with 385 additions and 541 deletions

View File

@ -39,39 +39,6 @@ cc_test(
], ],
) )
cc_library(
name = "apply_mask",
srcs = ["apply_mask.cc"],
hdrs = ["apply_mask.h"],
deps = [
":gpu_operation",
":util",
":work_group_picking",
"//tensorflow/lite/delegates/gpu/cl:cl_kernel",
"//tensorflow/lite/delegates/gpu/common:operations",
"//tensorflow/lite/delegates/gpu/common:status",
"//tensorflow/lite/delegates/gpu/common:types",
"@com_google_absl//absl/strings",
],
)
cc_test(
name = "apply_mask_test",
srcs = ["apply_mask_test.cc"],
linkstatic = True,
tags = tf_gpu_tests_tags() + [
"linux",
"local",
],
deps = [
":apply_mask",
":cl_test",
"//tensorflow/lite/delegates/gpu/common:operations",
"//tensorflow/lite/delegates/gpu/common:status",
"@com_google_googletest//:gtest_main",
],
)
cc_library( cc_library(
name = "cl_test", name = "cl_test",
testonly = 1, testonly = 1,
@ -1328,7 +1295,6 @@ test_suite(
name = "all_tests", name = "all_tests",
tests = [ tests = [
"add_test", "add_test",
"apply_mask_test",
"concat_test", "concat_test",
"conv_buffer_1x1_test", "conv_buffer_1x1_test",
"conv_buffer_test", "conv_buffer_test",

View File

@ -1,103 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/cl/kernels/apply_mask.h"
#include <string>
#include <vector>
#include "absl/strings/str_cat.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/common/types.h"
namespace tflite {
namespace gpu {
namespace cl {
ApplyMask::ApplyMask(ApplyMask&& operation)
: ElementwiseOperation(std::move(operation)),
mask_type_(operation.mask_type_),
link_index_(operation.link_index_) {}
ApplyMask& ApplyMask::operator=(ApplyMask&& operation) {
if (this != &operation) {
mask_type_ = operation.mask_type_;
link_index_ = operation.link_index_;
ElementwiseOperation::operator=(std::move(operation));
}
return *this;
}
void ApplyMask::SetLinkIndex(int index) { link_index_ = index; }
std::string ApplyMask::GetCoreCode(const LinkingContext& context) const {
const std::string size_name = "mask_size_op" + std::to_string(link_index_);
const std::string tensor_name = absl::StrCat("mask_data_op", link_index_);
TensorCodeGenerator mask(
tensor_name,
WHSPoint{size_name + ".x", size_name + ".y", size_name + ".z"},
definition_.src_tensors[1]);
switch (mask_type_) {
case MaskType::TENSOR:
return context.var_name + " *= " +
mask.ReadWHS(context.x_coord, context.y_coord, context.s_coord) +
";\n";
case MaskType::CHANNELS:
return context.var_name +
" *= " + mask.ReadWHS("0", "0", context.s_coord) + ";\n";
case MaskType::LAYER:
return context.var_name +
" *= " + mask.ReadWHS(context.x_coord, context.y_coord, "0") +
".x;\n";
}
}
std::string ApplyMask::GetArgsDeclaration() const {
std::string args;
const std::string tensor_name = absl::StrCat("mask_data_op", link_index_);
absl::StrAppend(&args, ",\n",
GetTensorDeclaration(AccessType::READ, tensor_name,
definition_.src_tensors[1]));
const std::string size_name = "mask_size_op" + std::to_string(link_index_);
absl::StrAppend(&args, ",\n int4 ", size_name);
return args;
}
Status ApplyMask::BindArguments(CLKernel* kernel) {
RETURN_IF_ERROR(kernel->SetMemoryAuto(src_[1]->GetMemoryPtr()));
RETURN_IF_ERROR(kernel->SetBytesAuto(src_[1]->GetWBatchedHSB()));
return OkStatus();
}
ApplyMask CreateApplyMask(const OperationDef& definition, const BHWC& src_shape,
const BHWC& mask_shape) {
ApplyMask::MaskType mask_type;
if (mask_shape == src_shape) {
mask_type = ApplyMask::MaskType::TENSOR;
} else if (mask_shape.c == 1) {
mask_type = ApplyMask::MaskType::LAYER;
} else {
mask_type = ApplyMask::MaskType::CHANNELS;
}
ApplyMask operation(definition, mask_type);
operation.SetLinkIndex(0);
return operation;
}
} // namespace cl
} // namespace gpu
} // namespace tflite

View File

@ -1,63 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_APPLY_MASK_H_
#define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_APPLY_MASK_H_
#include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
#include "tensorflow/lite/delegates/gpu/common/types.h"
namespace tflite {
namespace gpu {
namespace cl {
class ApplyMask : public ElementwiseOperation {
public:
// Move only
ApplyMask(ApplyMask&& operation);
ApplyMask& operator=(ApplyMask&& operation);
ApplyMask(const ApplyMask&) = delete;
ApplyMask& operator=(const ApplyMask&) = delete;
void SetLinkIndex(int index) override;
std::string GetCoreCode(const LinkingContext& context) const override;
std::string GetArgsDeclaration() const override;
Status BindArguments(CLKernel* kernel) override;
private:
friend ApplyMask CreateApplyMask(const OperationDef& definition,
const BHWC& src_shape,
const BHWC& mask_shape);
enum class MaskType { LAYER, CHANNELS, TENSOR };
explicit ApplyMask(const OperationDef& definition, MaskType mask_type)
: ElementwiseOperation(definition), mask_type_(mask_type) {}
MaskType mask_type_;
int link_index_;
};
ApplyMask CreateApplyMask(const OperationDef& definition, const BHWC& src_shape,
const BHWC& mask_shape);
} // namespace cl
} // namespace gpu
} // namespace tflite
#endif // TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_APPLY_MASK_H_

View File

@ -1,127 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/cl/kernels/apply_mask.h"
#include <memory>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/lite/delegates/gpu/cl/kernels/cl_test.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
using ::testing::FloatNear;
using ::testing::Pointwise;
namespace tflite {
namespace gpu {
namespace cl {
namespace {
TEST_F(OpenCLOperationTest, ApplyMaskOneChannel) {
TensorFloat32 src_tensor;
src_tensor.shape = BHWC(1, 2, 2, 2);
src_tensor.data = {-4.0f, -3.0f, -1.0f, 0.0f, 1.0f, 3.0f, 4.0f, 6.0f};
TensorFloat32 mask_tensor;
mask_tensor.shape = BHWC(1, 2, 2, 1);
mask_tensor.data = {2.0f, 0.5f, 1.0f, 0.0f};
for (auto storage : env_.GetSupportedStorages()) {
for (auto precision : env_.GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
OperationDef op_def;
op_def.precision = precision;
auto data_type = DeduceDataTypeFromPrecision(precision);
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
ApplyMask operation =
CreateApplyMask(op_def, src_tensor.shape, mask_tensor.shape);
ASSERT_OK(ExecuteGPUOperation({src_tensor, mask_tensor},
creation_context_, &operation,
BHWC(1, 2, 2, 2), &dst_tensor));
EXPECT_THAT(dst_tensor.data,
Pointwise(FloatNear(eps), {-8.0f, -6.0f, -0.5f, 0.0f, 1.0f,
3.0f, 0.0f, 0.0f}));
}
}
}
TEST_F(OpenCLOperationTest, ApplyMaskEqualSizes) {
TensorFloat32 src_tensor;
src_tensor.shape = BHWC(1, 2, 2, 2);
src_tensor.data = {-4.0f, -3.0f, -1.0f, 0.0f, 1.0f, 3.0f, 4.0f, 6.0f};
TensorFloat32 mask_tensor;
mask_tensor.shape = BHWC(1, 2, 2, 2);
mask_tensor.data = {2.0f, 0.5f, 1.0f, 0.0f, 2.0f, 0.5f, 1.0f, 0.0f};
for (auto storage : env_.GetSupportedStorages()) {
for (auto precision : env_.GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
OperationDef op_def;
op_def.precision = precision;
auto data_type = DeduceDataTypeFromPrecision(precision);
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
ApplyMask operation =
CreateApplyMask(op_def, src_tensor.shape, mask_tensor.shape);
ASSERT_OK(ExecuteGPUOperation({src_tensor, mask_tensor},
creation_context_, &operation,
BHWC(1, 2, 2, 2), &dst_tensor));
EXPECT_THAT(dst_tensor.data,
Pointwise(FloatNear(eps), {-8.0f, -1.5f, -1.0f, 0.0f, 2.0f,
1.5f, 4.0f, 0.0f}));
}
}
}
TEST_F(OpenCLOperationTest, ApplyMaskVector) {
TensorFloat32 src_tensor;
src_tensor.shape = BHWC(1, 2, 2, 2);
src_tensor.data = {-4.0f, -3.0f, -1.0f, 0.0f, 1.0f, 3.0f, 4.0f, 6.0f};
TensorFloat32 mask_tensor;
mask_tensor.shape = BHWC(1, 1, 1, 2);
mask_tensor.data = {2.0f, 0.5f};
for (auto storage : env_.GetSupportedStorages()) {
for (auto precision : env_.GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
OperationDef op_def;
op_def.precision = precision;
auto data_type = DeduceDataTypeFromPrecision(precision);
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
ApplyMask operation =
CreateApplyMask(op_def, src_tensor.shape, mask_tensor.shape);
ASSERT_OK(ExecuteGPUOperation({src_tensor, mask_tensor},
creation_context_, &operation,
BHWC(1, 2, 2, 2), &dst_tensor));
EXPECT_THAT(dst_tensor.data,
Pointwise(FloatNear(eps), {-8.0f, -1.5f, -2.0f, 0.0f, 2.0f,
1.5f, 8.0f, 3.0f}));
}
}
}
} // namespace
} // namespace cl
} // namespace gpu
} // namespace tflite

View File

@ -105,7 +105,7 @@ Status MultiplyAdd::BindArguments(CLKernel* kernel) {
return OkStatus(); return OkStatus();
} }
Status MultiplyAdd::UploadMul(const MultiplyScalarAttributes& attr, Status MultiplyAdd::UploadMul(const MultiplyAttributes& attr,
CalculationsPrecision scalar_precision, CalculationsPrecision scalar_precision,
CLContext* context) { CLContext* context) {
auto mul = absl::get_if<::tflite::gpu::Tensor<Linear, DataType::FLOAT32>>( auto mul = absl::get_if<::tflite::gpu::Tensor<Linear, DataType::FLOAT32>>(
@ -135,8 +135,7 @@ Status MultiplyAdd::UploadAdd(const AddAttributes& attr,
Status CreateMultiplyAdd(const CreationContext& creation_context, Status CreateMultiplyAdd(const CreationContext& creation_context,
const OperationDef& definition, const OperationDef& definition,
const MultiplyScalarAttributes& attr, const MultiplyAttributes& attr, MultiplyAdd* result) {
MultiplyAdd* result) {
const auto scalar_precision = creation_context.device->IsPowerVR() const auto scalar_precision = creation_context.device->IsPowerVR()
? CalculationsPrecision::F32 ? CalculationsPrecision::F32
: definition.precision; : definition.precision;
@ -162,7 +161,7 @@ Status CreateMultiplyAdd(const CreationContext& creation_context,
Status CreateMultiplyAdd(const CreationContext& creation_context, Status CreateMultiplyAdd(const CreationContext& creation_context,
const OperationDef& definition, const OperationDef& definition,
const MultiplyScalarAttributes& mul_attr, const MultiplyAttributes& mul_attr,
const AddAttributes& add_attr, MultiplyAdd* result) { const AddAttributes& add_attr, MultiplyAdd* result) {
const auto scalar_precision = creation_context.device->IsPowerVR() const auto scalar_precision = creation_context.device->IsPowerVR()
? CalculationsPrecision::F32 ? CalculationsPrecision::F32
@ -176,6 +175,76 @@ Status CreateMultiplyAdd(const CreationContext& creation_context,
return OkStatus(); return OkStatus();
} }
ApplyMask::ApplyMask(ApplyMask&& operation)
: ElementwiseOperation(std::move(operation)),
mask_type_(operation.mask_type_),
link_index_(operation.link_index_) {}
ApplyMask& ApplyMask::operator=(ApplyMask&& operation) {
if (this != &operation) {
mask_type_ = operation.mask_type_;
link_index_ = operation.link_index_;
ElementwiseOperation::operator=(std::move(operation));
}
return *this;
}
void ApplyMask::SetLinkIndex(int index) { link_index_ = index; }
std::string ApplyMask::GetCoreCode(const LinkingContext& context) const {
const std::string size_name = "mask_size_op" + std::to_string(link_index_);
const std::string tensor_name = absl::StrCat("mask_data_op", link_index_);
TensorCodeGenerator mask(
tensor_name,
WHSPoint{size_name + ".x", size_name + ".y", size_name + ".z"},
definition_.src_tensors[1]);
switch (mask_type_) {
case MaskType::TENSOR:
return context.var_name + " *= " +
mask.ReadWHS(context.x_coord, context.y_coord, context.s_coord) +
";\n";
case MaskType::CHANNELS:
return context.var_name +
" *= " + mask.ReadWHS("0", "0", context.s_coord) + ";\n";
case MaskType::LAYER:
return context.var_name +
" *= " + mask.ReadWHS(context.x_coord, context.y_coord, "0") +
".x;\n";
}
}
std::string ApplyMask::GetArgsDeclaration() const {
std::string args;
const std::string tensor_name = absl::StrCat("mask_data_op", link_index_);
absl::StrAppend(&args, ",\n",
GetTensorDeclaration(AccessType::READ, tensor_name,
definition_.src_tensors[1]));
const std::string size_name = "mask_size_op" + std::to_string(link_index_);
absl::StrAppend(&args, ",\n int4 ", size_name);
return args;
}
Status ApplyMask::BindArguments(CLKernel* kernel) {
RETURN_IF_ERROR(kernel->SetMemoryAuto(src_[1]->GetMemoryPtr()));
RETURN_IF_ERROR(kernel->SetBytesAuto(src_[1]->GetWBatchedHSB()));
return OkStatus();
}
ApplyMask CreateApplyMask(const OperationDef& definition, const BHWC& src_shape,
const BHWC& mask_shape) {
ApplyMask::MaskType mask_type;
if (mask_shape == src_shape) {
mask_type = ApplyMask::MaskType::TENSOR;
} else if (mask_shape.c == 1) {
mask_type = ApplyMask::MaskType::LAYER;
} else {
mask_type = ApplyMask::MaskType::CHANNELS;
}
ApplyMask operation(definition, mask_type);
operation.SetLinkIndex(0);
return operation;
}
} // namespace cl } // namespace cl
} // namespace gpu } // namespace gpu
} // namespace tflite } // namespace tflite

View File

@ -40,7 +40,7 @@ class MultiplyAdd : public ElementwiseOperation {
MultiplyAdd(const MultiplyAdd&) = delete; MultiplyAdd(const MultiplyAdd&) = delete;
MultiplyAdd& operator=(const MultiplyAdd&) = delete; MultiplyAdd& operator=(const MultiplyAdd&) = delete;
Status UploadMul(const MultiplyScalarAttributes& attr, Status UploadMul(const MultiplyAttributes& attr,
CalculationsPrecision scalar_precision, CLContext* context); CalculationsPrecision scalar_precision, CLContext* context);
Status UploadAdd(const AddAttributes& attr, Status UploadAdd(const AddAttributes& attr,
CalculationsPrecision scalar_precision, CLContext* context); CalculationsPrecision scalar_precision, CLContext* context);
@ -61,7 +61,7 @@ class MultiplyAdd : public ElementwiseOperation {
friend Status CreateMultiplyAdd(const CreationContext& creation_context, friend Status CreateMultiplyAdd(const CreationContext& creation_context,
const OperationDef& definition, const OperationDef& definition,
const MultiplyScalarAttributes& attr, const MultiplyAttributes& attr,
MultiplyAdd* result); MultiplyAdd* result);
friend Status CreateMultiplyAdd(const CreationContext& creation_context, friend Status CreateMultiplyAdd(const CreationContext& creation_context,
@ -71,7 +71,7 @@ class MultiplyAdd : public ElementwiseOperation {
friend Status CreateMultiplyAdd(const CreationContext& creation_context, friend Status CreateMultiplyAdd(const CreationContext& creation_context,
const OperationDef& definition, const OperationDef& definition,
const MultiplyScalarAttributes& mul_attr, const MultiplyAttributes& mul_attr,
const AddAttributes& add_attr, const AddAttributes& add_attr,
MultiplyAdd* result); MultiplyAdd* result);
@ -91,8 +91,7 @@ class MultiplyAdd : public ElementwiseOperation {
Status CreateMultiplyAdd(const CreationContext& creation_context, Status CreateMultiplyAdd(const CreationContext& creation_context,
const OperationDef& definition, const OperationDef& definition,
const MultiplyScalarAttributes& attr, const MultiplyAttributes& attr, MultiplyAdd* result);
MultiplyAdd* result);
Status CreateMultiplyAdd(const CreationContext& creation_context, Status CreateMultiplyAdd(const CreationContext& creation_context,
const OperationDef& definition, const OperationDef& definition,
@ -100,7 +99,7 @@ Status CreateMultiplyAdd(const CreationContext& creation_context,
Status CreateMultiplyAdd(const CreationContext& creation_context, Status CreateMultiplyAdd(const CreationContext& creation_context,
const OperationDef& definition, const OperationDef& definition,
const MultiplyScalarAttributes& mul_attr, const MultiplyAttributes& mul_attr,
const AddAttributes& add_attr, MultiplyAdd* result); const AddAttributes& add_attr, MultiplyAdd* result);
template <DataType T> template <DataType T>
@ -127,6 +126,36 @@ Status MultiplyAdd::UploadAdd(const ::tflite::gpu::Tensor<Linear, T>& add,
return OkStatus(); return OkStatus();
} }
class ApplyMask : public ElementwiseOperation {
public:
// Move only
ApplyMask(ApplyMask&& operation);
ApplyMask& operator=(ApplyMask&& operation);
ApplyMask(const ApplyMask&) = delete;
ApplyMask& operator=(const ApplyMask&) = delete;
void SetLinkIndex(int index) override;
std::string GetCoreCode(const LinkingContext& context) const override;
std::string GetArgsDeclaration() const override;
Status BindArguments(CLKernel* kernel) override;
private:
friend ApplyMask CreateApplyMask(const OperationDef& definition,
const BHWC& src_shape,
const BHWC& mask_shape);
enum class MaskType { LAYER, CHANNELS, TENSOR };
explicit ApplyMask(const OperationDef& definition, MaskType mask_type)
: ElementwiseOperation(definition), mask_type_(mask_type) {}
MaskType mask_type_;
int link_index_;
};
ApplyMask CreateApplyMask(const OperationDef& definition, const BHWC& src_shape,
const BHWC& mask_shape);
} // namespace cl } // namespace cl
} // namespace gpu } // namespace gpu
} // namespace tflite } // namespace tflite

View File

@ -37,7 +37,7 @@ TEST_F(OpenCLOperationTest, MultiplyAddVectorMul) {
src_tensor.shape = BHWC(1, 2, 1, 2); src_tensor.shape = BHWC(1, 2, 1, 2);
src_tensor.data = {0.0f, 1.0f, 2.0f, 3.0f}; src_tensor.data = {0.0f, 1.0f, 2.0f, 3.0f};
MultiplyScalarAttributes attr; MultiplyAttributes attr;
::tflite::gpu::Tensor<Linear, DataType::FLOAT32> parameters; ::tflite::gpu::Tensor<Linear, DataType::FLOAT32> parameters;
parameters.shape = Linear(2); parameters.shape = Linear(2);
parameters.data = {0.5f, 2.0f}; parameters.data = {0.5f, 2.0f};
@ -97,7 +97,7 @@ TEST_F(OpenCLOperationTest, MultiplyAddScalarMul) {
src_tensor.shape = BHWC(1, 2, 1, 2); src_tensor.shape = BHWC(1, 2, 1, 2);
src_tensor.data = {0.0f, 1.0f, 2.0f, 3.0f}; src_tensor.data = {0.0f, 1.0f, 2.0f, 3.0f};
MultiplyScalarAttributes attr; MultiplyAttributes attr;
attr.param = 0.5f; attr.param = 0.5f;
for (auto storage : env_.GetSupportedStorages()) { for (auto storage : env_.GetSupportedStorages()) {
@ -151,7 +151,7 @@ TEST_F(OpenCLOperationTest, MultiplyAddVectorMad) {
src_tensor.shape = BHWC(1, 2, 1, 2); src_tensor.shape = BHWC(1, 2, 1, 2);
src_tensor.data = {0.0f, 1.0f, 2.0f, 3.0f}; src_tensor.data = {0.0f, 1.0f, 2.0f, 3.0f};
MultiplyScalarAttributes mul_attr; MultiplyAttributes mul_attr;
::tflite::gpu::Tensor<Linear, DataType::FLOAT32> parameters; ::tflite::gpu::Tensor<Linear, DataType::FLOAT32> parameters;
parameters.shape = Linear(2); parameters.shape = Linear(2);
parameters.data = {0.5f, 2.0f}; parameters.data = {0.5f, 2.0f};
@ -181,6 +181,96 @@ TEST_F(OpenCLOperationTest, MultiplyAddVectorMad) {
} }
} }
TEST_F(OpenCLOperationTest, ApplyMaskOneChannel) {
TensorFloat32 src_tensor;
src_tensor.shape = BHWC(1, 2, 2, 2);
src_tensor.data = {-4.0f, -3.0f, -1.0f, 0.0f, 1.0f, 3.0f, 4.0f, 6.0f};
TensorFloat32 mask_tensor;
mask_tensor.shape = BHWC(1, 2, 2, 1);
mask_tensor.data = {2.0f, 0.5f, 1.0f, 0.0f};
for (auto storage : env_.GetSupportedStorages()) {
for (auto precision : env_.GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
OperationDef op_def;
op_def.precision = precision;
auto data_type = DeduceDataTypeFromPrecision(precision);
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
ApplyMask operation =
CreateApplyMask(op_def, src_tensor.shape, mask_tensor.shape);
ASSERT_OK(ExecuteGPUOperation({src_tensor, mask_tensor},
creation_context_, &operation,
BHWC(1, 2, 2, 2), &dst_tensor));
EXPECT_THAT(dst_tensor.data,
Pointwise(FloatNear(eps), {-8.0f, -6.0f, -0.5f, 0.0f, 1.0f,
3.0f, 0.0f, 0.0f}));
}
}
}
TEST_F(OpenCLOperationTest, ApplyMaskEqualSizes) {
TensorFloat32 src_tensor;
src_tensor.shape = BHWC(1, 2, 2, 2);
src_tensor.data = {-4.0f, -3.0f, -1.0f, 0.0f, 1.0f, 3.0f, 4.0f, 6.0f};
TensorFloat32 mask_tensor;
mask_tensor.shape = BHWC(1, 2, 2, 2);
mask_tensor.data = {2.0f, 0.5f, 1.0f, 0.0f, 2.0f, 0.5f, 1.0f, 0.0f};
for (auto storage : env_.GetSupportedStorages()) {
for (auto precision : env_.GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
OperationDef op_def;
op_def.precision = precision;
auto data_type = DeduceDataTypeFromPrecision(precision);
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
ApplyMask operation =
CreateApplyMask(op_def, src_tensor.shape, mask_tensor.shape);
ASSERT_OK(ExecuteGPUOperation({src_tensor, mask_tensor},
creation_context_, &operation,
BHWC(1, 2, 2, 2), &dst_tensor));
EXPECT_THAT(dst_tensor.data,
Pointwise(FloatNear(eps), {-8.0f, -1.5f, -1.0f, 0.0f, 2.0f,
1.5f, 4.0f, 0.0f}));
}
}
}
TEST_F(OpenCLOperationTest, ApplyMaskVector) {
TensorFloat32 src_tensor;
src_tensor.shape = BHWC(1, 2, 2, 2);
src_tensor.data = {-4.0f, -3.0f, -1.0f, 0.0f, 1.0f, 3.0f, 4.0f, 6.0f};
TensorFloat32 mask_tensor;
mask_tensor.shape = BHWC(1, 1, 1, 2);
mask_tensor.data = {2.0f, 0.5f};
for (auto storage : env_.GetSupportedStorages()) {
for (auto precision : env_.GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
OperationDef op_def;
op_def.precision = precision;
auto data_type = DeduceDataTypeFromPrecision(precision);
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
ApplyMask operation =
CreateApplyMask(op_def, src_tensor.shape, mask_tensor.shape);
ASSERT_OK(ExecuteGPUOperation({src_tensor, mask_tensor},
creation_context_, &operation,
BHWC(1, 2, 2, 2), &dst_tensor));
EXPECT_THAT(dst_tensor.data,
Pointwise(FloatNear(eps), {-8.0f, -1.5f, -2.0f, 0.0f, 2.0f,
1.5f, 8.0f, 3.0f}));
}
}
}
} // namespace } // namespace
} // namespace cl } // namespace cl
} // namespace gpu } // namespace gpu

View File

@ -103,7 +103,6 @@ cc_library(
hdrs = ["simple_selectors.h"], hdrs = ["simple_selectors.h"],
deps = [ deps = [
"//tensorflow/lite/delegates/gpu/cl/kernels:add", "//tensorflow/lite/delegates/gpu/cl/kernels:add",
"//tensorflow/lite/delegates/gpu/cl/kernels:apply_mask",
"//tensorflow/lite/delegates/gpu/cl/kernels:concat_xy", "//tensorflow/lite/delegates/gpu/cl/kernels:concat_xy",
"//tensorflow/lite/delegates/gpu/cl/kernels:concat_z", "//tensorflow/lite/delegates/gpu/cl/kernels:concat_z",
"//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation", "//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation",

View File

@ -68,11 +68,6 @@ Status GPUOperationFromNode(const CreationContext& creation_context,
return OkStatus(); return OkStatus();
} }
} }
case OperationType::APPLY_MASK: {
SelectApplyMask(op_def, inputs[0]->tensor.shape, inputs[1]->tensor.shape,
gpu_op);
return OkStatus();
}
case OperationType::CONCAT: { case OperationType::CONCAT: {
auto attr = absl::any_cast<ConcatAttributes>(node.operation.attributes); auto attr = absl::any_cast<ConcatAttributes>(node.operation.attributes);
std::vector<int> channels(inputs.size()); std::vector<int> channels(inputs.size());
@ -119,10 +114,17 @@ Status GPUOperationFromNode(const CreationContext& creation_context,
auto attr = absl::any_cast<MeanAttributes>(node.operation.attributes); auto attr = absl::any_cast<MeanAttributes>(node.operation.attributes);
return SelectMean(attr, op_def, gpu_op); return SelectMean(attr, op_def, gpu_op);
} }
case OperationType::MULTIPLY_SCALAR: { case OperationType::MUL: {
auto attr = if (node.operation.attributes.has_value()) {
absl::any_cast<MultiplyScalarAttributes>(node.operation.attributes); auto attr =
return SelectMultiplyScalar(attr, creation_context, op_def, gpu_op); absl::any_cast<MultiplyAttributes>(node.operation.attributes);
return SelectMultiplyScalar(attr, creation_context, op_def, gpu_op);
} else {
SelectApplyMask(op_def, inputs[0]->tensor.shape,
inputs[1]->tensor.shape, gpu_op);
return OkStatus();
}
} }
case OperationType::PAD: { case OperationType::PAD: {
auto attr = absl::any_cast<PadAttributes>(node.operation.attributes); auto attr = absl::any_cast<PadAttributes>(node.operation.attributes);

View File

@ -20,7 +20,6 @@ limitations under the License.
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/add.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/add.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/apply_mask.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/concat_xy.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/concat_xy.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/concat_z.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/concat_z.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/lstm.h" #include "tensorflow/lite/delegates/gpu/cl/kernels/lstm.h"
@ -155,7 +154,7 @@ Status SelectMean(const MeanAttributes& attr, const OperationDef& op_def,
return OkStatus(); return OkStatus();
} }
Status SelectMultiplyScalar(const MultiplyScalarAttributes& attr, Status SelectMultiplyScalar(const MultiplyAttributes& attr,
const CreationContext& creation_context, const CreationContext& creation_context,
const OperationDef& op_def, const OperationDef& op_def,
std::unique_ptr<GPUOperation>* ptr) { std::unique_ptr<GPUOperation>* ptr) {

View File

@ -73,7 +73,7 @@ void SelectStridedSlice(const SliceAttributes& attr, const OperationDef& op_def,
Status SelectMean(const MeanAttributes& attr, const OperationDef& op_def, Status SelectMean(const MeanAttributes& attr, const OperationDef& op_def,
std::unique_ptr<GPUOperation>* ptr); std::unique_ptr<GPUOperation>* ptr);
Status SelectMultiplyScalar(const MultiplyScalarAttributes& attr, Status SelectMultiplyScalar(const MultiplyAttributes& attr,
const CreationContext& creation_context, const CreationContext& creation_context,
const OperationDef& op_def, const OperationDef& op_def,
std::unique_ptr<GPUOperation>* ptr); std::unique_ptr<GPUOperation>* ptr);

View File

@ -1395,8 +1395,11 @@ class MulOperationParser : public TFLiteOperationParser {
const bool runtime_tensor0 = !constant_tensor0; const bool runtime_tensor0 = !constant_tensor0;
const bool runtime_tensor1 = !constant_tensor1; const bool runtime_tensor1 = !constant_tensor1;
// Parse for APPLY_MASK. The "larger" input tensor must be bound to 1st Node* node = graph->NewNode();
// input and the "smaller" input tensor ("mask") must be bound to 2nd input. node->operation.type = ToString(OperationType::MUL);
// The "larger" input tensor must be bound to 1st input and the "smaller"
// input tensor ("mask") must be bound to 2nd input.
if (runtime_tensor0 && runtime_tensor1) { if (runtime_tensor0 && runtime_tensor1) {
BHWC shape0; BHWC shape0;
RETURN_IF_ERROR(ExtractTensorShape(*input0, &shape0)); RETURN_IF_ERROR(ExtractTensorShape(*input0, &shape0));
@ -1409,11 +1412,11 @@ class MulOperationParser : public TFLiteOperationParser {
input_tensor0 = 1; input_tensor0 = 1;
input_tensor1 = 0; input_tensor1 = 0;
} }
return ParseApplyMask(input_tensor0, input_tensor1, graph, reader); return ParseApplyMask(node, input_tensor0, input_tensor1, graph, reader);
} }
// Parse for MULTIPLY_SCALAR. The runtime input tensor must be bound to 1st // The runtime input tensor must be bound to 1st input and the constant
// input and the constant input tensor must be bound to 2nd input. // input tensor must be bound to 2nd input.
int runtime_tensor = 0; int runtime_tensor = 0;
int constant_tensor = 1; int constant_tensor = 1;
TfLiteIntArray* constant_dims = input1->dims; TfLiteIntArray* constant_dims = input1->dims;
@ -1422,27 +1425,24 @@ class MulOperationParser : public TFLiteOperationParser {
constant_tensor = 0; constant_tensor = 0;
constant_dims = input0->dims; constant_dims = input0->dims;
} }
return ParseMultiplyScalar(runtime_tensor, constant_tensor, constant_dims, return ParseMultiplyScalar(node, runtime_tensor, constant_tensor,
graph, reader); constant_dims, graph, reader);
} }
private: private:
Status ParseApplyMask(int input_tensor0, int input_tensor1, Status ParseApplyMask(Node* node, int input_tensor0, int input_tensor1,
GraphFloat32* graph, ObjectReader* reader) { GraphFloat32* graph, ObjectReader* reader) {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::APPLY_MASK);
RETURN_IF_ERROR(reader->AddInput(node, input_tensor0)); RETURN_IF_ERROR(reader->AddInput(node, input_tensor0));
RETURN_IF_ERROR(reader->AddInput(node, input_tensor1)); RETURN_IF_ERROR(reader->AddInput(node, input_tensor1));
return reader->AddOutputs(node); return reader->AddOutputs(node);
} }
Status ParseMultiplyScalar(int runtime_tensor, int constant_tensor, Status ParseMultiplyScalar(Node* node, int runtime_tensor,
int constant_tensor,
const TfLiteIntArray* constant_dims, const TfLiteIntArray* constant_dims,
GraphFloat32* graph, ObjectReader* reader) { GraphFloat32* graph, ObjectReader* reader) {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::MULTIPLY_SCALAR);
RETURN_IF_ERROR(reader->AddInput(node, runtime_tensor)); RETURN_IF_ERROR(reader->AddInput(node, runtime_tensor));
MultiplyScalarAttributes attr; MultiplyAttributes attr;
if (constant_dims->size <= 0) { if (constant_dims->size <= 0) {
Tensor<Scalar, DataType::FLOAT32> tensor; Tensor<Scalar, DataType::FLOAT32> tensor;
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor)); RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));

View File

@ -72,8 +72,6 @@ std::string ToString(enum OperationType op) {
return "abs"; return "abs";
case OperationType::ADD: case OperationType::ADD:
return "add"; return "add";
case OperationType::APPLY_MASK:
return "apply_mask";
case OperationType::BATCH_NORMALIZATION: case OperationType::BATCH_NORMALIZATION:
return "batch_normalization"; return "batch_normalization";
case OperationType::BATCH_TO_SPACE: case OperationType::BATCH_TO_SPACE:
@ -106,8 +104,6 @@ std::string ToString(enum OperationType op) {
return "mean"; return "mean";
case OperationType::MUL: case OperationType::MUL:
return "mul"; return "mul";
case OperationType::MULTIPLY_SCALAR:
return "multiply_scalar";
case OperationType::PAD: case OperationType::PAD:
return "pad"; return "pad";
case OperationType::POOLING_2D: case OperationType::POOLING_2D:
@ -157,7 +153,6 @@ OperationType OperationTypeFromString(const std::string& name) {
new std::unordered_map<std::string, OperationType>({ new std::unordered_map<std::string, OperationType>({
{"abs", OperationType::ABS}, {"abs", OperationType::ABS},
{"add", OperationType::ADD}, {"add", OperationType::ADD},
{"apply_mask", OperationType::APPLY_MASK},
{"batch_normalization", OperationType::BATCH_NORMALIZATION}, {"batch_normalization", OperationType::BATCH_NORMALIZATION},
{"concat", OperationType::CONCAT}, {"concat", OperationType::CONCAT},
{"const", OperationType::CONST}, {"const", OperationType::CONST},
@ -173,7 +168,6 @@ OperationType OperationTypeFromString(const std::string& name) {
{"max_unpooling", OperationType::MAX_UNPOOLING_2D}, {"max_unpooling", OperationType::MAX_UNPOOLING_2D},
{"mean", OperationType::MEAN}, {"mean", OperationType::MEAN},
{"mul", OperationType::MUL}, {"mul", OperationType::MUL},
{"multiply_scalar", OperationType::MULTIPLY_SCALAR},
{"pad", OperationType::PAD}, {"pad", OperationType::PAD},
{"pooling_2d", OperationType::POOLING_2D}, {"pooling_2d", OperationType::POOLING_2D},
{"pow", OperationType::POW}, {"pow", OperationType::POW},

View File

@ -34,8 +34,6 @@ enum class OperationType {
UNKNOWN = 0, UNKNOWN = 0,
ABS, ABS,
ADD, ADD,
// TODO(eignasheva): remove APPLY_MASK operation, is should be just MUL
APPLY_MASK,
BATCH_TO_SPACE, BATCH_TO_SPACE,
BATCH_NORMALIZATION, BATCH_NORMALIZATION,
CONCAT, CONCAT,
@ -52,7 +50,6 @@ enum class OperationType {
MAX_UNPOOLING_2D, MAX_UNPOOLING_2D,
MEAN, MEAN,
MUL, MUL,
MULTIPLY_SCALAR,
PAD, PAD,
POOLING_2D, POOLING_2D,
POW, POW,
@ -354,7 +351,7 @@ struct LstmAttributes {
LstmKernelType kernel_type = LstmKernelType::BASIC; LstmKernelType kernel_type = LstmKernelType::BASIC;
}; };
struct MultiplyScalarAttributes { struct MultiplyAttributes {
absl::variant<absl::monostate, Tensor<Linear, DataType::FLOAT32>, float> absl::variant<absl::monostate, Tensor<Linear, DataType::FLOAT32>, float>
param; param;
}; };

View File

@ -33,15 +33,14 @@ class MergeConvolutionWithMul : public SequenceTransformation {
GraphFloat32* graph) final { GraphFloat32* graph) final {
auto& conv_node = *sequence[0]; auto& conv_node = *sequence[0];
auto& mul_node = *sequence[1]; auto& mul_node = *sequence[1];
if (mul_node.operation.type != ToString(OperationType::MUL) && if (mul_node.operation.type != ToString(OperationType::MUL) ||
mul_node.operation.type != ToString(OperationType::MULTIPLY_SCALAR)) { !mul_node.operation.attributes.has_value()) {
return {TransformStatus::SKIPPED, ""}; return {TransformStatus::SKIPPED, ""};
} }
MultiplyScalarAttributes mul_attr = MultiplyAttributes mul_attr =
absl::any_cast<MultiplyScalarAttributes>(mul_node.operation.attributes); absl::any_cast<MultiplyAttributes>(mul_node.operation.attributes);
if (!absl::get_if<Tensor<Linear, DataType::FLOAT32>>( if (!absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param) &&
&mul_attr.param) &&
!absl::get_if<float>(&mul_attr.param)) { !absl::get_if<float>(&mul_attr.param)) {
return { return {
TransformStatus::DECLINED, TransformStatus::DECLINED,
@ -93,13 +92,13 @@ class MergeMulWithConvolution : public SequenceTransformation {
GraphFloat32* graph) final { GraphFloat32* graph) final {
auto& conv_node = *sequence[1]; auto& conv_node = *sequence[1];
auto& mul_node = *sequence[0]; auto& mul_node = *sequence[0];
if (mul_node.operation.type != ToString(OperationType::MUL) && if (mul_node.operation.type != ToString(OperationType::MUL) ||
mul_node.operation.type != ToString(OperationType::MULTIPLY_SCALAR)) { !mul_node.operation.attributes.has_value()) {
return {TransformStatus::SKIPPED, ""}; return {TransformStatus::SKIPPED, ""};
} }
MultiplyScalarAttributes mul_attr = MultiplyAttributes mul_attr =
absl::any_cast<MultiplyScalarAttributes>(mul_node.operation.attributes); absl::any_cast<MultiplyAttributes>(mul_node.operation.attributes);
if (!absl::get_if<Tensor<Linear, DataType::FLOAT32>>( if (!absl::get_if<Tensor<Linear, DataType::FLOAT32>>(
&mul_attr.param) && &mul_attr.param) &&
!absl::get_if<float>(&mul_attr.param)) { !absl::get_if<float>(&mul_attr.param)) {
@ -155,7 +154,7 @@ std::unique_ptr<SequenceTransformation> NewMergeMulWithConvolution() {
return absl::make_unique<MergeMulWithConvolution>(); return absl::make_unique<MergeMulWithConvolution>();
} }
void FuseConvolution2DWithMultiply(const MultiplyScalarAttributes& mul_attr, void FuseConvolution2DWithMultiply(const MultiplyAttributes& mul_attr,
Convolution2DAttributes* attr) { Convolution2DAttributes* attr) {
auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param); auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
auto mul_scalar = absl::get_if<float>(&mul_attr.param); auto mul_scalar = absl::get_if<float>(&mul_attr.param);
@ -176,7 +175,7 @@ void FuseConvolution2DWithMultiply(const MultiplyScalarAttributes& mul_attr,
} }
void FuseDepthwiseConvolution2DWithMultiply( void FuseDepthwiseConvolution2DWithMultiply(
const MultiplyScalarAttributes& mul_attr, const MultiplyAttributes& mul_attr,
DepthwiseConvolution2DAttributes* attr) { DepthwiseConvolution2DAttributes* attr) {
auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param); auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
auto mul_scalar = absl::get_if<float>(&mul_attr.param); auto mul_scalar = absl::get_if<float>(&mul_attr.param);
@ -198,8 +197,7 @@ void FuseDepthwiseConvolution2DWithMultiply(
} }
void FuseConvolutionTransposedWithMultiply( void FuseConvolutionTransposedWithMultiply(
const MultiplyScalarAttributes& mul_attr, const MultiplyAttributes& mul_attr, ConvolutionTransposedAttributes* attr) {
ConvolutionTransposedAttributes* attr) {
auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param); auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
auto mul_scalar = absl::get_if<float>(&mul_attr.param); auto mul_scalar = absl::get_if<float>(&mul_attr.param);
for (int d = 0; d < attr->weights.shape.o; ++d) { for (int d = 0; d < attr->weights.shape.o; ++d) {
@ -218,7 +216,7 @@ void FuseConvolutionTransposedWithMultiply(
} }
} }
void FuseFullyConnectedWithMultiply(const MultiplyScalarAttributes& mul_attr, void FuseFullyConnectedWithMultiply(const MultiplyAttributes& mul_attr,
FullyConnectedAttributes* attr) { FullyConnectedAttributes* attr) {
auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param); auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
auto mul_scalar = absl::get_if<float>(&mul_attr.param); auto mul_scalar = absl::get_if<float>(&mul_attr.param);
@ -234,7 +232,7 @@ void FuseFullyConnectedWithMultiply(const MultiplyScalarAttributes& mul_attr,
} }
} }
void FuseMultiplyWithConvolution2D(const MultiplyScalarAttributes& mul_attr, void FuseMultiplyWithConvolution2D(const MultiplyAttributes& mul_attr,
Convolution2DAttributes* attr) { Convolution2DAttributes* attr) {
auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param); auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
auto mul_scalar = absl::get_if<float>(&mul_attr.param); auto mul_scalar = absl::get_if<float>(&mul_attr.param);
@ -252,7 +250,7 @@ void FuseMultiplyWithConvolution2D(const MultiplyScalarAttributes& mul_attr,
} }
void FuseMultiplyWithDepthwiseConvolution2D( void FuseMultiplyWithDepthwiseConvolution2D(
const MultiplyScalarAttributes& mul_attr, const MultiplyAttributes& mul_attr,
DepthwiseConvolution2DAttributes* attr) { DepthwiseConvolution2DAttributes* attr) {
auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param); auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
auto mul_scalar = absl::get_if<float>(&mul_attr.param); auto mul_scalar = absl::get_if<float>(&mul_attr.param);
@ -270,8 +268,7 @@ void FuseMultiplyWithDepthwiseConvolution2D(
} }
void FuseMultiplyWithConvolutionTransposed( void FuseMultiplyWithConvolutionTransposed(
const MultiplyScalarAttributes& mul_attr, const MultiplyAttributes& mul_attr, ConvolutionTransposedAttributes* attr) {
ConvolutionTransposedAttributes* attr) {
auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param); auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
auto mul_scalar = absl::get_if<float>(&mul_attr.param); auto mul_scalar = absl::get_if<float>(&mul_attr.param);
for (int s = 0; s < attr->weights.shape.i; ++s) { for (int s = 0; s < attr->weights.shape.i; ++s) {
@ -287,7 +284,7 @@ void FuseMultiplyWithConvolutionTransposed(
} }
} }
void FuseMultiplyWithFullyConnected(const MultiplyScalarAttributes& mul_attr, void FuseMultiplyWithFullyConnected(const MultiplyAttributes& mul_attr,
FullyConnectedAttributes* attr) { FullyConnectedAttributes* attr) {
auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param); auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
auto mul_scalar = absl::get_if<float>(&mul_attr.param); auto mul_scalar = absl::get_if<float>(&mul_attr.param);

View File

@ -38,53 +38,49 @@ std::unique_ptr<SequenceTransformation> NewMergeMulWithConvolution();
// Modify Convolution2DAttributes so that after making convolution with // Modify Convolution2DAttributes so that after making convolution with
// modified attributes we will have the same result as convolution // modified attributes we will have the same result as convolution
// with old attributes and following multiply operation. // with old attributes and following multiply operation.
void FuseConvolution2DWithMultiply(const MultiplyScalarAttributes& mul_attr, void FuseConvolution2DWithMultiply(const MultiplyAttributes& mul_attr,
Convolution2DAttributes* attr); Convolution2DAttributes* attr);
// Modify DepthwiseConvolution2DAttributes so that after making depth wise // Modify DepthwiseConvolution2DAttributes so that after making depth wise
// convolution with modified attributes we will have the same result as depth // convolution with modified attributes we will have the same result as depth
// wise convolution with old attributes and following multiply operation. // wise convolution with old attributes and following multiply operation.
void FuseDepthwiseConvolution2DWithMultiply( void FuseDepthwiseConvolution2DWithMultiply(
const MultiplyScalarAttributes& mul_attr, const MultiplyAttributes& mul_attr, DepthwiseConvolution2DAttributes* attr);
DepthwiseConvolution2DAttributes* attr);
// Modify ConvolutionTransposedAttributes so that after making convolution // Modify ConvolutionTransposedAttributes so that after making convolution
// transposed with modified attributes we will have the same result as // transposed with modified attributes we will have the same result as
// convolution transposed with old attributes and following multiply operation. // convolution transposed with old attributes and following multiply operation.
void FuseConvolutionTransposedWithMultiply( void FuseConvolutionTransposedWithMultiply(
const MultiplyScalarAttributes& mul_attr, const MultiplyAttributes& mul_attr, ConvolutionTransposedAttributes* attr);
ConvolutionTransposedAttributes* attr);
// Modify FullyConnectedAttributes so that after making fully connected with // Modify FullyConnectedAttributes so that after making fully connected with
// modified attributes we will have the same result as fully connected // modified attributes we will have the same result as fully connected
// with old attributes and following multiply operation. // with old attributes and following multiply operation.
void FuseFullyConnectedWithMultiply(const MultiplyScalarAttributes& mul_attr, void FuseFullyConnectedWithMultiply(const MultiplyAttributes& mul_attr,
FullyConnectedAttributes* attr); FullyConnectedAttributes* attr);
// Modify Convolution2DAttributes so that after making convolution with // Modify Convolution2DAttributes so that after making convolution with
// modified attributes we will have the same result as multiply operation and // modified attributes we will have the same result as multiply operation and
// convolution with old attributes // convolution with old attributes
void FuseMultiplyWithConvolution2D(const MultiplyScalarAttributes& mul_attr, void FuseMultiplyWithConvolution2D(const MultiplyAttributes& mul_attr,
Convolution2DAttributes* attr); Convolution2DAttributes* attr);
// Modify DepthwiseConvolution2DAttributes so that after making depth wise // Modify DepthwiseConvolution2DAttributes so that after making depth wise
// convolution with modified attributes we will have the same result as multiply // convolution with modified attributes we will have the same result as multiply
// operation and depth wise convolution with old attributes // operation and depth wise convolution with old attributes
void FuseMultiplyWithDepthwiseConvolution2D( void FuseMultiplyWithDepthwiseConvolution2D(
const MultiplyScalarAttributes& mul_attr, const MultiplyAttributes& mul_attr, DepthwiseConvolution2DAttributes* attr);
DepthwiseConvolution2DAttributes* attr);
// Modify ConvolutionTransposedAttributes so that after making convolution // Modify ConvolutionTransposedAttributes so that after making convolution
// transposed with modified attributes we will have the same result as multiply // transposed with modified attributes we will have the same result as multiply
// operation and convolution transposed with old attributes // operation and convolution transposed with old attributes
void FuseMultiplyWithConvolutionTransposed( void FuseMultiplyWithConvolutionTransposed(
const MultiplyScalarAttributes& mul_attr, const MultiplyAttributes& mul_attr, ConvolutionTransposedAttributes* attr);
ConvolutionTransposedAttributes* attr);
// Modify FullyConnectedAttributes so that after making fully connected // Modify FullyConnectedAttributes so that after making fully connected
// with modified attributes we will have the same result as multiply // with modified attributes we will have the same result as multiply
// operation and fully connected with old attributes // operation and fully connected with old attributes
void FuseMultiplyWithFullyConnected(const MultiplyScalarAttributes& mul_attr, void FuseMultiplyWithFullyConnected(const MultiplyAttributes& mul_attr,
FullyConnectedAttributes* attr); FullyConnectedAttributes* attr);
} // namespace gpu } // namespace gpu

View File

@ -46,7 +46,7 @@ TEST(MergeConvolutionWithMulTest, Smoke) {
Tensor<Linear, DataType::FLOAT32> mul_tensor; Tensor<Linear, DataType::FLOAT32> mul_tensor;
mul_tensor.shape = Linear(16); mul_tensor.shape = Linear(16);
mul_tensor.data.resize(16); mul_tensor.data.resize(16);
MultiplyScalarAttributes mul_attr; MultiplyAttributes mul_attr;
mul_attr.param = mul_tensor; mul_attr.param = mul_tensor;
auto conv_node = graph.NewNode(); auto conv_node = graph.NewNode();
@ -87,7 +87,7 @@ TEST(MergeMulWithConvolutionTest, Smoke) {
Tensor<Linear, DataType::FLOAT32> mul_tensor; Tensor<Linear, DataType::FLOAT32> mul_tensor;
mul_tensor.shape = Linear(8); mul_tensor.shape = Linear(8);
mul_tensor.data.resize(8); mul_tensor.data.resize(8);
MultiplyScalarAttributes mul_attr; MultiplyAttributes mul_attr;
mul_attr.param = mul_tensor; mul_attr.param = mul_tensor;
Convolution2DAttributes conv_attr; Convolution2DAttributes conv_attr;
@ -140,7 +140,7 @@ TEST(FuseMulAfterConvolution2DTest, Smoke) {
Tensor<Linear, DataType::FLOAT32> mul_tensor; Tensor<Linear, DataType::FLOAT32> mul_tensor;
mul_tensor.shape = Linear(2); mul_tensor.shape = Linear(2);
mul_tensor.data = {0.5f, 2.0f}; mul_tensor.data = {0.5f, 2.0f};
MultiplyScalarAttributes mul_attr; MultiplyAttributes mul_attr;
mul_attr.param = mul_tensor; mul_attr.param = mul_tensor;
FuseConvolution2DWithMultiply(mul_attr, &attr); FuseConvolution2DWithMultiply(mul_attr, &attr);
@ -161,7 +161,7 @@ TEST(FuseMulAfterDepthwiseConvolution2DTest, Smoke) {
Tensor<Linear, DataType::FLOAT32> mul_tensor; Tensor<Linear, DataType::FLOAT32> mul_tensor;
mul_tensor.shape = Linear(4); mul_tensor.shape = Linear(4);
mul_tensor.data = {0.5f, 2.0f, 4.0f, 0.25f}; mul_tensor.data = {0.5f, 2.0f, 4.0f, 0.25f};
MultiplyScalarAttributes mul_attr; MultiplyAttributes mul_attr;
mul_attr.param = mul_tensor; mul_attr.param = mul_tensor;
FuseDepthwiseConvolution2DWithMultiply(mul_attr, &attr); FuseDepthwiseConvolution2DWithMultiply(mul_attr, &attr);
@ -183,7 +183,7 @@ TEST(FuseMulAfterConvolutionTransposedTest, Smoke) {
Tensor<Linear, DataType::FLOAT32> mul_tensor; Tensor<Linear, DataType::FLOAT32> mul_tensor;
mul_tensor.shape = Linear(2); mul_tensor.shape = Linear(2);
mul_tensor.data = {0.5f, 2.0f}; mul_tensor.data = {0.5f, 2.0f};
MultiplyScalarAttributes mul_attr; MultiplyAttributes mul_attr;
mul_attr.param = mul_tensor; mul_attr.param = mul_tensor;
FuseConvolutionTransposedWithMultiply(mul_attr, &attr); FuseConvolutionTransposedWithMultiply(mul_attr, &attr);
@ -204,7 +204,7 @@ TEST(FuseMulAfterFullyConnectedTest, Smoke) {
Tensor<Linear, DataType::FLOAT32> mul_tensor; Tensor<Linear, DataType::FLOAT32> mul_tensor;
mul_tensor.shape = Linear(2); mul_tensor.shape = Linear(2);
mul_tensor.data = {0.5f, 2.0f}; mul_tensor.data = {0.5f, 2.0f};
MultiplyScalarAttributes mul_attr; MultiplyAttributes mul_attr;
mul_attr.param = mul_tensor; mul_attr.param = mul_tensor;
FuseFullyConnectedWithMultiply(mul_attr, &attr); FuseFullyConnectedWithMultiply(mul_attr, &attr);
@ -224,7 +224,7 @@ TEST(FuseMulBeforeConvolution2DTest, Smoke) {
Tensor<Linear, DataType::FLOAT32> mul_tensor; Tensor<Linear, DataType::FLOAT32> mul_tensor;
mul_tensor.shape = Linear(2); mul_tensor.shape = Linear(2);
mul_tensor.data = {0.5f, 2.0f}; mul_tensor.data = {0.5f, 2.0f};
MultiplyScalarAttributes mul_attr; MultiplyAttributes mul_attr;
mul_attr.param = mul_tensor; mul_attr.param = mul_tensor;
FuseMultiplyWithConvolution2D(mul_attr, &attr); FuseMultiplyWithConvolution2D(mul_attr, &attr);
@ -245,7 +245,7 @@ TEST(FuseMulBeforeDepthwiseConvolution2DTest, Smoke) {
Tensor<Linear, DataType::FLOAT32> mul_tensor; Tensor<Linear, DataType::FLOAT32> mul_tensor;
mul_tensor.shape = Linear(4); mul_tensor.shape = Linear(4);
mul_tensor.data = {0.5f, 2.0f, 4.0f, 0.25f}; mul_tensor.data = {0.5f, 2.0f, 4.0f, 0.25f};
MultiplyScalarAttributes mul_attr; MultiplyAttributes mul_attr;
mul_attr.param = mul_tensor; mul_attr.param = mul_tensor;
FuseMultiplyWithDepthwiseConvolution2D(mul_attr, &attr); FuseMultiplyWithDepthwiseConvolution2D(mul_attr, &attr);
@ -267,7 +267,7 @@ TEST(FuseMulBeforeConvolutionTransposedTest, Smoke) {
Tensor<Linear, DataType::FLOAT32> mul_tensor; Tensor<Linear, DataType::FLOAT32> mul_tensor;
mul_tensor.shape = Linear(2); mul_tensor.shape = Linear(2);
mul_tensor.data = {0.5f, 2.0f}; mul_tensor.data = {0.5f, 2.0f};
MultiplyScalarAttributes mul_attr; MultiplyAttributes mul_attr;
mul_attr.param = mul_tensor; mul_attr.param = mul_tensor;
FuseMultiplyWithConvolutionTransposed(mul_attr, &attr); FuseMultiplyWithConvolutionTransposed(mul_attr, &attr);
@ -288,7 +288,7 @@ TEST(FuseMulBeforeFullyConnectedTest, Smoke) {
Tensor<Linear, DataType::FLOAT32> mul_tensor; Tensor<Linear, DataType::FLOAT32> mul_tensor;
mul_tensor.shape = Linear(2); mul_tensor.shape = Linear(2);
mul_tensor.data = {0.5f, 2.0f}; mul_tensor.data = {0.5f, 2.0f};
MultiplyScalarAttributes mul_attr; MultiplyAttributes mul_attr;
mul_attr.param = mul_tensor; mul_attr.param = mul_tensor;
FuseMultiplyWithFullyConnected(mul_attr, &attr); FuseMultiplyWithFullyConnected(mul_attr, &attr);

View File

@ -29,115 +29,116 @@ limitations under the License.
namespace tflite { namespace tflite {
namespace gpu { namespace gpu {
namespace gl { namespace gl {
namespace { namespace {
class ApplyMask : public NodeShader { bool IsApplyMaskSupported(const NodeShader::GenerationContext& ctx) {
public: const auto inputs = ctx.graph->FindInputs(ctx.node->id);
static bool IsSupported(const GenerationContext& ctx) { if (inputs.size() != 2) return false;
const auto inputs = ctx.graph->FindInputs(ctx.node->id); const auto& shape0 = inputs[0]->tensor.shape;
if (inputs.size() != 2) return false; const auto& shape1 = inputs[1]->tensor.shape;
const auto& shape0 = inputs[0]->tensor.shape;
const auto& shape1 = inputs[1]->tensor.shape;
// [H, W, C] x [H, W, 0][0] // [H, W, C] x [H, W, 0][0]
if (shape1.c == 1) return true; if (shape0.h == shape1.h && shape0.w == shape1.w && shape1.c == 1) {
return true;
if (shape0.c != shape1.c) return false;
// [H, W, C] x [H, W, C]
if (shape0.h == shape1.h && shape0.w == shape1.w) return true;
// [H, W, C] x [0, 0, C]
return shape1.h == 1 && shape1.w == 1;
} }
Status GenerateCode(const GenerationContext& ctx, // [H, W, C] x [H, W, C]
GeneratedCode* generated_code) const final { if (shape0 == shape1) {
if (!IsSupported(ctx)) { return true;
return InvalidArgumentError( }
"This case is not supported by apply mask operation");
}
const auto inputs = ctx.graph->FindInputs(ctx.node->id);
const auto& shape0 = inputs[0]->tensor.shape;
const auto& shape1 = inputs[1]->tensor.shape;
std::string source = "value_0 = $input_data_0[gid.x, gid.y, gid.z]$ * "; // [H, W, C] x [0, 0, C]
if (shape1.c == 1) { return shape1.h == 1 && shape1.w == 1 && shape0.c == shape1.c;
// [H, W, C] x [H, W, 0][0] }
absl::StrAppend(&source, "$input_data_1[gid.x, gid.y, 0]$.x;");
} else if (shape0.h == shape1.h && shape0.w == shape1.w) {
// [H, W, C] x [H, W, C]
absl::StrAppend(&source, "$input_data_1[gid.x, gid.y, gid.z]$;");
} else {
// [H, W, C] x [0, 0, C]
absl::StrAppend(&source, "$input_data_1[0, 0, gid.z]$;");
}
Status GenerateApplyMaskCode(const NodeShader::GenerationContext& ctx,
GeneratedCode* generated_code) {
const auto inputs = ctx.graph->FindInputs(ctx.node->id);
const auto& shape0 = inputs[0]->tensor.shape;
const auto& shape1 = inputs[1]->tensor.shape;
std::string source = "value_0 = $input_data_0[gid.x, gid.y, gid.z]$ * ";
if (shape1.c == 1) {
// [H, W, C] x [H, W, 0][0]
absl::StrAppend(&source, "$input_data_1[gid.x, gid.y, 0]$.x;");
} else if (shape0.h == shape1.h && shape0.w == shape1.w) {
// [H, W, C] x [H, W, C]
absl::StrAppend(&source, "$input_data_1[gid.x, gid.y, gid.z]$;");
} else {
// [H, W, C] x [0, 0, C]
absl::StrAppend(&source, "$input_data_1[0, 0, gid.z]$;");
}
*generated_code = {
/*parameters=*/{},
/*objects=*/{},
/*shared_variables=*/{},
/*workload=*/uint3(),
/*workgroup=*/uint3(),
/*source_code=*/std::move(source),
/*input=*/IOStructure::ONLY_DEFINITIONS,
/*output=*/IOStructure::AUTO,
};
return OkStatus();
}
Status GenerateMultiplyScalarCode(const NodeShader::GenerationContext& ctx,
GeneratedCode* generated_code) {
auto attr =
absl::any_cast<MultiplyAttributes>(ctx.node->operation.attributes);
auto muls = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&attr.param);
auto scalar = absl::get_if<float>(&attr.param);
if (scalar) {
*generated_code = { *generated_code = {
/*parameters=*/{}, /*parameters=*/{{"scalar", *scalar}},
/*objects=*/{}, /*objects=*/{},
/*shared_variables=*/{}, /*shared_variables=*/{},
/*workload=*/uint3(), /*workload=*/uint3(),
/*workgroup=*/uint3(), /*workgroup=*/uint3(),
/*source_code=*/std::move(source), /*source_code=*/"value_0 *= $scalar$;",
/*input=*/IOStructure::ONLY_DEFINITIONS, /*input=*/IOStructure::AUTO,
/*output=*/IOStructure::AUTO,
};
} else {
if (!muls) {
return InvalidArgumentError("Empty parameters for Multiplication.");
}
auto shape = ctx.graph->FindInputs(ctx.node->id)[0]->tensor.shape;
*generated_code = {
/*parameters=*/{},
/*objects=*/{{"mul_buffer", MakeReadonlyObject(muls->data)}},
/*shared_variables=*/{},
// Declare workload explicitly because shader depends on gid.z.
/*workload=*/
uint3(shape.w, shape.h, IntegralDivideRoundUp(shape.c, 4)),
/*workgroup=*/uint3(),
/*source_code=*/"value_0 *= $mul_buffer[gid.z]$;",
/*input=*/IOStructure::AUTO,
/*output=*/IOStructure::AUTO, /*output=*/IOStructure::AUTO,
}; };
return OkStatus();
} }
};
class MultiplyScalar : public NodeShader { return OkStatus();
}
class Multiply : public NodeShader {
public: public:
Status GenerateCode(const GenerationContext& ctx, Status GenerateCode(const GenerationContext& ctx,
GeneratedCode* generated_code) const final { GeneratedCode* generated_code) const final {
auto attr = absl::any_cast<MultiplyScalarAttributes>( if (IsApplyMaskSupported(ctx)) {
ctx.node->operation.attributes); return GenerateApplyMaskCode(ctx, generated_code);
auto muls = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&attr.param);
auto scalar = absl::get_if<float>(&attr.param);
if (scalar) {
*generated_code = {
/*parameters=*/{{"scalar", *scalar}},
/*objects=*/{},
/*shared_variables=*/{},
/*workload=*/uint3(),
/*workgroup=*/uint3(),
/*source_code=*/"value_0 *= $scalar$;",
/*input=*/IOStructure::AUTO,
/*output=*/IOStructure::AUTO,
};
} else { } else {
if (!muls) { return GenerateMultiplyScalarCode(ctx, generated_code);
return InvalidArgumentError("Empty parameters for Multiplication.");
}
auto shape = ctx.graph->FindInputs(ctx.node->id)[0]->tensor.shape;
*generated_code = {
/*parameters=*/{},
/*objects=*/{{"mul_buffer", MakeReadonlyObject(muls->data)}},
/*shared_variables=*/{},
// Declare workload explicitly because shader depends on gid.z.
/*workload=*/
uint3(shape.w, shape.h, IntegralDivideRoundUp(shape.c, 4)),
/*workgroup=*/uint3(),
/*source_code=*/"value_0 *= $mul_buffer[gid.z]$;",
/*input=*/IOStructure::AUTO,
/*output=*/IOStructure::AUTO,
};
} }
return OkStatus();
} }
}; };
} // namespace } // namespace
std::unique_ptr<NodeShader> NewApplyMaskNodeShader() { std::unique_ptr<NodeShader> NewMultiplyNodeShader() {
return absl::make_unique<ApplyMask>(); return absl::make_unique<Multiply>();
}
std::unique_ptr<NodeShader> NewMultiplyScalarNodeShader() {
return absl::make_unique<MultiplyScalar>();
} }
} // namespace gl } // namespace gl

View File

@ -25,9 +25,7 @@ namespace tflite {
namespace gpu { namespace gpu {
namespace gl { namespace gl {
std::unique_ptr<NodeShader> NewApplyMaskNodeShader(); std::unique_ptr<NodeShader> NewMultiplyNodeShader();
std::unique_ptr<NodeShader> NewMultiplyScalarNodeShader();
} // namespace gl } // namespace gl
} // namespace gpu } // namespace gpu

View File

@ -41,13 +41,12 @@ TEST(MulTest, Scalar) {
output.ref = 1; output.ref = 1;
output.shape = BHWC(1, 2, 2, 1); output.shape = BHWC(1, 2, 2, 1);
MultiplyScalarAttributes attr; MultiplyAttributes attr;
attr.param = 2.f; attr.param = 2.f;
// TODO(eignasheva): change to MULTIPLY_SCALAR
SingleOpModel model({ToString(OperationType::MUL), attr}, {input}, {output}); SingleOpModel model({ToString(OperationType::MUL), attr}, {input}, {output});
ASSERT_TRUE(model.PopulateTensor(0, {1, 2, 3, 4})); ASSERT_TRUE(model.PopulateTensor(0, {1, 2, 3, 4}));
ASSERT_OK(model.Invoke(*NewMultiplyScalarNodeShader())); ASSERT_OK(model.Invoke(*NewMultiplyNodeShader()));
EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {2, 4, 6, 8})); EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {2, 4, 6, 8}));
} }
@ -62,21 +61,20 @@ TEST(MulTest, Linear) {
output.ref = 1; output.ref = 1;
output.shape = BHWC(1, 1, 2, 2); output.shape = BHWC(1, 1, 2, 2);
MultiplyScalarAttributes attr; MultiplyAttributes attr;
Tensor<Linear, DataType::FLOAT32> tensor; Tensor<Linear, DataType::FLOAT32> tensor;
tensor.shape.v = 2; tensor.shape.v = 2;
tensor.id = 1; tensor.id = 1;
tensor.data = {2, 3}; tensor.data = {2, 3};
attr.param = std::move(tensor); attr.param = std::move(tensor);
// TODO(eignasheva): change to MULTIPLY_SCALAR
SingleOpModel model({ToString(OperationType::MUL), attr}, {input}, {output}); SingleOpModel model({ToString(OperationType::MUL), attr}, {input}, {output});
ASSERT_TRUE(model.PopulateTensor(0, {1, 2, 3, 4})); ASSERT_TRUE(model.PopulateTensor(0, {1, 2, 3, 4}));
ASSERT_OK(model.Invoke(*NewMultiplyScalarNodeShader())); ASSERT_OK(model.Invoke(*NewMultiplyNodeShader()));
EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {2, 6, 6, 12})); EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {2, 6, 6, 12}));
} }
TEST(ApplyMaskTest, MaskChannel1) { TEST(MulTest, MaskChannel1) {
TensorRef<BHWC> input; TensorRef<BHWC> input;
input.type = DataType::FLOAT32; input.type = DataType::FLOAT32;
input.ref = 0; input.ref = 0;
@ -92,15 +90,15 @@ TEST(ApplyMaskTest, MaskChannel1) {
output.ref = 2; output.ref = 2;
output.shape = BHWC(1, 1, 2, 2); output.shape = BHWC(1, 1, 2, 2);
SingleOpModel model({ToString(OperationType::APPLY_MASK), {}}, {input, mask}, SingleOpModel model({ToString(OperationType::MUL), {}}, {input, mask},
{output}); {output});
ASSERT_TRUE(model.PopulateTensor(0, {1, 2, 3, 4})); ASSERT_TRUE(model.PopulateTensor(0, {1, 2, 3, 4}));
ASSERT_TRUE(model.PopulateTensor(1, {2, 3})); ASSERT_TRUE(model.PopulateTensor(1, {2, 3}));
ASSERT_OK(model.Invoke(*NewApplyMaskNodeShader())); ASSERT_OK(model.Invoke(*NewMultiplyNodeShader()));
EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {2, 4, 9, 12})); EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {2, 4, 9, 12}));
} }
TEST(ApplyMaskTest, MaskChannelEqualsToInputChannel) { TEST(MulTest, MaskChannelEqualsToInputChannel) {
TensorRef<BHWC> input; TensorRef<BHWC> input;
input.type = DataType::FLOAT32; input.type = DataType::FLOAT32;
input.ref = 0; input.ref = 0;
@ -116,11 +114,11 @@ TEST(ApplyMaskTest, MaskChannelEqualsToInputChannel) {
output.ref = 2; output.ref = 2;
output.shape = BHWC(1, 1, 2, 2); output.shape = BHWC(1, 1, 2, 2);
SingleOpModel model({ToString(OperationType::APPLY_MASK), {}}, {input, mask}, SingleOpModel model({ToString(OperationType::MUL), {}}, {input, mask},
{output}); {output});
ASSERT_TRUE(model.PopulateTensor(0, {1, 2, 3, 4})); ASSERT_TRUE(model.PopulateTensor(0, {1, 2, 3, 4}));
ASSERT_TRUE(model.PopulateTensor(1, {1, 2, 3, 4})); ASSERT_TRUE(model.PopulateTensor(1, {1, 2, 3, 4}));
ASSERT_OK(model.Invoke(*NewApplyMaskNodeShader())); ASSERT_OK(model.Invoke(*NewMultiplyNodeShader()));
EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {1, 4, 9, 16})); EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {1, 4, 9, 16}));
} }

View File

@ -71,7 +71,6 @@ class Registry : public NodeShader {
}; };
insert_op(Type::ADD, NewAddNodeShader); insert_op(Type::ADD, NewAddNodeShader);
insert_op(Type::APPLY_MASK, NewApplyMaskNodeShader);
insert_op(Type::CONCAT, NewAlignedConcatNodeShader); insert_op(Type::CONCAT, NewAlignedConcatNodeShader);
insert_op(Type::CONCAT, NewFlatConcatNodeShader); insert_op(Type::CONCAT, NewFlatConcatNodeShader);
insert_op(Type::CONCAT, NewConcatNodeShader); insert_op(Type::CONCAT, NewConcatNodeShader);
@ -82,7 +81,7 @@ class Registry : public NodeShader {
insert_op(Type::FULLY_CONNECTED, NewFullyConnectedNodeShader); insert_op(Type::FULLY_CONNECTED, NewFullyConnectedNodeShader);
insert_op(Type::LSTM, NewLstmNodeShader); insert_op(Type::LSTM, NewLstmNodeShader);
insert_op(Type::MEAN, NewMeanNodeShader); insert_op(Type::MEAN, NewMeanNodeShader);
insert_op(Type::MULTIPLY_SCALAR, NewMultiplyScalarNodeShader); insert_op(Type::MUL, NewMultiplyNodeShader);
insert_op(Type::PAD, NewPadNodeShader); insert_op(Type::PAD, NewPadNodeShader);
insert_op(Type::POOLING_2D, NewPoolingNodeShader); insert_op(Type::POOLING_2D, NewPoolingNodeShader);
insert_op(Type::PRELU, NewPReLUNodeShader); insert_op(Type::PRELU, NewPReLUNodeShader);

View File

@ -199,11 +199,15 @@ Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
*tasks = Mean(node_id, inputs[0], outputs[0], *tasks = Mean(node_id, inputs[0], outputs[0],
absl::any_cast<MeanAttributes>(node->operation.attributes)); absl::any_cast<MeanAttributes>(node->operation.attributes));
break; break;
case OperationType::MULTIPLY_SCALAR: case OperationType::MUL:
*tasks = Multiply( if (node->operation.attributes.has_value()) {
node_id, inputs[0], outputs[0], *tasks = Multiply(
absl::any_cast<MultiplyScalarAttributes>(node->operation.attributes), node_id, inputs[0], outputs[0],
options); absl::any_cast<MultiplyAttributes>(node->operation.attributes),
options);
} else {
*tasks = ApplyMask(node_id, inputs[0], inputs[1], outputs[0], options);
}
break; break;
case OperationType::PAD: { case OperationType::PAD: {
auto attr = absl::any_cast<PadAttributes>(node->operation.attributes); auto attr = absl::any_cast<PadAttributes>(node->operation.attributes);
@ -268,12 +272,10 @@ Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
case OperationType::SQUARED_DIFF: case OperationType::SQUARED_DIFF:
*tasks = ElementwiseWithTwoInputs(node_id, inputs, outputs[0], op_type); *tasks = ElementwiseWithTwoInputs(node_id, inputs, outputs[0], op_type);
break; break;
case OperationType::APPLY_MASK:
case OperationType::BATCH_NORMALIZATION: case OperationType::BATCH_NORMALIZATION:
case OperationType::BATCH_TO_SPACE: case OperationType::BATCH_TO_SPACE:
case OperationType::CONST: case OperationType::CONST:
case OperationType::LSTM: case OperationType::LSTM:
case OperationType::MUL:
case OperationType::SPACE_TO_BATCH: case OperationType::SPACE_TO_BATCH:
case OperationType::TRANSPOSE: case OperationType::TRANSPOSE:
case OperationType::UNKNOWN: case OperationType::UNKNOWN:

View File

@ -128,9 +128,10 @@ std::vector<ComputeTaskDescriptorPtr> ApplyMask(int id, ValueId input_id_0,
return {desc}; return {desc};
} }
std::vector<ComputeTaskDescriptorPtr> Multiply( std::vector<ComputeTaskDescriptorPtr> Multiply(int id, ValueId input_id,
int id, ValueId input_id, ValueId output_id, ValueId output_id,
const MultiplyScalarAttributes& attr, const RuntimeOptions& options) { const MultiplyAttributes& attr,
const RuntimeOptions& options) {
auto desc = std::make_shared<ComputeTaskDescriptor>(); auto desc = std::make_shared<ComputeTaskDescriptor>();
desc->id = id; desc->id = id;
desc->is_linkable = true; desc->is_linkable = true;

View File

@ -26,9 +26,10 @@ namespace gpu {
namespace metal { namespace metal {
// Multiply operation, supports scalar and vector broadcast. // Multiply operation, supports scalar and vector broadcast.
std::vector<ComputeTaskDescriptorPtr> Multiply( std::vector<ComputeTaskDescriptorPtr> Multiply(int id, ValueId input_id,
int id, ValueId input_id, ValueId output_id, ValueId output_id,
const MultiplyScalarAttributes& attr, const RuntimeOptions& options); const MultiplyAttributes& attr,
const RuntimeOptions& options);
std::vector<ComputeTaskDescriptorPtr> ApplyMask(int id, ValueId input_id_0, std::vector<ComputeTaskDescriptorPtr> ApplyMask(int id, ValueId input_id_0,
ValueId input_id_1, ValueId input_id_1,

View File

@ -31,7 +31,7 @@ limitations under the License.
using ::tflite::gpu::DataType; using ::tflite::gpu::DataType;
using ::tflite::gpu::BHWC; using ::tflite::gpu::BHWC;
using ::tflite::gpu::Linear; using ::tflite::gpu::Linear;
using ::tflite::gpu::MultiplyScalarAttributes; using ::tflite::gpu::MultiplyAttributes;
using ::tflite::gpu::OperationType; using ::tflite::gpu::OperationType;
using ::tflite::gpu::Tensor; using ::tflite::gpu::Tensor;
using ::tflite::gpu::TensorRef; using ::tflite::gpu::TensorRef;
@ -57,10 +57,10 @@ using ::tflite::gpu::metal::SingleOpModel;
output.ref = 1; output.ref = 1;
output.shape = BHWC(1, 2, 2, 1); output.shape = BHWC(1, 2, 2, 1);
MultiplyScalarAttributes attr; MultiplyAttributes attr;
attr.param = 2; attr.param = 2;
SingleOpModel model({ToString(OperationType::MULTIPLY_SCALAR), attr}, {input}, {output}); SingleOpModel model({ToString(OperationType::MUL), attr}, {input}, {output});
XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4})); XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4}));
auto status = model.Invoke(); auto status = model.Invoke();
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
@ -79,14 +79,14 @@ using ::tflite::gpu::metal::SingleOpModel;
output.ref = 1; output.ref = 1;
output.shape = BHWC(1, 1, 2, 2); output.shape = BHWC(1, 1, 2, 2);
MultiplyScalarAttributes attr; MultiplyAttributes attr;
Tensor<Linear, DataType::FLOAT32> tensor; Tensor<Linear, DataType::FLOAT32> tensor;
tensor.shape.v = 2; tensor.shape.v = 2;
tensor.id = 1; tensor.id = 1;
tensor.data = {2, 3}; tensor.data = {2, 3};
attr.param = std::move(tensor); attr.param = std::move(tensor);
SingleOpModel model({ToString(OperationType::MULTIPLY_SCALAR), attr}, {input}, {output}); SingleOpModel model({ToString(OperationType::MUL), attr}, {input}, {output});
XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4})); XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4}));
auto status = model.Invoke(); auto status = model.Invoke();
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
@ -111,7 +111,7 @@ using ::tflite::gpu::metal::SingleOpModel;
output.ref = 2; output.ref = 2;
output.shape = BHWC(1, 1, 2, 2); output.shape = BHWC(1, 1, 2, 2);
SingleOpModel model({ToString(OperationType::APPLY_MASK), {}}, {input, mask}, {output}); SingleOpModel model({ToString(OperationType::MUL), {}}, {input, mask}, {output});
XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4})); XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4}));
XCTAssertTrue(model.PopulateTensor(1, {2, 3})); XCTAssertTrue(model.PopulateTensor(1, {2, 3}));
auto status = model.Invoke(); auto status = model.Invoke();
@ -136,13 +136,12 @@ using ::tflite::gpu::metal::SingleOpModel;
output.ref = 2; output.ref = 2;
output.shape = BHWC(1, 1, 2, 2); output.shape = BHWC(1, 1, 2, 2);
SingleOpModel model({ToString(OperationType::APPLY_MASK), {}}, {input, mask}, {output}); SingleOpModel model({ToString(OperationType::MUL), {}}, {input, mask}, {output});
XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4})); XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4}));
XCTAssertTrue(model.PopulateTensor(1, {1, 2, 3, 4})); XCTAssertTrue(model.PopulateTensor(1, {1, 2, 3, 4}));
auto status = model.Invoke(); auto status = model.Invoke();
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
// Disable test for now. status = CompareVectors({1, 4, 9, 16}, model.GetOutput(0), 1e-6f);
// status = CompareVectors({1, 4, 9, 16}, model.GetOutput(0), 1e-6f);
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str()); XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
} }