Remove APPLY_MASK and MULTIPLY_SCALAR enum values, use MUL instead.

Rename MultiplyScalarAttributes to MultiplyAttributes.

PiperOrigin-RevId: 292228929
Change-Id: I459585d62e4161a93d6ec163bdbc0494c37eaf0e
This commit is contained in:
A. Unique TensorFlower 2020-01-29 15:17:40 -08:00 committed by TensorFlower Gardener
parent abb256c1f8
commit 71c6f97e2d
25 changed files with 385 additions and 541 deletions

View File

@ -39,39 +39,6 @@ cc_test(
],
)
cc_library(
name = "apply_mask",
srcs = ["apply_mask.cc"],
hdrs = ["apply_mask.h"],
deps = [
":gpu_operation",
":util",
":work_group_picking",
"//tensorflow/lite/delegates/gpu/cl:cl_kernel",
"//tensorflow/lite/delegates/gpu/common:operations",
"//tensorflow/lite/delegates/gpu/common:status",
"//tensorflow/lite/delegates/gpu/common:types",
"@com_google_absl//absl/strings",
],
)
cc_test(
name = "apply_mask_test",
srcs = ["apply_mask_test.cc"],
linkstatic = True,
tags = tf_gpu_tests_tags() + [
"linux",
"local",
],
deps = [
":apply_mask",
":cl_test",
"//tensorflow/lite/delegates/gpu/common:operations",
"//tensorflow/lite/delegates/gpu/common:status",
"@com_google_googletest//:gtest_main",
],
)
cc_library(
name = "cl_test",
testonly = 1,
@ -1328,7 +1295,6 @@ test_suite(
name = "all_tests",
tests = [
"add_test",
"apply_mask_test",
"concat_test",
"conv_buffer_1x1_test",
"conv_buffer_test",

View File

@ -1,103 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/cl/kernels/apply_mask.h"
#include <string>
#include <vector>
#include "absl/strings/str_cat.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/common/types.h"
namespace tflite {
namespace gpu {
namespace cl {
ApplyMask::ApplyMask(ApplyMask&& operation)
: ElementwiseOperation(std::move(operation)),
mask_type_(operation.mask_type_),
link_index_(operation.link_index_) {}
ApplyMask& ApplyMask::operator=(ApplyMask&& operation) {
if (this != &operation) {
mask_type_ = operation.mask_type_;
link_index_ = operation.link_index_;
ElementwiseOperation::operator=(std::move(operation));
}
return *this;
}
void ApplyMask::SetLinkIndex(int index) { link_index_ = index; }
std::string ApplyMask::GetCoreCode(const LinkingContext& context) const {
const std::string size_name = "mask_size_op" + std::to_string(link_index_);
const std::string tensor_name = absl::StrCat("mask_data_op", link_index_);
TensorCodeGenerator mask(
tensor_name,
WHSPoint{size_name + ".x", size_name + ".y", size_name + ".z"},
definition_.src_tensors[1]);
switch (mask_type_) {
case MaskType::TENSOR:
return context.var_name + " *= " +
mask.ReadWHS(context.x_coord, context.y_coord, context.s_coord) +
";\n";
case MaskType::CHANNELS:
return context.var_name +
" *= " + mask.ReadWHS("0", "0", context.s_coord) + ";\n";
case MaskType::LAYER:
return context.var_name +
" *= " + mask.ReadWHS(context.x_coord, context.y_coord, "0") +
".x;\n";
}
}
std::string ApplyMask::GetArgsDeclaration() const {
std::string args;
const std::string tensor_name = absl::StrCat("mask_data_op", link_index_);
absl::StrAppend(&args, ",\n",
GetTensorDeclaration(AccessType::READ, tensor_name,
definition_.src_tensors[1]));
const std::string size_name = "mask_size_op" + std::to_string(link_index_);
absl::StrAppend(&args, ",\n int4 ", size_name);
return args;
}
Status ApplyMask::BindArguments(CLKernel* kernel) {
RETURN_IF_ERROR(kernel->SetMemoryAuto(src_[1]->GetMemoryPtr()));
RETURN_IF_ERROR(kernel->SetBytesAuto(src_[1]->GetWBatchedHSB()));
return OkStatus();
}
ApplyMask CreateApplyMask(const OperationDef& definition, const BHWC& src_shape,
const BHWC& mask_shape) {
ApplyMask::MaskType mask_type;
if (mask_shape == src_shape) {
mask_type = ApplyMask::MaskType::TENSOR;
} else if (mask_shape.c == 1) {
mask_type = ApplyMask::MaskType::LAYER;
} else {
mask_type = ApplyMask::MaskType::CHANNELS;
}
ApplyMask operation(definition, mask_type);
operation.SetLinkIndex(0);
return operation;
}
} // namespace cl
} // namespace gpu
} // namespace tflite

View File

@ -1,63 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_APPLY_MASK_H_
#define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_APPLY_MASK_H_
#include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
#include "tensorflow/lite/delegates/gpu/common/types.h"
namespace tflite {
namespace gpu {
namespace cl {
class ApplyMask : public ElementwiseOperation {
public:
// Move only
ApplyMask(ApplyMask&& operation);
ApplyMask& operator=(ApplyMask&& operation);
ApplyMask(const ApplyMask&) = delete;
ApplyMask& operator=(const ApplyMask&) = delete;
void SetLinkIndex(int index) override;
std::string GetCoreCode(const LinkingContext& context) const override;
std::string GetArgsDeclaration() const override;
Status BindArguments(CLKernel* kernel) override;
private:
friend ApplyMask CreateApplyMask(const OperationDef& definition,
const BHWC& src_shape,
const BHWC& mask_shape);
enum class MaskType { LAYER, CHANNELS, TENSOR };
explicit ApplyMask(const OperationDef& definition, MaskType mask_type)
: ElementwiseOperation(definition), mask_type_(mask_type) {}
MaskType mask_type_;
int link_index_;
};
ApplyMask CreateApplyMask(const OperationDef& definition, const BHWC& src_shape,
const BHWC& mask_shape);
} // namespace cl
} // namespace gpu
} // namespace tflite
#endif // TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_APPLY_MASK_H_

View File

@ -1,127 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/cl/kernels/apply_mask.h"
#include <memory>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/lite/delegates/gpu/cl/kernels/cl_test.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
using ::testing::FloatNear;
using ::testing::Pointwise;
namespace tflite {
namespace gpu {
namespace cl {
namespace {
TEST_F(OpenCLOperationTest, ApplyMaskOneChannel) {
TensorFloat32 src_tensor;
src_tensor.shape = BHWC(1, 2, 2, 2);
src_tensor.data = {-4.0f, -3.0f, -1.0f, 0.0f, 1.0f, 3.0f, 4.0f, 6.0f};
TensorFloat32 mask_tensor;
mask_tensor.shape = BHWC(1, 2, 2, 1);
mask_tensor.data = {2.0f, 0.5f, 1.0f, 0.0f};
for (auto storage : env_.GetSupportedStorages()) {
for (auto precision : env_.GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
OperationDef op_def;
op_def.precision = precision;
auto data_type = DeduceDataTypeFromPrecision(precision);
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
ApplyMask operation =
CreateApplyMask(op_def, src_tensor.shape, mask_tensor.shape);
ASSERT_OK(ExecuteGPUOperation({src_tensor, mask_tensor},
creation_context_, &operation,
BHWC(1, 2, 2, 2), &dst_tensor));
EXPECT_THAT(dst_tensor.data,
Pointwise(FloatNear(eps), {-8.0f, -6.0f, -0.5f, 0.0f, 1.0f,
3.0f, 0.0f, 0.0f}));
}
}
}
TEST_F(OpenCLOperationTest, ApplyMaskEqualSizes) {
TensorFloat32 src_tensor;
src_tensor.shape = BHWC(1, 2, 2, 2);
src_tensor.data = {-4.0f, -3.0f, -1.0f, 0.0f, 1.0f, 3.0f, 4.0f, 6.0f};
TensorFloat32 mask_tensor;
mask_tensor.shape = BHWC(1, 2, 2, 2);
mask_tensor.data = {2.0f, 0.5f, 1.0f, 0.0f, 2.0f, 0.5f, 1.0f, 0.0f};
for (auto storage : env_.GetSupportedStorages()) {
for (auto precision : env_.GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
OperationDef op_def;
op_def.precision = precision;
auto data_type = DeduceDataTypeFromPrecision(precision);
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
ApplyMask operation =
CreateApplyMask(op_def, src_tensor.shape, mask_tensor.shape);
ASSERT_OK(ExecuteGPUOperation({src_tensor, mask_tensor},
creation_context_, &operation,
BHWC(1, 2, 2, 2), &dst_tensor));
EXPECT_THAT(dst_tensor.data,
Pointwise(FloatNear(eps), {-8.0f, -1.5f, -1.0f, 0.0f, 2.0f,
1.5f, 4.0f, 0.0f}));
}
}
}
TEST_F(OpenCLOperationTest, ApplyMaskVector) {
TensorFloat32 src_tensor;
src_tensor.shape = BHWC(1, 2, 2, 2);
src_tensor.data = {-4.0f, -3.0f, -1.0f, 0.0f, 1.0f, 3.0f, 4.0f, 6.0f};
TensorFloat32 mask_tensor;
mask_tensor.shape = BHWC(1, 1, 1, 2);
mask_tensor.data = {2.0f, 0.5f};
for (auto storage : env_.GetSupportedStorages()) {
for (auto precision : env_.GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
OperationDef op_def;
op_def.precision = precision;
auto data_type = DeduceDataTypeFromPrecision(precision);
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
ApplyMask operation =
CreateApplyMask(op_def, src_tensor.shape, mask_tensor.shape);
ASSERT_OK(ExecuteGPUOperation({src_tensor, mask_tensor},
creation_context_, &operation,
BHWC(1, 2, 2, 2), &dst_tensor));
EXPECT_THAT(dst_tensor.data,
Pointwise(FloatNear(eps), {-8.0f, -1.5f, -2.0f, 0.0f, 2.0f,
1.5f, 8.0f, 3.0f}));
}
}
}
} // namespace
} // namespace cl
} // namespace gpu
} // namespace tflite

View File

@ -105,7 +105,7 @@ Status MultiplyAdd::BindArguments(CLKernel* kernel) {
return OkStatus();
}
Status MultiplyAdd::UploadMul(const MultiplyScalarAttributes& attr,
Status MultiplyAdd::UploadMul(const MultiplyAttributes& attr,
CalculationsPrecision scalar_precision,
CLContext* context) {
auto mul = absl::get_if<::tflite::gpu::Tensor<Linear, DataType::FLOAT32>>(
@ -135,8 +135,7 @@ Status MultiplyAdd::UploadAdd(const AddAttributes& attr,
Status CreateMultiplyAdd(const CreationContext& creation_context,
const OperationDef& definition,
const MultiplyScalarAttributes& attr,
MultiplyAdd* result) {
const MultiplyAttributes& attr, MultiplyAdd* result) {
const auto scalar_precision = creation_context.device->IsPowerVR()
? CalculationsPrecision::F32
: definition.precision;
@ -162,7 +161,7 @@ Status CreateMultiplyAdd(const CreationContext& creation_context,
Status CreateMultiplyAdd(const CreationContext& creation_context,
const OperationDef& definition,
const MultiplyScalarAttributes& mul_attr,
const MultiplyAttributes& mul_attr,
const AddAttributes& add_attr, MultiplyAdd* result) {
const auto scalar_precision = creation_context.device->IsPowerVR()
? CalculationsPrecision::F32
@ -176,6 +175,76 @@ Status CreateMultiplyAdd(const CreationContext& creation_context,
return OkStatus();
}
ApplyMask::ApplyMask(ApplyMask&& operation)
: ElementwiseOperation(std::move(operation)),
mask_type_(operation.mask_type_),
link_index_(operation.link_index_) {}
ApplyMask& ApplyMask::operator=(ApplyMask&& operation) {
if (this != &operation) {
mask_type_ = operation.mask_type_;
link_index_ = operation.link_index_;
ElementwiseOperation::operator=(std::move(operation));
}
return *this;
}
void ApplyMask::SetLinkIndex(int index) { link_index_ = index; }
std::string ApplyMask::GetCoreCode(const LinkingContext& context) const {
const std::string size_name = "mask_size_op" + std::to_string(link_index_);
const std::string tensor_name = absl::StrCat("mask_data_op", link_index_);
TensorCodeGenerator mask(
tensor_name,
WHSPoint{size_name + ".x", size_name + ".y", size_name + ".z"},
definition_.src_tensors[1]);
switch (mask_type_) {
case MaskType::TENSOR:
return context.var_name + " *= " +
mask.ReadWHS(context.x_coord, context.y_coord, context.s_coord) +
";\n";
case MaskType::CHANNELS:
return context.var_name +
" *= " + mask.ReadWHS("0", "0", context.s_coord) + ";\n";
case MaskType::LAYER:
return context.var_name +
" *= " + mask.ReadWHS(context.x_coord, context.y_coord, "0") +
".x;\n";
}
}
std::string ApplyMask::GetArgsDeclaration() const {
std::string args;
const std::string tensor_name = absl::StrCat("mask_data_op", link_index_);
absl::StrAppend(&args, ",\n",
GetTensorDeclaration(AccessType::READ, tensor_name,
definition_.src_tensors[1]));
const std::string size_name = "mask_size_op" + std::to_string(link_index_);
absl::StrAppend(&args, ",\n int4 ", size_name);
return args;
}
Status ApplyMask::BindArguments(CLKernel* kernel) {
RETURN_IF_ERROR(kernel->SetMemoryAuto(src_[1]->GetMemoryPtr()));
RETURN_IF_ERROR(kernel->SetBytesAuto(src_[1]->GetWBatchedHSB()));
return OkStatus();
}
ApplyMask CreateApplyMask(const OperationDef& definition, const BHWC& src_shape,
const BHWC& mask_shape) {
ApplyMask::MaskType mask_type;
if (mask_shape == src_shape) {
mask_type = ApplyMask::MaskType::TENSOR;
} else if (mask_shape.c == 1) {
mask_type = ApplyMask::MaskType::LAYER;
} else {
mask_type = ApplyMask::MaskType::CHANNELS;
}
ApplyMask operation(definition, mask_type);
operation.SetLinkIndex(0);
return operation;
}
} // namespace cl
} // namespace gpu
} // namespace tflite

View File

@ -40,7 +40,7 @@ class MultiplyAdd : public ElementwiseOperation {
MultiplyAdd(const MultiplyAdd&) = delete;
MultiplyAdd& operator=(const MultiplyAdd&) = delete;
Status UploadMul(const MultiplyScalarAttributes& attr,
Status UploadMul(const MultiplyAttributes& attr,
CalculationsPrecision scalar_precision, CLContext* context);
Status UploadAdd(const AddAttributes& attr,
CalculationsPrecision scalar_precision, CLContext* context);
@ -61,7 +61,7 @@ class MultiplyAdd : public ElementwiseOperation {
friend Status CreateMultiplyAdd(const CreationContext& creation_context,
const OperationDef& definition,
const MultiplyScalarAttributes& attr,
const MultiplyAttributes& attr,
MultiplyAdd* result);
friend Status CreateMultiplyAdd(const CreationContext& creation_context,
@ -71,7 +71,7 @@ class MultiplyAdd : public ElementwiseOperation {
friend Status CreateMultiplyAdd(const CreationContext& creation_context,
const OperationDef& definition,
const MultiplyScalarAttributes& mul_attr,
const MultiplyAttributes& mul_attr,
const AddAttributes& add_attr,
MultiplyAdd* result);
@ -91,8 +91,7 @@ class MultiplyAdd : public ElementwiseOperation {
Status CreateMultiplyAdd(const CreationContext& creation_context,
const OperationDef& definition,
const MultiplyScalarAttributes& attr,
MultiplyAdd* result);
const MultiplyAttributes& attr, MultiplyAdd* result);
Status CreateMultiplyAdd(const CreationContext& creation_context,
const OperationDef& definition,
@ -100,7 +99,7 @@ Status CreateMultiplyAdd(const CreationContext& creation_context,
Status CreateMultiplyAdd(const CreationContext& creation_context,
const OperationDef& definition,
const MultiplyScalarAttributes& mul_attr,
const MultiplyAttributes& mul_attr,
const AddAttributes& add_attr, MultiplyAdd* result);
template <DataType T>
@ -127,6 +126,36 @@ Status MultiplyAdd::UploadAdd(const ::tflite::gpu::Tensor<Linear, T>& add,
return OkStatus();
}
class ApplyMask : public ElementwiseOperation {
public:
// Move only
ApplyMask(ApplyMask&& operation);
ApplyMask& operator=(ApplyMask&& operation);
ApplyMask(const ApplyMask&) = delete;
ApplyMask& operator=(const ApplyMask&) = delete;
void SetLinkIndex(int index) override;
std::string GetCoreCode(const LinkingContext& context) const override;
std::string GetArgsDeclaration() const override;
Status BindArguments(CLKernel* kernel) override;
private:
friend ApplyMask CreateApplyMask(const OperationDef& definition,
const BHWC& src_shape,
const BHWC& mask_shape);
enum class MaskType { LAYER, CHANNELS, TENSOR };
explicit ApplyMask(const OperationDef& definition, MaskType mask_type)
: ElementwiseOperation(definition), mask_type_(mask_type) {}
MaskType mask_type_;
int link_index_;
};
ApplyMask CreateApplyMask(const OperationDef& definition, const BHWC& src_shape,
const BHWC& mask_shape);
} // namespace cl
} // namespace gpu
} // namespace tflite

View File

@ -37,7 +37,7 @@ TEST_F(OpenCLOperationTest, MultiplyAddVectorMul) {
src_tensor.shape = BHWC(1, 2, 1, 2);
src_tensor.data = {0.0f, 1.0f, 2.0f, 3.0f};
MultiplyScalarAttributes attr;
MultiplyAttributes attr;
::tflite::gpu::Tensor<Linear, DataType::FLOAT32> parameters;
parameters.shape = Linear(2);
parameters.data = {0.5f, 2.0f};
@ -97,7 +97,7 @@ TEST_F(OpenCLOperationTest, MultiplyAddScalarMul) {
src_tensor.shape = BHWC(1, 2, 1, 2);
src_tensor.data = {0.0f, 1.0f, 2.0f, 3.0f};
MultiplyScalarAttributes attr;
MultiplyAttributes attr;
attr.param = 0.5f;
for (auto storage : env_.GetSupportedStorages()) {
@ -151,7 +151,7 @@ TEST_F(OpenCLOperationTest, MultiplyAddVectorMad) {
src_tensor.shape = BHWC(1, 2, 1, 2);
src_tensor.data = {0.0f, 1.0f, 2.0f, 3.0f};
MultiplyScalarAttributes mul_attr;
MultiplyAttributes mul_attr;
::tflite::gpu::Tensor<Linear, DataType::FLOAT32> parameters;
parameters.shape = Linear(2);
parameters.data = {0.5f, 2.0f};
@ -181,6 +181,96 @@ TEST_F(OpenCLOperationTest, MultiplyAddVectorMad) {
}
}
TEST_F(OpenCLOperationTest, ApplyMaskOneChannel) {
TensorFloat32 src_tensor;
src_tensor.shape = BHWC(1, 2, 2, 2);
src_tensor.data = {-4.0f, -3.0f, -1.0f, 0.0f, 1.0f, 3.0f, 4.0f, 6.0f};
TensorFloat32 mask_tensor;
mask_tensor.shape = BHWC(1, 2, 2, 1);
mask_tensor.data = {2.0f, 0.5f, 1.0f, 0.0f};
for (auto storage : env_.GetSupportedStorages()) {
for (auto precision : env_.GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
OperationDef op_def;
op_def.precision = precision;
auto data_type = DeduceDataTypeFromPrecision(precision);
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
ApplyMask operation =
CreateApplyMask(op_def, src_tensor.shape, mask_tensor.shape);
ASSERT_OK(ExecuteGPUOperation({src_tensor, mask_tensor},
creation_context_, &operation,
BHWC(1, 2, 2, 2), &dst_tensor));
EXPECT_THAT(dst_tensor.data,
Pointwise(FloatNear(eps), {-8.0f, -6.0f, -0.5f, 0.0f, 1.0f,
3.0f, 0.0f, 0.0f}));
}
}
}
TEST_F(OpenCLOperationTest, ApplyMaskEqualSizes) {
TensorFloat32 src_tensor;
src_tensor.shape = BHWC(1, 2, 2, 2);
src_tensor.data = {-4.0f, -3.0f, -1.0f, 0.0f, 1.0f, 3.0f, 4.0f, 6.0f};
TensorFloat32 mask_tensor;
mask_tensor.shape = BHWC(1, 2, 2, 2);
mask_tensor.data = {2.0f, 0.5f, 1.0f, 0.0f, 2.0f, 0.5f, 1.0f, 0.0f};
for (auto storage : env_.GetSupportedStorages()) {
for (auto precision : env_.GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
OperationDef op_def;
op_def.precision = precision;
auto data_type = DeduceDataTypeFromPrecision(precision);
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
ApplyMask operation =
CreateApplyMask(op_def, src_tensor.shape, mask_tensor.shape);
ASSERT_OK(ExecuteGPUOperation({src_tensor, mask_tensor},
creation_context_, &operation,
BHWC(1, 2, 2, 2), &dst_tensor));
EXPECT_THAT(dst_tensor.data,
Pointwise(FloatNear(eps), {-8.0f, -1.5f, -1.0f, 0.0f, 2.0f,
1.5f, 4.0f, 0.0f}));
}
}
}
TEST_F(OpenCLOperationTest, ApplyMaskVector) {
TensorFloat32 src_tensor;
src_tensor.shape = BHWC(1, 2, 2, 2);
src_tensor.data = {-4.0f, -3.0f, -1.0f, 0.0f, 1.0f, 3.0f, 4.0f, 6.0f};
TensorFloat32 mask_tensor;
mask_tensor.shape = BHWC(1, 1, 1, 2);
mask_tensor.data = {2.0f, 0.5f};
for (auto storage : env_.GetSupportedStorages()) {
for (auto precision : env_.GetSupportedPrecisions()) {
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
OperationDef op_def;
op_def.precision = precision;
auto data_type = DeduceDataTypeFromPrecision(precision);
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
TensorFloat32 dst_tensor;
ApplyMask operation =
CreateApplyMask(op_def, src_tensor.shape, mask_tensor.shape);
ASSERT_OK(ExecuteGPUOperation({src_tensor, mask_tensor},
creation_context_, &operation,
BHWC(1, 2, 2, 2), &dst_tensor));
EXPECT_THAT(dst_tensor.data,
Pointwise(FloatNear(eps), {-8.0f, -1.5f, -2.0f, 0.0f, 2.0f,
1.5f, 8.0f, 3.0f}));
}
}
}
} // namespace
} // namespace cl
} // namespace gpu

View File

@ -103,7 +103,6 @@ cc_library(
hdrs = ["simple_selectors.h"],
deps = [
"//tensorflow/lite/delegates/gpu/cl/kernels:add",
"//tensorflow/lite/delegates/gpu/cl/kernels:apply_mask",
"//tensorflow/lite/delegates/gpu/cl/kernels:concat_xy",
"//tensorflow/lite/delegates/gpu/cl/kernels:concat_z",
"//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation",

View File

@ -68,11 +68,6 @@ Status GPUOperationFromNode(const CreationContext& creation_context,
return OkStatus();
}
}
case OperationType::APPLY_MASK: {
SelectApplyMask(op_def, inputs[0]->tensor.shape, inputs[1]->tensor.shape,
gpu_op);
return OkStatus();
}
case OperationType::CONCAT: {
auto attr = absl::any_cast<ConcatAttributes>(node.operation.attributes);
std::vector<int> channels(inputs.size());
@ -119,10 +114,17 @@ Status GPUOperationFromNode(const CreationContext& creation_context,
auto attr = absl::any_cast<MeanAttributes>(node.operation.attributes);
return SelectMean(attr, op_def, gpu_op);
}
case OperationType::MULTIPLY_SCALAR: {
auto attr =
absl::any_cast<MultiplyScalarAttributes>(node.operation.attributes);
return SelectMultiplyScalar(attr, creation_context, op_def, gpu_op);
case OperationType::MUL: {
if (node.operation.attributes.has_value()) {
auto attr =
absl::any_cast<MultiplyAttributes>(node.operation.attributes);
return SelectMultiplyScalar(attr, creation_context, op_def, gpu_op);
} else {
SelectApplyMask(op_def, inputs[0]->tensor.shape,
inputs[1]->tensor.shape, gpu_op);
return OkStatus();
}
}
case OperationType::PAD: {
auto attr = absl::any_cast<PadAttributes>(node.operation.attributes);

View File

@ -20,7 +20,6 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/add.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/apply_mask.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/concat_xy.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/concat_z.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/lstm.h"
@ -155,7 +154,7 @@ Status SelectMean(const MeanAttributes& attr, const OperationDef& op_def,
return OkStatus();
}
Status SelectMultiplyScalar(const MultiplyScalarAttributes& attr,
Status SelectMultiplyScalar(const MultiplyAttributes& attr,
const CreationContext& creation_context,
const OperationDef& op_def,
std::unique_ptr<GPUOperation>* ptr) {

View File

@ -73,7 +73,7 @@ void SelectStridedSlice(const SliceAttributes& attr, const OperationDef& op_def,
Status SelectMean(const MeanAttributes& attr, const OperationDef& op_def,
std::unique_ptr<GPUOperation>* ptr);
Status SelectMultiplyScalar(const MultiplyScalarAttributes& attr,
Status SelectMultiplyScalar(const MultiplyAttributes& attr,
const CreationContext& creation_context,
const OperationDef& op_def,
std::unique_ptr<GPUOperation>* ptr);

View File

@ -1395,8 +1395,11 @@ class MulOperationParser : public TFLiteOperationParser {
const bool runtime_tensor0 = !constant_tensor0;
const bool runtime_tensor1 = !constant_tensor1;
// Parse for APPLY_MASK. The "larger" input tensor must be bound to 1st
// input and the "smaller" input tensor ("mask") must be bound to 2nd input.
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::MUL);
// The "larger" input tensor must be bound to 1st input and the "smaller"
// input tensor ("mask") must be bound to 2nd input.
if (runtime_tensor0 && runtime_tensor1) {
BHWC shape0;
RETURN_IF_ERROR(ExtractTensorShape(*input0, &shape0));
@ -1409,11 +1412,11 @@ class MulOperationParser : public TFLiteOperationParser {
input_tensor0 = 1;
input_tensor1 = 0;
}
return ParseApplyMask(input_tensor0, input_tensor1, graph, reader);
return ParseApplyMask(node, input_tensor0, input_tensor1, graph, reader);
}
// Parse for MULTIPLY_SCALAR. The runtime input tensor must be bound to 1st
// input and the constant input tensor must be bound to 2nd input.
// The runtime input tensor must be bound to 1st input and the constant
// input tensor must be bound to 2nd input.
int runtime_tensor = 0;
int constant_tensor = 1;
TfLiteIntArray* constant_dims = input1->dims;
@ -1422,27 +1425,24 @@ class MulOperationParser : public TFLiteOperationParser {
constant_tensor = 0;
constant_dims = input0->dims;
}
return ParseMultiplyScalar(runtime_tensor, constant_tensor, constant_dims,
graph, reader);
return ParseMultiplyScalar(node, runtime_tensor, constant_tensor,
constant_dims, graph, reader);
}
private:
Status ParseApplyMask(int input_tensor0, int input_tensor1,
Status ParseApplyMask(Node* node, int input_tensor0, int input_tensor1,
GraphFloat32* graph, ObjectReader* reader) {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::APPLY_MASK);
RETURN_IF_ERROR(reader->AddInput(node, input_tensor0));
RETURN_IF_ERROR(reader->AddInput(node, input_tensor1));
return reader->AddOutputs(node);
}
Status ParseMultiplyScalar(int runtime_tensor, int constant_tensor,
Status ParseMultiplyScalar(Node* node, int runtime_tensor,
int constant_tensor,
const TfLiteIntArray* constant_dims,
GraphFloat32* graph, ObjectReader* reader) {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::MULTIPLY_SCALAR);
RETURN_IF_ERROR(reader->AddInput(node, runtime_tensor));
MultiplyScalarAttributes attr;
MultiplyAttributes attr;
if (constant_dims->size <= 0) {
Tensor<Scalar, DataType::FLOAT32> tensor;
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));

View File

@ -72,8 +72,6 @@ std::string ToString(enum OperationType op) {
return "abs";
case OperationType::ADD:
return "add";
case OperationType::APPLY_MASK:
return "apply_mask";
case OperationType::BATCH_NORMALIZATION:
return "batch_normalization";
case OperationType::BATCH_TO_SPACE:
@ -106,8 +104,6 @@ std::string ToString(enum OperationType op) {
return "mean";
case OperationType::MUL:
return "mul";
case OperationType::MULTIPLY_SCALAR:
return "multiply_scalar";
case OperationType::PAD:
return "pad";
case OperationType::POOLING_2D:
@ -157,7 +153,6 @@ OperationType OperationTypeFromString(const std::string& name) {
new std::unordered_map<std::string, OperationType>({
{"abs", OperationType::ABS},
{"add", OperationType::ADD},
{"apply_mask", OperationType::APPLY_MASK},
{"batch_normalization", OperationType::BATCH_NORMALIZATION},
{"concat", OperationType::CONCAT},
{"const", OperationType::CONST},
@ -173,7 +168,6 @@ OperationType OperationTypeFromString(const std::string& name) {
{"max_unpooling", OperationType::MAX_UNPOOLING_2D},
{"mean", OperationType::MEAN},
{"mul", OperationType::MUL},
{"multiply_scalar", OperationType::MULTIPLY_SCALAR},
{"pad", OperationType::PAD},
{"pooling_2d", OperationType::POOLING_2D},
{"pow", OperationType::POW},

View File

@ -34,8 +34,6 @@ enum class OperationType {
UNKNOWN = 0,
ABS,
ADD,
// TODO(eignasheva): remove APPLY_MASK operation, is should be just MUL
APPLY_MASK,
BATCH_TO_SPACE,
BATCH_NORMALIZATION,
CONCAT,
@ -52,7 +50,6 @@ enum class OperationType {
MAX_UNPOOLING_2D,
MEAN,
MUL,
MULTIPLY_SCALAR,
PAD,
POOLING_2D,
POW,
@ -354,7 +351,7 @@ struct LstmAttributes {
LstmKernelType kernel_type = LstmKernelType::BASIC;
};
struct MultiplyScalarAttributes {
struct MultiplyAttributes {
absl::variant<absl::monostate, Tensor<Linear, DataType::FLOAT32>, float>
param;
};

View File

@ -33,15 +33,14 @@ class MergeConvolutionWithMul : public SequenceTransformation {
GraphFloat32* graph) final {
auto& conv_node = *sequence[0];
auto& mul_node = *sequence[1];
if (mul_node.operation.type != ToString(OperationType::MUL) &&
mul_node.operation.type != ToString(OperationType::MULTIPLY_SCALAR)) {
if (mul_node.operation.type != ToString(OperationType::MUL) ||
!mul_node.operation.attributes.has_value()) {
return {TransformStatus::SKIPPED, ""};
}
MultiplyScalarAttributes mul_attr =
absl::any_cast<MultiplyScalarAttributes>(mul_node.operation.attributes);
if (!absl::get_if<Tensor<Linear, DataType::FLOAT32>>(
&mul_attr.param) &&
MultiplyAttributes mul_attr =
absl::any_cast<MultiplyAttributes>(mul_node.operation.attributes);
if (!absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param) &&
!absl::get_if<float>(&mul_attr.param)) {
return {
TransformStatus::DECLINED,
@ -93,13 +92,13 @@ class MergeMulWithConvolution : public SequenceTransformation {
GraphFloat32* graph) final {
auto& conv_node = *sequence[1];
auto& mul_node = *sequence[0];
if (mul_node.operation.type != ToString(OperationType::MUL) &&
mul_node.operation.type != ToString(OperationType::MULTIPLY_SCALAR)) {
if (mul_node.operation.type != ToString(OperationType::MUL) ||
!mul_node.operation.attributes.has_value()) {
return {TransformStatus::SKIPPED, ""};
}
MultiplyScalarAttributes mul_attr =
absl::any_cast<MultiplyScalarAttributes>(mul_node.operation.attributes);
MultiplyAttributes mul_attr =
absl::any_cast<MultiplyAttributes>(mul_node.operation.attributes);
if (!absl::get_if<Tensor<Linear, DataType::FLOAT32>>(
&mul_attr.param) &&
!absl::get_if<float>(&mul_attr.param)) {
@ -155,7 +154,7 @@ std::unique_ptr<SequenceTransformation> NewMergeMulWithConvolution() {
return absl::make_unique<MergeMulWithConvolution>();
}
void FuseConvolution2DWithMultiply(const MultiplyScalarAttributes& mul_attr,
void FuseConvolution2DWithMultiply(const MultiplyAttributes& mul_attr,
Convolution2DAttributes* attr) {
auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
auto mul_scalar = absl::get_if<float>(&mul_attr.param);
@ -176,7 +175,7 @@ void FuseConvolution2DWithMultiply(const MultiplyScalarAttributes& mul_attr,
}
void FuseDepthwiseConvolution2DWithMultiply(
const MultiplyScalarAttributes& mul_attr,
const MultiplyAttributes& mul_attr,
DepthwiseConvolution2DAttributes* attr) {
auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
auto mul_scalar = absl::get_if<float>(&mul_attr.param);
@ -198,8 +197,7 @@ void FuseDepthwiseConvolution2DWithMultiply(
}
void FuseConvolutionTransposedWithMultiply(
const MultiplyScalarAttributes& mul_attr,
ConvolutionTransposedAttributes* attr) {
const MultiplyAttributes& mul_attr, ConvolutionTransposedAttributes* attr) {
auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
auto mul_scalar = absl::get_if<float>(&mul_attr.param);
for (int d = 0; d < attr->weights.shape.o; ++d) {
@ -218,7 +216,7 @@ void FuseConvolutionTransposedWithMultiply(
}
}
void FuseFullyConnectedWithMultiply(const MultiplyScalarAttributes& mul_attr,
void FuseFullyConnectedWithMultiply(const MultiplyAttributes& mul_attr,
FullyConnectedAttributes* attr) {
auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
auto mul_scalar = absl::get_if<float>(&mul_attr.param);
@ -234,7 +232,7 @@ void FuseFullyConnectedWithMultiply(const MultiplyScalarAttributes& mul_attr,
}
}
void FuseMultiplyWithConvolution2D(const MultiplyScalarAttributes& mul_attr,
void FuseMultiplyWithConvolution2D(const MultiplyAttributes& mul_attr,
Convolution2DAttributes* attr) {
auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
auto mul_scalar = absl::get_if<float>(&mul_attr.param);
@ -252,7 +250,7 @@ void FuseMultiplyWithConvolution2D(const MultiplyScalarAttributes& mul_attr,
}
void FuseMultiplyWithDepthwiseConvolution2D(
const MultiplyScalarAttributes& mul_attr,
const MultiplyAttributes& mul_attr,
DepthwiseConvolution2DAttributes* attr) {
auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
auto mul_scalar = absl::get_if<float>(&mul_attr.param);
@ -270,8 +268,7 @@ void FuseMultiplyWithDepthwiseConvolution2D(
}
void FuseMultiplyWithConvolutionTransposed(
const MultiplyScalarAttributes& mul_attr,
ConvolutionTransposedAttributes* attr) {
const MultiplyAttributes& mul_attr, ConvolutionTransposedAttributes* attr) {
auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
auto mul_scalar = absl::get_if<float>(&mul_attr.param);
for (int s = 0; s < attr->weights.shape.i; ++s) {
@ -287,7 +284,7 @@ void FuseMultiplyWithConvolutionTransposed(
}
}
void FuseMultiplyWithFullyConnected(const MultiplyScalarAttributes& mul_attr,
void FuseMultiplyWithFullyConnected(const MultiplyAttributes& mul_attr,
FullyConnectedAttributes* attr) {
auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
auto mul_scalar = absl::get_if<float>(&mul_attr.param);

View File

@ -38,53 +38,49 @@ std::unique_ptr<SequenceTransformation> NewMergeMulWithConvolution();
// Modify Convolution2DAttributes so that after making convolution with
// modified attributes we will have the same result as convolution
// with old attributes and following multiply operation.
void FuseConvolution2DWithMultiply(const MultiplyScalarAttributes& mul_attr,
void FuseConvolution2DWithMultiply(const MultiplyAttributes& mul_attr,
Convolution2DAttributes* attr);
// Modify DepthwiseConvolution2DAttributes so that after making depth wise
// convolution with modified attributes we will have the same result as depth
// wise convolution with old attributes and following multiply operation.
void FuseDepthwiseConvolution2DWithMultiply(
const MultiplyScalarAttributes& mul_attr,
DepthwiseConvolution2DAttributes* attr);
const MultiplyAttributes& mul_attr, DepthwiseConvolution2DAttributes* attr);
// Modify ConvolutionTransposedAttributes so that after making convolution
// transposed with modified attributes we will have the same result as
// convolution transposed with old attributes and following multiply operation.
void FuseConvolutionTransposedWithMultiply(
const MultiplyScalarAttributes& mul_attr,
ConvolutionTransposedAttributes* attr);
const MultiplyAttributes& mul_attr, ConvolutionTransposedAttributes* attr);
// Modify FullyConnectedAttributes so that after making fully connected with
// modified attributes we will have the same result as fully connected
// with old attributes and following multiply operation.
void FuseFullyConnectedWithMultiply(const MultiplyScalarAttributes& mul_attr,
void FuseFullyConnectedWithMultiply(const MultiplyAttributes& mul_attr,
FullyConnectedAttributes* attr);
// Modify Convolution2DAttributes so that after making convolution with
// modified attributes we will have the same result as multiply operation and
// convolution with old attributes
void FuseMultiplyWithConvolution2D(const MultiplyScalarAttributes& mul_attr,
void FuseMultiplyWithConvolution2D(const MultiplyAttributes& mul_attr,
Convolution2DAttributes* attr);
// Modify DepthwiseConvolution2DAttributes so that after making depth wise
// convolution with modified attributes we will have the same result as multiply
// operation and depth wise convolution with old attributes
void FuseMultiplyWithDepthwiseConvolution2D(
const MultiplyScalarAttributes& mul_attr,
DepthwiseConvolution2DAttributes* attr);
const MultiplyAttributes& mul_attr, DepthwiseConvolution2DAttributes* attr);
// Modify ConvolutionTransposedAttributes so that after making convolution
// transposed with modified attributes we will have the same result as multiply
// operation and convolution transposed with old attributes
void FuseMultiplyWithConvolutionTransposed(
const MultiplyScalarAttributes& mul_attr,
ConvolutionTransposedAttributes* attr);
const MultiplyAttributes& mul_attr, ConvolutionTransposedAttributes* attr);
// Modify FullyConnectedAttributes so that after making fully connected
// with modified attributes we will have the same result as multiply
// operation and fully connected with old attributes
void FuseMultiplyWithFullyConnected(const MultiplyScalarAttributes& mul_attr,
void FuseMultiplyWithFullyConnected(const MultiplyAttributes& mul_attr,
FullyConnectedAttributes* attr);
} // namespace gpu

View File

@ -46,7 +46,7 @@ TEST(MergeConvolutionWithMulTest, Smoke) {
Tensor<Linear, DataType::FLOAT32> mul_tensor;
mul_tensor.shape = Linear(16);
mul_tensor.data.resize(16);
MultiplyScalarAttributes mul_attr;
MultiplyAttributes mul_attr;
mul_attr.param = mul_tensor;
auto conv_node = graph.NewNode();
@ -87,7 +87,7 @@ TEST(MergeMulWithConvolutionTest, Smoke) {
Tensor<Linear, DataType::FLOAT32> mul_tensor;
mul_tensor.shape = Linear(8);
mul_tensor.data.resize(8);
MultiplyScalarAttributes mul_attr;
MultiplyAttributes mul_attr;
mul_attr.param = mul_tensor;
Convolution2DAttributes conv_attr;
@ -140,7 +140,7 @@ TEST(FuseMulAfterConvolution2DTest, Smoke) {
Tensor<Linear, DataType::FLOAT32> mul_tensor;
mul_tensor.shape = Linear(2);
mul_tensor.data = {0.5f, 2.0f};
MultiplyScalarAttributes mul_attr;
MultiplyAttributes mul_attr;
mul_attr.param = mul_tensor;
FuseConvolution2DWithMultiply(mul_attr, &attr);
@ -161,7 +161,7 @@ TEST(FuseMulAfterDepthwiseConvolution2DTest, Smoke) {
Tensor<Linear, DataType::FLOAT32> mul_tensor;
mul_tensor.shape = Linear(4);
mul_tensor.data = {0.5f, 2.0f, 4.0f, 0.25f};
MultiplyScalarAttributes mul_attr;
MultiplyAttributes mul_attr;
mul_attr.param = mul_tensor;
FuseDepthwiseConvolution2DWithMultiply(mul_attr, &attr);
@ -183,7 +183,7 @@ TEST(FuseMulAfterConvolutionTransposedTest, Smoke) {
Tensor<Linear, DataType::FLOAT32> mul_tensor;
mul_tensor.shape = Linear(2);
mul_tensor.data = {0.5f, 2.0f};
MultiplyScalarAttributes mul_attr;
MultiplyAttributes mul_attr;
mul_attr.param = mul_tensor;
FuseConvolutionTransposedWithMultiply(mul_attr, &attr);
@ -204,7 +204,7 @@ TEST(FuseMulAfterFullyConnectedTest, Smoke) {
Tensor<Linear, DataType::FLOAT32> mul_tensor;
mul_tensor.shape = Linear(2);
mul_tensor.data = {0.5f, 2.0f};
MultiplyScalarAttributes mul_attr;
MultiplyAttributes mul_attr;
mul_attr.param = mul_tensor;
FuseFullyConnectedWithMultiply(mul_attr, &attr);
@ -224,7 +224,7 @@ TEST(FuseMulBeforeConvolution2DTest, Smoke) {
Tensor<Linear, DataType::FLOAT32> mul_tensor;
mul_tensor.shape = Linear(2);
mul_tensor.data = {0.5f, 2.0f};
MultiplyScalarAttributes mul_attr;
MultiplyAttributes mul_attr;
mul_attr.param = mul_tensor;
FuseMultiplyWithConvolution2D(mul_attr, &attr);
@ -245,7 +245,7 @@ TEST(FuseMulBeforeDepthwiseConvolution2DTest, Smoke) {
Tensor<Linear, DataType::FLOAT32> mul_tensor;
mul_tensor.shape = Linear(4);
mul_tensor.data = {0.5f, 2.0f, 4.0f, 0.25f};
MultiplyScalarAttributes mul_attr;
MultiplyAttributes mul_attr;
mul_attr.param = mul_tensor;
FuseMultiplyWithDepthwiseConvolution2D(mul_attr, &attr);
@ -267,7 +267,7 @@ TEST(FuseMulBeforeConvolutionTransposedTest, Smoke) {
Tensor<Linear, DataType::FLOAT32> mul_tensor;
mul_tensor.shape = Linear(2);
mul_tensor.data = {0.5f, 2.0f};
MultiplyScalarAttributes mul_attr;
MultiplyAttributes mul_attr;
mul_attr.param = mul_tensor;
FuseMultiplyWithConvolutionTransposed(mul_attr, &attr);
@ -288,7 +288,7 @@ TEST(FuseMulBeforeFullyConnectedTest, Smoke) {
Tensor<Linear, DataType::FLOAT32> mul_tensor;
mul_tensor.shape = Linear(2);
mul_tensor.data = {0.5f, 2.0f};
MultiplyScalarAttributes mul_attr;
MultiplyAttributes mul_attr;
mul_attr.param = mul_tensor;
FuseMultiplyWithFullyConnected(mul_attr, &attr);

View File

@ -29,115 +29,116 @@ limitations under the License.
namespace tflite {
namespace gpu {
namespace gl {
namespace {
class ApplyMask : public NodeShader {
public:
static bool IsSupported(const GenerationContext& ctx) {
const auto inputs = ctx.graph->FindInputs(ctx.node->id);
if (inputs.size() != 2) return false;
const auto& shape0 = inputs[0]->tensor.shape;
const auto& shape1 = inputs[1]->tensor.shape;
bool IsApplyMaskSupported(const NodeShader::GenerationContext& ctx) {
const auto inputs = ctx.graph->FindInputs(ctx.node->id);
if (inputs.size() != 2) return false;
const auto& shape0 = inputs[0]->tensor.shape;
const auto& shape1 = inputs[1]->tensor.shape;
// [H, W, C] x [H, W, 0][0]
if (shape1.c == 1) return true;
if (shape0.c != shape1.c) return false;
// [H, W, C] x [H, W, C]
if (shape0.h == shape1.h && shape0.w == shape1.w) return true;
// [H, W, C] x [0, 0, C]
return shape1.h == 1 && shape1.w == 1;
// [H, W, C] x [H, W, 0][0]
if (shape0.h == shape1.h && shape0.w == shape1.w && shape1.c == 1) {
return true;
}
Status GenerateCode(const GenerationContext& ctx,
GeneratedCode* generated_code) const final {
if (!IsSupported(ctx)) {
return InvalidArgumentError(
"This case is not supported by apply mask operation");
}
const auto inputs = ctx.graph->FindInputs(ctx.node->id);
const auto& shape0 = inputs[0]->tensor.shape;
const auto& shape1 = inputs[1]->tensor.shape;
// [H, W, C] x [H, W, C]
if (shape0 == shape1) {
return true;
}
std::string source = "value_0 = $input_data_0[gid.x, gid.y, gid.z]$ * ";
if (shape1.c == 1) {
// [H, W, C] x [H, W, 0][0]
absl::StrAppend(&source, "$input_data_1[gid.x, gid.y, 0]$.x;");
} else if (shape0.h == shape1.h && shape0.w == shape1.w) {
// [H, W, C] x [H, W, C]
absl::StrAppend(&source, "$input_data_1[gid.x, gid.y, gid.z]$;");
} else {
// [H, W, C] x [0, 0, C]
absl::StrAppend(&source, "$input_data_1[0, 0, gid.z]$;");
}
// [H, W, C] x [0, 0, C]
return shape1.h == 1 && shape1.w == 1 && shape0.c == shape1.c;
}
Status GenerateApplyMaskCode(const NodeShader::GenerationContext& ctx,
GeneratedCode* generated_code) {
const auto inputs = ctx.graph->FindInputs(ctx.node->id);
const auto& shape0 = inputs[0]->tensor.shape;
const auto& shape1 = inputs[1]->tensor.shape;
std::string source = "value_0 = $input_data_0[gid.x, gid.y, gid.z]$ * ";
if (shape1.c == 1) {
// [H, W, C] x [H, W, 0][0]
absl::StrAppend(&source, "$input_data_1[gid.x, gid.y, 0]$.x;");
} else if (shape0.h == shape1.h && shape0.w == shape1.w) {
// [H, W, C] x [H, W, C]
absl::StrAppend(&source, "$input_data_1[gid.x, gid.y, gid.z]$;");
} else {
// [H, W, C] x [0, 0, C]
absl::StrAppend(&source, "$input_data_1[0, 0, gid.z]$;");
}
*generated_code = {
/*parameters=*/{},
/*objects=*/{},
/*shared_variables=*/{},
/*workload=*/uint3(),
/*workgroup=*/uint3(),
/*source_code=*/std::move(source),
/*input=*/IOStructure::ONLY_DEFINITIONS,
/*output=*/IOStructure::AUTO,
};
return OkStatus();
}
Status GenerateMultiplyScalarCode(const NodeShader::GenerationContext& ctx,
GeneratedCode* generated_code) {
auto attr =
absl::any_cast<MultiplyAttributes>(ctx.node->operation.attributes);
auto muls = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&attr.param);
auto scalar = absl::get_if<float>(&attr.param);
if (scalar) {
*generated_code = {
/*parameters=*/{},
/*parameters=*/{{"scalar", *scalar}},
/*objects=*/{},
/*shared_variables=*/{},
/*workload=*/uint3(),
/*workgroup=*/uint3(),
/*source_code=*/std::move(source),
/*input=*/IOStructure::ONLY_DEFINITIONS,
/*source_code=*/"value_0 *= $scalar$;",
/*input=*/IOStructure::AUTO,
/*output=*/IOStructure::AUTO,
};
} else {
if (!muls) {
return InvalidArgumentError("Empty parameters for Multiplication.");
}
auto shape = ctx.graph->FindInputs(ctx.node->id)[0]->tensor.shape;
*generated_code = {
/*parameters=*/{},
/*objects=*/{{"mul_buffer", MakeReadonlyObject(muls->data)}},
/*shared_variables=*/{},
// Declare workload explicitly because shader depends on gid.z.
/*workload=*/
uint3(shape.w, shape.h, IntegralDivideRoundUp(shape.c, 4)),
/*workgroup=*/uint3(),
/*source_code=*/"value_0 *= $mul_buffer[gid.z]$;",
/*input=*/IOStructure::AUTO,
/*output=*/IOStructure::AUTO,
};
return OkStatus();
}
};
class MultiplyScalar : public NodeShader {
return OkStatus();
}
class Multiply : public NodeShader {
public:
Status GenerateCode(const GenerationContext& ctx,
GeneratedCode* generated_code) const final {
auto attr = absl::any_cast<MultiplyScalarAttributes>(
ctx.node->operation.attributes);
auto muls = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&attr.param);
auto scalar = absl::get_if<float>(&attr.param);
if (scalar) {
*generated_code = {
/*parameters=*/{{"scalar", *scalar}},
/*objects=*/{},
/*shared_variables=*/{},
/*workload=*/uint3(),
/*workgroup=*/uint3(),
/*source_code=*/"value_0 *= $scalar$;",
/*input=*/IOStructure::AUTO,
/*output=*/IOStructure::AUTO,
};
if (IsApplyMaskSupported(ctx)) {
return GenerateApplyMaskCode(ctx, generated_code);
} else {
if (!muls) {
return InvalidArgumentError("Empty parameters for Multiplication.");
}
auto shape = ctx.graph->FindInputs(ctx.node->id)[0]->tensor.shape;
*generated_code = {
/*parameters=*/{},
/*objects=*/{{"mul_buffer", MakeReadonlyObject(muls->data)}},
/*shared_variables=*/{},
// Declare workload explicitly because shader depends on gid.z.
/*workload=*/
uint3(shape.w, shape.h, IntegralDivideRoundUp(shape.c, 4)),
/*workgroup=*/uint3(),
/*source_code=*/"value_0 *= $mul_buffer[gid.z]$;",
/*input=*/IOStructure::AUTO,
/*output=*/IOStructure::AUTO,
};
return GenerateMultiplyScalarCode(ctx, generated_code);
}
return OkStatus();
}
};
} // namespace
std::unique_ptr<NodeShader> NewApplyMaskNodeShader() {
return absl::make_unique<ApplyMask>();
}
std::unique_ptr<NodeShader> NewMultiplyScalarNodeShader() {
return absl::make_unique<MultiplyScalar>();
std::unique_ptr<NodeShader> NewMultiplyNodeShader() {
return absl::make_unique<Multiply>();
}
} // namespace gl

View File

@ -25,9 +25,7 @@ namespace tflite {
namespace gpu {
namespace gl {
std::unique_ptr<NodeShader> NewApplyMaskNodeShader();
std::unique_ptr<NodeShader> NewMultiplyScalarNodeShader();
std::unique_ptr<NodeShader> NewMultiplyNodeShader();
} // namespace gl
} // namespace gpu

View File

@ -41,13 +41,12 @@ TEST(MulTest, Scalar) {
output.ref = 1;
output.shape = BHWC(1, 2, 2, 1);
MultiplyScalarAttributes attr;
MultiplyAttributes attr;
attr.param = 2.f;
// TODO(eignasheva): change to MULTIPLY_SCALAR
SingleOpModel model({ToString(OperationType::MUL), attr}, {input}, {output});
ASSERT_TRUE(model.PopulateTensor(0, {1, 2, 3, 4}));
ASSERT_OK(model.Invoke(*NewMultiplyScalarNodeShader()));
ASSERT_OK(model.Invoke(*NewMultiplyNodeShader()));
EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {2, 4, 6, 8}));
}
@ -62,21 +61,20 @@ TEST(MulTest, Linear) {
output.ref = 1;
output.shape = BHWC(1, 1, 2, 2);
MultiplyScalarAttributes attr;
MultiplyAttributes attr;
Tensor<Linear, DataType::FLOAT32> tensor;
tensor.shape.v = 2;
tensor.id = 1;
tensor.data = {2, 3};
attr.param = std::move(tensor);
// TODO(eignasheva): change to MULTIPLY_SCALAR
SingleOpModel model({ToString(OperationType::MUL), attr}, {input}, {output});
ASSERT_TRUE(model.PopulateTensor(0, {1, 2, 3, 4}));
ASSERT_OK(model.Invoke(*NewMultiplyScalarNodeShader()));
ASSERT_OK(model.Invoke(*NewMultiplyNodeShader()));
EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {2, 6, 6, 12}));
}
TEST(ApplyMaskTest, MaskChannel1) {
TEST(MulTest, MaskChannel1) {
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
@ -92,15 +90,15 @@ TEST(ApplyMaskTest, MaskChannel1) {
output.ref = 2;
output.shape = BHWC(1, 1, 2, 2);
SingleOpModel model({ToString(OperationType::APPLY_MASK), {}}, {input, mask},
SingleOpModel model({ToString(OperationType::MUL), {}}, {input, mask},
{output});
ASSERT_TRUE(model.PopulateTensor(0, {1, 2, 3, 4}));
ASSERT_TRUE(model.PopulateTensor(1, {2, 3}));
ASSERT_OK(model.Invoke(*NewApplyMaskNodeShader()));
ASSERT_OK(model.Invoke(*NewMultiplyNodeShader()));
EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {2, 4, 9, 12}));
}
TEST(ApplyMaskTest, MaskChannelEqualsToInputChannel) {
TEST(MulTest, MaskChannelEqualsToInputChannel) {
TensorRef<BHWC> input;
input.type = DataType::FLOAT32;
input.ref = 0;
@ -116,11 +114,11 @@ TEST(ApplyMaskTest, MaskChannelEqualsToInputChannel) {
output.ref = 2;
output.shape = BHWC(1, 1, 2, 2);
SingleOpModel model({ToString(OperationType::APPLY_MASK), {}}, {input, mask},
SingleOpModel model({ToString(OperationType::MUL), {}}, {input, mask},
{output});
ASSERT_TRUE(model.PopulateTensor(0, {1, 2, 3, 4}));
ASSERT_TRUE(model.PopulateTensor(1, {1, 2, 3, 4}));
ASSERT_OK(model.Invoke(*NewApplyMaskNodeShader()));
ASSERT_OK(model.Invoke(*NewMultiplyNodeShader()));
EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {1, 4, 9, 16}));
}

View File

@ -71,7 +71,6 @@ class Registry : public NodeShader {
};
insert_op(Type::ADD, NewAddNodeShader);
insert_op(Type::APPLY_MASK, NewApplyMaskNodeShader);
insert_op(Type::CONCAT, NewAlignedConcatNodeShader);
insert_op(Type::CONCAT, NewFlatConcatNodeShader);
insert_op(Type::CONCAT, NewConcatNodeShader);
@ -82,7 +81,7 @@ class Registry : public NodeShader {
insert_op(Type::FULLY_CONNECTED, NewFullyConnectedNodeShader);
insert_op(Type::LSTM, NewLstmNodeShader);
insert_op(Type::MEAN, NewMeanNodeShader);
insert_op(Type::MULTIPLY_SCALAR, NewMultiplyScalarNodeShader);
insert_op(Type::MUL, NewMultiplyNodeShader);
insert_op(Type::PAD, NewPadNodeShader);
insert_op(Type::POOLING_2D, NewPoolingNodeShader);
insert_op(Type::PRELU, NewPReLUNodeShader);

View File

@ -199,11 +199,15 @@ Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
*tasks = Mean(node_id, inputs[0], outputs[0],
absl::any_cast<MeanAttributes>(node->operation.attributes));
break;
case OperationType::MULTIPLY_SCALAR:
*tasks = Multiply(
node_id, inputs[0], outputs[0],
absl::any_cast<MultiplyScalarAttributes>(node->operation.attributes),
options);
case OperationType::MUL:
if (node->operation.attributes.has_value()) {
*tasks = Multiply(
node_id, inputs[0], outputs[0],
absl::any_cast<MultiplyAttributes>(node->operation.attributes),
options);
} else {
*tasks = ApplyMask(node_id, inputs[0], inputs[1], outputs[0], options);
}
break;
case OperationType::PAD: {
auto attr = absl::any_cast<PadAttributes>(node->operation.attributes);
@ -268,12 +272,10 @@ Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
case OperationType::SQUARED_DIFF:
*tasks = ElementwiseWithTwoInputs(node_id, inputs, outputs[0], op_type);
break;
case OperationType::APPLY_MASK:
case OperationType::BATCH_NORMALIZATION:
case OperationType::BATCH_TO_SPACE:
case OperationType::CONST:
case OperationType::LSTM:
case OperationType::MUL:
case OperationType::SPACE_TO_BATCH:
case OperationType::TRANSPOSE:
case OperationType::UNKNOWN:

View File

@ -128,9 +128,10 @@ std::vector<ComputeTaskDescriptorPtr> ApplyMask(int id, ValueId input_id_0,
return {desc};
}
std::vector<ComputeTaskDescriptorPtr> Multiply(
int id, ValueId input_id, ValueId output_id,
const MultiplyScalarAttributes& attr, const RuntimeOptions& options) {
std::vector<ComputeTaskDescriptorPtr> Multiply(int id, ValueId input_id,
ValueId output_id,
const MultiplyAttributes& attr,
const RuntimeOptions& options) {
auto desc = std::make_shared<ComputeTaskDescriptor>();
desc->id = id;
desc->is_linkable = true;

View File

@ -26,9 +26,10 @@ namespace gpu {
namespace metal {
// Multiply operation, supports scalar and vector broadcast.
std::vector<ComputeTaskDescriptorPtr> Multiply(
int id, ValueId input_id, ValueId output_id,
const MultiplyScalarAttributes& attr, const RuntimeOptions& options);
std::vector<ComputeTaskDescriptorPtr> Multiply(int id, ValueId input_id,
ValueId output_id,
const MultiplyAttributes& attr,
const RuntimeOptions& options);
std::vector<ComputeTaskDescriptorPtr> ApplyMask(int id, ValueId input_id_0,
ValueId input_id_1,

View File

@ -31,7 +31,7 @@ limitations under the License.
using ::tflite::gpu::DataType;
using ::tflite::gpu::BHWC;
using ::tflite::gpu::Linear;
using ::tflite::gpu::MultiplyScalarAttributes;
using ::tflite::gpu::MultiplyAttributes;
using ::tflite::gpu::OperationType;
using ::tflite::gpu::Tensor;
using ::tflite::gpu::TensorRef;
@ -57,10 +57,10 @@ using ::tflite::gpu::metal::SingleOpModel;
output.ref = 1;
output.shape = BHWC(1, 2, 2, 1);
MultiplyScalarAttributes attr;
MultiplyAttributes attr;
attr.param = 2;
SingleOpModel model({ToString(OperationType::MULTIPLY_SCALAR), attr}, {input}, {output});
SingleOpModel model({ToString(OperationType::MUL), attr}, {input}, {output});
XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4}));
auto status = model.Invoke();
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
@ -79,14 +79,14 @@ using ::tflite::gpu::metal::SingleOpModel;
output.ref = 1;
output.shape = BHWC(1, 1, 2, 2);
MultiplyScalarAttributes attr;
MultiplyAttributes attr;
Tensor<Linear, DataType::FLOAT32> tensor;
tensor.shape.v = 2;
tensor.id = 1;
tensor.data = {2, 3};
attr.param = std::move(tensor);
SingleOpModel model({ToString(OperationType::MULTIPLY_SCALAR), attr}, {input}, {output});
SingleOpModel model({ToString(OperationType::MUL), attr}, {input}, {output});
XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4}));
auto status = model.Invoke();
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
@ -111,7 +111,7 @@ using ::tflite::gpu::metal::SingleOpModel;
output.ref = 2;
output.shape = BHWC(1, 1, 2, 2);
SingleOpModel model({ToString(OperationType::APPLY_MASK), {}}, {input, mask}, {output});
SingleOpModel model({ToString(OperationType::MUL), {}}, {input, mask}, {output});
XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4}));
XCTAssertTrue(model.PopulateTensor(1, {2, 3}));
auto status = model.Invoke();
@ -136,13 +136,12 @@ using ::tflite::gpu::metal::SingleOpModel;
output.ref = 2;
output.shape = BHWC(1, 1, 2, 2);
SingleOpModel model({ToString(OperationType::APPLY_MASK), {}}, {input, mask}, {output});
SingleOpModel model({ToString(OperationType::MUL), {}}, {input, mask}, {output});
XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4}));
XCTAssertTrue(model.PopulateTensor(1, {1, 2, 3, 4}));
auto status = model.Invoke();
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
// Disable test for now.
// status = CompareVectors({1, 4, 9, 16}, model.GetOutput(0), 1e-6f);
status = CompareVectors({1, 4, 9, 16}, model.GetOutput(0), 1e-6f);
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
}