Removed CreationContext from some ops.
PiperOrigin-RevId: 327494182 Change-Id: I70a8d649b51c891d720b789db75befed359ec08f
This commit is contained in:
parent
9c828254cd
commit
7c64157c36
@ -160,68 +160,68 @@ GPUOperation CreateElementwiseOneRuntimeOneScalar(
|
||||
|
||||
// Creates simple two input(first input is runtime tensor and second input is
|
||||
// constant linear tensor) operation, for example sub, div and etc.
|
||||
absl::Status CreateElementwiseTwoInput(
|
||||
const CreationContext& creation_context, const OperationDef& definition,
|
||||
GPUOperation CreateElementwiseTwoInput(
|
||||
const DeviceInfo& device_info, const OperationDef& definition,
|
||||
const OperationType& op_type,
|
||||
const tflite::gpu::Tensor<Linear, DataType::FLOAT32>& constant_tensor,
|
||||
bool swap_inputs, GPUOperation* result) {
|
||||
bool swap_inputs) {
|
||||
const BHWC shape = BHWC(1, 1, 1, constant_tensor.shape.v);
|
||||
TensorStorageType storage_type = SelectBestStorageType(
|
||||
creation_context.device->info_, shape, definition.GetPrimaryStorageType(),
|
||||
device_info, shape, definition.GetPrimaryStorageType(),
|
||||
definition.GetDataType(), Layout::HWC);
|
||||
TensorDescriptor desc{definition.GetDataType(), storage_type, Layout::HWC};
|
||||
desc.UploadData(constant_tensor);
|
||||
|
||||
*result = GPUOperation(definition);
|
||||
result->elementwise_ = true;
|
||||
result->args_.AddObject("second_tensor",
|
||||
absl::make_unique<TensorDescriptor>(std::move(desc)));
|
||||
GPUOperation result(definition);
|
||||
result.elementwise_ = true;
|
||||
result.args_.AddObject("second_tensor",
|
||||
absl::make_unique<TensorDescriptor>(std::move(desc)));
|
||||
const std::string s_coord = shape.c == 1 ? "0" : "S_COORD";
|
||||
result->code_ = absl::StrCat(
|
||||
result.code_ = absl::StrCat(
|
||||
"FLT4 second_val = args.second_tensor.Read(0, 0, ", s_coord, ");\n");
|
||||
if (shape.c == 1) {
|
||||
result->code_ += " second_val.y = second_val.x;\n";
|
||||
result->code_ += " second_val.z = second_val.x;\n";
|
||||
result->code_ += " second_val.w = second_val.x;\n";
|
||||
result.code_ += " second_val.y = second_val.x;\n";
|
||||
result.code_ += " second_val.z = second_val.x;\n";
|
||||
result.code_ += " second_val.w = second_val.x;\n";
|
||||
}
|
||||
result->code_ += GetTwoInputCode(op_type, "in_out_value", "in_out_value",
|
||||
"second_val", swap_inputs);
|
||||
return absl::OkStatus();
|
||||
result.code_ += GetTwoInputCode(op_type, "in_out_value", "in_out_value",
|
||||
"second_val", swap_inputs);
|
||||
return result;
|
||||
}
|
||||
|
||||
// Creates simple two input(first input is runtime tensor and second input is
|
||||
// constant HWC tensor) operation, for example sub, div and etc.
|
||||
absl::Status CreateElementwiseTwoInput(
|
||||
const CreationContext& creation_context, const OperationDef& definition,
|
||||
GPUOperation CreateElementwiseTwoInput(
|
||||
const DeviceInfo& device_info, const OperationDef& definition,
|
||||
const OperationType& op_type,
|
||||
const tflite::gpu::Tensor<HWC, DataType::FLOAT32>& constant_tensor,
|
||||
bool swap_inputs, GPUOperation* result) {
|
||||
bool swap_inputs) {
|
||||
const BHWC shape = BHWC(1, constant_tensor.shape.h, constant_tensor.shape.w,
|
||||
constant_tensor.shape.c);
|
||||
TensorStorageType storage_type = SelectBestStorageType(
|
||||
creation_context.device->info_, shape, definition.GetPrimaryStorageType(),
|
||||
device_info, shape, definition.GetPrimaryStorageType(),
|
||||
definition.GetDataType(), Layout::HWC);
|
||||
TensorDescriptor desc{definition.GetDataType(), storage_type, Layout::HWC};
|
||||
desc.UploadData(constant_tensor);
|
||||
|
||||
*result = GPUOperation(definition);
|
||||
result->elementwise_ = true;
|
||||
result->args_.AddObject("second_tensor",
|
||||
absl::make_unique<TensorDescriptor>(std::move(desc)));
|
||||
GPUOperation result(definition);
|
||||
result.elementwise_ = true;
|
||||
result.args_.AddObject("second_tensor",
|
||||
absl::make_unique<TensorDescriptor>(std::move(desc)));
|
||||
const std::string x_coord = shape.w == 1 ? "0" : "X_COORD";
|
||||
const std::string y_coord = shape.h == 1 ? "0" : "Y_COORD";
|
||||
const std::string s_coord = shape.c == 1 ? "0" : "S_COORD";
|
||||
result->code_ = absl::StrCat("FLT4 second_val = args.second_tensor.Read(",
|
||||
x_coord, ", ", y_coord, ", ", s_coord, ");\n");
|
||||
result.code_ = absl::StrCat("FLT4 second_val = args.second_tensor.Read(",
|
||||
x_coord, ", ", y_coord, ", ", s_coord, ");\n");
|
||||
if (shape.c == 1) {
|
||||
result->code_ += " second_val.y = second_val.x;\n";
|
||||
result->code_ += " second_val.z = second_val.x;\n";
|
||||
result->code_ += " second_val.w = second_val.x;\n";
|
||||
result.code_ += " second_val.y = second_val.x;\n";
|
||||
result.code_ += " second_val.z = second_val.x;\n";
|
||||
result.code_ += " second_val.w = second_val.x;\n";
|
||||
}
|
||||
result->code_ += GetTwoInputCode(op_type, "in_out_value", "in_out_value",
|
||||
"second_val", swap_inputs);
|
||||
result.code_ += GetTwoInputCode(op_type, "in_out_value", "in_out_value",
|
||||
"second_val", swap_inputs);
|
||||
|
||||
return absl::OkStatus();
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@ -234,11 +234,10 @@ GPUOperation CreateElementwiseOneInput(const OperationDef& definition,
|
||||
return op;
|
||||
}
|
||||
|
||||
absl::Status CreateElementwise(const CreationContext& creation_context,
|
||||
GPUOperation CreateElementwise(const DeviceInfo& device_info,
|
||||
const OperationDef& definition,
|
||||
const OperationType& op_type,
|
||||
const ElementwiseAttributes& attr,
|
||||
GPUOperation* result) {
|
||||
const ElementwiseAttributes& attr) {
|
||||
const float* scalar = absl::get_if<float>(&attr.param);
|
||||
const auto* linear_tensor =
|
||||
absl::get_if<tflite::gpu::Tensor<Linear, DataType::FLOAT32>>(&attr.param);
|
||||
@ -246,20 +245,19 @@ absl::Status CreateElementwise(const CreationContext& creation_context,
|
||||
absl::get_if<tflite::gpu::Tensor<HWC, DataType::FLOAT32>>(&attr.param);
|
||||
|
||||
if (scalar) {
|
||||
*result = CreateElementwiseOneRuntimeOneScalar(
|
||||
definition, op_type, *scalar, attr.runtime_tensor_is_second);
|
||||
return absl::OkStatus();
|
||||
return CreateElementwiseOneRuntimeOneScalar(definition, op_type, *scalar,
|
||||
attr.runtime_tensor_is_second);
|
||||
} else if (linear_tensor) {
|
||||
return CreateElementwiseTwoInput(creation_context, definition, op_type,
|
||||
return CreateElementwiseTwoInput(device_info, definition, op_type,
|
||||
*linear_tensor,
|
||||
attr.runtime_tensor_is_second, result);
|
||||
attr.runtime_tensor_is_second);
|
||||
} else if (hwc_tensor) {
|
||||
return CreateElementwiseTwoInput(creation_context, definition, op_type,
|
||||
*hwc_tensor, attr.runtime_tensor_is_second,
|
||||
result);
|
||||
return CreateElementwiseTwoInput(device_info, definition, op_type,
|
||||
*hwc_tensor,
|
||||
attr.runtime_tensor_is_second);
|
||||
} else {
|
||||
return GPUOperation(definition);
|
||||
}
|
||||
return absl::UnimplementedError(
|
||||
"No elementwise implementation for this case");
|
||||
}
|
||||
|
||||
GPUOperation CreateElementwiseTwoInput(const OperationDef& definition,
|
||||
|
@ -33,11 +33,10 @@ GPUOperation CreateElementwiseOneInput(const OperationDef& definition,
|
||||
|
||||
// Creates simple two input(first input is runtime tensor and second input is
|
||||
// constant or linear/hwc tensor) operation, for example sub, div and etc.
|
||||
absl::Status CreateElementwise(const CreationContext& creation_context,
|
||||
GPUOperation CreateElementwise(const DeviceInfo& device_info,
|
||||
const OperationDef& definition,
|
||||
const OperationType& op_type,
|
||||
const ElementwiseAttributes& attr,
|
||||
GPUOperation* result);
|
||||
const ElementwiseAttributes& attr);
|
||||
|
||||
// Creates simple two input(2 runtime tensors) operation, for example
|
||||
// sub, div and etc.
|
||||
|
@ -546,9 +546,9 @@ TEST_F(OpenCLOperationTest, MaximumWithScalar) {
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
TensorFloat32 dst_tensor;
|
||||
GPUOperation operation;
|
||||
ASSERT_OK(CreateElementwise(creation_context_, op_def,
|
||||
OperationType::MAXIMUM, attr, &operation));
|
||||
GPUOperation operation =
|
||||
CreateElementwise(creation_context_.GetDeviceInfo(), op_def,
|
||||
OperationType::MAXIMUM, attr);
|
||||
ASSERT_OK(ExecuteGPUOperation(src_tensor_0, creation_context_, &operation,
|
||||
BHWC(1, 4, 1, 1), &dst_tensor));
|
||||
EXPECT_THAT(dst_tensor.data,
|
||||
@ -577,9 +577,9 @@ TEST_F(OpenCLOperationTest, MaximumWithConstantLinearTensor) {
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
TensorFloat32 dst_tensor;
|
||||
GPUOperation operation;
|
||||
ASSERT_OK(CreateElementwise(creation_context_, op_def,
|
||||
OperationType::MAXIMUM, attr, &operation));
|
||||
GPUOperation operation =
|
||||
CreateElementwise(creation_context_.GetDeviceInfo(), op_def,
|
||||
OperationType::MAXIMUM, attr);
|
||||
ASSERT_OK(ExecuteGPUOperation(src_tensor_0, creation_context_, &operation,
|
||||
BHWC(1, 2, 1, 2), &dst_tensor));
|
||||
EXPECT_THAT(dst_tensor.data,
|
||||
@ -608,9 +608,9 @@ TEST_F(OpenCLOperationTest, MaximumWithConstantHWCTensor) {
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
TensorFloat32 dst_tensor;
|
||||
GPUOperation operation;
|
||||
ASSERT_OK(CreateElementwise(creation_context_, op_def,
|
||||
OperationType::MAXIMUM, attr, &operation));
|
||||
GPUOperation operation =
|
||||
CreateElementwise(creation_context_.GetDeviceInfo(), op_def,
|
||||
OperationType::MAXIMUM, attr);
|
||||
ASSERT_OK(ExecuteGPUOperation(src_tensor_0, creation_context_, &operation,
|
||||
BHWC(1, 2, 1, 2), &dst_tensor));
|
||||
EXPECT_THAT(dst_tensor.data,
|
||||
@ -638,9 +638,9 @@ TEST_F(OpenCLOperationTest, MaximumWithConstantHWCTensorBroadcastChannels) {
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
TensorFloat32 dst_tensor;
|
||||
GPUOperation operation;
|
||||
ASSERT_OK(CreateElementwise(creation_context_, op_def,
|
||||
OperationType::MAXIMUM, attr, &operation));
|
||||
GPUOperation operation =
|
||||
CreateElementwise(creation_context_.GetDeviceInfo(), op_def,
|
||||
OperationType::MAXIMUM, attr);
|
||||
ASSERT_OK(ExecuteGPUOperation(src_tensor_0, creation_context_, &operation,
|
||||
BHWC(1, 2, 1, 2), &dst_tensor));
|
||||
EXPECT_THAT(dst_tensor.data,
|
||||
@ -694,9 +694,9 @@ TEST_F(OpenCLOperationTest, MinimumWithScalar) {
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
TensorFloat32 dst_tensor;
|
||||
GPUOperation operation;
|
||||
ASSERT_OK(CreateElementwise(creation_context_, op_def,
|
||||
OperationType::MINIMUM, attr, &operation));
|
||||
GPUOperation operation =
|
||||
CreateElementwise(creation_context_.GetDeviceInfo(), op_def,
|
||||
OperationType::MINIMUM, attr);
|
||||
ASSERT_OK(ExecuteGPUOperation(src_tensor_0, creation_context_, &operation,
|
||||
BHWC(1, 4, 1, 1), &dst_tensor));
|
||||
EXPECT_THAT(dst_tensor.data,
|
||||
@ -807,9 +807,8 @@ TEST_F(OpenCLOperationTest, SubWithScalarAtFirstPosition) {
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
TensorFloat32 dst_tensor;
|
||||
GPUOperation operation;
|
||||
ASSERT_OK(CreateElementwise(creation_context_, op_def, OperationType::SUB,
|
||||
attr, &operation));
|
||||
GPUOperation operation = CreateElementwise(
|
||||
creation_context_.GetDeviceInfo(), op_def, OperationType::SUB, attr);
|
||||
ASSERT_OK(ExecuteGPUOperation(src_tensor_0, creation_context_, &operation,
|
||||
BHWC(1, 4, 1, 1), &dst_tensor));
|
||||
EXPECT_THAT(dst_tensor.data,
|
||||
|
@ -26,7 +26,7 @@ namespace tflite {
|
||||
namespace gpu {
|
||||
namespace cl {
|
||||
GPUOperation CreateQuantizeAndDequantize(
|
||||
const CreationContext& creation_context, const OperationDef& definition,
|
||||
const OperationDef& definition,
|
||||
const QuantizeAndDequantizeAttributes& attr) {
|
||||
QuantizeAndDequantizeAttributes adjusted_attr = attr;
|
||||
const bool is_fp16 = definition.precision == CalculationsPrecision::F16 ||
|
||||
|
@ -44,7 +44,7 @@ namespace cl {
|
||||
// NOTE: We do not need to nudge min/max values in this op, since they would
|
||||
// already be adjusted while generating the quantized model.
|
||||
GPUOperation CreateQuantizeAndDequantize(
|
||||
const CreationContext& creation_context, const OperationDef& definition,
|
||||
const OperationDef& definition,
|
||||
const QuantizeAndDequantizeAttributes& attr);
|
||||
|
||||
} // namespace cl
|
||||
|
@ -56,8 +56,7 @@ TEST_F(OpenCLOperationTest, QuantAndDequant_Dim2Bits8) {
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
TensorFloat32 dst_tensor;
|
||||
GPUOperation operation =
|
||||
CreateQuantizeAndDequantize(creation_context_, op_def, attr);
|
||||
GPUOperation operation = CreateQuantizeAndDequantize(op_def, attr);
|
||||
ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation,
|
||||
BHWC(1, 3, 2, 1), &dst_tensor));
|
||||
EXPECT_THAT(dst_tensor.data,
|
||||
@ -91,8 +90,7 @@ TEST_F(OpenCLOperationTest, QuantAndDequant_Dim3Bits8_NegativeRange) {
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
TensorFloat32 dst_tensor;
|
||||
GPUOperation operation =
|
||||
CreateQuantizeAndDequantize(creation_context_, op_def, attr);
|
||||
GPUOperation operation = CreateQuantizeAndDequantize(op_def, attr);
|
||||
ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation,
|
||||
BHWC(1, 3, 1, 2), &dst_tensor));
|
||||
EXPECT_THAT(dst_tensor.data,
|
||||
@ -126,8 +124,7 @@ TEST_F(OpenCLOperationTest, QuantAndDequant_Dim3Bits16) {
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
TensorFloat32 dst_tensor;
|
||||
GPUOperation operation =
|
||||
CreateQuantizeAndDequantize(creation_context_, op_def, attr);
|
||||
GPUOperation operation = CreateQuantizeAndDequantize(op_def, attr);
|
||||
ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation,
|
||||
BHWC(1, 3, 1, 2), &dst_tensor));
|
||||
EXPECT_THAT(dst_tensor.data,
|
||||
@ -161,8 +158,7 @@ TEST_F(OpenCLOperationTest, QuantAndDequant_Dim2Bits16_NegativeRange) {
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
TensorFloat32 dst_tensor;
|
||||
GPUOperation operation =
|
||||
CreateQuantizeAndDequantize(creation_context_, op_def, attr);
|
||||
GPUOperation operation = CreateQuantizeAndDequantize(op_def, attr);
|
||||
ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation,
|
||||
BHWC(1, 3, 2, 1), &dst_tensor));
|
||||
EXPECT_THAT(dst_tensor.data,
|
||||
|
@ -21,8 +21,7 @@ limitations under the License.
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace cl {
|
||||
GPUOperation CreateReLU(const CreationContext& creation_context,
|
||||
const OperationDef& definition,
|
||||
GPUOperation CreateReLU(const OperationDef& definition,
|
||||
const ReLUAttributes& attr) {
|
||||
GPUOperation op(definition);
|
||||
op.elementwise_ = true;
|
||||
|
@ -25,8 +25,7 @@ namespace tflite {
|
||||
namespace gpu {
|
||||
namespace cl {
|
||||
|
||||
GPUOperation CreateReLU(const CreationContext& creation_context,
|
||||
const OperationDef& definition,
|
||||
GPUOperation CreateReLU(const OperationDef& definition,
|
||||
const ReLUAttributes& attr);
|
||||
|
||||
} // namespace cl
|
||||
|
@ -49,7 +49,7 @@ TEST_F(OpenCLOperationTest, ReLUNoClipNoAlpha) {
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
TensorFloat32 dst_tensor;
|
||||
GPUOperation operation = CreateReLU(creation_context_, op_def, attr);
|
||||
GPUOperation operation = CreateReLU(op_def, attr);
|
||||
ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation,
|
||||
BHWC(1, 2, 1, 2), &dst_tensor));
|
||||
EXPECT_THAT(dst_tensor.data,
|
||||
@ -76,7 +76,7 @@ TEST_F(OpenCLOperationTest, ReLUClip) {
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
TensorFloat32 dst_tensor;
|
||||
GPUOperation operation = CreateReLU(creation_context_, op_def, attr);
|
||||
GPUOperation operation = CreateReLU(op_def, attr);
|
||||
ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation,
|
||||
BHWC(1, 2, 1, 2), &dst_tensor));
|
||||
EXPECT_THAT(dst_tensor.data,
|
||||
@ -103,7 +103,7 @@ TEST_F(OpenCLOperationTest, ReLUAlpha) {
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
TensorFloat32 dst_tensor;
|
||||
GPUOperation operation = CreateReLU(creation_context_, op_def, attr);
|
||||
GPUOperation operation = CreateReLU(op_def, attr);
|
||||
ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation,
|
||||
BHWC(1, 2, 1, 2), &dst_tensor));
|
||||
EXPECT_THAT(dst_tensor.data,
|
||||
@ -130,7 +130,7 @@ TEST_F(OpenCLOperationTest, ReLUAlphaClip) {
|
||||
op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
|
||||
TensorFloat32 dst_tensor;
|
||||
GPUOperation operation = CreateReLU(creation_context_, op_def, attr);
|
||||
GPUOperation operation = CreateReLU(op_def, attr);
|
||||
ASSERT_OK(ExecuteGPUOperation(src_tensor, creation_context_, &operation,
|
||||
BHWC(1, 2, 1, 2), &dst_tensor));
|
||||
EXPECT_THAT(dst_tensor.data,
|
||||
|
@ -156,9 +156,8 @@ absl::Status GPUOperationFromNode(const CreationContext& creation_context,
|
||||
} else if (inputs.size() == 1 && node.operation.attributes.has_value()) {
|
||||
auto attr =
|
||||
absl::any_cast<ElementwiseAttributes>(node.operation.attributes);
|
||||
GPUOperation operation;
|
||||
RETURN_IF_ERROR(CreateElementwise(creation_context, op_def, op_type,
|
||||
attr, &operation));
|
||||
GPUOperation operation = CreateElementwise(
|
||||
creation_context.GetDeviceInfo(), op_def, op_type, attr);
|
||||
*gpu_op = absl::make_unique<GPUOperation>(std::move(operation));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
@ -286,12 +285,12 @@ absl::Status GPUOperationFromNode(const CreationContext& creation_context,
|
||||
case OperationType::QUANTIZE_AND_DEQUANTIZE: {
|
||||
auto attr = absl::any_cast<QuantizeAndDequantizeAttributes>(
|
||||
node.operation.attributes);
|
||||
SelectQuantizeAndDequantize(attr, creation_context, op_def, gpu_op);
|
||||
*gpu_op = SelectQuantizeAndDequantize(attr, op_def);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
case OperationType::RELU: {
|
||||
auto attr = absl::any_cast<ReLUAttributes>(node.operation.attributes);
|
||||
SelectReLU(creation_context, attr, op_def, gpu_op);
|
||||
*gpu_op = SelectReLU(attr, op_def);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
case OperationType::RESHAPE: {
|
||||
@ -357,9 +356,8 @@ absl::Status GPUOperationFromNode(const CreationContext& creation_context,
|
||||
} else if (inputs.size() == 1 && node.operation.attributes.has_value()) {
|
||||
auto attr =
|
||||
absl::any_cast<ElementwiseAttributes>(node.operation.attributes);
|
||||
GPUOperation operation;
|
||||
RETURN_IF_ERROR(CreateElementwise(creation_context, op_def, op_type,
|
||||
attr, &operation));
|
||||
GPUOperation operation = CreateElementwise(
|
||||
creation_context.GetDeviceInfo(), op_def, op_type, attr);
|
||||
*gpu_op = absl::make_unique<GPUOperation>(std::move(operation));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
@ -51,11 +51,9 @@ void SelectLSTM(const OperationDef& op_def, const DeviceInfo& device_info,
|
||||
*ptr = absl::make_unique<LSTM>(std::move(operation));
|
||||
}
|
||||
|
||||
void SelectReLU(const CreationContext& creation_context,
|
||||
const ReLUAttributes& attr, const OperationDef& op_def,
|
||||
std::unique_ptr<GPUOperation>* ptr) {
|
||||
GPUOperation relu = CreateReLU(creation_context, op_def, attr);
|
||||
*ptr = absl::make_unique<GPUOperation>(std::move(relu));
|
||||
std::unique_ptr<GPUOperation> SelectReLU(const ReLUAttributes& attr,
|
||||
const OperationDef& op_def) {
|
||||
return absl::make_unique<GPUOperation>(CreateReLU(op_def, attr));
|
||||
}
|
||||
|
||||
absl::Status SelectPReLU(const PReLUAttributes& attr,
|
||||
@ -193,13 +191,10 @@ std::unique_ptr<GPUOperation> SelectWinograd36To4x4(
|
||||
CreateWinograd36To4x4(device_info, op_def, biases));
|
||||
}
|
||||
|
||||
void SelectQuantizeAndDequantize(const QuantizeAndDequantizeAttributes& attr,
|
||||
const CreationContext& creation_context,
|
||||
const OperationDef& op_def,
|
||||
std::unique_ptr<GPUOperation>* ptr) {
|
||||
GPUOperation operation =
|
||||
CreateQuantizeAndDequantize(creation_context, op_def, attr);
|
||||
*ptr = absl::make_unique<GPUOperation>(std::move(operation));
|
||||
std::unique_ptr<GPUOperation> SelectQuantizeAndDequantize(
|
||||
const QuantizeAndDequantizeAttributes& attr, const OperationDef& op_def) {
|
||||
return absl::make_unique<GPUOperation>(
|
||||
CreateQuantizeAndDequantize(op_def, attr));
|
||||
}
|
||||
|
||||
} // namespace cl
|
||||
|
@ -31,9 +31,8 @@ namespace cl {
|
||||
void SelectLSTM(const OperationDef& op_def, const DeviceInfo& device_info,
|
||||
std::unique_ptr<GPUOperation>* ptr);
|
||||
|
||||
void SelectReLU(const CreationContext& creation_context,
|
||||
const ReLUAttributes& attr, const OperationDef& op_def,
|
||||
std::unique_ptr<GPUOperation>* ptr);
|
||||
std::unique_ptr<GPUOperation> SelectReLU(const ReLUAttributes& attr,
|
||||
const OperationDef& op_def);
|
||||
|
||||
absl::Status SelectPReLU(const PReLUAttributes& attr,
|
||||
const CreationContext& creation_context,
|
||||
@ -93,10 +92,8 @@ std::unique_ptr<GPUOperation> SelectWinograd36To4x4(
|
||||
const DeviceInfo& device_info, const OperationDef& op_def,
|
||||
const tflite::gpu::Tensor<Linear, DataType::FLOAT32>& biases);
|
||||
|
||||
void SelectQuantizeAndDequantize(const QuantizeAndDequantizeAttributes& attr,
|
||||
const CreationContext& creation_context,
|
||||
const OperationDef& op_def,
|
||||
std::unique_ptr<GPUOperation>* ptr);
|
||||
std::unique_ptr<GPUOperation> SelectQuantizeAndDequantize(
|
||||
const QuantizeAndDequantizeAttributes& attr, const OperationDef& op_def);
|
||||
|
||||
} // namespace cl
|
||||
} // namespace gpu
|
||||
|
Loading…
x
Reference in New Issue
Block a user