Experimental feature.
Using special kernels for inference improvements. Added new hint to control behavior. PiperOrigin-RevId: 322465118 Change-Id: I62c2a3ddc75907f2d9e455b7454e1de8c54a9881
This commit is contained in:
parent
777bdc36a2
commit
2a150a026a
tensorflow/lite/delegates/gpu/cl
@ -366,6 +366,7 @@ cc_library(
|
|||||||
":tensor_type",
|
":tensor_type",
|
||||||
"//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation",
|
"//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation",
|
||||||
"//tensorflow/lite/delegates/gpu/cl/selectors:operation_selector",
|
"//tensorflow/lite/delegates/gpu/cl/selectors:operation_selector",
|
||||||
|
"//tensorflow/lite/delegates/gpu/cl/selectors:special_selector",
|
||||||
"//tensorflow/lite/delegates/gpu/common:data_type",
|
"//tensorflow/lite/delegates/gpu/common:data_type",
|
||||||
"//tensorflow/lite/delegates/gpu/common:memory_management",
|
"//tensorflow/lite/delegates/gpu/common:memory_management",
|
||||||
"//tensorflow/lite/delegates/gpu/common:model",
|
"//tensorflow/lite/delegates/gpu/common:model",
|
||||||
|
@ -30,6 +30,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/delegates/gpu/cl/model_hints.h"
|
#include "tensorflow/lite/delegates/gpu/cl/model_hints.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/precision.h"
|
#include "tensorflow/lite/delegates/gpu/cl/precision.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.h"
|
#include "tensorflow/lite/delegates/gpu/cl/selectors/operation_selector.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/cl/selectors/special_selector.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/storage_type_util.h"
|
#include "tensorflow/lite/delegates/gpu/cl/storage_type_util.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/cl/tensor_type.h"
|
#include "tensorflow/lite/delegates/gpu/cl/tensor_type.h"
|
||||||
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
|
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
|
||||||
@ -261,6 +262,12 @@ void InferenceContext::ReserveGraphTensors(
|
|||||||
absl::Status InferenceContext::ConvertOperations(
|
absl::Status InferenceContext::ConvertOperations(
|
||||||
const CreationContext& creation_context, const GraphFloat32& graph,
|
const CreationContext& creation_context, const GraphFloat32& graph,
|
||||||
ModelHints hints) {
|
ModelHints hints) {
|
||||||
|
std::map<ValueId, TensorDescriptor> tensor_descriptors;
|
||||||
|
const auto values = graph.values();
|
||||||
|
for (auto value : values) {
|
||||||
|
tensor_descriptors[value->id] = tensor_reserver_.Get(value->id).descriptor;
|
||||||
|
}
|
||||||
|
std::set<NodeId> consumed_nodes;
|
||||||
std::vector<Node*> graph_nodes = graph.nodes();
|
std::vector<Node*> graph_nodes = graph.nodes();
|
||||||
std::map<ValueId, int>
|
std::map<ValueId, int>
|
||||||
tensor_usages; // keeps latest index of operation that updated tensor
|
tensor_usages; // keeps latest index of operation that updated tensor
|
||||||
@ -270,45 +277,54 @@ absl::Status InferenceContext::ConvertOperations(
|
|||||||
}
|
}
|
||||||
for (int i = 0; i < graph_nodes.size(); ++i) {
|
for (int i = 0; i < graph_nodes.size(); ++i) {
|
||||||
const Node& node = *graph_nodes[i];
|
const Node& node = *graph_nodes[i];
|
||||||
auto inputs = graph.FindInputs(node.id);
|
if (consumed_nodes.find(node.id) != consumed_nodes.end()) {
|
||||||
auto outputs = graph.FindOutputs(node.id);
|
continue;
|
||||||
|
|
||||||
// Reordering of input ids and updating of temporary tensors_usage struct.
|
|
||||||
// This stage is necessary because we are building OperationDef that rely on
|
|
||||||
// order of input ids. But we also should have input id on first position
|
|
||||||
// that potentially can be "linking" tensor and as result eliminated(unused)
|
|
||||||
// We apply it only for ADD operation, because of ADD associativity and
|
|
||||||
// ADD can be linked.
|
|
||||||
// In current approach "linking" tensor can be only latest written
|
|
||||||
// tensor(during linear order of execution) among input tensors.
|
|
||||||
if (IsGenericAdd(node, inputs, outputs)) {
|
|
||||||
int latest_written_tensor_index = 0;
|
|
||||||
int last_usage = tensor_usages[inputs[0]->id];
|
|
||||||
for (int j = 1; j < inputs.size(); ++j) {
|
|
||||||
if (tensor_usages[inputs[j]->id] > last_usage) {
|
|
||||||
last_usage = tensor_usages[inputs[j]->id];
|
|
||||||
latest_written_tensor_index = j;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
std::swap(inputs[0], inputs[latest_written_tensor_index]);
|
|
||||||
}
|
|
||||||
for (const auto& out_id : outputs) {
|
|
||||||
tensor_usages[out_id->id] = i;
|
|
||||||
}
|
|
||||||
|
|
||||||
OperationDef op_def;
|
|
||||||
op_def.precision = precision_;
|
|
||||||
for (int j = 0; j < inputs.size(); ++j) {
|
|
||||||
op_def.src_tensors.push_back(
|
|
||||||
tensor_reserver_.Get(inputs[j]->id).descriptor);
|
|
||||||
}
|
|
||||||
for (int j = 0; j < outputs.size(); ++j) {
|
|
||||||
op_def.dst_tensors.push_back(
|
|
||||||
tensor_reserver_.Get(outputs[j]->id).descriptor);
|
|
||||||
}
|
}
|
||||||
GPUOperationsSubgraph gpu_subgraph;
|
GPUOperationsSubgraph gpu_subgraph;
|
||||||
RETURN_IF_ERROR(GPUOperationFromNode(creation_context, op_def, hints,
|
if (hints.Check(ModelHints::kAllowSpecialKernels) &&
|
||||||
inputs, outputs, node, &gpu_subgraph));
|
GPUSubgraphFromGraph(creation_context, precision_, graph, node.id,
|
||||||
|
tensor_descriptors, &consumed_nodes, &gpu_subgraph)
|
||||||
|
.ok()) {
|
||||||
|
// Mapping of subgraph (set of nodes) to GPU operations. Should happen
|
||||||
|
// before straigtforward mapping.
|
||||||
|
} else {
|
||||||
|
// Straigtforward mapping of one graph node to GPU operations.
|
||||||
|
auto inputs = graph.FindInputs(node.id);
|
||||||
|
auto outputs = graph.FindOutputs(node.id);
|
||||||
|
// Reordering of input ids and updating of temporary tensors_usage struct.
|
||||||
|
// This stage is necessary because we are building OperationDef that rely
|
||||||
|
// on order of input ids. But we also should have input id on first
|
||||||
|
// position that potentially can be "linking" tensor and as result
|
||||||
|
// eliminated(unused) We apply it only for ADD operation, because of ADD
|
||||||
|
// associativity and ADD can be linked. In current approach "linking"
|
||||||
|
// tensor can be only latest written tensor(during linear order of
|
||||||
|
// execution) among input tensors.
|
||||||
|
if (IsGenericAdd(node, inputs, outputs)) {
|
||||||
|
int latest_written_tensor_index = 0;
|
||||||
|
int last_usage = tensor_usages[inputs[0]->id];
|
||||||
|
for (int j = 1; j < inputs.size(); ++j) {
|
||||||
|
if (tensor_usages[inputs[j]->id] > last_usage) {
|
||||||
|
last_usage = tensor_usages[inputs[j]->id];
|
||||||
|
latest_written_tensor_index = j;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::swap(inputs[0], inputs[latest_written_tensor_index]);
|
||||||
|
}
|
||||||
|
consumed_nodes.insert(node.id);
|
||||||
|
OperationDef op_def;
|
||||||
|
op_def.precision = precision_;
|
||||||
|
for (int j = 0; j < inputs.size(); ++j) {
|
||||||
|
op_def.src_tensors.push_back(
|
||||||
|
tensor_reserver_.Get(inputs[j]->id).descriptor);
|
||||||
|
}
|
||||||
|
for (int j = 0; j < outputs.size(); ++j) {
|
||||||
|
op_def.dst_tensors.push_back(
|
||||||
|
tensor_reserver_.Get(outputs[j]->id).descriptor);
|
||||||
|
}
|
||||||
|
RETURN_IF_ERROR(GPUOperationFromNode(creation_context, op_def, hints,
|
||||||
|
inputs, outputs, node,
|
||||||
|
&gpu_subgraph));
|
||||||
|
}
|
||||||
std::unordered_map<int, ValueId> mapping_to_global_ids;
|
std::unordered_map<int, ValueId> mapping_to_global_ids;
|
||||||
for (int j = 0; j < gpu_subgraph.new_tensors.size(); ++j) {
|
for (int j = 0; j < gpu_subgraph.new_tensors.size(); ++j) {
|
||||||
const auto& t = gpu_subgraph.new_tensors[j];
|
const auto& t = gpu_subgraph.new_tensors[j];
|
||||||
@ -324,7 +340,7 @@ absl::Status InferenceContext::ConvertOperations(
|
|||||||
for (int j = 0; j < gpu_op.input_ids.size(); ++j) {
|
for (int j = 0; j < gpu_op.input_ids.size(); ++j) {
|
||||||
int id = gpu_op.input_ids[j];
|
int id = gpu_op.input_ids[j];
|
||||||
if (id >= 0) {
|
if (id >= 0) {
|
||||||
cl_node.inputs[j] = inputs[id]->id;
|
cl_node.inputs[j] = id;
|
||||||
} else {
|
} else {
|
||||||
cl_node.inputs[j] = mapping_to_global_ids[-(id + 1)];
|
cl_node.inputs[j] = mapping_to_global_ids[-(id + 1)];
|
||||||
}
|
}
|
||||||
@ -333,7 +349,8 @@ absl::Status InferenceContext::ConvertOperations(
|
|||||||
for (int j = 0; j < gpu_op.output_ids.size(); ++j) {
|
for (int j = 0; j < gpu_op.output_ids.size(); ++j) {
|
||||||
int id = gpu_op.output_ids[j];
|
int id = gpu_op.output_ids[j];
|
||||||
if (id >= 0) {
|
if (id >= 0) {
|
||||||
cl_node.outputs[j] = outputs[id]->id;
|
cl_node.outputs[j] = id;
|
||||||
|
tensor_usages[id] = i;
|
||||||
} else {
|
} else {
|
||||||
cl_node.outputs[j] = mapping_to_global_ids[-(id + 1)];
|
cl_node.outputs[j] = mapping_to_global_ids[-(id + 1)];
|
||||||
}
|
}
|
||||||
|
@ -25,13 +25,18 @@ namespace cl {
|
|||||||
struct ModelHints {
|
struct ModelHints {
|
||||||
using ModelHint = uint64_t;
|
using ModelHint = uint64_t;
|
||||||
|
|
||||||
// By default we want the fastest inference
|
// By default we want the fastest inference.
|
||||||
static constexpr ModelHint kFastestInference = 0x00000000;
|
static constexpr ModelHint kFastestInference = 0x00000000;
|
||||||
// Can improve compilation time, but inference can be slower
|
// Can improve compilation time, but inference can be slower.
|
||||||
static constexpr ModelHint kReduceKernelsCount = 0x00000001;
|
static constexpr ModelHint kReduceKernelsCount = 0x00000001;
|
||||||
// Can improve tuning time, but inference can be slower
|
// Can improve tuning time, but inference can be slower.
|
||||||
static constexpr ModelHint kFastTuning = 0x00000002;
|
static constexpr ModelHint kFastTuning = 0x00000002;
|
||||||
|
|
||||||
|
// Experimental.
|
||||||
|
// Can improve performance and memory consumption, but slow down
|
||||||
|
// initialization a lot and create more kernels.
|
||||||
|
static constexpr ModelHint kAllowSpecialKernels = 0x00000004;
|
||||||
|
|
||||||
void Add(ModelHint hint) {
|
void Add(ModelHint hint) {
|
||||||
if (hint == kFastestInference) {
|
if (hint == kFastestInference) {
|
||||||
hints = kFastestInference;
|
hints = kFastestInference;
|
||||||
|
@ -152,6 +152,26 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "special_selector",
|
||||||
|
srcs = ["special_selector.cc"],
|
||||||
|
hdrs = ["special_selector.h"],
|
||||||
|
deps = [
|
||||||
|
":subgraph",
|
||||||
|
"//tensorflow/lite/delegates/gpu/cl:cl_device",
|
||||||
|
"//tensorflow/lite/delegates/gpu/cl:tensor_type",
|
||||||
|
"//tensorflow/lite/delegates/gpu/cl/kernels:gpu_operation",
|
||||||
|
"//tensorflow/lite/delegates/gpu/cl/kernels/special:depthwise_conv_plus_1x1_conv",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:data_type",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:model",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:operations",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:shape",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:status",
|
||||||
|
"//tensorflow/lite/delegates/gpu/common:tensor",
|
||||||
|
"@com_google_absl//absl/types:any",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "subgraph",
|
name = "subgraph",
|
||||||
srcs = ["subgraph.cc"],
|
srcs = ["subgraph.cc"],
|
||||||
|
111
tensorflow/lite/delegates/gpu/cl/selectors/special_selector.cc
Normal file
111
tensorflow/lite/delegates/gpu/cl/selectors/special_selector.cc
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/cl/selectors/special_selector.h"
|
||||||
|
|
||||||
|
#include "absl/types/any.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/cl/cl_device.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/cl/kernels/special/depthwise_conv_plus_1x1_conv.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/cl/tensor_type.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace cl {
|
||||||
|
namespace {
|
||||||
|
absl::Status TryDepthwiseConvPlus1x1Conv(
|
||||||
|
const CreationContext& creation_context, CalculationsPrecision precision,
|
||||||
|
const GraphFloat32& graph, NodeId first_node_id,
|
||||||
|
const std::map<ValueId, TensorDescriptor>& tensor_descriptors,
|
||||||
|
std::set<NodeId>* consumed_nodes, GPUOperationsSubgraph* gpu_subgraph) {
|
||||||
|
auto* dw_node = graph.GetNode(first_node_id);
|
||||||
|
if (OperationTypeFromString(dw_node->operation.type) !=
|
||||||
|
OperationType::DEPTHWISE_CONVOLUTION) {
|
||||||
|
return absl::NotFoundError("DepthwiseConvPlus1x1Conv not suitable.");
|
||||||
|
}
|
||||||
|
auto dw_outputs = graph.FindOutputs(dw_node->id);
|
||||||
|
auto consumers = graph.FindConsumers(dw_outputs[0]->id);
|
||||||
|
if (consumers.size() != 1) {
|
||||||
|
return absl::NotFoundError("DepthwiseConvPlus1x1Conv not suitable.");
|
||||||
|
}
|
||||||
|
auto* conv_node = consumers[0];
|
||||||
|
if (consumed_nodes->find(conv_node->id) != consumed_nodes->end()) {
|
||||||
|
return absl::NotFoundError("DepthwiseConvPlus1x1Conv not suitable.");
|
||||||
|
}
|
||||||
|
if (OperationTypeFromString(conv_node->operation.type) !=
|
||||||
|
OperationType::CONVOLUTION_2D) {
|
||||||
|
return absl::NotFoundError("DepthwiseConvPlus1x1Conv not suitable.");
|
||||||
|
}
|
||||||
|
if (graph.FindInputs(conv_node->id).size() != 1) {
|
||||||
|
return absl::NotFoundError("DepthwiseConvPlus1x1Conv not suitable.");
|
||||||
|
}
|
||||||
|
auto dw_attr = absl::any_cast<DepthwiseConvolution2DAttributes>(
|
||||||
|
dw_node->operation.attributes);
|
||||||
|
auto conv_attr =
|
||||||
|
absl::any_cast<Convolution2DAttributes>(conv_node->operation.attributes);
|
||||||
|
auto dw_inputs = graph.FindInputs(dw_node->id);
|
||||||
|
auto conv_outputs = graph.FindOutputs(conv_node->id);
|
||||||
|
OperationDef op_def;
|
||||||
|
op_def.precision = precision;
|
||||||
|
auto it = tensor_descriptors.find(dw_inputs[0]->id);
|
||||||
|
if (it != tensor_descriptors.end()) {
|
||||||
|
op_def.src_tensors.push_back(it->second);
|
||||||
|
}
|
||||||
|
it = tensor_descriptors.find(conv_outputs[0]->id);
|
||||||
|
if (it != tensor_descriptors.end()) {
|
||||||
|
op_def.dst_tensors.push_back(it->second);
|
||||||
|
}
|
||||||
|
if (!IsDepthwiseConvPlus1x1ConvSupported(*creation_context.device, op_def,
|
||||||
|
dw_attr, conv_attr)) {
|
||||||
|
return absl::NotFoundError("DepthwiseConvPlus1x1Conv not suitable.");
|
||||||
|
}
|
||||||
|
std::unique_ptr<GPUOperation>* gpu_op =
|
||||||
|
InitSingleOpSubgraph(dw_inputs, conv_outputs, gpu_subgraph);
|
||||||
|
DepthwiseConvPlus1x1Conv operation;
|
||||||
|
RETURN_IF_ERROR(CreateDepthwiseConvPlus1x1Conv(
|
||||||
|
creation_context, op_def, dw_attr, conv_attr, &operation));
|
||||||
|
*gpu_op = absl::make_unique<DepthwiseConvPlus1x1Conv>(std::move(operation));
|
||||||
|
consumed_nodes->insert(dw_node->id);
|
||||||
|
consumed_nodes->insert(conv_node->id);
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
absl::Status GPUSubgraphFromGraph(
|
||||||
|
const CreationContext& creation_context, CalculationsPrecision precision,
|
||||||
|
const GraphFloat32& graph, NodeId first_node_id,
|
||||||
|
const std::map<ValueId, TensorDescriptor>& tensor_descriptors,
|
||||||
|
std::set<NodeId>* consumed_nodes, GPUOperationsSubgraph* gpu_subgraph) {
|
||||||
|
if (!creation_context.device->IsNvidia()) {
|
||||||
|
return absl::NotFoundError(
|
||||||
|
"Experimental feature, enabled for NVidia only, but device is not "
|
||||||
|
"nvidia gpu.");
|
||||||
|
}
|
||||||
|
if (TryDepthwiseConvPlus1x1Conv(creation_context, precision, graph,
|
||||||
|
first_node_id, tensor_descriptors,
|
||||||
|
consumed_nodes, gpu_subgraph)
|
||||||
|
.ok()) {
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
return absl::NotFoundError("No special combination.");
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cl
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
@ -0,0 +1,43 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_SPECIAL_SELECTOR_H_
|
||||||
|
#define TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_SPECIAL_SELECTOR_H_
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
#include <set>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/cl/selectors/subgraph.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/cl/tensor_type.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||||
|
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace gpu {
|
||||||
|
namespace cl {
|
||||||
|
|
||||||
|
absl::Status GPUSubgraphFromGraph(
|
||||||
|
const CreationContext& creation_context, CalculationsPrecision precision,
|
||||||
|
const GraphFloat32& graph, NodeId first_node_id,
|
||||||
|
const std::map<ValueId, TensorDescriptor>& tensor_descriptors,
|
||||||
|
std::set<NodeId>* consumed_nodes, GPUOperationsSubgraph* gpu_subgraph);
|
||||||
|
|
||||||
|
} // namespace cl
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_LITE_DELEGATES_GPU_CL_SELECTORS_SPECIAL_SELECTOR_H_
|
@ -32,10 +32,10 @@ std::unique_ptr<GPUOperation>* InitSingleOpSubgraph(
|
|||||||
gpu_subgraph->new_tensors.clear();
|
gpu_subgraph->new_tensors.clear();
|
||||||
gpu_subgraph->operations.push_back({});
|
gpu_subgraph->operations.push_back({});
|
||||||
for (int i = 0; i < inputs.size(); ++i) {
|
for (int i = 0; i < inputs.size(); ++i) {
|
||||||
gpu_subgraph->operations[0].input_ids.push_back(i);
|
gpu_subgraph->operations[0].input_ids.push_back(inputs[i]->id);
|
||||||
}
|
}
|
||||||
for (int i = 0; i < outputs.size(); ++i) {
|
for (int i = 0; i < outputs.size(); ++i) {
|
||||||
gpu_subgraph->operations[0].output_ids.push_back(i);
|
gpu_subgraph->operations[0].output_ids.push_back(outputs[i]->id);
|
||||||
}
|
}
|
||||||
|
|
||||||
return &gpu_subgraph->operations[0].operation;
|
return &gpu_subgraph->operations[0].operation;
|
||||||
|
@ -44,6 +44,7 @@ absl::Status RunModelSample(const std::string& model_name) {
|
|||||||
? CalculationsPrecision::F16
|
? CalculationsPrecision::F16
|
||||||
: CalculationsPrecision::F32;
|
: CalculationsPrecision::F32;
|
||||||
create_info.storage_type = GetFastestStorageType(env.device());
|
create_info.storage_type = GetFastestStorageType(env.device());
|
||||||
|
create_info.hints.Add(ModelHints::kAllowSpecialKernels);
|
||||||
std::cout << "Precision: " << ToString(create_info.precision) << std::endl;
|
std::cout << "Precision: " << ToString(create_info.precision) << std::endl;
|
||||||
std::cout << "Storage type: " << ToString(create_info.storage_type)
|
std::cout << "Storage type: " << ToString(create_info.storage_type)
|
||||||
<< std::endl;
|
<< std::endl;
|
||||||
|
Loading…
Reference in New Issue
Block a user