Merging of nodes done as in OpenCL inference context.
PiperOrigin-RevId: 351217823 Change-Id: I00e795dc5bd0f2382e521799501ed23e2530a89a
This commit is contained in:
parent
ae83b4fa29
commit
6e2d5e6c55
@ -104,7 +104,6 @@ using ComputeTaskDescriptorPtr = std::shared_ptr<ComputeTaskDescriptor>;
|
||||
|
||||
struct NodeDescriptor {
|
||||
ComputeTaskDescriptorPtr task;
|
||||
// Unique ID to match the graph compilation errors.
|
||||
int id;
|
||||
std::string description;
|
||||
std::vector<ValueId> src_tensors_ids;
|
||||
|
@ -38,6 +38,17 @@ namespace tflite {
|
||||
namespace gpu {
|
||||
namespace metal {
|
||||
namespace {
|
||||
|
||||
bool IsReady(const std::set<ValueId>& ready_tensors,
|
||||
const NodeDescriptor& node) {
|
||||
for (const ValueId in_id : node.src_tensors_ids) {
|
||||
if (ready_tensors.find(in_id) == ready_tensors.end()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void AddUsage(ValueId id, int task_index,
|
||||
std::map<ValueId, int2>* usage_records) {
|
||||
auto it = usage_records->find(id);
|
||||
@ -82,346 +93,6 @@ bool IsGenericAdd(const Node& node, const std::vector<Value*>& inputs,
|
||||
return true;
|
||||
}
|
||||
|
||||
// Allows to get result about the graph compilation to validate graph. This
|
||||
// information helps to find a cause of performance degradation, like misfusing.
|
||||
struct OptimizationInfo {
|
||||
// Initial operations count before compilation.
|
||||
int operations_count;
|
||||
// GPU tasks count after fusion and splitting complex operations into few GPU
|
||||
// subtasks.
|
||||
int gpu_tasks_count;
|
||||
// Some operations are not used due to dependencies of the graph.
|
||||
std::vector<int> unused_operations;
|
||||
// Used inputs.
|
||||
std::vector<ValueId> input_buffer_ids;
|
||||
// Unused inputs. Requested outputs do not require this inputs to be used.
|
||||
std::vector<ValueId> unused_input_buffer_ids;
|
||||
// The outputs are deducted by the graph but not requested by user.
|
||||
std::vector<ValueId> extra_output_buffer_ids;
|
||||
// Outputs that are requested but can't be calculated by the graph.
|
||||
std::vector<ValueId> missing_output_buffer_ids;
|
||||
};
|
||||
|
||||
using FusionSequence = std::vector<NodeDescriptor>;
|
||||
|
||||
bool Contains(const std::vector<ValueId>& container, ValueId value) {
|
||||
return std::find(container.begin(), container.end(), value) !=
|
||||
container.end();
|
||||
}
|
||||
|
||||
template <class T>
|
||||
bool Contains(const std::vector<T>& container, ValueId value) {
|
||||
for (const auto& buffer : container) {
|
||||
if (buffer.id == value) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Checks if all elements of the narrow vector exist in the wide vector. Vectors
|
||||
// are expected to be unsorted.
|
||||
bool Contains(const std::vector<ValueId>& wide,
|
||||
const std::vector<ValueId>& narrow) {
|
||||
if (narrow.empty() || narrow.size() > wide.size()) {
|
||||
return false;
|
||||
}
|
||||
std::set<ValueId> wide_sorted;
|
||||
wide_sorted.insert(wide.begin(), wide.end());
|
||||
for (auto element : narrow) {
|
||||
if (std::find(wide.begin(), wide.end(), element) == wide.end()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
uint32_t BufferUseCount(ValueId id,
|
||||
const std::list<NodeDescriptor>& descriptors,
|
||||
std::list<FusionSequence>* chains) {
|
||||
uint32_t use_count = 0;
|
||||
// Buffer may be read by both processed and not processed operations.
|
||||
for (auto& desc : descriptors) {
|
||||
if (Contains(desc.src_tensors_ids, id)) {
|
||||
use_count++;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& chain : *chains) {
|
||||
for (auto& desc : chain) {
|
||||
if (Contains(desc.src_tensors_ids, id)) {
|
||||
use_count++;
|
||||
}
|
||||
}
|
||||
}
|
||||
return use_count;
|
||||
}
|
||||
|
||||
// Examines if the second operation can be linked to the first one. Linking may
|
||||
// be skipped in the situation when conflict may happen: if first operation's
|
||||
// output is used by more than 1 other operation.
|
||||
bool CanFuseOperations(const NodeDescriptor& first,
|
||||
const NodeDescriptor& second,
|
||||
const std::vector<ValueId>& output_ids,
|
||||
const std::list<NodeDescriptor>& descriptors,
|
||||
std::list<FusionSequence>* chains) {
|
||||
return second.task->is_linkable &&
|
||||
!Contains(output_ids, first.dst_tensors_ids[0]) &&
|
||||
BufferUseCount(first.dst_tensors_ids[0], descriptors, chains) == 1;
|
||||
}
|
||||
|
||||
// Takes an unsorted list of task descriptors, builds a list of chains. Each
|
||||
// chain is a list of task descriptors that can be fused into a single GPU task.
|
||||
// Building is started from the input IDs and building statistic is filled.
|
||||
void BuildFusableChains(const std::vector<ValueId>& input_ids,
|
||||
const std::vector<ValueId>& output_ids,
|
||||
std::list<NodeDescriptor>* descriptors,
|
||||
std::list<FusionSequence>* chains,
|
||||
std::vector<int>* unused_ids) {
|
||||
// Proxy tasks for inputs - only output is valid on this elements.
|
||||
for (auto input_id : input_ids) {
|
||||
auto desc = std::make_shared<ComputeTaskDescriptor>();
|
||||
desc->is_linkable = true;
|
||||
desc->AddDstTensor("", {});
|
||||
NodeDescriptor node;
|
||||
node.task = desc;
|
||||
node.dst_tensors_ids = {input_id};
|
||||
chains->push_back({node});
|
||||
}
|
||||
|
||||
if (descriptors->empty()) return;
|
||||
// Get all possible operations - grow-up chains.
|
||||
bool added;
|
||||
do {
|
||||
// At least one element must be added to any chain at this step.
|
||||
added = false;
|
||||
for (auto it = descriptors->begin(); it != descriptors->end();) {
|
||||
const NodeDescriptor& task_descriptor = *it;
|
||||
|
||||
// Gather all outputs of all chains to check with.
|
||||
std::vector<ValueId> ready_buffer_ids;
|
||||
ready_buffer_ids.reserve(chains->size());
|
||||
for (const auto& chain : *chains) {
|
||||
ready_buffer_ids.push_back(chain.back().dst_tensors_ids[0]);
|
||||
}
|
||||
|
||||
// Check if all inputs of this operation are ready.
|
||||
if (Contains(ready_buffer_ids, task_descriptor.src_tensors_ids)) {
|
||||
// Now find a chain to fuse with.
|
||||
bool fused = false;
|
||||
for (auto& chain : *chains) {
|
||||
// We can fuse only single output for now.
|
||||
const bool can_link = task_descriptor.src_tensors_ids[0] ==
|
||||
chain.back().dst_tensors_ids[0];
|
||||
if (can_link && CanFuseOperations(chain.back(), task_descriptor,
|
||||
output_ids, *descriptors, chains)) {
|
||||
chain.push_back(task_descriptor);
|
||||
fused = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!fused) {
|
||||
chains->push_back({task_descriptor});
|
||||
}
|
||||
|
||||
// Remove operation from original list and start from the beginning.
|
||||
descriptors->erase(it);
|
||||
added = true;
|
||||
break;
|
||||
} else {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
} while (!descriptors->empty() && added);
|
||||
|
||||
unused_ids->reserve(descriptors->size());
|
||||
for (const auto& desc : *descriptors) {
|
||||
unused_ids->push_back(desc.id);
|
||||
}
|
||||
}
|
||||
|
||||
// Accepts unsorted list of chains and returns sorted list with the order of GPU
|
||||
// task execution.
|
||||
std::list<FusionSequence> SortChains(
|
||||
const std::vector<ValueId>& graph_input_ids,
|
||||
std::list<FusionSequence>* chains) {
|
||||
std::list<FusionSequence> sorted_chains;
|
||||
while (!chains->empty()) {
|
||||
// Collect ready buffers.
|
||||
std::vector<ValueId> ready_buffer_ids;
|
||||
ready_buffer_ids.reserve(graph_input_ids.size() + sorted_chains.size());
|
||||
ready_buffer_ids.insert(ready_buffer_ids.begin(), graph_input_ids.begin(),
|
||||
graph_input_ids.end());
|
||||
for (auto& chain : sorted_chains) {
|
||||
ready_buffer_ids.push_back(chain.back().dst_tensors_ids[0]);
|
||||
}
|
||||
|
||||
for (auto it = chains->begin(); it != chains->end();) {
|
||||
const FusionSequence& chain = *it;
|
||||
|
||||
// If the input is also is the output in the same chain - eliminate
|
||||
// because it used internally inside sthis chain only.
|
||||
std::vector<ValueId> elements_output_buffer_ids;
|
||||
elements_output_buffer_ids.reserve(chain.size());
|
||||
for (const auto& element : chain) {
|
||||
elements_output_buffer_ids.push_back(element.dst_tensors_ids[0]);
|
||||
}
|
||||
|
||||
// Collect all inputs also for linked operations.
|
||||
std::vector<ValueId> elements_input_buffer_ids;
|
||||
for (const auto& element : chain) {
|
||||
for (const auto& id : element.src_tensors_ids) {
|
||||
if (!Contains(elements_output_buffer_ids, id)) {
|
||||
elements_input_buffer_ids.push_back(id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (Contains(ready_buffer_ids, elements_input_buffer_ids)) {
|
||||
// All input buffers for all elements of this chain are ready.
|
||||
sorted_chains.push_back(chain);
|
||||
it = chains->erase(it);
|
||||
} else {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
}
|
||||
return sorted_chains;
|
||||
}
|
||||
|
||||
// If a graph structure contains unused outputs then it can lead to unused
|
||||
// operations and unused input buffers. It's not an error but some sort of
|
||||
// warning.
|
||||
std::vector<ValueId> GetUsedInputBufferIds(
|
||||
const std::list<FusionSequence>& sorted_chains) {
|
||||
// Match requested outputs with all outputs and intermediate buffers.
|
||||
std::vector<ValueId> output_and_intermediate_ids;
|
||||
output_and_intermediate_ids.reserve(sorted_chains.size());
|
||||
std::set<ValueId> input_and_intermediate_ids;
|
||||
for (auto it = sorted_chains.begin(); it != sorted_chains.end(); ++it) {
|
||||
output_and_intermediate_ids.push_back(it->back().dst_tensors_ids[0]);
|
||||
for (const auto& id : it->front().src_tensors_ids) {
|
||||
input_and_intermediate_ids.insert(id);
|
||||
}
|
||||
}
|
||||
std::vector<ValueId> input_ids;
|
||||
for (ValueId id : input_and_intermediate_ids) {
|
||||
if (!Contains(output_and_intermediate_ids, id)) {
|
||||
input_ids.push_back(id);
|
||||
}
|
||||
}
|
||||
return input_ids;
|
||||
}
|
||||
|
||||
// If a buffer is requested as output from the graph but the graph structure
|
||||
// can't provide this buffer by output (can't deduct), that means the graph
|
||||
// structure is incorrect.
|
||||
std::vector<ValueId> GetMissingOutputBufferIds(
|
||||
const std::vector<ValueId>& output_ids,
|
||||
const std::list<FusionSequence>& sorted_chains) {
|
||||
// Match requested outputs with all output and intermediate buffers.
|
||||
std::vector<ValueId> output_and_intermediate_ids;
|
||||
output_and_intermediate_ids.reserve(sorted_chains.size());
|
||||
for (auto it = sorted_chains.begin(); it != sorted_chains.end(); ++it) {
|
||||
output_and_intermediate_ids.push_back(it->back().dst_tensors_ids[0]);
|
||||
}
|
||||
std::vector<ValueId> missing_output_ids;
|
||||
for (ValueId id : output_ids) {
|
||||
if (!Contains(output_and_intermediate_ids, id)) {
|
||||
missing_output_ids.push_back(id);
|
||||
}
|
||||
}
|
||||
return missing_output_ids;
|
||||
}
|
||||
|
||||
// Graph may contain leafs with outputs that are not requested. It wastes GPU
|
||||
// computations.
|
||||
std::vector<ValueId> DeductOutputBufferIds(
|
||||
const std::vector<ValueId>& output_ids,
|
||||
const std::list<FusionSequence>& sorted_chains) {
|
||||
std::vector<ValueId> extra_output_ids;
|
||||
// Detect all unused output buffers - all outputs.
|
||||
for (auto it1 = sorted_chains.begin(); it1 != sorted_chains.end(); ++it1) {
|
||||
bool found_as_input = false;
|
||||
for (auto it2 = sorted_chains.begin(); it2 != sorted_chains.end(); ++it2) {
|
||||
if (it1 != it2) {
|
||||
std::vector<ValueId> input_ids;
|
||||
for (const auto& element : *it2) {
|
||||
for (const auto& id : element.src_tensors_ids) {
|
||||
input_ids.push_back(id);
|
||||
}
|
||||
}
|
||||
if (Contains(input_ids, it1->back().dst_tensors_ids[0])) {
|
||||
found_as_input = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!found_as_input) {
|
||||
if (!Contains(output_ids, it1->back().dst_tensors_ids[0])) {
|
||||
extra_output_ids.push_back(it1->back().dst_tensors_ids[0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
return extra_output_ids;
|
||||
}
|
||||
|
||||
// Delete all unused task descriptors that have non-requested outputs.
|
||||
// !delete not the whole chain but only the last element, then others.!
|
||||
std::vector<int> DeleteUnusedTasks(const std::vector<ValueId>& output_ids,
|
||||
std::list<FusionSequence>* chains) {
|
||||
std::vector<int> unused_operations;
|
||||
for (auto it1 = chains->rbegin(); it1 != chains->rend();) {
|
||||
// Don't delete if output is requested.
|
||||
if (Contains(output_ids, it1->back().dst_tensors_ids[0])) {
|
||||
++it1;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Don't delete if some operation uses the output.
|
||||
bool output_used = false;
|
||||
for (auto it2 = chains->rbegin(); it2 != chains->rend(); ++it2) {
|
||||
std::vector<ValueId> input_ids;
|
||||
for (const auto& element : *it2) {
|
||||
for (const auto& id : element.src_tensors_ids) {
|
||||
input_ids.push_back(id);
|
||||
}
|
||||
}
|
||||
if (Contains(input_ids, it1->back().dst_tensors_ids[0])) {
|
||||
output_used = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (output_used) {
|
||||
++it1;
|
||||
continue;
|
||||
}
|
||||
// Delete if not used.
|
||||
unused_operations.push_back(it1->back().id);
|
||||
it1 = decltype(it1){chains->erase(std::next(it1).base())};
|
||||
}
|
||||
return unused_operations;
|
||||
}
|
||||
|
||||
// Returns unused input buffer IDs.
|
||||
void RemoveInputProxies(std::list<FusionSequence>* chains) {
|
||||
// Remove input proxy and sort items.
|
||||
for (auto it = chains->begin(); it != chains->end();) {
|
||||
auto& chain = *it;
|
||||
// Remove input proxy-operations.
|
||||
if (chain.front().src_tensors_ids.empty()) {
|
||||
chain.erase(chain.begin());
|
||||
}
|
||||
if (chain.empty()) {
|
||||
// Input proxy operation has been deleted and the chain is empty due to
|
||||
// unused input buffer.
|
||||
it = chains->erase(it);
|
||||
} else {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
absl::Status MergeNodes(const NodeDescriptor* src, NodeDescriptor* dst) {
|
||||
for (int j = 1; j < src->src_tensors_ids.size(); ++j) {
|
||||
dst->src_tensors_ids.push_back(src->src_tensors_ids[j]);
|
||||
@ -430,14 +101,6 @@ absl::Status MergeNodes(const NodeDescriptor* src, NodeDescriptor* dst) {
|
||||
dst->description += " linked : " + src->description;
|
||||
return dst->task->AddTask(src->task.get());
|
||||
}
|
||||
|
||||
absl::Status FuseChain(const FusionSequence& chain, NodeDescriptor* node_desc) {
|
||||
*node_desc = chain.front();
|
||||
for (int j = 1; j < chain.size(); ++j) {
|
||||
RETURN_IF_ERROR(MergeNodes(&chain[j], node_desc));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
absl::Status InferenceContext::InitFromGraph(
|
||||
@ -459,11 +122,8 @@ absl::Status InferenceContext::InitFromGraph(
|
||||
CompiledModel compiled_model;
|
||||
RETURN_IF_ERROR(
|
||||
Compile(graph, gpu_info, create_info.precision, &compiled_model));
|
||||
CompiledModel optimized_model;
|
||||
RETURN_IF_ERROR(ValidateOptimizeModel(input_ids_, output_ids_, compiled_model,
|
||||
&optimized_model));
|
||||
|
||||
RETURN_IF_ERROR(CompileModelWithDevice(device, optimized_model, input_ids_,
|
||||
RETURN_IF_ERROR(Merge(&compiled_model));
|
||||
RETURN_IF_ERROR(CompileModelWithDevice(device, compiled_model, input_ids_,
|
||||
output_ids_, create_info.precision));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
@ -500,18 +160,17 @@ absl::Status InferenceContext::Compile(const GraphFloat32& graph,
|
||||
// associativity and ADD can be linked. In current approach "linking"
|
||||
// tensor can be only latest written tensor(during linear order of
|
||||
// execution) among input tensors.
|
||||
// TODO(b/176397043) sorokin, check failure on segmentation model
|
||||
// if (IsGenericAdd(node, inputs, outputs)) {
|
||||
// int latest_written_tensor_index = 0;
|
||||
// int last_usage = tensor_usages[inputs[0]->id];
|
||||
// for (int j = 1; j < inputs.size(); ++j) {
|
||||
// if (tensor_usages[inputs[j]->id] > last_usage) {
|
||||
// last_usage = tensor_usages[inputs[j]->id];
|
||||
// latest_written_tensor_index = j;
|
||||
// }
|
||||
// }
|
||||
// std::swap(inputs[0], inputs[latest_written_tensor_index]);
|
||||
// }
|
||||
if (IsGenericAdd(node, inputs, outputs)) {
|
||||
int latest_written_tensor_index = 0;
|
||||
int last_usage = tensor_usages[inputs[0]->id];
|
||||
for (int j = 1; j < inputs.size(); ++j) {
|
||||
if (tensor_usages[inputs[j]->id] > last_usage) {
|
||||
last_usage = tensor_usages[inputs[j]->id];
|
||||
latest_written_tensor_index = j;
|
||||
}
|
||||
}
|
||||
std::swap(inputs[0], inputs[latest_written_tensor_index]);
|
||||
}
|
||||
DataType data_type = DeduceDataTypeFromPrecision(precision);
|
||||
TensorDescriptor tensor_descriptor =
|
||||
TensorDescriptor{data_type, TensorStorageType::BUFFER, Layout::HWC};
|
||||
@ -557,63 +216,53 @@ absl::Status InferenceContext::Compile(const GraphFloat32& graph,
|
||||
}
|
||||
metal_node.description =
|
||||
node.operation.type + " " + std::to_string(node.id);
|
||||
metal_node.id = i;
|
||||
compiled_model->nodes.push_back(std::move(metal_node));
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status InferenceContext::ValidateOptimizeModel(
|
||||
const std::vector<ValueId>& input_buffers,
|
||||
const std::vector<ValueId>& output_buffers,
|
||||
const CompiledModel& input_model, CompiledModel* output_model) {
|
||||
std::list<NodeDescriptor> input;
|
||||
input.insert(input.end(), input_model.nodes.begin(), input_model.nodes.end());
|
||||
OptimizationInfo info;
|
||||
info.operations_count = static_cast<int>(input.size());
|
||||
|
||||
// A chain is a sequence of fusable operations. All internal outputs are
|
||||
// consumed with the next element of the chain. The last element of each chain
|
||||
// contains outputs which are ready to be used as inputs. if a chain can't be
|
||||
// extended with linkable element then new chain is created.
|
||||
std::list<FusionSequence> unsorted_chains;
|
||||
BuildFusableChains(input_buffers, output_buffers, &input, &unsorted_chains,
|
||||
&info.unused_operations);
|
||||
|
||||
RemoveInputProxies(&unsorted_chains);
|
||||
std::list<FusionSequence> sorted_chains =
|
||||
SortChains(input_buffers, &unsorted_chains);
|
||||
|
||||
info.extra_output_buffer_ids =
|
||||
DeductOutputBufferIds(output_buffers, sorted_chains);
|
||||
info.unused_operations = DeleteUnusedTasks(output_buffers, &sorted_chains);
|
||||
info.input_buffer_ids = GetUsedInputBufferIds(sorted_chains);
|
||||
// find provided input buffers that has not being used
|
||||
for (ValueId id : input_buffers) {
|
||||
if (!Contains(info.input_buffer_ids, id)) {
|
||||
info.unused_input_buffer_ids.push_back(id);
|
||||
absl::Status InferenceContext::Merge(CompiledModel* model) {
|
||||
std::set<ValueId> ready_tensors;
|
||||
for (const auto& input_id : input_ids_) {
|
||||
ready_tensors.insert(input_id);
|
||||
}
|
||||
for (int i = 0; i < model->nodes.size(); ++i) {
|
||||
auto& node = model->nodes[i];
|
||||
for (const auto& out_id : node.dst_tensors_ids) {
|
||||
ready_tensors.insert(out_id);
|
||||
}
|
||||
if (node.dst_tensors_ids.size() != 1) {
|
||||
continue;
|
||||
}
|
||||
std::vector<int> next_nodes;
|
||||
int link_index = 0;
|
||||
for (int j = i + 1; j < model->nodes.size(); ++j) {
|
||||
for (int k = 0; k < model->nodes[j].src_tensors_ids.size(); ++k) {
|
||||
if (model->nodes[j].src_tensors_ids[k] == node.dst_tensors_ids[0]) {
|
||||
next_nodes.push_back(j);
|
||||
link_index = k;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (next_nodes.size() != 1 || link_index != 0) {
|
||||
continue;
|
||||
}
|
||||
auto& linkable_node = model->nodes[next_nodes[0]];
|
||||
if (!linkable_node.task->is_linkable ||
|
||||
linkable_node.dst_tensors_ids.size() != 1 ||
|
||||
!IsReady(ready_tensors, linkable_node)) {
|
||||
continue;
|
||||
}
|
||||
const auto& original_dst_def = node.task->definition.dst_tensors[0];
|
||||
const auto& link_dst_def = linkable_node.task->definition.dst_tensors[0];
|
||||
if (original_dst_def != link_dst_def) {
|
||||
continue;
|
||||
}
|
||||
RETURN_IF_ERROR(MergeNodes(&linkable_node, &node));
|
||||
model->nodes.erase(model->nodes.begin() + next_nodes[0]);
|
||||
i -= 1;
|
||||
}
|
||||
info.missing_output_buffer_ids =
|
||||
GetMissingOutputBufferIds(output_buffers, sorted_chains);
|
||||
info.gpu_tasks_count = static_cast<int>(sorted_chains.size());
|
||||
if (sorted_chains.empty()) {
|
||||
const std::string message =
|
||||
"No valid operations in the graph.\nInput operations count " +
|
||||
std::to_string(info.operations_count) + "\nUnused operations " +
|
||||
std::to_string(info.unused_operations.size()) + "\nUnused inputs " +
|
||||
std::to_string(info.unused_input_buffer_ids.size()) +
|
||||
"\nMissing output buffers " +
|
||||
std::to_string(info.missing_output_buffer_ids.size());
|
||||
return absl::InternalError(message);
|
||||
}
|
||||
for (const auto& chain : sorted_chains) {
|
||||
NodeDescriptor fused_node;
|
||||
RETURN_IF_ERROR(FuseChain(chain, &fused_node));
|
||||
output_model->nodes.push_back(std::move(fused_node));
|
||||
}
|
||||
output_model->tensor_shapes = input_model.tensor_shapes;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -106,6 +106,7 @@ class InferenceContext {
|
||||
const std::vector<ValueId>& output_ids,
|
||||
CalculationsPrecision precision);
|
||||
|
||||
absl::Status Merge(CompiledModel* model);
|
||||
absl::Status AllocateTensors(id<MTLDevice> device);
|
||||
absl::Status AllocateMemoryForBuffers(id<MTLDevice> device);
|
||||
void BindTensorsToOperations();
|
||||
|
Loading…
Reference in New Issue
Block a user