Remove APPLY_MASK and MULTIPLY_SCALAR enum values, use MUL instead.
Rename MultiplyScalarAttributes to MultiplyAttributes. PiperOrigin-RevId: 292228929 Change-Id: I459585d62e4161a93d6ec163bdbc0494c37eaf0e
This commit is contained in:
parent
abb256c1f8
commit
71c6f97e2d
@ -39,39 +39,6 @@ cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "apply_mask",
|
|
||||||
srcs = ["apply_mask.cc"],
|
|
||||||
hdrs = ["apply_mask.h"],
|
|
||||||
deps = [
|
|
||||||
":gpu_operation",
|
|
||||||
":util",
|
|
||||||
":work_group_picking",
|
|
||||||
"//tensorflow/lite/delegates/gpu/cl:cl_kernel",
|
|
||||||
"//tensorflow/lite/delegates/gpu/common:operations",
|
|
||||||
"//tensorflow/lite/delegates/gpu/common:status",
|
|
||||||
"//tensorflow/lite/delegates/gpu/common:types",
|
|
||||||
"@com_google_absl//absl/strings",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_test(
|
|
||||||
name = "apply_mask_test",
|
|
||||||
srcs = ["apply_mask_test.cc"],
|
|
||||||
linkstatic = True,
|
|
||||||
tags = tf_gpu_tests_tags() + [
|
|
||||||
"linux",
|
|
||||||
"local",
|
|
||||||
],
|
|
||||||
deps = [
|
|
||||||
":apply_mask",
|
|
||||||
":cl_test",
|
|
||||||
"//tensorflow/lite/delegates/gpu/common:operations",
|
|
||||||
"//tensorflow/lite/delegates/gpu/common:status",
|
|
||||||
"@com_google_googletest//:gtest_main",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "cl_test",
|
name = "cl_test",
|
||||||
testonly = 1,
|
testonly = 1,
|
||||||
@ -1328,7 +1295,6 @@ test_suite(
|
|||||||
name = "all_tests",
|
name = "all_tests",
|
||||||
tests = [
|
tests = [
|
||||||
"add_test",
|
"add_test",
|
||||||
"apply_mask_test",
|
|
||||||
"concat_test",
|
"concat_test",
|
||||||
"conv_buffer_1x1_test",
|
"conv_buffer_1x1_test",
|
||||||
"conv_buffer_test",
|
"conv_buffer_test",
|
||||||
|
@ -1,103 +0,0 @@
|
|||||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/apply_mask.h"
|
|
||||||
|
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "absl/strings/str_cat.h"
|
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/util.h"
|
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/work_group_picking.h"
|
|
||||||
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
|
||||||
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
|
||||||
|
|
||||||
namespace tflite {
|
|
||||||
namespace gpu {
|
|
||||||
namespace cl {
|
|
||||||
|
|
||||||
ApplyMask::ApplyMask(ApplyMask&& operation)
|
|
||||||
: ElementwiseOperation(std::move(operation)),
|
|
||||||
mask_type_(operation.mask_type_),
|
|
||||||
link_index_(operation.link_index_) {}
|
|
||||||
|
|
||||||
ApplyMask& ApplyMask::operator=(ApplyMask&& operation) {
|
|
||||||
if (this != &operation) {
|
|
||||||
mask_type_ = operation.mask_type_;
|
|
||||||
link_index_ = operation.link_index_;
|
|
||||||
ElementwiseOperation::operator=(std::move(operation));
|
|
||||||
}
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
void ApplyMask::SetLinkIndex(int index) { link_index_ = index; }
|
|
||||||
|
|
||||||
std::string ApplyMask::GetCoreCode(const LinkingContext& context) const {
|
|
||||||
const std::string size_name = "mask_size_op" + std::to_string(link_index_);
|
|
||||||
const std::string tensor_name = absl::StrCat("mask_data_op", link_index_);
|
|
||||||
TensorCodeGenerator mask(
|
|
||||||
tensor_name,
|
|
||||||
WHSPoint{size_name + ".x", size_name + ".y", size_name + ".z"},
|
|
||||||
definition_.src_tensors[1]);
|
|
||||||
switch (mask_type_) {
|
|
||||||
case MaskType::TENSOR:
|
|
||||||
return context.var_name + " *= " +
|
|
||||||
mask.ReadWHS(context.x_coord, context.y_coord, context.s_coord) +
|
|
||||||
";\n";
|
|
||||||
case MaskType::CHANNELS:
|
|
||||||
return context.var_name +
|
|
||||||
" *= " + mask.ReadWHS("0", "0", context.s_coord) + ";\n";
|
|
||||||
case MaskType::LAYER:
|
|
||||||
return context.var_name +
|
|
||||||
" *= " + mask.ReadWHS(context.x_coord, context.y_coord, "0") +
|
|
||||||
".x;\n";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string ApplyMask::GetArgsDeclaration() const {
|
|
||||||
std::string args;
|
|
||||||
const std::string tensor_name = absl::StrCat("mask_data_op", link_index_);
|
|
||||||
absl::StrAppend(&args, ",\n",
|
|
||||||
GetTensorDeclaration(AccessType::READ, tensor_name,
|
|
||||||
definition_.src_tensors[1]));
|
|
||||||
const std::string size_name = "mask_size_op" + std::to_string(link_index_);
|
|
||||||
absl::StrAppend(&args, ",\n int4 ", size_name);
|
|
||||||
return args;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status ApplyMask::BindArguments(CLKernel* kernel) {
|
|
||||||
RETURN_IF_ERROR(kernel->SetMemoryAuto(src_[1]->GetMemoryPtr()));
|
|
||||||
RETURN_IF_ERROR(kernel->SetBytesAuto(src_[1]->GetWBatchedHSB()));
|
|
||||||
return OkStatus();
|
|
||||||
}
|
|
||||||
|
|
||||||
ApplyMask CreateApplyMask(const OperationDef& definition, const BHWC& src_shape,
|
|
||||||
const BHWC& mask_shape) {
|
|
||||||
ApplyMask::MaskType mask_type;
|
|
||||||
if (mask_shape == src_shape) {
|
|
||||||
mask_type = ApplyMask::MaskType::TENSOR;
|
|
||||||
} else if (mask_shape.c == 1) {
|
|
||||||
mask_type = ApplyMask::MaskType::LAYER;
|
|
||||||
} else {
|
|
||||||
mask_type = ApplyMask::MaskType::CHANNELS;
|
|
||||||
}
|
|
||||||
ApplyMask operation(definition, mask_type);
|
|
||||||
operation.SetLinkIndex(0);
|
|
||||||
return operation;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace cl
|
|
||||||
} // namespace gpu
|
|
||||||
} // namespace tflite
|
|
@ -1,63 +0,0 @@
|
|||||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_APPLY_MASK_H_
|
|
||||||
#define TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_APPLY_MASK_H_
|
|
||||||
|
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/cl_kernel.h"
|
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
|
|
||||||
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
|
||||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
|
||||||
#include "tensorflow/lite/delegates/gpu/common/types.h"
|
|
||||||
|
|
||||||
namespace tflite {
|
|
||||||
namespace gpu {
|
|
||||||
namespace cl {
|
|
||||||
|
|
||||||
class ApplyMask : public ElementwiseOperation {
|
|
||||||
public:
|
|
||||||
// Move only
|
|
||||||
ApplyMask(ApplyMask&& operation);
|
|
||||||
ApplyMask& operator=(ApplyMask&& operation);
|
|
||||||
ApplyMask(const ApplyMask&) = delete;
|
|
||||||
ApplyMask& operator=(const ApplyMask&) = delete;
|
|
||||||
|
|
||||||
void SetLinkIndex(int index) override;
|
|
||||||
std::string GetCoreCode(const LinkingContext& context) const override;
|
|
||||||
std::string GetArgsDeclaration() const override;
|
|
||||||
Status BindArguments(CLKernel* kernel) override;
|
|
||||||
|
|
||||||
private:
|
|
||||||
friend ApplyMask CreateApplyMask(const OperationDef& definition,
|
|
||||||
const BHWC& src_shape,
|
|
||||||
const BHWC& mask_shape);
|
|
||||||
|
|
||||||
enum class MaskType { LAYER, CHANNELS, TENSOR };
|
|
||||||
|
|
||||||
explicit ApplyMask(const OperationDef& definition, MaskType mask_type)
|
|
||||||
: ElementwiseOperation(definition), mask_type_(mask_type) {}
|
|
||||||
|
|
||||||
MaskType mask_type_;
|
|
||||||
int link_index_;
|
|
||||||
};
|
|
||||||
|
|
||||||
ApplyMask CreateApplyMask(const OperationDef& definition, const BHWC& src_shape,
|
|
||||||
const BHWC& mask_shape);
|
|
||||||
|
|
||||||
} // namespace cl
|
|
||||||
} // namespace gpu
|
|
||||||
} // namespace tflite
|
|
||||||
|
|
||||||
#endif // TENSORFLOW_LITE_DELEGATES_GPU_CL_KERNELS_APPLY_MASK_H_
|
|
@ -1,127 +0,0 @@
|
|||||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/apply_mask.h"
|
|
||||||
|
|
||||||
#include <memory>
|
|
||||||
|
|
||||||
#include <gmock/gmock.h>
|
|
||||||
#include <gtest/gtest.h>
|
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/cl_test.h"
|
|
||||||
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
|
||||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
|
||||||
|
|
||||||
using ::testing::FloatNear;
|
|
||||||
using ::testing::Pointwise;
|
|
||||||
|
|
||||||
namespace tflite {
|
|
||||||
namespace gpu {
|
|
||||||
namespace cl {
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
TEST_F(OpenCLOperationTest, ApplyMaskOneChannel) {
|
|
||||||
TensorFloat32 src_tensor;
|
|
||||||
src_tensor.shape = BHWC(1, 2, 2, 2);
|
|
||||||
src_tensor.data = {-4.0f, -3.0f, -1.0f, 0.0f, 1.0f, 3.0f, 4.0f, 6.0f};
|
|
||||||
TensorFloat32 mask_tensor;
|
|
||||||
mask_tensor.shape = BHWC(1, 2, 2, 1);
|
|
||||||
mask_tensor.data = {2.0f, 0.5f, 1.0f, 0.0f};
|
|
||||||
|
|
||||||
for (auto storage : env_.GetSupportedStorages()) {
|
|
||||||
for (auto precision : env_.GetSupportedPrecisions()) {
|
|
||||||
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
|
|
||||||
OperationDef op_def;
|
|
||||||
op_def.precision = precision;
|
|
||||||
auto data_type = DeduceDataTypeFromPrecision(precision);
|
|
||||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
|
||||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
|
||||||
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
|
||||||
TensorFloat32 dst_tensor;
|
|
||||||
ApplyMask operation =
|
|
||||||
CreateApplyMask(op_def, src_tensor.shape, mask_tensor.shape);
|
|
||||||
ASSERT_OK(ExecuteGPUOperation({src_tensor, mask_tensor},
|
|
||||||
creation_context_, &operation,
|
|
||||||
BHWC(1, 2, 2, 2), &dst_tensor));
|
|
||||||
EXPECT_THAT(dst_tensor.data,
|
|
||||||
Pointwise(FloatNear(eps), {-8.0f, -6.0f, -0.5f, 0.0f, 1.0f,
|
|
||||||
3.0f, 0.0f, 0.0f}));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(OpenCLOperationTest, ApplyMaskEqualSizes) {
|
|
||||||
TensorFloat32 src_tensor;
|
|
||||||
src_tensor.shape = BHWC(1, 2, 2, 2);
|
|
||||||
src_tensor.data = {-4.0f, -3.0f, -1.0f, 0.0f, 1.0f, 3.0f, 4.0f, 6.0f};
|
|
||||||
TensorFloat32 mask_tensor;
|
|
||||||
mask_tensor.shape = BHWC(1, 2, 2, 2);
|
|
||||||
mask_tensor.data = {2.0f, 0.5f, 1.0f, 0.0f, 2.0f, 0.5f, 1.0f, 0.0f};
|
|
||||||
|
|
||||||
for (auto storage : env_.GetSupportedStorages()) {
|
|
||||||
for (auto precision : env_.GetSupportedPrecisions()) {
|
|
||||||
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
|
|
||||||
OperationDef op_def;
|
|
||||||
op_def.precision = precision;
|
|
||||||
auto data_type = DeduceDataTypeFromPrecision(precision);
|
|
||||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
|
||||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
|
||||||
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
|
||||||
TensorFloat32 dst_tensor;
|
|
||||||
ApplyMask operation =
|
|
||||||
CreateApplyMask(op_def, src_tensor.shape, mask_tensor.shape);
|
|
||||||
ASSERT_OK(ExecuteGPUOperation({src_tensor, mask_tensor},
|
|
||||||
creation_context_, &operation,
|
|
||||||
BHWC(1, 2, 2, 2), &dst_tensor));
|
|
||||||
EXPECT_THAT(dst_tensor.data,
|
|
||||||
Pointwise(FloatNear(eps), {-8.0f, -1.5f, -1.0f, 0.0f, 2.0f,
|
|
||||||
1.5f, 4.0f, 0.0f}));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(OpenCLOperationTest, ApplyMaskVector) {
|
|
||||||
TensorFloat32 src_tensor;
|
|
||||||
src_tensor.shape = BHWC(1, 2, 2, 2);
|
|
||||||
src_tensor.data = {-4.0f, -3.0f, -1.0f, 0.0f, 1.0f, 3.0f, 4.0f, 6.0f};
|
|
||||||
TensorFloat32 mask_tensor;
|
|
||||||
mask_tensor.shape = BHWC(1, 1, 1, 2);
|
|
||||||
mask_tensor.data = {2.0f, 0.5f};
|
|
||||||
|
|
||||||
for (auto storage : env_.GetSupportedStorages()) {
|
|
||||||
for (auto precision : env_.GetSupportedPrecisions()) {
|
|
||||||
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
|
|
||||||
OperationDef op_def;
|
|
||||||
op_def.precision = precision;
|
|
||||||
auto data_type = DeduceDataTypeFromPrecision(precision);
|
|
||||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
|
||||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
|
||||||
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
|
||||||
TensorFloat32 dst_tensor;
|
|
||||||
ApplyMask operation =
|
|
||||||
CreateApplyMask(op_def, src_tensor.shape, mask_tensor.shape);
|
|
||||||
ASSERT_OK(ExecuteGPUOperation({src_tensor, mask_tensor},
|
|
||||||
creation_context_, &operation,
|
|
||||||
BHWC(1, 2, 2, 2), &dst_tensor));
|
|
||||||
EXPECT_THAT(dst_tensor.data,
|
|
||||||
Pointwise(FloatNear(eps), {-8.0f, -1.5f, -2.0f, 0.0f, 2.0f,
|
|
||||||
1.5f, 8.0f, 3.0f}));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
} // namespace cl
|
|
||||||
} // namespace gpu
|
|
||||||
} // namespace tflite
|
|
@ -105,7 +105,7 @@ Status MultiplyAdd::BindArguments(CLKernel* kernel) {
|
|||||||
return OkStatus();
|
return OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status MultiplyAdd::UploadMul(const MultiplyScalarAttributes& attr,
|
Status MultiplyAdd::UploadMul(const MultiplyAttributes& attr,
|
||||||
CalculationsPrecision scalar_precision,
|
CalculationsPrecision scalar_precision,
|
||||||
CLContext* context) {
|
CLContext* context) {
|
||||||
auto mul = absl::get_if<::tflite::gpu::Tensor<Linear, DataType::FLOAT32>>(
|
auto mul = absl::get_if<::tflite::gpu::Tensor<Linear, DataType::FLOAT32>>(
|
||||||
@ -135,8 +135,7 @@ Status MultiplyAdd::UploadAdd(const AddAttributes& attr,
|
|||||||
|
|
||||||
Status CreateMultiplyAdd(const CreationContext& creation_context,
|
Status CreateMultiplyAdd(const CreationContext& creation_context,
|
||||||
const OperationDef& definition,
|
const OperationDef& definition,
|
||||||
const MultiplyScalarAttributes& attr,
|
const MultiplyAttributes& attr, MultiplyAdd* result) {
|
||||||
MultiplyAdd* result) {
|
|
||||||
const auto scalar_precision = creation_context.device->IsPowerVR()
|
const auto scalar_precision = creation_context.device->IsPowerVR()
|
||||||
? CalculationsPrecision::F32
|
? CalculationsPrecision::F32
|
||||||
: definition.precision;
|
: definition.precision;
|
||||||
@ -162,7 +161,7 @@ Status CreateMultiplyAdd(const CreationContext& creation_context,
|
|||||||
|
|
||||||
Status CreateMultiplyAdd(const CreationContext& creation_context,
|
Status CreateMultiplyAdd(const CreationContext& creation_context,
|
||||||
const OperationDef& definition,
|
const OperationDef& definition,
|
||||||
const MultiplyScalarAttributes& mul_attr,
|
const MultiplyAttributes& mul_attr,
|
||||||
const AddAttributes& add_attr, MultiplyAdd* result) {
|
const AddAttributes& add_attr, MultiplyAdd* result) {
|
||||||
const auto scalar_precision = creation_context.device->IsPowerVR()
|
const auto scalar_precision = creation_context.device->IsPowerVR()
|
||||||
? CalculationsPrecision::F32
|
? CalculationsPrecision::F32
|
||||||
@ -176,6 +175,76 @@ Status CreateMultiplyAdd(const CreationContext& creation_context,
|
|||||||
return OkStatus();
|
return OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ApplyMask::ApplyMask(ApplyMask&& operation)
|
||||||
|
: ElementwiseOperation(std::move(operation)),
|
||||||
|
mask_type_(operation.mask_type_),
|
||||||
|
link_index_(operation.link_index_) {}
|
||||||
|
|
||||||
|
ApplyMask& ApplyMask::operator=(ApplyMask&& operation) {
|
||||||
|
if (this != &operation) {
|
||||||
|
mask_type_ = operation.mask_type_;
|
||||||
|
link_index_ = operation.link_index_;
|
||||||
|
ElementwiseOperation::operator=(std::move(operation));
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ApplyMask::SetLinkIndex(int index) { link_index_ = index; }
|
||||||
|
|
||||||
|
std::string ApplyMask::GetCoreCode(const LinkingContext& context) const {
|
||||||
|
const std::string size_name = "mask_size_op" + std::to_string(link_index_);
|
||||||
|
const std::string tensor_name = absl::StrCat("mask_data_op", link_index_);
|
||||||
|
TensorCodeGenerator mask(
|
||||||
|
tensor_name,
|
||||||
|
WHSPoint{size_name + ".x", size_name + ".y", size_name + ".z"},
|
||||||
|
definition_.src_tensors[1]);
|
||||||
|
switch (mask_type_) {
|
||||||
|
case MaskType::TENSOR:
|
||||||
|
return context.var_name + " *= " +
|
||||||
|
mask.ReadWHS(context.x_coord, context.y_coord, context.s_coord) +
|
||||||
|
";\n";
|
||||||
|
case MaskType::CHANNELS:
|
||||||
|
return context.var_name +
|
||||||
|
" *= " + mask.ReadWHS("0", "0", context.s_coord) + ";\n";
|
||||||
|
case MaskType::LAYER:
|
||||||
|
return context.var_name +
|
||||||
|
" *= " + mask.ReadWHS(context.x_coord, context.y_coord, "0") +
|
||||||
|
".x;\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string ApplyMask::GetArgsDeclaration() const {
|
||||||
|
std::string args;
|
||||||
|
const std::string tensor_name = absl::StrCat("mask_data_op", link_index_);
|
||||||
|
absl::StrAppend(&args, ",\n",
|
||||||
|
GetTensorDeclaration(AccessType::READ, tensor_name,
|
||||||
|
definition_.src_tensors[1]));
|
||||||
|
const std::string size_name = "mask_size_op" + std::to_string(link_index_);
|
||||||
|
absl::StrAppend(&args, ",\n int4 ", size_name);
|
||||||
|
return args;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status ApplyMask::BindArguments(CLKernel* kernel) {
|
||||||
|
RETURN_IF_ERROR(kernel->SetMemoryAuto(src_[1]->GetMemoryPtr()));
|
||||||
|
RETURN_IF_ERROR(kernel->SetBytesAuto(src_[1]->GetWBatchedHSB()));
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
ApplyMask CreateApplyMask(const OperationDef& definition, const BHWC& src_shape,
|
||||||
|
const BHWC& mask_shape) {
|
||||||
|
ApplyMask::MaskType mask_type;
|
||||||
|
if (mask_shape == src_shape) {
|
||||||
|
mask_type = ApplyMask::MaskType::TENSOR;
|
||||||
|
} else if (mask_shape.c == 1) {
|
||||||
|
mask_type = ApplyMask::MaskType::LAYER;
|
||||||
|
} else {
|
||||||
|
mask_type = ApplyMask::MaskType::CHANNELS;
|
||||||
|
}
|
||||||
|
ApplyMask operation(definition, mask_type);
|
||||||
|
operation.SetLinkIndex(0);
|
||||||
|
return operation;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace cl
|
} // namespace cl
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -40,7 +40,7 @@ class MultiplyAdd : public ElementwiseOperation {
|
|||||||
MultiplyAdd(const MultiplyAdd&) = delete;
|
MultiplyAdd(const MultiplyAdd&) = delete;
|
||||||
MultiplyAdd& operator=(const MultiplyAdd&) = delete;
|
MultiplyAdd& operator=(const MultiplyAdd&) = delete;
|
||||||
|
|
||||||
Status UploadMul(const MultiplyScalarAttributes& attr,
|
Status UploadMul(const MultiplyAttributes& attr,
|
||||||
CalculationsPrecision scalar_precision, CLContext* context);
|
CalculationsPrecision scalar_precision, CLContext* context);
|
||||||
Status UploadAdd(const AddAttributes& attr,
|
Status UploadAdd(const AddAttributes& attr,
|
||||||
CalculationsPrecision scalar_precision, CLContext* context);
|
CalculationsPrecision scalar_precision, CLContext* context);
|
||||||
@ -61,7 +61,7 @@ class MultiplyAdd : public ElementwiseOperation {
|
|||||||
|
|
||||||
friend Status CreateMultiplyAdd(const CreationContext& creation_context,
|
friend Status CreateMultiplyAdd(const CreationContext& creation_context,
|
||||||
const OperationDef& definition,
|
const OperationDef& definition,
|
||||||
const MultiplyScalarAttributes& attr,
|
const MultiplyAttributes& attr,
|
||||||
MultiplyAdd* result);
|
MultiplyAdd* result);
|
||||||
|
|
||||||
friend Status CreateMultiplyAdd(const CreationContext& creation_context,
|
friend Status CreateMultiplyAdd(const CreationContext& creation_context,
|
||||||
@ -71,7 +71,7 @@ class MultiplyAdd : public ElementwiseOperation {
|
|||||||
|
|
||||||
friend Status CreateMultiplyAdd(const CreationContext& creation_context,
|
friend Status CreateMultiplyAdd(const CreationContext& creation_context,
|
||||||
const OperationDef& definition,
|
const OperationDef& definition,
|
||||||
const MultiplyScalarAttributes& mul_attr,
|
const MultiplyAttributes& mul_attr,
|
||||||
const AddAttributes& add_attr,
|
const AddAttributes& add_attr,
|
||||||
MultiplyAdd* result);
|
MultiplyAdd* result);
|
||||||
|
|
||||||
@ -91,8 +91,7 @@ class MultiplyAdd : public ElementwiseOperation {
|
|||||||
|
|
||||||
Status CreateMultiplyAdd(const CreationContext& creation_context,
|
Status CreateMultiplyAdd(const CreationContext& creation_context,
|
||||||
const OperationDef& definition,
|
const OperationDef& definition,
|
||||||
const MultiplyScalarAttributes& attr,
|
const MultiplyAttributes& attr, MultiplyAdd* result);
|
||||||
MultiplyAdd* result);
|
|
||||||
|
|
||||||
Status CreateMultiplyAdd(const CreationContext& creation_context,
|
Status CreateMultiplyAdd(const CreationContext& creation_context,
|
||||||
const OperationDef& definition,
|
const OperationDef& definition,
|
||||||
@ -100,7 +99,7 @@ Status CreateMultiplyAdd(const CreationContext& creation_context,
|
|||||||
|
|
||||||
Status CreateMultiplyAdd(const CreationContext& creation_context,
|
Status CreateMultiplyAdd(const CreationContext& creation_context,
|
||||||
const OperationDef& definition,
|
const OperationDef& definition,
|
||||||
const MultiplyScalarAttributes& mul_attr,
|
const MultiplyAttributes& mul_attr,
|
||||||
const AddAttributes& add_attr, MultiplyAdd* result);
|
const AddAttributes& add_attr, MultiplyAdd* result);
|
||||||
|
|
||||||
template <DataType T>
|
template <DataType T>
|
||||||
@ -127,6 +126,36 @@ Status MultiplyAdd::UploadAdd(const ::tflite::gpu::Tensor<Linear, T>& add,
|
|||||||
return OkStatus();
|
return OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class ApplyMask : public ElementwiseOperation {
|
||||||
|
public:
|
||||||
|
// Move only
|
||||||
|
ApplyMask(ApplyMask&& operation);
|
||||||
|
ApplyMask& operator=(ApplyMask&& operation);
|
||||||
|
ApplyMask(const ApplyMask&) = delete;
|
||||||
|
ApplyMask& operator=(const ApplyMask&) = delete;
|
||||||
|
|
||||||
|
void SetLinkIndex(int index) override;
|
||||||
|
std::string GetCoreCode(const LinkingContext& context) const override;
|
||||||
|
std::string GetArgsDeclaration() const override;
|
||||||
|
Status BindArguments(CLKernel* kernel) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
friend ApplyMask CreateApplyMask(const OperationDef& definition,
|
||||||
|
const BHWC& src_shape,
|
||||||
|
const BHWC& mask_shape);
|
||||||
|
|
||||||
|
enum class MaskType { LAYER, CHANNELS, TENSOR };
|
||||||
|
|
||||||
|
explicit ApplyMask(const OperationDef& definition, MaskType mask_type)
|
||||||
|
: ElementwiseOperation(definition), mask_type_(mask_type) {}
|
||||||
|
|
||||||
|
MaskType mask_type_;
|
||||||
|
int link_index_;
|
||||||
|
};
|
||||||
|
|
||||||
|
ApplyMask CreateApplyMask(const OperationDef& definition, const BHWC& src_shape,
|
||||||
|
const BHWC& mask_shape);
|
||||||
|
|
||||||
} // namespace cl
|
} // namespace cl
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -37,7 +37,7 @@ TEST_F(OpenCLOperationTest, MultiplyAddVectorMul) {
|
|||||||
src_tensor.shape = BHWC(1, 2, 1, 2);
|
src_tensor.shape = BHWC(1, 2, 1, 2);
|
||||||
src_tensor.data = {0.0f, 1.0f, 2.0f, 3.0f};
|
src_tensor.data = {0.0f, 1.0f, 2.0f, 3.0f};
|
||||||
|
|
||||||
MultiplyScalarAttributes attr;
|
MultiplyAttributes attr;
|
||||||
::tflite::gpu::Tensor<Linear, DataType::FLOAT32> parameters;
|
::tflite::gpu::Tensor<Linear, DataType::FLOAT32> parameters;
|
||||||
parameters.shape = Linear(2);
|
parameters.shape = Linear(2);
|
||||||
parameters.data = {0.5f, 2.0f};
|
parameters.data = {0.5f, 2.0f};
|
||||||
@ -97,7 +97,7 @@ TEST_F(OpenCLOperationTest, MultiplyAddScalarMul) {
|
|||||||
src_tensor.shape = BHWC(1, 2, 1, 2);
|
src_tensor.shape = BHWC(1, 2, 1, 2);
|
||||||
src_tensor.data = {0.0f, 1.0f, 2.0f, 3.0f};
|
src_tensor.data = {0.0f, 1.0f, 2.0f, 3.0f};
|
||||||
|
|
||||||
MultiplyScalarAttributes attr;
|
MultiplyAttributes attr;
|
||||||
attr.param = 0.5f;
|
attr.param = 0.5f;
|
||||||
|
|
||||||
for (auto storage : env_.GetSupportedStorages()) {
|
for (auto storage : env_.GetSupportedStorages()) {
|
||||||
@ -151,7 +151,7 @@ TEST_F(OpenCLOperationTest, MultiplyAddVectorMad) {
|
|||||||
src_tensor.shape = BHWC(1, 2, 1, 2);
|
src_tensor.shape = BHWC(1, 2, 1, 2);
|
||||||
src_tensor.data = {0.0f, 1.0f, 2.0f, 3.0f};
|
src_tensor.data = {0.0f, 1.0f, 2.0f, 3.0f};
|
||||||
|
|
||||||
MultiplyScalarAttributes mul_attr;
|
MultiplyAttributes mul_attr;
|
||||||
::tflite::gpu::Tensor<Linear, DataType::FLOAT32> parameters;
|
::tflite::gpu::Tensor<Linear, DataType::FLOAT32> parameters;
|
||||||
parameters.shape = Linear(2);
|
parameters.shape = Linear(2);
|
||||||
parameters.data = {0.5f, 2.0f};
|
parameters.data = {0.5f, 2.0f};
|
||||||
@ -181,6 +181,96 @@ TEST_F(OpenCLOperationTest, MultiplyAddVectorMad) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(OpenCLOperationTest, ApplyMaskOneChannel) {
|
||||||
|
TensorFloat32 src_tensor;
|
||||||
|
src_tensor.shape = BHWC(1, 2, 2, 2);
|
||||||
|
src_tensor.data = {-4.0f, -3.0f, -1.0f, 0.0f, 1.0f, 3.0f, 4.0f, 6.0f};
|
||||||
|
TensorFloat32 mask_tensor;
|
||||||
|
mask_tensor.shape = BHWC(1, 2, 2, 1);
|
||||||
|
mask_tensor.data = {2.0f, 0.5f, 1.0f, 0.0f};
|
||||||
|
|
||||||
|
for (auto storage : env_.GetSupportedStorages()) {
|
||||||
|
for (auto precision : env_.GetSupportedPrecisions()) {
|
||||||
|
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
|
||||||
|
OperationDef op_def;
|
||||||
|
op_def.precision = precision;
|
||||||
|
auto data_type = DeduceDataTypeFromPrecision(precision);
|
||||||
|
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||||
|
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||||
|
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
||||||
|
TensorFloat32 dst_tensor;
|
||||||
|
ApplyMask operation =
|
||||||
|
CreateApplyMask(op_def, src_tensor.shape, mask_tensor.shape);
|
||||||
|
ASSERT_OK(ExecuteGPUOperation({src_tensor, mask_tensor},
|
||||||
|
creation_context_, &operation,
|
||||||
|
BHWC(1, 2, 2, 2), &dst_tensor));
|
||||||
|
EXPECT_THAT(dst_tensor.data,
|
||||||
|
Pointwise(FloatNear(eps), {-8.0f, -6.0f, -0.5f, 0.0f, 1.0f,
|
||||||
|
3.0f, 0.0f, 0.0f}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(OpenCLOperationTest, ApplyMaskEqualSizes) {
|
||||||
|
TensorFloat32 src_tensor;
|
||||||
|
src_tensor.shape = BHWC(1, 2, 2, 2);
|
||||||
|
src_tensor.data = {-4.0f, -3.0f, -1.0f, 0.0f, 1.0f, 3.0f, 4.0f, 6.0f};
|
||||||
|
TensorFloat32 mask_tensor;
|
||||||
|
mask_tensor.shape = BHWC(1, 2, 2, 2);
|
||||||
|
mask_tensor.data = {2.0f, 0.5f, 1.0f, 0.0f, 2.0f, 0.5f, 1.0f, 0.0f};
|
||||||
|
|
||||||
|
for (auto storage : env_.GetSupportedStorages()) {
|
||||||
|
for (auto precision : env_.GetSupportedPrecisions()) {
|
||||||
|
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
|
||||||
|
OperationDef op_def;
|
||||||
|
op_def.precision = precision;
|
||||||
|
auto data_type = DeduceDataTypeFromPrecision(precision);
|
||||||
|
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||||
|
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||||
|
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
||||||
|
TensorFloat32 dst_tensor;
|
||||||
|
ApplyMask operation =
|
||||||
|
CreateApplyMask(op_def, src_tensor.shape, mask_tensor.shape);
|
||||||
|
ASSERT_OK(ExecuteGPUOperation({src_tensor, mask_tensor},
|
||||||
|
creation_context_, &operation,
|
||||||
|
BHWC(1, 2, 2, 2), &dst_tensor));
|
||||||
|
EXPECT_THAT(dst_tensor.data,
|
||||||
|
Pointwise(FloatNear(eps), {-8.0f, -1.5f, -1.0f, 0.0f, 2.0f,
|
||||||
|
1.5f, 4.0f, 0.0f}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(OpenCLOperationTest, ApplyMaskVector) {
|
||||||
|
TensorFloat32 src_tensor;
|
||||||
|
src_tensor.shape = BHWC(1, 2, 2, 2);
|
||||||
|
src_tensor.data = {-4.0f, -3.0f, -1.0f, 0.0f, 1.0f, 3.0f, 4.0f, 6.0f};
|
||||||
|
TensorFloat32 mask_tensor;
|
||||||
|
mask_tensor.shape = BHWC(1, 1, 1, 2);
|
||||||
|
mask_tensor.data = {2.0f, 0.5f};
|
||||||
|
|
||||||
|
for (auto storage : env_.GetSupportedStorages()) {
|
||||||
|
for (auto precision : env_.GetSupportedPrecisions()) {
|
||||||
|
const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
|
||||||
|
OperationDef op_def;
|
||||||
|
op_def.precision = precision;
|
||||||
|
auto data_type = DeduceDataTypeFromPrecision(precision);
|
||||||
|
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||||
|
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||||
|
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
||||||
|
TensorFloat32 dst_tensor;
|
||||||
|
ApplyMask operation =
|
||||||
|
CreateApplyMask(op_def, src_tensor.shape, mask_tensor.shape);
|
||||||
|
ASSERT_OK(ExecuteGPUOperation({src_tensor, mask_tensor},
|
||||||
|
creation_context_, &operation,
|
||||||
|
BHWC(1, 2, 2, 2), &dst_tensor));
|
||||||
|
EXPECT_THAT(dst_tensor.data,
|
||||||
|
Pointwise(FloatNear(eps), {-8.0f, -1.5f, -2.0f, 0.0f, 2.0f,
|
||||||
|
1.5f, 8.0f, 3.0f}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace cl
|
} // namespace cl
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
|
@ -103,7 +103,6 @@ cc_library(
|
|||||||
hdrs = ["simple_selectors.h"],
|
hdrs = ["simple_selectors.h"],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/lite/delegates/gpu/cl/kernels:add",
|
"//tensorflow/lite/delegates/gpu/cl/kernels:add",
|
||||||
"//tensorflow/lite/delegates/gpu/cl/kernels:apply_mask",
|
|
||||||
"//tensorflow/lite/delegates/gpu/cl/kernels:concat_xy",
|
"//tensorflow/lite/delegates/gpu/cl/kernels:concat_xy",
|
||||||
"//tensorflow/lite/delegates/gpu/cl/kernels:concat_z",
|
"//tensorflow/lite/delegates/gpu/cl/kernels:concat_z",
|
||||||
"//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation",
|
"//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation",
|
||||||
|
@ -68,11 +68,6 @@ Status GPUOperationFromNode(const CreationContext& creation_context,
|
|||||||
return OkStatus();
|
return OkStatus();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case OperationType::APPLY_MASK: {
|
|
||||||
SelectApplyMask(op_def, inputs[0]->tensor.shape, inputs[1]->tensor.shape,
|
|
||||||
gpu_op);
|
|
||||||
return OkStatus();
|
|
||||||
}
|
|
||||||
case OperationType::CONCAT: {
|
case OperationType::CONCAT: {
|
||||||
auto attr = absl::any_cast<ConcatAttributes>(node.operation.attributes);
|
auto attr = absl::any_cast<ConcatAttributes>(node.operation.attributes);
|
||||||
std::vector<int> channels(inputs.size());
|
std::vector<int> channels(inputs.size());
|
||||||
@ -119,10 +114,17 @@ Status GPUOperationFromNode(const CreationContext& creation_context,
|
|||||||
auto attr = absl::any_cast<MeanAttributes>(node.operation.attributes);
|
auto attr = absl::any_cast<MeanAttributes>(node.operation.attributes);
|
||||||
return SelectMean(attr, op_def, gpu_op);
|
return SelectMean(attr, op_def, gpu_op);
|
||||||
}
|
}
|
||||||
case OperationType::MULTIPLY_SCALAR: {
|
case OperationType::MUL: {
|
||||||
auto attr =
|
if (node.operation.attributes.has_value()) {
|
||||||
absl::any_cast<MultiplyScalarAttributes>(node.operation.attributes);
|
auto attr =
|
||||||
return SelectMultiplyScalar(attr, creation_context, op_def, gpu_op);
|
absl::any_cast<MultiplyAttributes>(node.operation.attributes);
|
||||||
|
|
||||||
|
return SelectMultiplyScalar(attr, creation_context, op_def, gpu_op);
|
||||||
|
} else {
|
||||||
|
SelectApplyMask(op_def, inputs[0]->tensor.shape,
|
||||||
|
inputs[1]->tensor.shape, gpu_op);
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
case OperationType::PAD: {
|
case OperationType::PAD: {
|
||||||
auto attr = absl::any_cast<PadAttributes>(node.operation.attributes);
|
auto attr = absl::any_cast<PadAttributes>(node.operation.attributes);
|
||||||
|
@ -20,7 +20,6 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/add.h"
|
#include "tensorflow/lite/delegates/gpu/cl/kernels/add.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/apply_mask.h"
|
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/concat_xy.h"
|
#include "tensorflow/lite/delegates/gpu/cl/kernels/concat_xy.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/concat_z.h"
|
#include "tensorflow/lite/delegates/gpu/cl/kernels/concat_z.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/lstm.h"
|
#include "tensorflow/lite/delegates/gpu/cl/kernels/lstm.h"
|
||||||
@ -155,7 +154,7 @@ Status SelectMean(const MeanAttributes& attr, const OperationDef& op_def,
|
|||||||
return OkStatus();
|
return OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status SelectMultiplyScalar(const MultiplyScalarAttributes& attr,
|
Status SelectMultiplyScalar(const MultiplyAttributes& attr,
|
||||||
const CreationContext& creation_context,
|
const CreationContext& creation_context,
|
||||||
const OperationDef& op_def,
|
const OperationDef& op_def,
|
||||||
std::unique_ptr<GPUOperation>* ptr) {
|
std::unique_ptr<GPUOperation>* ptr) {
|
||||||
|
@ -73,7 +73,7 @@ void SelectStridedSlice(const SliceAttributes& attr, const OperationDef& op_def,
|
|||||||
Status SelectMean(const MeanAttributes& attr, const OperationDef& op_def,
|
Status SelectMean(const MeanAttributes& attr, const OperationDef& op_def,
|
||||||
std::unique_ptr<GPUOperation>* ptr);
|
std::unique_ptr<GPUOperation>* ptr);
|
||||||
|
|
||||||
Status SelectMultiplyScalar(const MultiplyScalarAttributes& attr,
|
Status SelectMultiplyScalar(const MultiplyAttributes& attr,
|
||||||
const CreationContext& creation_context,
|
const CreationContext& creation_context,
|
||||||
const OperationDef& op_def,
|
const OperationDef& op_def,
|
||||||
std::unique_ptr<GPUOperation>* ptr);
|
std::unique_ptr<GPUOperation>* ptr);
|
||||||
|
@ -1395,8 +1395,11 @@ class MulOperationParser : public TFLiteOperationParser {
|
|||||||
const bool runtime_tensor0 = !constant_tensor0;
|
const bool runtime_tensor0 = !constant_tensor0;
|
||||||
const bool runtime_tensor1 = !constant_tensor1;
|
const bool runtime_tensor1 = !constant_tensor1;
|
||||||
|
|
||||||
// Parse for APPLY_MASK. The "larger" input tensor must be bound to 1st
|
Node* node = graph->NewNode();
|
||||||
// input and the "smaller" input tensor ("mask") must be bound to 2nd input.
|
node->operation.type = ToString(OperationType::MUL);
|
||||||
|
|
||||||
|
// The "larger" input tensor must be bound to 1st input and the "smaller"
|
||||||
|
// input tensor ("mask") must be bound to 2nd input.
|
||||||
if (runtime_tensor0 && runtime_tensor1) {
|
if (runtime_tensor0 && runtime_tensor1) {
|
||||||
BHWC shape0;
|
BHWC shape0;
|
||||||
RETURN_IF_ERROR(ExtractTensorShape(*input0, &shape0));
|
RETURN_IF_ERROR(ExtractTensorShape(*input0, &shape0));
|
||||||
@ -1409,11 +1412,11 @@ class MulOperationParser : public TFLiteOperationParser {
|
|||||||
input_tensor0 = 1;
|
input_tensor0 = 1;
|
||||||
input_tensor1 = 0;
|
input_tensor1 = 0;
|
||||||
}
|
}
|
||||||
return ParseApplyMask(input_tensor0, input_tensor1, graph, reader);
|
return ParseApplyMask(node, input_tensor0, input_tensor1, graph, reader);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse for MULTIPLY_SCALAR. The runtime input tensor must be bound to 1st
|
// The runtime input tensor must be bound to 1st input and the constant
|
||||||
// input and the constant input tensor must be bound to 2nd input.
|
// input tensor must be bound to 2nd input.
|
||||||
int runtime_tensor = 0;
|
int runtime_tensor = 0;
|
||||||
int constant_tensor = 1;
|
int constant_tensor = 1;
|
||||||
TfLiteIntArray* constant_dims = input1->dims;
|
TfLiteIntArray* constant_dims = input1->dims;
|
||||||
@ -1422,27 +1425,24 @@ class MulOperationParser : public TFLiteOperationParser {
|
|||||||
constant_tensor = 0;
|
constant_tensor = 0;
|
||||||
constant_dims = input0->dims;
|
constant_dims = input0->dims;
|
||||||
}
|
}
|
||||||
return ParseMultiplyScalar(runtime_tensor, constant_tensor, constant_dims,
|
return ParseMultiplyScalar(node, runtime_tensor, constant_tensor,
|
||||||
graph, reader);
|
constant_dims, graph, reader);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Status ParseApplyMask(int input_tensor0, int input_tensor1,
|
Status ParseApplyMask(Node* node, int input_tensor0, int input_tensor1,
|
||||||
GraphFloat32* graph, ObjectReader* reader) {
|
GraphFloat32* graph, ObjectReader* reader) {
|
||||||
Node* node = graph->NewNode();
|
|
||||||
node->operation.type = ToString(OperationType::APPLY_MASK);
|
|
||||||
RETURN_IF_ERROR(reader->AddInput(node, input_tensor0));
|
RETURN_IF_ERROR(reader->AddInput(node, input_tensor0));
|
||||||
RETURN_IF_ERROR(reader->AddInput(node, input_tensor1));
|
RETURN_IF_ERROR(reader->AddInput(node, input_tensor1));
|
||||||
return reader->AddOutputs(node);
|
return reader->AddOutputs(node);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ParseMultiplyScalar(int runtime_tensor, int constant_tensor,
|
Status ParseMultiplyScalar(Node* node, int runtime_tensor,
|
||||||
|
int constant_tensor,
|
||||||
const TfLiteIntArray* constant_dims,
|
const TfLiteIntArray* constant_dims,
|
||||||
GraphFloat32* graph, ObjectReader* reader) {
|
GraphFloat32* graph, ObjectReader* reader) {
|
||||||
Node* node = graph->NewNode();
|
|
||||||
node->operation.type = ToString(OperationType::MULTIPLY_SCALAR);
|
|
||||||
RETURN_IF_ERROR(reader->AddInput(node, runtime_tensor));
|
RETURN_IF_ERROR(reader->AddInput(node, runtime_tensor));
|
||||||
MultiplyScalarAttributes attr;
|
MultiplyAttributes attr;
|
||||||
if (constant_dims->size <= 0) {
|
if (constant_dims->size <= 0) {
|
||||||
Tensor<Scalar, DataType::FLOAT32> tensor;
|
Tensor<Scalar, DataType::FLOAT32> tensor;
|
||||||
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
|
RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
|
||||||
|
@ -72,8 +72,6 @@ std::string ToString(enum OperationType op) {
|
|||||||
return "abs";
|
return "abs";
|
||||||
case OperationType::ADD:
|
case OperationType::ADD:
|
||||||
return "add";
|
return "add";
|
||||||
case OperationType::APPLY_MASK:
|
|
||||||
return "apply_mask";
|
|
||||||
case OperationType::BATCH_NORMALIZATION:
|
case OperationType::BATCH_NORMALIZATION:
|
||||||
return "batch_normalization";
|
return "batch_normalization";
|
||||||
case OperationType::BATCH_TO_SPACE:
|
case OperationType::BATCH_TO_SPACE:
|
||||||
@ -106,8 +104,6 @@ std::string ToString(enum OperationType op) {
|
|||||||
return "mean";
|
return "mean";
|
||||||
case OperationType::MUL:
|
case OperationType::MUL:
|
||||||
return "mul";
|
return "mul";
|
||||||
case OperationType::MULTIPLY_SCALAR:
|
|
||||||
return "multiply_scalar";
|
|
||||||
case OperationType::PAD:
|
case OperationType::PAD:
|
||||||
return "pad";
|
return "pad";
|
||||||
case OperationType::POOLING_2D:
|
case OperationType::POOLING_2D:
|
||||||
@ -157,7 +153,6 @@ OperationType OperationTypeFromString(const std::string& name) {
|
|||||||
new std::unordered_map<std::string, OperationType>({
|
new std::unordered_map<std::string, OperationType>({
|
||||||
{"abs", OperationType::ABS},
|
{"abs", OperationType::ABS},
|
||||||
{"add", OperationType::ADD},
|
{"add", OperationType::ADD},
|
||||||
{"apply_mask", OperationType::APPLY_MASK},
|
|
||||||
{"batch_normalization", OperationType::BATCH_NORMALIZATION},
|
{"batch_normalization", OperationType::BATCH_NORMALIZATION},
|
||||||
{"concat", OperationType::CONCAT},
|
{"concat", OperationType::CONCAT},
|
||||||
{"const", OperationType::CONST},
|
{"const", OperationType::CONST},
|
||||||
@ -173,7 +168,6 @@ OperationType OperationTypeFromString(const std::string& name) {
|
|||||||
{"max_unpooling", OperationType::MAX_UNPOOLING_2D},
|
{"max_unpooling", OperationType::MAX_UNPOOLING_2D},
|
||||||
{"mean", OperationType::MEAN},
|
{"mean", OperationType::MEAN},
|
||||||
{"mul", OperationType::MUL},
|
{"mul", OperationType::MUL},
|
||||||
{"multiply_scalar", OperationType::MULTIPLY_SCALAR},
|
|
||||||
{"pad", OperationType::PAD},
|
{"pad", OperationType::PAD},
|
||||||
{"pooling_2d", OperationType::POOLING_2D},
|
{"pooling_2d", OperationType::POOLING_2D},
|
||||||
{"pow", OperationType::POW},
|
{"pow", OperationType::POW},
|
||||||
|
@ -34,8 +34,6 @@ enum class OperationType {
|
|||||||
UNKNOWN = 0,
|
UNKNOWN = 0,
|
||||||
ABS,
|
ABS,
|
||||||
ADD,
|
ADD,
|
||||||
// TODO(eignasheva): remove APPLY_MASK operation, is should be just MUL
|
|
||||||
APPLY_MASK,
|
|
||||||
BATCH_TO_SPACE,
|
BATCH_TO_SPACE,
|
||||||
BATCH_NORMALIZATION,
|
BATCH_NORMALIZATION,
|
||||||
CONCAT,
|
CONCAT,
|
||||||
@ -52,7 +50,6 @@ enum class OperationType {
|
|||||||
MAX_UNPOOLING_2D,
|
MAX_UNPOOLING_2D,
|
||||||
MEAN,
|
MEAN,
|
||||||
MUL,
|
MUL,
|
||||||
MULTIPLY_SCALAR,
|
|
||||||
PAD,
|
PAD,
|
||||||
POOLING_2D,
|
POOLING_2D,
|
||||||
POW,
|
POW,
|
||||||
@ -354,7 +351,7 @@ struct LstmAttributes {
|
|||||||
LstmKernelType kernel_type = LstmKernelType::BASIC;
|
LstmKernelType kernel_type = LstmKernelType::BASIC;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct MultiplyScalarAttributes {
|
struct MultiplyAttributes {
|
||||||
absl::variant<absl::monostate, Tensor<Linear, DataType::FLOAT32>, float>
|
absl::variant<absl::monostate, Tensor<Linear, DataType::FLOAT32>, float>
|
||||||
param;
|
param;
|
||||||
};
|
};
|
||||||
|
@ -33,15 +33,14 @@ class MergeConvolutionWithMul : public SequenceTransformation {
|
|||||||
GraphFloat32* graph) final {
|
GraphFloat32* graph) final {
|
||||||
auto& conv_node = *sequence[0];
|
auto& conv_node = *sequence[0];
|
||||||
auto& mul_node = *sequence[1];
|
auto& mul_node = *sequence[1];
|
||||||
if (mul_node.operation.type != ToString(OperationType::MUL) &&
|
if (mul_node.operation.type != ToString(OperationType::MUL) ||
|
||||||
mul_node.operation.type != ToString(OperationType::MULTIPLY_SCALAR)) {
|
!mul_node.operation.attributes.has_value()) {
|
||||||
return {TransformStatus::SKIPPED, ""};
|
return {TransformStatus::SKIPPED, ""};
|
||||||
}
|
}
|
||||||
|
|
||||||
MultiplyScalarAttributes mul_attr =
|
MultiplyAttributes mul_attr =
|
||||||
absl::any_cast<MultiplyScalarAttributes>(mul_node.operation.attributes);
|
absl::any_cast<MultiplyAttributes>(mul_node.operation.attributes);
|
||||||
if (!absl::get_if<Tensor<Linear, DataType::FLOAT32>>(
|
if (!absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param) &&
|
||||||
&mul_attr.param) &&
|
|
||||||
!absl::get_if<float>(&mul_attr.param)) {
|
!absl::get_if<float>(&mul_attr.param)) {
|
||||||
return {
|
return {
|
||||||
TransformStatus::DECLINED,
|
TransformStatus::DECLINED,
|
||||||
@ -93,13 +92,13 @@ class MergeMulWithConvolution : public SequenceTransformation {
|
|||||||
GraphFloat32* graph) final {
|
GraphFloat32* graph) final {
|
||||||
auto& conv_node = *sequence[1];
|
auto& conv_node = *sequence[1];
|
||||||
auto& mul_node = *sequence[0];
|
auto& mul_node = *sequence[0];
|
||||||
if (mul_node.operation.type != ToString(OperationType::MUL) &&
|
if (mul_node.operation.type != ToString(OperationType::MUL) ||
|
||||||
mul_node.operation.type != ToString(OperationType::MULTIPLY_SCALAR)) {
|
!mul_node.operation.attributes.has_value()) {
|
||||||
return {TransformStatus::SKIPPED, ""};
|
return {TransformStatus::SKIPPED, ""};
|
||||||
}
|
}
|
||||||
|
|
||||||
MultiplyScalarAttributes mul_attr =
|
MultiplyAttributes mul_attr =
|
||||||
absl::any_cast<MultiplyScalarAttributes>(mul_node.operation.attributes);
|
absl::any_cast<MultiplyAttributes>(mul_node.operation.attributes);
|
||||||
if (!absl::get_if<Tensor<Linear, DataType::FLOAT32>>(
|
if (!absl::get_if<Tensor<Linear, DataType::FLOAT32>>(
|
||||||
&mul_attr.param) &&
|
&mul_attr.param) &&
|
||||||
!absl::get_if<float>(&mul_attr.param)) {
|
!absl::get_if<float>(&mul_attr.param)) {
|
||||||
@ -155,7 +154,7 @@ std::unique_ptr<SequenceTransformation> NewMergeMulWithConvolution() {
|
|||||||
return absl::make_unique<MergeMulWithConvolution>();
|
return absl::make_unique<MergeMulWithConvolution>();
|
||||||
}
|
}
|
||||||
|
|
||||||
void FuseConvolution2DWithMultiply(const MultiplyScalarAttributes& mul_attr,
|
void FuseConvolution2DWithMultiply(const MultiplyAttributes& mul_attr,
|
||||||
Convolution2DAttributes* attr) {
|
Convolution2DAttributes* attr) {
|
||||||
auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
|
auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
|
||||||
auto mul_scalar = absl::get_if<float>(&mul_attr.param);
|
auto mul_scalar = absl::get_if<float>(&mul_attr.param);
|
||||||
@ -176,7 +175,7 @@ void FuseConvolution2DWithMultiply(const MultiplyScalarAttributes& mul_attr,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void FuseDepthwiseConvolution2DWithMultiply(
|
void FuseDepthwiseConvolution2DWithMultiply(
|
||||||
const MultiplyScalarAttributes& mul_attr,
|
const MultiplyAttributes& mul_attr,
|
||||||
DepthwiseConvolution2DAttributes* attr) {
|
DepthwiseConvolution2DAttributes* attr) {
|
||||||
auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
|
auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
|
||||||
auto mul_scalar = absl::get_if<float>(&mul_attr.param);
|
auto mul_scalar = absl::get_if<float>(&mul_attr.param);
|
||||||
@ -198,8 +197,7 @@ void FuseDepthwiseConvolution2DWithMultiply(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void FuseConvolutionTransposedWithMultiply(
|
void FuseConvolutionTransposedWithMultiply(
|
||||||
const MultiplyScalarAttributes& mul_attr,
|
const MultiplyAttributes& mul_attr, ConvolutionTransposedAttributes* attr) {
|
||||||
ConvolutionTransposedAttributes* attr) {
|
|
||||||
auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
|
auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
|
||||||
auto mul_scalar = absl::get_if<float>(&mul_attr.param);
|
auto mul_scalar = absl::get_if<float>(&mul_attr.param);
|
||||||
for (int d = 0; d < attr->weights.shape.o; ++d) {
|
for (int d = 0; d < attr->weights.shape.o; ++d) {
|
||||||
@ -218,7 +216,7 @@ void FuseConvolutionTransposedWithMultiply(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void FuseFullyConnectedWithMultiply(const MultiplyScalarAttributes& mul_attr,
|
void FuseFullyConnectedWithMultiply(const MultiplyAttributes& mul_attr,
|
||||||
FullyConnectedAttributes* attr) {
|
FullyConnectedAttributes* attr) {
|
||||||
auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
|
auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
|
||||||
auto mul_scalar = absl::get_if<float>(&mul_attr.param);
|
auto mul_scalar = absl::get_if<float>(&mul_attr.param);
|
||||||
@ -234,7 +232,7 @@ void FuseFullyConnectedWithMultiply(const MultiplyScalarAttributes& mul_attr,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void FuseMultiplyWithConvolution2D(const MultiplyScalarAttributes& mul_attr,
|
void FuseMultiplyWithConvolution2D(const MultiplyAttributes& mul_attr,
|
||||||
Convolution2DAttributes* attr) {
|
Convolution2DAttributes* attr) {
|
||||||
auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
|
auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
|
||||||
auto mul_scalar = absl::get_if<float>(&mul_attr.param);
|
auto mul_scalar = absl::get_if<float>(&mul_attr.param);
|
||||||
@ -252,7 +250,7 @@ void FuseMultiplyWithConvolution2D(const MultiplyScalarAttributes& mul_attr,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void FuseMultiplyWithDepthwiseConvolution2D(
|
void FuseMultiplyWithDepthwiseConvolution2D(
|
||||||
const MultiplyScalarAttributes& mul_attr,
|
const MultiplyAttributes& mul_attr,
|
||||||
DepthwiseConvolution2DAttributes* attr) {
|
DepthwiseConvolution2DAttributes* attr) {
|
||||||
auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
|
auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
|
||||||
auto mul_scalar = absl::get_if<float>(&mul_attr.param);
|
auto mul_scalar = absl::get_if<float>(&mul_attr.param);
|
||||||
@ -270,8 +268,7 @@ void FuseMultiplyWithDepthwiseConvolution2D(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void FuseMultiplyWithConvolutionTransposed(
|
void FuseMultiplyWithConvolutionTransposed(
|
||||||
const MultiplyScalarAttributes& mul_attr,
|
const MultiplyAttributes& mul_attr, ConvolutionTransposedAttributes* attr) {
|
||||||
ConvolutionTransposedAttributes* attr) {
|
|
||||||
auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
|
auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
|
||||||
auto mul_scalar = absl::get_if<float>(&mul_attr.param);
|
auto mul_scalar = absl::get_if<float>(&mul_attr.param);
|
||||||
for (int s = 0; s < attr->weights.shape.i; ++s) {
|
for (int s = 0; s < attr->weights.shape.i; ++s) {
|
||||||
@ -287,7 +284,7 @@ void FuseMultiplyWithConvolutionTransposed(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void FuseMultiplyWithFullyConnected(const MultiplyScalarAttributes& mul_attr,
|
void FuseMultiplyWithFullyConnected(const MultiplyAttributes& mul_attr,
|
||||||
FullyConnectedAttributes* attr) {
|
FullyConnectedAttributes* attr) {
|
||||||
auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
|
auto mul = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&mul_attr.param);
|
||||||
auto mul_scalar = absl::get_if<float>(&mul_attr.param);
|
auto mul_scalar = absl::get_if<float>(&mul_attr.param);
|
||||||
|
@ -38,53 +38,49 @@ std::unique_ptr<SequenceTransformation> NewMergeMulWithConvolution();
|
|||||||
// Modify Convolution2DAttributes so that after making convolution with
|
// Modify Convolution2DAttributes so that after making convolution with
|
||||||
// modified attributes we will have the same result as convolution
|
// modified attributes we will have the same result as convolution
|
||||||
// with old attributes and following multiply operation.
|
// with old attributes and following multiply operation.
|
||||||
void FuseConvolution2DWithMultiply(const MultiplyScalarAttributes& mul_attr,
|
void FuseConvolution2DWithMultiply(const MultiplyAttributes& mul_attr,
|
||||||
Convolution2DAttributes* attr);
|
Convolution2DAttributes* attr);
|
||||||
|
|
||||||
// Modify DepthwiseConvolution2DAttributes so that after making depth wise
|
// Modify DepthwiseConvolution2DAttributes so that after making depth wise
|
||||||
// convolution with modified attributes we will have the same result as depth
|
// convolution with modified attributes we will have the same result as depth
|
||||||
// wise convolution with old attributes and following multiply operation.
|
// wise convolution with old attributes and following multiply operation.
|
||||||
void FuseDepthwiseConvolution2DWithMultiply(
|
void FuseDepthwiseConvolution2DWithMultiply(
|
||||||
const MultiplyScalarAttributes& mul_attr,
|
const MultiplyAttributes& mul_attr, DepthwiseConvolution2DAttributes* attr);
|
||||||
DepthwiseConvolution2DAttributes* attr);
|
|
||||||
|
|
||||||
// Modify ConvolutionTransposedAttributes so that after making convolution
|
// Modify ConvolutionTransposedAttributes so that after making convolution
|
||||||
// transposed with modified attributes we will have the same result as
|
// transposed with modified attributes we will have the same result as
|
||||||
// convolution transposed with old attributes and following multiply operation.
|
// convolution transposed with old attributes and following multiply operation.
|
||||||
void FuseConvolutionTransposedWithMultiply(
|
void FuseConvolutionTransposedWithMultiply(
|
||||||
const MultiplyScalarAttributes& mul_attr,
|
const MultiplyAttributes& mul_attr, ConvolutionTransposedAttributes* attr);
|
||||||
ConvolutionTransposedAttributes* attr);
|
|
||||||
|
|
||||||
// Modify FullyConnectedAttributes so that after making fully connected with
|
// Modify FullyConnectedAttributes so that after making fully connected with
|
||||||
// modified attributes we will have the same result as fully connected
|
// modified attributes we will have the same result as fully connected
|
||||||
// with old attributes and following multiply operation.
|
// with old attributes and following multiply operation.
|
||||||
void FuseFullyConnectedWithMultiply(const MultiplyScalarAttributes& mul_attr,
|
void FuseFullyConnectedWithMultiply(const MultiplyAttributes& mul_attr,
|
||||||
FullyConnectedAttributes* attr);
|
FullyConnectedAttributes* attr);
|
||||||
|
|
||||||
// Modify Convolution2DAttributes so that after making convolution with
|
// Modify Convolution2DAttributes so that after making convolution with
|
||||||
// modified attributes we will have the same result as multiply operation and
|
// modified attributes we will have the same result as multiply operation and
|
||||||
// convolution with old attributes
|
// convolution with old attributes
|
||||||
void FuseMultiplyWithConvolution2D(const MultiplyScalarAttributes& mul_attr,
|
void FuseMultiplyWithConvolution2D(const MultiplyAttributes& mul_attr,
|
||||||
Convolution2DAttributes* attr);
|
Convolution2DAttributes* attr);
|
||||||
|
|
||||||
// Modify DepthwiseConvolution2DAttributes so that after making depth wise
|
// Modify DepthwiseConvolution2DAttributes so that after making depth wise
|
||||||
// convolution with modified attributes we will have the same result as multiply
|
// convolution with modified attributes we will have the same result as multiply
|
||||||
// operation and depth wise convolution with old attributes
|
// operation and depth wise convolution with old attributes
|
||||||
void FuseMultiplyWithDepthwiseConvolution2D(
|
void FuseMultiplyWithDepthwiseConvolution2D(
|
||||||
const MultiplyScalarAttributes& mul_attr,
|
const MultiplyAttributes& mul_attr, DepthwiseConvolution2DAttributes* attr);
|
||||||
DepthwiseConvolution2DAttributes* attr);
|
|
||||||
|
|
||||||
// Modify ConvolutionTransposedAttributes so that after making convolution
|
// Modify ConvolutionTransposedAttributes so that after making convolution
|
||||||
// transposed with modified attributes we will have the same result as multiply
|
// transposed with modified attributes we will have the same result as multiply
|
||||||
// operation and convolution transposed with old attributes
|
// operation and convolution transposed with old attributes
|
||||||
void FuseMultiplyWithConvolutionTransposed(
|
void FuseMultiplyWithConvolutionTransposed(
|
||||||
const MultiplyScalarAttributes& mul_attr,
|
const MultiplyAttributes& mul_attr, ConvolutionTransposedAttributes* attr);
|
||||||
ConvolutionTransposedAttributes* attr);
|
|
||||||
|
|
||||||
// Modify FullyConnectedAttributes so that after making fully connected
|
// Modify FullyConnectedAttributes so that after making fully connected
|
||||||
// with modified attributes we will have the same result as multiply
|
// with modified attributes we will have the same result as multiply
|
||||||
// operation and fully connected with old attributes
|
// operation and fully connected with old attributes
|
||||||
void FuseMultiplyWithFullyConnected(const MultiplyScalarAttributes& mul_attr,
|
void FuseMultiplyWithFullyConnected(const MultiplyAttributes& mul_attr,
|
||||||
FullyConnectedAttributes* attr);
|
FullyConnectedAttributes* attr);
|
||||||
|
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
|
@ -46,7 +46,7 @@ TEST(MergeConvolutionWithMulTest, Smoke) {
|
|||||||
Tensor<Linear, DataType::FLOAT32> mul_tensor;
|
Tensor<Linear, DataType::FLOAT32> mul_tensor;
|
||||||
mul_tensor.shape = Linear(16);
|
mul_tensor.shape = Linear(16);
|
||||||
mul_tensor.data.resize(16);
|
mul_tensor.data.resize(16);
|
||||||
MultiplyScalarAttributes mul_attr;
|
MultiplyAttributes mul_attr;
|
||||||
mul_attr.param = mul_tensor;
|
mul_attr.param = mul_tensor;
|
||||||
|
|
||||||
auto conv_node = graph.NewNode();
|
auto conv_node = graph.NewNode();
|
||||||
@ -87,7 +87,7 @@ TEST(MergeMulWithConvolutionTest, Smoke) {
|
|||||||
Tensor<Linear, DataType::FLOAT32> mul_tensor;
|
Tensor<Linear, DataType::FLOAT32> mul_tensor;
|
||||||
mul_tensor.shape = Linear(8);
|
mul_tensor.shape = Linear(8);
|
||||||
mul_tensor.data.resize(8);
|
mul_tensor.data.resize(8);
|
||||||
MultiplyScalarAttributes mul_attr;
|
MultiplyAttributes mul_attr;
|
||||||
mul_attr.param = mul_tensor;
|
mul_attr.param = mul_tensor;
|
||||||
|
|
||||||
Convolution2DAttributes conv_attr;
|
Convolution2DAttributes conv_attr;
|
||||||
@ -140,7 +140,7 @@ TEST(FuseMulAfterConvolution2DTest, Smoke) {
|
|||||||
Tensor<Linear, DataType::FLOAT32> mul_tensor;
|
Tensor<Linear, DataType::FLOAT32> mul_tensor;
|
||||||
mul_tensor.shape = Linear(2);
|
mul_tensor.shape = Linear(2);
|
||||||
mul_tensor.data = {0.5f, 2.0f};
|
mul_tensor.data = {0.5f, 2.0f};
|
||||||
MultiplyScalarAttributes mul_attr;
|
MultiplyAttributes mul_attr;
|
||||||
mul_attr.param = mul_tensor;
|
mul_attr.param = mul_tensor;
|
||||||
|
|
||||||
FuseConvolution2DWithMultiply(mul_attr, &attr);
|
FuseConvolution2DWithMultiply(mul_attr, &attr);
|
||||||
@ -161,7 +161,7 @@ TEST(FuseMulAfterDepthwiseConvolution2DTest, Smoke) {
|
|||||||
Tensor<Linear, DataType::FLOAT32> mul_tensor;
|
Tensor<Linear, DataType::FLOAT32> mul_tensor;
|
||||||
mul_tensor.shape = Linear(4);
|
mul_tensor.shape = Linear(4);
|
||||||
mul_tensor.data = {0.5f, 2.0f, 4.0f, 0.25f};
|
mul_tensor.data = {0.5f, 2.0f, 4.0f, 0.25f};
|
||||||
MultiplyScalarAttributes mul_attr;
|
MultiplyAttributes mul_attr;
|
||||||
mul_attr.param = mul_tensor;
|
mul_attr.param = mul_tensor;
|
||||||
|
|
||||||
FuseDepthwiseConvolution2DWithMultiply(mul_attr, &attr);
|
FuseDepthwiseConvolution2DWithMultiply(mul_attr, &attr);
|
||||||
@ -183,7 +183,7 @@ TEST(FuseMulAfterConvolutionTransposedTest, Smoke) {
|
|||||||
Tensor<Linear, DataType::FLOAT32> mul_tensor;
|
Tensor<Linear, DataType::FLOAT32> mul_tensor;
|
||||||
mul_tensor.shape = Linear(2);
|
mul_tensor.shape = Linear(2);
|
||||||
mul_tensor.data = {0.5f, 2.0f};
|
mul_tensor.data = {0.5f, 2.0f};
|
||||||
MultiplyScalarAttributes mul_attr;
|
MultiplyAttributes mul_attr;
|
||||||
mul_attr.param = mul_tensor;
|
mul_attr.param = mul_tensor;
|
||||||
|
|
||||||
FuseConvolutionTransposedWithMultiply(mul_attr, &attr);
|
FuseConvolutionTransposedWithMultiply(mul_attr, &attr);
|
||||||
@ -204,7 +204,7 @@ TEST(FuseMulAfterFullyConnectedTest, Smoke) {
|
|||||||
Tensor<Linear, DataType::FLOAT32> mul_tensor;
|
Tensor<Linear, DataType::FLOAT32> mul_tensor;
|
||||||
mul_tensor.shape = Linear(2);
|
mul_tensor.shape = Linear(2);
|
||||||
mul_tensor.data = {0.5f, 2.0f};
|
mul_tensor.data = {0.5f, 2.0f};
|
||||||
MultiplyScalarAttributes mul_attr;
|
MultiplyAttributes mul_attr;
|
||||||
mul_attr.param = mul_tensor;
|
mul_attr.param = mul_tensor;
|
||||||
|
|
||||||
FuseFullyConnectedWithMultiply(mul_attr, &attr);
|
FuseFullyConnectedWithMultiply(mul_attr, &attr);
|
||||||
@ -224,7 +224,7 @@ TEST(FuseMulBeforeConvolution2DTest, Smoke) {
|
|||||||
Tensor<Linear, DataType::FLOAT32> mul_tensor;
|
Tensor<Linear, DataType::FLOAT32> mul_tensor;
|
||||||
mul_tensor.shape = Linear(2);
|
mul_tensor.shape = Linear(2);
|
||||||
mul_tensor.data = {0.5f, 2.0f};
|
mul_tensor.data = {0.5f, 2.0f};
|
||||||
MultiplyScalarAttributes mul_attr;
|
MultiplyAttributes mul_attr;
|
||||||
mul_attr.param = mul_tensor;
|
mul_attr.param = mul_tensor;
|
||||||
|
|
||||||
FuseMultiplyWithConvolution2D(mul_attr, &attr);
|
FuseMultiplyWithConvolution2D(mul_attr, &attr);
|
||||||
@ -245,7 +245,7 @@ TEST(FuseMulBeforeDepthwiseConvolution2DTest, Smoke) {
|
|||||||
Tensor<Linear, DataType::FLOAT32> mul_tensor;
|
Tensor<Linear, DataType::FLOAT32> mul_tensor;
|
||||||
mul_tensor.shape = Linear(4);
|
mul_tensor.shape = Linear(4);
|
||||||
mul_tensor.data = {0.5f, 2.0f, 4.0f, 0.25f};
|
mul_tensor.data = {0.5f, 2.0f, 4.0f, 0.25f};
|
||||||
MultiplyScalarAttributes mul_attr;
|
MultiplyAttributes mul_attr;
|
||||||
mul_attr.param = mul_tensor;
|
mul_attr.param = mul_tensor;
|
||||||
|
|
||||||
FuseMultiplyWithDepthwiseConvolution2D(mul_attr, &attr);
|
FuseMultiplyWithDepthwiseConvolution2D(mul_attr, &attr);
|
||||||
@ -267,7 +267,7 @@ TEST(FuseMulBeforeConvolutionTransposedTest, Smoke) {
|
|||||||
Tensor<Linear, DataType::FLOAT32> mul_tensor;
|
Tensor<Linear, DataType::FLOAT32> mul_tensor;
|
||||||
mul_tensor.shape = Linear(2);
|
mul_tensor.shape = Linear(2);
|
||||||
mul_tensor.data = {0.5f, 2.0f};
|
mul_tensor.data = {0.5f, 2.0f};
|
||||||
MultiplyScalarAttributes mul_attr;
|
MultiplyAttributes mul_attr;
|
||||||
mul_attr.param = mul_tensor;
|
mul_attr.param = mul_tensor;
|
||||||
|
|
||||||
FuseMultiplyWithConvolutionTransposed(mul_attr, &attr);
|
FuseMultiplyWithConvolutionTransposed(mul_attr, &attr);
|
||||||
@ -288,7 +288,7 @@ TEST(FuseMulBeforeFullyConnectedTest, Smoke) {
|
|||||||
Tensor<Linear, DataType::FLOAT32> mul_tensor;
|
Tensor<Linear, DataType::FLOAT32> mul_tensor;
|
||||||
mul_tensor.shape = Linear(2);
|
mul_tensor.shape = Linear(2);
|
||||||
mul_tensor.data = {0.5f, 2.0f};
|
mul_tensor.data = {0.5f, 2.0f};
|
||||||
MultiplyScalarAttributes mul_attr;
|
MultiplyAttributes mul_attr;
|
||||||
mul_attr.param = mul_tensor;
|
mul_attr.param = mul_tensor;
|
||||||
|
|
||||||
FuseMultiplyWithFullyConnected(mul_attr, &attr);
|
FuseMultiplyWithFullyConnected(mul_attr, &attr);
|
||||||
|
@ -29,115 +29,116 @@ limitations under the License.
|
|||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace gpu {
|
namespace gpu {
|
||||||
namespace gl {
|
namespace gl {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
class ApplyMask : public NodeShader {
|
bool IsApplyMaskSupported(const NodeShader::GenerationContext& ctx) {
|
||||||
public:
|
const auto inputs = ctx.graph->FindInputs(ctx.node->id);
|
||||||
static bool IsSupported(const GenerationContext& ctx) {
|
if (inputs.size() != 2) return false;
|
||||||
const auto inputs = ctx.graph->FindInputs(ctx.node->id);
|
const auto& shape0 = inputs[0]->tensor.shape;
|
||||||
if (inputs.size() != 2) return false;
|
const auto& shape1 = inputs[1]->tensor.shape;
|
||||||
const auto& shape0 = inputs[0]->tensor.shape;
|
|
||||||
const auto& shape1 = inputs[1]->tensor.shape;
|
|
||||||
|
|
||||||
// [H, W, C] x [H, W, 0][0]
|
// [H, W, C] x [H, W, 0][0]
|
||||||
if (shape1.c == 1) return true;
|
if (shape0.h == shape1.h && shape0.w == shape1.w && shape1.c == 1) {
|
||||||
|
return true;
|
||||||
if (shape0.c != shape1.c) return false;
|
|
||||||
|
|
||||||
// [H, W, C] x [H, W, C]
|
|
||||||
if (shape0.h == shape1.h && shape0.w == shape1.w) return true;
|
|
||||||
|
|
||||||
// [H, W, C] x [0, 0, C]
|
|
||||||
return shape1.h == 1 && shape1.w == 1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GenerateCode(const GenerationContext& ctx,
|
// [H, W, C] x [H, W, C]
|
||||||
GeneratedCode* generated_code) const final {
|
if (shape0 == shape1) {
|
||||||
if (!IsSupported(ctx)) {
|
return true;
|
||||||
return InvalidArgumentError(
|
}
|
||||||
"This case is not supported by apply mask operation");
|
|
||||||
}
|
|
||||||
const auto inputs = ctx.graph->FindInputs(ctx.node->id);
|
|
||||||
const auto& shape0 = inputs[0]->tensor.shape;
|
|
||||||
const auto& shape1 = inputs[1]->tensor.shape;
|
|
||||||
|
|
||||||
std::string source = "value_0 = $input_data_0[gid.x, gid.y, gid.z]$ * ";
|
// [H, W, C] x [0, 0, C]
|
||||||
if (shape1.c == 1) {
|
return shape1.h == 1 && shape1.w == 1 && shape0.c == shape1.c;
|
||||||
// [H, W, C] x [H, W, 0][0]
|
}
|
||||||
absl::StrAppend(&source, "$input_data_1[gid.x, gid.y, 0]$.x;");
|
|
||||||
} else if (shape0.h == shape1.h && shape0.w == shape1.w) {
|
|
||||||
// [H, W, C] x [H, W, C]
|
|
||||||
absl::StrAppend(&source, "$input_data_1[gid.x, gid.y, gid.z]$;");
|
|
||||||
} else {
|
|
||||||
// [H, W, C] x [0, 0, C]
|
|
||||||
absl::StrAppend(&source, "$input_data_1[0, 0, gid.z]$;");
|
|
||||||
}
|
|
||||||
|
|
||||||
|
Status GenerateApplyMaskCode(const NodeShader::GenerationContext& ctx,
|
||||||
|
GeneratedCode* generated_code) {
|
||||||
|
const auto inputs = ctx.graph->FindInputs(ctx.node->id);
|
||||||
|
const auto& shape0 = inputs[0]->tensor.shape;
|
||||||
|
const auto& shape1 = inputs[1]->tensor.shape;
|
||||||
|
|
||||||
|
std::string source = "value_0 = $input_data_0[gid.x, gid.y, gid.z]$ * ";
|
||||||
|
if (shape1.c == 1) {
|
||||||
|
// [H, W, C] x [H, W, 0][0]
|
||||||
|
absl::StrAppend(&source, "$input_data_1[gid.x, gid.y, 0]$.x;");
|
||||||
|
} else if (shape0.h == shape1.h && shape0.w == shape1.w) {
|
||||||
|
// [H, W, C] x [H, W, C]
|
||||||
|
absl::StrAppend(&source, "$input_data_1[gid.x, gid.y, gid.z]$;");
|
||||||
|
} else {
|
||||||
|
// [H, W, C] x [0, 0, C]
|
||||||
|
absl::StrAppend(&source, "$input_data_1[0, 0, gid.z]$;");
|
||||||
|
}
|
||||||
|
|
||||||
|
*generated_code = {
|
||||||
|
/*parameters=*/{},
|
||||||
|
/*objects=*/{},
|
||||||
|
/*shared_variables=*/{},
|
||||||
|
/*workload=*/uint3(),
|
||||||
|
/*workgroup=*/uint3(),
|
||||||
|
/*source_code=*/std::move(source),
|
||||||
|
/*input=*/IOStructure::ONLY_DEFINITIONS,
|
||||||
|
/*output=*/IOStructure::AUTO,
|
||||||
|
};
|
||||||
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GenerateMultiplyScalarCode(const NodeShader::GenerationContext& ctx,
|
||||||
|
GeneratedCode* generated_code) {
|
||||||
|
auto attr =
|
||||||
|
absl::any_cast<MultiplyAttributes>(ctx.node->operation.attributes);
|
||||||
|
auto muls = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&attr.param);
|
||||||
|
auto scalar = absl::get_if<float>(&attr.param);
|
||||||
|
|
||||||
|
if (scalar) {
|
||||||
*generated_code = {
|
*generated_code = {
|
||||||
/*parameters=*/{},
|
/*parameters=*/{{"scalar", *scalar}},
|
||||||
/*objects=*/{},
|
/*objects=*/{},
|
||||||
/*shared_variables=*/{},
|
/*shared_variables=*/{},
|
||||||
/*workload=*/uint3(),
|
/*workload=*/uint3(),
|
||||||
/*workgroup=*/uint3(),
|
/*workgroup=*/uint3(),
|
||||||
/*source_code=*/std::move(source),
|
/*source_code=*/"value_0 *= $scalar$;",
|
||||||
/*input=*/IOStructure::ONLY_DEFINITIONS,
|
/*input=*/IOStructure::AUTO,
|
||||||
|
/*output=*/IOStructure::AUTO,
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
if (!muls) {
|
||||||
|
return InvalidArgumentError("Empty parameters for Multiplication.");
|
||||||
|
}
|
||||||
|
auto shape = ctx.graph->FindInputs(ctx.node->id)[0]->tensor.shape;
|
||||||
|
*generated_code = {
|
||||||
|
/*parameters=*/{},
|
||||||
|
/*objects=*/{{"mul_buffer", MakeReadonlyObject(muls->data)}},
|
||||||
|
/*shared_variables=*/{},
|
||||||
|
// Declare workload explicitly because shader depends on gid.z.
|
||||||
|
/*workload=*/
|
||||||
|
uint3(shape.w, shape.h, IntegralDivideRoundUp(shape.c, 4)),
|
||||||
|
/*workgroup=*/uint3(),
|
||||||
|
/*source_code=*/"value_0 *= $mul_buffer[gid.z]$;",
|
||||||
|
/*input=*/IOStructure::AUTO,
|
||||||
/*output=*/IOStructure::AUTO,
|
/*output=*/IOStructure::AUTO,
|
||||||
};
|
};
|
||||||
return OkStatus();
|
|
||||||
}
|
}
|
||||||
};
|
|
||||||
|
|
||||||
class MultiplyScalar : public NodeShader {
|
return OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
class Multiply : public NodeShader {
|
||||||
public:
|
public:
|
||||||
Status GenerateCode(const GenerationContext& ctx,
|
Status GenerateCode(const GenerationContext& ctx,
|
||||||
GeneratedCode* generated_code) const final {
|
GeneratedCode* generated_code) const final {
|
||||||
auto attr = absl::any_cast<MultiplyScalarAttributes>(
|
if (IsApplyMaskSupported(ctx)) {
|
||||||
ctx.node->operation.attributes);
|
return GenerateApplyMaskCode(ctx, generated_code);
|
||||||
auto muls = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&attr.param);
|
|
||||||
auto scalar = absl::get_if<float>(&attr.param);
|
|
||||||
|
|
||||||
if (scalar) {
|
|
||||||
*generated_code = {
|
|
||||||
/*parameters=*/{{"scalar", *scalar}},
|
|
||||||
/*objects=*/{},
|
|
||||||
/*shared_variables=*/{},
|
|
||||||
/*workload=*/uint3(),
|
|
||||||
/*workgroup=*/uint3(),
|
|
||||||
/*source_code=*/"value_0 *= $scalar$;",
|
|
||||||
/*input=*/IOStructure::AUTO,
|
|
||||||
/*output=*/IOStructure::AUTO,
|
|
||||||
};
|
|
||||||
} else {
|
} else {
|
||||||
if (!muls) {
|
return GenerateMultiplyScalarCode(ctx, generated_code);
|
||||||
return InvalidArgumentError("Empty parameters for Multiplication.");
|
|
||||||
}
|
|
||||||
auto shape = ctx.graph->FindInputs(ctx.node->id)[0]->tensor.shape;
|
|
||||||
*generated_code = {
|
|
||||||
/*parameters=*/{},
|
|
||||||
/*objects=*/{{"mul_buffer", MakeReadonlyObject(muls->data)}},
|
|
||||||
/*shared_variables=*/{},
|
|
||||||
// Declare workload explicitly because shader depends on gid.z.
|
|
||||||
/*workload=*/
|
|
||||||
uint3(shape.w, shape.h, IntegralDivideRoundUp(shape.c, 4)),
|
|
||||||
/*workgroup=*/uint3(),
|
|
||||||
/*source_code=*/"value_0 *= $mul_buffer[gid.z]$;",
|
|
||||||
/*input=*/IOStructure::AUTO,
|
|
||||||
/*output=*/IOStructure::AUTO,
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return OkStatus();
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<NodeShader> NewApplyMaskNodeShader() {
|
std::unique_ptr<NodeShader> NewMultiplyNodeShader() {
|
||||||
return absl::make_unique<ApplyMask>();
|
return absl::make_unique<Multiply>();
|
||||||
}
|
|
||||||
|
|
||||||
std::unique_ptr<NodeShader> NewMultiplyScalarNodeShader() {
|
|
||||||
return absl::make_unique<MultiplyScalar>();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gl
|
} // namespace gl
|
||||||
|
@ -25,9 +25,7 @@ namespace tflite {
|
|||||||
namespace gpu {
|
namespace gpu {
|
||||||
namespace gl {
|
namespace gl {
|
||||||
|
|
||||||
std::unique_ptr<NodeShader> NewApplyMaskNodeShader();
|
std::unique_ptr<NodeShader> NewMultiplyNodeShader();
|
||||||
|
|
||||||
std::unique_ptr<NodeShader> NewMultiplyScalarNodeShader();
|
|
||||||
|
|
||||||
} // namespace gl
|
} // namespace gl
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
|
@ -41,13 +41,12 @@ TEST(MulTest, Scalar) {
|
|||||||
output.ref = 1;
|
output.ref = 1;
|
||||||
output.shape = BHWC(1, 2, 2, 1);
|
output.shape = BHWC(1, 2, 2, 1);
|
||||||
|
|
||||||
MultiplyScalarAttributes attr;
|
MultiplyAttributes attr;
|
||||||
attr.param = 2.f;
|
attr.param = 2.f;
|
||||||
|
|
||||||
// TODO(eignasheva): change to MULTIPLY_SCALAR
|
|
||||||
SingleOpModel model({ToString(OperationType::MUL), attr}, {input}, {output});
|
SingleOpModel model({ToString(OperationType::MUL), attr}, {input}, {output});
|
||||||
ASSERT_TRUE(model.PopulateTensor(0, {1, 2, 3, 4}));
|
ASSERT_TRUE(model.PopulateTensor(0, {1, 2, 3, 4}));
|
||||||
ASSERT_OK(model.Invoke(*NewMultiplyScalarNodeShader()));
|
ASSERT_OK(model.Invoke(*NewMultiplyNodeShader()));
|
||||||
EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {2, 4, 6, 8}));
|
EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {2, 4, 6, 8}));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -62,21 +61,20 @@ TEST(MulTest, Linear) {
|
|||||||
output.ref = 1;
|
output.ref = 1;
|
||||||
output.shape = BHWC(1, 1, 2, 2);
|
output.shape = BHWC(1, 1, 2, 2);
|
||||||
|
|
||||||
MultiplyScalarAttributes attr;
|
MultiplyAttributes attr;
|
||||||
Tensor<Linear, DataType::FLOAT32> tensor;
|
Tensor<Linear, DataType::FLOAT32> tensor;
|
||||||
tensor.shape.v = 2;
|
tensor.shape.v = 2;
|
||||||
tensor.id = 1;
|
tensor.id = 1;
|
||||||
tensor.data = {2, 3};
|
tensor.data = {2, 3};
|
||||||
attr.param = std::move(tensor);
|
attr.param = std::move(tensor);
|
||||||
|
|
||||||
// TODO(eignasheva): change to MULTIPLY_SCALAR
|
|
||||||
SingleOpModel model({ToString(OperationType::MUL), attr}, {input}, {output});
|
SingleOpModel model({ToString(OperationType::MUL), attr}, {input}, {output});
|
||||||
ASSERT_TRUE(model.PopulateTensor(0, {1, 2, 3, 4}));
|
ASSERT_TRUE(model.PopulateTensor(0, {1, 2, 3, 4}));
|
||||||
ASSERT_OK(model.Invoke(*NewMultiplyScalarNodeShader()));
|
ASSERT_OK(model.Invoke(*NewMultiplyNodeShader()));
|
||||||
EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {2, 6, 6, 12}));
|
EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {2, 6, 6, 12}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ApplyMaskTest, MaskChannel1) {
|
TEST(MulTest, MaskChannel1) {
|
||||||
TensorRef<BHWC> input;
|
TensorRef<BHWC> input;
|
||||||
input.type = DataType::FLOAT32;
|
input.type = DataType::FLOAT32;
|
||||||
input.ref = 0;
|
input.ref = 0;
|
||||||
@ -92,15 +90,15 @@ TEST(ApplyMaskTest, MaskChannel1) {
|
|||||||
output.ref = 2;
|
output.ref = 2;
|
||||||
output.shape = BHWC(1, 1, 2, 2);
|
output.shape = BHWC(1, 1, 2, 2);
|
||||||
|
|
||||||
SingleOpModel model({ToString(OperationType::APPLY_MASK), {}}, {input, mask},
|
SingleOpModel model({ToString(OperationType::MUL), {}}, {input, mask},
|
||||||
{output});
|
{output});
|
||||||
ASSERT_TRUE(model.PopulateTensor(0, {1, 2, 3, 4}));
|
ASSERT_TRUE(model.PopulateTensor(0, {1, 2, 3, 4}));
|
||||||
ASSERT_TRUE(model.PopulateTensor(1, {2, 3}));
|
ASSERT_TRUE(model.PopulateTensor(1, {2, 3}));
|
||||||
ASSERT_OK(model.Invoke(*NewApplyMaskNodeShader()));
|
ASSERT_OK(model.Invoke(*NewMultiplyNodeShader()));
|
||||||
EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {2, 4, 9, 12}));
|
EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {2, 4, 9, 12}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ApplyMaskTest, MaskChannelEqualsToInputChannel) {
|
TEST(MulTest, MaskChannelEqualsToInputChannel) {
|
||||||
TensorRef<BHWC> input;
|
TensorRef<BHWC> input;
|
||||||
input.type = DataType::FLOAT32;
|
input.type = DataType::FLOAT32;
|
||||||
input.ref = 0;
|
input.ref = 0;
|
||||||
@ -116,11 +114,11 @@ TEST(ApplyMaskTest, MaskChannelEqualsToInputChannel) {
|
|||||||
output.ref = 2;
|
output.ref = 2;
|
||||||
output.shape = BHWC(1, 1, 2, 2);
|
output.shape = BHWC(1, 1, 2, 2);
|
||||||
|
|
||||||
SingleOpModel model({ToString(OperationType::APPLY_MASK), {}}, {input, mask},
|
SingleOpModel model({ToString(OperationType::MUL), {}}, {input, mask},
|
||||||
{output});
|
{output});
|
||||||
ASSERT_TRUE(model.PopulateTensor(0, {1, 2, 3, 4}));
|
ASSERT_TRUE(model.PopulateTensor(0, {1, 2, 3, 4}));
|
||||||
ASSERT_TRUE(model.PopulateTensor(1, {1, 2, 3, 4}));
|
ASSERT_TRUE(model.PopulateTensor(1, {1, 2, 3, 4}));
|
||||||
ASSERT_OK(model.Invoke(*NewApplyMaskNodeShader()));
|
ASSERT_OK(model.Invoke(*NewMultiplyNodeShader()));
|
||||||
EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {1, 4, 9, 16}));
|
EXPECT_THAT(model.GetOutput(0), Pointwise(FloatNear(1e-6), {1, 4, 9, 16}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -71,7 +71,6 @@ class Registry : public NodeShader {
|
|||||||
};
|
};
|
||||||
|
|
||||||
insert_op(Type::ADD, NewAddNodeShader);
|
insert_op(Type::ADD, NewAddNodeShader);
|
||||||
insert_op(Type::APPLY_MASK, NewApplyMaskNodeShader);
|
|
||||||
insert_op(Type::CONCAT, NewAlignedConcatNodeShader);
|
insert_op(Type::CONCAT, NewAlignedConcatNodeShader);
|
||||||
insert_op(Type::CONCAT, NewFlatConcatNodeShader);
|
insert_op(Type::CONCAT, NewFlatConcatNodeShader);
|
||||||
insert_op(Type::CONCAT, NewConcatNodeShader);
|
insert_op(Type::CONCAT, NewConcatNodeShader);
|
||||||
@ -82,7 +81,7 @@ class Registry : public NodeShader {
|
|||||||
insert_op(Type::FULLY_CONNECTED, NewFullyConnectedNodeShader);
|
insert_op(Type::FULLY_CONNECTED, NewFullyConnectedNodeShader);
|
||||||
insert_op(Type::LSTM, NewLstmNodeShader);
|
insert_op(Type::LSTM, NewLstmNodeShader);
|
||||||
insert_op(Type::MEAN, NewMeanNodeShader);
|
insert_op(Type::MEAN, NewMeanNodeShader);
|
||||||
insert_op(Type::MULTIPLY_SCALAR, NewMultiplyScalarNodeShader);
|
insert_op(Type::MUL, NewMultiplyNodeShader);
|
||||||
insert_op(Type::PAD, NewPadNodeShader);
|
insert_op(Type::PAD, NewPadNodeShader);
|
||||||
insert_op(Type::POOLING_2D, NewPoolingNodeShader);
|
insert_op(Type::POOLING_2D, NewPoolingNodeShader);
|
||||||
insert_op(Type::PRELU, NewPReLUNodeShader);
|
insert_op(Type::PRELU, NewPReLUNodeShader);
|
||||||
|
@ -199,11 +199,15 @@ Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
|
|||||||
*tasks = Mean(node_id, inputs[0], outputs[0],
|
*tasks = Mean(node_id, inputs[0], outputs[0],
|
||||||
absl::any_cast<MeanAttributes>(node->operation.attributes));
|
absl::any_cast<MeanAttributes>(node->operation.attributes));
|
||||||
break;
|
break;
|
||||||
case OperationType::MULTIPLY_SCALAR:
|
case OperationType::MUL:
|
||||||
*tasks = Multiply(
|
if (node->operation.attributes.has_value()) {
|
||||||
node_id, inputs[0], outputs[0],
|
*tasks = Multiply(
|
||||||
absl::any_cast<MultiplyScalarAttributes>(node->operation.attributes),
|
node_id, inputs[0], outputs[0],
|
||||||
options);
|
absl::any_cast<MultiplyAttributes>(node->operation.attributes),
|
||||||
|
options);
|
||||||
|
} else {
|
||||||
|
*tasks = ApplyMask(node_id, inputs[0], inputs[1], outputs[0], options);
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
case OperationType::PAD: {
|
case OperationType::PAD: {
|
||||||
auto attr = absl::any_cast<PadAttributes>(node->operation.attributes);
|
auto attr = absl::any_cast<PadAttributes>(node->operation.attributes);
|
||||||
@ -268,12 +272,10 @@ Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
|
|||||||
case OperationType::SQUARED_DIFF:
|
case OperationType::SQUARED_DIFF:
|
||||||
*tasks = ElementwiseWithTwoInputs(node_id, inputs, outputs[0], op_type);
|
*tasks = ElementwiseWithTwoInputs(node_id, inputs, outputs[0], op_type);
|
||||||
break;
|
break;
|
||||||
case OperationType::APPLY_MASK:
|
|
||||||
case OperationType::BATCH_NORMALIZATION:
|
case OperationType::BATCH_NORMALIZATION:
|
||||||
case OperationType::BATCH_TO_SPACE:
|
case OperationType::BATCH_TO_SPACE:
|
||||||
case OperationType::CONST:
|
case OperationType::CONST:
|
||||||
case OperationType::LSTM:
|
case OperationType::LSTM:
|
||||||
case OperationType::MUL:
|
|
||||||
case OperationType::SPACE_TO_BATCH:
|
case OperationType::SPACE_TO_BATCH:
|
||||||
case OperationType::TRANSPOSE:
|
case OperationType::TRANSPOSE:
|
||||||
case OperationType::UNKNOWN:
|
case OperationType::UNKNOWN:
|
||||||
|
@ -128,9 +128,10 @@ std::vector<ComputeTaskDescriptorPtr> ApplyMask(int id, ValueId input_id_0,
|
|||||||
return {desc};
|
return {desc};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<ComputeTaskDescriptorPtr> Multiply(
|
std::vector<ComputeTaskDescriptorPtr> Multiply(int id, ValueId input_id,
|
||||||
int id, ValueId input_id, ValueId output_id,
|
ValueId output_id,
|
||||||
const MultiplyScalarAttributes& attr, const RuntimeOptions& options) {
|
const MultiplyAttributes& attr,
|
||||||
|
const RuntimeOptions& options) {
|
||||||
auto desc = std::make_shared<ComputeTaskDescriptor>();
|
auto desc = std::make_shared<ComputeTaskDescriptor>();
|
||||||
desc->id = id;
|
desc->id = id;
|
||||||
desc->is_linkable = true;
|
desc->is_linkable = true;
|
||||||
|
@ -26,9 +26,10 @@ namespace gpu {
|
|||||||
namespace metal {
|
namespace metal {
|
||||||
|
|
||||||
// Multiply operation, supports scalar and vector broadcast.
|
// Multiply operation, supports scalar and vector broadcast.
|
||||||
std::vector<ComputeTaskDescriptorPtr> Multiply(
|
std::vector<ComputeTaskDescriptorPtr> Multiply(int id, ValueId input_id,
|
||||||
int id, ValueId input_id, ValueId output_id,
|
ValueId output_id,
|
||||||
const MultiplyScalarAttributes& attr, const RuntimeOptions& options);
|
const MultiplyAttributes& attr,
|
||||||
|
const RuntimeOptions& options);
|
||||||
|
|
||||||
std::vector<ComputeTaskDescriptorPtr> ApplyMask(int id, ValueId input_id_0,
|
std::vector<ComputeTaskDescriptorPtr> ApplyMask(int id, ValueId input_id_0,
|
||||||
ValueId input_id_1,
|
ValueId input_id_1,
|
||||||
|
@ -31,7 +31,7 @@ limitations under the License.
|
|||||||
using ::tflite::gpu::DataType;
|
using ::tflite::gpu::DataType;
|
||||||
using ::tflite::gpu::BHWC;
|
using ::tflite::gpu::BHWC;
|
||||||
using ::tflite::gpu::Linear;
|
using ::tflite::gpu::Linear;
|
||||||
using ::tflite::gpu::MultiplyScalarAttributes;
|
using ::tflite::gpu::MultiplyAttributes;
|
||||||
using ::tflite::gpu::OperationType;
|
using ::tflite::gpu::OperationType;
|
||||||
using ::tflite::gpu::Tensor;
|
using ::tflite::gpu::Tensor;
|
||||||
using ::tflite::gpu::TensorRef;
|
using ::tflite::gpu::TensorRef;
|
||||||
@ -57,10 +57,10 @@ using ::tflite::gpu::metal::SingleOpModel;
|
|||||||
output.ref = 1;
|
output.ref = 1;
|
||||||
output.shape = BHWC(1, 2, 2, 1);
|
output.shape = BHWC(1, 2, 2, 1);
|
||||||
|
|
||||||
MultiplyScalarAttributes attr;
|
MultiplyAttributes attr;
|
||||||
attr.param = 2;
|
attr.param = 2;
|
||||||
|
|
||||||
SingleOpModel model({ToString(OperationType::MULTIPLY_SCALAR), attr}, {input}, {output});
|
SingleOpModel model({ToString(OperationType::MUL), attr}, {input}, {output});
|
||||||
XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4}));
|
XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4}));
|
||||||
auto status = model.Invoke();
|
auto status = model.Invoke();
|
||||||
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
|
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
|
||||||
@ -79,14 +79,14 @@ using ::tflite::gpu::metal::SingleOpModel;
|
|||||||
output.ref = 1;
|
output.ref = 1;
|
||||||
output.shape = BHWC(1, 1, 2, 2);
|
output.shape = BHWC(1, 1, 2, 2);
|
||||||
|
|
||||||
MultiplyScalarAttributes attr;
|
MultiplyAttributes attr;
|
||||||
Tensor<Linear, DataType::FLOAT32> tensor;
|
Tensor<Linear, DataType::FLOAT32> tensor;
|
||||||
tensor.shape.v = 2;
|
tensor.shape.v = 2;
|
||||||
tensor.id = 1;
|
tensor.id = 1;
|
||||||
tensor.data = {2, 3};
|
tensor.data = {2, 3};
|
||||||
attr.param = std::move(tensor);
|
attr.param = std::move(tensor);
|
||||||
|
|
||||||
SingleOpModel model({ToString(OperationType::MULTIPLY_SCALAR), attr}, {input}, {output});
|
SingleOpModel model({ToString(OperationType::MUL), attr}, {input}, {output});
|
||||||
XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4}));
|
XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4}));
|
||||||
auto status = model.Invoke();
|
auto status = model.Invoke();
|
||||||
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
|
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
|
||||||
@ -111,7 +111,7 @@ using ::tflite::gpu::metal::SingleOpModel;
|
|||||||
output.ref = 2;
|
output.ref = 2;
|
||||||
output.shape = BHWC(1, 1, 2, 2);
|
output.shape = BHWC(1, 1, 2, 2);
|
||||||
|
|
||||||
SingleOpModel model({ToString(OperationType::APPLY_MASK), {}}, {input, mask}, {output});
|
SingleOpModel model({ToString(OperationType::MUL), {}}, {input, mask}, {output});
|
||||||
XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4}));
|
XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4}));
|
||||||
XCTAssertTrue(model.PopulateTensor(1, {2, 3}));
|
XCTAssertTrue(model.PopulateTensor(1, {2, 3}));
|
||||||
auto status = model.Invoke();
|
auto status = model.Invoke();
|
||||||
@ -136,13 +136,12 @@ using ::tflite::gpu::metal::SingleOpModel;
|
|||||||
output.ref = 2;
|
output.ref = 2;
|
||||||
output.shape = BHWC(1, 1, 2, 2);
|
output.shape = BHWC(1, 1, 2, 2);
|
||||||
|
|
||||||
SingleOpModel model({ToString(OperationType::APPLY_MASK), {}}, {input, mask}, {output});
|
SingleOpModel model({ToString(OperationType::MUL), {}}, {input, mask}, {output});
|
||||||
XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4}));
|
XCTAssertTrue(model.PopulateTensor(0, {1, 2, 3, 4}));
|
||||||
XCTAssertTrue(model.PopulateTensor(1, {1, 2, 3, 4}));
|
XCTAssertTrue(model.PopulateTensor(1, {1, 2, 3, 4}));
|
||||||
auto status = model.Invoke();
|
auto status = model.Invoke();
|
||||||
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
|
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
|
||||||
// Disable test for now.
|
status = CompareVectors({1, 4, 9, 16}, model.GetOutput(0), 1e-6f);
|
||||||
// status = CompareVectors({1, 4, 9, 16}, model.GetOutput(0), 1e-6f);
|
|
||||||
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
|
XCTAssertTrue(status.ok(), @"%s", status.error_message().c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user