Remove APPLY_MASK and MULTIPLY_SCALAR enum values, use MUL instead.
Rename MultiplyScalarAttributes to MultiplyAttributes. PiperOrigin-RevId: 292228929 Change-Id: I459585d62e4161a93d6ec163bdbc0494c37eaf0e
This commit is contained in:
parent
abb256c1f8
commit
71c6f97e2d
@ -39,39 +39,6 @@ cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "apply_mask",
|
||||
srcs = ["apply_mask.cc"],
|
||||
hdrs = ["apply_mask.h"],
|
||||
deps = [
|
||||
":gpu_operation",
|
||||
":util",
|
||||
":work_group_picking",
|
||||
"//tensorflow/lite/delegates/gpu/cl:cl_kernel",
|
||||
"//tensorflow/lite/delegates/gpu/common:operations",
|
||||
"//tensorflow/lite/delegates/gpu/common:status",
|
||||
"//tensorflow/lite/delegates/gpu/common:types",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "apply_mask_test",
|
||||
srcs = ["apply_mask_test.cc"],
|
||||
linkstatic = True,
|
||||
tags = tf_gpu_tests_tags() + [
|
||||
"linux",
|
||||
"local",
|
||||
],
|
||||
deps = [
|
||||
":apply_mask",
|
||||
":cl_test",
|
||||
"//tensorflow/lite/delegates/gpu/common:operations",
|
||||
"//tensorflow/lite/delegates/gpu/common:status",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cl_test",
|
||||
testonly = 1,
|
||||
@ -1328,7 +1295,6 @@ test_suite(
|
||||
name = "all_tests",
|
||||
tests = [
|
||||
"add_test",
|
||||
"apply_mask_test",
|
||||
"concat_test",
|
||||
"conv_buffer_1x1_test",
|
||||
"conv_buffer_test",
|
||||
|
@ -1,103 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/apply_mask.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace cl {
|
||||
|
||||
ApplyMask::ApplyMask(ApplyMask&& operation)
|
||||
: ElementwiseOperation(std::move(operation)),
|
||||
mask_type_(operation.mask_type_),
|
||||
link_index_(operation.link_index_) {}
|
||||
|
||||
ApplyMask& ApplyMask::operator=(ApplyMask&& operation) {
|
||||
if (this != &operation) {
|
||||
mask_type_ = operation.mask_type_;
|
||||
link_index_ = operation.link_index_;
|
||||
ElementwiseOperation::operator=(std::move(operation));
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
void ApplyMask::SetLinkIndex(int index) { link_index_ = index; }
|
||||
|
||||
std::string ApplyMask::GetCoreCode(const LinkingContext& context) const {
|
||||
const std::string size_name = "mask_size_op" + std::to_string(link_index_);
|
||||
const std::string tensor_name = absl::StrCat("mask_data_op", link_index_);
|
||||
TensorCodeGenerator mask(
|
||||
tensor_name,
|
||||
WHSPoint{size_name + ".x", size_name + ".y", size_name + ".z"},
|
||||
definition_.src_tensors[1]);
|
||||
switch (mask_type_) {
|
||||
case MaskType::TENSOR:
|
||||
return context.var_name + " *= " +
|
||||
mask.ReadWHS(context.x_coord, context.y_coord, context.s_coord) +
|
||||
";\n";
|
||||
case MaskType::CHANNELS:
|
||||
return context.var_name +
|
||||
" *= " + mask.ReadWHS("0", "0", context.s_coord) + ";\n";
|
||||
case MaskType::LAYER:
|
||||
return context.var_name +
|
||||
" *= " + mask.ReadWHS(context.x_coord, context.y_coord, "0") +
|
||||
".x;\n";
|
||||
}
|
||||
}
|
||||
|
||||
std::string ApplyMask::GetArgsDeclaration() const {
|
||||
std::string args;
|
||||
const std::string tensor_name = absl::StrCat("mask_data_op", link_index_);
|
||||
absl::StrAppend(&args, ",\n",
|
||||
GetTensorDeclaration(AccessType::READ, tensor_name,
|
||||
definition_.src_tensors[1]));
|
||||
const std::string size_name = "mask_size_op" + std::to_string(link_index_);
|
||||
absl::StrAppend(&args, ",\n int4 ", size_name);
|
||||
return args;
|
||||
}
|
||||
|
||||
Status ApplyMask::BindArguments(CLKernel* kernel) {
|
||||
RETURN_IF_ERROR(kernel->SetMemoryAuto(src_[1]->GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(kernel->SetBytesAuto(src_[1]->GetWBatchedHSB()));
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
ApplyMask CreateApplyMask(const OperationDef& definition, const BHWC& src_shape,
|
||||
const BHWC& mask_shape) {
|
||||
ApplyMask::MaskType mask_type;
|
||||
if (mask_shape == src_shape) {
|
||||
mask_type = ApplyMask::MaskType::TENSOR;
|
||||
} else if (mask_shape.c == 1) {
|
||||
mask_type = ApplyMask::MaskType::LAYER;
|
||||
} else {
|
||||
mask_type = ApplyMask::MaskType::CHANNELS;
|
||||
}
|
||||
ApplyMask operation(definition, mask_type);
|
||||
operation.SetLinkIndex(0);
|
||||
return operation;
|
||||
}
|
||||
|
||||
} // namespace cl
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
@ -1,63 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_APPLY_MASK_H_
|
||||
#define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_APPLY_MASK_H_
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace cl {
|
||||
|
||||
class ApplyMask : public ElementwiseOperation {
|
||||
public:
|
||||
// Move only
|
||||
ApplyMask(ApplyMask&& operation);
|
||||
ApplyMask& operator=(ApplyMask&& operation);
|
||||
ApplyMask(const ApplyMask&) = delete;
|
||||
ApplyMask& operator=(const ApplyMask&) = delete;
|
||||
|
||||
void SetLinkIndex(int index) override;
|
||||
std::string GetCoreCode(const LinkingContext& context) const override;
|
||||
std::string GetArgsDeclaration() const override;
|
||||
Status BindArguments(CLKernel* kernel) override;
|
||||
|
||||
private:
|
||||
friend ApplyMask CreateApplyMask(const OperationDef& definition,
|
||||
const BHWC& src_shape,
|
||||
const BHWC& mask_shape);
|
||||
|
||||
enum class MaskType { LAYER, CHANNELS, TENSOR };
|
||||
|
||||
explicit ApplyMask(const OperationDef& definition, MaskType mask_type)
|
||||
: ElementwiseOperation(definition), mask_type_(mask_type) {}
|
||||
|
||||
MaskType mask_type_;
|
||||
int link_index_;
|
||||
};
|
||||
|
||||
ApplyMask CreateApplyMask(const OperationDef& definition, const BHWC& src_shape,
|
||||
const BHWC& mask_shape);
|
||||
|
||||
} // namespace cl
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_APPLY_MASK_H_
|
@ -1,127 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/apply_mask.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/cl_test.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
|
||||
using ::testing::FloatNear;
|
||||
using ::testing::Pointwise;
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace cl {
|
||||
namespace {
|
||||
|
||||
TEST_F(OpenCLOperationTest, ApplyMaskOneChannel) {
|
||||
TensorFloat32 src_tensor;
|
||||
src_tensor.shape = BHWC(1, 2, 2, 2);
|
||||
src_tensor.data = {-4.0f, -3.0f, -1.0f, 0.0f, 1.0f, 3.0f, 4.0f, 6.0f};
|
||||
TensorFloat32 mask_tensor;
|
||||
mask_tensor.shape = BHWC(1, 2, 2, 1);
|
||||
mask_tensor.data = {2.0f, 0.5f, 1.0f, 0.0f};
|
||||
|
||||
for (auto storage : env_.GetSupportedStorages()) {
|
||||
for (auto precision : env_.GetSupportedPrecisions()) {
|
||||
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
|
||||
OperationDef op_def;
|
||||
op_def.precision = precision;
|
||||
auto data_type = DeduceDataTypeFromPrecision(precision);
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
TensorFloat32 dst_tensor;
|
||||
ApplyMask operation =
|
||||
CreateApplyMask(op_def, src_tensor.shape, mask_tensor.shape);
|
||||
ASSERT_OK(ExecuteGPUOperation({src_tensor, mask_tensor},
|
||||
creation_context_, &operation,
|
||||
BHWC(1, 2, 2, 2), &dst_tensor));
|
||||
EXPECT_THAT(dst_tensor.data,
|
||||
Pointwise(FloatNear(eps), {-8.0f, -6.0f, -0.5f, 0.0f, 1.0f,
|
||||
3.0f, 0.0f, 0.0f}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(OpenCLOperationTest, ApplyMaskEqualSizes) {
|
||||
TensorFloat32 src_tensor;
|
||||
src_tensor.shape = BHWC(1, 2, 2, 2);
|
||||
src_tensor.data = {-4.0f, -3.0f, -1.0f, 0.0f, 1.0f, 3.0f, 4.0f, 6.0f};
|
||||
TensorFloat32 mask_tensor;
|
||||
mask_tensor.shape = BHWC(1, 2, 2, 2);
|
||||
mask_tensor.data = {2.0f, 0.5f, 1.0f, 0.0f, 2.0f, 0.5f, 1.0f, 0.0f};
|
||||
|
||||
for (auto storage : env_.GetSupportedStorages()) {
|
||||
for (auto precision : env_.GetSupportedPrecisions()) {
|
||||
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
|
||||
OperationDef op_def;
|
||||
op_def.precision = precision;
|
||||
auto data_type = DeduceDataTypeFromPrecision(precision);
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
TensorFloat32 dst_tensor;
|
||||
ApplyMask operation =
|
||||
CreateApplyMask(op_def, src_tensor.shape, mask_tensor.shape);
|
||||
ASSERT_OK(ExecuteGPUOperation({src_tensor, mask_tensor},
|
||||
creation_context_, &operation,
|
||||
BHWC(1, 2, 2, 2), &dst_tensor));
|
||||
EXPECT_THAT(dst_tensor.data,
|
||||
Pointwise(FloatNear(eps), {-8.0f, -1.5f, -1.0f, 0.0f, 2.0f,
|
||||
1.5f, 4.0f, 0.0f}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(OpenCLOperationTest, ApplyMaskVector) {
|
||||
TensorFloat32 src_tensor;
|
||||
src_tensor.shape = BHWC(1, 2, 2, 2);
|
||||
src_tensor.data = {-4.0f, -3.0f, -1.0f, 0.0f, 1.0f, 3.0f, 4.0f, 6.0f};
|
||||
TensorFloat32 mask_tensor;
|
||||
mask_tensor.shape = BHWC(1, 1, 1, 2);
|
||||
mask_tensor.data = {2.0f, 0.5f};
|
||||
|
||||
for (auto storage : env_.GetSupportedStorages()) {
|
||||
for (auto precision : env_.GetSupportedPrecisions()) {
|
||||
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
|
||||
OperationDef op_def;
|
||||
op_def.precision = precision;
|
||||
auto data_type = DeduceDataTypeFromPrecision(precision);
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
TensorFloat32 dst_tensor;
|
||||
ApplyMask operation =
|
||||
CreateApplyMask(op_def, src_tensor.shape, mask_tensor.shape);
|
||||
ASSERT_OK(ExecuteGPUOperation({src_tensor, mask_tensor},
|
||||
creation_context_, &operation,
|
||||
BHWC(1, 2, 2, 2), &dst_tensor));
|
||||
EXPECT_THAT(dst_tensor.data,
|
||||
Pointwise(FloatNear(eps), {-8.0f, -1.5f, -2.0f, 0.0f, 2.0f,
|
||||
1.5f, 8.0f, 3.0f}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace cl
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
@ -105,7 +105,7 @@ Status MultiplyAdd::BindArguments(CLKernel* kernel) {
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
Status MultiplyAdd::UploadMul(const MultiplyScalarAttributes& attr,
|
||||
Status MultiplyAdd::UploadMul(const MultiplyAttributes& attr,
|
||||
CalculationsPrecision scalar_precision,
|
||||
CLContext* context) {
|
||||
auto mul = absl::get_if<::tflite::gpu::Tensor<Linear, DataType::FLOAT32>>(
|
||||
@ -135,8 +135,7 @@ Status MultiplyAdd::UploadAdd(const AddAttributes& attr,
|
||||
|
||||
Status CreateMultiplyAdd(const CreationContext& creation_context,
|
||||
const OperationDef& definition,
|
||||
const MultiplyScalarAttributes& attr,
|
||||
MultiplyAdd* result) {
|
||||
const MultiplyAttributes& attr, MultiplyAdd* result) {
|
||||
const auto scalar_precision = creation_context.device->IsPowerVR()
|
||||
? CalculationsPrecision::F32
|
||||
: definition.precision;
|
||||
@ -162,7 +161,7 @@ Status CreateMultiplyAdd(const CreationContext& creation_context,
|
||||
|
||||
Status CreateMultiplyAdd(const CreationContext& creation_context,
|
||||
const OperationDef& definition,
|
||||
const MultiplyScalarAttributes& mul_attr,
|
||||
const MultiplyAttributes& mul_attr,
|
||||
const AddAttributes& add_attr, MultiplyAdd* result) {
|
||||
const auto scalar_precision = creation_context.device->IsPowerVR()
|
||||
? CalculationsPrecision::F32
|
||||
@ -176,6 +175,76 @@ Status CreateMultiplyAdd(const CreationContext& creation_context,
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
ApplyMask::ApplyMask(ApplyMask&& operation)
|
||||
: ElementwiseOperation(std::move(operation)),
|
||||
mask_type_(operation.mask_type_),
|
||||
link_index_(operation.link_index_) {}
|
||||
|
||||
ApplyMask& ApplyMask::operator=(ApplyMask&& operation) {
|
||||
if (this != &operation) {
|
||||
mask_type_ = operation.mask_type_;
|
||||
link_index_ = operation.link_index_;
|
||||
ElementwiseOperation::operator=(std::move(operation));
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
void ApplyMask::SetLinkIndex(int index) { link_index_ = index; }
|
||||
|
||||
std::string ApplyMask::GetCoreCode(const LinkingContext& context) const {
|
||||
const std::string size_name = "mask_size_op" + std::to_string(link_index_);
|
||||
const std::string tensor_name = absl::StrCat("mask_data_op", link_index_);
|
||||
TensorCodeGenerator mask(
|
||||
tensor_name,
|
||||
WHSPoint{size_name + ".x", size_name + ".y", size_name + ".z"},
|
||||
definition_.src_tensors[1]);
|
||||
switch (mask_type_) {
|
||||
case MaskType::TENSOR:
|
||||
return context.var_name + " *= " +
|
||||
mask.ReadWHS(context.x_coord, context.y_coord, context.s_coord) +
|
||||
";\n";
|
||||
case MaskType::CHANNELS:
|
||||
return context.var_name +
|
||||
" *= " + mask.ReadWHS("0", "0", context.s_coord) + ";\n";
|
||||
case MaskType::LAYER:
|
||||
return context.var_name +
|
||||
" *= " + mask.ReadWHS(context.x_coord, context.y_coord, "0") +
|
||||
".x;\n";
|
||||
}
|
||||
}
|
||||
|
||||
std::string ApplyMask::GetArgsDeclaration() const {
|
||||
std::string args;
|
||||
const std::string tensor_name = absl::StrCat("mask_data_op", link_index_);
|
||||
absl::StrAppend(&args, ",\n",
|
||||
GetTensorDeclaration(AccessType::READ, tensor_name,
|
||||
definition_.src_tensors[1]));
|
||||
const std::string size_name = "mask_size_op" + std::to_string(link_index_);
|
||||
absl::StrAppend(&args, ",\n int4 ", size_name);
|
||||
return args;
|
||||
}
|
||||
|
||||
Status ApplyMask::BindArguments(CLKernel* kernel) {
|
||||
RETURN_IF_ERROR(kernel->SetMemoryAuto(src_[1]->GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(kernel->SetBytesAuto(src_[1]->GetWBatchedHSB()));
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
ApplyMask CreateApplyMask(const OperationDef& definition, const BHWC& src_shape,
|
||||
const BHWC& mask_shape) {
|
||||
ApplyMask::MaskType mask_type;
|
||||
if (mask_shape == src_shape) {
|
||||
mask_type = ApplyMask::MaskType::TENSOR;
|
||||
} else if (mask_shape.c == 1) {
|
||||
mask_type = ApplyMask::MaskType::LAYER;
|
||||
} else {
|
||||
mask_type = ApplyMask::MaskType::CHANNELS;
|
||||
}
|
||||
ApplyMask operation(definition, mask_type);
|
||||
operation.SetLinkIndex(0);
|
||||
return operation;
|
||||
}
|
||||
|
||||
} // namespace cl
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
@ -40,7 +40,7 @@ class MultiplyAdd : public ElementwiseOperation {
|
||||
MultiplyAdd(const MultiplyAdd&) = delete;
|
||||
MultiplyAdd& operator=(const MultiplyAdd&) = delete;
|
||||
|
||||
Status UploadMul(const MultiplyScalarAttributes& attr,
|
||||
Status UploadMul(const MultiplyAttributes& attr,
|
||||
CalculationsPrecision scalar_precision, CLContext* context);
|
||||
Status UploadAdd(const AddAttributes& attr,
|
||||
CalculationsPrecision scalar_precision, CLContext* context);
|
||||
@ -61,7 +61,7 @@ class MultiplyAdd : public ElementwiseOperation {
|
||||
|
||||
friend Status CreateMultiplyAdd(const CreationContext& creation_context,
|
||||
const OperationDef& definition,
|
||||
const MultiplyScalarAttributes& attr,
|
||||
const MultiplyAttributes& attr,
|
||||
MultiplyAdd* result);
|
||||
|
||||
friend Status CreateMultiplyAdd(const CreationContext& creation_context,
|
||||
@ -71,7 +71,7 @@ class MultiplyAdd : public ElementwiseOperation {
|
||||
|
||||
friend Status CreateMultiplyAdd(const CreationContext& creation_context,
|
||||
const OperationDef& definition,
|
||||
const MultiplyScalarAttributes& mul_attr,
|
||||
const MultiplyAttributes& mul_attr,
|
||||
const AddAttributes& add_attr,
|
||||
MultiplyAdd* result);
|
||||
|
||||
@ -91,8 +91,7 @@ class MultiplyAdd : public ElementwiseOperation {
|
||||
|
||||
Status CreateMultiplyAdd(const CreationContext& creation_context,
|
||||
const OperationDef& definition,
|
||||
const MultiplyScalarAttributes& attr,
|
||||
MultiplyAdd* result);
|
||||
const MultiplyAttributes& attr, MultiplyAdd* result);
|
||||
|
||||
Status CreateMultiplyAdd(const CreationContext& creation_context,
|
||||
const OperationDef& definition,
|
||||
@ -100,7 +99,7 @@ Status CreateMultiplyAdd(const CreationContext& creation_context,
|
||||
|
||||
Status CreateMultiplyAdd(const CreationContext& creation_context,
|
||||
const OperationDef& definition,
|
||||
const MultiplyScalarAttributes& mul_attr,
|
||||
const MultiplyAttributes& mul_attr,
|
||||
const AddAttributes& add_attr, MultiplyAdd* result);
|
||||
|
||||
template <DataType T>
|
||||
@ -127,6 +126,36 @@ Status MultiplyAdd::UploadAdd(const ::tflite::gpu::Tensor<Linear, T>& add,
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
class ApplyMask : public ElementwiseOperation {
|
||||
public:
|
||||
// Move only
|
||||
ApplyMask(ApplyMask&& operation);
|
||||
ApplyMask& operator=(ApplyMask&& operation);
|
||||
ApplyMask(const ApplyMask&) = delete;
|
||||
ApplyMask& operator=(const ApplyMask&) = delete;
|
||||
|
||||
void SetLinkIndex(int index) override;
|
||||
std::string GetCoreCode(const LinkingContext& context) const override;
|
||||
std::string GetArgsDeclaration() const override;
|
||||
Status BindArguments(CLKernel* kernel) override;
|
||||
|
||||
private:
|
||||
friend ApplyMask CreateApplyMask(const OperationDef& definition,
|
||||
const BHWC& src_shape,
|
||||
const BHWC& mask_shape);
|
||||
|
||||
enum class MaskType { LAYER, CHANNELS, TENSOR };
|
||||
|
||||
explicit ApplyMask(const OperationDef& definition, MaskType mask_type)
|
||||
: ElementwiseOperation(definition), mask_type_(mask_type) {}
|
||||
|
||||
MaskType mask_type_;
|
||||
int link_index_;
|
||||
};
|
||||
|
||||
ApplyMask CreateApplyMask(const OperationDef& definition, const BHWC& src_shape,
|
||||
const BHWC& mask_shape);
|
||||
|
||||
} // namespace cl
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
@ -37,7 +37,7 @@ TEST_F(OpenCLOperationTest, MultiplyAddVectorMul) {
|
||||
src_tensor.shape = BHWC(1, 2, 1, 2);
|
||||
src_tensor.data = {0.0f, 1.0f, 2.0f, 3.0f};
|
||||
|
||||
MultiplyScalarAttributes attr;
|
||||
MultiplyAttributes attr;
|
||||
::tflite::gpu::Tensor<Linear, DataType::FLOAT32> parameters;
|
||||
parameters.shape = Linear(2);
|
||||
parameters.data = {0.5f, 2.0f};
|
||||
@ -97,7 +97,7 @@ TEST_F(OpenCLOperationTest, MultiplyAddScalarMul) {
|
||||
src_tensor.shape = BHWC(1, 2, 1, 2);
|
||||
src_tensor.data = {0.0f, 1.0f, 2.0f, 3.0f};
|
||||
|
||||
MultiplyScalarAttributes attr;
|
||||
MultiplyAttributes attr;
|
||||
attr.param = 0.5f;
|
||||
|
||||
for (auto storage : env_.GetSupportedStorages()) {
|
||||
@ -151,7 +151,7 @@ TEST_F(OpenCLOperationTest, MultiplyAddVectorMad) {
|
||||
src_tensor.shape = BHWC(1, 2, 1, 2);
|
||||
src_tensor.data = {0.0f, 1.0f, 2.0f, 3.0f};
|
||||
|
||||
MultiplyScalarAttributes mul_attr;
|
||||
MultiplyAttributes mul_attr;
|
||||
::tflite::gpu::Tensor<Linear, DataType::FLOAT32> parameters;
|
||||
parameters.shape = Linear(2);
|
||||
parameters.data = {0.5f, 2.0f};
|
||||
@ -181,6 +181,96 @@ TEST_F(OpenCLOperationTest, MultiplyAddVectorMad) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(OpenCLOperationTest, ApplyMaskOneChannel) {
|
||||
TensorFloat32 src_tensor;
|
||||
src_tensor.shape = BHWC(1, 2, 2, 2);
|
||||
src_tensor.data = {-4.0f, -3.0f, -1.0f, 0.0f, 1.0f, 3.0f, 4.0f, 6.0f};
|
||||
TensorFloat32 mask_tensor;
|
||||
mask_tensor.shape = BHWC(1, 2, 2, 1);
|
||||
mask_tensor.data = {2.0f, 0.5f, 1.0f, 0.0f};
|
||||
|
||||
for (auto storage : env_.GetSupportedStorages()) {
|
||||
for (auto precision : env_.GetSupportedPrecisions()) {
|
||||
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
|
||||
OperationDef op_def;
|
||||
op_def.precision = precision;
|
||||
auto data_type = DeduceDataTypeFromPrecision(precision);
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
TensorFloat32 dst_tensor;
|
||||
ApplyMask operation =
|
||||
CreateApplyMask(op_def, src_tensor.shape, mask_tensor.shape);
|
||||
ASSERT_OK(ExecuteGPUOperation({src_tensor, mask_tensor},
|
||||
creation_context_, &operation,
|
||||
BHWC(1, 2, 2, 2), &dst_tensor));
|
||||
EXPECT_THAT(dst_tensor.data,
|
||||
Pointwise(FloatNear(eps), {-8.0f, -6.0f, -0.5f, 0.0f, 1.0f,
|
||||
3.0f, 0.0f, 0.0f}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(OpenCLOperationTest, ApplyMaskEqualSizes) {
|
||||
TensorFloat32 src_tensor;
|
||||
src_tensor.shape = BHWC(1, 2, 2, 2);
|
||||
src_tensor.data = {-4.0f, -3.0f, -1.0f, 0.0f, 1.0f, 3.0f, 4.0f, 6.0f};
|
||||
TensorFloat32 mask_tensor;
|
||||
mask_tensor.shape = BHWC(1, 2, 2, 2);
|
||||
mask_tensor.data = {2.0f, 0.5f, 1.0f, 0.0f, 2.0f, 0.5f, 1.0f, 0.0f};
|
||||
|
||||
for (auto storage : env_.GetSupportedStorages()) {
|
||||
for (auto precision : env_.GetSupportedPrecisions()) {
|
||||
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
|
||||
OperationDef op_def;
|
||||
op_def.precision = precision;
|
||||
auto data_type = DeduceDataTypeFromPrecision(precision);
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
TensorFloat32 dst_tensor;
|
||||
ApplyMask operation =
|
||||
CreateApplyMask(op_def, src_tensor.shape, mask_tensor.shape);
|
||||
ASSERT_OK(ExecuteGPUOperation({src_tensor, mask_tensor},
|
||||
creation_context_, &operation,
|
||||
BHWC(1, 2, 2, 2), &dst_tensor));
|
||||
EXPECT_THAT(dst_tensor.data,
|
||||
Pointwise(FloatNear(eps), {-8.0f, -1.5f, -1.0f, 0.0f, 2.0f,
|
||||
1.5f, 4.0f, 0.0f}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(OpenCLOperationTest, ApplyMaskVector) {
|
||||
TensorFloat32 src_tensor;
|
||||
src_tensor.shape = BHWC(1, 2, 2, 2);
|
||||
src_tensor.data = {-4.0f, -3.0f, -1.0f, 0.0f, 1.0f, 3.0f, 4.0f, 6.0f};
|
||||
TensorFloat32 mask_tensor;
|
||||
mask_tensor.shape = BHWC(1, 1, 1, 2);
|
||||
mask_tensor.data = {2.0f, 0.5f};
|
||||
|
||||
for (auto storage : env_.GetSupportedStorages()) {
|
||||
for (auto precision : env_.GetSupportedPrecisions()) {
|
||||
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
|
||||
OperationDef op_def;
|
||||
op_def.precision = precision;
|
||||
auto data_type = DeduceDataTypeFromPrecision(precision);
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
TensorFloat32 dst_tensor;
|
||||
ApplyMask operation =
|
||||
CreateApplyMask(op_def, src_tensor.shape, mask_tensor.shape);
|
||||
ASSERT_OK(ExecuteGPUOperation({src_tensor, mask_tensor},
|
||||
creation_context_, &operation,
|
||||
BHWC(1, 2, 2, 2), &dst_tensor));
|
||||
EXPECT_THAT(dst_tensor.data,
|
||||
Pointwise(FloatNear(eps), {-8.0f, -1.5f, -2.0f, 0.0f, 2.0f,
|
||||
1.5f, 8.0f, 3.0f}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace cl
|
||||
} // namespace gpu
|
||||
|
@ -103,7 +103,6 @@ cc_library(
|
||||
hdrs = ["simple_selectors.h"],
|
||||
deps = [
|
||||
"//tensorflow/lite/delegates/gpu/cl/kernels:add",
|
||||
"//tensorflow/lite/delegates/gpu/cl/kernels:apply_mask",
|
||||
"//tensorflow/lite/delegates/gpu/cl/kernels:concat_xy",
|
||||
"//tensorflow/lite/delegates/gpu/cl/kernels:concat_z",
|
||||
"//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation",
|
||||
|
@ -68,11 +68,6 @@ Status GPUOperationFromNode(const CreationContext& creation_context,
|
||||
return OkStatus();
|
||||
}
|
||||
}
|
||||
case OperationType::APPLY_MASK: {
|
||||
SelectApplyMask(op_def, inputs[0]->tensor.shape, inputs[1]->tensor.shape,
|
||||
gpu_op);
|
||||
return OkStatus();
|
||||
}
|
||||
case OperationType::CONCAT: {
|
||||
auto attr = absl::any_cast<ConcatAttributes>(node.operation.attributes);
|
||||
std::vector<int> channels(inputs.size());
|
||||
@ -119,10 +114,17 @@ Status GPUOperationFromNode(const CreationContext& creation_context,
|
||||
auto attr = absl::any_cast<MeanAttributes>(node.operation.attributes);
|
||||
return SelectMean(attr, op_def, gpu_op);
|
||||
}
|
||||
case OperationType::MULTIPLY_SCALAR: {
|
||||
auto attr =
|
||||
absl::any_cast<MultiplyScalarAttributes>(node.operation.attributes);
|
||||
return SelectMultiplyScalar(attr, creation_context, op_def, gpu_op);
|
||||
case OperationType::MUL: {
|
||||
if (node.operation.attributes.has_value()) {
|
||||
auto attr =
|
||||
absl::any_cast<MultiplyAttributes>(node.operation.attributes);
|
||||
|
||||
return SelectMultiplyScalar(attr, creation_context, op_def, gpu_op);
|
||||
} else {
|
||||
SelectApplyMask(op_def, inputs[0]->tensor.shape,
|
||||
inputs[1]->tensor.shape, gpu_op);
|
||||
return OkStatus();
|
||||
}
|
||||
}
|
||||
case OperationType::PAD: {
|
||||
auto attr = absl::any_cast<PadAttributes>(node.operation.attributes);
|
||||
|
@ -20,7 +20,6 @@ limitations under the License.
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/add.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/apply_mask.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/concat_xy.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/concat_z.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/lstm.h"
|
||||
@ -155,7 +154,7 @@ Status SelectMean(const MeanAttributes& attr, const OperationDef& op_def,
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
Status SelectMultiplyScalar(const MultiplyScalarAttributes& attr,
|
||||
Status SelectMultiplyScalar(const MultiplyAttributes& attr,
|
||||
const CreationContext& creation_context,
|
||||
const OperationDef& op_def,
|
||||
std::unique_ptr<GPUOperation>* ptr) {
|
||||
|
@ -73,7 +73,7 @@ void SelectStridedSlice(const SliceAttributes& attr, const OperationDef& op_def,
|
||||
Status SelectMean(const MeanAttributes& attr, const OperationDef& op_def,
|
||||
std::unique_ptr<GPUOperation>* ptr);
|
||||
|
||||
Status SelectMultiplyScalar(const MultiplyScalarAttributes& attr,
|
||||
Status SelectMultiplyScalar(const MultiplyAttributes& attr,
|
||||
const CreationContext& creation_context,
|
||||
const OperationDef& op_def,
|
||||
std::unique_ptr<GPUOperation>* ptr);
|
||||
|
@ -1395,8 +1395,11 @@ class MulOperationParser : public TFLiteOperationParser {
|
||||
const bool runtime_tensor0 = !constant_tensor0;
|
||||
const bool runtime_tensor1 = !constant_tensor1;
|
||||
|
||||
// Parse for APPLY_MASK. The "larger" input tensor must be bound to 1st
|
||||
// input and the "smaller" input tensor ("mask") must be bound to 2nd input.
|
||||
Node* node = graph->NewNode();
|
||||
node->operation.type = ToString(OperationType::MUL);
|
||||
|
||||
// The "larger" input tensor must be bound to 1st input and the "smaller"
|
||||
// input tensor ("mask") must be bound to 2nd input.
|
||||
if (runtime_tensor0 && runtime_tensor1) {
|
||||
BHWC shape0;
|
||||
RETURN_IF_ERROR(ExtractTensorShape(*input0, &shape0));
|
||||
@ -1409,11 +1412,11 @@ class MulOperationParser : public TFLiteOperationParser {
|
||||
input_tensor0 = 1;
|
||||
input_tensor1 = 0;
|
||||
}
|
||||
return ParseApplyMask(input_tensor0, input_tensor1, graph, reader);
|
||||
return ParseApplyMask(node, input_tensor0, input_tensor1, graph, reader);
|
||||
}
|
||||
|
||||
// Parse for MULTIPLY_SCALAR. The runtime input tensor must be bound to 1st
|
||||
// input and the constant input tensor must be bound to 2nd input.
|
||||
// The runtime input tensor must be bound to 1st input and the constant
|
||||
// input tensor must be bound to 2nd input.
|
||||
int runtime_tensor = 0;
|
||||
int constant_tensor = 1;
|
||||
TfLiteIntArray* constant_dims = input1->dims;
|
||||
@ -1422,27 +1425,24 @@ class MulOperationParser : public TFLiteOperationParser {
|
||||
constant_tensor = 0;
|
||||
constant_dims = input0->dims;
|
||||
}
|
||||
return ParseMultiplyScalar(runtime_tensor, constant_tensor, constant_dims,
|
||||
graph, reader);
|
||||
return ParseMultiplyScalar(node, runtime_tensor, constant_tensor,
|
||||
constant_dims, graph, reader);
|
||||
}
|
||||
|
||||
private:
|
||||
Status ParseApplyMask(int input_tensor0, int input_tensor1,
|
||||
Status ParseApplyMask(Node* node, int input_tensor0, int input_tensor1,
|
||||
GraphFloat32* graph, ObjectReader* reader) {
|
||||
Node* node = graph->NewNode();
|
||||
node->operation.type = ToString(OperationType::APPLY_MASK);
|
||||
RETURN_IF_ERROR(reader->AddInput(node, input_tensor0));
|
||||
RETURN_IF_ERROR(reader->AddInput(node, input_tensor1));
|
||||
return reader->AddOutputs(node);
|
||||
}
|
||||
|
||||
Status ParseMultiplyScalar(int runtime_tensor, int constant_tensor,
|
||||
Status ParseMultiplyScalar(Node* node, int runtime_tensor,
|
||||
int constant_tensor,
|
||||
const TfLiteIntArray* constant_dims,
|
||||
GraphFloat32* graph, ObjectReader* reader) {
|
||||
Node* node = graph->NewNode();
|
||||
node->operation.type = ToString(OperationType::MULTIPLY_SCALAR);
|
||||
RETURN_IF_ERROR(reader->AddInput(node, runtime_tensor));
|
||||
MultiplyScalarAttributes attr;
|
||||
MultiplyAttributes attr;
|
||||
if (constant_dims->size <= 0) {
|
||||
Tensor<Scalar, DataType::FLOAT32> tensor;
|
||||
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
|
||||
|
@ -72,8 +72,6 @@ std::string ToString(enum OperationType op) {
|
||||
return "abs";
|
||||
case OperationType::ADD:
|
||||
return "add";
|
||||
case OperationType::APPLY_MASK:
|
||||
return "apply_mask";
|
||||
case OperationType::BATCH_NORMALIZATION:
|
||||
return "batch_normalization";
|
||||
case OperationType::BATCH_TO_SPACE:
|
||||
@ -106,8 +104,6 @@ std::string ToString(enum OperationType op) {
|
||||
return "mean";
|
||||
case OperationType::MUL:
|
||||
return "mul";
|
||||
case OperationType::MULTIPLY_SCALAR:
|
||||
return "multiply_scalar";
|
||||
case OperationType::PAD:
|
||||
return "pad";
|
||||
case OperationType::POOLING_2D:
|
||||
@ -157,7 +153,6 @@ OperationType OperationTypeFromString(const std::string& name) {
|
||||
new std::unordered_map<std::string, OperationType>({
|
||||
{"abs", OperationType::ABS},
|
||||
{"add", OperationType::ADD},
|
||||
{"apply_mask", OperationType::APPLY_MASK},
|
||||
{"batch_normalization", OperationType::BATCH_NORMALIZATION},
|
||||
{"concat", OperationType::CONCAT},
|
||||
{"const", OperationType::CONST},
|
||||
@ -173,7 +168,6 @@ OperationType OperationTypeFromString(const std::string& name) {
|
||||
{"max_unpooling", OperationType::MAX_UNPOOLING_2D},
|
||||
{"mean", OperationType::MEAN},
|
||||
{"mul", OperationType::MUL},
|
||||
{"multiply_scalar", OperationType::MULTIPLY_SCALAR},
|
||||
{"pad", OperationType::PAD},
|
||||
{"pooling_2d", OperationType::POOLING_2D},
|
||||
{"pow", OperationType::POW},
|
||||
|
@ -34,8 +34,6 @@ enum class OperationType {
|
||||
UNKNOWN = 0,
|
||||
ABS,
|
||||
ADD,
|
||||
// TODO(eignasheva): remove APPLY_MASK operation, is should be just MUL
|
||||
APPLY_MASK,
|
||||
BATCH_TO_SPACE,
|
||||
BATCH_NORMALIZATION,
|
||||
CONCAT,
|
||||
@ -52,7 +50,6 @@ enum class OperationType {
|
||||
MAX_UNPOOLING_2D,
|
||||
MEAN,
|
||||
MUL,
|
||||
MULTIPLY_SCALAR,
|
||||
PAD,
|
||||
POOLING_2D,
|
||||
POW,
|
||||
@ -354,7 +351,7 @@ struct LstmAttributes {
|
||||
LstmKernelType kernel_type = LstmKernelType::BASIC;
|
||||
};
|
||||
|
||||
struct MultiplyScalarAttributes {
|
||||
struct MultiplyAttributes {
|
||||
absl::variant<absl::monostate, Tensor<Linear, DataType::FLOAT32>, float>
|
||||
param;
|
||||
};
|
||||
|
@ -33,15 +33,14 @@ class MergeConvolutionWithMul : public SequenceTransformation {
|
||||
GraphFloat32* graph) final {
|
||||
auto& conv_node = *sequence[0];
|
||||
auto& mul_node = *sequence[1];
|
||||
if (mul_node.operation.type != ToString(OperationType::MUL) &&
|
||||
mul_node.operation.type != ToString(OperationType::MULTIPLY_SCALAR)) {
|
||||
if (mul_node.operation.type != ToString(OperationType::MUL) ||
|
||||
!mul_node.operation.attributes.has_value()) {
|
||||
return {TransformStatus::SKIPPED, ""};
|
||||
}
|
||||
|
||||
MultiplyScalarAttributes mul_attr =
|
||||
absl::any_cast<MultiplyScalarAttributes>(mul_node.operation.attributes);
|
||||
if (!absl::get_if<Tensor<Linear, DataType::FLOAT32>>(
|
||||
&mul_attr.param) &&
|
||||
MultiplyAttributes mul_attr =
|
||||
absl::any_cast<MultiplyAttributes>(mul_node.operation.attributes);
|
||||
if (!absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param) &&
|
||||
!absl::get_if<float>(&mul_attr.param)) {
|
||||
return {
|
||||
TransformStatus::DECLINED,
|
||||
@ -93,13 +92,13 @@ class MergeMulWithConvolution : public SequenceTransformation {
|
||||
GraphFloat32* graph) final {
|
||||
auto& conv_node = *sequence[1];
|
||||
auto& mul_node = *sequence[0];
|
||||
if (mul_node.operation.type != ToString(OperationType::MUL) &&
|
||||
mul_node.operation.type != ToString(OperationType::MULTIPLY_SCALAR)) {
|
||||
if (mul_node.operation.type != ToString(OperationType::MUL) ||
|
||||
!mul_node.operation.attributes.has_value()) {
|
||||
return {TransformStatus::SKIPPED, ""};
|
||||
}
|
||||
|
||||
MultiplyScalarAttributes mul_attr =
|
||||
absl::any_cast<MultiplyScalarAttributes>(mul_node.operation.attributes);
|
||||
MultiplyAttributes mul_attr =
|
||||
absl::any_cast<MultiplyAttributes>(mul_node.operation.attributes);
|
||||
if (!absl::get_if<Tensor<Linear, DataType::FLOAT32>>(
|
||||
&mul_attr.param) &&
|
||||
!absl::get_if<float>(&mul_attr.param)) {
|
||||
@ -155,7 +154,7 @@ std::unique_ptr<SequenceTransformation> NewMergeMulWithConvolution() {
|
||||
return absl::make_unique<MergeMulWithConvolution>();
|
||||
}
|
||||
|
||||
void FuseConvolution2DWithMultiply(const MultiplyScalarAttributes& mul_attr,
|
||||
void FuseConvolution2DWithMultiply(const MultiplyAttributes& mul_attr,
|
||||
Convolution2DAttributes* attr) {
|
||||
auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
|
||||
auto mul_scalar = absl::get_if<float>(&mul_attr.param);
|
||||
@ -176,7 +175,7 @@ void FuseConvolution2DWithMultiply(const MultiplyScalarAttributes& mul_attr,
|
||||
}
|
||||
|
||||
void FuseDepthwiseConvolution2DWithMultiply(
|
||||
const MultiplyScalarAttributes& mul_attr,
|
||||
const MultiplyAttributes& mul_attr,
|
||||
DepthwiseConvolution2DAttributes* attr) {
|
||||
auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
|
||||
auto mul_scalar = absl::get_if<float>(&mul_attr.param);
|
||||
@ -198,8 +197,7 @@ void FuseDepthwiseConvolution2DWithMultiply(
|
||||
}
|
||||
|
||||
void FuseConvolutionTransposedWithMultiply(
|
||||
const MultiplyScalarAttributes& mul_attr,
|
||||
ConvolutionTransposedAttributes* attr) {
|
||||
const MultiplyAttributes& mul_attr, ConvolutionTransposedAttributes* attr) {
|
||||
auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
|
||||
auto mul_scalar = absl::get_if<float>(&mul_attr.param);
|
||||
for (int d = 0; d < attr->weights.shape.o; ++d) {
|
||||
@ -218,7 +216,7 @@ void FuseConvolutionTransposedWithMultiply(
|
||||
}
|
||||
}
|
||||
|
||||
void FuseFullyConnectedWithMultiply(const MultiplyScalarAttributes& mul_attr,
|
||||
void FuseFullyConnectedWithMultiply(const MultiplyAttributes& mul_attr,
|
||||
FullyConnectedAttributes* attr) {
|
||||
auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
|
||||
auto mul_scalar = absl::get_if<float>(&mul_attr.param);
|
||||
@ -234,7 +232,7 @@ void FuseFullyConnectedWithMultiply(const MultiplyScalarAttributes& mul_attr,
|
||||
}
|
||||
}
|
||||
|
||||
void FuseMultiplyWithConvolution2D(const MultiplyScalarAttributes& mul_attr,
|
||||
void FuseMultiplyWithConvolution2D(const MultiplyAttributes& mul_attr,
|
||||
Convolution2DAttributes* attr) {
|
||||
auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
|
||||
auto mul_scalar = absl::get_if<float>(&mul_attr.param);
|
||||
@ -252,7 +250,7 @@ void FuseMultiplyWithConvolution2D(const MultiplyScalarAttributes& mul_attr,
|
||||
}
|
||||
|
||||
void FuseMultiplyWithDepthwiseConvolution2D(
|
||||
const MultiplyScalarAttributes& mul_attr,
|
||||
const MultiplyAttributes& mul_attr,
|
||||
DepthwiseConvolution2DAttributes* attr) {
|
||||
auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
|
||||
auto mul_scalar = absl::get_if<float>(&mul_attr.param);
|
||||
@ -270,8 +268,7 @@ void FuseMultiplyWithDepthwiseConvolution2D(
|
||||
}
|
||||
|
||||
void FuseMultiplyWithConvolutionTransposed(
|
||||
const MultiplyScalarAttributes& mul_attr,
|
||||
ConvolutionTransposedAttributes* attr) {
|
||||
const MultiplyAttributes& mul_attr, ConvolutionTransposedAttributes* attr) {
|
||||
auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
|
||||
auto mul_scalar = absl::get_if<float>(&mul_attr.param);
|
||||
for (int s = 0; s < attr->weights.shape.i; ++s) {
|
||||
@ -287,7 +284,7 @@ void FuseMultiplyWithConvolutionTransposed(
|
||||
}
|
||||
}
|
||||
|
||||
void FuseMultiplyWithFullyConnected(const MultiplyScalarAttributes& mul_attr,
|
||||
void FuseMultiplyWithFullyConnected(const MultiplyAttributes& mul_attr,
|
||||
FullyConnectedAttributes* attr) {
|
||||
auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
|
||||
auto mul_scalar = absl::get_if<float>(&mul_attr.param);
|
||||
|
@ -38,53 +38,49 @@ std::unique_ptr<SequenceTransformation> NewMergeMulWithConvolution();
|
||||
// Modify Convolution2DAttributes so that after making convolution with
|
||||
// modified attributes we will have the same result as convolution
|
||||
// with old attributes and following multiply operation.
|
||||
void FuseConvolution2DWithMultiply(const MultiplyScalarAttributes& mul_attr,
|
||||
void FuseConvolution2DWithMultiply(const MultiplyAttributes& mul_attr,
|
||||
Convolution2DAttributes* attr);
|
||||
|
||||
// Modify DepthwiseConvolution2DAttributes so that after making depth wise
|
||||
// convolution with modified attributes we will have the same result as depth
|
||||
// wise convolution with old attributes and following multiply operation.
|
||||
void FuseDepthwiseConvolution2DWithMultiply(
|
||||
const MultiplyScalarAttributes& mul_attr,
|
||||
DepthwiseConvolution2DAttributes* attr);
|
||||
const MultiplyAttributes& mul_attr, DepthwiseConvolution2DAttributes* attr);
|
||||
|
||||
// Modify ConvolutionTransposedAttributes so that after making convolution
|
||||
// transposed with modified attributes we will have the same result as
|
||||
// convolution transposed with old attributes and following multiply operation.
|
||||
void FuseConvolutionTransposedWithMultiply(
|
||||
const MultiplyScalarAttributes& mul_attr,
|
||||
ConvolutionTransposedAttributes* attr);
|
||||
const MultiplyAttributes& mul_attr, ConvolutionTransposedAttributes* attr);
|
||||
|
||||
// Modify FullyConnectedAttributes so that after making fully connected with
|
||||
// modified attributes we will have the same result as fully connected
|
||||
// with old attributes and following multiply operation.
|
||||
void FuseFullyConnectedWithMultiply(const MultiplyScalarAttributes& mul_attr,
|
||||
void FuseFullyConnectedWithMultiply(const MultiplyAttributes& mul_attr,
|
||||
FullyConnectedAttributes* attr);
|
||||
|
||||
// Modify Convolution2DAttributes so that after making convolution with
|
||||
// modified attributes we will have the same result as multiply operation and
|
||||
// convolution with old attributes
|
||||
void FuseMultiplyWithConvolution2D(const MultiplyScalarAttributes& mul_attr,
|
||||
void FuseMultiplyWithConvolution2D(const MultiplyAttributes& mul_attr,
|
||||
Convolution2DAttributes* attr);
|
||||
|
||||
// Modify DepthwiseConvolution2DAttributes so that after making depth wise
|
||||
// convolution with modified attributes we will have the same result as multiply
|
||||
// operation and depth wise convolution with old attributes
|
||||
void FuseMultiplyWithDepthwiseConvolution2D(
|
||||
const MultiplyScalarAttributes& mul_attr,
|
||||
DepthwiseConvolution2DAttributes* attr);
|
||||
const MultiplyAttributes& mul_attr, DepthwiseConvolution2DAttributes* attr);
|
||||
|
||||
// Modify ConvolutionTransposedAttributes so that after making convolution
|
||||
// transposed with modified attributes we will have the same result as multiply
|
||||
// operation and convolution transposed with old attributes
|
||||
void FuseMultiplyWithConvolutionTransposed(
|
||||
const MultiplyScalarAttributes& mul_attr,
|
||||
ConvolutionTransposedAttributes* attr);
|
||||
const MultiplyAttributes& mul_attr, ConvolutionTransposedAttributes* attr);
|
||||
|
||||
// Modify FullyConnectedAttributes so that after making fully connected
|
||||
// with modified attributes we will have the same result as multiply
|
||||
// operation and fully connected with old attributes
|
||||
void FuseMultiplyWithFullyConnected(const MultiplyScalarAttributes& mul_attr,
|
||||
void FuseMultiplyWithFullyConnected(const MultiplyAttributes& mul_attr,
|
||||
FullyConnectedAttributes* attr);
|
||||
|
||||
} // namespace gpu
|
||||
|
@ -46,7 +46,7 @@ TEST(MergeConvolutionWithMulTest, Smoke) {
|
||||
Tensor<Linear, DataType::FLOAT32> mul_tensor;
|
||||
mul_tensor.shape = Linear(16);
|
||||
mul_tensor.data.resize(16);
|
||||
MultiplyScalarAttributes mul_attr;
|
||||
MultiplyAttributes mul_attr;
|
||||
mul_attr.param = mul_tensor;
|
||||
|
||||
auto conv_node = graph.NewNode();
|
||||
@ -87,7 +87,7 @@ TEST(MergeMulWithConvolutionTest, Smoke) {
|
||||
Tensor<Linear, DataType::FLOAT32> mul_tensor;
|
||||
mul_tensor.shape = Linear(8);
|
||||
mul_tensor.data.resize(8);
|
||||
MultiplyScalarAttributes mul_attr;
|
||||
MultiplyAttributes mul_attr;
|
||||
mul_attr.param = mul_tensor;
|
||||
|
||||
Convolution2DAttributes conv_attr;
|
||||
@ -140,7 +140,7 @@ TEST(FuseMulAfterConvolution2DTest, Smoke) {
|
||||
Tensor<Linear, DataType::FLOAT32> mul_tensor;
|
||||
mul_tensor.shape = Linear(2);
|
||||
mul_tensor.data = {0.5f, 2.0f};
|
||||
MultiplyScalarAttributes mul_attr;
|
||||
MultiplyAttributes mul_attr;
|
||||
mul_attr.param = mul_tensor;
|
||||
|
||||
FuseConvolution2DWithMultiply(mul_attr, &attr);
|
||||
@ -161,7 +161,7 @@ TEST(FuseMulAfterDepthwiseConvolution2DTest, Smoke) {
|
||||
Tensor<Linear, DataType::FLOAT32> mul_tensor;
|
||||
mul_tensor.shape = Linear(4);
|
||||
mul_tensor.data = {0.5f, 2.0f, 4.0f, 0.25f};
|
||||
MultiplyScalarAttributes mul_attr;
|
||||
MultiplyAttributes mul_attr;
|
||||
mul_attr.param = mul_tensor;
|
||||
|
||||
FuseDepthwiseConvolution2DWithMultiply(mul_attr, &attr);
|
||||
@ -183,7 +183,7 @@ TEST(FuseMulAfterConvolutionTransposedTest, Smoke) {
|
||||
Tensor<Linear, DataType::FLOAT32> mul_tensor;
|
||||
mul_tensor.shape = Linear(2);
|
||||
mul_tensor.data = {0.5f, 2.0f};
|
||||
MultiplyScalarAttributes mul_attr;
|
||||
MultiplyAttributes mul_attr;
|
||||
mul_attr.param = mul_tensor;
|
||||
|
||||
FuseConvolutionTransposedWithMultiply(mul_attr, &attr);
|
||||
@ -204,7 +204,7 @@ TEST(FuseMulAfterFullyConnectedTest, Smoke) {
|
||||
Tensor<Linear, DataType::FLOAT32> mul_tensor;
|
||||
mul_tensor.shape = Linear(2);
|
||||
mul_tensor.data = {0.5f, 2.0f};
|
||||
MultiplyScalarAttributes mul_attr;
|
||||
MultiplyAttributes mul_attr;
|
||||
mul_attr.param = mul_tensor;
|
||||
|
||||
FuseFullyConnectedWithMultiply(mul_attr, &attr);
|
||||
@ -224,7 +224,7 @@ TEST(FuseMulBeforeConvolution2DTest, Smoke) {
|
||||
Tensor<Linear, DataType::FLOAT32> mul_tensor;
|
||||
mul_tensor.shape = Linear(2);
|
||||
mul_tensor.data = {0.5f, 2.0f};
|
||||
MultiplyScalarAttributes mul_attr;
|
||||
MultiplyAttributes mul_attr;
|
||||
mul_attr.param = mul_tensor;
|
||||
|
||||
FuseMultiplyWithConvolution2D(mul_attr, &attr);
|
||||
@ -245,7 +245,7 @@ TEST(FuseMulBeforeDepthwiseConvolution2DTest, Smoke) {
|
||||
Tensor<Linear, DataType::FLOAT32> mul_tensor;
|
||||
mul_tensor.shape = Linear(4);
|
||||
mul_tensor.data = {0.5f, 2.0f, 4.0f, 0.25f};
|
||||
MultiplyScalarAttributes mul_attr;
|
||||
MultiplyAttributes mul_attr;
|
||||
mul_attr.param = mul_tensor;
|
||||
|
||||
FuseMultiplyWithDepthwiseConvolution2D(mul_attr, &attr);
|
||||
@ -267,7 +267,7 @@ TEST(FuseMulBeforeConvolutionTransposedTest, Smoke) {
|
||||
Tensor<Linear, DataType::FLOAT32> mul_tensor;
|
||||
mul_tensor.shape = Linear(2);
|
||||
mul_tensor.data = {0.5f, 2.0f};
|
||||
MultiplyScalarAttributes mul_attr;
|
||||
MultiplyAttributes mul_attr;
|
||||
mul_attr.param = mul_tensor;
|
||||
|
||||
FuseMultiplyWithConvolutionTransposed(mul_attr, &attr);
|
||||
@ -288,7 +288,7 @@ TEST(FuseMulBeforeFullyConnectedTest, Smoke) {
|
||||
Tensor<Linear, DataType::FLOAT32> mul_tensor;
|
||||
mul_tensor.shape = Linear(2);
|
||||
mul_tensor.data = {0.5f, 2.0f};
|
||||
MultiplyScalarAttributes mul_attr;
|
||||
MultiplyAttributes mul_attr;
|
||||
mul_attr.param = mul_tensor;
|
||||
|
||||
FuseMultiplyWithFullyConnected(mul_attr, &attr);
|
||||
|
@ -29,115 +29,116 @@ limitations under the License.
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace gl {
|
||||
|
||||
namespace {
|
||||
|
||||
class ApplyMask : public NodeShader {
|
||||
public:
|
||||
static bool IsSupported(const GenerationContext& ctx) {
|
||||
const auto inputs = ctx.graph->FindInputs(ctx.node->id);
|
||||
if (inputs.size() != 2) return false;
|
||||
const auto& shape0 = inputs[0]->tensor.shape;
|
||||
const auto& shape1 = inputs[1]->tensor.shape;
|
||||
bool IsApplyMaskSupported(const NodeShader::GenerationContext& ctx) {
|
||||
const auto inputs = ctx.graph->FindInputs(ctx.node->id);
|
||||
if (inputs.size() != 2) return false;
|
||||
const auto& shape0 = inputs[0]->tensor.shape;
|
||||
const auto& shape1 = inputs[1]->tensor.shape;
|
||||
|
||||
// [H, W, C] x [H, W, 0][0]
|
||||
if (shape1.c == 1) return true;
|
||||
|
||||
if (shape0.c != shape1.c) return false;
|
||||
|
||||
// [H, W, C] x [H, W, C]
|
||||
if (shape0.h == shape1.h && shape0.w == shape1.w) return true;
|
||||
|
||||
// [H, W, C] x [0, 0, C]
|
||||
return shape1.h == 1 && shape1.w == 1;
|
||||
// [H, W, C] x [H, W, 0][0]
|
||||
if (shape0.h == shape1.h && shape0.w == shape1.w && shape1.c == 1) {
|
||||
return true;
|
||||
}
|
||||
|
||||
Status GenerateCode(const GenerationContext& ctx,
|
||||
GeneratedCode* generated_code) const final {
|
||||
if (!IsSupported(ctx)) {
|
||||
return InvalidArgumentError(
|
||||
"This case is not supported by apply mask operation");
|
||||
}
|
||||
const auto inputs = ctx.graph->FindInputs(ctx.node->id);
|
||||
const auto& shape0 = inputs[0]->tensor.shape;
|
||||
const auto& shape1 = inputs[1]->tensor.shape;
|
||||
// [H, W, C] x [H, W, C]
|
||||
if (shape0 == shape1) {
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string source = "value_0 = $input_data_0[gid.x, gid.y, gid.z]$ * ";
|
||||
if (shape1.c == 1) {
|
||||
// [H, W, C] x [H, W, 0][0]
|
||||
absl::StrAppend(&source, "$input_data_1[gid.x, gid.y, 0]$.x;");
|
||||
} else if (shape0.h == shape1.h && shape0.w == shape1.w) {
|
||||
// [H, W, C] x [H, W, C]
|
||||
absl::StrAppend(&source, "$input_data_1[gid.x, gid.y, gid.z]$;");
|
||||
} else {
|
||||
// [H, W, C] x [0, 0, C]
|
||||
absl::StrAppend(&source, "$input_data_1[0, 0, gid.z]$;");
|
||||
}
|
||||
// [H, W, C] x [0, 0, C]
|
||||
return shape1.h == 1 && shape1.w == 1 && shape0.c == shape1.c;
|
||||
}
|
||||
|
||||
Status GenerateApplyMaskCode(const NodeShader::GenerationContext& ctx,
|
||||
GeneratedCode* generated_code) {
|
||||
const auto inputs = ctx.graph->FindInputs(ctx.node->id);
|
||||
const auto& shape0 = inputs[0]->tensor.shape;
|
||||
const auto& shape1 = inputs[1]->tensor.shape;
|
||||
|
||||
std::string source = "value_0 = $input_data_0[gid.x, gid.y, gid.z]$ * ";
|
||||
if (shape1.c == 1) {
|
||||
// [H, W, C] x [H, W, 0][0]
|
||||
absl::StrAppend(&source, "$input_data_1[gid.x, gid.y, 0]$.x;");
|
||||
} else if (shape0.h == shape1.h && shape0.w == shape1.w) {
|
||||
// [H, W, C] x [H, W, C]
|
||||
absl::StrAppend(&source, "$input_data_1[gid.x, gid.y, gid.z]$;");
|
||||
} else {
|
||||
// [H, W, C] x [0, 0, C]
|
||||
absl::StrAppend(&source, "$input_data_1[0, 0, gid.z]$;");
|
||||
}
|
||||
|
||||
*generated_code = {
|
||||
/*parameters=*/{},
|
||||
/*objects=*/{},
|
||||
/*shared_variables=*/{},
|
||||
/*workload=*/uint3(),
|
||||
/*workgroup=*/uint3(),
|
||||
/*source_code=*/std::move(source),
|
||||
/*input=*/IOStructure::ONLY_DEFINITIONS,
|
||||
/*output=*/IOStructure::AUTO,
|
||||
};
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
Status GenerateMultiplyScalarCode(const NodeShader::GenerationContext& ctx,
|
||||
GeneratedCode* generated_code) {
|
||||
auto attr =
|
||||
absl::any_cast<MultiplyAttributes>(ctx.node->operation.attributes);
|
||||
auto muls = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&attr.param);
|
||||
auto scalar = absl::get_if<float>(&attr.param);
|
||||
|
||||
if (scalar) {
|
||||
*generated_code = {
|
||||
/*parameters=*/{},
|
||||
/*parameters=*/{{"scalar", *scalar}},
|
||||
/*objects=*/{},
|
||||
/*shared_variables=*/{},
|
||||
/*workload=*/uint3(),
|
||||
/*workgroup=*/uint3(),
|
||||
/*source_code=*/std::move(source),
|
||||
/*input=*/IOStructure::ONLY_DEFINITIONS,
|
||||
/*source_code=*/"value_0 *= $scalar$;",
|
||||
/*input=*/IOStructure::AUTO,
|
||||
/*output=*/IOStructure::AUTO,
|
||||
};
|
||||
} else {
|
||||
if (!muls) {
|
||||
return InvalidArgumentError("Empty parameters for Multiplication.");
|
||||
}
|
||||
auto shape = ctx.graph->FindInputs(ctx.node->id)[0]->tensor.shape;
|
||||
*generated_code = {
|
||||
/*parameters=*/{},
|
||||
/*objects=*/{{"mul_buffer", MakeReadonlyObject(muls->data)}},
|
||||
/*shared_variables=*/{},
|
||||
// Declare workload explicitly because shader depends on gid.z.
|
||||
/*workload=*/
|
||||
uint3(shape.w, shape.h, IntegralDivideRoundUp(shape.c, 4)),
|
||||
/*workgroup=*/uint3(),
|
||||
/*source_code=*/"value_0 *= $mul_buffer[gid.z]$;",
|
||||
/*input=*/IOStructure::AUTO,
|
||||
/*output=*/IOStructure::AUTO,
|
||||
};
|
||||
return OkStatus();
|
||||
}
|
||||
};
|
||||
|
||||
class MultiplyScalar : public NodeShader {
|
||||
return OkStatus();
|
||||
}
|
||||
|
||||
class Multiply : public NodeShader {
|
||||
public:
|
||||
Status GenerateCode(const GenerationContext& ctx,
|
||||
GeneratedCode* generated_code) const final {
|
||||
auto attr = absl::any_cast<MultiplyScalarAttributes>(
|
||||
ctx.node->operation.attributes);
|
||||
auto muls = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&attr.param);
|
||||
auto scalar = absl::get_if<float>(&attr.param);
|
||||
|
||||
if (scalar) {
|
||||
*generated_code = {
|
||||
/*parameters=*/{{"scalar", *scalar}},
|
||||
/*objects=*/{},
|
||||
/*shared_variables=*/{},
|
||||
/*workload=*/uint3(),
|
||||
/*workgroup=*/uint3(),
|
||||
/*source_code=*/"value_0 *= $scalar$;",
|
||||
/*input=*/IOStructure::AUTO,
|
||||
/*output=*/IOStructure::AUTO,
|
||||
};
|
||||
if (IsApplyMaskSupported(ctx)) {
|
||||
return GenerateApplyMaskCode(ctx, generated_code);
|
||||
} else {
|
||||
if (!muls) {
|
||||
return InvalidArgumentError("Empty parameters for Multiplication.");
|
||||
}
|
||||
auto shape = ctx.graph->FindInputs(ctx.node->id)[0]->tensor.shape;
|
||||
*generated_code = {
|
||||
/*parameters=*/{},
|
||||
/*objects=*/{{"mul_buffer", MakeReadonlyObject(muls->data)}},
|
||||
/*shared_variables=*/{},
|
||||
// Declare workload explicitly because shader depends on gid.z.
|
||||
/*workload=*/
|
||||
uint3(shape.w, shape.h, IntegralDivideRoundUp(shape.c, 4)),
|
||||
/*workgroup=*/uint3(),
|
||||
/*source_code=*/"value_0 *= $mul_buffer[gid.z]$;",
|
||||
/*input=*/IOStructure::AUTO,
|
||||
/*output=*/IOStructure::AUTO,
|
||||
};
|
||||
return GenerateMultiplyScalarCode(ctx, generated_code);
|
||||
}
|
||||
|
||||
return OkStatus();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<NodeShader> NewApplyMaskNodeShader() {
|
||||
return absl::make_unique<ApplyMask>();
|
||||
}
|
||||
|
||||
std::unique_ptr<NodeShader> NewMultiplyScalarNodeShader() {
|
||||
return absl::make_unique<MultiplyScalar>();
|
||||
std::unique_ptr<NodeShader> NewMultiplyNodeShader() {
|
||||
return absl::make_unique<Multiply>();
|
||||
}
|
||||
|
||||
} // namespace gl
|
||||
|
@ -25,9 +25,7 @@ namespace tflite {
|
||||
namespace gpu {
|
||||
namespace gl {
|
||||
|
||||
std::unique_ptr<NodeShader> NewApplyMaskNodeShader();
|
||||
|
||||
std::unique_ptr<NodeShader> NewMultiplyScalarNodeShader();
|
||||
std::unique_ptr<NodeShader> NewMultiplyNodeShader();
|
||||
|
||||
} // namespace gl
|
||||
} // namespace gpu
|
||||
|
@ -41,13 +41,12 @@ TEST(MulTest, Scalar) {
|
||||
output.ref = 1;
|
||||
output.shape = BHWC(1, 2, 2, 1);
|
||||
|
||||
MultiplyScalarAttributes attr;
|
||||
MultiplyAttributes attr;
|
||||
attr.param = 2.f;
|
||||
|
||||
// TODO(eignasheva): change to MULTIPLY_SCALAR
|
||||
SingleOpModel model({ToString(OperationType::MUL), attr}, {input}, {output});
|
||||
ASSERT_TRUE(model.PopulateTensor(0, {1, 2, 3, 4}));
|
||||
ASSERT_OK(model.Invoke(*NewMultiplyScalarNodeShader()));
|
||||
ASSERT_OK(model.Invoke(*NewMultiplyNodeShader()));
|
||||
EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {2, 4, 6, 8}));
|
||||
}
|
||||
|
||||
@ -62,21 +61,20 @@ TEST(MulTest, Linear) {
|
||||
output.ref = 1;
|
||||
output.shape = BHWC(1, 1, 2, 2);
|
||||
|
||||
MultiplyScalarAttributes attr;
|
||||
MultiplyAttributes attr;
|
||||
Tensor<Linear, DataType::FLOAT32> tensor;
|
||||
tensor.shape.v = 2;
|
||||
tensor.id = 1;
|
||||
tensor.data = {2, 3};
|
||||
attr.param = std::move(tensor);
|
||||
|
||||
// TODO(eignasheva): change to MULTIPLY_SCALAR
|
||||
SingleOpModel model({ToString(OperationType::MUL), attr}, {input}, {output});
|
||||
ASSERT_TRUE(model.PopulateTensor(0, {1, 2, 3, 4}));
|
||||
ASSERT_OK(model.Invoke(*NewMultiplyScalarNodeShader()));
|
||||
ASSERT_OK(model.Invoke(*NewMultiplyNodeShader()));
|
||||
EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {2, 6, 6, 12}));
|
||||
}
|
||||
|
||||
TEST(ApplyMaskTest, MaskChannel1) {
|
||||
TEST(MulTest, MaskChannel1) {
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
@ -92,15 +90,15 @@ TEST(ApplyMaskTest, MaskChannel1) {
|
||||
output.ref = 2;
|
||||
output.shape = BHWC(1, 1, 2, 2);
|
||||
|
||||
SingleOpModel model({ToString(OperationType::APPLY_MASK), {}}, {input, mask},
|
||||
SingleOpModel model({ToString(OperationType::MUL), {}}, {input, mask},
|
||||
{output});
|
||||
ASSERT_TRUE(model.PopulateTensor(0, {1, 2, 3, 4}));
|
||||
ASSERT_TRUE(model.PopulateTensor(1, {2, 3}));
|
||||
ASSERT_OK(model.Invoke(*NewApplyMaskNodeShader()));
|
||||
ASSERT_OK(model.Invoke(*NewMultiplyNodeShader()));
|
||||
EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {2, 4, 9, 12}));
|
||||
}
|
||||
|
||||
TEST(ApplyMaskTest, MaskChannelEqualsToInputChannel) {
|
||||
TEST(MulTest, MaskChannelEqualsToInputChannel) {
|
||||
TensorRef<BHWC> input;
|
||||
input.type = DataType::FLOAT32;
|
||||
input.ref = 0;
|
||||
@ -116,11 +114,11 @@ TEST(ApplyMaskTest, MaskChannelEqualsToInputChannel) {
|
||||
output.ref = 2;
|
||||
output.shape = BHWC(1, 1, 2, 2);
|
||||
|
||||
SingleOpModel model({ToString(OperationType::APPLY_MASK), {}}, {input, mask},
|
||||
SingleOpModel model({ToString(OperationType::MUL), {}}, {input, mask},
|
||||
{output});
|
||||
ASSERT_TRUE(model.PopulateTensor(0, {1, 2, 3, 4}));
|
||||
ASSERT_TRUE(model.PopulateTensor(1, {1, 2, 3, 4}));
|
||||
ASSERT_OK(model.Invoke(*NewApplyMaskNodeShader()));
|
||||
ASSERT_OK(model.Invoke(*NewMultiplyNodeShader()));
|
||||
EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {1, 4, 9, 16}));
|
||||
}
|
||||
|
||||
|
@ -71,7 +71,6 @@ class Registry : public NodeShader {
|
||||
};
|
||||
|
||||
insert_op(Type::ADD, NewAddNodeShader);
|
||||
insert_op(Type::APPLY_MASK, NewApplyMaskNodeShader);
|
||||
insert_op(Type::CONCAT, NewAlignedConcatNodeShader);
|
||||
insert_op(Type::CONCAT, NewFlatConcatNodeShader);
|
||||
insert_op(Type::CONCAT, NewConcatNodeShader);
|
||||
@ -82,7 +81,7 @@ class Registry : public NodeShader {
|
||||
insert_op(Type::FULLY_CONNECTED, NewFullyConnectedNodeShader);
|
||||
insert_op(Type::LSTM, NewLstmNodeShader);
|
||||
insert_op(Type::MEAN, NewMeanNodeShader);
|
||||
insert_op(Type::MULTIPLY_SCALAR, NewMultiplyScalarNodeShader);
|
||||
insert_op(Type::MUL, NewMultiplyNodeShader);
|
||||
insert_op(Type::PAD, NewPadNodeShader);
|
||||
insert_op(Type::POOLING_2D, NewPoolingNodeShader);
|
||||
insert_op(Type::PRELU, NewPReLUNodeShader);
|
||||
|
@ -199,11 +199,15 @@ Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
|
||||
*tasks = Mean(node_id, inputs[0], outputs[0],
|
||||
absl::any_cast<MeanAttributes>(node->operation.attributes));
|
||||
break;
|
||||
case OperationType::MULTIPLY_SCALAR:
|
||||
*tasks = Multiply(
|
||||
node_id, inputs[0], outputs[0],
|
||||
absl::any_cast<MultiplyScalarAttributes>(node->operation.attributes),
|
||||
options);
|
||||
case OperationType::MUL:
|
||||
if (node->operation.attributes.has_value()) {
|
||||
*tasks = Multiply(
|
||||
node_id, inputs[0], outputs[0],
|
||||
absl::any_cast<MultiplyAttributes>(node->operation.attributes),
|
||||
options);
|
||||
} else {
|
||||
*tasks = ApplyMask(node_id, inputs[0], inputs[1], outputs[0], options);
|
||||
}
|
||||
break;
|
||||
case OperationType::PAD: {
|
||||
auto attr = absl::any_cast<PadAttributes>(node->operation.attributes);
|
||||
@ -268,12 +272,10 @@ Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
|
||||
case OperationType::SQUARED_DIFF:
|
||||
*tasks = ElementwiseWithTwoInputs(node_id, inputs, outputs[0], op_type);
|
||||
break;
|
||||
case OperationType::APPLY_MASK:
|
||||
case OperationType::BATCH_NORMALIZATION:
|
||||
case OperationType::BATCH_TO_SPACE:
|
||||
case OperationType::CONST:
|
||||
case OperationType::LSTM:
|
||||
case OperationType::MUL:
|
||||
case OperationType::SPACE_TO_BATCH:
|
||||
case OperationType::TRANSPOSE:
|
||||
case OperationType::UNKNOWN:
|
||||
|
@ -128,9 +128,10 @@ std::vector<ComputeTaskDescriptorPtr> ApplyMask(int id, ValueId input_id_0,
|
||||
return {desc};
|
||||
}
|
||||
|
||||
std::vector<ComputeTaskDescriptorPtr> Multiply(
|
||||
int id, ValueId input_id, ValueId output_id,
|
||||
const MultiplyScalarAttributes& attr, const RuntimeOptions& options) {
|
||||
std::vector<ComputeTaskDescriptorPtr> Multiply(int id, ValueId input_id,
|
||||
ValueId output_id,
|
||||
const MultiplyAttributes& attr,
|
||||
const RuntimeOptions& options) {
|
||||
auto desc = std::make_shared<ComputeTaskDescriptor>();
|
||||
desc->id = id;
|
||||
desc->is_linkable = true;
|
||||
|
@ -26,9 +26,10 @@ namespace gpu {
|
||||
namespace metal {
|
||||
|
||||
// Multiply operation, supports scalar and vector broadcast.
|
||||
std::vector<ComputeTaskDescriptorPtr> Multiply(
|
||||
int id, ValueId input_id, ValueId output_id,
|
||||
const MultiplyScalarAttributes& attr, const RuntimeOptions& options);
|
||||
std::vector<ComputeTaskDescriptorPtr> Multiply(int id, ValueId input_id,
|
||||
ValueId output_id,
|
||||
const MultiplyAttributes& attr,
|
||||
const RuntimeOptions& options);
|
||||
|
||||
std::vector<ComputeTaskDescriptorPtr> ApplyMask(int id, ValueId input_id_0,
|
||||
ValueId input_id_1,
|
||||
|
@ -31,7 +31,7 @@ limitations under the License.
|
||||
using ::tflite::gpu::DataType;
|
||||
using ::tflite::gpu::BHWC;
|
||||
using ::tflite::gpu::Linear;
|
||||
using ::tflite::gpu::MultiplyScalarAttributes;
|
||||
using ::tflite::gpu::MultiplyAttributes;
|
||||
using ::tflite::gpu::OperationType;
|
||||
using ::tflite::gpu::Tensor;
|
||||
using ::tflite::gpu::TensorRef;
|
||||
@ -57,10 +57,10 @@ using ::tflite::gpu::metal::SingleOpModel;
|
||||
output.ref = 1;
|
||||
output.shape = BHWC(1, 2, 2, 1);
|
||||
|
||||
MultiplyScalarAttributes attr;
|
||||
MultiplyAttributes attr;
|
||||
attr.param = 2;
|
||||
|
||||
SingleOpModel model({ToString(OperationType::MULTIPLY_SCALAR), attr}, {input}, {output});
|
||||
SingleOpModel model({ToString(OperationType::MUL), attr}, {input}, {output});
|
||||
XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4}));
|
||||
auto status = model.Invoke();
|
||||
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
|
||||
@ -79,14 +79,14 @@ using ::tflite::gpu::metal::SingleOpModel;
|
||||
output.ref = 1;
|
||||
output.shape = BHWC(1, 1, 2, 2);
|
||||
|
||||
MultiplyScalarAttributes attr;
|
||||
MultiplyAttributes attr;
|
||||
Tensor<Linear, DataType::FLOAT32> tensor;
|
||||
tensor.shape.v = 2;
|
||||
tensor.id = 1;
|
||||
tensor.data = {2, 3};
|
||||
attr.param = std::move(tensor);
|
||||
|
||||
SingleOpModel model({ToString(OperationType::MULTIPLY_SCALAR), attr}, {input}, {output});
|
||||
SingleOpModel model({ToString(OperationType::MUL), attr}, {input}, {output});
|
||||
XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4}));
|
||||
auto status = model.Invoke();
|
||||
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
|
||||
@ -111,7 +111,7 @@ using ::tflite::gpu::metal::SingleOpModel;
|
||||
output.ref = 2;
|
||||
output.shape = BHWC(1, 1, 2, 2);
|
||||
|
||||
SingleOpModel model({ToString(OperationType::APPLY_MASK), {}}, {input, mask}, {output});
|
||||
SingleOpModel model({ToString(OperationType::MUL), {}}, {input, mask}, {output});
|
||||
XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4}));
|
||||
XCTAssertTrue(model.PopulateTensor(1, {2, 3}));
|
||||
auto status = model.Invoke();
|
||||
@ -136,13 +136,12 @@ using ::tflite::gpu::metal::SingleOpModel;
|
||||
output.ref = 2;
|
||||
output.shape = BHWC(1, 1, 2, 2);
|
||||
|
||||
SingleOpModel model({ToString(OperationType::APPLY_MASK), {}}, {input, mask}, {output});
|
||||
SingleOpModel model({ToString(OperationType::MUL), {}}, {input, mask}, {output});
|
||||
XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4}));
|
||||
XCTAssertTrue(model.PopulateTensor(1, {1, 2, 3, 4}));
|
||||
auto status = model.Invoke();
|
||||
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
|
||||
// Disable test for now.
|
||||
// status = CompareVectors({1, 4, 9, 16}, model.GetOutput(0), 1e-6f);
|
||||
status = CompareVectors({1, 4, 9, 16}, model.GetOutput(0), 1e-6f);
|
||||
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user